Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/sqrt.py: 0%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import libentry, libtuner 

9 

10logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.sqrt") 

11 

12 

13@libentry() 

14@libtuner( 

15 configs=runtime.get_tuned_config("sqrt"), 

16 key=["n_elements"], 

17 strategy=["align32"], 

18 warmup=5, 

19 rep=10, 

20) 

21@triton.jit 

22def sqrt_kernel( 

23 input_ptr, 

24 output_ptr, 

25 n_elements, 

26 BLOCK_SIZE: tl.constexpr, 

27): 

28 pid = tl.program_id(axis=0) 

29 block_start = pid * BLOCK_SIZE 

30 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

31 mask = offsets < n_elements 

32 x = tl.load(input_ptr + offsets, mask=mask) 

33 x_fp32 = x.to(tl.float32) 

34 output = tl.sqrt(x_fp32) 

35 output = output.to(output_ptr.dtype.element_ty) 

36 tl.store(output_ptr + offsets, output, mask=mask) 

37 

38 

39def sqrt(A): 

40 logger.debug("GEMS_HOPPER SQRT") 

41 output = torch.empty_like(A) 

42 n_elements = output.numel() 

43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

44 sqrt_kernel[grid](A, output, n_elements) 

45 return output 

46 

47 

48def sqrt_(A): 

49 logger.debug("GEMS_HOPPER SQRT_") 

50 output = torch.empty_like(A) 

51 n_elements = A.numel() 

52 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

53 sqrt_kernel[grid](A, output, n_elements) 

54 A.copy_(output) 

55 return A