Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/elu.py: 0%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic( 

12 is_tensor=[True, False, False, False], promotion_methods=[(0, "DEFAULT")] 

13) 

14@triton.jit 

15def elu_forward_kernel(x, alpha, scale, input_scale): 

16 x_fp32 = x.to(tl.float32) 

17 return tl.where( 

18 x_fp32 > 0, 

19 scale * input_scale * x_fp32, 

20 scale * alpha * (tl.exp(x_fp32 * input_scale) - 1), 

21 ) 

22 

23 

24@pointwise_dynamic( 

25 is_tensor=[True, True, False, False, False, False], 

26 promotion_methods=[(0, 1, "DEFAULT")], 

27) 

28@triton.jit 

29def elu_backward_kernel(grad_output, x, alpha, scale, input_scale, is_result): 

30 x_fp32 = x.to(tl.float32) 

31 grad_pos = grad_output * scale * input_scale 

32 if is_result: 

33 grad_neg = grad_output * input_scale * (x_fp32 + scale * alpha) 

34 else: 

35 grad_neg = ( 

36 grad_output * scale * alpha * input_scale * tl.exp(x_fp32 * input_scale) 

37 ) 

38 

39 return tl.where(x_fp32 > 0, grad_pos, grad_neg) 

40 

41 

42def elu(A, alpha=1.0, scale=1.0, input_scale=1.0): 

43 logger.debug("GEMS_KUNLUNXIN ELU") 

44 return elu_forward_kernel(A, alpha, scale, input_scale) 

45 

46 

47def elu_(A, alpha=1.0, scale=1.0, input_scale=1.0): 

48 logger.debug("GEMS_KUNLUNXIN ELU_") 

49 return elu_forward_kernel(A, alpha, scale, input_scale, out0=A) 

50 

51 

52def elu_backward(grad_output, alpha, scale, input_scale, is_result, self_or_result): 

53 logger.debug("GEMS_KUNLUNXIN ELU BACKWARD") 

54 grad_input = elu_backward_kernel( 

55 grad_output, self_or_result, alpha, scale, input_scale, is_result 

56 ) 

57 return grad_input