Coverage for src/flag_gems/ops/new_full.py: 33%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4 

5from flag_gems.ops.full import check_dtype, full_func, full_func_scalar 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10def new_full( 

11 self, 

12 size, 

13 fill_value, 

14 *, 

15 dtype=None, 

16 layout=None, 

17 device=None, 

18 requires_grad=False, 

19 pin_memory=False, 

20): 

21 logger.debug("GEMS NEW_FULL") 

22 if device is None: 

23 device = self.device 

24 if dtype is None: 

25 dtype = self.dtype 

26 fill_value = check_dtype(fill_value, dtype, device) 

27 out = torch.empty(size, device=device, dtype=dtype) 

28 if isinstance(fill_value, torch.Tensor): 

29 return full_func(out, fill_value) 

30 else: 

31 return full_func_scalar(out, fill_value)