Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/soft_margin_loss.py: 0%
98 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
12from ..utils.pointwise_dynamic import pointwise_dynamic
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")])
18@triton.jit
19def _soft_margin_loss_elementwise(x, y):
20 xf = x.to(tl.float32)
21 yf = y.to(tl.float32)
22 z = -xf * yf
23 absz = tl.abs(z)
24 return tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz))
27@libentry()
28@triton.jit
29def kernel_1(
30 x_ptr,
31 y_ptr,
32 mid,
33 M,
34 BLOCK_SIZE: tl.constexpr,
35 reduction: tl.constexpr,
36):
37 pid = ext.program_id(0)
38 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
39 mask = offset < M
41 xf = tl.load(x_ptr + offset, mask=mask, other=0).to(tl.float32)
42 yf = tl.load(y_ptr + offset, mask=mask, other=0).to(tl.float32)
44 z = -xf * yf
45 absz = tl.abs(z)
46 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz))
47 # Zero out contributions from out-of-bounds elements
48 # (soft_margin_loss(0,0) = log(2) != 0, so masking is required)
49 vals = tl.where(mask, vals, 0.0)
51 # Reduction.MEAN.value: 1, Reduction.SUM.value: 2
52 if reduction == 1:
53 sum_val = tl.sum(vals) / M
54 else:
55 sum_val = tl.sum(vals)
57 tl.store(mid + pid, sum_val)
60@libentry()
61@triton.jit
62def kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
63 offset = tl.arange(0, BLOCK_MID)
64 mask = offset < mid_size
65 mid_val = tl.load(mid + offset, mask=mask, other=0).to(tl.float32)
66 sum_val = tl.sum(mid_val)
67 tl.store(out, sum_val)
70def _normalize_reduction(reduction):
71 if isinstance(reduction, str):
72 r = reduction.lower()
73 if r == "none":
74 return 0
75 if r == "mean":
76 return 1
77 if r == "sum":
78 return 2
79 raise ValueError(f"Invalid reduction: {reduction}")
80 if isinstance(reduction, int):
81 if reduction in (0, 1, 2):
82 return reduction
83 raise ValueError(f"Invalid reduction int: {reduction}")
84 raise ValueError(f"Unsupported reduction type: {type(reduction)}")
87def soft_margin_loss(input: torch.Tensor, target: torch.Tensor, reduction="mean"):
88 logger.debug("GEMS_KUNLUNXIN SOFT_MARGIN_LOSS")
89 red = _normalize_reduction(reduction)
91 if not input.is_contiguous():
92 input = input.contiguous()
93 if not target.is_contiguous():
94 target = target.contiguous()
96 n_elements = input.numel()
98 if red == 0:
99 # reduction = 'none': use pointwise kernel (no atomic_add, no masked load issues)
100 if n_elements == 0:
101 return torch.empty_like(input)
102 return _soft_margin_loss_elementwise(input, target)
104 # reduction = 'sum' (red==2) or 'mean' (red==1)
105 if n_elements == 0:
106 if red == 2:
107 return torch.zeros((), device=input.device, dtype=input.dtype)
108 else:
109 return torch.full((), float("nan"), device=input.device, dtype=input.dtype)
111 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))
112 mid_size = triton.cdiv(n_elements, block_size)
113 block_mid = triton.next_power_of_2(mid_size)
115 mid = torch.empty((mid_size,), dtype=torch.float32, device=input.device)
116 out = torch.empty([], dtype=torch.float32, device=input.device)
118 import os
120 os.environ["TRITONXPU_OTHER_SIM"] = "1"
122 with torch_device_fn.device(input.device):
123 kernel_1[(mid_size, 1, 1)](input, target, mid, n_elements, block_size, red)
124 if mid_size == 1:
125 result = mid.reshape([]).to(dtype=input.dtype)
126 if "TRITONXPU_OTHER_SIM" in os.environ:
127 del os.environ["TRITONXPU_OTHER_SIM"]
128 return result
129 kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
131 if "TRITONXPU_OTHER_SIM" in os.environ:
132 del os.environ["TRITONXPU_OTHER_SIM"]
134 return out.to(dtype=input.dtype)
137def soft_margin_loss_out(
138 input: torch.Tensor,
139 target: torch.Tensor,
140 reduction="mean",
141 out: torch.Tensor = None,
142):
143 logger.debug("GEMS_KUNLUNXIN SOFT_MARGIN_LOSS_OUT")
144 result = soft_margin_loss(input, target, reduction)
145 if out is None:
146 return result
147 out.copy_(result)
148 return out