Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/fill.py: 0%
107 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
9logger = logging.getLogger(__name__)
12@triton.jit
13def fill_scalar_kernel(
14 out_ptr,
15 value_scalar,
16 n_elements,
17 BLOCK_SIZE: tl.constexpr,
18):
19 pid = tl.program_id(axis=0)
20 block_start = pid * BLOCK_SIZE
21 offsets = block_start + tl.arange(0, BLOCK_SIZE)
22 mask = offsets < n_elements
24 # Load a dummy value to infer the dtype of out_ptr
25 dummy = tl.load(out_ptr + offsets, mask=mask, other=0)
26 fill_val = tl.full([BLOCK_SIZE], value_scalar, dtype=dummy.dtype)
27 tl.store(out_ptr + offsets, fill_val, mask=mask)
30@triton.jit
31def fill_tensor_kernel(
32 out_ptr,
33 value_ptr,
34 n_elements,
35 BLOCK_SIZE: tl.constexpr,
36):
37 pid = tl.program_id(axis=0)
38 block_start = pid * BLOCK_SIZE
39 offsets = block_start + tl.arange(0, BLOCK_SIZE)
40 mask = offsets < n_elements
42 val = tl.load(value_ptr)
43 tl.store(out_ptr + offsets, val, mask=mask)
46def _as_contiguous(tensor):
47 """Return tensor.contiguous() view for use with flat-offset kernels.
49 For non-contiguous tensors this allocates a new buffer; callers that
50 need in-place semantics must copy back afterwards.
51 """
52 if tensor.is_contiguous():
53 return tensor, False
54 return tensor.contiguous(), True
57def fill_scalar(input, value):
58 logger.debug("GEMS_HOPPER FILL_SCALAR")
59 out = torch.empty_like(input)
60 n_elements = out.numel()
61 grid = (triton.cdiv(n_elements, 1024),)
62 with torch_device_fn.device(input.device):
63 fill_scalar_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024)
64 return out
67def fill_scalar_out(input, value, *, out=None):
68 logger.debug("GEMS_HOPPER FILL_SCALAR_OUT")
69 if out is None:
70 return fill_scalar(input, value)
71 out_contig, need_copy = _as_contiguous(out)
72 n_elements = out_contig.numel()
73 grid = (triton.cdiv(n_elements, 1024),)
74 with torch_device_fn.device(input.device):
75 fill_scalar_kernel[grid](out_contig, value, n_elements, BLOCK_SIZE=1024)
76 if need_copy:
77 out.copy_(out_contig)
78 return out
81def fill_tensor(input, value):
82 if not value.is_cuda:
83 return fill_scalar(input, value.item())
84 logger.debug("GEMS_HOPPER FILL_TENSOR")
85 if value.ndim != 0:
86 raise RuntimeError(
87 f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
88 )
89 out = torch.empty_like(input)
90 n_elements = out.numel()
91 grid = (triton.cdiv(n_elements, 1024),)
92 with torch_device_fn.device(input.device):
93 fill_tensor_kernel[grid](out, value, n_elements, BLOCK_SIZE=1024)
94 return out
97def fill_tensor_out(input, value, *, out=None):
98 logger.debug("GEMS_HOPPER FILL_TENSOR_OUT")
99 if out is None:
100 return fill_tensor(input, value)
101 if not value.is_cuda:
102 return fill_scalar_out(input, value.item(), out=out)
103 if value.ndim != 0:
104 raise RuntimeError(
105 f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
106 )
107 out_contig, need_copy = _as_contiguous(out)
108 n_elements = out_contig.numel()
109 grid = (triton.cdiv(n_elements, 1024),)
110 with torch_device_fn.device(input.device):
111 fill_tensor_kernel[grid](out_contig, value, n_elements, BLOCK_SIZE=1024)
112 if need_copy:
113 out.copy_(out_contig)
114 return out
117def fill_tensor_(self, value):
118 if not value.is_cuda:
119 return fill_scalar_(self, value.item())
120 logger.debug("GEMS_HOPPER FILL_TENSOR_")
121 if value.ndim != 0:
122 raise RuntimeError(
123 f"fill only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
124 )
125 if self.is_contiguous():
126 n_elements = self.numel()
127 grid = (triton.cdiv(n_elements, 1024),)
128 with torch_device_fn.device(self.device):
129 fill_tensor_kernel[grid](self, value, n_elements, BLOCK_SIZE=1024)
130 else:
131 tmp = self.contiguous()
132 n_elements = tmp.numel()
133 grid = (triton.cdiv(n_elements, 1024),)
134 with torch_device_fn.device(self.device):
135 fill_tensor_kernel[grid](tmp, value, n_elements, BLOCK_SIZE=1024)
136 self.copy_(tmp)
137 return self
140def fill_scalar_(self, value):
141 logger.debug("GEMS_HOPPER FILL_SCALAR_")
142 if self.is_contiguous():
143 n_elements = self.numel()
144 grid = (triton.cdiv(n_elements, 1024),)
145 with torch_device_fn.device(self.device):
146 fill_scalar_kernel[grid](self, value, n_elements, BLOCK_SIZE=1024)
147 else:
148 tmp = self.contiguous()
149 n_elements = tmp.numel()
150 grid = (triton.cdiv(n_elements, 1024),)
151 with torch_device_fn.device(self.device):
152 fill_scalar_kernel[grid](tmp, value, n_elements, BLOCK_SIZE=1024)
153 self.copy_(tmp)
154 return self