Coverage for src/flag_gems/runtime/backend/_sunrise/ops/pow.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

7from flag_gems.utils.pointwise_dynamic import CodeGenConfig 

8 

9_pow = tl_extra_shim.pow 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12MAX_GRID_SIZES = (65535, 65535, 65535) 

13config = CodeGenConfig( 

14 max_tile_size=512, 

15 max_grid_size=MAX_GRID_SIZES, 

16 max_num_warps_per_cta=32, 

17 prefer_block_pointer=False, 

18 prefer_1d_tile=True, 

19) 

20 

21 

22@pointwise_dynamic(promotion_methods=[(0, 1, "BOOL_TO_LONG")], config=config) 

23@triton.jit 

24def pow_func(x, exponent): 

25 return _pow(x.to(tl.float32), exponent.to(tl.float32)) 

26 

27 

28def pow_tensor_tensor(A, exponent): 

29 logger.debug("GEMS POW_TENSOR_TENSOR") 

30 return pow_func(A, exponent) 

31 

32 

33def pow_tensor_tensor_(A, exponent): 

34 logger.debug("GEMS POW_TENSOR_TENSOR_") 

35 return pow_func(A, exponent, out0=A) 

36 

37 

38@pointwise_dynamic( 

39 is_tensor=[True, False], promotion_methods=[(0, 1, "BOOL_TO_LONG")], config=config 

40) 

41@triton.jit 

42def pow_func_tensor_scalar(x, exponent): 

43 return _pow(x.to(tl.float32), exponent.to(tl.float32)) 

44 

45 

46def pow_tensor_scalar(A, exponent): 

47 logger.debug("GEMS POW_TENSOR_SCALAR") 

48 return pow_func_tensor_scalar(A, exponent) 

49 

50 

51def pow_tensor_scalar_(A, exponent): 

52 logger.debug("GEMS POW_TENSOR_SCALAR_") 

53 return pow_func_tensor_scalar(A, exponent, out0=A) 

54 

55 

56@pointwise_dynamic( 

57 is_tensor=[False, True], promotion_methods=[(0, 1, "BOOL_TO_LONG")], config=config 

58) 

59@triton.jit 

60def pow_func_scalar_tensor(x, exponent): 

61 return _pow(x.to(tl.float32), exponent.to(tl.float32)) 

62 

63 

64def pow_scalar(A, exponent): 

65 logger.debug("GEMS POW_SCALAR") 

66 return pow_func_scalar_tensor(A, exponent)