Coverage for src/flag_gems/runtime/backend/_cambricon/ops/gelu.py: 0%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from triton.language.extra.mlu.libdevice import fast_erf, fast_tanh 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner, tl_extra_shim 

10 

11from ..utils import TOTAL_CORE_NUM 

12from ..utils.pointwise_dynamic import pointwise_dynamic 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16exp = tl_extra_shim.exp 

17 

18 

19@libentry() 

20@libtuner( 

21 configs=[ 

22 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

23 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

24 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

25 triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1), 

26 ], 

27 key=["n_elements"], 

28) 

29@triton.jit 

30def gelu_none_kernel(X_ptr, OUT_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

31 pid = tl.program_id(0) 

32 num_jobs = tl.num_programs(0) 

33 block_start = pid * BLOCK_SIZE 

34 step = num_jobs * BLOCK_SIZE 

35 block_start = block_start.to(tl.int64) 

36 scale: tl.constexpr = 0.7071067811 

37 for off in range(block_start, n_elements, step): 

38 offsets = off + tl.arange(0, BLOCK_SIZE) 

39 mask = offsets < n_elements 

40 x = tl.load(X_ptr + offsets, mask=mask) 

41 x_f32 = x.to(tl.float32) 

42 result = 0.5 * x_f32 + 0.5 * x_f32 * fast_erf(x_f32 * scale) 

43 tl.store(OUT_ptr + offsets, result.to(x.dtype), mask=mask) 

44 

45 

46@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) 

47@triton.jit 

48def gelu_tanh(x, inplace): 

49 x_f32 = x.to(tl.float32) 

50 output = 0.5 * x_f32 + 0.5 * x_f32 * fast_tanh( 

51 x_f32 * 0.79788456 + x_f32 * 0.79788456 * 0.044715 * x_f32 * x_f32 

52 ) 

53 return output 

54 

55 

56@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

57@triton.jit 

58def gelu_backward_none(x, dy): 

59 scale1: tl.constexpr = 0.7071067811 

60 scale2: tl.constexpr = 0.3989422803 

61 x_fp32 = x.to(tl.float32) 

62 x_sqrt = scale1 * x_fp32 

63 dydx = scale2 * x_fp32 * exp(-x_sqrt * x_sqrt) + 0.5 * fast_erf(x_sqrt) + 0.5 

64 dx = dydx * dy 

65 return dx 

66 

67 

68@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

69@triton.jit 

70def gelu_backward_tanh(x, dy): 

71 x_fp32 = x.to(tl.float32) 

72 c1 = 0.79788456 

73 c2 = 0.044715 

74 tanh_out = fast_tanh(c1 * x_fp32 + c1 * x_fp32 * c2 * x_fp32 * x_fp32) 

75 dydx = ( 

76 0.5 * ((x - x * tanh_out * tanh_out) * (c1 + 0.1070322243 * x_fp32 * x_fp32)) 

77 + 0.5 

78 + 0.5 * tanh_out 

79 ) 

80 dx = dydx * dy 

81 return dx 

82 

83 

84def gelu(self, *, approximate="none"): 

85 logger.debug("GEMS_CAMBRICON GELU FORWARD") 

86 if approximate == "tanh": 

87 return gelu_tanh(self, False) 

88 else: 

89 A = self.contiguous() 

90 out = torch.empty_like(A) 

91 N = A.numel() 

92 if N == 0: 

93 return out 

94 grid_fn = lambda meta: ( 

95 min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

96 ) 

97 with torch_device_fn.device(A.device): 

98 gelu_none_kernel[grid_fn](A, out, N) 

99 return out 

100 

101 

102def gelu_backward(grad_output, self, *, approximate="none"): 

103 logger.debug("GEMS_CAMBRICON GELU BACKWARD") 

104 if approximate == "tanh": 

105 return gelu_backward_tanh(self, grad_output) 

106 else: 

107 return gelu_backward_none(self, grad_output) 

108 

109 

110def gelu_(A, *, approximate="none"): 

111 logger.debug("GEMS_CAMBRICON GELU_ FORWARD") 

112 if approximate == "tanh": 

113 return gelu_tanh(A, True, out0=A) 

114 else: 

115 A_contig = A.contiguous() 

116 N = A_contig.numel() 

117 if N == 0: 

118 return A 

119 grid_fn = lambda meta: ( 

120 min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

121 ) 

122 with torch_device_fn.device(A.device): 

123 gelu_none_kernel[grid_fn](A_contig, A_contig, N) 

124 if not A.is_contiguous(): 

125 A.copy_(A_contig) 

126 return A