Coverage for src/flag_gems/runtime/backend/_sunrise/ops/clamp.py: 0%

152 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8from flag_gems.utils.pointwise_dynamic import CodeGenConfig 

9 

10logger = logging.getLogger(__name__) 

11 

12MAX_GRID_SIZES = (65535, 65535, 65535) 

13config_f16 = CodeGenConfig( 

14 max_tile_size=1024, 

15 max_grid_size=MAX_GRID_SIZES, 

16 max_num_warps_per_cta=32, 

17 prefer_block_pointer=True, 

18 prefer_1d_tile=True, 

19) 

20 

21 

22@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")]) 

23@triton.jit 

24def clamp_func_tensor(x, mini, maxi): 

25 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) 

26 

27 

28@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")], config=config_f16) 

29@triton.jit 

30def clamp_func_tensor_f16(x, mini, maxi): 

31 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) 

32 

33 

34@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

35@triton.jit 

36def clamp_func_min_tensor(x, mini): 

37 return tl.maximum(mini, x.to(tl.float32)) 

38 

39 

40@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config_f16) 

41@triton.jit 

42def clamp_func_min_tensor_f16(x, mini): 

43 return tl.maximum(mini, x.to(tl.float32)) 

44 

45 

46@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

47@triton.jit 

48def clamp_func_max_tensor(x, maxi): 

49 return tl.minimum(maxi, x.to(tl.float32)) 

50 

51 

52@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config_f16) 

53@triton.jit 

54def clamp_func_max_tensor_f16(x, maxi): 

55 return tl.minimum(maxi, x.to(tl.float32)) 

56 

57 

58def clamp_tensor(A, mini=None, maxi=None): 

59 logging.debug("GEMS CLAMP TENSOR") 

60 if A.dtype == torch.half: 

61 if mini is None and maxi is None: 

62 raise ValueError("At least one of mini or maxi must not be None") 

63 elif mini is None: 

64 return clamp_func_max_tensor_f16(A, maxi) 

65 elif maxi is None: 

66 return clamp_func_min_tensor_f16(A, mini) 

67 else: 

68 return clamp_func_tensor_f16(A, mini, maxi) 

69 else: 

70 if mini is None and maxi is None: 

71 raise ValueError("At least one of mini or maxi must not be None") 

72 elif mini is None: 

73 return clamp_func_max_tensor(A, maxi) 

74 elif maxi is None: 

75 return clamp_func_min_tensor(A, mini) 

76 else: 

77 return clamp_func_tensor(A, mini, maxi) 

78 

79 

80def clamp_tensor_(A, mini=None, maxi=None): 

81 logger.debug("GEMS CLAMP_ TENSOR") 

82 if A.dtype == torch.half: 

83 if mini is None and maxi is None: 

84 raise ValueError("At least one of mini or maxi must not be None") 

85 elif mini is None: 

86 return clamp_func_max_tensor_f16(A, maxi, out0=A) 

87 elif maxi is None: 

88 return clamp_func_min_tensor_f16(A, mini, out0=A) 

89 else: 

90 return clamp_func_tensor_f16(A, mini, maxi, out0=A) 

91 else: 

92 if mini is None and maxi is None: 

93 raise ValueError("At least one of mini or maxi must not be None") 

94 elif mini is None: 

95 return clamp_func_max_tensor(A, maxi, out0=A) 

96 elif maxi is None: 

97 return clamp_func_min_tensor(A, mini, out0=A) 

98 else: 

99 return clamp_func_tensor(A, mini, maxi, out0=A) 

100 

101 

102@pointwise_dynamic( 

103 is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")] 

104) 

105@triton.jit 

106def clamp_func(x, mini, maxi): 

107 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) 

108 

109 

110@pointwise_dynamic( 

111 is_tensor=[True, False, False], 

112 promotion_methods=[(0, 1, 2, "DEFAULT")], 

113 config=config_f16, 

114) 

115@triton.jit 

116def clamp_func_f16(x, mini, maxi): 

