Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/zero.py: 0%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13# Kunlunxin XPU has 12 compute clusters; distribute work evenly across them. 

14CLUSTER_NUM = 12 

15 

16 

17@libentry() 

18@triton.jit 

19def zero_kernel( 

20 out_ptr, 

21 n_elements, 

22 BLOCK_SIZE: tl.constexpr, 

23): 

24 """Write-only kernel: no dummy load, stores 0 directly with dtype handled by Triton.""" 

25 pid = ext.program_id(axis=0) 

26 block_start = pid * BLOCK_SIZE 

27 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

28 mask = offsets < n_elements 

29 tl.store(out_ptr + offsets, 0.0, mask=mask) 

30 

31 

32def _launch_zero_kernel(tensor: torch.Tensor) -> torch.Tensor: 

33 n_elements = tensor.numel() 

34 if n_elements == 0: 

35 return tensor 

36 # BLOCK_SIZE: distribute n_elements evenly across CLUSTER_NUM clusters, 

37 # rounded up to the next power of 2 for aligned vectorised stores. 

38 block_size = triton.next_power_of_2(triton.cdiv(n_elements, CLUSTER_NUM)) 

39 grid = (CLUSTER_NUM, 1, 1) 

40 with torch_device_fn.device(tensor.device): 

41 zero_kernel[grid]( 

42 tensor, 

43 n_elements, 

44 BLOCK_SIZE=block_size, 

45 buffer_size_limit=2048, 

46 isCloseDtypeConvert=True, 

47 ) 

48 return tensor 

49 

50 

51def zero(self: torch.Tensor) -> torch.Tensor: 

52 """aten::zero(Tensor self) -> Tensor — in-place zero-fill, returns self.""" 

53 logger.debug("GEMS_KUNLUNXIN ZERO") 

54 return _launch_zero_kernel(self) 

55 

56 

57def zero_out(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: 

58 """aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) — writes zeros to out.""" 

59 logger.debug("GEMS_KUNLUNXIN ZERO_OUT") 

60 return _launch_zero_kernel(out)