Coverage for src/flag_gems/ops/square.py: 90%
21 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import triton
5from flag_gems.utils import pointwise_dynamic
7logger = logging.getLogger(__name__)
10@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
11@triton.jit
12def square_func(x):
13 return x * x
16def square(A):
17 logger.debug("GEMS SQUARE")
18 return square_func(A)
21def square_out(A, *, out=None):
22 logger.debug("GEMS SQUARE_OUT")
23 if out is None:
24 return square_func(A)
25 square_func(A, out0=out)
26 return out
29def square_(A):
30 logger.debug("GEMS SQUARE_")
31 square_func(A, out0=A)
32 return A