Coverage for src/flag_gems/ops/zero.py: 61%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def zero_kernel(
15 out_ptr, # *Pointer* to tensor to be zeroed
16 n_elements, # Number of elements
17 BLOCK_SIZE: tl.constexpr,
18):
19 pid = tl.program_id(axis=0)
20 block_start = pid * BLOCK_SIZE
21 offsets = block_start + tl.arange(0, BLOCK_SIZE)
22 mask = offsets < n_elements
23 # Create a zero value with the correct dtype using a dummy load to infer dtype
24 dummy = tl.load(out_ptr + offsets, mask=mask, other=0)
25 z = tl.zeros([BLOCK_SIZE], dtype=dummy.dtype)
26 tl.store(out_ptr + offsets, z, mask=mask)
29def _launch_zero_kernel(tensor: torch.Tensor):
30 assert isinstance(tensor, torch.Tensor), "Expected a torch.Tensor"
31 if tensor.device.type != flag_gems.device:
32 raise ValueError(f"Tensor must be on {flag_gems.device} device")
33 assert tensor.is_contiguous(), "Tensor must be contiguous"
34 assert tensor.numel() >= 0
35 n_elements = tensor.numel()
36 if n_elements == 0:
37 return tensor
38 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
39 zero_kernel[grid](tensor, n_elements, BLOCK_SIZE=1024)
40 return tensor
43def zero(*args, **kwargs):
44 logger.debug("GEMS ZERO")
45 # Accept common conventions: first positional as target, or 'self'/'input'/'out' in kwargs
46 target = None
47 if len(args) >= 1 and isinstance(args[0], torch.Tensor):
48 target = args[0]
49 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor):
50 target = kwargs["self"]
51 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor):
52 target = kwargs["input"]
53 elif "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
54 target = kwargs["out"]
55 else:
56 raise ValueError(
57 "zero expects a Tensor as the first argument or in kwargs as 'self', 'input', or 'out'"
58 )
59 return _launch_zero_kernel(target)
62def zero_out(*args, **kwargs):
63 logger.debug("GEMS ZERO_OUT")
64 # Out variant: prefer 'out' kwarg; else first positional
65 out = None
66 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
67 out = kwargs["out"]
68 elif len(args) >= 1 and isinstance(args[0], torch.Tensor):
69 out = args[0]
70 else:
71 raise ValueError(
72 "zero_out expects an output Tensor as the first positional argument or 'out' kwarg"
73 )
74 return _launch_zero_kernel(out)