Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/zeros.py: 0%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import device, torch_device_fn
8from flag_gems.utils import libentry, libtuner
9from flag_gems.utils.shape_utils import volume
11TOTAL_CORE_NUM = torch_device_fn.get_device_properties().multi_processor_count
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14device_ = device
17@libentry()
18@libtuner(
19 configs=[
20 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1),
21 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
22 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
23 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
24 ],
25 strategy=["align32"],
26 key=["n_elements"],
27 warmup=1,
28 rep=2,
29)
30@triton.jit
31def zeros_kernel(
32 output_ptr,
33 n_elements,
34 BLOCK_SIZE: tl.constexpr,
35):
36 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
37 num_jobs = tl.num_programs(axis=0)
38 block_start = pid * BLOCK_SIZE
39 step = num_jobs * BLOCK_SIZE
40 block_start = block_start.to(tl.int64)
41 for block_start_offset in range(block_start, n_elements, step):
42 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
43 mask = offsets < n_elements
44 tl.store(output_ptr + offsets, 0.0, mask=mask)
47def zeros(size, *, dtype=None, layout=None, device=None, pin_memory=None):
48 logger.debug("GEMS_TSINGMICRO ZEROS")
49 if dtype is None:
50 dtype = torch.get_default_dtype()
51 if device is None:
52 device = torch.device(device_.name)
54 out = torch.empty(size, device=device, dtype=dtype)
55 N = volume(size)
56 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
57 with torch_device_fn.device(device):
58 zeros_kernel[grid_fn](out, N)
59 return out
62def zero_(x: torch.Tensor) -> torch.Tensor:
63 logger.debug("GEMS_TSINGMICRO ZERO_")
64 N = x.numel()
65 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
66 with torch_device_fn.device(x.device):
67 zeros_kernel[grid_fn](x, N)
68 return x