Coverage for src/flag_gems/ops/soft_margin_loss.py: 38%
126 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def _soft_margin_loss_elementwise_kernel(
15 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr
16):
17 pid = tl.program_id(axis=0)
18 block_start = pid * BLOCK_SIZE
19 offsets = block_start + tl.arange(0, BLOCK_SIZE)
20 mask = offsets < n_elements
22 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
23 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
25 xf = x.to(tl.float32)
26 yf = y.to(tl.float32)
27 z = -xf * yf
28 absz = tl.abs(z)
29 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz))
31 tl.store(out_ptr + offsets, vals, mask=mask)
34@triton.jit
35def _soft_margin_loss_sum_kernel(
36 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr
37):
38 pid = tl.program_id(axis=0)
39 block_start = pid * BLOCK_SIZE
40 offsets = block_start + tl.arange(0, BLOCK_SIZE)
41 mask = offsets < n_elements
43 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
44 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
46 xf = x.to(tl.float32)
47 yf = y.to(tl.float32)
48 z = -xf * yf
49 absz = tl.abs(z)
50 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz))
51 vals = tl.where(mask, vals, 0.0)
53 acc = tl.sum(vals, axis=0)
54 tl.atomic_add(out_ptr, acc)
57def _normalize_reduction(reduction):
58 # Accept both string and enum/int forms: 0=none,1=mean,2=sum
59 if isinstance(reduction, str):
60 r = reduction.lower()
61 if r == "none":
62 return 0
63 if r == "mean":
64 return 1
65 if r == "sum":
66 return 2
67 raise ValueError(f"Invalid reduction: {reduction}")
68 if isinstance(reduction, int):
69 if reduction in (0, 1, 2):
70 return reduction
71 raise ValueError(f"Invalid reduction int: {reduction}")
72 raise ValueError(f"Unsupported reduction type: {type(reduction)}")
75def _check_tensors(input: torch.Tensor, target: torch.Tensor):
76 if input.device.type != flag_gems.device or target.device.type != flag_gems.device:
77 raise AssertionError(
78 f"soft_margin_loss: input and target must be {flag_gems.device} tensors for Triton kernel."
79 )
80 if input.device != target.device:
81 raise AssertionError(
82 "soft_margin_loss: input and target must be on the same device."
83 )
84 if input.numel() != target.numel():
85 raise AssertionError(
86 "soft_margin_loss: input and target must have the same number of elements."
87 )
88 if not input.is_contiguous():
89 input = input.contiguous()
90 if not target.is_contiguous():
91 target = target.contiguous()
92 return input, target
95def soft_margin_loss(input: torch.Tensor, target: torch.Tensor, reduction="mean"):
96 logger.debug("GEMS SOFT_MARGIN_LOSS")
97 input, target = _check_tensors(input, target)
98 red = _normalize_reduction(reduction)
99 n_elements = input.numel()
101 if red == 0:
102 # reduction = 'none'
103 out = torch.empty_like(input)
104 if n_elements == 0:
105 return out
106 BLOCK_SIZE = 1024
107 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
108 _soft_margin_loss_elementwise_kernel[grid](
109 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE
110 )
111 return out
112 else:
113 # reduction = 'sum' or 'mean' (1=mean, 2=sum)
114 if n_elements == 0:
115 # Follow PyTorch behavior: sum -> 0, mean -> NaN
116 if red == 2:
117 return torch.zeros((), device=input.device, dtype=input.dtype)
118 else:
119 return torch.full(
120 (), float("nan"), device=input.device, dtype=input.dtype
121 )
122 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32)
123 BLOCK_SIZE = 1024
124 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
125 _soft_margin_loss_sum_kernel[grid](
126 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE
127 )
128 if red == 2:
129 # sum
130 return tmp_sum.to(dtype=input.dtype)
131 else:
132 # mean
133 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype)
134 return mean_val
137def soft_margin_loss_out(
138 input: torch.Tensor,
139 target: torch.Tensor,
140 reduction="mean",
141 out: torch.Tensor = None,
142):
143 logger.debug("GEMS SOFT_MARGIN_LOSS_OUT")
144 input, target = _check_tensors(input, target)
145 red = _normalize_reduction(reduction)
146 n_elements = input.numel()
148 if out is None:
149 # Allocate output based on reduction
150 if red == 0:
151 out = torch.empty_like(input)
152 else:
153 out = torch.empty((), device=input.device, dtype=input.dtype)
154 else:
155 if out.device.type != flag_gems.device:
156 raise AssertionError(
157 f"soft_margin_loss_out: out must be a {flag_gems.device} tensor."
158 )
159 if red == 0:
160 if out.numel() != n_elements:
161 raise AssertionError(
162 "soft_margin_loss_out: for reduction='none', out must match input shape."
163 )
164 else:
165 if out.numel() != 1:
166 raise AssertionError(
167 "soft_margin_loss_out: for reduction='sum' or 'mean', out must be a scalar tensor."
168 )
169 if out.device != input.device:
170 raise AssertionError(
171 "soft_margin_loss_out: out must be on the same device as input."
172 )
174 if red == 0:
175 if n_elements > 0:
176 BLOCK_SIZE = 1024
177 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
178 _soft_margin_loss_elementwise_kernel[grid](
179 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE
180 )
181 return out
182 else:
183 if n_elements == 0:
184 if red == 2:
185 out.fill_(0)
186 else:
187 out.fill_(float("nan"))
188 return out
189 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32)
190 BLOCK_SIZE = 1024
191 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
192 _soft_margin_loss_sum_kernel[grid](
193 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE
194 )
195 if red == 2:
196 out.fill_(tmp_sum.to(dtype=input.dtype))
197 else:
198 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype)
199 out.fill_(mean_val)
200 return out