Coverage for src/flag_gems/runtime/backend/_cambricon/ops/fill.py: 0%
91 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry, libtuner
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@libtuner(
17 configs=[
18 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1),
19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
21 ],
22 key=["N"],
23 strategy=["log"],
24)
25@triton.jit(do_not_specialize=["value_scalar"])
26def fill_scalar_kernel(
27 out_ptr,
28 N,
29 value_scalar,
30 BLOCK_SIZE: tl.constexpr,
31):
32 pid = tl.program_id(0)
33 num_jobs = tl.num_programs(axis=0)
34 block_start = pid * BLOCK_SIZE
35 step = num_jobs * BLOCK_SIZE
36 block_start = block_start.to(tl.int64)
37 for block_start_offset in range(block_start, N, step):
38 offset = block_start_offset + tl.arange(0, BLOCK_SIZE)
39 tl.store(out_ptr + offset, value_scalar, mask=offset < N)
42@libentry()
43@libtuner(
44 configs=[
45 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1),
46 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
47 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
48 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
49 ],
50 key=["N"],
51)
52@triton.jit
53def fill_tensor_kernel(
54 out_ptr,
55 N,
56 value_ptr,
57 BLOCK_SIZE: tl.constexpr,
58):
59 pid = tl.program_id(0)
60 num_jobs = tl.num_programs(axis=0)
61 block_start = pid * BLOCK_SIZE
62 step = num_jobs * BLOCK_SIZE
63 block_start = block_start.to(tl.int64)
64 for block_start_offset in range(block_start, N, step):
65 offset = block_start_offset + tl.arange(0, BLOCK_SIZE)
66 value_scalar = tl.load(value_ptr) # load the value from the tensor.
67 tl.store(out_ptr + offset, value_scalar, mask=offset < N)
70def fill_tensor(input, value):
71 logger.debug("GEMS_CAMBRICON FILL TENSOR")
72 if value.ndim != 0:
73 raise RuntimeError(
74 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
75 )
76 out = torch.empty_like(input)
77 N = out.numel()
78 # grid = triton.cdiv(N, BLOCK_SIZE)
79 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
81 with torch_device_fn.device(input.device):
82 fill_tensor_kernel[grid_fn](out, N, value)
83 return out
86def fill_scalar(input, value):
87 logger.debug("GEMS_CAMBRICON FILL SCALAR")
88 if 0 in input.shape:
89 return input
90 out = torch.empty_like(input)
91 N = out.numel()
92 # grid = triton.cdiv(N, BLOCK_SIZE)
93 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
95 with torch_device_fn.device(input.device):
96 fill_scalar_kernel[grid_fn](out, N, value)
97 return out
100def fill_scalar_out(input, value, *, out=None):
101 logger.debug("GEMS_CAMBRICON FILL SCALAR_OUT")
102 if out is None:
103 return fill_scalar(input, value)
104 N = out.numel()
105 # grid = triton.cdiv(N, BLOCK_SIZE)
106 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
108 with torch_device_fn.device(input.device):
109 fill_scalar_kernel[grid_fn](out, N, value)
110 return out
113def fill_tensor_out(input, value, *, out=None):
114 logger.debug("GEMS_CAMBRICON FILL_TENSOR_OUT")
115 if value.ndim != 0:
116 raise RuntimeError(
117 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
118 )
119 if out is None:
120 return fill_tensor(input, value)
121 N = out.numel()
122 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
124 with torch_device_fn.device(input.device):
125 fill_tensor_kernel[grid_fn](out, N, value)
126 return out
129def fill_tensor_(self, value):
130 logger.debug("GEMS_CAMBRICON FILL_TENSOR_")
131 if value.ndim != 0:
132 raise RuntimeError(
133 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
134 )
135 N = self.numel()
136 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
138 with torch_device_fn.device(self.device):
139 fill_tensor_kernel[grid_fn](self, N, value)
140 return self
143def fill_scalar_(self, value):
144 logger.debug("GEMS_CAMBRICON FILL_SCALAR_")
145 if 0 in self.shape:
146 return self
147 N = self.numel()
148 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
150 with torch_device_fn.device(self.device):
151 fill_scalar_kernel[grid_fn](self, N, value)
152 return self