Coverage for src/flag_gems/runtime/backend/_spacemit/ops/gelu.py: 0%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import tl_extra_shim
7from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
9erf = tl_extra_shim.erf
10exp = tl_extra_shim.exp
11pow = tl_extra_shim.pow
12tanh = tl_extra_shim.tanh
13geluTanh = tl_extra_shim.gelu_tanh
14geluNone = tl_extra_shim.gelu_none
17@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
18@triton.jit
19def gelu_none(x):
20 # scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
21 # output = 0.5 * x * (1 + erf(x * scale))
22 output = geluNone(x.to(tl.float32))
23 return output
26@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
27@triton.jit
28def gelu_tanh(x):
29 output = (
30 # 0.5 * x * (1 + tanh(x * 0.79788456 * (1 + 0.044715 * pow(x.to(tl.float32), 2))))
31 geluTanh(x.to(tl.float32))
32 )
33 return output
36@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
37@triton.jit
38def gelu_backward_none(x, dy):
39 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
40 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
41 x_fp32 = x.to(tl.float32)
42 dydx = (
43 scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2))
44 + 0.5 * erf(scale1 * x_fp32)
45 + 0.5
46 )
47 dx = dydx * dy
48 return dx
51@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
52@triton.jit
53def gelu_backward_tanh(x, dy):
54 x_fp32 = x.to(tl.float32)
55 # 0.79788456 = math.sqrt(2 / math.pi)
56 tanh_out = tanh(0.79788456 * x_fp32 * (1 + 0.044715 * pow(x_fp32, 2)))
57 dydx = 0.5 * x_fp32 * (
58 (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2))
59 ) + 0.5 * (1 + tanh_out)
60 dx = dydx * dy
61 return dx
64def gelu(A, *, approximate="none"):
65 logging.debug("GEMS_SPACEMIT GELU FORWARD")
66 if approximate == "tanh":
67 out = gelu_tanh(A)
68 else:
69 out = gelu_none(A)
70 return out
73def gelu_backward(grad_output, self, *, approximate="none"):
74 logging.debug("GEMS_SPACEMIT GELU_BACKWARD")
75 if approximate == "tanh":
76 in_grad = gelu_backward_tanh(self, grad_output)
77 else:
78 in_grad = gelu_backward_none(self, grad_output)
79 return in_grad
82def gelu_(A, *, approximate="none"):
83 logging.debug("GEMS_SPACEMIT GELU_ FORWARD")
84 if approximate == "tanh":
85 out = gelu_tanh(A, out0=A)
86 else:
87 out = gelu_none(A, out0=A)
88 return out