Coverage for src/flag_gems/ops/margin_ranking_loss.py: 43%
111 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def _margin_ranking_loss_kernel(
15 x1_ptr, x2_ptr, target_ptr, out_ptr, n_elements, margin, BLOCK_SIZE: tl.constexpr
16):
17 """
18 Triton kernel for computing margin ranking loss forward pass.
20 Computes: loss = max(0, -y * (x1 - x2) + margin)
21 where y is the target (typically +1 or -1).
23 Args:
24 x1_ptr: Pointer to first input tensor
25 x2_ptr: Pointer to second input tensor
26 target_ptr: Pointer to target tensor (labels)
27 out_ptr: Pointer to output loss tensor
28 n_elements: Total number of elements to process
29 margin: Margin value for the loss
30 BLOCK_SIZE: Number of elements processed per thread block
31 """
32 # Get the program ID for this block
33 pid = tl.program_id(axis=0)
34 block_start = pid * BLOCK_SIZE
35 offsets = block_start + tl.arange(0, BLOCK_SIZE)
36 mask = offsets < n_elements
38 # Load input tensors with boundary checking
39 x1 = tl.load(x1_ptr + offsets, mask=mask, other=0)
40 x2 = tl.load(x2_ptr + offsets, mask=mask, other=0)
41 y = tl.load(target_ptr + offsets, mask=mask, other=0)
43 # Compute margin ranking loss: max(0, -y * (x1 - x2) + margin)
44 diff = x1 - x2
45 m = tl.full([BLOCK_SIZE], margin, tl.float32)
46 val = -y * diff + m
47 zero = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
48 loss = tl.maximum(val, zero)
50 # Store the result (cast back to input dtype)
51 tl.store(out_ptr + offsets, loss.to(x1.dtype), mask=mask)
54@triton.jit
55def _margin_ranking_loss_backward_kernel(
56 grad_output_ptr,
57 x1_ptr,
58 x2_ptr,
59 y_ptr,
60 grad_x1_ptr,
61 grad_x2_ptr,
62 margin,
63 n_elements,
64 BLOCK_SIZE: tl.constexpr,
65):
66 """
67 Triton kernel for computing margin ranking loss backward pass.
69 Computes gradients:
70 grad_x1 = -y * grad_output (where loss > 0)
71 grad_x2 = y * grad_output (where loss > 0)
73 Args:
74 grad_output_ptr: Pointer to gradient from upstream
75 x1_ptr: Pointer to first input tensor
76 x2_ptr: Pointer to second input tensor
77 y_ptr: Pointer to target tensor
78 grad_x1_ptr: Pointer to gradient output for x1
79 grad_x2_ptr: Pointer to gradient output for x2
80 margin: Margin value used in forward pass
81 n_elements: Total number of elements to process
82 BLOCK_SIZE: Number of elements processed per thread block
83 """
85 # print("\n.......test for mutibackend specific margin_ranking_loss backward........\n")
86 # Get the program ID for this block
87 pid = tl.program_id(axis=0)
88 block_start = pid * BLOCK_SIZE
89 offsets = block_start + tl.arange(0, BLOCK_SIZE)
90 mask = offsets < n_elements
92 # Load tensors with boundary checking
93 grad_output = tl.load(grad_output_ptr + offsets, mask=mask, other=0.0)
94 x1 = tl.load(x1_ptr + offsets, mask=mask, other=0.0)
95 x2 = tl.load(x2_ptr + offsets, mask=mask, other=0.0)
96 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
98 # Recompute forward pass to determine active elements (where loss > 0)
99 diff = x1 - x2
100 m = tl.full([BLOCK_SIZE], margin, tl.float32)
101 val = -y * diff + m
102 active_mask = val > 0
104 # Compute gradients only for active elements
105 # d(loss)/d(x1) = -y when loss > 0, else 0
106 # d(loss)/d(x2) = y when loss > 0, else 0
107 grad_x1 = tl.where(active_mask, -y * grad_output, 0.0)
108 grad_x2 = tl.where(active_mask, y * grad_output, 0.0)
110 tl.store(grad_x1_ptr + offsets, grad_x1.to(x1.dtype), mask=mask)
111 tl.store(grad_x2_ptr + offsets, grad_x2.to(x1.dtype), mask=mask)
114class MarginRankingLossOp(torch.autograd.Function):
115 """
116 Custom autograd function for margin ranking loss with Triton kernel acceleration.
118 Implements the margin ranking loss: loss = max(0, -y * (x1 - x2) + margin)
119 This loss is used to learn rankings where x1 should be ranked higher than x2
120 when y = 1, and x2 should be ranked higher than x1 when y = -1.
121 """
123 @staticmethod
124 def forward(ctx, x1, x2, target, margin=0.0, reduction="mean"):
125 """
126 Forward pass for margin ranking loss.
128 Args:
129 ctx: Context object for saving tensors for backward pass
130 x1: First input tensor
131 x2: Second input tensor
132 target: Target tensor with values typically +1 or -1
133 margin: Margin value (default: 0.0)
134 reduction: Reduction mode - 'none', 'mean', or 'sum' (default: 'mean')
136 Returns:
137 Loss tensor with shape depending on reduction mode
138 """
139 logger.debug("GEMS MARGIN_RANKING_LOSS")
141 if not (
142 x1.is_floating_point()
143 and x2.is_floating_point()
144 and target.is_floating_point()
145 ):
146 raise ValueError("All inputs must be floating point tensors")
148 # Normalize reduction parameter (handle both string and int formats)
149 if isinstance(reduction, int):
150 reduction = {0: "none", 1: "mean", 2: "sum"}.get(reduction, "mean")
151 if reduction not in ("none", "mean", "sum"):
152 raise ValueError("reduction must be one of 'none', 'mean', or 'sum'")
154 # Check device compatibility and fallback to PyTorch if needed
155 device = x1.device
156 if not (isinstance(device, torch.device) and device.type == flag_gems.device):
157 # Fallback to PyTorch implementation for non-CUDA tensors
158 return torch.ops.aten.margin_ranking_loss(
159 x1,
160 x2,
161 target,
162 float(margin),
163 {"none": 0, "mean": 1, "sum": 2}[reduction],
164 )
166 # Broadcast tensors to ensure compatible shapes
167 x1_b, x2_b, tgt_b = torch.broadcast_tensors(x1, x2, target)
169 # Ensure all tensors have the same floating point dtype
170 common_dtype = x1_b.dtype if x1_b.is_floating_point() else torch.float32
171 x1_b = x1_b.to(dtype=common_dtype)
172 x2_b = x2_b.to(dtype=common_dtype)
173 tgt_b = tgt_b.to(dtype=common_dtype)
175 # Flatten tensors to 1D for efficient kernel processing
176 x1_c = x1_b.contiguous().view(-1)
177 x2_c = x2_b.contiguous().view(-1)
178 tgt_c = tgt_b.contiguous().view(-1)
180 # Allocate output buffer
181 out = torch.empty_like(x1_c)
183 n_elements = out.numel()
184 if n_elements == 0:
185 # Handle empty tensors gracefully
186 if reduction == "none":
187 return out.view(x1_b.shape)
188 elif reduction == "sum":
189 return out.sum()
190 else:
191 return out.mean()
193 # Launch Triton kernel for forward computation
194 BLOCK_SIZE = 1024
195 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
196 _margin_ranking_loss_kernel[grid](
197 x1_c, x2_c, tgt_c, out, n_elements, float(margin), BLOCK_SIZE=BLOCK_SIZE
198 )
200 # Save tensors needed for backward pass
201 ctx.save_for_backward(x1_c, x2_c, tgt_c)
202 ctx.reduction = reduction
203 ctx.margin = margin
204 ctx.n_elements = n_elements
205 ctx.original_shape = x1_b.shape
207 # Apply reduction operation
208 if reduction == "none":
209 return out.view(x1_b.shape)
210 elif reduction == "sum":
211 return out.sum()
212 else:
213 return out.mean()
215 @staticmethod
216 def backward(ctx, grad_output):
217 """
218 Backward pass for margin ranking loss.
220 Args:
221 ctx: Context object with saved tensors from forward pass
222 grad_output: Gradient from upstream layers
224 Returns:
225 Tuple of gradients (grad_x1, grad_x2, None, None, None)
226 where None corresponds to target, margin, and reduction (no gradients needed)
227 """
228 logger.debug("GEMS MARGIN_RANKING_LOSS_BACKWARD")
230 x1, x2, y = ctx.saved_tensors
231 margin = ctx.margin
232 reduction = ctx.reduction
233 n_elements = ctx.n_elements
235 # Handle empty tensor case
236 if n_elements == 0:
237 grad_x1 = torch.zeros_like(x1)
238 grad_x2 = torch.zeros_like(x2)
239 grad_target = torch.zeros_like(y)
240 return grad_x1, grad_x2, grad_target, None, None
242 # Scale gradient based on reduction mode and expand to match flat tensor shape
243 if reduction == "mean":
244 grad_output = grad_output.expand(n_elements) / n_elements
245 elif reduction == "sum":
246 grad_output = grad_output.expand(n_elements)
247 else:
248 grad_output = grad_output.contiguous().view(-1)
250 grad_output = grad_output.contiguous()
252 # Allocate gradient buffers
253 grad_x1 = torch.empty_like(x1)
254 grad_x2 = torch.empty_like(x2)
256 # Launch Triton kernel for backward computation
257 BLOCK_SIZE = 1024
258 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
259 _margin_ranking_loss_backward_kernel[grid](
260 grad_output,
261 x1,
262 x2,
263 y,
264 grad_x1,
265 grad_x2,
266 float(margin),
267 n_elements,
268 BLOCK_SIZE=BLOCK_SIZE,
269 )
271 # Reshape gradients back to original input shape
272 original_shape = ctx.original_shape
273 grad_x1 = grad_x1.view(original_shape)
274 grad_x2 = grad_x2.view(original_shape)
276 # Return gradients (zero grad for target to support autograd.grad with allow_unused=False)
277 grad_target = torch.zeros_like(y).view(original_shape)
278 return grad_x1, grad_x2, grad_target, None, None
281def margin_ranking_loss(x1, x2, target, margin=0.0, reduction="mean"):
282 """
283 Compute margin ranking loss using Triton-accelerated kernels.
285 The margin ranking loss is defined as:
286 loss = max(0, -y * (x1 - x2) + margin)
288 This loss encourages x1 to be ranked higher than x2 when y = 1,
289 and x2 to be ranked higher than x1 when y = -1.
291 Args:
292 x1: First input tensor
293 x2: Second input tensor
294 target: Target tensor with values typically +1 or -1
295 margin: Margin value (default: 0.0)
296 reduction: Specifies the reduction to apply to the output:
297 'none': no reduction
298 'mean': mean of all elements
299 'sum': sum of all elements
301 Returns:
302 Loss tensor with shape depending on reduction mode
304 Example:
305 >>> x1 = torch.randn(4, device='cuda')
306 >>> x2 = torch.randn(4, device='cuda')
307 >>> target = torch.ones(4, device='cuda')
308 >>> loss = margin_ranking_loss(x1, x2, target, margin=1.0)
309 """
310 return MarginRankingLossOp.apply(x1, x2, target, margin, reduction)