Coverage for src/flag_gems/runtime/backend/_spacemit/ops/gelu.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
10erf = tl_extra_shim.erf
11exp = tl_extra_shim.exp
12pow = tl_extra_shim.pow
13tanh = tl_extra_shim.tanh
14geluTanh = tl_extra_shim.gelu_tanh
15geluNone = tl_extra_shim.gelu_none
18@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
19@triton.jit
20def gelu_none(x):
21 # scale: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
22 # output = 0.5 * x * (1 + erf(x * scale))
23 output = geluNone(x.to(tl.float32))
24 return output
27@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
28@triton.jit
29def gelu_tanh(x):
30 output = (
31 # 0.5 * x * (1 + tanh(x * 0.79788456 * (1 + 0.044715 * pow(x.to(tl.float32), 2))))
32 geluTanh(x.to(tl.float32))
33 )
34 return output
37@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
38@triton.jit
39def gelu_backward_none(x, dy):
40 scale1: tl.constexpr = 0.7071067811 # 1 / math.sqrt(2)
41 scale2: tl.constexpr = 0.3989422803 # 1 / math.sqrt(2 * math.pi)
42 x_fp32 = x.to(tl.float32)
43 dydx = (
44 scale2 * x_fp32 * exp(-pow(scale1 * x_fp32, 2))
45 + 0.5 * erf(scale1 * x_fp32)
46 + 0.5
47 )
48 dx = dydx * dy
49 return dx
52@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
53@triton.jit
54def gelu_backward_tanh(x, dy):
55 x_fp32 = x.to(tl.float32)
56 # 0.79788456 = math.sqrt(2 / math.pi)
57 tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * pow(x_fp32, 2)))
58 dydx = 0.5 * x * (
59 (1 - pow(tanh_out, 2)) * (0.79788456 + 0.1070322243 * pow(x_fp32, 2))
60 ) + 0.5 * (1 + tanh_out)
61 dx = dydx * dy
62 return dx
65class Gelu(torch.autograd.Function):
66 @staticmethod
67 def forward(ctx, A, approximate):
68 logging.debug("GEMS_SPACEMIT GELU_FORWARD")
69 if approximate == "tanh":
70 out = gelu_tanh(A)
71 else:
72 out = gelu_none(A)
73 ctx.save_for_backward(A)
74 ctx.approximate = approximate
75 return out
77 @staticmethod
78 def backward(ctx, out_grad):
79 logging.debug("GEMS_SPACEMIT GELU_BACKWARD")
80 (inp,) = ctx.saved_tensors
81 approximate = ctx.approximate
82 if approximate == "tanh":
83 in_grad = gelu_backward_tanh(inp, out_grad)
84 else:
85 in_grad = gelu_backward_none(inp, out_grad)
86 return in_grad, None
89def gelu(A, *, approximate="none"):
90 # print("\n.......test for mutibackend specific gelu........\n")
91 return Gelu.apply(A, approximate)