Coverage for src/flag_gems/runtime/backend/_cambricon/ops/sqrt.py: 0%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from triton.language.extra.mlu.libdevice import sqrt as _sqrt
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
11from ..utils import TOTAL_CORE_NUM
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@libentry()
17@libtuner(
18 configs=[
19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
21 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
22 triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1),
23 ],
24 key=["n_elements"],
25)
26@triton.jit
27def sqrt_kernel(X_ptr, OUT_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
28 pid = tl.program_id(0)
29 num_jobs = tl.num_programs(0)
30 block_start = pid * BLOCK_SIZE
31 step = num_jobs * BLOCK_SIZE
32 block_start = block_start.to(tl.int64)
33 for off in range(block_start, n_elements, step):
34 offsets = off + tl.arange(0, BLOCK_SIZE)
35 mask = offsets < n_elements
36 x = tl.load(X_ptr + offsets, mask=mask)
37 result = _sqrt(x.to(tl.float32))
38 tl.store(OUT_ptr + offsets, result.to(x.dtype), mask=mask)
41def sqrt(A):
42 logger.debug("GEMS_CAMBRICON SQRT")
43 A = A.contiguous()
44 out = torch.empty_like(A)
45 N = A.numel()
46 if N == 0:
47 return out
48 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
49 with torch_device_fn.device(A.device):
50 sqrt_kernel[grid_fn](A, out, N)
51 return out
54def sqrt_(A):
55 logger.debug("GEMS_CAMBRICON SQRT_")
56 A_contig = A.contiguous()
57 N = A_contig.numel()
58 if N == 0:
59 return A
60 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
61 with torch_device_fn.device(A.device):
62 sqrt_kernel[grid_fn](A_contig, A_contig, N)
63 if not A.is_contiguous():
64 A.copy_(A_contig)
65 return A