Coverage for src/flag_gems/fused/pack_seq.py: 54%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def _pack_seq_kernel( 

12 x_ptr, # [N, D] 

13 out_ptr, # [B, Lmax, D] 

14 lengths_ptr, # *i32, [B] 

15 N: tl.constexpr, 

16 D: tl.constexpr, 

17 Lmax: tl.constexpr, 

18 PAD_VALUE: tl.constexpr, 

19 PAD_IS_UINT8: tl.constexpr, 

20 BLOCK_T: tl.constexpr, # timesteps per program 

21 BLOCK_D: tl.constexpr, # features per program 

22): 

23 pid_b = tl.program_id(0) # batch id 

24 pid_t = tl.program_id(1) # block over time dimension 

25 pid_d = tl.program_id(2) # block over feature dimension 

26 off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] 

27 off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] 

28 

29 # Compute start index and sequence length from cumulative lengths 

30 in_start = 0 

31 for i in range(pid_b): 

32 in_start += tl.load(lengths_ptr + i) 

33 seq_len = tl.load(lengths_ptr + pid_b) 

34 

35 # valid time positions for this block 

36 t_mask = off_t < Lmax 

37 

38 # compute input row indices for valid (b, t) 

39 in_row = in_start + off_t 

40 valid_row = (off_t < seq_len) & t_mask 

41 

42 # Pointers 

43 # x_ptr: row-major [N, D] 

44 x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :] 

45 

46 # out_ptr: row-major [B, Lmax, D] 

47 out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] 

48 

49 # Initialize with PAD. PAD_IS_UINT8 selects the pad tensor's dtype so 

50 # integer-typed outputs (e.g. MXFP4 packed nibbles, ue8m0 scale bytes) 

51 # get an exact-byte pad rather than going through an fp32→uint8 cast 

52 # that's implementation-defined outside of value 0. 

53 d_mask = off_d[None, :] < D 

54 if PAD_IS_UINT8: 

55 pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.uint8) 

56 else: 

57 pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32) 

58 tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask) 

59 

60 # Load & write only where within seq_len 

61 x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask) 

62 tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask) 

63 

64 

65def pack_seq_triton( 

66 x: torch.Tensor, 

67 lengths: torch.Tensor, 

68 pad_value: float | int = -float("inf"), 

69 block_t: int = 64, 

70 block_d: int = 64, 

71) -> torch.Tensor: 

72 logger.debug("GEMS PACK_SEQ_TRITON") 

73 is_uint8 = x.dtype == torch.uint8 

74 if is_uint8: 

75 assert ( 

76 isinstance(pad_value, int) and 0 <= pad_value <= 255 

77 ), f"uint8 pack requires an integer pad in [0, 255], got {pad_value!r}" 

78 pad_constexpr: int | float = int(pad_value) 

79 else: 

80 pad_constexpr = float(pad_value) 

81 

82 original_shape = x.shape 

83 if len(original_shape) > 2: 

84 N = original_shape[0] 

85 x_reshaped = x.reshape(N, -1) 

86 D = x_reshaped.shape[1] 

87 else: 

88 N, D = x.shape 

89 x_reshaped = x 

90 

91 B = lengths.numel() 

92 Lmax = int(lengths.max().item()) 

93 

94 out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype) 

95 

96 grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) 

97 _pack_seq_kernel[grid]( 

98 x_reshaped, 

99 out, 

100 lengths.int(), 

101 N, 

102 D, 

103 Lmax, 

104 PAD_VALUE=pad_constexpr, 

105 PAD_IS_UINT8=is_uint8, 

106 BLOCK_T=block_t, 

107 BLOCK_D=block_d, 

108 num_warps=4, 

109 num_stages=2, 

110 ) 

111 

112 if len(original_shape) > 2: 

113 out = out.reshape((B, Lmax) + original_shape[1:]) 

114 

115 return out