Coverage for src/flag_gems/runtime/backend/_ascend/ops/full.py: 0%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import triton_lang_extension as tle 

10from flag_gems.utils.shape_utils import volume 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15@triton.jit(do_not_specialize=["fill_value_or_ptr"]) 

16def full_kernel( 

17 output_ptr, 

18 n_elements, 

19 fill_value_or_ptr, 

20 FILL_VALUE_IS_PTR: tl.constexpr, 

21 BLOCK_SIZE: tl.constexpr, 

22): 

23 pid = tle.program_id(axis=0) 

24 block_start = pid * BLOCK_SIZE 

25 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

26 mask = offsets < n_elements 

27 if FILL_VALUE_IS_PTR: 

28 fill_value = tl.load(fill_value_or_ptr) 

29 else: 

30 fill_value = fill_value_or_ptr 

31 tl.store(output_ptr + offsets, fill_value, mask=mask) 

32 

33 

34ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64) 

35ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64) 

36 

37 

38def check_dtype(fill_value, dtype, device): 

39 if isinstance(fill_value, bool): 

40 if dtype != torch.bool: 

41 fill_value = int(fill_value) 

42 elif dtype in ALL_INT_DTYPES and ( 

43 fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max 

44 ): 

45 raise RuntimeError( 

46 f"value cannot be converted to type {dtype} without overflow" 

47 ) 

48 if dtype in ALL_FLOAT_DTYPES: 

49 fill_value = torch.tensor(fill_value, dtype=dtype, device=device) 

50 return fill_value 

51 

52 

53def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None): 

54 logger.debug("GEMS_ASCEND FULL") 

55 if device is None: 

56 device = torch.device("cpu") 

57 if dtype is None: 

58 if isinstance(fill_value, bool): 

59 dtype = torch.bool 

60 elif isinstance(fill_value, int): 

61 dtype = torch.int64 

62 else: 

63 dtype = torch.get_default_dtype() 

64 else: 

65 fill_value = check_dtype(fill_value, dtype, device) 

66 

67 out = torch.empty(size, device=device, dtype=dtype) 

68 N = volume(size) 

69 BLOCK_SIZE = min(triton.next_power_of_2(math.ceil(math.sqrt(N))), 2048) 

70 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) 

71 with torch_device_fn.device(device): 

72 full_kernel[grid_fn]( 

73 out, 

74 N, 

75 fill_value, 

76 FILL_VALUE_IS_PTR=isinstance(fill_value, torch.Tensor), 

77 BLOCK_SIZE=BLOCK_SIZE, 

78 ) 

79 return out