Coverage for src/flag_gems/runtime/backend/_sunrise/ops/log_softmax.py: 0%

189 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15# Filter (TILE_N, num_warps) pairs so each warp has at least 32 lanes. 

16# Drops gross over-subscription (num_warps * 32 > TILE_N) which leaves most 

17# lanes idle on tiny ONE_TILE_PER_CTA launches. 

18_INNER_CONFIGS = [ 

19 triton.Config({"TILE_N": tile_n}, num_warps=num_warps) 

20 for tile_n in (64, 128, 256, 512, 1024) 

21 for num_warps in (1, 2, 4, 8, 16) 

22 if num_warps * 32 <= tile_n 

23] 

24 

25 

26def _one_tile_per_cta(args): 

27 return args["TILE_N"] >= args["N"] 

28 

29 

30@triton.jit 

31def _prev_multiple_of(a, b): 

32 return tl.cdiv(a, b) * b - b 

33 

34 

35@libentry() 

36@triton.jit 

37def log_softmax_kernel( 

38 output_ptr, 

39 input_ptr, 

40 M, 

41 N, 

42 K, 

43 BLOCK_M: tl.constexpr = 8, 

44 BLOCK_N: tl.constexpr = 256, 

45): 

46 pid_m = ext.program_id(0) 

47 pid_k = ext.program_id(1) 

48 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

49 

50 # TODO(chenfeiyu): consider float64 add add a utility function to get accumulator type 

51 m = tl.full([BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32) 

52 z = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32) 

53 for start_n in range(0, N, BLOCK_N): 

54 n_offset = start_n + tl.arange(0, BLOCK_N) 

55 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

56 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N) 

57 input_ptrs = input_ptr + offset 

58 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

59 m_new = tl.maximum(inp, m) 

60 all_neg_inf = m_new == float("-inf") 

61 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

62 m = m_new 

63 

64 m_reduced = tl.max(m, 1) 

65 z = tl.sum(z * tl.exp(m - m_reduced[:, None]), 1) 

66 m = m_reduced 

67 

68 for start_n in range(0, N, BLOCK_N): 

69 n_offset = start_n + tl.arange(0, BLOCK_N) 

70 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

71 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N) 

72 input_ptrs = input_ptr + offset 

73 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

74 o = inp - m[:, None] - tl.log(z[:, None]) 

75 tl.store(output_ptr + offset, o, mask=mask) 

76 

77 

78@libentry() 

79@triton.autotune(configs=_INNER_CONFIGS, key=["M", "N"]) 

80@triton.heuristics({"ONE_TILE_PER_CTA": _one_tile_per_cta}) 

81@triton.jit 

82def log_softmax_kernel_inner( 

83 output_ptr, 

84 input_ptr, 

85 M, 

86 N, 

87 TILE_N: tl.constexpr, 

88 ONE_TILE_PER_CTA: tl.constexpr, 

89): 

90 pid_m = ext.program_id(0) 

91 

92 if ONE_TILE_PER_CTA: 

93 n_offsets = tl.arange(0, TILE_N) 

94 offset = pid_m * N + n_offsets 

95 mask = n_offsets < N 

96 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(tl.float32) 

97 m = tl.max(inp, 0) 

98 e = tl.exp(inp - m) 

99 z = tl.sum(e, 0) 

100 out = inp - m - tl.log(z) 

101 tl.store(output_ptr + offset, out, mask=mask) 

102 else: 

103 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32) 

104 z = tl.full([TILE_N], value=0.0, dtype=tl.float32) 

105 input_ptr += pid_m * N 

106 output_ptr += pid_m * N 

107 

108 # Pass 1: mask-free hot loop + masked tail 

109 previous_multiple = _prev_multiple_of(N, TILE_N) 

110 for start_n in range(0, previous_multiple, TILE_N): 

111 n_offset = start_n + tl.arange(0, TILE_N) 

112 inp = tl.load(input_ptr + n_offset) 

113 m_new = tl.maximum(m, inp) 

114 all_neg_inf = m_new == float("-inf") 

115 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

116 m = m_new 

117 for start_n in range(previous_multiple, N, TILE_N): 

118 n_offset = start_n + tl.arange(0, TILE_N) 

119 mask = n_offset < N 

120 inp = tl.load(input_ptr + n_offset, mask=mask, other=-float("inf")) 

121 m_new = tl.maximum(m, inp) 

122 all_neg_inf = m_new == float("-inf") 

123 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

124 m = m_new 

125 

126 m_reduced = tl.max(m, 0) 

127 z = tl.sum(z * tl.exp(m - m_reduced), 0) 

128 m = m_reduced 

129 log_z = tl.log(z) 

130 

131 # Pass 2: reverse traversal with eviction hints 

132 previous_multiple = _prev_multiple_of(N, TILE_N) 

133 for start_n in range(0, TILE_N, TILE_N): 

134 n_offset = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

135 mask = n_offset < N 

136 inp = tl.load( 

137 input_ptr + n_offset, 

138 mask=mask, 

139 other=-float("inf"), 

140 eviction_policy="evict_first", 

141 ) 

142 o = inp - m - log_z 

143 tl.store(output_ptr + n_offset, o, mask=mask) 

144 for start_n in range(TILE_N, N, TILE_N): 

145 n_offset = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

146 inp = tl.load(input_ptr + n_offset, eviction_policy="evict_first") 

