Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/soft_margin_loss.py: 0%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12from ..utils.pointwise_dynamic import pointwise_dynamic 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")]) 

18@triton.jit 

19def _soft_margin_loss_elementwise(x, y): 

20 xf = x.to(tl.float32) 

21 yf = y.to(tl.float32) 

22 z = -xf * yf 

23 absz = tl.abs(z) 

24 return tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz)) 

25 

26 

27@libentry() 

28@triton.jit 

29def kernel_1( 

30 x_ptr, 

31 y_ptr, 

32 mid, 

33 M, 

34 BLOCK_SIZE: tl.constexpr, 

35 reduction: tl.constexpr, 

36): 

37 pid = tle.program_id(0) 

38 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

39 mask = offset < M 

40 

41 xf = tl.load(x_ptr + offset, mask=mask, other=0).to(tl.float32) 

42 yf = tl.load(y_ptr + offset, mask=mask, other=0).to(tl.float32) 

43 

44 z = -xf * yf 

45 absz = tl.abs(z) 

46 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz)) 

47 # Zero out contributions from out-of-bounds elements 

48 # (soft_margin_loss(0,0) = log(2) != 0, so masking is required) 

49 vals = tl.where(mask, vals, 0.0) 

50 

51 # Reduction.MEAN.value: 1, Reduction.SUM.value: 2 

52 if reduction == 1: 

53 sum_val = tl.sum(vals) / M 

54 else: 

55 sum_val = tl.sum(vals) 

56 

57 tl.store(mid + pid, sum_val) 

58 

59 

60@libentry() 

61@triton.jit 

62def kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

63 offset = tl.arange(0, BLOCK_MID) 

64 mask = offset < mid_size 

65 mid_val = tl.load(mid + offset, mask=mask, other=0).to(tl.float32) 

66 sum_val = tl.sum(mid_val) 

67 tl.store(out, sum_val) 

68 

69 

70def _normalize_reduction(reduction): 

71 if isinstance(reduction, str): 

72 r = reduction.lower() 

73 if r == "none": 

74 return 0 

75 if r == "mean": 

76 return 1 

77 if r == "sum": 

78 return 2 

79 raise ValueError(f"Invalid reduction: {reduction}") 

80 if isinstance(reduction, int): 

81 if reduction in (0, 1, 2): 

82 return reduction 

83 raise ValueError(f"Invalid reduction int: {reduction}") 

84 raise ValueError(f"Unsupported reduction type: {type(reduction)}") 

85 

86 

87def soft_margin_loss(input: torch.Tensor, target: torch.Tensor, reduction="mean"): 

88 logger.debug("GEMS_KUNLUNXIN SOFT_MARGIN_LOSS") 

89 red = _normalize_reduction(reduction) 

90 

91 if not input.is_contiguous(): 

92 input = input.contiguous() 

93 if not target.is_contiguous(): 

94 target = target.contiguous() 

95 

96 n_elements = input.numel() 

97 

98 if red == 0: 

99 # reduction = 'none': use pointwise kernel (no atomic_add, no masked load issues) 

100 if n_elements == 0: 

101 return torch.empty_like(input) 

102 return _soft_margin_loss_elementwise(input, target) 

103 

104 # reduction = 'sum' (red==2) or 'mean' (red==1) 

105 if n_elements == 0: 

106 if red == 2: 

107 return torch.zeros((), device=input.device, dtype=input.dtype) 

108 else: 

109 return torch.full((), float("nan"), device=input.device, dtype=input.dtype) 

110 

111 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) 

112 mid_size = triton.cdiv(n_elements, block_size) 

113 block_mid = triton.next_power_of_2(mid_size) 

114 

115 mid = torch.empty((mid_size,), dtype=torch.float32, device=input.device) 

116 out = torch.empty([], dtype=torch.float32, device=input.device) 

117 

118 import os 

119 

120 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

121 

122 with torch_device_fn.device(input.device): 

123 kernel_1[(mid_size, 1, 1)](input, target, mid, n_elements, block_size, red) 

124 if mid_size == 1: 

125 result = mid.reshape([]).to(dtype=input.dtype) 

126 if "TRITONXPU_OTHER_SIM" in os.environ: 

127 del os.environ["TRITONXPU_OTHER_SIM"] 

128 return result 

129 kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

130 

131 if "TRITONXPU_OTHER_SIM" in os.environ: 

132 del os.environ["TRITONXPU_OTHER_SIM"] 

133 

134 return out.to(dtype=input.dtype) 

135 

136 

137def soft_margin_loss_out( 

138 input: torch.Tensor, 

139 target: torch.Tensor, 

140 reduction="mean", 

141 out: torch.Tensor = None, 

142): 

143 logger.debug("GEMS_KUNLUNXIN SOFT_MARGIN_LOSS_OUT") 

144 result = soft_margin_loss(input, target, reduction) 

145 if out is None: 

146 return result 

147 out.copy_(result) 

148 return out