Coverage for src/flag_gems/ops/smooth_l1_loss.py: 59%

167 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import device, torch_device_fn 

8 

9device = device.name 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def _smooth_l1_loss_kernel( 

15 inp, 

16 target, 

17 out, 

18 n_elements, 

19 beta: tl.constexpr, 

20 BLOCK_SIZE: tl.constexpr, 

21): 

22 pid = tl.program_id(0) 

23 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

24 mask = offsets < n_elements 

25 

26 inp_val = tl.load(inp + offsets, mask=mask, other=0.0).to(tl.float32) 

27 target_val = tl.load(target + offsets, mask=mask, other=0.0).to(tl.float32) 

28 diff = tl.abs(inp_val - target_val) 

29 if beta == 0.0: 

30 loss = diff 

31 else: 

32 loss = tl.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta) 

33 tl.store(out + offsets, loss, mask=mask) 

34 

35 

36@triton.jit 

37def _smooth_l1_loss_partial_sum_kernel( 

38 inp, 

39 target, 

40 mid, 

41 n_elements, 

42 beta: tl.constexpr, 

43 reduction: tl.constexpr, 

44 BLOCK_SIZE: tl.constexpr, 

45): 

46 pid = tl.program_id(0) 

47 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

48 mask = offsets < n_elements 

49 

50 inp_val = tl.load(inp + offsets, mask=mask, other=0.0).to(tl.float32) 

51 target_val = tl.load(target + offsets, mask=mask, other=0.0).to(tl.float32) 

52 diff = tl.abs(inp_val - target_val) 

53 if beta == 0.0: 

54 loss = diff 

55 else: 

56 loss = tl.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta) 

57 loss = tl.where(mask, loss, 0.0) 

58 acc = tl.sum(loss, axis=0) 

59 if reduction == 1: 

60 acc = acc / n_elements 

61 tl.store(mid + pid, acc) 

62 

63 

64@triton.jit 

