Coverage for src/flag_gems/ops/margin_ranking_loss.py: 43%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def _margin_ranking_loss_kernel( 

15 x1_ptr, x2_ptr, target_ptr, out_ptr, n_elements, margin, BLOCK_SIZE: tl.constexpr 

16): 

17 """ 

18 Triton kernel for computing margin ranking loss forward pass. 

19 

20 Computes: loss = max(0, -y * (x1 - x2) + margin) 

21 where y is the target (typically +1 or -1). 

22 

23 Args: 

24 x1_ptr: Pointer to first input tensor 

25 x2_ptr: Pointer to second input tensor 

26 target_ptr: Pointer to target tensor (labels) 

27 out_ptr: Pointer to output loss tensor 

28 n_elements: Total number of elements to process 

29 margin: Margin value for the loss 

30 BLOCK_SIZE: Number of elements processed per thread block 

31 """ 

32 # Get the program ID for this block 

33 pid = tl.program_id(axis=0) 

34 block_start = pid * BLOCK_SIZE 

35 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

36 mask = offsets < n_elements 

37 

38 # Load input tensors with boundary checking 

39 x1 = tl.load(x1_ptr + offsets, mask=mask, other=0) 

40 x2 = tl.load(x2_ptr + offsets, mask=mask, other=0) 

41 y = tl.load(target_ptr + offsets, mask=mask, other=0) 

42 

43 # Compute margin ranking loss: max(0, -y * (x1 - x2) + margin) 

44 diff = x1 - x2 

45 m = tl.full([BLOCK_SIZE], margin, tl.float32) 

46 val = -y * diff + m 

47 zero = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 

48 loss = tl.maximum(val, zero) 

49 

50 # Store the result (cast back to input dtype) 

51 tl.store(out_ptr + offsets, loss.to(x1.dtype), mask=mask) 

52 

53 

54@triton.jit 

55def _margin_ranking_loss_backward_kernel( 

56 grad_output_ptr, 

57 x1_ptr, 

58 x2_ptr, 

59 y_ptr, 

60 grad_x1_ptr, 

61 grad_x2_ptr, 

62 margin, 

63 n_elements, 

64 BLOCK_SIZE: tl.constexpr, 

65): 

66 """ 

67 Triton kernel for computing margin ranking loss backward pass. 

68 

69 Computes gradients: 

70 grad_x1 = -y * grad_output (where loss > 0) 

71 grad_x2 = y * grad_output (where loss > 0) 

72 

73 Args: 

74 grad_output_ptr: Pointer to gradient from upstream 

75 x1_ptr: Pointer to first input tensor 

76 x2_ptr: Pointer to second input tensor 

77 y_ptr: Pointer to target tensor 

78 grad_x1_ptr: Pointer to gradient output for x1 

79 grad_x2_ptr: Pointer to gradient output for x2 

80 margin: Margin value used in forward pass 

81 n_elements: Total number of elements to process 

82 BLOCK_SIZE: Number of elements processed per thread block 

83 """ 

84 

85 # print("\n.......test for mutibackend specific margin_ranking_loss backward........\n") 

86 # Get the program ID for this block 

87 pid = tl.program_id(axis=0) 

88 block_start = pid * BLOCK_SIZE 

89 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

90 mask = offsets < n_elements 

91 

92 # Load tensors with boundary checking 

93 grad_output = tl.load(grad_output_ptr + offsets, mask=mask, other=0.0) 

94 x1 = tl.load(x1_ptr + offsets, mask=mask, other=0.0) 

95 x2 = tl.load(x2_ptr + offsets, mask=mask, other=0.0) 

96 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

97 

98 # Recompute forward pass to determine active elements (where loss > 0) 

99 diff = x1 - x2 

100 m = tl.full([BLOCK_SIZE], margin, tl.float32) 

101 val = -y * diff + m 

102 active_mask = val > 0 

103 

104 # Compute gradients only for active elements 

105 # d(loss)/d(x1) = -y when loss > 0, else 0 

106 # d(loss)/d(x2) = y when loss > 0, else 0 

107 grad_x1 = tl.where(active_mask, -y * grad_output, 0.0) 

