Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/elu.py: 0%
29 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import triton
4import triton.language as tl
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(
12 is_tensor=[True, False, False, False], promotion_methods=[(0, "DEFAULT")]
13)
14@triton.jit
15def elu_forward_kernel(x, alpha, scale, input_scale):
16 x_fp32 = x.to(tl.float32)
17 return tl.where(
18 x_fp32 > 0,
19 scale * input_scale * x_fp32,
20 scale * alpha * (tl.exp(x_fp32 * input_scale) - 1),
21 )
24@pointwise_dynamic(
25 is_tensor=[True, True, False, False, False, False],
26 promotion_methods=[(0, 1, "DEFAULT")],
27)
28@triton.jit
29def elu_backward_kernel(grad_output, x, alpha, scale, input_scale, is_result):
30 x_fp32 = x.to(tl.float32)
31 grad_pos = grad_output * scale * input_scale
32 if is_result:
33 grad_neg = grad_output * input_scale * (x_fp32 + scale * alpha)
34 else:
35 grad_neg = (
36 grad_output * scale * alpha * input_scale * tl.exp(x_fp32 * input_scale)
37 )
39 return tl.where(x_fp32 > 0, grad_pos, grad_neg)
42def elu(A, alpha=1.0, scale=1.0, input_scale=1.0):
43 logger.debug("GEMS_KUNLUNXIN ELU")
44 return elu_forward_kernel(A, alpha, scale, input_scale)
47def elu_(A, alpha=1.0, scale=1.0, input_scale=1.0):
48 logger.debug("GEMS_KUNLUNXIN ELU_")
49 return elu_forward_kernel(A, alpha, scale, input_scale, out0=A)
52def elu_backward(grad_output, alpha, scale, input_scale, is_result, self_or_result):
53 logger.debug("GEMS_KUNLUNXIN ELU BACKWARD")
54 grad_input = elu_backward_kernel(
55 grad_output, self_or_result, alpha, scale, input_scale, is_result
56 )
57 return grad_input