Coverage for src/flag_gems/runtime/backend/_cambricon/ops/arange.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry, libtuner 

10 

11from ..utils import TOTAL_CORE_NUM 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=[ 

19 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=3, num_warps=1), 

20 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=3, num_warps=1), 

21 triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_stages=3, num_warps=1), 

22 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=3, num_warps=1), 

23 ], 

24 key=["size"], 

25 strategy=["log"], 

26) 

27@triton.jit 

28def arange_func(y_ptr, start, end, step, size, BLOCK_SIZE: tl.constexpr): 

29 pid = tl.program_id(axis=0) 

30 num_jobs = tl.num_programs(axis=0) 

31 block_start = pid * BLOCK_SIZE 

32 block_step = num_jobs * BLOCK_SIZE 

33 block_start = block_start 

34 for block_start_offset in range(block_start, size, block_step): 

35 offset = tl.arange(0, BLOCK_SIZE) + block_start_offset 

36 arange_val = offset * step + start 

37 tl.store(y_ptr + offset, arange_val, mask=offset < size) 

38 

39 

40def arange_start( 

41 start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None 

42): 

43 logger.debug("GEMS_CAMBRICON ARANGE") 

44 if dtype is torch.int64: 

45 start = int(start) 

46 end = int(end) 

47 step = int(step) 

48 if step == 0: 

49 raise RuntimeError("step must be nonzero") 

50 sgn = (step > 0) - (step < 0) 

51 size = (end - start + step - sgn) // step 

52 else: 

53 if dtype is torch.int64 and ( 

54 isinstance(step, float) 

55 or isinstance(start, float) 

56 or isinstance(end, float) 

57 ): 

58 int_step = int(step) 

59 if int_step == 0: 

60 raise RuntimeError("step must be nonzero") 

61 size = math.ceil((end - start) / step) 

62 size = int(size) 

63 

64 assert ( 

65 size < torch.iinfo(torch.int32).max 

66 ), f"Size {size} is not less than the maximum int32 value max_int32" 

67 

68 grid = lambda META: (min(triton.cdiv(size, META["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

69 

70 if dtype is None: 

71 dtype = torch.int64 

72 

73 if pin_memory is None: 

74 pin_memory = False 

75 

76 if device is None: 

77 device = ( 

78 runtime.device.name 

79 ) # Note(Zhengzekang): Torch default value is CPU, but triton is target to GPU. 

80 

81 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory) 

82 arange_func[grid](result, start, end, step, size) 

83 return result 

84 

85 

86def arange(end, *, dtype=None, layout=None, device=None, pin_memory=None): 

87 return arange_start( 

88 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 

89 )