Coverage for src/flag_gems/runtime/backend/_sunrise/ops/triu.py: 0%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.autotune(configs=runtime.get_tuned_config("triu"), key=["M", "N"]) 

17@triton.jit(do_not_specialize=["diagonal"]) 

18def triu_kernel( 

19 X, 

20 Y, 

21 M, 

22 N, 

23 diagonal, 

24 M_BLOCK_SIZE: tl.constexpr, 

25 N_BLOCK_SIZE: tl.constexpr, 

26): 

27 pid = ext.program_id(0) 

28 row = pid * M_BLOCK_SIZE + tl.arange(0, M_BLOCK_SIZE)[:, None] 

29 m_mask = row < M 

30 X += row * N 

31 Y += row * N 

32 

33 for n_offset in range(0, N, N_BLOCK_SIZE): 

34 cols = n_offset + tl.arange(0, N_BLOCK_SIZE)[None, :] 

35 n_mask = cols < N 

36 mask = m_mask and n_mask 

37 

38 x = tl.load(X + cols, mask, other=0.0) 

39 y = tl.where(row + diagonal <= cols, x, 0.0) 

40 tl.store(Y + cols, y, mask=mask) 

41 

42 

43@libentry() 

44@triton.autotune( 

45 configs=runtime.get_tuned_config("triu_batch"), 

46 key=["batch", "MN", "N", "diagonal"], 

47) 

48@triton.jit(do_not_specialize=["diagonal"]) 

49def triu_batch_kernel( 

50 X, 

51 Y, 

52 batch, 

53 MN, 

54 N, 

55 diagonal, 

56 BATCH_BLOCK_SIZE: tl.constexpr, 

57 MN_BLOCK_SIZE: tl.constexpr, 

58): 

59 batch_id = ext.program_id(0) 

60 mn_id = ext.program_id(1) 

61 row = batch_id * BATCH_BLOCK_SIZE + tl.arange(0, BATCH_BLOCK_SIZE)[:, None] 

62 batch_mask = row < batch 

63 X += row * MN 

64 Y += row * MN 

65 

66 cols = mn_id * MN_BLOCK_SIZE + tl.arange(0, MN_BLOCK_SIZE)[None, :] 

67 mn_mask = cols < MN 

68 mask = batch_mask and mn_mask 

69 x = tl.load(X + cols, mask, other=0.0) 

70 m = cols // N 

71 n = cols % N 

72 y = tl.where(m + diagonal <= n, x, 0.0) 

73 tl.store(Y + cols, y, mask=mask) 

74 

75 

76def _check_batch_contiguous(tensor, allow_zero_stride=True): 

77 if tensor.is_contiguous(): 

78 return True, tensor 

79 

80 dims = tensor.dim() 

81 

82 if dims >= 2: 

83 n = tensor.size(-1) 

84 stride_row, stride_col = tensor.stride(-2), tensor.stride(-1) 

85 

86 if not (stride_col == 1 and stride_row == n): 

87 return False, tensor.contiguous() 

88 

89 if allow_zero_stride and dims <= 3: 

90 return True, tensor 

91 

92 expected_stride = tensor.size(-1) * tensor.size(-2) 

93 for i in range(dims - 3, -1, -1): 

94 if ( 

95 allow_zero_stride 

96 and i == 0 

97 and (tensor.stride(i) == 0 or tensor.size(i) == 1) 

98 ): 

99 continue 

100 

101 if tensor.stride(i) != expected_stride: 

102 return False, tensor.contiguous() 

103 

104 expected_stride *= tensor.size(i) 

105 

106 return True, tensor 

107 

108 

109def triu(A, diagonal=0): 

110 logger.debug("GEMS TRIU") 

111 ori_type = A.dtype 

112 out = torch.empty(A.shape, device="ptpu").as_strided(A.shape, A.stride()) 

113 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions" 

114 

115 can_use_directly, A_input = _check_batch_contiguous(A, allow_zero_stride=False) 

116 

117 out = torch.empty( 

118 A.shape, dtype=A.dtype, device=A.device, memory_format=torch.contiguous_format 

119 ) 

120 

121 M, N = A_input.shape[-2:] 

122 

123 with torch_device_fn.device(A_input.device): 

124 if len(A_input.shape) == 2: 

125 grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) 

126 triu_kernel[grid](A_input, out, M, N, diagonal) 

127 else: 

128 batch = int(torch.numel(A_input) / M / N) 

129 B = A_input.view(batch, -1) 

130 grid = lambda meta: ( 

131 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]), 

132 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]), 

133 ) 

134 triu_batch_kernel[grid](B, out, batch, M * N, N, diagonal) 

135 out = out.view(A.shape) 

136 

137 return out.to(ori_type) 

138 

139 

140def triu_(A, diagonal=0): 

141 logger.debug("GEMS TRIU_ (inplace)") 

142 

143 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions" 

144 diagonal = int(diagonal) 

145 M, N = A.shape[-2:] 

146 

147 can_use_directly, A_to_use = _check_batch_contiguous(A, allow_zero_stride=True) 

148 

149 if not can_use_directly: 

150 logger.debug( 

151 "Input tensor does not satisfy contiguity requirements, " 

152 "using temporary tensor for computation" 

153 ) 

154 

155 result_temp = torch.empty_like(A_to_use, memory_format=torch.contiguous_format) 

156 

157 with torch_device_fn.device(A.device): 

158 if len(A.shape) == 2: 

159 grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) 

160 triu_kernel[grid](A_to_use, result_temp, M, N, diagonal) 

161 else: 

162 batch = int(torch.numel(A) / M / N) 

163 B = A_to_use.view(batch, -1) 

164 result_temp_flat = result_temp.view(batch, -1) 

165 grid = lambda meta: ( 

166 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]), 

167 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]), 

168 ) 

169 triu_batch_kernel[grid](B, result_temp_flat, batch, M * N, N, diagonal) 

170 

171 A.copy_(result_temp) 

172 else: 

173 with torch_device_fn.device(A.device): 

174 if len(A.shape) == 2: 

175 grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) 

176 triu_kernel[grid](A, A, M, N, diagonal) 

177 else: 

178 batch = int(torch.numel(A) / M / N) 

179 B = A.view(batch, -1) 

180 grid = lambda meta: ( 

181 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]), 

182 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]), 

183 ) 

184 triu_batch_kernel[grid](B, B, batch, M * N, N, diagonal) 

185 

186 return A