Coverage for src/flag_gems/runtime/backend/_spacemit/ops/pow.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
10_pow = tl_extra_shim.pow
11logger = logging.getLogger(__name__)
14@pointwise_dynamic(promotion_methods=[(0, 1, "BOOL_TO_LONG")])
15@triton.jit
16def pow_func(x, exponent):
17 return _pow(x.to(tl.float32), exponent.to(tl.float32))
20def pow_tensor_tensor(A, exponent):
21 logger.debug("GEMS_SPACEMIT POW_TENSOR_TENSOR")
22 return pow_func(A, exponent)
25def pow_tensor_tensor_(A, exponent):
26 logger.debug("GEMS_SPACEMIT POW_TENSOR_TENSOR_")
27 return pow_func(A, exponent, out0=A)
30@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "BOOL_TO_LONG")])
31@triton.jit
32def pow_func_tensor_scalar(x, exponent):
33 return _pow(x.to(tl.float32), exponent.to(tl.float32))
36@triton.jit
37def pow_by_mul_kernel(
38 X_ptr,
39 Out_ptr,
40 n_elements,
41 exponent: tl.constexpr,
42 BLOCK_SIZE: tl.constexpr,
43):
44 pid = tl.program_id(0)
45 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
46 mask = offsets < n_elements
47 x = tl.load(X_ptr + offsets, mask=mask)
48 result = x
49 for _ in range(1, exponent):
50 result = result * x
51 tl.store(Out_ptr + offsets, result, mask=mask)
54def _pow_by_mul(A, exponent):
55 out = torch.empty_like(A)
56 n_elements = A.numel()
57 BLOCK_SIZE = 1024
58 grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE,)
59 pow_by_mul_kernel[grid](
60 A,
61 out,
62 n_elements,
63 exponent=exponent,
64 BLOCK_SIZE=BLOCK_SIZE,
65 )
66 return out
69def pow_tensor_scalar(A, exponent):
70 logger.debug("GEMS_SPACEMIT POW_TENSOR_SCALAR")
71 has_neg = bool((A < 0).any().item())
72 if has_neg and isinstance(exponent, int) and exponent > 0:
73 return _pow_by_mul(A, exponent)
74 return pow_func_tensor_scalar(A, exponent)
77def pow_tensor_scalar_(A, exponent):
78 logger.debug("GEMS_SPACEMIT POW_TENSOR_SCALAR_")
79 has_neg = bool((A < 0).any().item())
80 if has_neg and isinstance(exponent, int) and exponent > 0:
81 result = _pow_by_mul(A, exponent)
82 A.copy_(result)
83 return A
85 return pow_func_tensor_scalar(A, exponent, out0=A)
88@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")])
89@triton.jit
90def pow_func_scalar_tensor(x, exponent):
91 return _pow(x.to(tl.float32), exponent)
94def pow_scalar(A, exponent):
95 logger.debug("GEMS_SPACEMIT POW_SCALAR")
96 return pow_func_scalar_tensor(A, exponent)