Coverage for src/flag_gems/runtime/backend/_mthreads/ops/arange.py: 0%

89 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.ops.arange import arange_start as default_arange_start 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry 

13from flag_gems.utils import triton_lang_extension as ext 

14 

15logger = logging.getLogger( 

16 f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}" 

17) 

18 

19device_ = runtime.device 

20_SUPPORTED_DTYPES = { 

21 torch.float16, 

22 torch.bfloat16, 

23 torch.float32, 

24 torch.int32, 

25 torch.int64, 

26} 

27_AUTOTUNE_CONFIGS = [ 

28 triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=1), 

29 triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=1), 

30 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1), 

31 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=1), 

32] 

33 

34 

35@libentry() 

36@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["n_elements", "USE_INT64"]) 

37@triton.jit(do_not_specialize=["start", "step"]) 

38def arange_kernel( 

39 out_ptr, 

40 start, 

41 step, 

42 n_elements, 

43 IS_FLOAT: tl.constexpr, 

44 USE_INT64: tl.constexpr, 

45 BLOCK_SIZE: tl.constexpr, 

46): 

47 pid = ext.program_id(0) 

48 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

49 if USE_INT64: 

50 offsets = offsets.to(tl.int64) 

51 n_elements = tl.full((1,), n_elements, tl.int64) 

52 else: 

53 offsets = offsets.to(tl.int32) 

54 n_elements = tl.full((1,), n_elements, tl.int32) 

55 mask = offsets < n_elements 

56 

57 if IS_FLOAT: 

58 idx = offsets.to(tl.float32) 

59 step_val = tl.full((1,), step, tl.float32) 

60 start_val = tl.full((1,), start, tl.float32) 

61 values = tl.fma(idx, step_val, start_val) 

62 else: 

63 value_dtype = tl.int64 if USE_INT64 else tl.int32 

64 idx = offsets.to(value_dtype) 

65 step_val = tl.full((1,), step, value_dtype) 

66 start_val = tl.full((1,), start, value_dtype) 

67 values = start_val + idx * step_val 

68 

69 tl.store(out_ptr + offsets, values, mask=mask) 

70 

71 

72def _normalize_scalar(value): 

73 if isinstance(value, torch.Tensor): 

74 return value.item() 

75 return value 

76 

77 

78def _compute_size(start, end, step, is_float_dtype: bool) -> int: 

79 if step == 0: 

80 raise ValueError("arange(): step must be non-zero.") 

81 if is_float_dtype: 

82 size = math.ceil((end - start) / step) 

83 else: 

84 sgn = (step > 0) - (step < 0) 

85 size = (end - start + step - sgn) // step 

86 return int(size) if size > 0 else 0 

87 

88 

89def _use_triton(dtype: torch.dtype, device: torch.device, size: int) -> bool: 

90 if device.type != "musa": 

91 return False 

92 if dtype not in _SUPPORTED_DTYPES: 

93 return False 

94 return size > 0 

95 

96 

97def _launch_triton_kernel( 

98 out: torch.Tensor, 

99 start, 

100 step, 

101 size: int, 

102 *, 

103 is_float_dtype: bool, 

104 use_int64: bool, 

105): 

106 grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),) 

107 with torch_device_fn.device(out.device): 

108 arange_kernel[grid]( 

109 out, 

110 start, 

111 step, 

112 size, 

113 IS_FLOAT=is_float_dtype, 

114 USE_INT64=use_int64, 

115 ) 

116 return out 

117 

118 

119def arange_start( 

120 start, 

121 end, 

122 step=1, 

123 *, 

124 dtype: Optional[torch.dtype] = None, 

125 layout=None, 

126 device=None, 

127 pin_memory: Optional[bool] = None, 

128): 

129 logger.debug("GEMS_MTHREADS ARANGE") 

130 start = _normalize_scalar(start) 

131 end = _normalize_scalar(end) 

132 step = _normalize_scalar(step) 

133 

134 if dtype is None: 

135 dtype = torch.int64 

136 if pin_memory is None: 

137 pin_memory = False 

138 if device is None: 

139 device = torch.device(device_.name) 

140 else: 

141 device = torch.device(device) 

142 

143 # Handle int64 dtype with float parameters - convert to int 

144 if dtype is torch.int64: 

145 if ( 

146 isinstance(start, float) 

147 or isinstance(end, float) 

148 or isinstance(step, float) 

149 ): 

150 start = int(start) if isinstance(start, float) else start 

151 end = int(end) if isinstance(end, float) else end 

152 step = int(step) if isinstance(step, float) else step 

153 if step == 0: 

154 raise RuntimeError("step must be nonzero") 

155 

156 is_float_dtype = torch.is_floating_point(torch.tensor(0, dtype=dtype)) 

157 use_int64 = dtype == torch.int64 

158 size = _compute_size(start, end, step, is_float_dtype) 

159 

160 if not _use_triton(dtype, device, size): 

161 return default_arange_start( 

162 start, 

163 end, 

164 step, 

165 dtype=dtype, 

166 layout=layout, 

167 device=device, 

168 pin_memory=pin_memory, 

169 ) 

170 

171 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory) 

172 return _launch_triton_kernel( 

173 result, 

174 start, 

175 step, 

176 size, 

177 is_float_dtype=is_float_dtype, 

178 use_int64=use_int64, 

179 ) 

180 

181 

182def arange( 

183 end, 

184 *, 

185 dtype: Optional[torch.dtype] = None, 

186 layout=None, 

187 device=None, 

188 pin_memory: Optional[bool] = None, 

189): 

190 return arange_start( 

191 0, 

192 end, 

193 1, 

194 dtype=dtype, 

195 layout=layout, 

196 device=device, 

197 pin_memory=pin_memory, 

198 )