Coverage for src/flag_gems/runtime/backend/_ascend/ops/bmm.py: 0%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.runtime.backend._ascend import heuristics_config_utils as _hcu 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

14 

15 

16# avoid 

17@libentry() 

18@triton.autotune( 

19 configs=runtime.get_tuned_config("bmm"), 

20 key=["M", "N", "K"], 

21) 

22@triton.heuristics(_hcu.HEURISTICS_CONFIGS["bmm"]) 

23@triton.jit 

24def bmm_kernel( 

25 A, 

26 B, 

27 O, 

28 M, 

29 N, 

30 K, 

31 TILE_M: tl.constexpr, 

32 TILE_N: tl.constexpr, 

33 TILE_K: tl.constexpr, 

34 GROUP_M: tl.constexpr, 

35 DIVISIBLE_M: tl.constexpr, 

36 DIVISIBLE_N: tl.constexpr, 

37 DIVISIBLE_K: tl.constexpr, 

38): 

39 # batch offsets 

40 pid_b = ext.program_id(2) 

41 A += pid_b * M * K 

42 B += pid_b * K * N 

43 O += pid_b * M * N 

44 

45 pidx = ext.program_id(0) 

46 pidy = ext.program_id(1) 

47 if GROUP_M == 1: 

48 pid_m, pid_n = pidx, pidy 

49 else: 

50 # reorder CTAs 

51 gridx = ext.num_programs(0) 

52 gridy = ext.num_programs(1) 

53 pid = pidx + pidy * gridx 

54 

55 num_CTA_per_group = gridy * GROUP_M 

56 

57 group_id = pid // num_CTA_per_group 

58 inner_group_id = pid % num_CTA_per_group 

59 GROUP_SIZE = tl.where( 

60 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

61 ) 

62 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

63 pid_n = inner_group_id // GROUP_SIZE 

64 

65 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

66 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

67 offs_k = tl.arange(0, TILE_K) 

68 

69 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

70 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

71 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

72 

73 num_iters = tl.cdiv(K, TILE_K) 

74 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

75 for i in range(num_iters): 

76 mask_a = offs_k[None, :] < K - i * TILE_K 

77 mask_b = offs_k[:, None] < K - i * TILE_K 

78 a = tl.load(a_ptrs, mask=mask_a) 

79 b = tl.load(b_ptrs, mask=mask_b) 

80 

81 a_ptrs += TILE_K 

82 b_ptrs += TILE_K * N 

83 

84 o += tl.dot(a, b, allow_tf32=False) 

85 

86 mask_m = (pid_m * TILE_M + tl.arange(0, TILE_M)) < M 

87 mask_n = (pid_n * TILE_N + tl.arange(0, TILE_N)) < N 

88 mask_c = mask_m[:, None] & mask_n[None, :] 

89 tl.store(o_ptrs, o, mask_c) 

90 

91 

92def bmm(A, B): 

93 logger.debug("GEMS_ASCEND BMM") 

94 batch, M, K = A.shape 

95 _, _, N = B.shape 

96 A = A.contiguous() 

97 B = B.contiguous() 

98 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

99 

100 grid_fn = lambda meta: ( 

101 triton.cdiv(meta["M"], meta["TILE_M"]), 

102 triton.cdiv(meta["N"], meta["TILE_N"]), 

103 batch, 

104 ) 

105 

106 with torch_device_fn.device(A.device): 

107 bmm_kernel[grid_fn](A, B, out, M, N, K) 

108 return out