Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cumprod.py: 0%

240 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import functools 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8from torch._prims_common import is_boolean_dtype, is_integer_dtype 

9 

10from flag_gems.runtime import device as runtime_device 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import get_device_properties, libentry 

13from flag_gems.utils import triton_lang_extension as ext 

14 

15logger = logging.getLogger(__name__) 

16 

17_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

18 torch._C.DispatchKey.CompositeExplicitAutograd 

19) 

20DEFAULT_BLOCK_SIZE = 1024 

21CUDA_SMALL_SCAN_LIMIT = 1024 * 4 

22ASCEND_SCAN_LIMIT = 1024 

23DEFAULT_NUM_SMS = 40 

24 

25 

26@functools.lru_cache 

27def get_num_sms(idx: int) -> int: 

28 return get_device_properties(idx).multi_processor_count or DEFAULT_NUM_SMS 

29 

30 

31def _get_device_index(torch_device): 

32 if torch_device.index is not None: 

33 return torch_device.index 

34 return torch_device_fn.current_device() 

35 

36 

37@tl.constexpr 

38def get_prod_accum_type(out_dtype: tl.dtype) -> tl.dtype: 

39 if out_dtype.is_bf16() or out_dtype.is_fp16(): 

40 return tl.float32 

41 if out_dtype.is_int(): 

42 return tl.int64 

43 return out_dtype 

44 

45 

46@triton.jit 

47def reduce_mul(a, b): 

48 return a * b 

49 

50 

51@libentry() 

52@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

53def scan_part_product_kernel( 

54 inp, 

55 out, 

56 partial_product, 

57 n_elements, 

58 part_num, 

59 BLOCK_SIZE: tl.constexpr, 

60): 

61 pid = ext.program_id(0) 

62 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

63 mask = offset < n_elements 

64 

65 acc_dtype: tl.constexpr = get_prod_accum_type(out.type.element_ty) 

66 inp_vals = tl.load(inp + offset, mask=mask, other=1).to(acc_dtype) 

67 result = tl.cumprod(inp_vals, axis=0) 

68 part_product = tl.reduce(inp_vals, axis=0, combine_fn=reduce_mul) 

69 

70 tl.store(out + offset, result, mask=mask) 

71 tl.store(partial_product + pid, part_product) 

72 

73 

74@libentry() 

75@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

76def multiply_base_product_kernel( 

77 out, 

78 partial_product, 

79 n_elements, 

80 part_num, 

81 BLOCK_SIZE: tl.constexpr, 

82): 

83 pid = ext.program_id(0) 

84 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

85 mask = offset < n_elements 

86 

87 out_vals = tl.load(out + offset, mask=mask) 

88 

89 if pid > 0: 

90 acc_dtype: tl.constexpr = get_prod_accum_type(out.type.element_ty) 

91 base_product = tl.load(partial_product + pid - 1).to(acc_dtype) 

92 final_vals = out_vals.to(acc_dtype) * base_product 

93 tl.store(out + offset, final_vals, mask=mask) 

94 

95 

96@libentry() 

97@triton.jit(do_not_specialize=["part_num"]) 

98def scan_part_product_abc_kernel( 

99 inp, 

100 out, 

101 partial_product, 

102 B, 

103 C, 

104 part_num, 

105 BLOCK_SIZE: tl.constexpr, 

106): 

107 pid_a = ext.program_id(0) 

108 pid_b = ext.program_id(1) 

109 pid_c = ext.program_id(2) 

110 

111 a_idx = pid_a 

112 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

113 c_idx = pid_c 

114 

115 offset = a_idx * B * C + b_idx * C + c_idx 

116 base_part_offset = a_idx * part_num * C + c_idx 

117 part_offset = base_part_offset + pid_b * C 

118 mask = b_idx < B 

119 

120 acc_dtype: tl.constexpr = get_prod_accum_type(out.type.element_ty) 

121 inp_vals = tl.load(inp + offset, mask=mask, other=1).to(acc_dtype) 

122 result = tl.cumprod(inp_vals, axis=0) 

123 part_product = tl.reduce(inp_vals, axis=0, combine_fn=reduce_mul) 

124 

125 tl.store(out + offset, result, mask=mask) 

