Coverage for src/flag_gems/runtime/backend/_sunrise/ops/sigmoid.py: 0%
32 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic, tl_extra_shim
7from flag_gems.utils.pointwise_dynamic import CodeGenConfig
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
10exp2 = tl_extra_shim.exp2
13MAX_GRID_SIZES = (65535, 65535, 65535)
14config = CodeGenConfig(
15 max_tile_size=1024,
16 max_grid_size=MAX_GRID_SIZES,
17 max_num_warps_per_cta=32,
18 prefer_block_pointer=True,
19 prefer_1d_tile=True,
20)
23@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")], config=config)
24@triton.jit
25def sigmoid_forward(x):
26 # log2e: tl.constexpr = math.log2(math.e)
27 # triton 3.0.0 disallow calling non-jitted function inside jitted function, even if it is in
28 # the rhs of an assignment to a constexpr, so we use numeric literal instead to work around this.
29 log2e: tl.constexpr = 1.4426950408889634
30 return 1 / (1 + exp2(-x.to(tl.float32) * log2e))
33@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")])
34@triton.jit
35def sigmoid_backward_kernel(dy, y):
36 y_f32 = y.to(tl.float32)
37 dy_f32 = dy.to(tl.float32)
38 return dy_f32 * (1.0 - y_f32) * y_f32
41def sigmoid(self):
42 logger.debug("GEMS SIGMOID FORWARD")
43 output = sigmoid_forward(self)
44 return output
47def sigmoid_backward(grad_output, output):
48 logger.debug("GEMS SIGMOID BACKWARD")
49 grad_input = sigmoid_backward_kernel(grad_output, output)
50 return grad_input
53def sigmoid_(A):
54 logger.debug("GEMS SIGMOID_ FORWARD")
55 out = sigmoid_forward(A, out0=A)
56 return out