Coverage for src/flag_gems/runtime/backend/_cambricon/ops/gelu.py: 0%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from triton.language.extra.mlu.libdevice import fast_erf, fast_tanh
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner, tl_extra_shim
11from ..utils import TOTAL_CORE_NUM
12from ..utils.pointwise_dynamic import pointwise_dynamic
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16exp = tl_extra_shim.exp
19@libentry()
20@libtuner(
21 configs=[
22 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
23 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
24 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
25 triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1),
26 ],
27 key=["n_elements"],
28)
29@triton.jit
30def gelu_none_kernel(X_ptr, OUT_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
31 pid = tl.program_id(0)
32 num_jobs = tl.num_programs(0)
33 block_start = pid * BLOCK_SIZE
34 step = num_jobs * BLOCK_SIZE
35 block_start = block_start.to(tl.int64)
36 scale: tl.constexpr = 0.7071067811
37 for off in range(block_start, n_elements, step):
38 offsets = off + tl.arange(0, BLOCK_SIZE)
39 mask = offsets < n_elements
40 x = tl.load(X_ptr + offsets, mask=mask)
41 x_f32 = x.to(tl.float32)
42 result = 0.5 * x_f32 + 0.5 * x_f32 * fast_erf(x_f32 * scale)
43 tl.store(OUT_ptr + offsets, result.to(x.dtype), mask=mask)
46@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
47@triton.jit
48def gelu_tanh(x, inplace):
49 x_f32 = x.to(tl.float32)
50 output = 0.5 * x_f32 + 0.5 * x_f32 * fast_tanh(
51 x_f32 * 0.79788456 + x_f32 * 0.79788456 * 0.044715 * x_f32 * x_f32
52 )
53 return output
56@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
57@triton.jit
58def gelu_backward_none(x, dy):
59 scale1: tl.constexpr = 0.7071067811
60 scale2: tl.constexpr = 0.3989422803
61 x_fp32 = x.to(tl.float32)
62 x_sqrt = scale1 * x_fp32
63 dydx = scale2 * x_fp32 * exp(-x_sqrt * x_sqrt) + 0.5 * fast_erf(x_sqrt) + 0.5
64 dx = dydx * dy
65 return dx
68@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
69@triton.jit
70def gelu_backward_tanh(x, dy):
71 x_fp32 = x.to(tl.float32)
72 c1 = 0.79788456
73 c2 = 0.044715
74 tanh_out = fast_tanh(c1 * x_fp32 + c1 * x_fp32 * c2 * x_fp32 * x_fp32)
75 dydx = (
76 0.5 * ((x - x * tanh_out * tanh_out) * (c1 + 0.1070322243 * x_fp32 * x_fp32))
77 + 0.5
78 + 0.5 * tanh_out
79 )
80 dx = dydx * dy
81 return dx
84def gelu(self, *, approximate="none"):
85 logger.debug("GEMS_CAMBRICON GELU FORWARD")
86 if approximate == "tanh":
87 return gelu_tanh(self, False)
88 else:
89 A = self.contiguous()
90 out = torch.empty_like(A)
91 N = A.numel()
92 if N == 0:
93 return out
94 grid_fn = lambda meta: (
95 min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
96 )
97 with torch_device_fn.device(A.device):
98 gelu_none_kernel[grid_fn](A, out, N)
99 return out
102def gelu_backward(grad_output, self, *, approximate="none"):
103 logger.debug("GEMS_CAMBRICON GELU BACKWARD")
104 if approximate == "tanh":
105 return gelu_backward_tanh(self, grad_output)
106 else:
107 return gelu_backward_none(self, grad_output)
110def gelu_(A, *, approximate="none"):
111 logger.debug("GEMS_CAMBRICON GELU_ FORWARD")
112 if approximate == "tanh":
113 return gelu_tanh(A, True, out0=A)
114 else:
115 A_contig = A.contiguous()
116 N = A_contig.numel()
117 if N == 0:
118 return A
119 grid_fn = lambda meta: (
120 min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
121 )
122 with torch_device_fn.device(A.device):
123 gelu_none_kernel[grid_fn](A_contig, A_contig, N)
124 if not A.is_contiguous():
125 A.copy_(A_contig)
126 return A