126 tl.store(partial_product + part_offset, part_product) 

127 

128 

129@libentry() 

130@triton.jit(do_not_specialize=["part_num"]) 

131def multiply_base_product_abc_kernel( 

132 out, 

133 partial_product, 

134 B, 

135 C, 

136 part_num, 

137 BLOCK_SIZE: tl.constexpr, 

138): 

139 pid_a = ext.program_id(0) 

140 pid_b = ext.program_id(1) 

141 pid_c = ext.program_id(2) 

142 

143 a_idx = pid_a 

144 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

145 c_idx = pid_c 

146 

147 offset = a_idx * B * C + b_idx * C + c_idx 

148 base_part_offset = a_idx * part_num * C + c_idx 

149 last_part_offset = base_part_offset + (pid_b - 1) * C 

150 mask = b_idx < B 

151 

152 out_vals = tl.load(out + offset, mask=mask) 

153 

154 if pid_b > 0: 

155 acc_dtype: tl.constexpr = get_prod_accum_type(out.type.element_ty) 

156 base_product = tl.load(partial_product + last_part_offset).to(acc_dtype) 

157 final_vals = out_vals.to(acc_dtype) * base_product 

158 tl.store(out + offset, final_vals, mask=mask) 

159 

160 

161def scan_then_fan_col(inp, out, n_ele, dtype): 

162 BLOCK_SIZE = _scan_block_size(n_ele) 

163 part_num = math.ceil(n_ele / BLOCK_SIZE) 

164 partial_product = torch.empty(part_num, dtype=dtype, device=inp.device) 

165 

166 grid = (part_num,) 

167 with torch_device_fn.device(inp.device): 

168 scan_part_product_kernel[grid]( 

169 inp, out, partial_product, n_ele, part_num, BLOCK_SIZE 

170 ) 

171 

172 if part_num >= 2: 

173 partial_prefix = torch.empty_like(partial_product) 

174 scan_then_fan_col(partial_product, partial_prefix, part_num, dtype) 

175 with torch_device_fn.device(inp.device): 

176 multiply_base_product_kernel[grid]( 

177 out, partial_prefix, n_ele, part_num, BLOCK_SIZE 

178 ) 

179 

180 

181def scan_then_fan(inp, out, A, B, C, dtype): 

182 BLOCK_SIZE = _scan_block_size(B) 

183 part_num = math.ceil(B / BLOCK_SIZE) 

184 partial_product = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

185 

186 grid = (A, part_num, C) 

187 with torch_device_fn.device(inp.device): 

188 scan_part_product_abc_kernel[grid]( 

189 inp, out, partial_product, B, C, part_num, BLOCK_SIZE 

190 ) 

191 

192 if part_num >= 2: 

193 partial_prefix = torch.empty_like(partial_product) 

194 scan_then_fan(partial_product, partial_prefix, A, part_num, C, dtype) 

195 with torch_device_fn.device(inp.device): 

196 multiply_base_product_abc_kernel[grid]( 

197 out, partial_prefix, B, C, part_num, BLOCK_SIZE 

198 ) 

199 

200 

201def _get_output_dtype(inp, dtype): 

202 if dtype is not None: 

203 return dtype 

204 if is_integer_dtype(inp.dtype) or is_boolean_dtype(inp.dtype): 

205 return torch.int64 

206 return inp.dtype 

207 

208 

209def _get_compute_dtype(dtype): 

210 if dtype in (torch.float16, torch.bfloat16): 

211 return torch.float32 

212 if is_integer_dtype(dtype) or is_boolean_dtype(dtype): 

213 return torch.int64 

214 return dtype 

215 

216 

217def _should_redispatch_on_ascend(dtype): 

218 return runtime_device.vendor_name == "ascend" and ( 

219 is_integer_dtype(dtype) or is_boolean_dtype(dtype) 

220 ) 

221 

222 

223def _scan_block_size(length): 

224 limit = ( 

225 ASCEND_SCAN_LIMIT 

226 if runtime_device.vendor_name == "ascend" 

227 else CUDA_SMALL_SCAN_LIMIT 

228 ) 

229 if length <= limit: 

230 return triton.next_power_of_2(length) 

