Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/arange.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.jit 

17def arange_func( 

18 y_ptr, 

19 start, 

20 end, 

21 step, 

22 size, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = ext.program_id(0) 

26 offset = pid * BLOCK_SIZE 

27 step_offset = offset * step 

28 

29 cols = tl.arange(0, BLOCK_SIZE) 

30 arange_val = cols * step + step_offset + start 

31 mask = cols + offset < size 

32 tl.store(y_ptr + offset + cols, arange_val, mask=mask) 

33 

34 

35def arange_start( 

36 start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None 

37): 

38 logger.debug("GEMS_KUNLUNXIN ARANGE") 

39 if dtype is torch.int64: 

40 start = int(start) 

41 end = int(end) 

42 step = int(step) 

43 if step == 0: 

44 raise RuntimeError("step must be nonzero") 

45 sgn = (step > 0) - (step < 0) 

46 size = (end - start + step - sgn) // step 

47 else: 

48 if dtype is torch.int64 and ( 

49 isinstance(step, float) 

50 or isinstance(start, float) 

51 or isinstance(end, float) 

52 ): 

53 int_step = int(step) 

54 if int_step == 0: 

55 raise RuntimeError("step must be nonzero") 

56 size = math.ceil((end - start) / step) 

57 size = int(size) 

58 

59 if dtype is None: 

60 dtype = torch.int64 

61 

62 if pin_memory is None: 

63 pin_memory = False 

64 

65 if device is None: 

66 device = runtime.device.name 

67 

68 # Size-based heuristic for BLOCK_SIZE and num_warps 

69 if size <= 1024: 

70 BLOCK_SIZE = 256 

71 num_warps = 2 

72 elif size <= 8192: 

73 BLOCK_SIZE = 1024 

74 num_warps = 4 

75 elif size <= 65536: 

76 BLOCK_SIZE = 4096 

77 num_warps = 8 

78 else: 

79 BLOCK_SIZE = 8192 

80 num_warps = 8 

81 

82 grid = triton.cdiv(size, BLOCK_SIZE) 

83 

84 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory) 

85 arange_func[grid,](result, start, end, step, size, BLOCK_SIZE, num_warps=num_warps) 

86 return result 

87 

88 

89def arange(end, *, dtype=None, layout=None, device=None, pin_memory=None): 

90 return arange_start( 

91 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 

92 )