Coverage for src/flag_gems/runtime/backend/_sunrise/ops/gelu.py: 0%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

7from flag_gems.utils.pointwise_dynamic import CodeGenConfig 

8 

9erf = tl_extra_shim.erf 

10exp = tl_extra_shim.exp 

11pow = tl_extra_shim.pow 

12tanh = tl_extra_shim.tanh 

13 

14MAX_GRID_SIZES = (65535, 65535, 65535) 

15config = CodeGenConfig( 

16 max_tile_size=1024, 

17 max_grid_size=MAX_GRID_SIZES, 

18 max_num_warps_per_cta=32, 

19 prefer_block_pointer=False, 

20 prefer_1d_tile=True, 

21) 

22 

23logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

24 

25 

26@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

27@triton.jit 

28def gelu_none(x): 

29 scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) 

30 # output = 0.5 * x * (1 + erf(x * scale)) # "__ocml_erf_f32" 只有支持f32输入的 

31 x_fp32 = x.to(tl.float32) 

32 output = 0.5 * x_fp32 * (1 + erf(x_fp32 * scale)) 

33 

34 return output 

35 

36 

37@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

38@triton.jit 

39def gelu_tanh(x): 

40 x_fp32 = x.to(tl.float32) 

41 output = 0.5 * x * (1 + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32, 2)))) 

42 return output 

43 

44 

45@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

46@triton.jit 

47def gelu_backward_none(x, dy): 

48 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) 

49 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi) 

50 x_fp32 = x.to(tl.float32) 

51 dydx = ( 

52 scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2)) 

53 + 0.5 * erf(scale1 * x_fp32) 

54 + 0.5 

55 ) 

56 dx = dydx * dy 

57 return dx 

58 

59 

60@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

61@triton.jit 

62def gelu_backward_tanh(x, dy): 

63 x_fp32 = x.to(tl.float32) 

64 # 0.79788456 = math.sqrt(2 / math.pi) 

65 tanh_out = tanh(0.79788456 * x_fp32 * (1 + 0.044715 * pow(x_fp32, 2))) 

66 dydx = 0.5 * x_fp32 * ( 

67 (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2)) 

68 ) + 0.5 * (1 + tanh_out) 

69 dx = dydx * dy 

70 return dx 

71 

72 

73def gelu(self, *, approximate="none"): 

74 logger.debug("GEMS GELU FORWARD") 

75 if approximate == "tanh": 

76 out = gelu_tanh(self) 

77 else: 

78 out = gelu_none(self) 

79 return out 

80 

81 

82def gelu_backward(grad_output, self, *, approximate="none"): 

83 logger.debug("GEMS GELU BACKWARD") 

84 if approximate == "tanh": 

85 in_grad = gelu_backward_tanh(self, grad_output) 

86 else: 

87 in_grad = gelu_backward_none(self, grad_output) 

88 return in_grad 

89 

90 

91def gelu_(A, *, approximate="none"): 

92 logger.debug("GEMS GELU_ FORWARD") 

93 if approximate == "tanh": 

94 out = gelu_tanh(A, out0=A) 

95 else: 

96 out = gelu_none(A, out0=A) 

97 return out