Coverage for src/flag_gems/runtime/backend/_sunrise/ops/softmax.py: 0%

247 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.ops.zeros import zero_ 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) 

18@triton.jit 

19def softmax_kernel_non_inner( 

20 output_ptr, 

21 input_ptr, 

22 M, 

23 N, 

24 K, 

25 TILE_N: tl.constexpr, 

26 TILE_K: tl.constexpr, 

27 ONE_TILE_PER_CTA: tl.constexpr, 

28): 

29 pid_k = ext.program_id(1) 

30 pid_m = ext.program_id(0) 

31 

32 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K) 

33 

34 if ONE_TILE_PER_CTA: 

35 n_offsets = tl.arange(0, TILE_N) 

36 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

37 mask = (n_offsets[:, None] < N) & (k_offsets < K) 

38 input_ptrs = input_ptr + offset 

39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) 

40 m = tl.max(inp, 0) 

41 e = tl.exp(inp - m[None, :]) 

42 z = tl.sum(e, 0) 

43 out = e / z 

44 output_ptrs = output_ptr + offset 

45 tl.store(output_ptrs, out, mask=mask) 

46 else: 

47 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32) 

48 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32) 

49 

50 # specialization does not improve performance inn this example, as tested 

51 for start_n in range(0, N, TILE_N): 

52 n_offsets = start_n + tl.arange(0, TILE_N) 

53 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

54 mask = (n_offsets[:, None] < N) & (k_offsets < K) 

55 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")) 

56 m_new = tl.maximum(m, inp) 

57 all_neg_inf = m_new == float("-inf") 

58 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

59 m = m_new 

60 

61 m_reduced = tl.max(m, 0) # (TILE_K,) 

62 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, ) 

63 m = m_reduced 

64 

65 # specialization does not improve performance inn this example, as tested 

66 previous_multiple = prev_multiple_of(N, TILE_N) 

67 for start_n in range(0, N, TILE_N): 

68 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

69 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

70 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K) 

71 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")) 

72 o = tl.exp(inp - m[None, :]) / z[None, :] 

73 tl.store(output_ptr + offsets, o, mask=mask) 

74 

75 

76@triton.jit 

77def next_multiple_of(a, b): 

78 # the smallest x>=a that x%b ==0 

79 return tl.cidv(a, b) * b 

80 

81 

82@triton.jit 

83def prev_multiple_of(a, b): 

84 # the largest x<a that x%b ==0 

85 return tl.cdiv(a, b) * b - b 

86 

87 

88@libentry() 

89@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

90@triton.jit 

91def softmax_kernel_inner( 

92 output_ptr, 

93 input_ptr, 

94 M, 

95 N, 

96 TILE_N: tl.constexpr, 

97 ONE_TILE_PER_CTA: tl.constexpr, 

98): 

99 pid_m = ext.program_id(0) 

100 if ONE_TILE_PER_CTA: 

101 n_offsets = tl.arange(0, TILE_N) 

102 offset = pid_m * N + n_offsets 

103 input_ptrs = input_ptr + offset 

104 mask = n_offsets < N 

105 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to( 

106 output_ptr.dtype.element_ty 

107 ) 

108 m = tl.max(inp, 0) 

109 e = tl.exp(inp - m) 

110 z = tl.sum(e, 0) 

111 out = e / z 

112 output_ptrs = output_ptr + offset 

113 tl.store(output_ptrs, out, mask=mask) 

114 else: 

115 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32) 

116 z = tl.full([TILE_N], value=0.0, dtype=tl.float32) 

117 input_ptr += pid_m * N 

118 output_ptr += pid_m * N 

119 

120 previous_multiple = prev_multiple_of(N, TILE_N) 

121 for start_n in range(0, previous_multiple, TILE_N): 

122 n_offsets = start_n + tl.arange(0, TILE_N) 

123 inp = tl.load(input_ptr + n_offsets) 

124 m_new = tl.maximum(m, inp) 

125 # it is possible that there are -inf's in the input 

126 all_neg_inf = m_new == float("-inf") 

127 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

128 m = m_new 

129 # specialize the last iteration 

130 for start_n in range(previous_multiple, N, TILE_N): 

131 n_offsets = start_n + tl.arange(0, TILE_N) 

132 mask = n_offsets < N 

133 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf")) 

134 m_new = tl.maximum(m, inp) 

