Coverage for src/flag_gems/runtime/backend/_cambricon/ops/ceil.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry, libtuner
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@libtuner(
17 configs=[
18 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
19 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
20 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
21 triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1),
22 ],
23 key=["n_elements"],
24)
25@triton.jit
26def ceil_kernel(
27 X_ptr,
28 OUT_ptr,
29 n_elements,
30 BLOCK_SIZE: tl.constexpr,
31):
32 pid = tl.program_id(0)
33 num_jobs = tl.num_programs(0)
34 block_start = pid * BLOCK_SIZE
35 step = num_jobs * BLOCK_SIZE
36 block_start = block_start.to(tl.int64)
37 for off in range(block_start, n_elements, step):
38 offsets = off + tl.arange(0, BLOCK_SIZE)
39 mask = offsets < n_elements
40 x = tl.load(X_ptr + offsets, mask=mask)
41 result = tl.ceil(x.to(tl.float32)).to(x.dtype)
42 tl.store(OUT_ptr + offsets, result, mask=mask)
45def ceil(A):
46 logger.debug("GEMS_CAMBRICON CEIL")
47 A = A.contiguous()
48 out = torch.empty_like(A)
49 N = A.numel()
50 if N == 0:
51 return out
52 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
53 with torch_device_fn.device(A.device):
54 ceil_kernel[grid_fn](A, out, N)
55 return out
58def ceil_out(A, *, out=None):
59 logger.debug("GEMS_CAMBRICON CEIL_OUT")
60 A = A.contiguous()
61 N = A.numel()
62 if out is None:
63 out = torch.empty_like(A)
64 if N == 0:
65 return out
66 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
67 with torch_device_fn.device(A.device):
68 ceil_kernel[grid_fn](A, out, N)
69 return out
72def ceil_(A):
73 logger.debug("GEMS_CAMBRICON CEIL_")
74 A_contig = A.contiguous()
75 N = A_contig.numel()
76 if N == 0:
77 return A
78 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
79 with torch_device_fn.device(A.device):
80 ceil_kernel[grid_fn](A_contig, A_contig, N)
81 if not A.is_contiguous():
82 A.copy_(A_contig)
83 return A