Coverage for src/flag_gems/utils/pointwise_dynamic.py: 94%
1019 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import importlib
2import os
3from dataclasses import dataclass
4from enum import Enum, auto
5from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
7import torch
8import triton
9from triton.runtime.jit import JITFunction
11from flag_gems.utils.code_cache import code_cache_dir
12from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
13from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config
14from flag_gems.utils.device_info import get_device_capability
15from flag_gems.utils.shape_utils import (
16 MemOverlap,
17 all_c_contiguous,
18 all_the_same_shape,
19 all_the_same_stride,
20 broadcast_shapes,
21 broadcasted_stride,
22 check_tensor_attributes,
23 has_internal_overlapping,
24)
25from flag_gems.utils.tensor_wrapper import StridedBuffer
26from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion
29# ------------------ Operation Description ---------------------------
30def _type_name(type) -> str:
31 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object"
32 if type in (bool, int, float, str):
33 return type.__name__
34 if isinstance(type, torch.dtype):
35 return str(type)
36 return str(type)
39def _check_typed_list(container, type):
40 for item in container:
41 assert isinstance(item, type)
44def _check_sized_list(container, size):
45 assert len(container) == size
48def _tuple_content(strings: Sequence[str]) -> str:
49 # comma separated list
50 if len(strings) == 0:
51 return ""
52 if len(strings) == 1:
53 return f"{strings[0]},"
54 else:
55 return ", ".join(strings)
58def _cs(strings: Iterable[str]) -> str:
59 return ", ".join(strings)
62def _broadcast_vec(i, ndim):
63 axes = [":" if j == i else "None" for j in range(ndim)]
64 return f"[{_cs(axes)}]"
67class FunctionSchema:
68 _num_inputs: int
69 _is_tensor: List[bool]
70 _dtypes: List[Optional[type]]
72 _num_input_tensors: int
73 _num_non_tensor_inputs: int
75 _num_outputs: int
76 _promotion_methods: List[Tuple[int, ...]]
78 def __init__(
79 self,
80 *,
81 num_inputs: Optional[int] = None,
82 is_tensor: Optional[List[bool]] = None,
83 dtypes: Optional[List[Optional[type]]] = None,
84 num_outputs: Optional[int] = None,
85 promotion_methods=None,
86 ):
87 if is_tensor is not None:
88 _check_typed_list(is_tensor, bool)
89 if dtypes is not None:
90 _check_typed_list(dtypes, (type, type(None)))
92 if promotion_methods is None:
93 raise ValueError(
94 "No type promotion method provided! You must provide type promotion method for each output!"
95 )
96 else:
97 self._promotion_methods = self.canonicalize_promotion_methods(
98 promotion_methods
99 )
100 if num_inputs is not None:
101 self._num_inputs = num_inputs
102 if is_tensor is not None:
103 _check_sized_list(is_tensor, num_inputs)
104 self._is_tensor = is_tensor
105 else:
106 self._is_tensor = [True] * num_inputs
108 if dtypes is not None:
109 _check_sized_list(dtypes, num_inputs)
110 self._dtypes = dtypes
111 else:
112 self._dtypes = [None] * num_inputs
113 elif is_tensor is not None:
114 self._num_inputs = len(is_tensor)
115 self._is_tensor = is_tensor
116 if dtypes is not None:
117 _check_sized_list(dtypes, self._num_inputs)
118 self._dtypes = dtypes
119 else:
120 self._dtypes = [None] * self._num_inputs
121 elif dtypes is not None:
122 self._num_inputs = len(dtypes)
123 self._dtypes = dtypes
124 if is_tensor is not None:
125 _check_sized_list(is_tensor, self._num_inputs)
126 self._is_tensor = is_tensor
127 else:
128 self._is_tensor = [item is None for item in dtypes]
129 else:
130 raise ValueError(
131 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified."
132 )
134 if num_outputs is not None:
135 self._num_outputs = num_outputs
136 _check_sized_list(promotion_methods, num_outputs)
137 else:
138 self._num_outputs = len(promotion_methods)
140 assert self._num_inputs >= 1
141 assert self._num_outputs >= 1
143 self._num_input_tensors = sum(self._is_tensor)
144 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors
145 self._input_id = self._compute_input_id()
147 @staticmethod
148 def canonicalize_promotion_methods(promotion_methods):
149 canonicalized = []
150 for item in promotion_methods:
151 *arg_indices, method = item
152 canonicalized.append(
153 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method])
154 )
155 return canonicalized
157 def num_inputs(self):
158 # num of arguments, outputs not included
159 return self._num_inputs
161 def num_outputs(self):
162 return self._num_outputs
164 def is_tensor(self, arg_id: int) -> bool:
165 return self._is_tensor[arg_id]
167 def input_type(self, arg_id) -> Optional[type]:
168 return self._dtypes[arg_id]
170 def output_type(self, i):
171 return self._promotion_methods[i]
173 def num_input_tensors(self) -> int:
174 return self._num_input_tensors
176 def num_output_tensors(self) -> int:
177 return self._num_outputs
179 def num_non_tensor_args(self) -> int:
180 return self._num_non_tensor_inputs
182 def signature(self, outputs_in_arg: bool = False) -> str:
183 input_types = []
184 for is_tensor, dtype in zip(self._is_tensor, self._dtypes):
185 if is_tensor:
186 input_types.append("StridedBuffer")
187 else:
188 if dtype is None:
189 input_types.append("scalar")
190 else:
191 input_types.append(_type_name(dtype))
193 output_types = []
195 if outputs_in_arg:
196 for i in range(self.num_outputs()):
197 output_types.append(f"StridedBuffer(a{1}!)")
198 input_types.extend(output_types)
199 else:
200 for _ in range(self.num_outputs()):
201 output_types.append("StridedBuffer")
202 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}'
203 return sig
205 def _compute_input_id(self):
206 input_tensor_index = 0
207 non_tensor_index = 0
208 mapping: List[int] = []
209 for i in range(self.num_inputs()):
210 if self.is_tensor(i):
211 mapping.append(input_tensor_index)
212 input_tensor_index += 1
213 else:
214 mapping.append(non_tensor_index)
215 non_tensor_index += 1
216 return mapping
218 def input_index(self, idx):
219 return self._input_id[idx]
221 def __str__(self) -> str:
222 return self.signature(outputs_in_arg=False)
225class KernelGenerator:
226 def __init__(
227 self,
228 function_schema: FunctionSchema,
229 scalar_fn: triton.JITFunction,
230 rank: int,
231 name: str,
232 config: CodeGenConfig,
233 ):
234 self.fx = function_schema
235 self.fn = scalar_fn
236 self.ndim = rank
237 self.name = name
238 self.config = config
240 self.fn_name = scalar_fn.__name__
241 self.fn_module = scalar_fn.__module__
243 def gen_import_function(self, code: IndentedBuffer):
244 code.writeline("@triton.jit")
245 code.writemultiline(self.fn.src)
246 code.newline()
248 def gen_decorators(self, code):
249 code.writeline("@libentry()")
250 num_non_tensor_args = self.fx.num_non_tensor_args()
251 if num_non_tensor_args > 0:
252 # we do not specialize non tensor args since they are passed into the inlined function
253 # which means that their values may not deserve specialization
254 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)]
255 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
256 else:
257 code.writeline("@triton.jit")
259 def input_name(self, i):
260 is_tensor = self.fx.is_tensor(i)
261 name = "in" if is_tensor else "val"
262 index = self.fx.input_index(i)
263 return f"{name}{index}"
265 def output_name(self, i):
266 return f"out{i}"
268 def gen_signature(self, code, with_block_pointer=False):
269 code.writeline(f"def {self.name}(")
270 with code.indent():
271 input_tensor_index = 0
272 non_tensor_index = 0
273 output_tensor_index = 0
275 schema = self.fx
276 # signature: inputs ptrs & non tensor inputs
277 for i in range(schema.num_inputs()):
278 if schema.is_tensor(i):
279 code.writeline(
280 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
281 )
282 input_tensor_index += 1
283 else:
284 if schema.input_type(i) is not None:
285 code.writeline(
286 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
287 )
288 else:
289 code.writeline(f"val{non_tensor_index},")
290 non_tensor_index += 1
292 # signature: output ptrs
293 for i in range(schema.num_outputs()):
294 code.writeline(
295 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
296 )
297 output_tensor_index += 1
299 # signature: strides, for each tensor arguments
300 ndim = self.ndim
301 if ndim > 0:
302 # strides for inputs
303 for i in range(schema.num_input_tensors()):
304 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
305 code.writeline(f"{stride_args}, # strides for in{i}")
306 if with_block_pointer:
307 stride_order_args = _cs(
308 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
309 )
310 code.writeline(f"{stride_order_args}, # stride order for in{i}")
312 # strides for outputs
313 for i in range(schema.num_output_tensors()):
314 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
315 code.writeline(f"{stride_args}, # strides for out{i}")
316 if with_block_pointer:
317 stride_order_args = _cs(
318 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
319 )
320 code.writeline(
321 f"{stride_order_args}, # stride order for out{i}"
322 )
324 # task space, used to reconstruct multi index
325 task_space_args = _cs(f"s{i}" for i in range(ndim))
326 code.writeline(f"{task_space_args}, # task_space")
328 # number of tasks, used to compute mask
329 code.writeline("num_tasks,")
331 # tile size & tiles_per_cta, gsl style
332 if ndim > 0:
333 code.writeline("tiles_per_cta: int,")
334 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim))
335 code.writeline(f"{tile_sizes},")
336 code.writeline("one_tile_per_cta: tl.constexpr,")
337 code.writeline("):")
339 def gen_signature_1d_tile(self, code):
340 code.writeline(f"def {self.name}(")
341 with code.indent():
342 input_tensor_index = 0
343 non_tensor_index = 0
344 output_tensor_index = 0
346 schema = self.fx
347 # signature: inputs ptrs & non tensor inputs
348 for i in range(schema.num_inputs()):
349 if schema.is_tensor(i):
350 code.writeline(
351 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
352 )
353 input_tensor_index += 1
354 else:
355 if schema.input_type(i) is not None:
356 code.writeline(
357 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
358 )
359 else:
360 code.writeline(f"val{non_tensor_index},")
361 non_tensor_index += 1
363 # signature: output ptrs
364 for i in range(schema.num_outputs()):
365 code.writeline(
366 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
367 )
368 output_tensor_index += 1
370 # signature: strides, for each tensor arguments
371 ndim = self.ndim
372 if ndim > 0:
373 # strides for inputs
374 for i in range(schema.num_input_tensors()):
375 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
376 code.writeline(f"{stride_args}, # strides for in{i}")
378 # strides for outputs
379 for i in range(schema.num_output_tensors()):
380 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
381 code.writeline(f"{stride_args}, # strides for out{i}")
383 # task space, used to reconstruct multi index
384 task_space_args = _cs(f"s{i}" for i in range(ndim))
385 code.writeline(f"{task_space_args}, # task_space")
387 # number of tasks, used to compute mask
388 code.writeline("num_tasks,")
390 # tile size & tiles_per_cta, gsl style
391 if ndim > 0:
392 code.writeline("tiles_per_cta: int,")
393 code.writeline("tile_size: tl.constexpr,")
394 code.writeline("one_tile_per_cta: tl.constexpr,")
395 code.writeline("):")
397 def gen_num_tiles(self, code):
398 # tile-grid size
399 ndim = self.ndim
400 for i in range(ndim):
401 if i < ndim:
402 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})")
404 def gen_body_for_0d(self, code):
405 schema = self.fx
406 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
407 outputs_to_scalar_fn = [
408 self.output_name(i) for i in range(schema.num_output_tensors())
409 ]
410 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
411 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
413 code.writeline("# loads")
414 for i in range(schema.num_input_tensors()):
415 code.writeline(
416 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) "
417 "# workaround the bug on bool, we should use the pointer's dtype)"
418 )
419 code.newline()
421 code.writeline("# compute")
422 code.writeline(
423 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
424 )
425 code.newline()
427 code.writeline("# stores")
428 for i in range(schema.num_output_tensors()):
429 code.writeline(
430 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))"
431 )
432 code.newline()
433 return code
435 # nd tile 1d grid kernel with block pointer
436 def gen_body_one_tile_per_cta_with_bptr(self, code):
437 ndim = self.ndim
438 schema = self.fx
440 # block pointer for each operand
441 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim)))
442 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim)))
443 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim)))
445 # reconstruct pid multi index
446 code.writeline(
447 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
448 )
449 for i in reversed(range(ndim)):
450 if i > 0:
451 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
452 code.writeline(f"tile_id //= num_tiles{i}")
453 else:
454 code.writeline(f"tile_id{i} = tile_id")
455 code.newline()
457 # cta_offsets
458 code.writeline("# tile offsets")
459 for i in range(ndim):
460 # Or else: AssertionError: Block pointers only support 32 bit
461 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing
462 # for 64 bit support
463 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)")
465 # loads
466 code.writeline("# loads")
467 for i in range(schema.num_input_tensors()):
468 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim)))
469 import flag_gems
471 if flag_gems.vendor_name == "spacemit":
472 order = _tuple_content(tuple(f"{ndim - j - 1}" for j in range(ndim)))
473 else:
474 order = _tuple_content(
475 tuple(f"in{i}_stride_order{j}" for j in range(ndim))
476 )
477 code.writeline(
478 f"in{i}_bptr = tl.make_block_ptr("
479 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
480 )
481 code.writeline(
482 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) "
483 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)"
484 )
485 code.newline()
487 # compute
488 # TODO: sepearate this part
489 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
490 outputs_to_scalar_fn = [
491 self.output_name(i) for i in range(schema.num_output_tensors())
492 ]
493 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
494 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
496 code.writeline("# compute")
497 code.writeline(
498 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
499 )
500 code.newline()
502 # stores
503 code.writeline(
504 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype"
505 )
506 for i in range(schema.num_output_tensors()):
507 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim)))
508 order = _tuple_content(
509 tuple(f"out{i}_stride_order{j}" for j in range(ndim))
510 )
511 code.writeline(
512 f"out{i}_bptr = tl.make_block_ptr("
513 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
514 )
515 code.writeline(
516 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))"
517 )
519 def gen_body_gsl_with_bptr(self, code):
520 code.writeline("num_ctas = ext.num_programs(0)")
521 code.writeline("for j in range(0, tiles_per_cta):")
522 with code.indent():
523 code.writeline("tile_id = pid + j * num_ctas")
524 self.gen_body_one_tile_per_cta_with_bptr(code)
526 def gen_body_one_tile_per_cta_without_bptr(self, code):
527 ndim = self.ndim
528 schema = self.fx
530 # reconstruct pid multi index
531 code.writeline(
532 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
533 )
534 for i in reversed(range(ndim)):
535 if i > 0:
536 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
537 code.writeline(f"tile_id //= num_tiles{i}")
538 else:
539 code.writeline(f"tile_id{i} = tile_id")
540 code.newline()
542 # offsets
543 for i in range(ndim):
544 code.writeline(
545 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})"
546 )
548 # masks
549 for i in range(ndim):
550 code.writeline(f"mask{i} = offsets{i} < s{i}")
551 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim))
552 mask_combine = " & ".join(masks)
553 code.writeline(f"mask = {mask_combine}")
555 # loads
556 code.writeline("# loads")
557 for i in range(schema.num_input_tensors()):
558 offsets = tuple(
559 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}"
560 for j in range(ndim)
561 )
562 offset_combine = " + ".join(offsets)
563 code.writeline(
564 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
565 )
567 code.newline()
569 # compute
570 # TODO: sepearate this part
571 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
572 outputs_to_scalar_fn = [
573 self.output_name(i) for i in range(schema.num_output_tensors())
574 ]
575 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
576 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
578 code.writeline("# compute")
579 code.writeline(
580 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
581 )
582 code.newline()
584 # stores
585 for i in range(schema.num_output_tensors()):
586 offsets = tuple(
587 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}"
588 for j in range(ndim)
589 )
590 offset_combine = " + ".join(offsets)
591 code.writeline(
592 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
593 )
595 def gen_body_gsl_without_bptr(self, code):
596 code.writeline("num_ctas = ext.num_programs(0)")
597 code.writeline("for j in range(0, tiles_per_cta):")
598 with code.indent():
599 code.writeline("tile_id = pid + j * num_ctas")
600 self.gen_body_one_tile_per_cta_without_bptr(code)
602 def codegen_nd_tile_with_bptr(self, code):
603 """Generate kernel nd tile & 1d grid with gsl support with block pointer."""
604 self.gen_import_function(code)
605 self.gen_decorators(code)
606 self.gen_signature(code, with_block_pointer=True)
608 # function body for rank-0
609 if self.ndim == 0:
610 with code.indent():
611 self.gen_body_for_0d(code)
612 return code
614 with code.indent():
615 code.writeline("pid = ext.program_id(0)")
616 self.gen_num_tiles(code)
617 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
618 code.writeline("if one_tile_per_cta: # monolitic kernel style")
619 with code.indent():
620 code.writeline("tile_id = pid")
621 self.gen_body_one_tile_per_cta_with_bptr(code)
622 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
623 code.writeline("else: # grid-stride-loop style kernel")
624 with code.indent():
625 self.gen_body_gsl_with_bptr(code)
626 code.newline()
627 return code
629 def codegen_nd_tile_without_bptr(self, code):
630 self.gen_import_function(code)
631 self.gen_decorators(code)
632 self.gen_signature(code, with_block_pointer=False)
634 # function body for rank-0
635 if self.ndim == 0:
636 with code.indent():
637 self.gen_body_for_0d(code)
638 return code
640 with code.indent():
641 code.writeline("pid = ext.program_id(0)")
642 self.gen_num_tiles(code)
643 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
644 code.writeline("if one_tile_per_cta: # monolitic kernel style")
645 with code.indent():
646 code.writeline("tile_id = pid")
647 self.gen_body_one_tile_per_cta_without_bptr(code)
648 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
649 code.writeline("else: # grid-stride-loop style kernel")
650 with code.indent():
651 self.gen_body_gsl_without_bptr(code)
652 code.newline()
653 return code
655 def codegen_nd_tile(self, code):
656 use_block_pointer = self.config.prefer_block_pointer
657 if use_block_pointer:
658 self.codegen_nd_tile_with_bptr(code)
659 else:
660 self.codegen_nd_tile_without_bptr(code)
661 return code
663 def gen_body_one_tile_per_cta_1d_tile(self, code):
664 ndim = self.ndim
665 schema = self.fx
667 # tile id
668 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)")
669 code.writeline("mask = tid < num_tasks")
671 # multi index reconstruction
672 for i in reversed(range(ndim)):
673 if i > 0:
674 code.writeline(f"i{i} = tid % s{i}")
675 code.writeline(f"tid //= s{i}")
676 else:
677 code.writeline(f"i{i} = tid")
678 code.newline()
680 # loads
681 code.writeline("# loads")
682 for i in range(schema.num_input_tensors()):
683 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim))
684 offset_combine = " + ".join(offsets)
685 code.writeline(
686 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
687 )
689 code.newline()
691 # compute
692 # TODO: sepearate this part
693 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
694 outputs_to_scalar_fn = [
695 self.output_name(i) for i in range(schema.num_output_tensors())
696 ]
697 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
698 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
700 code.writeline("# compute")
701 code.writeline(
702 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
703 )
704 code.newline()
706 # stores
707 for i in range(schema.num_output_tensors()):
708 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim))
709 offset_combine = " + ".join(offsets)
710 code.writeline(
711 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
712 )
714 def gen_body_gsl_1d_tile(self, code):
715 code.writeline("num_ctas = ext.num_programs(0)")
716 code.writeline("for j in range(0, tiles_per_cta):")
717 with code.indent():
718 code.writeline("tile_id = pid + j * num_ctas")
719 self.gen_body_one_tile_per_cta_1d_tile(code)
721 def codegen_1d_tile(self, code):
722 """Generate kernel 1d tile & 1d grid with gsl support."""
723 self.gen_import_function(code)
724 self.gen_decorators(code)
725 self.gen_signature_1d_tile(code)
727 # function body for rank-0
728 if self.ndim == 0:
729 with code.indent():
730 self.gen_body_for_0d(code)
731 return code
733 with code.indent():
734 code.writeline("pid = ext.program_id(0)")
735 # code.writeline("num_ctas = te.num_programs(0)")
736 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
737 code.writeline("if one_tile_per_cta: # monolitic kernel style")
738 with code.indent():
739 code.writeline("tile_id = pid")
740 self.gen_body_one_tile_per_cta_1d_tile(code)
741 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
742 code.writeline("else: # grid-stride-loop style kernel")
743 with code.indent():
744 self.gen_body_gsl_1d_tile(code)
745 code.newline()
746 return code
749class WrapperGenerator:
750 def __init__(
751 self,
752 function_schema: FunctionSchema,
753 jit_fn_name: str,
754 ndim: int,
755 name: str,
756 config: CodeGenConfig,
757 ):
758 self.fx = function_schema
759 self.jit_fn_name = jit_fn_name
760 self.ndim = ndim
761 self.name = name
762 self.config = config
764 def input_name(self, i):
765 is_tensor = self.fx.is_tensor(i)
766 name = "in" if is_tensor else "val"
767 index = self.fx.input_index(i)
768 return f"{name}{index}"
770 def output_name(self, i):
771 return f"out{i}"
773 def gen_signature(self, code: IndentedBuffer):
774 # TODO: check if triton handles constexprs transitively
775 schema = self.fx
776 params: List[str] = []
777 for i in range(schema.num_inputs()):
778 if schema.is_tensor(i):
779 params.append(
780 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]"
781 )
782 else:
783 arg_type = schema.input_type(i)
784 if arg_type is not None:
785 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}")
786 else:
787 params.append(f"{self.input_name(i)}")
788 # NOTE: [the wrapper's signature and rules for passing parameters ]
789 # input params: must be passed by position, since the names are renamed to
790 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd
791 # So we enforce that these parameters must be passed by position.
792 # maybe we can fix it later
793 # output parameters: must be passed by keyword, since the scalar function
794 # do not have output parameters(think of it as some scalar function, output
795 # parameter does not make sense in this case.) They are added to allow destination
796 # passing style API. Output parameter is convenient in cases where we want
797 # to use some pre-defiend outputs(especially when they are some views of other
798 # tensors). We emphasize that these parameters are added in-addition, we enforce
799 # that they be passed by keyword. After all, out0, out1, ... does not mismatch
800 # names form the scalar function, since it does not have output parameters.
801 params.append("/")
802 params.append("*") # output params must be passed by keyword
804 for i in range(schema.num_output_tensors()):
805 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]")
806 code.writeline(f"def {self.name}({_cs(params)}): ")
808 def gen_docstring(self, code: IndentedBuffer):
809 schema = self.fx
810 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""'
811 code.writeline(doc)
813 def gen_same_shape_check(self, code: IndentedBuffer):
814 schema: FunctionSchema = self.fx
815 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [
816 f"out{i}.shape" for i in range(schema.num_output_tensors())
817 ]
818 check: str = " == ".join(params)
819 code.writeline(f"assert {check}, 'operand shapes mismatch'")
821 def gen_task_partition(self, code: IndentedBuffer):
822 code.writeline("# task partitioning")
823 ndim = self.ndim
824 if ndim == 0:
825 code.writeline("num_warps = 1")
826 code.writeline("num_ctas = 1")
827 else:
828 code.writeline("shape = out0.shape")
829 code.writeline("num_tasks = out0.numel()")
830 code.writeline("if num_tasks == 0:")
831 with code.indent():
832 self.gen_return(code)
833 max_tile_size = self.config.max_tile_size
834 # Check if all input and output dtypes are complex
835 all_complex = True
836 for i in range(self.fx.num_inputs()):
837 if self.fx.is_tensor(i):
838 input_dtype = self.fx.input_type(i)
839 if input_dtype is not None and not (
840 input_dtype == torch.complex64
841 or input_dtype == torch.complex128
842 ):
843 all_complex = False
844 break
845 if all_complex:
846 # If all inputs are complex, set max_tile_size to half
847 max_tile_size = max_tile_size // 2
848 major, _ = get_device_capability()
849 if self.name.find("fill_scalar") != -1 and major >= 9:
850 code.writeline("tile_sizes = tuple([64])")
851 else:
852 code.writeline(
853 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)"
854 )
855 code.writeline("tile_size = math.prod(tile_sizes)")
856 code.writeline(
857 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))"
858 )
860 if self.name.find("fill_scalar") != -1 and major >= 9:
861 code.writeline("num_ctas = num_tiles")
862 else:
863 max_grid_size0 = self.config.max_grid_size[0]
864 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
866 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
867 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
868 code.writeline("one_tile_per_cta = tiles_per_cta==1")
869 code.writeline("grid = (num_ctas, 1, 1)")
871 def gen_task_partition_1d(self, code: IndentedBuffer):
872 code.writeline("# task partitioning")
873 ndim = self.ndim
874 if ndim == 0:
875 code.writeline("num_warps = 1")
876 code.writeline("num_ctas = 1")
877 else:
878 code.writeline("shape = out0.shape")
879 code.writeline("num_tasks = out0.numel()")
880 code.writeline("if num_tasks == 0:")
881 with code.indent():
882 self.gen_return(code)
883 max_tile_size = self.config.max_tile_size
884 # Check if all input and output dtypes are complex
885 all_complex = True
886 for i in range(self.fx.num_inputs()):
887 if self.fx.is_tensor(i):
888 input_dtype = self.fx.input_type(i)
889 if input_dtype is not None and not (
890 input_dtype == torch.complex64
891 or input_dtype == torch.complex128
892 ):
893 all_complex = False
894 break
895 if all_complex:
896 max_tile_size = max_tile_size // 2
897 major, _ = get_device_capability()
898 if self.name.find("fill_scalar") != -1 and major >= 9:
899 code.writeline("tile_sizes = tuple([1024])")
900 else:
901 code.writeline(
902 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)"
903 )
905 code.writeline("tile_size = tile_sizes[0]")
906 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)")
908 if self.name.find("fill_scalar") != -1 and major >= 9:
909 code.writeline("num_ctas = num_tiles")
910 else:
911 max_grid_size0 = self.config.max_grid_size[0]
912 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
914 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
915 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
916 code.writeline("one_tile_per_cta = tiles_per_cta==1")
917 code.writeline("grid = (num_ctas, 1, 1)")
919 def gen_kernel_launch(
920 self,
921 code: IndentedBuffer,
922 ):
923 schema = self.fx
924 ndim = self.ndim
926 with_block_pointer = self.config.prefer_block_pointer
928 code.writeline("# kernel launch")
929 for i in range(schema.num_input_tensors()):
930 code.writeline(f"in{i}_strides = in{i}.stride()")
931 if not with_block_pointer:
932 continue
933 if ndim >= 2: # where ndim is 1, we don't need to compute stride order
934 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)")
935 else:
936 code.writeline(f"in{i}_stride_order = (0,)")
937 for i in range(schema.num_output_tensors()):
938 code.writeline(f"out{i}_strides = out{i}.stride()")
939 if not with_block_pointer:
940 continue
941 if ndim >= 2:
942 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)")
943 else:
944 code.writeline(f"out{i}_stride_order = (0,)")
946 code.writeline("with torch_device_fn.device(in0.device.index):")
947 with code.indent():
948 code.writeline(f"{self.jit_fn_name}[grid](")
949 with code.indent():
950 params = []
951 # NOTE: WRAP
952 for i in range(schema.num_inputs()):
953 if schema.is_tensor(i):
954 params.append(f"{self.input_name(i)}")
955 else:
956 params.append(self.input_name(i))
957 for i in range(schema.num_output_tensors()):
958 params.append(f"{self.output_name(i)}")
960 code.writeline(f"{_cs(params)},")
962 if ndim > 0:
963 for i in range(schema.num_input_tensors()):
964 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
965 code.writeline(f"{s}, # stride for in{i}")
966 if not with_block_pointer:
967 continue
968 order = ", ".join(
969 f"in{i}_stride_order[{j}]" for j in range(ndim)
970 )
971 code.writeline(f"{order}, # stride order for in{i}")
973 for i in range(schema.num_output_tensors()):
974 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
975 code.writeline(f"{s}, # stride for out{i}")
976 if not with_block_pointer:
977 continue
978 order = ", ".join(
979 f"out{i}_stride_order[{j}]" for j in range(ndim)
980 )
981 code.writeline(f"{order}, # stride orderfor out{i}")
983 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
984 code.writeline(f"{shape_args}, # task indexing space")
985 code.writeline("num_tasks, # num tasks")
986 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
987 for i in range(ndim):
988 code.writeline(f"tile_size{i}=tile_sizes[{i}],")
989 code.writeline("one_tile_per_cta=one_tile_per_cta,")
990 code.writeline("num_warps=num_warps,")
991 code.writeline(")")
993 def gen_kernel_launch_1d(
994 self,
995 code: IndentedBuffer,
996 ):
997 schema = self.fx
998 ndim = self.ndim
1000 code.writeline("# kernel launch")
1001 for i in range(schema.num_input_tensors()):
1002 code.writeline(f"in{i}_strides = in{i}.stride()")
1003 for i in range(schema.num_output_tensors()):
1004 code.writeline(f"out{i}_strides = out{i}.stride()")
1006 code.writeline("with torch_device_fn.device(in0.device.index):")
1007 with code.indent():
1008 code.writeline(f"{self.jit_fn_name}[grid](")
1009 with code.indent():
1010 params = []
1011 # NOTE: WRAP
1012 for i in range(schema.num_inputs()):
1013 if schema.is_tensor(i):
1014 params.append(f"{self.input_name(i)}")
1015 else:
1016 params.append(self.input_name(i))
1017 for i in range(schema.num_output_tensors()):
1018 params.append(f"{self.output_name(i)}")
1020 code.writeline(f"{_cs(params)},")
1022 if ndim > 0:
1023 for i in range(schema.num_input_tensors()):
1024 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
1025 code.writeline(f"{s}, # stride for in{i}")
1026 for i in range(schema.num_output_tensors()):
1027 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
1028 code.writeline(f"{s}, # stride for out{i}")
1030 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
1031 code.writeline(f"{shape_args}, # task indexing space")
1032 code.writeline("num_tasks, # num tasks")
1033 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
1034 code.writeline("tile_size=tile_size,")
1035 code.writeline("one_tile_per_cta=one_tile_per_cta,")
1036 code.writeline("num_warps=num_warps,")
1037 code.writeline(")")
1039 def gen_return(self, code: IndentedBuffer):
1040 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors()))
1041 code.writeline(f"return {return_exprs}")
1043 def codegen_nd_tile(self, code):
1044 self.gen_signature(code)
1046 with code.indent():
1047 self.gen_docstring(code)
1048 self.gen_same_shape_check(code)
1049 self.gen_task_partition(code)
1050 self.gen_kernel_launch(code)
1051 self.gen_return(code)
1052 code.newline()
1053 return code
1055 def codegen_1d_tile(self, code):
1056 self.gen_signature(code)
1058 with code.indent():
1059 self.gen_docstring(code)
1060 self.gen_same_shape_check(code)
1061 self.gen_task_partition_1d(code)
1062 self.gen_kernel_launch_1d(code)
1063 self.gen_return(code)
1064 code.newline()
1065 return code
1068class ModuleGenerator:
1069 def __init__(
1070 self,
1071 function_schema: FunctionSchema,
1072 scalar_fn: triton.JITFunction,
1073 ndim: int,
1074 jit_fn_name: str,
1075 wrapper_name: str,
1076 config: CodeGenConfig,
1077 ):
1078 self.config = config
1079 self.scalar_fn = scalar_fn
1080 self.wrapper_gen = WrapperGenerator(
1081 function_schema, jit_fn_name, ndim, wrapper_name, config
1082 )
1083 self.kernel_gen = KernelGenerator(
1084 function_schema, scalar_fn, ndim, jit_fn_name, config
1085 )
1087 @staticmethod
1088 def _collect_jit_deps(scalar_fn):
1089 """Collect extra imports and local @triton.jit helper sources.
1091 Parses the source module where scalar_fn is defined using AST.
1092 Returns a tuple of:
1093 - extra_imports: dict of module_path -> set of names
1094 - local_sources: list of source strings for local @triton.jit
1095 functions (those NOT decorated with @pointwise_dynamic)
1096 """
1097 import ast
1098 import inspect
1100 py_fn = getattr(scalar_fn, "fn", scalar_fn)
1101 module_name = getattr(py_fn, "__module__", None)
1102 if not module_name:
1103 return {}, []
1104 try:
1105 mod = importlib.import_module(module_name)
1106 source_file = inspect.getfile(mod)
1107 except (ImportError, TypeError, OSError):
1108 return {}, []
1109 try:
1110 with open(source_file) as f:
1111 module_source = f.read()
1112 source_lines = module_source.splitlines(keepends=True)
1113 tree = ast.parse(module_source)
1114 except (OSError, SyntaxError):
1115 return {}, []
1117 # Collect non-standard import-from lines
1118 ALREADY_IMPORTED = {
1119 "math",
1120 "typing",
1121 "torch",
1122 "triton",
1123 "triton.language",
1124 "flag_gems.utils.shape_utils",
1125 "flag_gems.utils.tensor_wrapper",
1126 "flag_gems.utils.libentry",
1127 "flag_gems.utils",
1128 "flag_gems.runtime",
1129 "flag_gems.utils.pointwise_dynamic",
1130 }
1131 extra_imports = {}
1132 for node in ast.iter_child_nodes(tree):
1133 if isinstance(node, ast.ImportFrom) and node.module:
1134 if node.module in ALREADY_IMPORTED:
1135 continue
1136 names = {alias.name for alias in node.names}
1137 extra_imports.setdefault(node.module, set()).update(names)
1139 # Collect local @triton.jit functions (without @pointwise_dynamic)
1140 def _has_decorator(func_node, name):
1141 for dec in func_node.decorator_list:
1142 src = "".join(source_lines[dec.lineno - 1 : dec.end_lineno])
1143 if name in src:
1144 return True
1145 return False
1147 def _extract_source(func_node):
1148 start = func_node.lineno - 1
1149 if func_node.decorator_list:
1150 start = func_node.decorator_list[0].lineno - 1
1151 end = func_node.end_lineno
1152 return "".join(source_lines[start:end])
1154 local_sources = []
1155 for node in ast.iter_child_nodes(tree):
1156 if not isinstance(node, ast.FunctionDef):
1157 continue
1158 if not _has_decorator(node, "triton.jit") and not _has_decorator(
1159 node, "jit"
1160 ):
1161 continue
1162 if _has_decorator(node, "pointwise_dynamic"):
1163 continue
1164 local_sources.append(_extract_source(node))
1166 return extra_imports, local_sources
1168 def generate_imports(self, code: IndentedBuffer) -> IndentedBuffer:
1169 code.writeline("import math")
1170 code.writeline("from typing import Union")
1171 code.writeline("import torch")
1172 code.writeline("import triton")
1173 code.writeline("from triton import language as tl")
1174 code.newline()
1175 code.writeline("from flag_gems.utils.shape_utils import (")
1176 code.writeline(" heuristics_for_tile_size,")
1177 code.writeline(" heuristics_for_num_warps,")
1178 code.writeline(" stride_order,")
1179 code.writeline(")")
1180 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer")
1181 code.writeline("from flag_gems.utils.libentry import libentry")
1182 code.writeline("from flag_gems.utils import triton_lang_extension as ext")
1183 code.writeline("from flag_gems.runtime import torch_device_fn")
1185 # Generate extra imports and local JIT deps of the scalar function
1186 jit_dep_imports, local_jit_sources = self._collect_jit_deps(self.scalar_fn)
1187 for module_path, names in sorted(jit_dep_imports.items()):
1188 sorted_names = ", ".join(sorted(names))
1189 code.writeline(f"from {module_path} import {sorted_names}")
1191 code.newline()
1192 code.newline()
1194 # Emit local @triton.jit helper functions
1195 for source in local_jit_sources:
1196 for line in source.splitlines():
1197 code.writeline(line)
1198 code.newline()
1200 return code
1202 def codegen(self, code: IndentedBuffer):
1203 code = self.generate_imports(code)
1204 if self.config.prefer_1d_tile:
1205 code = self.wrapper_gen.codegen_1d_tile(code)
1206 code = self.kernel_gen.codegen_1d_tile(code)
1207 else:
1208 code = self.wrapper_gen.codegen_nd_tile(code)
1209 code = self.kernel_gen.codegen_nd_tile(code)
1210 return code
1213@dataclass
1214class KernelInfo:
1215 """Information about a generated kernel for C++ integration."""
1217 file_path: str
1218 kernel_name: str
1219 wrapper_name: str
1220 ndim: int
1223class ComplexMode(Enum):
1224 NONE = auto()
1225 ELEMENTWISE = auto() # add/sub: view_as_real → same kernel → view_as_complex
1226 CROSS = auto() # mul/div: split ar/ai/br/bi → cross_kernel
1229@dataclass
1230class ComplexStrategy:
1231 mode: ComplexMode = ComplexMode.NONE
1232 cross_kernel: object = None
1233 tensorize_scalars: bool = False
1234 fallback_target: object = None
1237_REAL_TO_COMPLEX = {
1238 torch.float16: torch.complex32,
1239 torch.bfloat16: torch.complex32,
1240 torch.float32: torch.complex64,
1241 torch.float64: torch.complex128,
1242}
1245class PointwiseDynamicFunction:
1246 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
1247 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
1248 The generated code are written out to the cache directory (defaults to ~/.flaggems).
1249 """
1251 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None):
1252 self.fx = op_desc
1254 assert isinstance(scalar_fn, JITFunction)
1255 self._scalar_fn = scalar_fn
1256 self._scalar_fn_cache_key = scalar_fn.cache_key
1257 self.pid = os.getpid()
1259 self.config: CodeGenConfig = config or get_codegen_config()
1261 # instantiated & cached overloads
1262 self.overloads: Mapping[str, Callable] = {}
1263 # cached kernel info for C++ integration
1264 self._kernel_info_cache: Mapping[str, KernelInfo] = {}
1266 # complex dispatch support
1267 self.complex_strategy = ComplexStrategy()
1268 self._operand_indices = self._infer_operand_indices()
1270 # -------------------- operand index inference --------------------
1272 def _infer_operand_indices(self):
1273 """Infer operand indices from schema._promotion_methods, done once at init."""
1274 indices = set()
1275 for pm in self.fx._promotion_methods:
1276 for idx in pm[:-1]:
1277 indices.add(idx)
1278 return frozenset(indices)
1280 # -------------------- register_complex --------------------
1282 def register_complex(
1283 self, mode, cross_kernel=None, tensorize_scalars=False, fallback_target=None
1284 ):
1285 """Register complex number support for this kernel.
1287 Args:
1288 mode: ComplexMode.ELEMENTWISE (add/sub) or ComplexMode.CROSS (mul/div).
1289 cross_kernel: A PointwiseDynamicFunction for cross-term ops (mul/div).
1290 tensorize_scalars: If True, scalar operands are converted to tensors
1291 before delegating to fallback_target.
1292 fallback_target: A PointwiseDynamicFunction (tensor-tensor version)
1293 to delegate to after tensorizing scalar operands.
1294 """
1295 self.complex_strategy = ComplexStrategy(
1296 mode=mode,
1297 cross_kernel=cross_kernel,
1298 tensorize_scalars=tensorize_scalars,
1299 fallback_target=fallback_target,
1300 )
1301 return self
1303 # -------------------- call entry --------------------
1305 def __call__(self, *args, **kwargs):
1306 if self._should_use_complex_path(args):
1307 return self._call_complex_dispatch(*args, **kwargs)
1308 return self._call_real_impl(*args, **kwargs)
1310 def _call_real_impl(self, *args, **kwargs):
1311 """Single entry point for real kernel invocation."""
1312 ndim, args, kwargs = self.prepare_args(*args, **kwargs)
1313 overload = self.instantiate(ndim)
1314 out = overload(*args, **kwargs)
1315 return self._unwrap(out)
1317 # -------------------- complex helpers --------------------
1319 @staticmethod
1320 def _is_complex_arg(a):
1321 return (isinstance(a, torch.Tensor) and a.is_complex()) or isinstance(
1322 a, complex
1323 )
1325 def _should_use_complex_path(self, args):
1326 if self.complex_strategy.mode == ComplexMode.NONE:
1327 return False
1328 return any(
1329 self._is_complex_arg(args[i])
1330 for i in self._operand_indices
1331 if i < len(args)
1332 )
1334 def _split_args(self, args):
1335 """Split args into operands and others by original position index."""
1336 operands = {}
1337 others = {}
1338 for i, a in enumerate(args):
1339 if i in self._operand_indices:
1340 operands[i] = a
1341 else:
1342 others[i] = a
1343 return operands, others
1345 def _merge_args(self, operands, others):
1346 """Rebuild args tuple from operands and others by original position index."""
1347 total = len(operands) + len(others)
1348 merged = [None] * total
1349 for i, v in operands.items():
1350 merged[i] = v
1351 for i, v in others.items():
1352 merged[i] = v
1353 return tuple(merged)
1355 def _classify_complex_inputs(self, operands):
1356 """Classify operands as 'all_complex', 'mixed', or 'real'."""
1357 complex_count = sum(1 for v in operands.values() if self._is_complex_arg(v))
1358 if complex_count == len(operands):
1359 return "all_complex"
1360 elif complex_count > 0:
1361 return "mixed"
1362 return "real"
1364 def _infer_device(self, operands):
1365 for v in operands.values():
1366 if isinstance(v, torch.Tensor):
1367 return v.device
1368 return None
1370 def _infer_complex_dtype(self, operands):
1371 return torch.result_type(*operands.values())
1373 def _tensorize_scalar_operands(self, operands, dtype, device):
1374 """Convert scalar operands to tensors."""
1375 result = {}
1376 for i, v in operands.items():
1377 if not isinstance(v, torch.Tensor):
1378 if isinstance(v, complex):
1379 result[i] = torch.tensor(v, dtype=dtype, device=device)
1380 elif isinstance(v, float):
1381 result[i] = torch.tensor(v, dtype=torch.float32, device=device)
1382 elif isinstance(v, (int, bool)):
1383 result[i] = torch.tensor(v, dtype=torch.int64, device=device)
1384 else:
1385 result[i] = v
1386 else:
1387 result[i] = v
1388 return result
1390 def _to_complex_tensor(self, a, target_dtype, device):
1391 """Convert a scalar or real tensor to a complex tensor."""
1392 if isinstance(a, torch.Tensor):
1393 if a.is_complex():
1394 return a
1395 if a.is_floating_point():
1396 cdtype = _REAL_TO_COMPLEX.get(a.dtype, torch.complex64)
1397 else:
1398 a = a.to(torch.float32)
1399 cdtype = torch.complex64
1400 return torch.complex(a, torch.zeros_like(a)).to(cdtype)
1401 elif isinstance(a, complex):
1402 return torch.tensor(a, dtype=target_dtype, device=device)
1403 elif isinstance(a, (int, float)):
1404 return torch.tensor(complex(a, 0), dtype=target_dtype, device=device)
1405 return a
1407 # -------------------- complex dispatch --------------------
1409 def _call_complex_dispatch(self, *args, **kwargs):
1410 """Unified complex dispatch entry point."""
1411 strategy = self.complex_strategy
1412 operands, others = self._split_args(args)
1414 device = self._infer_device(operands)
1415 result_dtype = self._infer_complex_dtype(operands)
1417 # tensorize scalar operands and delegate to fallback_target
1418 if strategy.tensorize_scalars and strategy.fallback_target is not None:
1419 operands = self._tensorize_scalar_operands(operands, result_dtype, device)
1420 new_args = self._merge_args(operands, others)
1421 return strategy.fallback_target(*new_args, **kwargs)
1423 # convert all operands to complex tensors
1424 for i in list(operands.keys()):
1425 operands[i] = self._to_complex_tensor(operands[i], result_dtype, device)
1427 # broadcast complex tensor operands
1428 complex_tensors = [operands[i] for i in sorted(operands.keys())]
1429 complex_tensors = torch.broadcast_tensors(*complex_tensors)
1430 for idx, key in enumerate(sorted(operands.keys())):
1431 operands[key] = complex_tensors[idx]
1433 classification = self._classify_complex_inputs(operands)
1435 if strategy.mode == ComplexMode.CROSS and classification == "all_complex":
1436 return self._call_complex_cross(operands, result_dtype)
1437 elif classification in ("all_complex", "mixed"):
1438 return self._call_complex_elementwise(
1439 operands, others, result_dtype, kwargs
1440 )
1441 else:
1442 new_args = self._merge_args(operands, others)
1443 return self._call_real_impl(*new_args, **kwargs)
1445 def _call_complex_elementwise(self, operands, others, result_dtype, kwargs):
1446 """Elementwise: view_as_real -> call real kernel -> view_as_complex."""
1447 real_tensors = {i: torch.view_as_real(t) for i, t in operands.items()}
1449 # promote to common real dtype
1450 dtypes = [t.dtype for t in real_tensors.values()]
1451 common_dtype = dtypes[0]
1452 for d in dtypes[1:]:
1453 common_dtype = torch.promote_types(common_dtype, d)
1454 real_tensors = {i: t.to(common_dtype) for i, t in real_tensors.items()}
1456 new_args = self._merge_args(real_tensors, others)
1457 out_real = self._call_real_impl(*new_args, **kwargs)
1458 return torch.view_as_complex(out_real.contiguous()).to(result_dtype)
1460 def _call_complex_cross(self, operands, result_dtype):
1461 """Cross-term: split ar/ai/br/bi -> call cross_kernel -> stack -> view_as_complex."""
1462 sorted_keys = sorted(operands.keys())
1463 A, B = operands[sorted_keys[0]], operands[sorted_keys[1]]
1464 Ar = torch.view_as_real(A)
1465 Br = torch.view_as_real(B)
1466 ar, ai = Ar[..., 0], Ar[..., 1]
1467 br, bi = Br[..., 0], Br[..., 1]
1469 common_dtype = torch.promote_types(ar.dtype, br.dtype)
1470 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
1471 br, bi = br.to(common_dtype), bi.to(common_dtype)
1473 real, imag = self.complex_strategy.cross_kernel(ar, ai, br, bi)
1475 out = torch.stack((real, imag), dim=-1)
1476 return torch.view_as_complex(out.contiguous()).to(result_dtype)
1478 @staticmethod
1479 def use_fast_path(tensors):
1480 return all_the_same_shape(tensors) and (
1481 all_c_contiguous(tensors)
1482 or (
1483 all_the_same_stride(tensors)
1484 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0])
1485 )
1486 )
1488 def prepare_args(self, *args, _skip_tensor_check=False, **kwargs):
1489 # output allocation(when needed)
1490 # task simplification & task-rank infernece & input-output reinterpretation
1491 schema = self.fx
1492 outputs_that_need_allocation: List[int] = []
1493 out_tensors = []
1494 for i in range(schema.num_output_tensors()):
1495 k = f"out{i}"
1496 if k in kwargs:
1497 out_tensors.append(kwargs[k])
1498 else:
1499 outputs_that_need_allocation.append(i)
1500 # input arguments must be passed by position
1501 if not _skip_tensor_check and schema._is_tensor is not None:
1502 if not check_tensor_attributes(args, (schema._is_tensor)):
1503 raise ValueError(
1504 "Input arguments must be passed by position, and the corresponding dtype must be specified."
1505 )
1506 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)]
1508 # output dtype promotions
1509 outputs_dtypes_for_allocation = []
1510 for i in outputs_that_need_allocation:
1511 *arg_indices, method = schema._promotion_methods[i]
1512 promote_args = (args[j] for j in arg_indices)
1513 _, dtype = type_promotion(*promote_args, type_promotion=method)
1514 outputs_dtypes_for_allocation.append(dtype)
1516 tensors = out_tensors + in_tensors
1517 INT32_MAX = torch.iinfo(torch.int32).max
1518 if tensors[0].numel() > INT32_MAX:
1519 self.config.prefer_block_pointer = False
1520 if self.use_fast_path(tensors): # dimension collapse & use physical ordering
1521 allocated_outputs = [
1522 torch.empty_like(tensors[0], dtype=dtype)
1523 for dtype in outputs_dtypes_for_allocation
1524 ]
1525 task_shape = (tensors[0].numel(),)
1526 strides = (1,)
1527 ndim = 1
1528 args = tuple(
1529 (
1530 StridedBuffer(item, task_shape, strides)
1531 if schema.is_tensor(i)
1532 else item
1533 )
1534 for i, item in enumerate(args)
1535 )
1536 kwargs = {
1537 k: StridedBuffer(item, task_shape, strides)
1538 for k, item in kwargs.items()
1539 }
1540 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1541 kwargs[f"out{output_id}"] = StridedBuffer(
1542 allocated_outputs[seq_id], task_shape, strides
1543 )
1544 else:
1545 # a simple strategy: all the undefined tensors will follow the first
1546 # tensor that is not broadcated, no attempts to simplify task, no reordering,
1547 # no dimenion collapsing
1548 shapes = tuple(item.shape for item in in_tensors)
1550 task_shape = broadcast_shapes(shapes)
1552 if out_tensors:
1553 for index, item in enumerate(out_tensors):
1554 if list(item.shape) != list(task_shape):
1555 raise RuntimeError(
1556 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!"
1557 )
1558 # output arguments must not have internal overlapping for pointwise operation
1559 if has_internal_overlapping(item) == MemOverlap.Yes:
1560 raise RuntimeError(
1561 "Pointwise Input arguments should not have internal overlapping."
1562 )
1564 ndim = len(task_shape)
1565 for item in tensors:
1566 if item.shape == task_shape:
1567 allocated_outputs = [
1568 torch.empty_like(item, dtype=dtype)
1569 for dtype in outputs_dtypes_for_allocation
1570 ]
1571 break
1572 else: # nobreak
1573 device = tensors[0].device
1574 allocated_outputs = [
1575 torch.empty(task_shape, dtype=dtype, device=device)
1576 for dtype in outputs_dtypes_for_allocation
1577 ]
1578 args = tuple(
1579 (
1580 StridedBuffer(
1581 item,
1582 task_shape,
1583 broadcasted_stride(item.shape, item.stride(), task_shape),
1584 )
1585 if schema.is_tensor(i)
1586 else item
1587 )
1588 for i, item in enumerate(args)
1589 )
1590 kwargs = {
1591 k: StridedBuffer(
1592 item,
1593 task_shape,
1594 broadcasted_stride(item.shape, item.stride(), task_shape),
1595 )
1596 for k, item in kwargs.items()
1597 }
1598 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1599 item = allocated_outputs[seq_id]
1600 kwargs[f"out{output_id}"] = StridedBuffer(
1601 item,
1602 task_shape,
1603 broadcasted_stride(item.shape, item.stride(), task_shape),
1604 )
1605 return (ndim, args, kwargs)
1607 def _unwrap(self, tensors):
1608 # unwrap StridedBuffer to get Tensor
1609 if self.fx.num_output_tensors() == 1:
1610 item = tensors
1611 return item.unwrap()
1612 return tuple(item.unwrap() for item in tensors)
1614 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]:
1615 """Compute kernel name, wrapper name, and file path for a given ndim.
1617 This is the single source of truth for naming, used by both instantiate()
1618 and get_kernel_info() to ensure consistency.
1620 Returns:
1621 Tuple of (kernel_name, wrapper_name, file_path)
1622 """
1623 scalar_fn_name = self._scalar_fn.__name__
1624 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}"
1625 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}"
1627 file_name = (
1628 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_"
1629 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}"
1630 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}"
1631 ".py"
1632 )
1633 file_path = str(code_cache_dir() / file_name)
1635 return kernel_name, wrapper_name, file_path
1637 def instantiate(self, ndim):
1638 # NOTE: manually instantiated overload does not have `prepare_args` as
1639 # preprocessing, so you have to manually allocate output and make sure that
1640 # the inputs & ouputs actually fits the manually instantiated overload
1641 key = f"{ndim}_{self.config.prefer_block_pointer}"
1642 if key in self.overloads:
1643 return self.overloads[key]
1645 code = IndentedBuffer()
1647 # Use helper to compute names (single source of truth)
1648 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim)
1650 module_gen = ModuleGenerator(
1651 self.fx,
1652 self._scalar_fn,
1653 ndim,
1654 kernel_name,
1655 wrapper_name,
1656 self.config,
1657 )
1658 module_gen.codegen(code)
1660 # NOTE: [why write the generated code to a file]
1661 # triton uses inpsect to get the source of the jitted function, which requires
1662 # that the source code can be found by inspect
1663 # We write it into a file, since inspect cannot find the source of functions dynamically
1664 # created via exec string. We can help inspect to find the source by hacking linecache
1665 # library, but we find generating a module simpler, since we can generating 2 functions
1666 # the kernel and the wrapper, and the wrapper calls the kernel.
1667 write_atomic(file_path, code.getvalue())
1669 # load
1670 spec = importlib.util.spec_from_file_location(
1671 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}",
1672 file_path,
1673 )
1674 m = importlib.util.module_from_spec(spec)
1675 # do not expose it to sys.modules
1676 # sys.modules["_add_module"] = m
1678 # NOTE: [why not import the scalar function]
1679 # we do not re-import the scalar function, although the generated kernel **calls** it
1680 # Since a function's __name__ may be changed, from the module where it is defined import its
1681 # __name__ is not same; Also the same may be rebind to something else, importing via name
1682 # cannot guarantee that scalar function is imported.
1683 # So we copy the scalar function and its __globals__ to the generated module to do this
1684 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime
1685 spec.loader.exec_module(m)
1686 m.__dict__.update(self._scalar_fn.__globals__)
1687 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn
1689 overload = getattr(m, wrapper_name)
1690 self.overloads[key] = overload
1692 # Cache kernel info for C++ integration
1693 self._kernel_info_cache[key] = KernelInfo(
1694 file_path=file_path,
1695 kernel_name=kernel_name,
1696 wrapper_name=wrapper_name,
1697 ndim=ndim,
1698 )
1700 return overload
1702 def get_kernel_info(self, ndim: int) -> KernelInfo:
1703 """Get kernel information for a given ndim.
1705 This method is useful for C++ integration to get the file path and
1706 kernel name without duplicating the naming logic.
1708 If the kernel hasn't been instantiated yet, this will instantiate it first.
1710 Args:
1711 ndim: The rank of the task space
1713 Returns:
1714 KernelInfo with file_path, kernel_name, wrapper_name, and ndim
1715 """
1716 key = f"{ndim}_{self.config.prefer_block_pointer}"
1718 # Ensure the kernel is instantiated
1719 if key not in self._kernel_info_cache:
1720 self.instantiate(ndim)
1722 return self._kernel_info_cache[key]
1725def pointwise_dynamic(
1726 f: Optional[JITFunction] = None,
1727 *,
1728 num_inputs: Optional[int] = None,
1729 is_tensor: Optional[List[bool]] = None,
1730 dtypes: Optional[List[Optional[type]]] = None,
1731 num_outputs: Optional[int] = None,
1732 promotion_methods: Optional[Tuple[int, ...]] = None,
1733 config: Optional[CodeGenConfig] = None,
1734):
1735 def decorator(fn):
1736 nonlocal num_inputs
1737 if (num_inputs is None) and (is_tensor is None) and (dtypes is None):
1738 num_inputs = len(fn.arg_names)
1739 op_desc = FunctionSchema(
1740 num_inputs=num_inputs,
1741 is_tensor=is_tensor,
1742 dtypes=dtypes,
1743 num_outputs=num_outputs,
1744 promotion_methods=promotion_methods,
1745 )
1746 return PointwiseDynamicFunction(op_desc, fn, config)
1748 if f is not None:
1749 return decorator(f)
1750 return decorator