Coverage for src/flag_gems/runtime/backend/_cambricon/ops/ceil.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry, libtuner 

9 

10from ..utils import TOTAL_CORE_NUM 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=[ 

18 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

19 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

20 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

21 triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1), 

22 ], 

23 key=["n_elements"], 

24) 

25@triton.jit 

26def ceil_kernel( 

27 X_ptr, 

28 OUT_ptr, 

29 n_elements, 

30 BLOCK_SIZE: tl.constexpr, 

31): 

32 pid = tl.program_id(0) 

33 num_jobs = tl.num_programs(0) 

34 block_start = pid * BLOCK_SIZE 

35 step = num_jobs * BLOCK_SIZE 

36 block_start = block_start.to(tl.int64) 

37 for off in range(block_start, n_elements, step): 

38 offsets = off + tl.arange(0, BLOCK_SIZE) 

39 mask = offsets < n_elements 

40 x = tl.load(X_ptr + offsets, mask=mask) 

41 result = tl.ceil(x.to(tl.float32)).to(x.dtype) 

42 tl.store(OUT_ptr + offsets, result, mask=mask) 

43 

44 

45def ceil(A): 

46 logger.debug("GEMS_CAMBRICON CEIL") 

47 A = A.contiguous() 

48 out = torch.empty_like(A) 

49 N = A.numel() 

50 if N == 0: 

51 return out 

52 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

53 with torch_device_fn.device(A.device): 

54 ceil_kernel[grid_fn](A, out, N) 

55 return out 

56 

57 

58def ceil_out(A, *, out=None): 

59 logger.debug("GEMS_CAMBRICON CEIL_OUT") 

60 A = A.contiguous() 

61 N = A.numel() 

62 if out is None: 

63 out = torch.empty_like(A) 

64 if N == 0: 

65 return out 

66 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

67 with torch_device_fn.device(A.device): 

68 ceil_kernel[grid_fn](A, out, N) 

69 return out 

70 

71 

72def ceil_(A): 

73 logger.debug("GEMS_CAMBRICON CEIL_") 

74 A_contig = A.contiguous() 

75 N = A_contig.numel() 

76 if N == 0: 

77 return A 

78 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

79 with torch_device_fn.device(A.device): 

80 ceil_kernel[grid_fn](A_contig, A_contig, N) 

81 if not A.is_contiguous(): 

82 A.copy_(A_contig) 

83 return A