Coverage for src/flag_gems/runtime/backend/_arm/ops/bmm.py: 0%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import triton_lang_extension as tle
11# @libentry()
12@triton.autotune(
13 configs=runtime.get_tuned_config("bmm"),
14 key=["M", "N", "K"],
15)
16@triton.heuristics(runtime.get_heuristic_config("bmm"))
17@triton.jit
18def bmm_kernel(
19 A,
20 B,
21 O,
22 M,
23 N,
24 K,
25 TILE_M: tl.constexpr,
26 TILE_N: tl.constexpr,
27 TILE_K: tl.constexpr,
28 GROUP_M: tl.constexpr,
29 DIVISIBLE_M: tl.constexpr,
30 DIVISIBLE_N: tl.constexpr,
31 DIVISIBLE_K: tl.constexpr,
32):
33 # batch offsets
34 pid_b = tle.program_id(2)
35 A += pid_b * M * K
36 B += pid_b * K * N
37 O += pid_b * M * N
39 pidx = tle.program_id(0)
40 pidy = tle.program_id(1)
42 if GROUP_M == 1:
43 pid_m, pid_n = pidx, pidy
44 else:
45 # reorder CTAs
46 gridx = tle.num_programs(0)
47 gridy = tle.num_programs(1)
48 pid = pidx + pidy * gridx
50 num_CTA_per_group = gridy * GROUP_M
52 group_id = pid // num_CTA_per_group
53 inner_group_id = pid % num_CTA_per_group
54 GROUP_SIZE = tl.where(
55 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
56 )
57 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
58 pid_n = inner_group_id // GROUP_SIZE
60 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
61 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
62 offs_k = tl.arange(0, TILE_K)
64 if not DIVISIBLE_M:
65 mask_m = offs_m < M
66 if not DIVISIBLE_N:
67 mask_n = offs_n < N
69 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
70 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
71 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
73 num_iters = tl.cdiv(K, TILE_K)
74 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
75 for _ in range(num_iters):
76 if DIVISIBLE_K:
77 if DIVISIBLE_M:
78 mask_a = None
79 else:
80 mask_a = mask_m[:, None]
81 if DIVISIBLE_N:
82 mask_b = None
83 else:
84 mask_b = mask_n[None, :]
85 else:
86 mask_k = offs_k < K
87 if DIVISIBLE_M:
88 mask_a = mask_k[None, :]
89 else:
90 mask_a = mask_m[:, None] & mask_k[None, :]
91 if DIVISIBLE_N:
92 mask_b = mask_k[:, None]
93 else:
94 mask_b = mask_k[:, None] & mask_n[None, :]
96 a = tl.load(a_ptrs, mask_a)
97 b = tl.load(b_ptrs, mask_b)
99 offs_k += TILE_K
100 a_ptrs += TILE_K
101 b_ptrs += TILE_K * N
103 o += tl.dot(a, b, allow_tf32=False)
105 if DIVISIBLE_M and DIVISIBLE_N:
106 mask_c = None
107 elif DIVISIBLE_M and not DIVISIBLE_N:
108 mask_c = mask_n[None, :]
109 elif not DIVISIBLE_M and DIVISIBLE_N:
110 mask_c = mask_m[:, None]
111 else:
112 mask_c = mask_m[:, None] & mask_n[None, :]
113 tl.store(o_ptrs, o, mask_c)
116def bmm(A, B):
117 logging.debug("GEMS BMM")
118 batch, M, K = A.shape
119 _, _, N = B.shape
120 A = A.contiguous()
121 B = B.contiguous()
122 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
124 grid_fn = lambda meta: (
125 triton.cdiv(meta["M"], meta["TILE_M"]),
126 triton.cdiv(meta["N"], meta["TILE_N"]),
127 batch,
128 )
129 # with torch_device_fn.device(A.device):
130 bmm_kernel[grid_fn](A, B, out, M, N, K)
131 return out