Coverage for src/flag_gems/runtime/backend/_spacemit/ops/silu.py: 0%
54 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
10div_rn = tl_extra_shim.div_rn
11_silu = tl_extra_shim.silu
14@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
15@triton.jit
16def silu_forward(x):
17 x_fp32 = x.to(tl.float32)
18 # y = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))
19 y = _silu(x_fp32)
20 return y
23@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
24@triton.jit
25def silu_backward(x, dy):
26 dy_fp32 = dy.to(tl.float32)
27 x_fp32 = x.to(tl.float32)
28 sigma = div_rn(1.0, 1.0 + tl.exp(-x_fp32))
29 dx = dy_fp32 * sigma * (1.0 + x_fp32 * (1.0 - sigma))
30 return dx
33class Silu(torch.autograd.Function):
34 @staticmethod
35 def forward(ctx, A):
36 logging.debug("GEMS_SPACEMIT SILU_FORWARD")
37 out = silu_forward(A)
38 ctx.save_for_backward(A)
39 return out
41 @staticmethod
42 def backward(ctx, out_grad):
43 logging.debug("GEMS_SPACEMIT SILU_BACKWARD")
44 (inp,) = ctx.saved_tensors
45 in_grad = silu_backward(inp, out_grad)
46 return in_grad
49def silu(A):
50 return Silu.apply(A)
53class InplaceSilu(torch.autograd.Function):
54 @staticmethod
55 def forward(ctx, A):
56 logging.debug("GEMS_SPACEMIT SILU__FORWARD")
57 ctx.save_for_backward(A.clone())
58 ctx.mark_dirty(A)
59 out = silu_forward(A, out0=A)
60 return out
62 @staticmethod
63 def backward(ctx, out_grad):
64 logging.debug("GEMS_SPACEMIT SILU__BACKWARD")
65 (inp,) = ctx.saved_tensors
66 in_grad = silu_backward(inp, out_grad)
67 return in_grad
70def silu_(A):
71 InplaceSilu.apply(A)
72 return A