135 all_neg_inf = m_new == float("-inf") 

136 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

137 m = m_new 

138 

139 m_reduced = tl.max(m, 0) 

140 z = tl.sum(z * tl.exp(m - m_reduced), 0) 

141 m = m_reduced 

142 

143 previous_multiple = prev_multiple_of(N, TILE_N) 

144 # specialize the first iteration 

145 for start_n in range(0, TILE_N, TILE_N): 

146 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

147 mask = n_offsets < N 

148 inp = tl.load( 

149 input_ptr + n_offsets, 

150 mask=mask, 

151 other=-float("inf"), 

152 eviction_policy="evict_first", 

153 ) 

154 o = tl.exp(inp - m) / z 

155 tl.store(output_ptr + n_offsets, o, mask=mask) 

156 for start_n in range(TILE_N, N, TILE_N): 

157 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

158 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first") 

159 o = tl.exp(inp - m) / z 

160 tl.store(output_ptr + n_offsets, o) 

161 

162 

163# ------------------------ backward ------------------------------- 

164@libentry() 

165@triton.autotune( 

166 configs=runtime.get_tuned_config("softmax_non_inner"), 

167 key=[ 

168 "M", 

169 "N", 

170 "K", 

171 ], 

172) 

173@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner")) 

174@triton.jit 

175def softmax_backward_kernel_non_inner( 

176 out_ptr, 

177 out_grad_ptr, 

178 in_grad_ptr, 

179 M, 

180 N, 

181 K, 

182 TILE_N: tl.constexpr, 

183 TILE_K: tl.constexpr, 

184 ONE_TILE_PER_CTA: tl.constexpr, 

185): 

186 pid_m = ext.program_id(0) 

187 pid_k = ext.program_id(1) 

188 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K) 

189 

190 if ONE_TILE_PER_CTA: 

191 offsets_n = tl.arange(0, TILE_N) 

192 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

193 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

194 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

195 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

196 scale = tl.sum(out_tile * out_grad_tile, axis=0) 

197 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

198 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

199 else: 

200 offsets_n = tl.arange(0, TILE_N) 

201 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

202 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32) 

203 for _ in range(0, N, TILE_N): 

204 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

205 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

206 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

207 scale += out_tile * out_grad_tile 

208 offsets_n += TILE_N 

209 offsets += TILE_N * K 

210 scale = tl.sum(scale, axis=0) # (TILE_K) 

211 

212 offsets_n = tl.arange(0, TILE_N) 

213 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

214 for _ in range(0, N, TILE_N): 

215 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

216 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

217 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

218 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

219 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

220 offsets_n += TILE_N 

221 offsets += TILE_N * K 

222 

223 

224@libentry() 

225@triton.autotune( 

226 configs=runtime.get_tuned_config("softmax_inner"), 

227 key=["M", "N"], 

228) 

229@triton.heuristics( 

230 values=runtime.get_heuristic_config("softmax_backward_inner"), 

231) 

232@triton.jit 

233def softmax_backward_kernel_inner( 

234 out_ptr, 

235 out_grad_ptr, 

236 in_grad_ptr, 

237 M, 

238 N, 

239 TILE_M: tl.constexpr, 

240 TILE_N: tl.constexpr, 

241 ONE_TILE_PER_CTA: tl.constexpr, 

242): 

243 pid_m = ext.program_id(0) 

244 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M) 

245 if ONE_TILE_PER_CTA: 

246 n_offsets = tl.arange(0, TILE_N) 

247 offsets = m_offsets[:, None] * N + n_offsets 

248 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

249 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

250 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

251 scale = tl.sum(out_tile * out_grad_tile, 1) 

252 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

253 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

254 else: 

255 scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32) 

256 

257 n_offsets = tl.arange(0, TILE_N) 

258 offsets = m_offsets[:, None] * N + n_offsets 

259 for _ in range(0, N, TILE_N): 

260 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

261 out_tile = tl.load( 

262 out_ptr + offsets, mask=mask, eviction_policy="evict_last" 

263 ).to(tl.float32) 

264 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

265 scale += out_tile * out_grad_tile 

266 n_offsets += TILE_N 

267 offsets += TILE_N 

268 scale = tl.sum(scale, 1) # (TILE_M,) 

269 

270 n_offsets = tl.arange(0, TILE_N) 

