Coverage for src/flag_gems/runtime/backend/_mthreads/ops/tile.py: 0%
177 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as tle
9from flag_gems.utils.libentry import libentry
11logger = logging.getLogger(__name__)
13_FALLBACK_KEYSET = torch._C.DispatchKeySet(
14 torch._C.DispatchKey.CompositeImplicitAutograd
15)
18@libentry()
19@triton.autotune(
20 configs=[
21 triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4),
22 triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4),
23 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4),
24 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8),
25 triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=4),
26 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4),
27 triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=8),
28 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8),
29 ],
30 key=["out_shape0", "out_shape1"],
31)
32@triton.jit
33def tile_kernel_2d(
34 inp_ptr,
35 out_ptr,
36 inp_stride0,
37 inp_stride1,
38 out_stride0,
39 out_stride1,
40 inp_shape0,
41 inp_shape1,
42 out_shape0,
43 out_shape1,
44 BLOCK_M: tl.constexpr,
45 BLOCK_N: tl.constexpr,
46):
47 pid_m = tle.program_id(0)
48 pid_n = tle.program_id(1)
50 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
51 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
53 mask_m = offs_m < out_shape0
54 mask_n = offs_n < out_shape1
55 mask = mask_m[:, None] & mask_n[None, :]
57 # Map output indices to input indices using modulo
58 inp_offs_m = offs_m % inp_shape0
59 inp_offs_n = offs_n % inp_shape1
61 # Load from input
62 inp_ptrs = (
63 inp_ptr + inp_offs_m[:, None] * inp_stride0 + inp_offs_n[None, :] * inp_stride1
64 )
65 data = tl.load(inp_ptrs, mask=mask, other=0.0)
67 # Store to output
68 out_ptrs = out_ptr + offs_m[:, None] * out_stride0 + offs_n[None, :] * out_stride1
69 tl.store(out_ptrs, data, mask=mask)
72@libentry()
73@triton.autotune(
74 configs=[
75 triton.Config({"BLOCK_SIZE": 256}, num_warps=4),
76 triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
77 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8),
78 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
79 ],
80 key=["out_shape0"],
81)
82@triton.jit
83def tile_kernel_1d(
84 inp_ptr,
85 out_ptr,
86 inp_stride0,
87 out_stride0,
88 inp_shape0,
89 out_shape0,
90 BLOCK_SIZE: tl.constexpr,
91):
92 pid = tle.program_id(0)
93 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
94 mask = offs < out_shape0
96 # Map output indices to input indices
97 inp_offs = offs % inp_shape0
99 # Load and store
100 data = tl.load(inp_ptr + inp_offs * inp_stride0, mask=mask)
101 tl.store(out_ptr + offs * out_stride0, data, mask=mask)
104@libentry()
105@triton.autotune(
106 configs=[
107 triton.Config({"BLOCK_N": 32, "BLOCK_K": 32}, num_warps=4),
108 triton.Config({"BLOCK_N": 64, "BLOCK_K": 32}, num_warps=4),
109 triton.Config({"BLOCK_N": 32, "BLOCK_K": 64}, num_warps=4),
110 triton.Config({"BLOCK_N": 64, "BLOCK_K": 64}, num_warps=8),
111 triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_warps=4),
112 triton.Config({"BLOCK_N": 32, "BLOCK_K": 128}, num_warps=4),
113 ],
114 key=["out_shape1", "out_shape2"],
115)
116@triton.jit
117def tile_kernel_3d(
118 inp_ptr,
119 out_ptr,
120 inp_stride0,
121 inp_stride1,
122 inp_stride2,
123 out_stride0,
124 out_stride1,
125 out_stride2,
126 inp_shape0,
127 inp_shape1,
128 inp_shape2,
129 out_shape0,
130 out_shape1,
131 out_shape2,
132 BLOCK_N: tl.constexpr,
133 BLOCK_K: tl.constexpr,
134):
135 """Process 3D tile: one program handles one (m, n_block, k_block)"""
136 pid_m = tle.program_id(0)
137 pid_nk = tle.program_id(1)
139 num_k_blocks = tl.cdiv(out_shape2, BLOCK_K)
140 pid_n = pid_nk // num_k_blocks
141 pid_k = pid_nk % num_k_blocks
143 m_idx = pid_m
144 if m_idx >= out_shape0:
145 return
147 inp_m = m_idx % inp_shape0
149 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
150 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
152 mask_n = offs_n < out_shape1
153 mask_k = offs_k < out_shape2
154 mask = mask_n[:, None] & mask_k[None, :]
156 inp_n = offs_n % inp_shape1
157 inp_k = offs_k % inp_shape2
159 inp_ptrs = (
160 inp_ptr
161 + inp_m * inp_stride0
162 + inp_n[:, None] * inp_stride1
163 + inp_k[None, :] * inp_stride2
164 )
165 data = tl.load(inp_ptrs, mask=mask, other=0.0)
167 out_ptrs = (
168 out_ptr
169 + m_idx * out_stride0
170 + offs_n[:, None] * out_stride1
171 + offs_k[None, :] * out_stride2
172 )
173 tl.store(out_ptrs, data, mask=mask)
176@libentry()
177@triton.autotune(
178 configs=[
179 triton.Config({"BLOCK_K": 32, "BLOCK_L": 32}, num_warps=4),
180 triton.Config({"BLOCK_K": 64, "BLOCK_L": 32}, num_warps=4),
181 triton.Config({"BLOCK_K": 32, "BLOCK_L": 64}, num_warps=4),
182 triton.Config({"BLOCK_K": 64, "BLOCK_L": 64}, num_warps=8),
183 triton.Config({"BLOCK_K": 128, "BLOCK_L": 32}, num_warps=4),
184 triton.Config({"BLOCK_K": 32, "BLOCK_L": 128}, num_warps=4),
185 ],
186 key=["out_shape2", "out_shape3"],
187)
188@triton.jit
189def tile_kernel_4d(
190 inp_ptr,
191 out_ptr,
192 inp_stride0,
193 inp_stride1,
194 inp_stride2,
195 inp_stride3,
196 out_stride0,
197 out_stride1,
198 out_stride2,
199 out_stride3,
200 inp_shape0,
201 inp_shape1,
202 inp_shape2,
203 inp_shape3,
204 out_shape0,
205 out_shape1,
206 out_shape2,
207 out_shape3,
208 BLOCK_K: tl.constexpr,
209 BLOCK_L: tl.constexpr,
210):
211 """Process 4D tile: one program handles one (m, n, k_block, l_block)"""
212 pid_mn = tle.program_id(0)
213 pid_kl = tle.program_id(1)
215 num_l_blocks = tl.cdiv(out_shape3, BLOCK_L)
216 pid_k = pid_kl // num_l_blocks
217 pid_l = pid_kl % num_l_blocks
219 # Flatten m, n
220 m_idx = pid_mn // out_shape1
221 n_idx = pid_mn % out_shape1
223 if m_idx >= out_shape0:
224 return
226 inp_m = m_idx % inp_shape0
227 inp_n = n_idx % inp_shape1
229 offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
230 offs_l = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
232 mask_k = offs_k < out_shape2
233 mask_l = offs_l < out_shape3
234 mask = mask_k[:, None] & mask_l[None, :]
236 inp_k = offs_k % inp_shape2
237 inp_l = offs_l % inp_shape3
239 inp_ptrs = (
240 inp_ptr
241 + inp_m * inp_stride0
242 + inp_n * inp_stride1
243 + inp_k[:, None] * inp_stride2
244 + inp_l[None, :] * inp_stride3
245 )
246 data = tl.load(inp_ptrs, mask=mask, other=0.0)
248 out_ptrs = (
249 out_ptr
250 + m_idx * out_stride0
251 + n_idx * out_stride1
252 + offs_k[:, None] * out_stride2
253 + offs_l[None, :] * out_stride3
254 )
255 tl.store(out_ptrs, data, mask=mask)
258@libentry()
259@triton.jit
260def tile_kernel_nd_flat(
261 inp_ptr,
262 out_ptr,
263 num_tasks,
264 inp_shape0,
265 inp_shape1,
266 inp_shape2,
267 inp_shape3,
268 inp_shape4,
269 out_shape0,
270 out_shape1,
271 out_shape2,
272 out_shape3,
273 out_shape4,
274 inp_stride0,
275 inp_stride1,
276 inp_stride2,
277 inp_stride3,
278 inp_stride4,
279 out_stride0,
280 out_stride1,
281 out_stride2,
282 out_stride3,
283 out_stride4,
284 rank: tl.constexpr,
285 BLOCK_SIZE: tl.constexpr,
286):
287 """Generic N-D tile kernel (up to 5D) using flat indexing with modulo"""
288 pid = tle.program_id(0)
289 num_ctas = tle.num_programs(0)
291 for idx in range(pid * BLOCK_SIZE, num_tasks, num_ctas * BLOCK_SIZE):
292 offs = idx + tl.arange(0, BLOCK_SIZE)
293 mask = offs < num_tasks
295 remaining = offs
297 # Unroll for up to 5D
298 if rank >= 5:
299 out_idx4 = remaining % out_shape4
300 inp_idx4 = out_idx4 % inp_shape4
301 remaining = remaining // out_shape4
302 else:
303 out_idx4 = tl.zeros_like(offs)
304 inp_idx4 = tl.zeros_like(offs)
306 if rank >= 4:
307 out_idx3 = remaining % out_shape3
308 inp_idx3 = out_idx3 % inp_shape3
309 remaining = remaining // out_shape3
310 else:
311 out_idx3 = tl.zeros_like(offs)
312 inp_idx3 = tl.zeros_like(offs)
314 if rank >= 3:
315 out_idx2 = remaining % out_shape2
316 inp_idx2 = out_idx2 % inp_shape2
317 remaining = remaining // out_shape2
318 else:
319 out_idx2 = tl.zeros_like(offs)
320 inp_idx2 = tl.zeros_like(offs)
322 if rank >= 2:
323 out_idx1 = remaining % out_shape1
324 inp_idx1 = out_idx1 % inp_shape1
325 remaining = remaining // out_shape1
326 else:
327 out_idx1 = tl.zeros_like(offs)
328 inp_idx1 = tl.zeros_like(offs)
330 out_idx0 = remaining
331 inp_idx0 = out_idx0 % inp_shape0
333 inp_offset = (
334 inp_idx0 * inp_stride0
335 + inp_idx1 * inp_stride1
336 + inp_idx2 * inp_stride2
337 + inp_idx3 * inp_stride3
338 + inp_idx4 * inp_stride4
339 )
340 out_offset = (
341 out_idx0 * out_stride0
342 + out_idx1 * out_stride1
343 + out_idx2 * out_stride2
344 + out_idx3 * out_stride3
345 + out_idx4 * out_stride4
346 )
348 data = tl.load(inp_ptr + inp_offset, mask=mask)
349 tl.store(out_ptr + out_offset, data, mask=mask)
352def tile(inp: torch.Tensor, dims) -> torch.Tensor:
353 logger.debug("GEMS TILE")
354 if torch.is_grad_enabled():
355 return torch.ops.aten.tile.default.redispatch(_FALLBACK_KEYSET, inp, dims)
357 in0_rank = inp.dim()
358 dims_rank = len(dims)
359 in0_shape = list(inp.shape)
360 dims_shape = list(dims)
362 # Normalize shapes
363 if dims_rank < in0_rank:
364 diff = in0_rank - dims_rank
365 dims_shape = [1] * diff + dims_shape
366 elif dims_rank > in0_rank:
367 diff = dims_rank - in0_rank
368 in0_shape = [1] * diff + in0_shape
370 # Check for empty and compute output shape
371 is_empty = False
372 out_shape = []
373 for i in range(len(in0_shape)):
374 assert (
375 dims_shape[i] >= 0
376 ), f"the number of repetitions per dimension out of range (expected to >= 0) but got {dims_shape[i]}"
377 if dims_shape[i] == 0:
378 is_empty = True
379 out_shape.append(in0_shape[i] * dims_shape[i])
381 out = torch.empty(out_shape, device=inp.device, dtype=inp.dtype)
383 if is_empty:
384 return out
386 inp = inp.reshape(in0_shape)
387 rank = len(out_shape)
388 num_tasks = out.numel()
390 # Get strides (handle 0-sized dimensions)
391 inp_strides = list(inp.stride())
392 out_strides = list(out.stride())
394 with torch_device_fn.device(inp.device.index):
395 if rank == 1:
396 # 1D case with autotune
397 grid = lambda META: (triton.cdiv(out_shape[0], META["BLOCK_SIZE"]),)
398 tile_kernel_1d[grid](
399 inp,
400 out,
401 inp_strides[0] if inp_strides[0] != 0 else 1,
402 out_strides[0] if out_strides[0] != 0 else 1,
403 in0_shape[0],
404 out_shape[0],
405 )
406 elif rank == 2:
407 # 2D case - use 2D blocking with autotune
408 grid = lambda META: (
409 triton.cdiv(out_shape[0], META["BLOCK_M"]),
410 triton.cdiv(out_shape[1], META["BLOCK_N"]),
411 )
412 tile_kernel_2d[grid](
413 inp,
414 out,
415 inp_strides[0],
416 inp_strides[1],
417 out_strides[0],
418 out_strides[1],
419 in0_shape[0],
420 in0_shape[1],
421 out_shape[0],
422 out_shape[1],
423 )
424 elif rank == 3:
425 # 3D case
426 grid = lambda META: (
427 out_shape[0],
428 triton.cdiv(out_shape[1], META["BLOCK_N"])
429 * triton.cdiv(out_shape[2], META["BLOCK_K"]),
430 )
431 tile_kernel_3d[grid](
432 inp,
433 out,
434 inp_strides[0],
435 inp_strides[1],
436 inp_strides[2],
437 out_strides[0],
438 out_strides[1],
439 out_strides[2],
440 in0_shape[0],
441 in0_shape[1],
442 in0_shape[2],
443 out_shape[0],
444 out_shape[1],
445 out_shape[2],
446 )
447 elif rank == 4:
448 # 4D case
449 num_mn = out_shape[0] * out_shape[1]
450 grid = lambda META: (
451 num_mn,
452 triton.cdiv(out_shape[2], META["BLOCK_K"])
453 * triton.cdiv(out_shape[3], META["BLOCK_L"]),
454 )
455 tile_kernel_4d[grid](
456 inp,
457 out,
458 inp_strides[0],
459 inp_strides[1],
460 inp_strides[2],
461 inp_strides[3],
462 out_strides[0],
463 out_strides[1],
464 out_strides[2],
465 out_strides[3],
466 in0_shape[0],
467 in0_shape[1],
468 in0_shape[2],
469 in0_shape[3],
470 out_shape[0],
471 out_shape[1],
472 out_shape[2],
473 out_shape[3],
474 )
475 else:
476 # 5D+ case - use generic kernel
477 BLOCK_SIZE = 1024
478 grid = (min(65535, triton.cdiv(num_tasks, BLOCK_SIZE)),)
480 # Pad shapes and strides to 5D
481 while len(in0_shape) < 5:
482 in0_shape = [1] + in0_shape
483 out_shape = [1] + out_shape
484 inp_strides = [0] + inp_strides
485 out_strides = [0] + out_strides
487 tile_kernel_nd_flat[grid](
488 inp,
489 out,
490 num_tasks,
491 in0_shape[0],
492 in0_shape[1],
493 in0_shape[2],
494 in0_shape[3],
495 in0_shape[4],
496 out_shape[0],
497 out_shape[1],
498 out_shape[2],
499 out_shape[3],
500 out_shape[4],
501 inp_strides[0],
502 inp_strides[1],
503 inp_strides[2],
504 inp_strides[3],
505 inp_strides[4],
506 out_strides[0],
507 out_strides[1],
508 out_strides[2],
509 out_strides[3],
510 out_strides[4],
511 rank=rank,
512 BLOCK_SIZE=BLOCK_SIZE,
513 )
515 return out