147 o = inp - m - log_z 

148 tl.store(output_ptr + n_offset, o) 

149 

150 

151@libentry() 

152@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"]) 

153@triton.jit 

154def log_softmax_backward_kernel( 

155 out_ptr, 

156 out_grad_ptr, 

157 in_grad_ptr, 

158 M, 

159 N, 

160 K, 

161 BLOCK_M: tl.constexpr, 

162 BLOCK_N: tl.constexpr, 

163): 

164 pid_m = ext.program_id(0) 

165 pid_k = ext.program_id(1) 

166 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

167 

168 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

169 for start_n in range(0, N, BLOCK_N): 

170 n_offset = start_n + tl.arange(0, BLOCK_N) 

171 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

172 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N) 

173 out_grad_ptrs = out_grad_ptr + offsets 

174 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

175 scale += out_grad 

176 scale = tl.sum(scale, 1) 

177 

178 for start_n in range(0, N, BLOCK_N): 

179 n_offset = start_n + tl.arange(0, BLOCK_N) 

180 offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

181 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N) 

182 out_ptrs = out_ptr + offsets 

183 out = tl.load(out_ptrs, mask=mask).to(tl.float32) 

184 out_grad_ptrs = out_grad_ptr + offsets 

185 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

186 in_grad = out_grad - tl.exp(out) * scale[:, None] 

187 in_grad_ptrs = in_grad_ptr + offsets 

188 tl.store(in_grad_ptrs, in_grad, mask=mask) 

189 

190 

191@libentry() 

192@triton.autotune(configs=runtime.get_tuned_config("log_softmax"), key=["M", "N"]) 

193@triton.jit 

194def log_softmax_backward_kernel_opt( 

195 out_ptr, 

196 out_grad_ptr, 

197 in_grad_ptr, 

198 M, 

199 N, 

200 BLOCK_M: tl.constexpr, 

201 BLOCK_N: tl.constexpr, 

202): 

203 pid_m = ext.program_id(0) 

204 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

205 

206 scale = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

207 for start_n in range(0, N, BLOCK_N): 

208 n_offset = start_n + tl.arange(0, BLOCK_N) 

209 offsets = m_offset[:, None] * N + n_offset[None, :] 

210 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N) 

211 out_grad_ptrs = out_grad_ptr + offsets 

212 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

213 scale += out_grad 

214 scale = tl.sum(scale, 1) 

215 

216 for start_n in range(0, N, BLOCK_N): 

217 n_offset = start_n + tl.arange(0, BLOCK_N) 

218 offsets = m_offset[:, None] * N + n_offset[None, :] 

219 mask = (m_offset[:, None] < M) & (n_offset[None, :] < N) 

220 out_ptrs = out_ptr + offsets 

221 out = tl.load(out_ptrs, mask=mask).to(tl.float32) 

222 out_grad_ptrs = out_grad_ptr + offsets 

223 out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32) 

224 in_grad = out_grad - tl.exp(out) * scale[:, None] 

225 in_grad_ptrs = in_grad_ptr + offsets 

226 tl.store(in_grad_ptrs, in_grad, mask=mask) 

227 

228 

229def log_softmax(self, dim, half_to_float=False): 

230 logger.debug("GEMS LOG_SOFTMAX") 

231 

232 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

233 dim = dim % self.ndim 

234 M = 1 

235 N = self.shape[dim] 

236 for i in range(dim): 

237 M *= self.shape[i] 

238 inp = self.contiguous() 

239 if half_to_float: 

240 dtype = torch.float32 

241 else: 

242 dtype = self.dtype 

243 out = torch.empty_like(inp, dtype=dtype) 

244 K = inp.numel() // M // N 

245 

246 with torch_device_fn.device(inp.device): 

247 if K == 1: 

248 grid = (M, 1, 1) 

249 log_softmax_kernel_inner[grid](out, inp, M, N) 

250 else: 

251 grid = lambda meta: ( 

252 triton.cdiv(M, meta["BLOCK_M"]), 

253 K, 

254 ) 

255 log_softmax_kernel[grid]( 

256 out, 

257 inp, 

258 M, 

259 N, 

260 K, 

261 num_warps=16, 

262 ) 

263 return out 

264 

265 

266def log_softmax_backward(grad_output, output, dim, input_dtype): 

267 logger.debug("GEMS LOG_SOFTMAX VJP") 

268 

269 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

270 dim = dim % output.ndim 

271 M = 1 

272 N = output.shape[dim] 

273 for i in range(dim): 

274 M *= output.shape[i] 

275 

276 grad_output = grad_output.contiguous() 

277 in_grad = torch.empty_like(output, dtype=input_dtype) 

278 K = output.numel() // M // N 

279 if K == 1: 

280 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

281 with torch_device_fn.device(in_grad.device): 

282 log_softmax_backward_kernel_opt[grid]( 

283 output, 

284 grad_output, 

285 in_grad, 

286 M, 

287 N, 

288 ) 

289 else: 

290 grid = lambda meta: ( 

291 triton.cdiv(M, meta["BLOCK_M"]), 

292 K, 

293 ) 

294 with torch_device_fn.device(in_grad.device): 

295 log_softmax_backward_kernel[grid]( 

296 output, 

297 grad_output, 

298 in_grad, 

299 M, 

300 N, 

301 K, 

302 ) 

303 return in_grad