Coverage for src/flag_gems/runtime/backend/_ascend/ops/full.py: 0%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device, torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

12 

13device_ = device 

14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

15 

16ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64) 

17ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64) 

18 

19# Threshold for switching between pointwise_dynamic (small tensors) 

20# and hand-written multi-core kernel (large tensors). 

21SMALL_TENSOR_THRESHOLD = 100000 

22 

23 

24def check_dtype(fill_value, dtype, device): 

25 if isinstance(fill_value, bool): 

26 if dtype != torch.bool: 

27 fill_value = int(fill_value) 

28 

29 elif ( 

30 dtype in ALL_INT_DTYPES 

31 and (fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max) 

32 ) or ( 

33 dtype in ALL_FLOAT_DTYPES 

34 and not (math.isinf(fill_value) or math.isnan(fill_value)) 

35 and (fill_value < torch.finfo(dtype).min or fill_value > torch.finfo(dtype).max) 

36 ): 

37 raise RuntimeError( 

38 f"value cannot be converted to type {dtype} without overflow" 

39 ) 

40 

41 return fill_value 

42 

43 

44# Small tensor path: pointwise_dynamic has lower launch overhead 

45@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")]) 

46@triton.jit 

47def full_func(out, fill_value): 

48 return fill_value 

49 

50 

51@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) 

52@triton.jit 

53def full_func_scalar(out, fill_value): 

54 return tl.full(out.shape, fill_value, out.dtype) 

55 

56 

57# Large tensor path: hand-written multi-core kernel for better throughput 

58@libentry() 

59@triton.jit(do_not_specialize=["fill_value"]) 

60def full_kernel( 

61 out_ptr, 

62 N, 

63 fill_value, 

64 BLOCK_SIZE: tl.constexpr, 

65 SUBBLOCK_SIZE: tl.constexpr, 

66): 

67 pid = tle.program_id(0) 

68 pid_offset = pid * BLOCK_SIZE 

69 cols = tl.arange(0, SUBBLOCK_SIZE) 

70 num_loop = triton.cdiv(BLOCK_SIZE, SUBBLOCK_SIZE) 

71 for iloop in tl.range(num_loop): 

72 offset = pid_offset + iloop * SUBBLOCK_SIZE + cols 

73 tl.store(out_ptr + offset, fill_value, mask=offset < N) 

74 

75 

76def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None): 

77 logger.debug("GEMS_ASCEND FULL") 

78 if device is None: 

79 device = torch.device(device_.name) 

80 if dtype is None: 

81 if isinstance(fill_value, bool): 

82 dtype = torch.bool 

83 elif isinstance(fill_value, int): 

84 dtype = torch.int64 

85 else: 

86 dtype = torch.get_default_dtype() 

87 else: 

88 fill_value = check_dtype(fill_value, dtype, device) 

89 

90 out = torch.empty(size, device=device, dtype=dtype) 

91 N = out.numel() 

92 if N == 0: 

93 return out 

94 

95 if N < SMALL_TENSOR_THRESHOLD: 

96 # Small tensor: use pointwise_dynamic for lower launch overhead 

97 if isinstance(fill_value, torch.Tensor): 

98 return full_func(out, fill_value, out0=out) 

99 else: 

100 return full_func_scalar(out, fill_value, out0=out) 

101 

102 # Large tensor: use hand-written multi-core kernel 

103 if isinstance(fill_value, torch.Tensor): 

104 fill_value = fill_value.item() 

105 

106 # FIXME: 910B3&910B4 have 40 AIV cores while 910B1 has 50, 910B2 has 48. 

107 grid = min(40, N) 

108 BLOCK_SIZE = (N + grid - 1) // grid 

109 SUBBLOCK_SIZE = min(8192, BLOCK_SIZE) 

110 

111 with torch_device_fn.device(device): 

112 full_kernel[grid,](out, N, fill_value, BLOCK_SIZE, SUBBLOCK_SIZE) 

113 return out