231 return DEFAULT_BLOCK_SIZE 

232 

233 

234def cumprod_wrapper(inp, dim, dtype=None, out=None): 

235 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

236 dim = dim % inp.ndim 

237 out_dtype = _get_output_dtype(inp, dtype) 

238 

239 inp = inp.contiguous() 

240 if out is None: 

241 out = torch.empty_like(inp, dtype=out_dtype) 

242 

243 if inp.numel() == 0: 

244 return out 

245 

246 shape = inp.shape 

247 M = math.prod(shape[:dim]) 

248 N = shape[dim] 

249 K = inp.numel() // M // N 

250 compute_dtype = _get_compute_dtype(out.dtype) 

251 

252 if K == 1: 

253 reduce_then_scan_row(inp, out, M, N, compute_dtype) 

254 else: 

255 scan_then_fan(inp, out, M, N, K, compute_dtype) 

256 

257 return out 

258 

259 

260def reduce_then_scan_row(x, out, M, N, compute_dtype): 

261 persistent_limit = ( 

262 ASCEND_SCAN_LIMIT if runtime_device.vendor_name == "ascend" else 16384 

263 ) 

264 if N <= persistent_limit: 

265 TILE_SIZE = triton.next_power_of_2(N) 

266 num_warps = 8 if TILE_SIZE > 2048 else 4 

267 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)]( 

268 x, out, N, TILE_SIZE, num_warps=num_warps 

269 ) 

270 return out 

271 

272 TILE_SIZE = min(_scan_block_size(N), triton.next_power_of_2(N)) 

273 num_warps = 8 if TILE_SIZE > 2048 else 4 

274 num_tiles = triton.cdiv(N, TILE_SIZE) 

275 max_ctas = get_num_sms(_get_device_index(x.device)) * 4 

276 num_ctas = min(num_tiles, max_ctas) 

277 ROOT_SCAN_TILE_SIZE = triton.next_power_of_2(num_ctas) 

278 tiles_per_cta = triton.cdiv(num_tiles, num_ctas) 

279 

280 block_products = torch.empty((M, num_ctas), dtype=compute_dtype, device=x.device) 

281 block_inclusive_prefix = torch.empty_like(block_products) 

282 

283 reduce_then_scan_block_product_kernel_row[(M, num_ctas, 1, 1)]( 

284 x, block_products, N, tiles_per_cta, TILE_SIZE, num_warps=num_warps 

285 ) 

286 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)]( 

287 block_products, 

288 block_inclusive_prefix, 

289 num_ctas, 

290 ROOT_SCAN_TILE_SIZE, 

291 num_warps=num_warps, 

292 ) 

293 reduce_then_scan_block_scan_kernel_row[(M, num_ctas, 1)]( 

294 x, 

295 block_inclusive_prefix, 

296 out, 

297 N, 

298 num_ctas, 

299 tiles_per_cta, 

300 TILE_SIZE, 

301 num_warps=num_warps, 

302 ) 

303 return out 

304 

305 

306@triton.jit 

307def reduce_then_scan_block_product_kernel_row( 

308 in_ptr, 

309 block_product_ptr, 

310 N, 

311 tiles_per_cta, 

312 TILE_SIZE: tl.constexpr, 

313): 

314 pid_n = tl.program_id(1).to(tl.int64) 

315 pid_m = tl.program_id(0).to(tl.int64) 

316 num_programs_n = tl.num_programs(1) 

317 block_offset = pid_n * (tiles_per_cta * TILE_SIZE) 

318 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N) 

319 

320 acc_dtype: tl.constexpr = get_prod_accum_type(block_product_ptr.type.element_ty) 

321 acc = tl.full((TILE_SIZE,), value=1, dtype=acc_dtype) 

322 for start in range(block_offset, block_end, TILE_SIZE): 

323 offsets = start + tl.arange(0, TILE_SIZE) 

324 x = tl.load(in_ptr + pid_m * N + offsets, mask=offsets < N, other=1).to( 

325 acc_dtype 

326 ) 

327 acc *= x 

328 block_product = tl.reduce(acc, axis=0, combine_fn=reduce_mul) 

