Coverage for src/flag_gems/ops/smooth_l1_loss.py: 59%
167 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import device, torch_device_fn
9device = device.name
10logger = logging.getLogger(__name__)
13@triton.jit
14def _smooth_l1_loss_kernel(
15 inp,
16 target,
17 out,
18 n_elements,
19 beta: tl.constexpr,
20 BLOCK_SIZE: tl.constexpr,
21):
22 pid = tl.program_id(0)
23 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
24 mask = offsets < n_elements
26 inp_val = tl.load(inp + offsets, mask=mask, other=0.0).to(tl.float32)
27 target_val = tl.load(target + offsets, mask=mask, other=0.0).to(tl.float32)
28 diff = tl.abs(inp_val - target_val)
29 if beta == 0.0:
30 loss = diff
31 else:
32 loss = tl.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta)
33 tl.store(out + offsets, loss, mask=mask)
36@triton.jit
37def _smooth_l1_loss_partial_sum_kernel(
38 inp,
39 target,
40 mid,
41 n_elements,
42 beta: tl.constexpr,
43 reduction: tl.constexpr,
44 BLOCK_SIZE: tl.constexpr,
45):
46 pid = tl.program_id(0)
47 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
48 mask = offsets < n_elements
50 inp_val = tl.load(inp + offsets, mask=mask, other=0.0).to(tl.float32)
51 target_val = tl.load(target + offsets, mask=mask, other=0.0).to(tl.float32)
52 diff = tl.abs(inp_val - target_val)
53 if beta == 0.0:
54 loss = diff
55 else:
56 loss = tl.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta)
57 loss = tl.where(mask, loss, 0.0)
58 acc = tl.sum(loss, axis=0)
59 if reduction == 1:
60 acc = acc / n_elements
61 tl.store(mid + pid, acc)
64@triton.jit
65def _smooth_l1_loss_sum_kernel(mid, out, mid_size, BLOCK_MID: tl.constexpr):
66 offsets = tl.arange(0, BLOCK_MID)
67 mask = offsets < mid_size
68 vals = tl.load(mid + offsets, mask=mask, other=0.0).to(tl.float32)
69 acc = tl.sum(vals, axis=0)
70 tl.store(out, acc)
73@triton.jit
74def _smooth_l1_loss_backward_kernel(
75 grad_output,
76 inp,
77 target,
78 out,
79 n_elements,
80 reduction_elements,
81 beta: tl.constexpr,
82 reduction: tl.constexpr,
83 GRAD_OUTPUT_SCALAR: tl.constexpr,
84 BLOCK_SIZE: tl.constexpr,
85):
86 pid = tl.program_id(0)
87 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
88 mask = offsets < n_elements
90 inp_val = tl.load(inp + offsets, mask=mask, other=0.0).to(tl.float32)
91 target_val = tl.load(target + offsets, mask=mask, other=0.0).to(tl.float32)
92 diff = inp_val - target_val
94 if beta == 0.0:
95 grad = tl.where(diff == 0.0, float("nan"), tl.where(diff > 0.0, 1.0, -1.0))
96 else:
97 grad = tl.where(diff < -beta, -1.0, tl.where(diff > beta, 1.0, diff / beta))
99 if GRAD_OUTPUT_SCALAR:
100 grad_out = tl.load(grad_output).to(tl.float32)
101 if reduction == 1:
102 grad_out = grad_out * (1.0 / reduction_elements)
103 else:
104 grad_out = tl.load(grad_output + offsets, mask=mask, other=0.0).to(tl.float32)
105 if reduction == 1:
106 grad_out = grad_out * (1.0 / reduction_elements)
107 tl.store(out + offsets, grad * grad_out, mask=mask)
110def _normalize_reduction(reduction):
111 if isinstance(reduction, str):
112 if reduction == "none":
113 return 0
114 if reduction == "mean":
115 return 1
116 if reduction == "sum":
117 return 2
118 elif isinstance(reduction, int) and reduction in (0, 1, 2):
119 return reduction
120 raise ValueError("reduction must be one of 'none', 'mean', or 'sum'")
123def _check_input(input, target, beta):
124 if beta < 0:
125 raise RuntimeError("smooth_l1_loss does not support negative values for beta.")
126 if input.device.type != device or target.device.type != device:
127 raise AssertionError("smooth_l1_loss: input and target must be CUDA tensors.")
128 if input.device != target.device:
129 raise AssertionError(
130 "smooth_l1_loss: input and target must be on the same device."
131 )
132 input, target = torch.broadcast_tensors(input, target)
133 return input.contiguous(), target.contiguous()
136def _check_backward_input(grad_output, input, target, beta):
137 reduction_elements = input.numel()
138 input, target = _check_input(input, target, beta)
139 if grad_output.device.type != device:
140 raise AssertionError(
141 "smooth_l1_loss_backward: grad_output must be a CUDA tensor."
142 )
143 if grad_output.device != input.device:
144 raise AssertionError(
145 "smooth_l1_loss_backward: grad_output must be on the same device."
146 )
147 if grad_output.numel() != 1:
148 grad_output = torch.broadcast_to(grad_output, input.shape)
149 return grad_output.contiguous(), input, target, reduction_elements
152def _empty_reduction(input, reduction):
153 if reduction == 0:
154 return torch.empty_like(input)
155 if reduction == 1:
156 return torch.full((), float("nan"), device=input.device, dtype=input.dtype)
157 return torch.zeros((), device=input.device, dtype=input.dtype)
160def _smooth_l1_loss_none(input, target, beta, out=None):
161 n_elements = input.numel()
162 if out is None:
163 out = torch.empty_like(input)
164 out_contiguous = out
165 else:
166 if out.device != input.device:
167 raise AssertionError("smooth_l1_loss.out: out must be on the same device.")
168 if tuple(out.shape) != tuple(input.shape):
169 out.resize_(input.shape)
170 out_contiguous = out if out.is_contiguous() else torch.empty_like(input)
172 if n_elements > 0:
173 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
174 with torch_device_fn.device(input.device):
175 _smooth_l1_loss_kernel[grid](
176 input,
177 target,
178 out_contiguous,
179 n_elements,
180 beta=float(beta),
181 BLOCK_SIZE=1024,
182 )
183 if out_contiguous is not out:
184 out.copy_(out_contiguous)
185 return out
188def _smooth_l1_loss_reduce(input, target, beta, reduction, out=None):
189 n_elements = input.numel()
190 if n_elements == 0:
191 result = _empty_reduction(input, reduction)
192 if out is None:
193 return result
194 if out.device != input.device:
195 raise AssertionError("smooth_l1_loss.out: out must be on the same device.")
196 if out.dim() != 0:
197 out.resize_(())
198 out.copy_(result)
199 return out
201 block_size = 1024
202 mid_size = triton.cdiv(n_elements, block_size)
203 block_mid = triton.next_power_of_2(mid_size)
204 mid = torch.empty((mid_size,), device=input.device, dtype=torch.float32)
205 result = out
206 if result is None:
207 result = torch.empty((), device=input.device, dtype=input.dtype)
208 else:
209 if result.device != input.device:
210 raise AssertionError("smooth_l1_loss.out: out must be on the same device.")
211 if result.dim() != 0:
212 result.resize_(())
214 with torch_device_fn.device(input.device):
215 _smooth_l1_loss_partial_sum_kernel[(mid_size,)](
216 input,
217 target,
218 mid,
219 n_elements,
220 beta=float(beta),
221 reduction=reduction,
222 BLOCK_SIZE=block_size,
223 )
224 _smooth_l1_loss_sum_kernel[(1,)](mid, result, mid_size, BLOCK_MID=block_mid)
225 return result
228def smooth_l1_loss(
229 input: torch.Tensor,
230 target: torch.Tensor,
231 reduction=1,
232 beta: float = 1.0,
233) -> torch.Tensor:
234 logger.debug("GEMS SMOOTH_L1_LOSS")
235 reduction = _normalize_reduction(reduction)
236 input, target = _check_input(input, target, float(beta))
237 if reduction == 0:
238 return _smooth_l1_loss_none(input, target, float(beta))
239 return _smooth_l1_loss_reduce(input, target, float(beta), reduction)
242def smooth_l1_loss_out(
243 input: torch.Tensor,
244 target: torch.Tensor,
245 reduction=1,
246 beta: float = 1.0,
247 *,
248 out: torch.Tensor,
249) -> torch.Tensor:
250 logger.debug("GEMS SMOOTH_L1_LOSS OUT")
251 reduction = _normalize_reduction(reduction)
252 input, target = _check_input(input, target, float(beta))
253 if reduction == 0:
254 return _smooth_l1_loss_none(input, target, float(beta), out=out)
255 return _smooth_l1_loss_reduce(input, target, float(beta), reduction, out=out)
258def smooth_l1_loss_backward(
259 grad_output: torch.Tensor,
260 input: torch.Tensor,
261 target: torch.Tensor,
262 reduction,
263 beta: float,
264) -> torch.Tensor:
265 logger.debug("GEMS SMOOTH_L1_LOSS BACKWARD")
266 reduction = _normalize_reduction(reduction)
267 grad_output, input, target, reduction_elements = _check_backward_input(
268 grad_output, input, target, float(beta)
269 )
270 out = torch.empty_like(input)
271 n_elements = input.numel()
272 if n_elements == 0:
273 return out
275 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
276 with torch_device_fn.device(input.device):
277 _smooth_l1_loss_backward_kernel[grid](
278 grad_output,
279 input,
280 target,
281 out,
282 n_elements,
283 reduction_elements,
284 beta=float(beta),
285 reduction=reduction,
286 GRAD_OUTPUT_SCALAR=grad_output.numel() == 1,
287 BLOCK_SIZE=1024,
288 )
289 return out