Coverage for src/flag_gems/fused/unpack_seq.py: 59%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def _unpack_seq_triton_kernel(
12 packed_ptr, # [B, Lmax, D]
13 out_ptr, # [N, D]
14 lengths_ptr, # *i32, [B]
15 B: tl.constexpr,
16 Lmax: tl.constexpr,
17 D: tl.constexpr,
18 BLOCK_T: tl.constexpr, # timesteps per program
19 BLOCK_D: tl.constexpr, # features per program
20):
21 pid_b = tl.program_id(0) # batch id
22 pid_t = tl.program_id(1) # block over time dimension
23 pid_d = tl.program_id(2) # block over feature dimension
24 off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
25 off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
27 # bounds: compute start from cumulative lengths
28 in_start = 0
29 for i in range(pid_b):
30 in_start += tl.load(lengths_ptr + i)
31 seq_len = tl.load(lengths_ptr + pid_b)
33 # valid time positions for this block
34 t_mask = off_t < Lmax
35 valid_row = (off_t < seq_len) & t_mask
37 # compute output row indices for valid (b, t)
38 out_row = in_start + off_t
40 # Pointers
41 # packed_ptr: row-major [B, Lmax, D]
42 packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
44 # out_ptr: row-major [N, D]
45 out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
47 # Load from packed tensor and store to output
48 d_mask = off_d[None, :] < D
49 packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
50 tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
53def unpack_seq_triton(
54 packed_tensor: torch.Tensor,
55 lengths: torch.Tensor,
56 block_t: int = 64,
57 block_d: int = 64,
58) -> torch.Tensor:
59 logger.debug("GEMS UNPACK_SEQ_TRITON")
60 original_shape = packed_tensor.shape
61 if len(original_shape) > 3:
62 B, Lmax = original_shape[:2]
63 packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
64 D = packed_reshaped.shape[2]
65 else:
66 B, Lmax, D = packed_tensor.shape
67 packed_reshaped = packed_tensor
69 N = int(lengths.sum().item())
71 out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype)
73 grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
74 _unpack_seq_triton_kernel[grid](
75 packed_reshaped,
76 out,
77 lengths.int(),
78 B,
79 Lmax,
80 D,
81 BLOCK_T=block_t,
82 BLOCK_D=block_d,
83 num_warps=4,
84 num_stages=2,
85 )
87 if len(original_shape) > 3:
88 output_shape = (N,) + original_shape[2:]
89 out = out.reshape(output_shape)
91 return out