Coverage for src/flag_gems/runtime/backend/_ascend/ops/pow.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

7 

8_pow = tl_extra_shim.pow 

9 

10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

11 

12 

13@pointwise_dynamic(promotion_methods=[(0, 1, "BOOL_TO_LONG")]) 

14@triton.jit 

15def pow_func(x, exponent): 

16 return _pow(x.to(tl.float32), exponent) 

17 

18 

19def pow_tensor_tensor(A, exponent): 

20 logger.debug("GEMS_ASCEND POW_TENSOR_TENSOR") 

21 return pow_func(A, exponent) 

22 

23 

24def pow_tensor_tensor_(A, exponent): 

25 logger.debug("GEMS_ASCEND POW_TENSOR_TENSOR_") 

26 out = pow_func(A, exponent) 

27 A.copy_(out) 

28 return A 

29 

30 

31@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "BOOL_TO_LONG")]) 

32@triton.jit 

33def pow_func_tensor_scalar(x, exponent): 

34 return _pow(x.to(tl.float32), exponent) 

35 

36 

37def pow_tensor_scalar(A, exponent): 

38 logger.debug("GEMS_ASCEND POW_TENSOR_SCALAR") 

39 return pow_func_tensor_scalar(A, exponent) 

40 

41 

42def pow_tensor_scalar_(A, exponent): 

43 logger.debug("GEMS_ASCEND POW_TENSOR_SCALAR_") 

44 out = pow_func_tensor_scalar(A, exponent) 

45 A.copy_(out) 

46 return A 

47 

48 

49@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")]) 

50@triton.jit 

51def pow_func_scalar_tensor(x, exponent): 

52 return _pow(x.to(tl.float32), exponent) 

53 

54 

55def pow_scalar(A, exponent): 

56 logger.debug("GEMS_ASCEND POW_SCALAR") 

57 return pow_func_scalar_tensor(A, exponent)