Coverage for src/flag_gems/runtime/backend/_spacemit/ops/silu.py: 0%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import tl_extra_shim 

8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

9 

10div_rn = tl_extra_shim.div_rn 

11_silu = tl_extra_shim.silu 

12 

13 

14@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

15@triton.jit 

16def silu_forward(x): 

17 x_fp32 = x.to(tl.float32) 

18 # y = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

19 y = _silu(x_fp32) 

20 return y 

21 

22 

23@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

24@triton.jit 

25def silu_backward(x, dy): 

26 dy_fp32 = dy.to(tl.float32) 

27 x_fp32 = x.to(tl.float32) 

28 sigma = div_rn(1.0, 1.0 + tl.exp(-x_fp32)) 

29 dx = dy_fp32 * sigma * (1.0 + x_fp32 * (1.0 - sigma)) 

30 return dx 

31 

32 

33class Silu(torch.autograd.Function): 

34 @staticmethod 

35 def forward(ctx, A): 

36 logging.debug("GEMS_SPACEMIT SILU_FORWARD") 

37 out = silu_forward(A) 

38 ctx.save_for_backward(A) 

39 return out 

40 

41 @staticmethod 

42 def backward(ctx, out_grad): 

43 logging.debug("GEMS_SPACEMIT SILU_BACKWARD") 

44 (inp,) = ctx.saved_tensors 

45 in_grad = silu_backward(inp, out_grad) 

46 return in_grad 

47 

48 

49def silu(A): 

50 return Silu.apply(A) 

51 

52 

53class InplaceSilu(torch.autograd.Function): 

54 @staticmethod 

55 def forward(ctx, A): 

56 logging.debug("GEMS_SPACEMIT SILU__FORWARD") 

57 ctx.save_for_backward(A.clone()) 

58 ctx.mark_dirty(A) 

59 out = silu_forward(A, out0=A) 

60 return out 

61 

62 @staticmethod 

63 def backward(ctx, out_grad): 

64 logging.debug("GEMS_SPACEMIT SILU__BACKWARD") 

65 (inp,) = ctx.saved_tensors 

66 in_grad = silu_backward(inp, out_grad) 

67 return in_grad 

68 

69 

70def silu_(A): 

71 InplaceSilu.apply(A) 

72 return A