271 offsets = m_offsets[:, None] * N + n_offsets 

272 for _ in range(0, N, TILE_N): 

273 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

274 out_tile = tl.load( 

275 out_ptr + offsets, mask=mask, eviction_policy="evict_first" 

276 ).to(tl.float32) 

277 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

278 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

279 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

280 n_offsets += TILE_N 

281 offsets += TILE_N 

282 

283 

284def softmax_out(self, dim, half_to_float=False, *, out): 

285 logger.debug("GEMS SOFTMAX_OUT") 

286 

287 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

288 

289 if self.numel() == 0: 

290 if tuple(out.shape) != tuple(self.shape): 

291 # out.resize_(self.shape) # [sunrise fix][PTPU] out.resize_(shape) not supported. 

292 out = out.cpu().resize_(self.shape).to(out.device) 

293 zero_(out) 

294 return out 

295 

296 dim = dim % self.ndim 

297 M = 1 

298 N = self.shape[dim] 

299 for i in range(dim): 

300 M *= self.shape[i] 

301 self = self.contiguous() 

302 dtype = torch.float32 if half_to_float else self.dtype 

303 if tuple(out.shape) != tuple(self.shape): 

304 # out.resize_(self.shape) # [sunrise fix][PTPU] out.resize_(shape) not supported. 

305 out = out.cpu().resize_(self.shape).to(out.device) 

306 if out.dtype != dtype: 

307 raise RuntimeError(f"_softmax.out: expected out dtype {dtype}, got {out.dtype}") 

308 K = self.numel() // M // N 

309 

310 with torch_device_fn.device(self.device): 

311 if K > 1: 

312 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

313 softmax_kernel_non_inner[grid]( 

314 out, 

315 self, 

316 M, 

317 N, 

318 K, 

319 ) 

320 else: 

321 grid = (M, 1, 1) 

322 softmax_kernel_inner[grid]( 

323 out, 

324 self, 

325 M, 

326 N, 

327 ) 

328 return out 

329 

330 

331def softmax(self, dim, half_to_float=False): 

332 logger.debug("GEMS SOFTMAX") 

333 

334 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

335 

336 if self.numel() == 0: 

337 out_shape = list(self.shape) 

338 out = torch.empty(out_shape, dtype=self.dtype, device=self.device) 

339 zero_(out) 

340 return out 

341 

342 dtype = torch.float32 if half_to_float else self.dtype 

343 out = torch.empty_like(self, dtype=dtype) 

344 return softmax_out(self, dim, half_to_float, out=out) 

345 

346 

347def softmax_backward_out(grad_output, output, dim, input_dtype, *, grad_input): 

348 logger.debug("GEMS SOFTMAX_BACKWARD_OUT") 

349 

350 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

351 dim = dim % output.ndim 

352 M = 1 

353 N = output.shape[dim] 

354 for i in range(dim): 

355 M *= output.shape[i] 

356 

357 grad_output = grad_output.contiguous() 

358 if tuple(grad_input.shape) != tuple(output.shape): 

359 # grad_input.resize_(output.shape) # [sunrise fix][PTPU] out.resize_(shape) not supported. 

360 grad_input = grad_input.cpu().resize_(output.shape).to(grad_input.device) 

361 if grad_input.dtype != input_dtype: 

362 raise RuntimeError( 

363 f"_softmax_backward_data.out: expected grad_input dtype {input_dtype}, got {grad_input.dtype}" 

364 ) 

365 K = output.numel() // M // N 

366 

367 with torch_device_fn.device(grad_input.device): 

368 if K > 1: 

369 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

370 softmax_backward_kernel_non_inner[grid]( 

371 output, 

372 grad_output, 

373 grad_input, 

374 M, 

375 N, 

376 K, 

377 ) 

378 else: 

379 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1) 

380 softmax_backward_kernel_inner[grid]( 

381 output, 

382 grad_output, 

383 grad_input, 

384 M, 

385 N, 

386 ) 

387 return grad_input 

388 

389 

390def softmax_backward(grad_output, output, dim, input_dtype): 

391 logger.debug("GEMS SOFTMAX_BACKWARD") 

392 in_grad = torch.empty_like(output, dtype=input_dtype) 

393 return softmax_backward_out( 

394 grad_output, output, dim, input_dtype, grad_input=in_grad 

395 )