329 tl.store( 

330 block_product_ptr + pid_m * num_programs_n + pid_n, 

331 block_product, 

332 cache_modifier=".cg", 

333 ) 

334 

335 

336@triton.jit 

337def reduce_then_scan_root_scan_kernel_row(in_ptr, out_ptr, N, TILE_SIZE: tl.constexpr): 

338 pid = tl.program_id(0).to(tl.int64) 

339 offsets = tl.arange(0, TILE_SIZE) 

340 mask = offsets < N 

341 acc_dtype: tl.constexpr = get_prod_accum_type(out_ptr.type.element_ty) 

342 x = tl.load(in_ptr + pid * N + offsets, mask=mask, other=1).to(acc_dtype) 

343 out = tl.cumprod(x, 0) 

344 tl.store(out_ptr + pid * N + offsets, out, mask=mask) 

345 

346 

347@triton.jit 

348def reduce_then_scan_block_scan_kernel_row( 

349 in_ptr, 

350 previous_product_ptr, 

351 out_ptr, 

352 N, 

353 num_tiles_n, 

354 tiles_per_cta, 

355 TILE_SIZE: tl.constexpr, 

356): 

357 pid_m = tl.program_id(0).to(tl.int64) 

358 pid_n = tl.program_id(1).to(tl.int64) 

359 block_offset = pid_n * (tiles_per_cta * TILE_SIZE) 

360 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N) 

361 acc_dtype: tl.constexpr = get_prod_accum_type(out_ptr.type.element_ty) 

362 

363 prefix = tl.load( 

364 previous_product_ptr + pid_m * num_tiles_n + pid_n - 1, 

365 mask=pid_n > 0, 

366 other=1, 

367 ).to(acc_dtype) 

368 for start in range(block_offset, block_end, TILE_SIZE): 

369 offsets = start + tl.arange(0, TILE_SIZE) 

370 mask = offsets < N 

371 x = tl.load(in_ptr + pid_m * N + offsets, mask=mask, other=1).to(acc_dtype) 

372 tile_scan = prefix * tl.cumprod(x, 0) 

373 prefix *= tl.reduce(x, axis=0, combine_fn=reduce_mul) 

374 tl.store( 

375 out_ptr + pid_m * N + offsets, tile_scan, mask=mask, cache_modifier=".cg" 

376 ) 

377 

378 

379def cumprod(inp, dim, *, dtype=None): 

380 logger.debug("GEMS CUMPROD") 

381 out_dtype = _get_output_dtype(inp, dtype) 

382 if is_boolean_dtype(inp.dtype): 

383 if is_boolean_dtype(out_dtype): 

384 return torch.ops.aten.cumprod.default.redispatch( 

385 _FALLBACK_KEYSET, inp, dim, dtype=dtype 

386 ) 

387 uint8_inp = inp.to(torch.uint8) 

388 if runtime_device.vendor_name == "ascend": 

389 return torch.ops.aten.cumprod.default.redispatch( 

390 _FALLBACK_KEYSET, uint8_inp, dim, dtype=dtype 

391 ) 

392 return cumprod_wrapper(uint8_inp, dim, out_dtype) 

393 if _should_redispatch_on_ascend(out_dtype): 

394 return torch.ops.aten.cumprod.default.redispatch( 

395 _FALLBACK_KEYSET, inp, dim, dtype=dtype 

396 ) 

397 return cumprod_wrapper(inp, dim, dtype) 

398 

399 

400def cumprod_(inp, dim, *, dtype=None): 

401 logger.debug("GEMS CUMPROD_") 

402 if dtype is not None and dtype != inp.dtype: 

403 raise RuntimeError( 

404 "Bad in-place call: input tensor dtype and output tensor dtype should match" 

405 ) 

406 if is_boolean_dtype(inp.dtype): 

407 raise NotImplementedError( 

408 "In-place cumprod is not supported for boolean tensors" 

409 ) 

410 if _should_redispatch_on_ascend(inp.dtype): 

411 return torch.ops.aten.cumprod_.default.redispatch( 

412 _FALLBACK_KEYSET, inp, dim, dtype=dtype 

413 ) 

414 out = cumprod_wrapper(inp, dim, inp.dtype) 

415 inp.copy_(out) 

416 return inp