Coverage for src/flag_gems/runtime/backend/_mthreads/ops/tile.py: 0%
174 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as ext
9from flag_gems.utils.libentry import libentry
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.autotune(
16 configs=[
17 triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4),
18 triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4),
19 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4),
20 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8),
21 triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4),
22 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4),
23 triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8),
24 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8),
25 ],
26 key=["out_shape0", "out_shape1"],
27)
28@triton.jit
29def tile_kernel_2d(
30 inp_ptr,
31 out_ptr,
32 inp_stride0,
33 inp_stride1,
34 out_stride0,
35 out_stride1,
36 inp_shape0,
37 inp_shape1,
38 out_shape0,
39 out_shape1,
40 BLOCK_M: tl.constexpr,
41 BLOCK_N: tl.constexpr,
42):
43 pid_m = ext.program_id(0)
44 pid_n = ext.program_id(1)
46 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
47 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
49 mask_m = offs_m < out_shape0
50 mask_n = offs_n < out_shape1
51 mask = mask_m[:, None] & mask_n[None, :]
53 # Map output indices to input indices using modulo
54 inp_offs_m = offs_m % inp_shape0
55 inp_offs_n = offs_n % inp_shape1
57 # Load from input
58 inp_ptrs = (
59 inp_ptr + inp_offs_m[:, None] * inp_stride0 + inp_offs_n[None, :] * inp_stride1
60 )
61 data = tl.load(inp_ptrs, mask=mask, other=0.0)
63 # Store to output
64 out_ptrs = out_ptr + offs_m[:, None] * out_stride0 + offs_n[None, :] * out_stride1
65 tl.store(out_ptrs, data, mask=mask)
68@libentry()
69@triton.autotune(
70 configs=[
71 triton.Config({"BLOCK_SIZE": 256}, num_warps=4),
72 triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
73 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8),
74 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
75 ],
76 key=["out_shape0"],
77)
78@triton.jit
79def tile_kernel_1d(
80 inp_ptr,
81 out_ptr,
82 inp_stride0,
83 out_stride0,
84 inp_shape0,
85 out_shape0,
86 BLOCK_SIZE: tl.constexpr,
87):
88 pid = ext.program_id(0)
89 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
90 mask = offs < out_shape0
92 # Map output indices to input indices
93 inp_offs = offs % inp_shape0
95 # Load and store
96 data = tl.load(inp_ptr + inp_offs * inp_stride0, mask=mask)
97 tl.store(out_ptr + offs * out_stride0, data, mask=mask)
100@libentry()
101@triton.autotune(
102 configs=[
103 triton.Config({"BLOCK_N": 32, "BLOCK_K": 32}, num_warps=4),
104 triton.Config({"BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
105 triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4),
106 triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8),
107 triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
108 triton.Config({"BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4),
109 ],
110 key=["out_shape1", "out_shape2"],
111)
112@triton.jit
113def tile_kernel_3d(
114 inp_ptr,
115 out_ptr,
116 inp_stride0,
117 inp_stride1,
118 inp_stride2,
119 out_stride0,
120 out_stride1,
121 out_stride2,
122 inp_shape0,
123 inp_shape1,
124 inp_shape2,
125 out_shape0,
126 out_shape1,
127 out_shape2,
128 BLOCK_N: tl.constexpr,
129 BLOCK_K: tl.constexpr,
130):
131 """Process 3D tile: one program handles one (m, n_block, k_block)"""
132 pid_m = ext.program_id(0)
133 pid_nk = ext.program_id(1)
135 num_k_blocks = tl.cdiv(out_shape2, BLOCK_K)
136 pid_n = pid_nk // num_k_blocks
137 pid_k = pid_nk % num_k_blocks
139 m_idx = pid_m
140 if m_idx >= out_shape0:
141 return
143 inp_m = m_idx % inp_shape0
145 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
146 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
148 mask_n = offs_n < out_shape1
149 mask_k = offs_k < out_shape2
150 mask = mask_n[:, None] & mask_k[None, :]
152 inp_n = offs_n % inp_shape1
153 inp_k = offs_k % inp_shape2
155 inp_ptrs = (
156 inp_ptr
157 + inp_m * inp_stride0
158 + inp_n[:, None] * inp_stride1
159 + inp_k[None, :] * inp_stride2
160 )
161 data = tl.load(inp_ptrs, mask=mask, other=0.0)
163 out_ptrs = (
164 out_ptr
165 + m_idx * out_stride0
166 + offs_n[:, None] * out_stride1
167 + offs_k[None, :] * out_stride2
168 )
169 tl.store(out_ptrs, data, mask=mask)
172@libentry()
173@triton.autotune(
174 configs=[
175 triton.Config({"BLOCK_K": 32, "BLOCK_L": 32}, num_warps=4),
176 triton.Config({"BLOCK_K": 64, "BLOCK_L": 32}, num_warps=4),
177 triton.Config({"BLOCK_K": 32, "BLOCK_L": 64}, num_warps=4),
178 triton.Config({"BLOCK_K": 64, "BLOCK_L": 64}, num_warps=8),
179 triton.Config({"BLOCK_K": 128, "BLOCK_L": 32}, num_warps=4),
180 triton.Config({"BLOCK_K": 32, "BLOCK_L": 128}, num_warps=4),
181 ],
182 key=["out_shape2", "out_shape3"],
183)
184@triton.jit
185def tile_kernel_4d(
186 inp_ptr,
187 out_ptr,
188 inp_stride0,
189 inp_stride1,
190 inp_stride2,
191 inp_stride3,
192 out_stride0,
193 out_stride1,
194 out_stride2,
195 out_stride3,
196 inp_shape0,
197 inp_shape1,
198 inp_shape2,
199 inp_shape3,
200 out_shape0,
201 out_shape1,
202 out_shape2,
203 out_shape3,
204 BLOCK_K: tl.constexpr,
205 BLOCK_L: tl.constexpr,
206):
207 """Process 4D tile: one program handles one (m, n, k_block, l_block)"""
208 pid_mn = ext.program_id(0)
209 pid_kl = ext.program_id(1)
211 num_l_blocks = tl.cdiv(out_shape3, BLOCK_L)
212 pid_k = pid_kl // num_l_blocks
213 pid_l = pid_kl % num_l_blocks
215 # Flatten m, n
216 m_idx = pid_mn // out_shape1
217 n_idx = pid_mn % out_shape1
219 if m_idx >= out_shape0:
220 return
222 inp_m = m_idx % inp_shape0
223 inp_n = n_idx % inp_shape1
225 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
226 offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
228 mask_k = offs_k < out_shape2
229 mask_l = offs_l < out_shape3
230 mask = mask_k[:, None] & mask_l[None, :]
232 inp_k = offs_k % inp_shape2
233 inp_l = offs_l % inp_shape3
235 inp_ptrs = (
236 inp_ptr
237 + inp_m * inp_stride0
238 + inp_n * inp_stride1
239 + inp_k[:, None] * inp_stride2
240 + inp_l[None, :] * inp_stride3
241 )
242 data = tl.load(inp_ptrs, mask=mask, other=0.0)
244 out_ptrs = (
245 out_ptr
246 + m_idx * out_stride0
247 + n_idx * out_stride1
248 + offs_k[:, None] * out_stride2
249 + offs_l[None, :] * out_stride3
250 )
251 tl.store(out_ptrs, data, mask=mask)
254@libentry()
255@triton.jit
256def tile_kernel_nd_flat(
257 inp_ptr,
258 out_ptr,
259 num_tasks,
260 inp_shape0,
261 inp_shape1,
262 inp_shape2,
263 inp_shape3,
264 inp_shape4,
265 out_shape0,
266 out_shape1,
267 out_shape2,
268 out_shape3,
269 out_shape4,
270 inp_stride0,
271 inp_stride1,
272 inp_stride2,
273 inp_stride3,
274 inp_stride4,
275 out_stride0,
276 out_stride1,
277 out_stride2,
278 out_stride3,
279 out_stride4,
280 rank: tl.constexpr,
281 BLOCK_SIZE: tl.constexpr,
282):
283 """Generic N-D tile kernel (up to 5D) using flat indexing with modulo"""
284 pid = ext.program_id(0)
285 num_ctas = ext.num_programs(0)
287 for idx in range(pid * BLOCK_SIZE, num_tasks, num_ctas * BLOCK_SIZE):
288 offs = idx + tl.arange(0, BLOCK_SIZE)
289 mask = offs < num_tasks
291 remaining = offs
293 # Unroll for up to 5D
294 if rank >= 5:
295 out_idx4 = remaining % out_shape4
296 inp_idx4 = out_idx4 % inp_shape4
297 remaining = remaining // out_shape4
298 else:
299 out_idx4 = tl.zeros_like(offs)
300 inp_idx4 = tl.zeros_like(offs)
302 if rank >= 4:
303 out_idx3 = remaining % out_shape3
304 inp_idx3 = out_idx3 % inp_shape3
305 remaining = remaining // out_shape3
306 else:
307 out_idx3 = tl.zeros_like(offs)
308 inp_idx3 = tl.zeros_like(offs)
310 if rank >= 3:
311 out_idx2 = remaining % out_shape2
312 inp_idx2 = out_idx2 % inp_shape2
313 remaining = remaining // out_shape2
314 else:
315 out_idx2 = tl.zeros_like(offs)
316 inp_idx2 = tl.zeros_like(offs)
318 if rank >= 2:
319 out_idx1 = remaining % out_shape1
320 inp_idx1 = out_idx1 % inp_shape1
321 remaining = remaining // out_shape1
322 else:
323 out_idx1 = tl.zeros_like(offs)
324 inp_idx1 = tl.zeros_like(offs)
326 out_idx0 = remaining
327 inp_idx0 = out_idx0 % inp_shape0
329 inp_offset = (
330 inp_idx0 * inp_stride0
331 + inp_idx1 * inp_stride1
332 + inp_idx2 * inp_stride2
333 + inp_idx3 * inp_stride3
334 + inp_idx4 * inp_stride4
335 )
336 out_offset = (
337 out_idx0 * out_stride0
338 + out_idx1 * out_stride1
339 + out_idx2 * out_stride2
340 + out_idx3 * out_stride3
341 + out_idx4 * out_stride4
342 )
344 data = tl.load(inp_ptr + inp_offset, mask=mask)
345 tl.store(out_ptr + out_offset, data, mask=mask)
348def tile(inp: torch.Tensor, dims) -> torch.Tensor:
349 logger.debug("GEMS TILE")
350 in0_rank = inp.dim()
351 dims_rank = len(dims)
352 in0_shape = list(inp.shape)
353 dims_shape = list(dims)
355 # Normalize shapes
356 if dims_rank < in0_rank:
357 diff = in0_rank - dims_rank
358 dims_shape = [1] * diff + dims_shape
359 elif dims_rank > in0_rank:
360 diff = dims_rank - in0_rank
361 in0_shape = [1] * diff + in0_shape
363 # Check for empty and compute output shape
364 is_empty = False
365 out_shape = []
366 for i in range(len(in0_shape)):
367 assert (
368 dims_shape[i] >= 0
369 ), f"the number of repetitions per dimension out of range (expected to >= 0) but got {dims_shape[i]}"
370 if dims_shape[i] == 0:
371 is_empty = True
372 out_shape.append(in0_shape[i] * dims_shape[i])
374 out = torch.empty(out_shape, device=inp.device, dtype=inp.dtype)
376 if is_empty:
377 return out
379 inp = inp.reshape(in0_shape)
380 rank = len(out_shape)
381 num_tasks = out.numel()
383 # Get strides (handle 0-sized dimensions)
384 inp_strides = list(inp.stride())
385 out_strides = list(out.stride())
387 with torch_device_fn.device(inp.device.index):
388 if rank == 1:
389 # 1D case with autotune
390 grid = lambda META: (triton.cdiv(out_shape[0], META["BLOCK_SIZE"]),)
391 tile_kernel_1d[grid](
392 inp,
393 out,
394 inp_strides[0] if inp_strides[0] != 0 else 1,
395 out_strides[0] if out_strides[0] != 0 else 1,
396 in0_shape[0],
397 out_shape[0],
398 )
399 elif rank == 2:
400 # 2D case - use 2D blocking with autotune
401 grid = lambda META: (
402 triton.cdiv(out_shape[0], META["BLOCK_M"]),
403 triton.cdiv(out_shape[1], META["BLOCK_N"]),
404 )
405 tile_kernel_2d[grid](
406 inp,
407 out,
408 inp_strides[0],
409 inp_strides[1],
410 out_strides[0],
411 out_strides[1],
412 in0_shape[0],
413 in0_shape[1],
414 out_shape[0],
415 out_shape[1],
416 )
417 elif rank == 3:
418 # 3D case
419 grid = lambda META: (
420 out_shape[0],
421 triton.cdiv(out_shape[1], META["BLOCK_N"])
422 * triton.cdiv(out_shape[2], META["BLOCK_K"]),
423 )
424 tile_kernel_3d[grid](
425 inp,
426 out,
427 inp_strides[0],
428 inp_strides[1],
429 inp_strides[2],
430 out_strides[0],
431 out_strides[1],
432 out_strides[2],
433 in0_shape[0],
434 in0_shape[1],
435 in0_shape[2],
436 out_shape[0],
437 out_shape[1],
438 out_shape[2],
439 )
440 elif rank == 4:
441 # 4D case
442 num_mn = out_shape[0] * out_shape[1]
443 grid = lambda META: (
444 num_mn,
445 triton.cdiv(out_shape[2], META["BLOCK_K"])
446 * triton.cdiv(out_shape[3], META["BLOCK_L"]),
447 )
448 tile_kernel_4d[grid](
449 inp,
450 out,
451 inp_strides[0],
452 inp_strides[1],
453 inp_strides[2],
454 inp_strides[3],
455 out_strides[0],
456 out_strides[1],
457 out_strides[2],
458 out_strides[3],
459 in0_shape[0],
460 in0_shape[1],
461 in0_shape[2],
462 in0_shape[3],
463 out_shape[0],
464 out_shape[1],
465 out_shape[2],
466 out_shape[3],
467 )
468 else:
469 # 5D+ case - use generic kernel
470 BLOCK_SIZE = 1024
471 grid = (min(65535, triton.cdiv(num_tasks, BLOCK_SIZE)),)
473 # Pad shapes and strides to 5D
474 while len(in0_shape) < 5:
475 in0_shape = [1] + in0_shape
476 out_shape = [1] + out_shape
477 inp_strides = [0] + inp_strides
478 out_strides = [0] + out_strides
480 tile_kernel_nd_flat[grid](
481 inp,
482 out,
483 num_tasks,
484 in0_shape[0],
485 in0_shape[1],
486 in0_shape[2],
487 in0_shape[3],
488 in0_shape[4],
489 out_shape[0],
490 out_shape[1],
491 out_shape[2],
492 out_shape[3],
493 out_shape[4],
494 inp_strides[0],
495 inp_strides[1],
496 inp_strides[2],
497 inp_strides[3],
498 inp_strides[4],
499 out_strides[0],
500 out_strides[1],
501 out_strides[2],
502 out_strides[3],
503 out_strides[4],
504 rank=rank,
505 BLOCK_SIZE=BLOCK_SIZE,
506 )
508 return out