Coverage for src/flag_gems/ops/soft_margin_loss.py: 38%

126 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def _soft_margin_loss_elementwise_kernel( 

15 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr 

16): 

17 pid = tl.program_id(axis=0) 

18 block_start = pid * BLOCK_SIZE 

19 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

20 mask = offsets < n_elements 

21 

22 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

23 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

24 

25 xf = x.to(tl.float32) 

26 yf = y.to(tl.float32) 

27 z = -xf * yf 

28 absz = tl.abs(z) 

29 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz)) 

30 

31 tl.store(out_ptr + offsets, vals, mask=mask) 

32 

33 

34@triton.jit 

35def _soft_margin_loss_sum_kernel( 

36 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr 

37): 

38 pid = tl.program_id(axis=0) 

39 block_start = pid * BLOCK_SIZE 

40 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

41 mask = offsets < n_elements 

42 

43 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

44 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

45 

46 xf = x.to(tl.float32) 

47 yf = y.to(tl.float32) 

48 z = -xf * yf 

49 absz = tl.abs(z) 

50 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz)) 

51 vals = tl.where(mask, vals, 0.0) 

52 

53 acc = tl.sum(vals, axis=0) 

54 tl.atomic_add(out_ptr, acc) 

55 

56 

57def _normalize_reduction(reduction): 

58 # Accept both string and enum/int forms: 0=none,1=mean,2=sum 

59 if isinstance(reduction, str): 

60 r = reduction.lower() 

61 if r == "none": 

62 return 0 

63 if r == "mean": 

64 return 1 

65 if r == "sum": 

66 return 2 

67 raise ValueError(f"Invalid reduction: {reduction}") 

68 if isinstance(reduction, int): 

69 if reduction in (0, 1, 2): 

70 return reduction 

71 raise ValueError(f"Invalid reduction int: {reduction}") 

72 raise ValueError(f"Unsupported reduction type: {type(reduction)}") 

73 

74 

75def _check_tensors(input: torch.Tensor, target: torch.Tensor): 

76 if input.device.type != flag_gems.device or target.device.type != flag_gems.device: 

77 raise AssertionError( 

78 f"soft_margin_loss: input and target must be {flag_gems.device} tensors for Triton kernel." 

79 ) 

80 if input.device != target.device: 

81 raise AssertionError( 

82 "soft_margin_loss: input and target must be on the same device." 

83 ) 

84 if input.numel() != target.numel(): 

85 raise AssertionError( 

86 "soft_margin_loss: input and target must have the same number of elements." 

87 ) 

88 if not input.is_contiguous(): 

89 input = input.contiguous() 

90 if not target.is_contiguous(): 

91 target = target.contiguous() 

92 return input, target 

93 

94 

95def soft_margin_loss(input: torch.Tensor, target: torch.Tensor, reduction="mean"): 

96 logger.debug("GEMS SOFT_MARGIN_LOSS") 

97 input, target = _check_tensors(input, target) 

98 red = _normalize_reduction(reduction) 

99 n_elements = input.numel() 

100 

101 if red == 0: 

102 # reduction = 'none' 

103 out = torch.empty_like(input) 

104 if n_elements == 0: 

105 return out 

106 BLOCK_SIZE = 1024 

107 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

108 _soft_margin_loss_elementwise_kernel[grid]( 

109 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE 

110 ) 

111 return out 

112 else: 

113 # reduction = 'sum' or 'mean' (1=mean, 2=sum) 

114 if n_elements == 0: 

115 # Follow PyTorch behavior: sum -> 0, mean -> NaN 

116 if red == 2: 

117 return torch.zeros((), device=input.device, dtype=input.dtype) 

118 else: 

119 return torch.full( 

120 (), float("nan"), device=input.device, dtype=input.dtype 

121 ) 

122 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32) 

123 BLOCK_SIZE = 1024 

124 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

125 _soft_margin_loss_sum_kernel[grid]( 

126 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE 

127 ) 

128 if red == 2: 

129 # sum 

130 return tmp_sum.to(dtype=input.dtype) 

131 else: 

132 # mean 

133 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype) 

134 return mean_val 

135 

136 

137def soft_margin_loss_out( 

138 input: torch.Tensor, 

139 target: torch.Tensor, 

140 reduction="mean", 

141 out: torch.Tensor = None, 

142): 

143 logger.debug("GEMS SOFT_MARGIN_LOSS_OUT") 

144 input, target = _check_tensors(input, target) 

145 red = _normalize_reduction(reduction) 

146 n_elements = input.numel() 

147 

148 if out is None: 

149 # Allocate output based on reduction 

150 if red == 0: 

151 out = torch.empty_like(input) 

152 else: 

153 out = torch.empty((), device=input.device, dtype=input.dtype) 

154 else: 

155 if out.device.type != flag_gems.device: 

156 raise AssertionError( 

157 f"soft_margin_loss_out: out must be a {flag_gems.device} tensor." 

158 ) 

159 if red == 0: 

160 if out.numel() != n_elements: 

161 raise AssertionError( 

162 "soft_margin_loss_out: for reduction='none', out must match input shape." 

163 ) 

164 else: 

165 if out.numel() != 1: 

166 raise AssertionError( 

167 "soft_margin_loss_out: for reduction='sum' or 'mean', out must be a scalar tensor." 

168 ) 

169 if out.device != input.device: 

170 raise AssertionError( 

171 "soft_margin_loss_out: out must be on the same device as input." 

172 ) 

173 

174 if red == 0: 

175 if n_elements > 0: 

176 BLOCK_SIZE = 1024 

177 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

178 _soft_margin_loss_elementwise_kernel[grid]( 

179 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE 

180 ) 

181 return out 

182 else: 

183 if n_elements == 0: 

184 if red == 2: 

185 out.fill_(0) 

186 else: 

187 out.fill_(float("nan")) 

188 return out 

189 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32) 

190 BLOCK_SIZE = 1024 

191 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

192 _soft_margin_loss_sum_kernel[grid]( 

193 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE 

194 ) 

195 if red == 2: 

196 out.fill_(tmp_sum.to(dtype=input.dtype)) 

197 else: 

198 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype) 

199 out.fill_(mean_val) 

200 return out