108 grad_x2 = tl.where(active_mask, y * grad_output, 0.0) 

109 

110 tl.store(grad_x1_ptr + offsets, grad_x1.to(x1.dtype), mask=mask) 

111 tl.store(grad_x2_ptr + offsets, grad_x2.to(x1.dtype), mask=mask) 

112 

113 

114class MarginRankingLossOp(torch.autograd.Function): 

115 """ 

116 Custom autograd function for margin ranking loss with Triton kernel acceleration. 

117 

118 Implements the margin ranking loss: loss = max(0, -y * (x1 - x2) + margin) 

119 This loss is used to learn rankings where x1 should be ranked higher than x2 

120 when y = 1, and x2 should be ranked higher than x1 when y = -1. 

121 """ 

122 

123 @staticmethod 

124 def forward(ctx, x1, x2, target, margin=0.0, reduction="mean"): 

125 """ 

126 Forward pass for margin ranking loss. 

127 

128 Args: 

129 ctx: Context object for saving tensors for backward pass 

130 x1: First input tensor 

131 x2: Second input tensor 

132 target: Target tensor with values typically +1 or -1 

133 margin: Margin value (default: 0.0) 

134 reduction: Reduction mode - 'none', 'mean', or 'sum' (default: 'mean') 

135 

136 Returns: 

137 Loss tensor with shape depending on reduction mode 

138 """ 

139 logger.debug("GEMS MARGIN_RANKING_LOSS") 

140 

141 if not ( 

142 x1.is_floating_point() 

143 and x2.is_floating_point() 

144 and target.is_floating_point() 

145 ): 

146 raise ValueError("All inputs must be floating point tensors") 

147 

148 # Normalize reduction parameter (handle both string and int formats) 

149 if isinstance(reduction, int): 

150 reduction = {0: "none", 1: "mean", 2: "sum"}.get(reduction, "mean") 

151 if reduction not in ("none", "mean", "sum"): 

152 raise ValueError("reduction must be one of 'none', 'mean', or 'sum'") 

153 

154 # Check device compatibility and fallback to PyTorch if needed 

155 device = x1.device 

156 if not (isinstance(device, torch.device) and device.type == flag_gems.device): 

157 # Fallback to PyTorch implementation for non-CUDA tensors 

158 return torch.ops.aten.margin_ranking_loss( 

159 x1, 

160 x2, 

161 target, 

162 float(margin), 

163 {"none": 0, "mean": 1, "sum": 2}[reduction], 

164 ) 

165 

166 # Broadcast tensors to ensure compatible shapes 

167 x1_b, x2_b, tgt_b = torch.broadcast_tensors(x1, x2, target) 

168 

169 # Ensure all tensors have the same floating point dtype 

170 common_dtype = x1_b.dtype if x1_b.is_floating_point() else torch.float32 

171 x1_b = x1_b.to(dtype=common_dtype) 

172 x2_b = x2_b.to(dtype=common_dtype) 

173 tgt_b = tgt_b.to(dtype=common_dtype) 

174 

175 # Flatten tensors to 1D for efficient kernel processing 

176 x1_c = x1_b.contiguous().view(-1) 

177 x2_c = x2_b.contiguous().view(-1) 

178 tgt_c = tgt_b.contiguous().view(-1) 

179 

180 # Allocate output buffer 

181 out = torch.empty_like(x1_c) 

182 

183 n_elements = out.numel() 

184 if n_elements == 0: 

185 # Handle empty tensors gracefully 

186 if reduction == "none": 

187 return out.view(x1_b.shape) 

188 elif reduction == "sum": 

189 return out.sum() 

190 else: 

191 return out.mean() 

192 

193 # Launch Triton kernel for forward computation 

194 BLOCK_SIZE = 1024 

195 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

196 _margin_ranking_loss_kernel[grid]( 

197 x1_c, x2_c, tgt_c, out, n_elements, float(margin), BLOCK_SIZE=BLOCK_SIZE 

198 ) 

199 

200 # Save tensors needed for backward pass 

201 ctx.save_for_backward(x1_c, x2_c, tgt_c) 

