Coverage for src/flag_gems/runtime/backend/_sunrise/ops/fill.py: 0%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import pointwise_dynamic
9from flag_gems.utils.pointwise_dynamic import CodeGenConfig
11logger = logging.getLogger(__name__)
14MAX_GRID_SIZES = (65535, 65535, 65535)
15config = CodeGenConfig(
16 max_tile_size=1024,
17 max_grid_size=MAX_GRID_SIZES,
18 max_num_warps_per_cta=32,
19 prefer_block_pointer=True,
20 prefer_1d_tile=True,
21)
24@pointwise_dynamic(
25 is_tensor=[True, False],
26 promotion_methods=[(0, "DEFAULT")],
27 num_outputs=1,
28 config=config,
29)
30@triton.jit
31def fill_scalar_func(inp, value_scalar):
32 return tl.full(inp.shape, value_scalar, dtype=inp.dtype)
35@pointwise_dynamic(
36 is_tensor=[True, True],
37 promotion_methods=[(0, "DEFAULT")],
38 num_outputs=1,
39 config=config,
40)
41@triton.jit
42def fill_tensor_func(inp, value):
43 return value
46def fill_scalar(input, value):
47 logger.debug("GEMS FILL (Dynamic)")
48 out = torch.empty_like(input)
49 with torch_device_fn.device(input.device):
50 return fill_scalar_func(input, value, out0=out)
53def fill_scalar_out(input, value, *, out=None):
54 logger.debug("GEMS FILL_SCALAR_OUT")
55 if out is None:
56 return fill_scalar(input, value)
57 with torch_device_fn.device(input.device):
58 fill_scalar_func(input, value, out0=out)
59 return out
62def fill_tensor(input, value):
63 if not value.is_cuda:
64 return fill_scalar(input, value.item())
65 logger.debug("GEMS FILL (Dynamic)")
66 if value.ndim != 0:
67 raise RuntimeError(
68 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
69 )
70 out = torch.empty_like(input)
71 with torch_device_fn.device(input.device):
72 return fill_tensor_func(input, value, out0=out)
75def fill_tensor_out(input, value, *, out=None):
76 logger.debug("GEMS FILL_TENSOR_OUT")
77 if out is None:
78 return fill_tensor(input, value)
79 if not value.is_cuda:
80 return fill_scalar_out(input, value.item(), out=out)
81 if value.ndim != 0:
82 raise RuntimeError(
83 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
84 )
85 with torch_device_fn.device(input.device):
86 fill_tensor_func(input, value, out0=out)
87 return out
90def fill_tensor_(self, value):
91 if not value.is_cuda:
92 return fill_scalar_(self, value.item())
93 logger.debug("GEMS FILL_TENSOR_")
94 if value.ndim != 0:
95 raise RuntimeError(
96 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
97 )
98 with torch_device_fn.device(self.device):
99 fill_tensor_func(self, value, out0=self)
100 return self
103def fill_scalar_(self, value):
104 logger.debug("GEMS FILL_SCALAR_")
105 with torch_device_fn.device(self.device):
106 fill_scalar_func(self, value, out0=self)
107 return self