Coverage for src/flag_gems/utils/pointwise_dynamic.py: 94%
1019 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import importlib
2import os
3from dataclasses import dataclass
4from enum import Enum, auto
5from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
7import torch
8import triton
9from triton.runtime.jit import JITFunction
11from flag_gems.utils.code_cache import code_cache_dir
12from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
13from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config
14from flag_gems.utils.device_info import get_device_capability
15from flag_gems.utils.shape_utils import (
16 MemOverlap,
17 all_c_contiguous,
18 all_the_same_shape,
19 all_the_same_stride,
20 broadcast_shapes,
21 broadcasted_stride,
22 check_tensor_attributes,
23 has_internal_overlapping,
24)
25from flag_gems.utils.tensor_wrapper import StridedBuffer
26from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion
29# ------------------ Operation Description ---------------------------
30def _type_name(type) -> str:
31 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object"
32 if type in (bool, int, float, str):
33 return type.__name__
34 if isinstance(type, torch.dtype):
35 return str(type)
36 return str(type)
39def _check_typed_list(container, type):
40 for item in container:
41 assert isinstance(item, type)
44def _check_sized_list(container, size):
45 assert len(container) == size
48def _tuple_content(strings: Sequence[str]) -> str:
49 # comma separated list
50 if len(strings) == 0:
51 return ""
52 if len(strings) == 1:
53 return f"{strings[0]},"
54 else:
55 return ", ".join(strings)
58def _cs(strings: Iterable[str]) -> str:
59 return ", ".join(strings)
62def _broadcast_vec(i, ndim):
63 axes = [":" if j == i else "None" for j in range(ndim)]
64 return f"[{_cs(axes)}]"
67class FunctionSchema:
68 _num_inputs: int
69 _is_tensor: List[bool]
70 _dtypes: List[Optional[type]]
72 _num_input_tensors: int
73 _num_non_tensor_inputs: int
75 _num_outputs: int
76 _promotion_methods: List[Tuple[int, ...]]
78 def __init__(
79 self,
80 *,
81 num_inputs: Optional[int] = None,
82 is_tensor: Optional[List[bool]] = None,
83 dtypes: Optional[List[Optional[type]]] = None,
84 num_outputs: Optional[int] = None,
85 promotion_methods=None,
86 ):
87 if is_tensor is not None:
88 _check_typed_list(is_tensor, bool)
89 if dtypes is not None:
90 _check_typed_list(dtypes, (type, type(None)))
92 if promotion_methods is None:
93 raise ValueError(
94 "No type promotion method provided! You must provide type promotion method for each output!"
95 )
96 else:
97 self._promotion_methods = self.canonicalize_promotion_methods(
98 promotion_methods
99 )
100 if num_inputs is not None:
101 self._num_inputs = num_inputs
102 if is_tensor is not None:
103 _check_sized_list(is_tensor, num_inputs)
104 self._is_tensor = is_tensor
105 else:
106 self._is_tensor = [True] * num_inputs
108 if dtypes is not None:
109 _check_sized_list(dtypes, num_inputs)
110 self._dtypes = dtypes
111 else:
112 self._dtypes = [None] * num_inputs
113 elif is_tensor is not None:
114 self._num_inputs = len(is_tensor)
115 self._is_tensor = is_tensor
116 if dtypes is not None:
117 _check_sized_list(dtypes, self._num_inputs)
118 self._dtypes = dtypes
119 else:
120 self._dtypes = [None] * self._num_inputs
121 elif dtypes is not None:
122 self._num_inputs = len(dtypes)
123 self._dtypes = dtypes
124 if is_tensor is not None:
125 _check_sized_list(is_tensor, self._num_inputs)
126 self._is_tensor = is_tensor
127 else:
128 self._is_tensor = [item is None for item in dtypes]
129 else:
130 raise ValueError(
131 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified."
132 )
134 if num_outputs is not None:
135 self._num_outputs = num_outputs
136 _check_sized_list(promotion_methods, num_outputs)
137 else:
138 self._num_outputs = len(promotion_methods)
140 assert self._num_inputs >= 1
141 assert self._num_outputs >= 1
143 self._num_input_tensors = sum(self._is_tensor)
144 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors
145 self._input_id = self._compute_input_id()
147 @staticmethod
148 def canonicalize_promotion_methods(promotion_methods):
149 canonicalized = []
150 for item in promotion_methods:
151 *arg_indices, method = item
152 canonicalized.append(
153 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method])
154 )
155 return canonicalized
157 def num_inputs(self):
158 # num of arguments, outputs not included
159 return self._num_inputs
161 def num_outputs(self):
162 return self._num_outputs
164 def is_tensor(self, arg_id: int) -> bool:
165 return self._is_tensor[arg_id]
167 def input_type(self, arg_id) -> Optional[type]:
168 return self._dtypes[arg_id]
170 def output_type(self, i):
171 return self._promotion_methods[i]
173 def num_input_tensors(self) -> int:
174 return self._num_input_tensors
176 def num_output_tensors(self) -> int:
177 return self._num_outputs
179 def num_non_tensor_args(self) -> int:
180 return self._num_non_tensor_inputs
182 def signature(self, outputs_in_arg: bool = False) -> str:
183 input_types = []
184 for is_tensor, dtype in zip(self._is_tensor, self._dtypes):
185 if is_tensor:
186 input_types.append("StridedBuffer")
187 else:
188 if dtype is None:
189 input_types.append("scalar")
190 else:
191 input_types.append(_type_name(dtype))
193 output_types = []
195 if outputs_in_arg:
196 for i in range(self.num_outputs()):
197 output_types.append(f"StridedBuffer(a{1}!)")
198 input_types.extend(output_types)
199 else:
200 for _ in range(self.num_outputs()):
201 output_types.append("StridedBuffer")
202 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}'
203 return sig
205 def _compute_input_id(self):
206 input_tensor_index = 0
207 non_tensor_index = 0
208 mapping: List[int] = []
209 for i in range(self.num_inputs()):
210 if self.is_tensor(i):
211 mapping.append(input_tensor_index)
212 input_tensor_index += 1
213 else:
214 mapping.append(non_tensor_index)
215 non_tensor_index += 1
216 return mapping
218 def input_index(self, idx):
219 return self._input_id[idx]
221 def __str__(self) -> str:
222 return self.signature(outputs_in_arg=False)
225class KernelGenerator:
226 def __init__(
227 self,
228 function_schema: FunctionSchema,
229 scalar_fn: triton.JITFunction,
230 rank: int,
231 name: str,
232 config: CodeGenConfig,
233 ):
234 self.fx = function_schema
235 self.fn = scalar_fn
236 self.ndim = rank
237 self.name = name
238 self.config = config
240 self.fn_name = scalar_fn.__name__
241 self.fn_module = scalar_fn.__module__
243 def gen_import_function(self, code: IndentedBuffer):
244 code.writeline("@triton.jit")
245 code.writemultiline(self.fn.src)
246 code.newline()
248 def gen_decorators(self, code):
249 code.writeline("@libentry()")
250 num_non_tensor_args = self.fx.num_non_tensor_args()
251 if num_non_tensor_args > 0:
252 # we do not specialize non tensor args since they are passed into the inlined function
253 # which means that their values may not deserve specialization
254 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)]
255 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
256 else:
257 code.writeline("@triton.jit")
259 def input_name(self, i):
260 is_tensor = self.fx.is_tensor(i)
261 name = "in" if is_tensor else "val"
262 index = self.fx.input_index(i)
263 return f"{name}{index}"
265 def output_name(self, i):
266 return f"out{i}"
268 def gen_signature(self, code, with_block_pointer=False):
269 code.writeline(f"def {self.name}(")
270 with code.indent():
271 input_tensor_index = 0
272 non_tensor_index = 0
273 output_tensor_index = 0
275 schema = self.fx
276 # signature: inputs ptrs & non tensor inputs
277 for i in range(schema.num_inputs()):
278 if schema.is_tensor(i):
279 code.writeline(
280 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
281 )
282 input_tensor_index += 1
283 else:
284 if schema.input_type(i) is not None:
285 code.writeline(
286 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
287 )
288 else:
289 code.writeline(f"val{non_tensor_index},")
290 non_tensor_index += 1
292 # signature: output ptrs
293 for i in range(schema.num_outputs()):
294 code.writeline(
295 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
296 )
297 output_tensor_index += 1
299 # signature: strides, for each tensor arguments
300 ndim = self.ndim
301 if ndim > 0:
302 # strides for inputs
303 for i in range(schema.num_input_tensors()):
304 stride_args = _cs(
305 f"in{i}_stride{j}: tl.constexpr" for j in range(ndim)
306 )
307 code.writeline(f"{stride_args}, # strides for in{i}")
308 if with_block_pointer:
309 stride_order_args = _cs(
310 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
311 )
312 code.writeline(f"{stride_order_args}, # stride order for in{i}")
314 # strides for outputs
315 for i in range(schema.num_output_tensors()):
316 stride_args = _cs(
317 f"out{i}_stride{j}: tl.constexpr" for j in range(ndim)
318 )
319 code.writeline(f"{stride_args}, # strides for out{i}")
320 if with_block_pointer:
321 stride_order_args = _cs(
322 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
323 )
324 code.writeline(
325 f"{stride_order_args}, # stride order for out{i}"
326 )
328 # task space, used to reconstruct multi index
329 task_space_args = _cs(f"s{i}" for i in range(ndim))
330 code.writeline(f"{task_space_args}, # task_space")
332 # number of tasks, used to compute mask
333 code.writeline("num_tasks,")
335 # tile size & tiles_per_cta, gsl style
336 if ndim > 0:
337 code.writeline("tiles_per_cta: int,")
338 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim))
339 code.writeline(f"{tile_sizes},")
340 code.writeline("one_tile_per_cta: tl.constexpr,")
341 code.writeline("):")
343 def gen_signature_1d_tile(self, code):
344 code.writeline(f"def {self.name}(")
345 with code.indent():
346 input_tensor_index = 0
347 non_tensor_index = 0
348 output_tensor_index = 0
350 schema = self.fx
351 # signature: inputs ptrs & non tensor inputs
352 for i in range(schema.num_inputs()):
353 if schema.is_tensor(i):
354 code.writeline(
355 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
356 )
357 input_tensor_index += 1
358 else:
359 if schema.input_type(i) is not None:
360 code.writeline(
361 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
362 )
363 else:
364 code.writeline(f"val{non_tensor_index},")
365 non_tensor_index += 1
367 # signature: output ptrs
368 for i in range(schema.num_outputs()):
369 code.writeline(
370 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
371 )
372 output_tensor_index += 1
374 # signature: strides, for each tensor arguments
375 ndim = self.ndim
376 if ndim > 0:
377 # strides for inputs
378 for i in range(schema.num_input_tensors()):
379 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
380 code.writeline(f"{stride_args}, # strides for in{i}")
382 # strides for outputs
383 for i in range(schema.num_output_tensors()):
384 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
385 code.writeline(f"{stride_args}, # strides for out{i}")
387 # task space, used to reconstruct multi index
388 task_space_args = _cs(f"s{i}" for i in range(ndim))
389 code.writeline(f"{task_space_args}, # task_space")
391 # number of tasks, used to compute mask
392 code.writeline("num_tasks,")
394 # tile size & tiles_per_cta, gsl style
395 if ndim > 0:
396 code.writeline("tiles_per_cta: int,")
397 code.writeline("tile_size: tl.constexpr,")
398 code.writeline("one_tile_per_cta: tl.constexpr,")
399 code.writeline("):")
401 def gen_num_tiles(self, code):
402 # tile-grid size
403 ndim = self.ndim
404 for i in range(ndim):
405 if i < ndim:
406 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})")
408 def gen_body_for_0d(self, code):
409 schema = self.fx
410 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
411 outputs_to_scalar_fn = [
412 self.output_name(i) for i in range(schema.num_output_tensors())
413 ]
414 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
415 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
417 code.writeline("# loads")
418 for i in range(schema.num_input_tensors()):
419 code.writeline(
420 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) "
421 "# workaround the bug on bool, we should use the pointer's dtype)"
422 )
423 code.newline()
425 code.writeline("# compute")
426 code.writeline(
427 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
428 )
429 code.newline()
431 code.writeline("# stores")
432 for i in range(schema.num_output_tensors()):
433 code.writeline(
434 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))"
435 )
436 code.newline()
437 return code
439 # nd tile 1d grid kernel with block pointer
440 def gen_body_one_tile_per_cta_with_bptr(self, code):
441 ndim = self.ndim
442 schema = self.fx
444 # block pointer for each operand
445 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim)))
446 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim)))
447 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim)))
449 # reconstruct pid multi index
450 code.writeline(
451 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
452 )
453 for i in reversed(range(ndim)):
454 if i > 0:
455 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
456 code.writeline(f"tile_id //= num_tiles{i}")
457 else:
458 code.writeline(f"tile_id{i} = tile_id")
459 code.newline()
461 # cta_offsets
462 code.writeline("# tile offsets")
463 for i in range(ndim):
464 # Or else: AssertionError: Block pointers only support 32 bit
465 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing
466 # for 64 bit support
467 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)")
469 # loads
470 code.writeline("# loads")
471 for i in range(schema.num_input_tensors()):
472 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim)))
473 import flag_gems
475 if flag_gems.vendor_name == "spacemit":
476 order = _tuple_content(tuple(f"{ndim - j - 1}" for j in range(ndim)))
477 else:
478 order = _tuple_content(
479 tuple(f"in{i}_stride_order{j}" for j in range(ndim))
480 )
481 code.writeline(
482 f"in{i}_bptr = tl.make_block_ptr("
483 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
484 )
485 code.writeline(
486 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) "
487 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)"
488 )
489 code.newline()
491 # compute
492 # TODO: sepearate this part
493 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
494 outputs_to_scalar_fn = [
495 self.output_name(i) for i in range(schema.num_output_tensors())
496 ]
497 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
498 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
500 code.writeline("# compute")
501 code.writeline(
502 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
503 )
504 code.newline()
506 # stores
507 code.writeline(
508 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype"
509 )
510 for i in range(schema.num_output_tensors()):
511 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim)))
512 order = _tuple_content(
513 tuple(f"out{i}_stride_order{j}" for j in range(ndim))
514 )
515 code.writeline(
516 f"out{i}_bptr = tl.make_block_ptr("
517 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
518 )
519 code.writeline(
520 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))"
521 )
523 def gen_body_gsl_with_bptr(self, code):
524 code.writeline("num_ctas = ext.num_programs(0)")
525 code.writeline("for j in range(0, tiles_per_cta):")
526 with code.indent():
527 code.writeline("tile_id = pid + j * num_ctas")
528 self.gen_body_one_tile_per_cta_with_bptr(code)
530 def gen_body_one_tile_per_cta_without_bptr(self, code):
531 ndim = self.ndim
532 schema = self.fx
534 # reconstruct pid multi index
535 code.writeline(
536 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
537 )
538 for i in reversed(range(ndim)):
539 if i > 0:
540 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
541 code.writeline(f"tile_id //= num_tiles{i}")
542 else:
543 code.writeline(f"tile_id{i} = tile_id")
544 code.newline()
546 # offsets
547 for i in range(ndim):
548 code.writeline(
549 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})"
550 )
552 # masks
553 for i in range(ndim):
554 code.writeline(f"mask{i} = offsets{i} < s{i}")
555 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim))
556 mask_combine = " & ".join(masks)
557 code.writeline(f"mask = {mask_combine}")
559 # loads
560 code.writeline("# loads")
561 for i in range(schema.num_input_tensors()):
562 offsets = tuple(
563 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}"
564 for j in range(ndim)
565 )
566 offset_combine = " + ".join(offsets)
567 code.writeline(
568 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
569 )
571 code.newline()
573 # compute
574 # TODO: sepearate this part
575 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
576 outputs_to_scalar_fn = [
577 self.output_name(i) for i in range(schema.num_output_tensors())
578 ]
579 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
580 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
582 code.writeline("# compute")
583 code.writeline(
584 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
585 )
586 code.newline()
588 # stores
589 for i in range(schema.num_output_tensors()):
590 offsets = tuple(
591 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}"
592 for j in range(ndim)
593 )
594 offset_combine = " + ".join(offsets)
595 code.writeline(
596 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
597 )
599 def gen_body_gsl_without_bptr(self, code):
600 code.writeline("num_ctas = ext.num_programs(0)")
601 code.writeline("for j in range(0, tiles_per_cta):")
602 with code.indent():
603 code.writeline("tile_id = pid + j * num_ctas")
604 self.gen_body_one_tile_per_cta_without_bptr(code)
606 def codegen_nd_tile_with_bptr(self, code):
607 """Generate kernel nd tile & 1d grid with gsl support with block pointer."""
608 self.gen_import_function(code)
609 self.gen_decorators(code)
610 self.gen_signature(code, with_block_pointer=True)
612 # function body for rank-0
613 if self.ndim == 0:
614 with code.indent():
615 self.gen_body_for_0d(code)
616 return code
618 with code.indent():
619 code.writeline("pid = ext.program_id(0)")
620 self.gen_num_tiles(code)
621 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
622 code.writeline("if one_tile_per_cta: # monolitic kernel style")
623 with code.indent():
624 code.writeline("tile_id = pid")
625 self.gen_body_one_tile_per_cta_with_bptr(code)
626 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
627 code.writeline("else: # grid-stride-loop style kernel")
628 with code.indent():
629 self.gen_body_gsl_with_bptr(code)
630 code.newline()
631 return code
633 def codegen_nd_tile_without_bptr(self, code):
634 self.gen_import_function(code)
635 self.gen_decorators(code)
636 self.gen_signature(code, with_block_pointer=False)
638 # function body for rank-0
639 if self.ndim == 0:
640 with code.indent():
641 self.gen_body_for_0d(code)
642 return code
644 with code.indent():
645 code.writeline("pid = ext.program_id(0)")
646 self.gen_num_tiles(code)
647 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
648 code.writeline("if one_tile_per_cta: # monolitic kernel style")
649 with code.indent():
650 code.writeline("tile_id = pid")
651 self.gen_body_one_tile_per_cta_without_bptr(code)
652 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
653 code.writeline("else: # grid-stride-loop style kernel")
654 with code.indent():
655 self.gen_body_gsl_without_bptr(code)
656 code.newline()
657 return code
659 def codegen_nd_tile(self, code):
660 use_block_pointer = self.config.prefer_block_pointer
661 if use_block_pointer:
662 self.codegen_nd_tile_with_bptr(code)
663 else:
664 self.codegen_nd_tile_without_bptr(code)
665 return code
667 def gen_body_one_tile_per_cta_1d_tile(self, code):
668 ndim = self.ndim
669 schema = self.fx
671 # tile id
672 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)")
673 code.writeline("mask = tid < num_tasks")
675 # multi index reconstruction
676 for i in reversed(range(ndim)):
677 if i > 0:
678 code.writeline(f"i{i} = tid % s{i}")
679 code.writeline(f"tid //= s{i}")
680 else:
681 code.writeline(f"i{i} = tid")
682 code.newline()
684 # loads
685 code.writeline("# loads")
686 for i in range(schema.num_input_tensors()):
687 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim))
688 offset_combine = " + ".join(offsets)
689 code.writeline(
690 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
691 )
693 code.newline()
695 # compute
696 # TODO: sepearate this part
697 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
698 outputs_to_scalar_fn = [
699 self.output_name(i) for i in range(schema.num_output_tensors())
700 ]
701 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
702 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
704 code.writeline("# compute")
705 code.writeline(
706 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
707 )
708 code.newline()
710 # stores
711 for i in range(schema.num_output_tensors()):
712 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim))
713 offset_combine = " + ".join(offsets)
714 code.writeline(
715 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
716 )
718 def gen_body_gsl_1d_tile(self, code):
719 code.writeline("num_ctas = ext.num_programs(0)")
720 code.writeline("for j in range(0, tiles_per_cta):")
721 with code.indent():
722 code.writeline("tile_id = pid + j * num_ctas")
723 self.gen_body_one_tile_per_cta_1d_tile(code)
725 def codegen_1d_tile(self, code):
726 """Generate kernel 1d tile & 1d grid with gsl support."""
727 self.gen_import_function(code)
728 self.gen_decorators(code)
729 self.gen_signature_1d_tile(code)
731 # function body for rank-0
732 if self.ndim == 0:
733 with code.indent():
734 self.gen_body_for_0d(code)
735 return code
737 with code.indent():
738 code.writeline("pid = ext.program_id(0)")
739 # code.writeline("num_ctas = te.num_programs(0)")
740 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
741 code.writeline("if one_tile_per_cta: # monolitic kernel style")
742 with code.indent():
743 code.writeline("tile_id = pid")
744 self.gen_body_one_tile_per_cta_1d_tile(code)
745 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
746 code.writeline("else: # grid-stride-loop style kernel")
747 with code.indent():
748 self.gen_body_gsl_1d_tile(code)
749 code.newline()
750 return code
753class WrapperGenerator:
754 def __init__(
755 self,
756 function_schema: FunctionSchema,
757 jit_fn_name: str,
758 ndim: int,
759 name: str,
760 config: CodeGenConfig,
761 ):
762 self.fx = function_schema
763 self.jit_fn_name = jit_fn_name
764 self.ndim = ndim
765 self.name = name
766 self.config = config
768 def input_name(self, i):
769 is_tensor = self.fx.is_tensor(i)
770 name = "in" if is_tensor else "val"
771 index = self.fx.input_index(i)
772 return f"{name}{index}"
774 def output_name(self, i):
775 return f"out{i}"
777 def gen_signature(self, code: IndentedBuffer):
778 # TODO: check if triton handles constexprs transitively
779 schema = self.fx
780 params: List[str] = []
781 for i in range(schema.num_inputs()):
782 if schema.is_tensor(i):
783 params.append(
784 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]"
785 )
786 else:
787 arg_type = schema.input_type(i)
788 if arg_type is not None:
789 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}")
790 else:
791 params.append(f"{self.input_name(i)}")
792 # NOTE: [the wrapper's signature and rules for passing parameters ]
793 # input params: must be passed by position, since the names are renamed to
794 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd
795 # So we enforce that these parameters must be passed by position.
796 # maybe we can fix it later
797 # output parameters: must be passed by keyword, since the scalar function
798 # do not have output parameters(think of it as some scalar function, output
799 # parameter does not make sense in this case.) They are added to allow destination
800 # passing style API. Output parameter is convenient in cases where we want
801 # to use some pre-defiend outputs(especially when they are some views of other
802 # tensors). We emphasize that these parameters are added in-addition, we enforce
803 # that they be passed by keyword. After all, out0, out1, ... does not mismatch
804 # names form the scalar function, since it does not have output parameters.
805 params.append("/")
806 params.append("*") # output params must be passed by keyword
808 for i in range(schema.num_output_tensors()):
809 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]")
810 code.writeline(f"def {self.name}({_cs(params)}): ")
812 def gen_docstring(self, code: IndentedBuffer):
813 schema = self.fx
814 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""'
815 code.writeline(doc)
817 def gen_same_shape_check(self, code: IndentedBuffer):
818 schema: FunctionSchema = self.fx
819 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [
820 f"out{i}.shape" for i in range(schema.num_output_tensors())
821 ]
822 check: str = " == ".join(params)
823 code.writeline(f"assert {check}, 'operand shapes mismatch'")
825 def gen_task_partition(self, code: IndentedBuffer):
826 code.writeline("# task partitioning")
827 ndim = self.ndim
828 if ndim == 0:
829 code.writeline("num_warps = 1")
830 code.writeline("num_ctas = 1")
831 else:
832 code.writeline("shape = out0.shape")
833 code.writeline("num_tasks = out0.numel()")
834 code.writeline("if num_tasks == 0:")
835 with code.indent():
836 self.gen_return(code)
837 max_tile_size = self.config.max_tile_size
838 # Check if all input and output dtypes are complex
839 all_complex = True
840 for i in range(self.fx.num_inputs()):
841 if self.fx.is_tensor(i):
842 input_dtype = self.fx.input_type(i)
843 if input_dtype is not None and not (
844 input_dtype == torch.complex64
845 or input_dtype == torch.complex128
846 ):
847 all_complex = False
848 break
849 if all_complex:
850 # If all inputs are complex, set max_tile_size to half
851 max_tile_size = max_tile_size // 2
852 major, _ = get_device_capability()
853 if self.name.find("fill_scalar") != -1 and major >= 9:
854 code.writeline("tile_sizes = tuple([64])")
855 else:
856 code.writeline(
857 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)"
858 )
859 code.writeline("tile_size = math.prod(tile_sizes)")
860 code.writeline(
861 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))"
862 )
864 if self.name.find("fill_scalar") != -1 and major >= 9:
865 code.writeline("num_ctas = num_tiles")
866 else:
867 max_grid_size0 = self.config.max_grid_size[0]
868 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
870 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
871 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
872 code.writeline("one_tile_per_cta = tiles_per_cta==1")
873 code.writeline("grid = (num_ctas, 1, 1)")
875 def gen_task_partition_1d(self, code: IndentedBuffer):
876 code.writeline("# task partitioning")
877 ndim = self.ndim
878 if ndim == 0:
879 code.writeline("num_warps = 1")
880 code.writeline("num_ctas = 1")
881 else:
882 code.writeline("shape = out0.shape")
883 code.writeline("num_tasks = out0.numel()")
884 code.writeline("if num_tasks == 0:")
885 with code.indent():
886 self.gen_return(code)
887 max_tile_size = self.config.max_tile_size
888 # Check if all input and output dtypes are complex
889 all_complex = True
890 for i in range(self.fx.num_inputs()):
891 if self.fx.is_tensor(i):
892 input_dtype = self.fx.input_type(i)
893 if input_dtype is not None and not (
894 input_dtype == torch.complex64
895 or input_dtype == torch.complex128
896 ):
897 all_complex = False
898 break
899 if all_complex:
900 max_tile_size = max_tile_size // 2
901 major, _ = get_device_capability()
902 if self.name.find("fill_scalar") != -1 and major >= 9:
903 code.writeline("tile_sizes = tuple([1024])")
904 else:
905 code.writeline(
906 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)"
907 )
909 code.writeline("tile_size = tile_sizes[0]")
910 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)")
912 if self.name.find("fill_scalar") != -1 and major >= 9:
913 code.writeline("num_ctas = num_tiles")
914 else:
915 max_grid_size0 = self.config.max_grid_size[0]
916 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
918 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
919 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
920 code.writeline("one_tile_per_cta = tiles_per_cta==1")
921 code.writeline("grid = (num_ctas, 1, 1)")
923 def gen_kernel_launch(
924 self,
925 code: IndentedBuffer,
926 ):
927 schema = self.fx
928 ndim = self.ndim
930 with_block_pointer = self.config.prefer_block_pointer
932 code.writeline("# kernel launch")
933 for i in range(schema.num_input_tensors()):
934 code.writeline(f"in{i}_strides = in{i}.stride()")
935 if not with_block_pointer:
936 continue
937 if ndim >= 2: # where ndim is 1, we don't need to compute stride order
938 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)")
939 else:
940 code.writeline(f"in{i}_stride_order = (0,)")
941 for i in range(schema.num_output_tensors()):
942 code.writeline(f"out{i}_strides = out{i}.stride()")
943 if not with_block_pointer:
944 continue
945 if ndim >= 2:
946 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)")
947 else:
948 code.writeline(f"out{i}_stride_order = (0,)")
950 code.writeline("with torch_device_fn.device(in0.device.index):")
951 with code.indent():
952 code.writeline(f"{self.jit_fn_name}[grid](")
953 with code.indent():
954 params = []
955 # NOTE: WRAP
956 for i in range(schema.num_inputs()):
957 if schema.is_tensor(i):
958 params.append(f"{self.input_name(i)}")
959 else:
960 params.append(self.input_name(i))
961 for i in range(schema.num_output_tensors()):
962 params.append(f"{self.output_name(i)}")
964 code.writeline(f"{_cs(params)},")
966 if ndim > 0:
967 for i in range(schema.num_input_tensors()):
968 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
969 code.writeline(f"{s}, # stride for in{i}")
970 if not with_block_pointer:
971 continue
972 order = ", ".join(
973 f"in{i}_stride_order[{j}]" for j in range(ndim)
974 )
975 code.writeline(f"{order}, # stride order for in{i}")
977 for i in range(schema.num_output_tensors()):
978 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
979 code.writeline(f"{s}, # stride for out{i}")
980 if not with_block_pointer:
981 continue
982 order = ", ".join(
983 f"out{i}_stride_order[{j}]" for j in range(ndim)
984 )
985 code.writeline(f"{order}, # stride orderfor out{i}")
987 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
988 code.writeline(f"{shape_args}, # task indexing space")
989 code.writeline("num_tasks, # num tasks")
990 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
991 for i in range(ndim):
992 code.writeline(f"tile_size{i}=tile_sizes[{i}],")
993 code.writeline("one_tile_per_cta=one_tile_per_cta,")
994 code.writeline("num_warps=num_warps,")
995 code.writeline(")")
997 def gen_kernel_launch_1d(
998 self,
999 code: IndentedBuffer,
1000 ):
1001 schema = self.fx
1002 ndim = self.ndim
1004 code.writeline("# kernel launch")
1005 for i in range(schema.num_input_tensors()):
1006 code.writeline(f"in{i}_strides = in{i}.stride()")
1007 for i in range(schema.num_output_tensors()):
1008 code.writeline(f"out{i}_strides = out{i}.stride()")
1010 code.writeline("with torch_device_fn.device(in0.device.index):")
1011 with code.indent():
1012 code.writeline(f"{self.jit_fn_name}[grid](")
1013 with code.indent():
1014 params = []
1015 # NOTE: WRAP
1016 for i in range(schema.num_inputs()):
1017 if schema.is_tensor(i):
1018 params.append(f"{self.input_name(i)}")
1019 else:
1020 params.append(self.input_name(i))
1021 for i in range(schema.num_output_tensors()):
1022 params.append(f"{self.output_name(i)}")
1024 code.writeline(f"{_cs(params)},")
1026 if ndim > 0:
1027 for i in range(schema.num_input_tensors()):
1028 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
1029 code.writeline(f"{s}, # stride for in{i}")
1030 for i in range(schema.num_output_tensors()):
1031 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
1032 code.writeline(f"{s}, # stride for out{i}")
1034 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
1035 code.writeline(f"{shape_args}, # task indexing space")
1036 code.writeline("num_tasks, # num tasks")
1037 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
1038 code.writeline("tile_size=tile_size,")
1039 code.writeline("one_tile_per_cta=one_tile_per_cta,")
1040 code.writeline("num_warps=num_warps,")
1041 code.writeline(")")
1043 def gen_return(self, code: IndentedBuffer):
1044 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors()))
1045 code.writeline(f"return {return_exprs}")
1047 def codegen_nd_tile(self, code):
1048 self.gen_signature(code)
1050 with code.indent():
1051 self.gen_docstring(code)
1052 self.gen_same_shape_check(code)
1053 self.gen_task_partition(code)
1054 self.gen_kernel_launch(code)
1055 self.gen_return(code)
1056 code.newline()
1057 return code
1059 def codegen_1d_tile(self, code):
1060 self.gen_signature(code)
1062 with code.indent():
1063 self.gen_docstring(code)
1064 self.gen_same_shape_check(code)
1065 self.gen_task_partition_1d(code)
1066 self.gen_kernel_launch_1d(code)
1067 self.gen_return(code)
1068 code.newline()
1069 return code
1072class ModuleGenerator:
1073 def __init__(
1074 self,
1075 function_schema: FunctionSchema,
1076 scalar_fn: triton.JITFunction,
1077 ndim: int,
1078 jit_fn_name: str,
1079 wrapper_name: str,
1080 config: CodeGenConfig,
1081 ):
1082 self.config = config
1083 self.scalar_fn = scalar_fn
1084 self.wrapper_gen = WrapperGenerator(
1085 function_schema, jit_fn_name, ndim, wrapper_name, config
1086 )
1087 self.kernel_gen = KernelGenerator(
1088 function_schema, scalar_fn, ndim, jit_fn_name, config
1089 )
1091 @staticmethod
1092 def _collect_jit_deps(scalar_fn):
1093 """Collect extra imports and local @triton.jit helper sources.
1095 Parses the source module where scalar_fn is defined using AST.
1096 Returns a tuple of:
1097 - extra_imports: dict of module_path -> set of names
1098 - local_sources: list of source strings for local @triton.jit
1099 functions (those NOT decorated with @pointwise_dynamic)
1100 """
1101 import ast
1102 import inspect
1104 py_fn = getattr(scalar_fn, "fn", scalar_fn)
1105 module_name = getattr(py_fn, "__module__", None)
1106 if not module_name:
1107 return {}, []
1108 try:
1109 mod = importlib.import_module(module_name)
1110 source_file = inspect.getfile(mod)
1111 except (ImportError, TypeError, OSError):
1112 return {}, []
1113 try:
1114 with open(source_file) as f:
1115 module_source = f.read()
1116 source_lines = module_source.splitlines(keepends=True)
1117 tree = ast.parse(module_source)
1118 except (OSError, SyntaxError):
1119 return {}, []
1121 # Collect non-standard import-from lines
1122 ALREADY_IMPORTED = {
1123 "math",
1124 "typing",
1125 "torch",
1126 "triton",
1127 "triton.language",
1128 "flag_gems.utils.shape_utils",
1129 "flag_gems.utils.tensor_wrapper",
1130 "flag_gems.utils.libentry",
1131 "flag_gems.utils",
1132 "flag_gems.runtime",
1133 "flag_gems.utils.pointwise_dynamic",
1134 }
1135 extra_imports = {}
1136 for node in ast.iter_child_nodes(tree):
1137 if isinstance(node, ast.ImportFrom) and node.module:
1138 if node.module in ALREADY_IMPORTED:
1139 continue
1140 names = {alias.name for alias in node.names}
1141 extra_imports.setdefault(node.module, set()).update(names)
1143 # Collect local @triton.jit functions (without @pointwise_dynamic)
1144 def _has_decorator(func_node, name):
1145 for dec in func_node.decorator_list:
1146 src = "".join(source_lines[dec.lineno - 1 : dec.end_lineno])
1147 if name in src:
1148 return True
1149 return False
1151 def _extract_source(func_node):
1152 start = func_node.lineno - 1
1153 if func_node.decorator_list:
1154 start = func_node.decorator_list[0].lineno - 1
1155 end = func_node.end_lineno
1156 return "".join(source_lines[start:end])
1158 local_sources = []
1159 for node in ast.iter_child_nodes(tree):
1160 if not isinstance(node, ast.FunctionDef):
1161 continue
1162 if not _has_decorator(node, "triton.jit") and not _has_decorator(
1163 node, "jit"
1164 ):
1165 continue
1166 if _has_decorator(node, "pointwise_dynamic"):
1167 continue
1168 local_sources.append(_extract_source(node))
1170 return extra_imports, local_sources
1172 def generate_imports(self, code: IndentedBuffer) -> IndentedBuffer:
1173 code.writeline("import math")
1174 code.writeline("from typing import Union")
1175 code.writeline("import torch")
1176 code.writeline("import triton")
1177 code.writeline("from triton import language as tl")
1178 code.newline()
1179 code.writeline("from flag_gems.utils.shape_utils import (")
1180 code.writeline(" heuristics_for_tile_size,")
1181 code.writeline(" heuristics_for_num_warps,")
1182 code.writeline(" stride_order,")
1183 code.writeline(")")
1184 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer")
1185 code.writeline("from flag_gems.utils.libentry import libentry")
1186 code.writeline("from flag_gems.utils import triton_lang_extension as ext")
1187 code.writeline("from flag_gems.runtime import torch_device_fn")
1189 # Generate extra imports and local JIT deps of the scalar function
1190 jit_dep_imports, local_jit_sources = self._collect_jit_deps(self.scalar_fn)
1191 for module_path, names in sorted(jit_dep_imports.items()):
1192 sorted_names = ", ".join(sorted(names))
1193 code.writeline(f"from {module_path} import {sorted_names}")
1195 code.newline()
1196 code.newline()
1198 # Emit local @triton.jit helper functions
1199 for source in local_jit_sources:
1200 for line in source.splitlines():
1201 code.writeline(line)
1202 code.newline()
1204 return code
1206 def codegen(self, code: IndentedBuffer):
1207 code = self.generate_imports(code)
1208 if self.config.prefer_1d_tile:
1209 code = self.wrapper_gen.codegen_1d_tile(code)
1210 code = self.kernel_gen.codegen_1d_tile(code)
1211 else:
1212 code = self.wrapper_gen.codegen_nd_tile(code)
1213 code = self.kernel_gen.codegen_nd_tile(code)
1214 return code
1217@dataclass
1218class KernelInfo:
1219 """Information about a generated kernel for C++ integration."""
1221 file_path: str
1222 kernel_name: str
1223 wrapper_name: str
1224 ndim: int
1227class ComplexMode(Enum):
1228 NONE = auto()
1229 ELEMENTWISE = auto() # add/sub: view_as_real → same kernel → view_as_complex
1230 CROSS = auto() # mul/div: split ar/ai/br/bi → cross_kernel
1233@dataclass
1234class ComplexStrategy:
1235 mode: ComplexMode = ComplexMode.NONE
1236 cross_kernel: object = None
1237 tensorize_scalars: bool = False
1238 fallback_target: object = None
1241_REAL_TO_COMPLEX = {
1242 torch.float16: torch.complex32,
1243 torch.bfloat16: torch.complex32,
1244 torch.float32: torch.complex64,
1245 torch.float64: torch.complex128,
1246}
1249class PointwiseDynamicFunction:
1250 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
1251 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
1252 The generated code are written out to the cache directory (defaults to ~/.flaggems).
1253 """
1255 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None):
1256 self.fx = op_desc
1258 assert isinstance(scalar_fn, JITFunction)
1259 self._scalar_fn = scalar_fn
1260 self._scalar_fn_cache_key = scalar_fn.cache_key
1261 self.pid = os.getpid()
1263 self.config: CodeGenConfig = config or get_codegen_config()
1265 # instantiated & cached overloads
1266 self.overloads: Mapping[str, Callable] = {}
1267 # cached kernel info for C++ integration
1268 self._kernel_info_cache: Mapping[str, KernelInfo] = {}
1270 # complex dispatch support
1271 self.complex_strategy = ComplexStrategy()
1272 self._operand_indices = self._infer_operand_indices()
1274 # -------------------- operand index inference --------------------
1276 def _infer_operand_indices(self):
1277 """Infer operand indices from schema._promotion_methods, done once at init."""
1278 indices = set()
1279 for pm in self.fx._promotion_methods:
1280 for idx in pm[:-1]:
1281 indices.add(idx)
1282 return frozenset(indices)
1284 # -------------------- register_complex --------------------
1286 def register_complex(
1287 self, mode, cross_kernel=None, tensorize_scalars=False, fallback_target=None
1288 ):
1289 """Register complex number support for this kernel.
1291 Args:
1292 mode: ComplexMode.ELEMENTWISE (add/sub) or ComplexMode.CROSS (mul/div).
1293 cross_kernel: A PointwiseDynamicFunction for cross-term ops (mul/div).
1294 tensorize_scalars: If True, scalar operands are converted to tensors
1295 before delegating to fallback_target.
1296 fallback_target: A PointwiseDynamicFunction (tensor-tensor version)
1297 to delegate to after tensorizing scalar operands.
1298 """
1299 self.complex_strategy = ComplexStrategy(
1300 mode=mode,
1301 cross_kernel=cross_kernel,
1302 tensorize_scalars=tensorize_scalars,
1303 fallback_target=fallback_target,
1304 )
1305 return self
1307 # -------------------- call entry --------------------
1309 def __call__(self, *args, **kwargs):
1310 if self._should_use_complex_path(args):
1311 return self._call_complex_dispatch(*args, **kwargs)
1312 return self._call_real_impl(*args, **kwargs)
1314 def _call_real_impl(self, *args, **kwargs):
1315 """Single entry point for real kernel invocation."""
1316 ndim, args, kwargs = self.prepare_args(*args, **kwargs)
1317 overload = self.instantiate(ndim)
1318 out = overload(*args, **kwargs)
1319 return self._unwrap(out)
1321 # -------------------- complex helpers --------------------
1323 @staticmethod
1324 def _is_complex_arg(a):
1325 return (isinstance(a, torch.Tensor) and a.is_complex()) or isinstance(
1326 a, complex
1327 )
1329 def _should_use_complex_path(self, args):
1330 if self.complex_strategy.mode == ComplexMode.NONE:
1331 return False
1332 return any(
1333 self._is_complex_arg(args[i])
1334 for i in self._operand_indices
1335 if i < len(args)
1336 )
1338 def _split_args(self, args):
1339 """Split args into operands and others by original position index."""
1340 operands = {}
1341 others = {}
1342 for i, a in enumerate(args):
1343 if i in self._operand_indices:
1344 operands[i] = a
1345 else:
1346 others[i] = a
1347 return operands, others
1349 def _merge_args(self, operands, others):
1350 """Rebuild args tuple from operands and others by original position index."""
1351 total = len(operands) + len(others)
1352 merged = [None] * total
1353 for i, v in operands.items():
1354 merged[i] = v
1355 for i, v in others.items():
1356 merged[i] = v
1357 return tuple(merged)
1359 def _classify_complex_inputs(self, operands):
1360 """Classify operands as 'all_complex', 'mixed', or 'real'."""
1361 complex_count = sum(1 for v in operands.values() if self._is_complex_arg(v))
1362 if complex_count == len(operands):
1363 return "all_complex"
1364 elif complex_count > 0:
1365 return "mixed"
1366 return "real"
1368 def _infer_device(self, operands):
1369 for v in operands.values():
1370 if isinstance(v, torch.Tensor):
1371 return v.device
1372 return None
1374 def _infer_complex_dtype(self, operands):
1375 return torch.result_type(*operands.values())
1377 def _tensorize_scalar_operands(self, operands, dtype, device):
1378 """Convert scalar operands to tensors."""
1379 result = {}
1380 for i, v in operands.items():
1381 if not isinstance(v, torch.Tensor):
1382 if isinstance(v, complex):
1383 result[i] = torch.tensor(v, dtype=dtype, device=device)
1384 elif isinstance(v, float):
1385 result[i] = torch.tensor(v, dtype=torch.float32, device=device)
1386 elif isinstance(v, (int, bool)):
1387 result[i] = torch.tensor(v, dtype=torch.int64, device=device)
1388 else:
1389 result[i] = v
1390 else:
1391 result[i] = v
1392 return result
1394 def _to_complex_tensor(self, a, target_dtype, device):
1395 """Convert a scalar or real tensor to a complex tensor."""
1396 if isinstance(a, torch.Tensor):
1397 if a.is_complex():
1398 return a
1399 if a.is_floating_point():
1400 cdtype = _REAL_TO_COMPLEX.get(a.dtype, torch.complex64)
1401 else:
1402 a = a.to(torch.float32)
1403 cdtype = torch.complex64
1404 return torch.complex(a, torch.zeros_like(a)).to(cdtype)
1405 elif isinstance(a, complex):
1406 return torch.tensor(a, dtype=target_dtype, device=device)
1407 elif isinstance(a, (int, float)):
1408 return torch.tensor(complex(a, 0), dtype=target_dtype, device=device)
1409 return a
1411 # -------------------- complex dispatch --------------------
1413 def _call_complex_dispatch(self, *args, **kwargs):
1414 """Unified complex dispatch entry point."""
1415 strategy = self.complex_strategy
1416 operands, others = self._split_args(args)
1418 device = self._infer_device(operands)
1419 result_dtype = self._infer_complex_dtype(operands)
1421 # tensorize scalar operands and delegate to fallback_target
1422 if strategy.tensorize_scalars and strategy.fallback_target is not None:
1423 operands = self._tensorize_scalar_operands(operands, result_dtype, device)
1424 new_args = self._merge_args(operands, others)
1425 return strategy.fallback_target(*new_args, **kwargs)
1427 # convert all operands to complex tensors
1428 for i in list(operands.keys()):
1429 operands[i] = self._to_complex_tensor(operands[i], result_dtype, device)
1431 # broadcast complex tensor operands
1432 complex_tensors = [operands[i] for i in sorted(operands.keys())]
1433 complex_tensors = torch.broadcast_tensors(*complex_tensors)
1434 for idx, key in enumerate(sorted(operands.keys())):
1435 operands[key] = complex_tensors[idx]
1437 classification = self._classify_complex_inputs(operands)
1439 if strategy.mode == ComplexMode.CROSS and classification == "all_complex":
1440 return self._call_complex_cross(operands, result_dtype)
1441 elif classification in ("all_complex", "mixed"):
1442 return self._call_complex_elementwise(
1443 operands, others, result_dtype, kwargs
1444 )
1445 else:
1446 new_args = self._merge_args(operands, others)
1447 return self._call_real_impl(*new_args, **kwargs)
1449 def _call_complex_elementwise(self, operands, others, result_dtype, kwargs):
1450 """Elementwise: view_as_real -> call real kernel -> view_as_complex."""
1451 real_tensors = {i: torch.view_as_real(t) for i, t in operands.items()}
1453 # promote to common real dtype
1454 dtypes = [t.dtype for t in real_tensors.values()]
1455 common_dtype = dtypes[0]
1456 for d in dtypes[1:]:
1457 common_dtype = torch.promote_types(common_dtype, d)
1458 real_tensors = {i: t.to(common_dtype) for i, t in real_tensors.items()}
1460 new_args = self._merge_args(real_tensors, others)
1461 out_real = self._call_real_impl(*new_args, **kwargs)
1462 return torch.view_as_complex(out_real.contiguous()).to(result_dtype)
1464 def _call_complex_cross(self, operands, result_dtype):
1465 """Cross-term: split ar/ai/br/bi -> call cross_kernel -> stack -> view_as_complex."""
1466 sorted_keys = sorted(operands.keys())
1467 A, B = operands[sorted_keys[0]], operands[sorted_keys[1]]
1468 Ar = torch.view_as_real(A)
1469 Br = torch.view_as_real(B)
1470 ar, ai = Ar[..., 0], Ar[..., 1]
1471 br, bi = Br[..., 0], Br[..., 1]
1473 common_dtype = torch.promote_types(ar.dtype, br.dtype)
1474 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
1475 br, bi = br.to(common_dtype), bi.to(common_dtype)
1477 real, imag = self.complex_strategy.cross_kernel(ar, ai, br, bi)
1479 out = torch.stack((real, imag), dim=-1)
1480 return torch.view_as_complex(out.contiguous()).to(result_dtype)
1482 @staticmethod
1483 def use_fast_path(tensors):
1484 return all_the_same_shape(tensors) and (
1485 all_c_contiguous(tensors)
1486 or (
1487 all_the_same_stride(tensors)
1488 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0])
1489 )
1490 )
1492 def prepare_args(self, *args, _skip_tensor_check=False, **kwargs):
1493 # output allocation(when needed)
1494 # task simplification & task-rank infernece & input-output reinterpretation
1495 schema = self.fx
1496 outputs_that_need_allocation: List[int] = []
1497 out_tensors = []
1498 for i in range(schema.num_output_tensors()):
1499 k = f"out{i}"
1500 if k in kwargs:
1501 out_tensors.append(kwargs[k])
1502 else:
1503 outputs_that_need_allocation.append(i)
1504 # input arguments must be passed by position
1505 if not _skip_tensor_check and schema._is_tensor is not None:
1506 if not check_tensor_attributes(args, (schema._is_tensor)):
1507 raise ValueError(
1508 "Input arguments must be passed by position, and the corresponding dtype must be specified."
1509 )
1510 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)]
1512 # output dtype promotions
1513 outputs_dtypes_for_allocation = []
1514 for i in outputs_that_need_allocation:
1515 *arg_indices, method = schema._promotion_methods[i]
1516 promote_args = (args[j] for j in arg_indices)
1517 _, dtype = type_promotion(*promote_args, type_promotion=method)
1518 outputs_dtypes_for_allocation.append(dtype)
1520 tensors = out_tensors + in_tensors
1521 INT32_MAX = torch.iinfo(torch.int32).max
1522 if tensors[0].numel() > INT32_MAX:
1523 self.config.prefer_block_pointer = False
1524 if self.use_fast_path(tensors): # dimension collapse & use physical ordering
1525 allocated_outputs = [
1526 torch.empty_like(tensors[0], dtype=dtype)
1527 for dtype in outputs_dtypes_for_allocation
1528 ]
1529 task_shape = (tensors[0].numel(),)
1530 strides = (1,)
1531 ndim = 1
1532 args = tuple(
1533 (
1534 StridedBuffer(item, task_shape, strides)
1535 if schema.is_tensor(i)
1536 else item
1537 )
1538 for i, item in enumerate(args)
1539 )
1540 kwargs = {
1541 k: StridedBuffer(item, task_shape, strides)
1542 for k, item in kwargs.items()
1543 }
1544 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1545 kwargs[f"out{output_id}"] = StridedBuffer(
1546 allocated_outputs[seq_id], task_shape, strides
1547 )
1548 else:
1549 # a simple strategy: all the undefined tensors will follow the first
1550 # tensor that is not broadcated, no attempts to simplify task, no reordering,
1551 # no dimenion collapsing
1552 shapes = tuple(item.shape for item in in_tensors)
1554 task_shape = broadcast_shapes(shapes)
1556 if out_tensors:
1557 for index, item in enumerate(out_tensors):
1558 if list(item.shape) != list(task_shape):
1559 raise RuntimeError(
1560 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!"
1561 )
1562 # output arguments must not have internal overlapping for pointwise operation
1563 if has_internal_overlapping(item) == MemOverlap.Yes:
1564 raise RuntimeError(
1565 "Pointwise Input arguments should not have internal overlapping."
1566 )
1568 ndim = len(task_shape)
1569 for item in tensors:
1570 if item.shape == task_shape:
1571 allocated_outputs = [
1572 torch.empty_like(item, dtype=dtype)
1573 for dtype in outputs_dtypes_for_allocation
1574 ]
1575 break
1576 else: # nobreak
1577 device = tensors[0].device
1578 allocated_outputs = [
1579 torch.empty(task_shape, dtype=dtype, device=device)
1580 for dtype in outputs_dtypes_for_allocation
1581 ]
1582 args = tuple(
1583 (
1584 StridedBuffer(
1585 item,
1586 task_shape,
1587 broadcasted_stride(item.shape, item.stride(), task_shape),
1588 )
1589 if schema.is_tensor(i)
1590 else item
1591 )
1592 for i, item in enumerate(args)
1593 )
1594 kwargs = {
1595 k: StridedBuffer(
1596 item,
1597 task_shape,
1598 broadcasted_stride(item.shape, item.stride(), task_shape),
1599 )
1600 for k, item in kwargs.items()
1601 }
1602 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1603 item = allocated_outputs[seq_id]
1604 kwargs[f"out{output_id}"] = StridedBuffer(
1605 item,
1606 task_shape,
1607 broadcasted_stride(item.shape, item.stride(), task_shape),
1608 )
1609 return (ndim, args, kwargs)
1611 def _unwrap(self, tensors):
1612 # unwrap StridedBuffer to get Tensor
1613 if self.fx.num_output_tensors() == 1:
1614 item = tensors
1615 return item.unwrap()
1616 return tuple(item.unwrap() for item in tensors)
1618 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]:
1619 """Compute kernel name, wrapper name, and file path for a given ndim.
1621 This is the single source of truth for naming, used by both instantiate()
1622 and get_kernel_info() to ensure consistency.
1624 Returns:
1625 Tuple of (kernel_name, wrapper_name, file_path)
1626 """
1627 scalar_fn_name = self._scalar_fn.__name__
1628 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}"
1629 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}"
1631 file_name = (
1632 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_"
1633 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}"
1634 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}"
1635 ".py"
1636 )
1637 file_path = str(code_cache_dir() / file_name)
1639 return kernel_name, wrapper_name, file_path
1641 def instantiate(self, ndim):
1642 # NOTE: manually instantiated overload does not have `prepare_args` as
1643 # preprocessing, so you have to manually allocate output and make sure that
1644 # the inputs & ouputs actually fits the manually instantiated overload
1645 key = f"{ndim}_{self.config.prefer_block_pointer}"
1646 if key in self.overloads:
1647 return self.overloads[key]
1649 code = IndentedBuffer()
1651 # Use helper to compute names (single source of truth)
1652 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim)
1654 module_gen = ModuleGenerator(
1655 self.fx,
1656 self._scalar_fn,
1657 ndim,
1658 kernel_name,
1659 wrapper_name,
1660 self.config,
1661 )
1662 module_gen.codegen(code)
1664 # NOTE: [why write the generated code to a file]
1665 # triton uses inpsect to get the source of the jitted function, which requires
1666 # that the source code can be found by inspect
1667 # We write it into a file, since inspect cannot find the source of functions dynamically
1668 # created via exec string. We can help inspect to find the source by hacking linecache
1669 # library, but we find generating a module simpler, since we can generating 2 functions
1670 # the kernel and the wrapper, and the wrapper calls the kernel.
1671 write_atomic(file_path, code.getvalue())
1673 # load
1674 spec = importlib.util.spec_from_file_location(
1675 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}",
1676 file_path,
1677 )
1678 m = importlib.util.module_from_spec(spec)
1679 # do not expose it to sys.modules
1680 # sys.modules["_add_module"] = m
1682 # NOTE: [why not import the scalar function]
1683 # we do not re-import the scalar function, although the generated kernel **calls** it
1684 # Since a function's __name__ may be changed, from the module where it is defined import its
1685 # __name__ is not same; Also the same may be rebind to something else, importing via name
1686 # cannot guarantee that scalar function is imported.
1687 # So we copy the scalar function and its __globals__ to the generated module to do this
1688 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime
1689 spec.loader.exec_module(m)
1690 m.__dict__.update(self._scalar_fn.__globals__)
1691 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn
1693 overload = getattr(m, wrapper_name)
1694 self.overloads[key] = overload
1696 # Cache kernel info for C++ integration
1697 self._kernel_info_cache[key] = KernelInfo(
1698 file_path=file_path,
1699 kernel_name=kernel_name,
1700 wrapper_name=wrapper_name,
1701 ndim=ndim,
1702 )
1704 return overload
1706 def get_kernel_info(self, ndim: int) -> KernelInfo:
1707 """Get kernel information for a given ndim.
1709 This method is useful for C++ integration to get the file path and
1710 kernel name without duplicating the naming logic.
1712 If the kernel hasn't been instantiated yet, this will instantiate it first.
1714 Args:
1715 ndim: The rank of the task space
1717 Returns:
1718 KernelInfo with file_path, kernel_name, wrapper_name, and ndim
1719 """
1720 key = f"{ndim}_{self.config.prefer_block_pointer}"
1722 # Ensure the kernel is instantiated
1723 if key not in self._kernel_info_cache:
1724 self.instantiate(ndim)
1726 return self._kernel_info_cache[key]
1729def pointwise_dynamic(
1730 f: Optional[JITFunction] = None,
1731 *,
1732 num_inputs: Optional[int] = None,
1733 is_tensor: Optional[List[bool]] = None,
1734 dtypes: Optional[List[Optional[type]]] = None,
1735 num_outputs: Optional[int] = None,
1736 promotion_methods: Optional[Tuple[int, ...]] = None,
1737 config: Optional[CodeGenConfig] = None,
1738):
1739 def decorator(fn):
1740 nonlocal num_inputs
1741 if (num_inputs is None) and (is_tensor is None) and (dtypes is None):
1742 num_inputs = len(fn.arg_names)
1743 op_desc = FunctionSchema(
1744 num_inputs=num_inputs,
1745 is_tensor=is_tensor,
1746 dtypes=dtypes,
1747 num_outputs=num_outputs,
1748 promotion_methods=promotion_methods,
1749 )
1750 return PointwiseDynamicFunction(op_desc, fn, config)
1752 if f is not None:
1753 return decorator(f)
1754 return decorator