65def _smooth_l1_loss_sum_kernel(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

66 offsets = tl.arange(0, BLOCK_MID) 

67 mask = offsets < mid_size 

68 vals = tl.load(mid + offsets, mask=mask, other=0.0).to(tl.float32) 

69 acc = tl.sum(vals, axis=0) 

70 tl.store(out, acc) 

71 

72 

73@triton.jit 

74def _smooth_l1_loss_backward_kernel( 

75 grad_output, 

76 inp, 

77 target, 

78 out, 

79 n_elements, 

80 reduction_elements, 

81 beta: tl.constexpr, 

82 reduction: tl.constexpr, 

83 GRAD_OUTPUT_SCALAR: tl.constexpr, 

84 BLOCK_SIZE: tl.constexpr, 

85): 

86 pid = tl.program_id(0) 

87 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

88 mask = offsets < n_elements 

89 

90 inp_val = tl.load(inp + offsets, mask=mask, other=0.0).to(tl.float32) 

91 target_val = tl.load(target + offsets, mask=mask, other=0.0).to(tl.float32) 

92 diff = inp_val - target_val 

93 

94 if beta == 0.0: 

95 grad = tl.where(diff == 0.0, float("nan"), tl.where(diff > 0.0, 1.0, -1.0)) 

96 else: 

97 grad = tl.where(diff < -beta, -1.0, tl.where(diff > beta, 1.0, diff / beta)) 

98 

99 if GRAD_OUTPUT_SCALAR: 

100 grad_out = tl.load(grad_output).to(tl.float32) 

101 if reduction == 1: 

102 grad_out = grad_out * (1.0 / reduction_elements) 

103 else: 

104 grad_out = tl.load(grad_output + offsets, mask=mask, other=0.0).to(tl.float32) 

105 if reduction == 1: 

106 grad_out = grad_out * (1.0 / reduction_elements) 

107 tl.store(out + offsets, grad * grad_out, mask=mask) 

108 

109 

110def _normalize_reduction(reduction): 

111 if isinstance(reduction, str): 

112 if reduction == "none": 

113 return 0 

114 if reduction == "mean": 

115 return 1 

116 if reduction == "sum": 

117 return 2 

118 elif isinstance(reduction, int) and reduction in (0, 1, 2): 

119 return reduction 

120 raise ValueError("reduction must be one of 'none', 'mean', or 'sum'") 

121 

122 

123def _check_input(input, target, beta): 

124 if beta < 0: 

125 raise RuntimeError("smooth_l1_loss does not support negative values for beta.") 

126 if input.device.type != device or target.device.type != device: 

127 raise AssertionError("smooth_l1_loss: input and target must be CUDA tensors.") 

128 if input.device != target.device: 

129 raise AssertionError( 

130 "smooth_l1_loss: input and target must be on the same device." 

131 ) 

132 input, target = torch.broadcast_tensors(input, target) 

133 return input.contiguous(), target.contiguous() 

134 

135 

136def _check_backward_input(grad_output, input, target, beta): 

137 reduction_elements = input.numel() 

138 input, target = _check_input(input, target, beta) 

139 if grad_output.device.type != device: 

140 raise AssertionError( 

141 "smooth_l1_loss_backward: grad_output must be a CUDA tensor." 

142 ) 

143 if grad_output.device != input.device: 

144 raise AssertionError( 

145 "smooth_l1_loss_backward: grad_output must be on the same device." 

146 ) 

147 if grad_output.numel() != 1: 

148 grad_output = torch.broadcast_to(grad_output, input.shape) 

149 return grad_output.contiguous(), input, target, reduction_elements 

150 

151 

152def _empty_reduction(input, reduction): 

153 if reduction == 0: 

154 return torch.empty_like(input) 

155 if reduction == 1: 

156 return torch.full((), float("nan"), device=input.device, dtype=input.dtype) 

157 return torch.zeros((), device=input.device, dtype=input.dtype) 

158 

159 

160def _smooth_l1_loss_none(input, target, beta, out=None): 

161 n_elements = input.numel() 

162 if out is None: 

163 out = torch.empty_like(input) 

164 out_contiguous = out 

165 else: 

166 if out.device != input.device: 

167 raise AssertionError("smooth_l1_loss.out: out must be on the same device.") 

168 if tuple(out.shape) != tuple(input.shape): 

169 out.resize_(input.shape) 

170 out_contiguous = out if out.is_contiguous() else torch.empty_like(input) 

171 

172 if n_elements > 0: 

173 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

174 with torch_device_fn.device(input.device): 

175 _smooth_l1_loss_kernel[grid]( 

176 input, 

177 target, 

178 out_contiguous, 

179 n_elements, 

180 beta=float(beta), 

181 BLOCK_SIZE=1024, 

182 ) 

183 if out_contiguous is not out: 

184 out.copy_(out_contiguous) 

185 return out 

186 

187 

188def _smooth_l1_loss_reduce(input, target, beta, reduction, out=None): 

189 n_elements = input.numel() 

190 if n_elements == 0: 

191 result = _empty_reduction(input, reduction) 

192 if out is None: 

193 return result 

194 if out.device != input.device: 

195 raise AssertionError("smooth_l1_loss.out: out must be on the same device.") 

196 if out.dim() != 0: 

197 out.resize_(()) 

198 out.copy_(result) 

199 return out 

200 

201 block_size = 1024 

202 mid_size = triton.cdiv(n_elements, block_size) 

203 block_mid = triton.next_power_of_2(mid_size) 

204 mid = torch.empty((mid_size,), device=input.device, dtype=torch.float32) 

205 result = out 

206 if result is None: 

207 result = torch.empty((), device=input.device, dtype=input.dtype) 

208 else: 

209 if result.device != input.device: 

210 raise AssertionError("smooth_l1_loss.out: out must be on the same device.") 

211 if result.dim() != 0: 

212 result.resize_(()) 

213 

214 with torch_device_fn.device(input.device): 

215 _smooth_l1_loss_partial_sum_kernel[(mid_size,)]( 

216 input, 

217 target, 

218 mid, 

219 n_elements, 

220 beta=float(beta), 

221 reduction=reduction, 

222 BLOCK_SIZE=block_size, 

223 ) 

224 _smooth_l1_loss_sum_kernel[(1,)](mid, result, mid_size, BLOCK_MID=block_mid) 

225 return result 

226 

227 

228def smooth_l1_loss( 

229 input: torch.Tensor, 

230 target: torch.Tensor, 

231 reduction=1, 

232 beta: float = 1.0, 

233) -> torch.Tensor: 

234 logger.debug("GEMS SMOOTH_L1_LOSS") 

235 reduction = _normalize_reduction(reduction) 

236 input, target = _check_input(input, target, float(beta)) 

237 if reduction == 0: 

238 return _smooth_l1_loss_none(input, target, float(beta)) 

239 return _smooth_l1_loss_reduce(input, target, float(beta), reduction) 

240 

241 

242def smooth_l1_loss_out( 

243 input: torch.Tensor, 

244 target: torch.Tensor, 

245 reduction=1, 

246 beta: float = 1.0, 

247 *, 

248 out: torch.Tensor, 

249) -> torch.Tensor: 

250 logger.debug("GEMS SMOOTH_L1_LOSS OUT") 

251 reduction = _normalize_reduction(reduction) 

252 input, target = _check_input(input, target, float(beta)) 

253 if reduction == 0: 

254 return _smooth_l1_loss_none(input, target, float(beta), out=out) 

255 return _smooth_l1_loss_reduce(input, target, float(beta), reduction, out=out) 

256 

257 

258def smooth_l1_loss_backward( 

259 grad_output: torch.Tensor, 

260 input: torch.Tensor, 

261 target: torch.Tensor, 

262 reduction, 

263 beta: float, 

264) -> torch.Tensor: 

265 logger.debug("GEMS SMOOTH_L1_LOSS BACKWARD") 

266 reduction = _normalize_reduction(reduction) 

267 grad_output, input, target, reduction_elements = _check_backward_input( 

268 grad_output, input, target, float(beta) 

269 ) 

270 out = torch.empty_like(input) 

271 n_elements = input.numel() 

272 if n_elements == 0: 

273 return out 

274 

275 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

276 with torch_device_fn.device(input.device): 

277 _smooth_l1_loss_backward_kernel[grid]( 

278 grad_output, 

279 input, 

280 target, 

281 out, 

282 n_elements, 

283 reduction_elements, 

284 beta=float(beta), 

285 reduction=reduction, 

286 GRAD_OUTPUT_SCALAR=grad_output.numel() == 1, 

287 BLOCK_SIZE=1024, 

288 ) 

289 return out