Coverage for src/flag_gems/runtime/backend/_spacemit/ops/gelu.py: 0%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import tl_extra_shim 

8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

9 

10erf = tl_extra_shim.erf 

11exp = tl_extra_shim.exp 

12pow = tl_extra_shim.pow 

13tanh = tl_extra_shim.tanh 

14geluTanh = tl_extra_shim.gelu_tanh 

15geluNone = tl_extra_shim.gelu_none 

16 

17 

18@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

19@triton.jit 

20def gelu_none(x): 

21 # scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) 

22 # output = 0.5 * x * (1 + erf(x * scale)) 

23 output = geluNone(x.to(tl.float32)) 

24 return output 

25 

26 

27@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

28@triton.jit 

29def gelu_tanh(x): 

30 output = ( 

31 # 0.5 * x * (1 + tanh(x * 0.79788456 * (1 + 0.044715 * pow(x.to(tl.float32), 2)))) 

32 geluTanh(x.to(tl.float32)) 

33 ) 

34 return output 

35 

36 

37@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

38@triton.jit 

39def gelu_backward_none(x, dy): 

40 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2) 

41 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi) 

42 x_fp32 = x.to(tl.float32) 

43 dydx = ( 

44 scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2)) 

45 + 0.5 * erf(scale1 * x_fp32) 

46 + 0.5 

47 ) 

48 dx = dydx * dy 

49 return dx 

50 

51 

52@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

53@triton.jit 

54def gelu_backward_tanh(x, dy): 

55 x_fp32 = x.to(tl.float32) 

56 # 0.79788456 = math.sqrt(2 / math.pi) 

57 tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * pow(x_fp32, 2))) 

58 dydx = 0.5 * x * ( 

59 (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2)) 

60 ) + 0.5 * (1 + tanh_out) 

61 dx = dydx * dy 

62 return dx 

63 

64 

65class Gelu(torch.autograd.Function): 

66 @staticmethod 

67 def forward(ctx, A, approximate): 

68 logging.debug("GEMS_SPACEMIT GELU_FORWARD") 

69 if approximate == "tanh": 

70 out = gelu_tanh(A) 

71 else: 

72 out = gelu_none(A) 

73 ctx.save_for_backward(A) 

74 ctx.approximate = approximate 

75 return out 

76 

77 @staticmethod 

78 def backward(ctx, out_grad): 

79 logging.debug("GEMS_SPACEMIT GELU_BACKWARD") 

80 (inp,) = ctx.saved_tensors 

81 approximate = ctx.approximate 

82 if approximate == "tanh": 

83 in_grad = gelu_backward_tanh(inp, out_grad) 

84 else: 

85 in_grad = gelu_backward_none(inp, out_grad) 

86 return in_grad, None 

87 

88 

89def gelu(A, *, approximate="none"): 

90 # print("\n.......test for mutibackend specific gelu........\n") 

91 return Gelu.apply(A, approximate)