117 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) 

118 

119 

120@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

121@triton.jit 

122def clamp_func_min(x, mini): 

123 return tl.maximum(mini, x.to(tl.float32)) 

124 

125 

126@pointwise_dynamic( 

127 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config_f16 

128) 

129@triton.jit 

130def clamp_func_min_f16(x, mini): 

131 return tl.maximum(mini, x.to(tl.float32)) 

132 

133 

134@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

135@triton.jit 

136def clamp_func_max(x, maxi): 

137 return tl.minimum(maxi, x.to(tl.float32)) 

138 

139 

140@pointwise_dynamic( 

141 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config_f16 

142) 

143@triton.jit 

144def clamp_func_max_f16(x, maxi): 

145 return tl.minimum(maxi, x.to(tl.float32)) 

146 

147 

148def clamp_min(A, mini): 

149 logger.debug("GEMS CLAMP MIN") 

150 if mini is None: 

151 raise ValueError("Mini must not be None") 

152 if isinstance(mini, torch.Tensor): 

153 if A.dtype == torch.half: 

154 return clamp_func_min_tensor_f16(A, mini) 

155 return clamp_func_min_tensor(A, mini) 

156 return clamp_func_min(A, mini) 

157 

158 

159def clamp_min_(A, mini): 

160 logger.debug("GEMS CLAMP_ MIN") 

161 if mini is None: 

162 raise ValueError("Mini must not be None") 

163 if isinstance(mini, torch.Tensor): 

164 if A.dtype == torch.half: 

165 return clamp_func_min_tensor_f16(A, mini, out0=A) 

166 return clamp_func_min_tensor(A, mini, out0=A) 

167 return clamp_func_min(A, mini, out0=A) 

168 

169 

170def clamp_min_out(A, mini, *, out=None): 

171 logger.debug("GEMS CLAMP MIN OUT") 

172 if mini is None: 

173 raise ValueError("Mini must not be None") 

174 if isinstance(mini, torch.Tensor): 

175 if A.dtype == torch.half: 

176 return clamp_func_min_tensor_f16(A, mini, out0=out) 

177 return clamp_func_min_tensor(A, mini, out0=out) 

178 return clamp_func_min(A, mini, out0=out) 

179 

180 

181def clamp(A, mini=None, maxi=None): 

182 logger.debug("GEMS CLAMP") 

183 if A.dtype == torch.half: 

184 if mini is None and maxi is None: 

185 raise ValueError("At least one of mini or maxi must not be None") 

186 elif mini is None: 

187 return clamp_func_max_f16(A, maxi) 

188 elif maxi is None: 

189 return clamp_func_min_f16(A, mini) 

190 else: 

191 return clamp_func_f16(A, mini, maxi) 

192 else: 

193 if mini is None and maxi is None: 

194 raise ValueError("At least one of mini or maxi must not be None") 

195 elif mini is None: 

196 return clamp_func_max(A, maxi) 

197 elif maxi is None: 

198 return clamp_func_min(A, mini) 

199 else: 

200 return clamp_func(A, mini, maxi) 

201 

202 

203def clamp_(A, mini=None, maxi=None): 

204 logger.debug("GEMS CLAMP") 

205 if A.dtype == torch.half: 

206 if mini is None and maxi is None: 

207 raise ValueError("At least one of mini or maxi must not be None") 

208 elif mini is None: 

209 return clamp_func_max_f16(A, maxi, out0=A) 

210 elif maxi is None: 

211 return clamp_func_min_f16(A, mini, out0=A) 

212 else: 

213 return clamp_func_f16(A, mini, maxi, out0=A) 

214 else: 

215 if mini is None and maxi is None: 

216 raise ValueError("At least one of mini or maxi must not be None") 

217 elif mini is None: 

218 return clamp_func_max(A, maxi, out0=A) 

219 elif maxi is None: 

220 return clamp_func_min(A, mini, out0=A) 

221 else: 

222 return clamp_func(A, mini, maxi, out0=A)