Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/zeros.py: 0%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import device, torch_device_fn 

8from flag_gems.utils import libentry, libtuner 

9from flag_gems.utils.shape_utils import volume 

10 

11TOTAL_CORE_NUM = torch_device_fn.get_device_properties().multi_processor_count 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14device_ = device 

15 

16 

17@libentry() 

18@libtuner( 

19 configs=[ 

20 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1), 

21 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

22 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

23 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

24 ], 

25 strategy=["align32"], 

26 key=["n_elements"], 

27 warmup=1, 

28 rep=2, 

29) 

30@triton.jit 

31def zeros_kernel( 

32 output_ptr, 

33 n_elements, 

34 BLOCK_SIZE: tl.constexpr, 

35): 

36 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 

37 num_jobs = tl.num_programs(axis=0) 

38 block_start = pid * BLOCK_SIZE 

39 step = num_jobs * BLOCK_SIZE 

40 block_start = block_start.to(tl.int64) 

41 for block_start_offset in range(block_start, n_elements, step): 

42 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

43 mask = offsets < n_elements 

44 tl.store(output_ptr + offsets, 0.0, mask=mask) 

45 

46 

47def zeros(size, *, dtype=None, layout=None, device=None, pin_memory=None): 

48 logger.debug("GEMS_TSINGMICRO ZEROS") 

49 if dtype is None: 

50 dtype = torch.get_default_dtype() 

51 if device is None: 

52 device = torch.device(device_.name) 

53 

54 out = torch.empty(size, device=device, dtype=dtype) 

55 N = volume(size) 

56 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

57 with torch_device_fn.device(device): 

58 zeros_kernel[grid_fn](out, N) 

59 return out 

60 

61 

62def zero_(x: torch.Tensor) -> torch.Tensor: 

63 logger.debug("GEMS_TSINGMICRO ZERO_") 

64 N = x.numel() 

65 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

66 with torch_device_fn.device(x.device): 

67 zeros_kernel[grid_fn](x, N) 

68 return x