202 ctx.reduction = reduction 

203 ctx.margin = margin 

204 ctx.n_elements = n_elements 

205 ctx.original_shape = x1_b.shape 

206 

207 # Apply reduction operation 

208 if reduction == "none": 

209 return out.view(x1_b.shape) 

210 elif reduction == "sum": 

211 return out.sum() 

212 else: 

213 return out.mean() 

214 

215 @staticmethod 

216 def backward(ctx, grad_output): 

217 """ 

218 Backward pass for margin ranking loss. 

219 

220 Args: 

221 ctx: Context object with saved tensors from forward pass 

222 grad_output: Gradient from upstream layers 

223 

224 Returns: 

225 Tuple of gradients (grad_x1, grad_x2, None, None, None) 

226 where None corresponds to target, margin, and reduction (no gradients needed) 

227 """ 

228 logger.debug("GEMS MARGIN_RANKING_LOSS_BACKWARD") 

229 

230 x1, x2, y = ctx.saved_tensors 

231 margin = ctx.margin 

232 reduction = ctx.reduction 

233 n_elements = ctx.n_elements 

234 

235 # Handle empty tensor case 

236 if n_elements == 0: 

237 grad_x1 = torch.zeros_like(x1) 

238 grad_x2 = torch.zeros_like(x2) 

239 grad_target = torch.zeros_like(y) 

240 return grad_x1, grad_x2, grad_target, None, None 

241 

242 # Scale gradient based on reduction mode and expand to match flat tensor shape 

243 if reduction == "mean": 

244 grad_output = grad_output.expand(n_elements) / n_elements 

245 elif reduction == "sum": 

246 grad_output = grad_output.expand(n_elements) 

247 else: 

248 grad_output = grad_output.contiguous().view(-1) 

249 

250 grad_output = grad_output.contiguous() 

251 

252 # Allocate gradient buffers 

253 grad_x1 = torch.empty_like(x1) 

254 grad_x2 = torch.empty_like(x2) 

255 

256 # Launch Triton kernel for backward computation 

257 BLOCK_SIZE = 1024 

258 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

259 _margin_ranking_loss_backward_kernel[grid]( 

260 grad_output, 

261 x1, 

262 x2, 

263 y, 

264 grad_x1, 

265 grad_x2, 

266 float(margin), 

267 n_elements, 

268 BLOCK_SIZE=BLOCK_SIZE, 

269 ) 

270 

271 # Reshape gradients back to original input shape 

272 original_shape = ctx.original_shape 

273 grad_x1 = grad_x1.view(original_shape) 

274 grad_x2 = grad_x2.view(original_shape) 

275 

276 # Return gradients (zero grad for target to support autograd.grad with allow_unused=False) 

277 grad_target = torch.zeros_like(y).view(original_shape) 

278 return grad_x1, grad_x2, grad_target, None, None 

279 

280 

281def margin_ranking_loss(x1, x2, target, margin=0.0, reduction="mean"): 

282 """ 

283 Compute margin ranking loss using Triton-accelerated kernels. 

284 

285 The margin ranking loss is defined as: 

286 loss = max(0, -y * (x1 - x2) + margin) 

287 

288 This loss encourages x1 to be ranked higher than x2 when y = 1, 

289 and x2 to be ranked higher than x1 when y = -1. 

290 

291 Args: 

292 x1: First input tensor 

293 x2: Second input tensor 

294 target: Target tensor with values typically +1 or -1 

295 margin: Margin value (default: 0.0) 

296 reduction: Specifies the reduction to apply to the output: 

297 'none': no reduction 

298 'mean': mean of all elements 

299 'sum': sum of all elements 

300 

301 Returns: 

302 Loss tensor with shape depending on reduction mode 

303 

304 Example: 

305 >>> x1 = torch.randn(4, device='cuda') 

306 >>> x2 = torch.randn(4, device='cuda') 

307 >>> target = torch.ones(4, device='cuda') 

308 >>> loss = margin_ranking_loss(x1, x2, target, margin=1.0) 

309 """ 

310 return MarginRankingLossOp.apply(x1, x2, target, margin, reduction)