Coverage for src/flag_gems/ops/stack.py: 60%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2from typing import List, Tuple, Union
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def stack_copy_func_kernel_4(
13 out_ptr,
14 in_ptr_a,
15 in_ptr_b,
16 in_ptr_c,
17 in_ptr_d,
18 dim_size_out,
19 dim_prod_post,
20 dim_offset_a,
21 dim_offset_b,
22 dim_offset_c,
23 dim_offset_d,
24 total_elements_a,
25 total_elements_b,
26 total_elements_c,
27 total_elements_d,
28 BLOCK_X: tl.constexpr,
29):
30 pid_x = tl.program_id(0)
31 pid_y = tl.program_id(1)
33 if pid_y == 0:
34 in_ptr = in_ptr_a
35 dim_offset = dim_offset_a
36 total_elements = total_elements_a
37 elif pid_y == 1:
38 in_ptr = in_ptr_b
39 dim_offset = dim_offset_b
40 total_elements = total_elements_b
41 elif pid_y == 2:
42 in_ptr = in_ptr_c
43 dim_offset = dim_offset_c
44 total_elements = total_elements_c
45 else:
46 in_ptr = in_ptr_d
47 dim_offset = dim_offset_d
48 total_elements = total_elements_d
50 block_start = pid_x.to(tl.int64) * BLOCK_X
51 offsets = tl.arange(0, BLOCK_X).to(tl.int64)
52 idx = block_start + offsets
53 scalar_zero = offsets * 0
55 dim_size_out = dim_size_out + scalar_zero
56 dim_prod_post = dim_prod_post + scalar_zero
57 dim_offset = dim_offset + scalar_zero
58 total_elements = total_elements + scalar_zero
60 mask = idx < total_elements
62 pre_idx = idx // dim_prod_post
63 post_idx = idx % dim_prod_post
65 out_idx = (
66 pre_idx * dim_size_out * dim_prod_post + dim_offset * dim_prod_post + post_idx
67 )
69 data = tl.load(in_ptr + idx, mask=mask)
70 tl.store(out_ptr + out_idx, data, mask=mask)
73def stack(
74 tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
75) -> torch.Tensor:
76 logger.debug("GEMS STACK")
78 if len(tensors) == 0:
79 raise RuntimeError("stack expected a non-empty TensorList")
81 inp_shapes = [list(_.shape) for _ in tensors]
82 inp0_shape = inp_shapes[0]
83 for i, s in enumerate(inp_shapes[1:]):
84 if (dim < -tensors[i + 1].dim() - 1) or (dim > tensors[i + 1].dim()):
85 raise IndexError(
86 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
87 -tensors[i + 1].dim() - 1, tensors[i + 1].dim(), dim
88 )
89 )
90 if s != inp0_shape:
91 raise RuntimeError(
92 f"stack expects each tensor to be equal size, but got {inp0_shape} at entry 0 and {s} at entry {i + 1}"
93 )
95 if dim < 0:
96 dim = dim + len(inp0_shape) + 1
98 # Type promotion: find the common dtype for all tensors
99 dtypes = [t.dtype for t in tensors]
100 dtype = dtypes[0]
101 for dt in dtypes[1:]:
102 dtype = torch.promote_types(dtype, dt)
103 # Convert all tensors to the result dtype if needed
104 tensors = [t.to(dtype) if t.dtype != dtype else t for t in tensors]
105 device = tensors[0].device
106 out_shape = inp0_shape[:dim] + [len(tensors)] + inp0_shape[dim:]
107 out = torch.empty(out_shape, dtype=dtype, device=device)
109 dim_prod_post = 1
110 for s in inp0_shape[dim:]:
111 dim_prod_post *= s
113 BLOCK = 1024
114 i = 0
115 while i < len(tensors):
116 tensors_in_batch = tensors[i : i + 4]
117 num_tensors_in_batch = len(tensors_in_batch)
119 args = []
120 total_elements_list = []
122 for j in range(4):
123 if j < num_tensors_in_batch:
124 tensor = tensors_in_batch[j].contiguous()
125 total_elements = tensor.numel()
126 args.extend([tensor, i + j, total_elements])
127 total_elements_list.append(total_elements)
128 else:
129 args.extend([tensors_in_batch[0], 0, 0])
130 total_elements_list.append(0)
132 dim_size_out = len(tensors)
134 grid_y = num_tensors_in_batch
135 max_elements_in_batch = tensors[0].numel() if total_elements_list else 0
136 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y)
138 (
139 tensor_a,
140 dim_offset_a,
141 total_elements_a,
142 tensor_b,
143 dim_offset_b,
144 total_elements_b,
145 tensor_c,
146 dim_offset_c,
147 total_elements_c,
148 tensor_d,
149 dim_offset_d,
150 total_elements_d,
151 ) = args
153 stack_copy_func_kernel_4[grid](
154 out,
155 tensor_a,
156 tensor_b,
157 tensor_c,
158 tensor_d,
159 dim_size_out,
160 dim_prod_post,
161 dim_offset_a,
162 dim_offset_b,
163 dim_offset_c,
164 dim_offset_d,
165 total_elements_a,
166 total_elements_b,
167 total_elements_c,
168 total_elements_d,
169 BLOCK_X=BLOCK,
170 )
171 i += num_tensors_in_batch
173 return out