Coverage for src/flag_gems/fused/unpack_seq.py: 59%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def _unpack_seq_triton_kernel( 

12 packed_ptr, # [B, Lmax, D] 

13 out_ptr, # [N, D] 

14 lengths_ptr, # *i32, [B] 

15 B: tl.constexpr, 

16 Lmax: tl.constexpr, 

17 D: tl.constexpr, 

18 BLOCK_T: tl.constexpr, # timesteps per program 

19 BLOCK_D: tl.constexpr, # features per program 

20): 

21 pid_b = tl.program_id(0) # batch id 

22 pid_t = tl.program_id(1) # block over time dimension 

23 pid_d = tl.program_id(2) # block over feature dimension 

24 off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] 

25 off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] 

26 

27 # bounds: compute start from cumulative lengths 

28 in_start = 0 

29 for i in range(pid_b): 

30 in_start += tl.load(lengths_ptr + i) 

31 seq_len = tl.load(lengths_ptr + pid_b) 

32 

33 # valid time positions for this block 

34 t_mask = off_t < Lmax 

35 valid_row = (off_t < seq_len) & t_mask 

36 

37 # compute output row indices for valid (b, t) 

38 out_row = in_start + off_t 

39 

40 # Pointers 

41 # packed_ptr: row-major [B, Lmax, D] 

42 packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] 

43 

44 # out_ptr: row-major [N, D] 

45 out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :] 

46 

47 # Load from packed tensor and store to output 

48 d_mask = off_d[None, :] < D 

49 packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask) 

50 tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask) 

51 

52 

53def unpack_seq_triton( 

54 packed_tensor: torch.Tensor, 

55 lengths: torch.Tensor, 

56 block_t: int = 64, 

57 block_d: int = 64, 

58) -> torch.Tensor: 

59 logger.debug("GEMS UNPACK_SEQ_TRITON") 

60 original_shape = packed_tensor.shape 

61 if len(original_shape) > 3: 

62 B, Lmax = original_shape[:2] 

63 packed_reshaped = packed_tensor.reshape(B, Lmax, -1) 

64 D = packed_reshaped.shape[2] 

65 else: 

66 B, Lmax, D = packed_tensor.shape 

67 packed_reshaped = packed_tensor 

68 

69 N = int(lengths.sum().item()) 

70 

71 out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype) 

72 

73 grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) 

74 _unpack_seq_triton_kernel[grid]( 

75 packed_reshaped, 

76 out, 

77 lengths.int(), 

78 B, 

79 Lmax, 

80 D, 

81 BLOCK_T=block_t, 

82 BLOCK_D=block_d, 

83 num_warps=4, 

84 num_stages=2, 

85 ) 

86 

87 if len(original_shape) > 3: 

88 output_shape = (N,) + original_shape[2:] 

89 out = out.reshape(output_shape) 

90 

91 return out