Coverage for src/flag_gems/runtime/backend/_sunrise/ops/gelu.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic, tl_extra_shim
7from flag_gems.utils.pointwise_dynamic import CodeGenConfig
9erf = tl_extra_shim.erf
10exp = tl_extra_shim.exp
11pow = tl_extra_shim.pow
12tanh = tl_extra_shim.tanh
14MAX_GRID_SIZES = (65535, 65535, 65535)
15config = CodeGenConfig(
16 max_tile_size=1024,
17 max_grid_size=MAX_GRID_SIZES,
18 max_num_warps_per_cta=32,
19 prefer_block_pointer=False,
20 prefer_1d_tile=True,
21)
23logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
26@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
27@triton.jit
28def gelu_none(x):
29 scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
30 # output = 0.5 * x * (1 + erf(x * scale)) # "__ocml_erf_f32" 只有支持f32输入的
31 x_fp32 = x.to(tl.float32)
32 output = 0.5 * x_fp32 * (1 + erf(x_fp32 * scale))
34 return output
37@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
38@triton.jit
39def gelu_tanh(x):
40 x_fp32 = x.to(tl.float32)
41 output = 0.5 * x * (1 + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32, 2))))
42 return output
45@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
46@triton.jit
47def gelu_backward_none(x, dy):
48 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
49 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
50 x_fp32 = x.to(tl.float32)
51 dydx = (
52 scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2))
53 + 0.5 * erf(scale1 * x_fp32)
54 + 0.5
55 )
56 dx = dydx * dy
57 return dx
60@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
61@triton.jit
62def gelu_backward_tanh(x, dy):
63 x_fp32 = x.to(tl.float32)
64 # 0.79788456 = math.sqrt(2 / math.pi)
65 tanh_out = tanh(0.79788456 * x_fp32 * (1 + 0.044715 * pow(x_fp32, 2)))
66 dydx = 0.5 * x_fp32 * (
67 (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2))
68 ) + 0.5 * (1 + tanh_out)
69 dx = dydx * dy
70 return dx
73def gelu(self, *, approximate="none"):
74 logger.debug("GEMS GELU FORWARD")
75 if approximate == "tanh":
76 out = gelu_tanh(self)
77 else:
78 out = gelu_none(self)
79 return out
82def gelu_backward(grad_output, self, *, approximate="none"):
83 logger.debug("GEMS GELU BACKWARD")
84 if approximate == "tanh":
85 in_grad = gelu_backward_tanh(self, grad_output)
86 else:
87 in_grad = gelu_backward_none(self, grad_output)
88 return in_grad
91def gelu_(A, *, approximate="none"):
92 logger.debug("GEMS GELU_ FORWARD")
93 if approximate == "tanh":
94 out = gelu_tanh(A, out0=A)
95 else:
96 out = gelu_none(A, out0=A)
97 return out