Coverage for src/flag_gems/ops/lift_fresh_copy.py: 39%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def _copy_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 x = tl.load(in_ptr + offsets, mask=mask) 

20 tl.store(out_ptr + offsets, x, mask=mask) 

21 

22 

23def lift_fresh_copy(*args, **kwargs): 

24 logger.debug("GEMS LIFT_FRESH_COPY") 

25 # Attempt to find the input tensor from args/kwargs 

26 x = None 

27 if len(args) > 0 and isinstance(args[0], torch.Tensor): 

28 x = args[0] 

29 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor): 

30 x = kwargs["self"] 

31 else: 

32 for v in list(args) + list(kwargs.values()): 

33 if isinstance(v, torch.Tensor): 

34 x = v 

35 break 

36 if x is None: 

37 raise ValueError("lift_fresh_copy expects a Tensor argument") 

38 

39 if x.device.type != flag_gems.device: 

40 raise ValueError( 

41 f"lift_fresh_copy Triton kernel requires a {flag_gems.device} tensor" 

42 ) 

43 

44 x_contig = x.contiguous() 

45 out = torch.empty_like(x_contig, memory_format=torch.contiguous_format) 

46 

47 n_elements = x_contig.numel() 

48 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

49 _copy_kernel[grid](x_contig, out, n_elements, BLOCK_SIZE=1024) 

50 

51 return out.view_as(x_contig) 

52 

53 

54def lift_fresh_copy_out(x: torch.Tensor, out: torch.Tensor = None): 

55 logger.debug("GEMS LIFT_FRESH_COPY_OUT") 

56 if x is None or not isinstance(x, torch.Tensor): 

57 raise ValueError("lift_fresh_copy_out expects 'x' to be a Tensor") 

58 if x.device.type != flag_gems.device: 

59 raise ValueError( 

60 f"lift_fresh_copy_out Triton kernel requires {flag_gems.device} tensors" 

61 ) 

62 

63 x_contig = x.contiguous() 

64 

65 if out is None: 

66 out = torch.empty_like(x_contig, memory_format=torch.contiguous_format) 

67 else: 

68 if out.device.type != flag_gems.device: 

69 raise ValueError(f"Output tensor 'out' must be on {flag_gems.device}") 

70 if out.dtype != x_contig.dtype: 

71 raise ValueError("Output tensor 'out' must have the same dtype as input") 

72 # Resize to match input shape and ensure contiguous layout 

73 if out.numel() != x_contig.numel() or not out.is_contiguous(): 

74 out.resize_(x_contig.shape) 

75 if not out.is_contiguous(): 

76 out = out.contiguous() 

77 

78 n_elements = x_contig.numel() 

79 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

80 _copy_kernel[grid](x_contig, out, n_elements, BLOCK_SIZE=1024) 

81 

82 return out.view_as(x_contig)