Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/zero.py: 0%
32 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13# Kunlunxin XPU has 12 compute clusters; distribute work evenly across them.
14CLUSTER_NUM = 12
17@libentry()
18@triton.jit
19def zero_kernel(
20 out_ptr,
21 n_elements,
22 BLOCK_SIZE: tl.constexpr,
23):
24 """Write-only kernel: no dummy load, stores 0 directly with dtype handled by Triton."""
25 pid = ext.program_id(axis=0)
26 block_start = pid * BLOCK_SIZE
27 offsets = block_start + tl.arange(0, BLOCK_SIZE)
28 mask = offsets < n_elements
29 tl.store(out_ptr + offsets, 0.0, mask=mask)
32def _launch_zero_kernel(tensor: torch.Tensor) -> torch.Tensor:
33 n_elements = tensor.numel()
34 if n_elements == 0:
35 return tensor
36 # BLOCK_SIZE: distribute n_elements evenly across CLUSTER_NUM clusters,
37 # rounded up to the next power of 2 for aligned vectorised stores.
38 block_size = triton.next_power_of_2(triton.cdiv(n_elements, CLUSTER_NUM))
39 grid = (CLUSTER_NUM, 1, 1)
40 with torch_device_fn.device(tensor.device):
41 zero_kernel[grid](
42 tensor,
43 n_elements,
44 BLOCK_SIZE=block_size,
45 buffer_size_limit=2048,
46 isCloseDtypeConvert=True,
47 )
48 return tensor
51def zero(self: torch.Tensor) -> torch.Tensor:
52 """aten::zero(Tensor self) -> Tensor — in-place zero-fill, returns self."""
53 logger.debug("GEMS_KUNLUNXIN ZERO")
54 return _launch_zero_kernel(self)
57def zero_out(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
58 """aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) — writes zeros to out."""
59 logger.debug("GEMS_KUNLUNXIN ZERO_OUT")
60 return _launch_zero_kernel(out)