Coverage for src/flag_gems/ops/zero.py: 61%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def zero_kernel( 

15 out_ptr, # *Pointer* to tensor to be zeroed 

16 n_elements, # Number of elements 

17 BLOCK_SIZE: tl.constexpr, 

18): 

19 pid = tl.program_id(axis=0) 

20 block_start = pid * BLOCK_SIZE 

21 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

22 mask = offsets < n_elements 

23 # Create a zero value with the correct dtype using a dummy load to infer dtype 

24 dummy = tl.load(out_ptr + offsets, mask=mask, other=0) 

25 z = tl.zeros([BLOCK_SIZE], dtype=dummy.dtype) 

26 tl.store(out_ptr + offsets, z, mask=mask) 

27 

28 

29def _launch_zero_kernel(tensor: torch.Tensor): 

30 assert isinstance(tensor, torch.Tensor), "Expected a torch.Tensor" 

31 if tensor.device.type != flag_gems.device: 

32 raise ValueError(f"Tensor must be on {flag_gems.device} device") 

33 assert tensor.is_contiguous(), "Tensor must be contiguous" 

34 assert tensor.numel() >= 0 

35 n_elements = tensor.numel() 

36 if n_elements == 0: 

37 return tensor 

38 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

39 zero_kernel[grid](tensor, n_elements, BLOCK_SIZE=1024) 

40 return tensor 

41 

42 

43def zero(*args, **kwargs): 

44 logger.debug("GEMS ZERO") 

45 # Accept common conventions: first positional as target, or 'self'/'input'/'out' in kwargs 

46 target = None 

47 if len(args) >= 1 and isinstance(args[0], torch.Tensor): 

48 target = args[0] 

49 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor): 

50 target = kwargs["self"] 

51 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor): 

52 target = kwargs["input"] 

53 elif "out" in kwargs and isinstance(kwargs["out"], torch.Tensor): 

54 target = kwargs["out"] 

55 else: 

56 raise ValueError( 

57 "zero expects a Tensor as the first argument or in kwargs as 'self', 'input', or 'out'" 

58 ) 

59 return _launch_zero_kernel(target) 

60 

61 

62def zero_out(*args, **kwargs): 

63 logger.debug("GEMS ZERO_OUT") 

64 # Out variant: prefer 'out' kwarg; else first positional 

65 out = None 

66 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor): 

67 out = kwargs["out"] 

68 elif len(args) >= 1 and isinstance(args[0], torch.Tensor): 

69 out = args[0] 

70 else: 

71 raise ValueError( 

72 "zero_out expects an output Tensor as the first positional argument or 'out' kwarg" 

73 ) 

74 return _launch_zero_kernel(out)