Coverage for src/flag_gems/runtime/backend/_ascend/ops/bmm.py: 0%
63 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.runtime.backend._ascend import heuristics_config_utils as _hcu
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
16# avoid
17@libentry()
18@triton.autotune(
19 configs=runtime.get_tuned_config("bmm"),
20 key=["M", "N", "K"],
21)
22@triton.heuristics(_hcu.HEURISTICS_CONFIGS["bmm"])
23@triton.jit
24def bmm_kernel(
25 A,
26 B,
27 O,
28 M,
29 N,
30 K,
31 TILE_M: tl.constexpr,
32 TILE_N: tl.constexpr,
33 TILE_K: tl.constexpr,
34 GROUP_M: tl.constexpr,
35 DIVISIBLE_M: tl.constexpr,
36 DIVISIBLE_N: tl.constexpr,
37 DIVISIBLE_K: tl.constexpr,
38):
39 # batch offsets
40 pid_b = ext.program_id(2)
41 A += pid_b * M * K
42 B += pid_b * K * N
43 O += pid_b * M * N
45 pidx = ext.program_id(0)
46 pidy = ext.program_id(1)
47 if GROUP_M == 1:
48 pid_m, pid_n = pidx, pidy
49 else:
50 # reorder CTAs
51 gridx = ext.num_programs(0)
52 gridy = ext.num_programs(1)
53 pid = pidx + pidy * gridx
55 num_CTA_per_group = gridy * GROUP_M
57 group_id = pid // num_CTA_per_group
58 inner_group_id = pid % num_CTA_per_group
59 GROUP_SIZE = tl.where(
60 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
61 )
62 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
63 pid_n = inner_group_id // GROUP_SIZE
65 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
66 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
67 offs_k = tl.arange(0, TILE_K)
69 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
70 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
71 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
73 num_iters = tl.cdiv(K, TILE_K)
74 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
75 for i in range(num_iters):
76 mask_a = offs_k[None, :] < K - i * TILE_K
77 mask_b = offs_k[:, None] < K - i * TILE_K
78 a = tl.load(a_ptrs, mask=mask_a)
79 b = tl.load(b_ptrs, mask=mask_b)
81 a_ptrs += TILE_K
82 b_ptrs += TILE_K * N
84 o += tl.dot(a, b, allow_tf32=False)
86 mask_m = (pid_m * TILE_M + tl.arange(0, TILE_M)) < M
87 mask_n = (pid_n * TILE_N + tl.arange(0, TILE_N)) < N
88 mask_c = mask_m[:, None] & mask_n[None, :]
89 tl.store(o_ptrs, o, mask_c)
92def bmm(A, B):
93 logger.debug("GEMS_ASCEND BMM")
94 batch, M, K = A.shape
95 _, _, N = B.shape
96 A = A.contiguous()
97 B = B.contiguous()
98 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
100 grid_fn = lambda meta: (
101 triton.cdiv(meta["M"], meta["TILE_M"]),
102 triton.cdiv(meta["N"], meta["TILE_N"]),
103 batch,
104 )
106 with torch_device_fn.device(A.device):
107 bmm_kernel[grid_fn](A, B, out, M, N, K)
108 return out