Coverage for src/flag_gems/fused/pack_seq.py: 54%
50 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def _pack_seq_kernel(
12 x_ptr, # [N, D]
13 out_ptr, # [B, Lmax, D]
14 lengths_ptr, # *i32, [B]
15 N: tl.constexpr,
16 D: tl.constexpr,
17 Lmax: tl.constexpr,
18 PAD_VALUE: tl.constexpr,
19 PAD_IS_UINT8: tl.constexpr,
20 BLOCK_T: tl.constexpr, # timesteps per program
21 BLOCK_D: tl.constexpr, # features per program
22):
23 pid_b = tl.program_id(0) # batch id
24 pid_t = tl.program_id(1) # block over time dimension
25 pid_d = tl.program_id(2) # block over feature dimension
26 off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
27 off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
29 # Compute start index and sequence length from cumulative lengths
30 in_start = 0
31 for i in range(pid_b):
32 in_start += tl.load(lengths_ptr + i)
33 seq_len = tl.load(lengths_ptr + pid_b)
35 # valid time positions for this block
36 t_mask = off_t < Lmax
38 # compute input row indices for valid (b, t)
39 in_row = in_start + off_t
40 valid_row = (off_t < seq_len) & t_mask
42 # Pointers
43 # x_ptr: row-major [N, D]
44 x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
46 # out_ptr: row-major [B, Lmax, D]
47 out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :]
49 # Initialize with PAD. PAD_IS_UINT8 selects the pad tensor's dtype so
50 # integer-typed outputs (e.g. MXFP4 packed nibbles, ue8m0 scale bytes)
51 # get an exact-byte pad rather than going through an fp32→uint8 cast
52 # that's implementation-defined outside of value 0.
53 d_mask = off_d[None, :] < D
54 if PAD_IS_UINT8:
55 pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.uint8)
56 else:
57 pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
58 tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
60 # Load & write only where within seq_len
61 x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
62 tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
65def pack_seq_triton(
66 x: torch.Tensor,
67 lengths: torch.Tensor,
68 pad_value: float | int = -float("inf"),
69 block_t: int = 64,
70 block_d: int = 64,
71) -> torch.Tensor:
72 logger.debug("GEMS PACK_SEQ_TRITON")
73 is_uint8 = x.dtype == torch.uint8
74 if is_uint8:
75 assert (
76 isinstance(pad_value, int) and 0 <= pad_value <= 255
77 ), f"uint8 pack requires an integer pad in [0, 255], got {pad_value!r}"
78 pad_constexpr: int | float = int(pad_value)
79 else:
80 pad_constexpr = float(pad_value)
82 original_shape = x.shape
83 if len(original_shape) > 2:
84 N = original_shape[0]
85 x_reshaped = x.reshape(N, -1)
86 D = x_reshaped.shape[1]
87 else:
88 N, D = x.shape
89 x_reshaped = x
91 B = lengths.numel()
92 Lmax = int(lengths.max().item())
94 out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
96 grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
97 _pack_seq_kernel[grid](
98 x_reshaped,
99 out,
100 lengths.int(),
101 N,
102 D,
103 Lmax,
104 PAD_VALUE=pad_constexpr,
105 PAD_IS_UINT8=is_uint8,
106 BLOCK_T=block_t,
107 BLOCK_D=block_d,
108 num_warps=4,
109 num_stages=2,
110 )
112 if len(original_shape) > 2:
113 out = out.reshape((B, Lmax) + original_shape[1:])
115 return out