Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/sqrt.py: 0%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import libentry, libtuner
10logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.sqrt")
13@libentry()
14@libtuner(
15 configs=runtime.get_tuned_config("sqrt"),
16 key=["n_elements"],
17 strategy=["align32"],
18 warmup=5,
19 rep=10,
20)
21@triton.jit
22def sqrt_kernel(
23 input_ptr,
24 output_ptr,
25 n_elements,
26 BLOCK_SIZE: tl.constexpr,
27):
28 pid = tl.program_id(axis=0)
29 block_start = pid * BLOCK_SIZE
30 offsets = block_start + tl.arange(0, BLOCK_SIZE)
31 mask = offsets < n_elements
32 x = tl.load(input_ptr + offsets, mask=mask)
33 x_fp32 = x.to(tl.float32)
34 output = tl.sqrt(x_fp32)
35 output = output.to(output_ptr.dtype.element_ty)
36 tl.store(output_ptr + offsets, output, mask=mask)
39def sqrt(A):
40 logger.debug("GEMS_HOPPER SQRT")
41 output = torch.empty_like(A)
42 n_elements = output.numel()
43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
44 sqrt_kernel[grid](A, output, n_elements)
45 return output
48def sqrt_(A):
49 logger.debug("GEMS_HOPPER SQRT_")
50 output = torch.empty_like(A)
51 n_elements = A.numel()
52 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
53 sqrt_kernel[grid](A, output, n_elements)
54 A.copy_(output)
55 return A