Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/gelu.py: 0%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import tl_extra_shim
8from ..utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11erf = tl_extra_shim.erf
12exp = tl_extra_shim.exp
13pow = tl_extra_shim.pow
14tanh = tl_extra_shim.tanh
17@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
18@triton.jit
19def gelu_none(x):
20 scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
21 output = 0.5 * x * (1 + erf(x * scale))
22 return output
25@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
26@triton.jit
27def gelu_tanh(x):
28 x_fp32 = x.to(tl.float32)
29 output = 0.5 * x * (1 + tanh(x * 0.79788456 * (1 + 0.044715 * x_fp32 * x_fp32)))
30 return output
33@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
34@triton.jit
35def gelu_backward_none(x, dy):
36 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
37 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
38 x_fp32 = x.to(tl.float32)
39 scaled_x = scale1 * x_fp32
40 dydx = scale2 * x_fp32 * tl.exp(-scaled_x * scaled_x) + 0.5 * erf(scaled_x) + 0.5
41 dx = dydx * dy
42 return dx
45@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
46@triton.jit
47def gelu_backward_tanh(x, dy):
48 x_fp32 = x.to(tl.float32)
49 x_sq = x_fp32 * x_fp32
50 # 0.79788456 = math.sqrt(2 / math.pi)
51 tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x_sq))
52 dydx = 0.5 * x * (
53 (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x_sq)
54 ) + 0.5 * (1 + tanh_out)
55 dx = dydx * dy
56 return dx
59def gelu(self, *, approximate="none"):
60 logger.debug("GEMS_KUNLUNXIN GELU FORWARD")
61 if approximate == "tanh":
62 out = gelu_tanh(self)
63 else:
64 out = gelu_none(self)
65 return out
68def gelu_backward(grad_output, self, *, approximate="none"):
69 logger.debug("GEMS_KUNLUNXIN GELU BACKWARD")
70 if approximate == "tanh":
71 in_grad = gelu_backward_tanh(self, grad_output)
72 else:
73 in_grad = gelu_backward_none(self, grad_output)
74 return in_grad
77def gelu_(A, *, approximate="none"):
78 logger.debug("GEMS_KUNLUNXIN GELU_ FORWARD")
79 if approximate == "tanh":
80 out = gelu_tanh(A, out0=A)
81 else:
82 out = gelu_none(A, out0=A)
83 return out