Coverage for src/flag_gems/utils/pointwise_dynamic.py: 94%
1016 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import importlib
2import os
3from dataclasses import dataclass
4from enum import Enum, auto
5from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
7import torch
8import triton
9from triton.runtime.jit import JITFunction
11from flag_gems.utils.code_cache import code_cache_dir
12from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
13from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config
14from flag_gems.utils.device_info import get_device_capability
15from flag_gems.utils.shape_utils import (
16 MemOverlap,
17 all_c_contiguous,
18 all_the_same_shape,
19 all_the_same_stride,
20 broadcast_shapes,
21 broadcasted_stride,
22 check_tensor_attributes,
23 has_internal_overlapping,
24)
25from flag_gems.utils.tensor_wrapper import StridedBuffer
26from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion
29# ------------------ Operation Description ---------------------------
30def _type_name(type) -> str:
31 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object"
32 if type in (bool, int, float, str):
33 return type.__name__
34 if isinstance(type, torch.dtype):
35 return str(type)
36 return str(type)
39def _check_typed_list(container, type):
40 for item in container:
41 assert isinstance(item, type)
44def _check_sized_list(container, size):
45 assert len(container) == size
48def _tuple_content(strings: Sequence[str]) -> str:
49 # comma separated list
50 if len(strings) == 0:
51 return ""
52 if len(strings) == 1:
53 return f"{strings[0]},"
54 else:
55 return ", ".join(strings)
58def _cs(strings: Iterable[str]) -> str:
59 return ", ".join(strings)
62def _broadcast_vec(i, ndim):
63 axes = [":" if j == i else "None" for j in range(ndim)]
64 return f"[{_cs(axes)}]"
67class FunctionSchema:
68 _num_inputs: int
69 _is_tensor: List[bool]
70 _dtypes: List[Optional[type]]
72 _num_input_tensors: int
73 _num_non_tensor_inputs: int
75 _num_outputs: int
76 _promotion_methods: List[Tuple[int, ...]]
78 def __init__(
79 self,
80 *,
81 num_inputs: Optional[int] = None,
82 is_tensor: Optional[List[bool]] = None,
83 dtypes: Optional[List[Optional[type]]] = None,
84 num_outputs: Optional[int] = None,
85 promotion_methods=None,
86 ):
87 if is_tensor is not None:
88 _check_typed_list(is_tensor, bool)
89 if dtypes is not None:
90 _check_typed_list(dtypes, (type, type(None)))
92 if promotion_methods is None:
93 raise ValueError(
94 "No type promotion method provided! You must provide type promotion method for each output!"
95 )
96 else:
97 self._promotion_methods = self.canonicalize_promotion_methods(
98 promotion_methods
99 )
100 if num_inputs is not None:
101 self._num_inputs = num_inputs
102 if is_tensor is not None:
103 _check_sized_list(is_tensor, num_inputs)
104 self._is_tensor = is_tensor
105 else:
106 self._is_tensor = [True] * num_inputs
108 if dtypes is not None:
109 _check_sized_list(dtypes, num_inputs)
110 self._dtypes = dtypes
111 else:
112 self._dtypes = [None] * num_inputs
113 elif is_tensor is not None:
114 self._num_inputs = len(is_tensor)
115 self._is_tensor = is_tensor
116 if dtypes is not None:
117 _check_sized_list(dtypes, self._num_inputs)
118 self._dtypes = dtypes
119 else:
120 self._dtypes = [None] * self._num_inputs
121 elif dtypes is not None:
122 self._num_inputs = len(dtypes)
123 self._dtypes = dtypes
124 if is_tensor is not None:
125 _check_sized_list(is_tensor, self._num_inputs)
126 self._is_tensor = is_tensor
127 else:
128 self._is_tensor = [item is None for item in dtypes]
129 else:
130 raise ValueError(
131 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified."
132 )
134 if num_outputs is not None:
135 self._num_outputs = num_outputs
136 _check_sized_list(promotion_methods, num_outputs)
137 else:
138 self._num_outputs = len(promotion_methods)
140 assert self._num_inputs >= 1
141 assert self._num_outputs >= 1
143 self._num_input_tensors = sum(self._is_tensor)
144 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors
145 self._input_id = self._compute_input_id()
147 @staticmethod
148 def canonicalize_promotion_methods(promotion_methods):
149 canonicalized = []
150 for item in promotion_methods:
151 *arg_indices, method = item
152 canonicalized.append(
153 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method])
154 )
155 return canonicalized
157 def num_inputs(self):
158 # num of arguments, outputs not included
159 return self._num_inputs
161 def num_outputs(self):
162 return self._num_outputs
164 def is_tensor(self, arg_id: int) -> bool:
165 return self._is_tensor[arg_id]
167 def input_type(self, arg_id) -> Optional[type]:
168 return self._dtypes[arg_id]
170 def output_type(self, i):
171 return self._promotion_methods[i]
173 def num_input_tensors(self) -> int:
174 return self._num_input_tensors
176 def num_output_tensors(self) -> int:
177 return self._num_outputs
179 def num_non_tensor_args(self) -> int:
180 return self._num_non_tensor_inputs
182 def signature(self, outputs_in_arg: bool = False) -> str:
183 input_types = []
184 for is_tensor, dtype in zip(self._is_tensor, self._dtypes):
185 if is_tensor:
186 input_types.append("StridedBuffer")
187 else:
188 if dtype is None:
189 input_types.append("scalar")
190 else:
191 input_types.append(_type_name(dtype))
193 output_types = []
195 if outputs_in_arg:
196 for i in range(self.num_outputs()):
197 output_types.append(f"StridedBuffer(a{1}!)")
198 input_types.extend(output_types)
199 else:
200 for _ in range(self.num_outputs()):
201 output_types.append("StridedBuffer")
202 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}'
203 return sig
205 def _compute_input_id(self):
206 input_tensor_index = 0
207 non_tensor_index = 0
208 mapping: List[int] = []
209 for i in range(self.num_inputs()):
210 if self.is_tensor(i):
211 mapping.append(input_tensor_index)
212 input_tensor_index += 1
213 else:
214 mapping.append(non_tensor_index)
215 non_tensor_index += 1
216 return mapping
218 def input_index(self, idx):
219 return self._input_id[idx]
221 def __str__(self) -> str:
222 return self.signature(outputs_in_arg=False)
225class KernelGenerator:
226 def __init__(
227 self,
228 function_schema: FunctionSchema,
229 scalar_fn: triton.JITFunction,
230 rank: int,
231 name: str,
232 config: CodeGenConfig,
233 ):
234 self.fx = function_schema
235 self.fn = scalar_fn
236 self.ndim = rank
237 self.name = name
238 self.config = config
240 self.fn_name = scalar_fn.__name__
241 self.fn_module = scalar_fn.__module__
243 def gen_import_function(self, code: IndentedBuffer):
244 code.writeline("@triton.jit")
245 code.writemultiline(self.fn.src)
246 code.newline()
248 def gen_decorators(self, code):
249 code.writeline("@libentry()")
250 num_non_tensor_args = self.fx.num_non_tensor_args()
251 if num_non_tensor_args > 0:
252 # we do not specialize non tensor args since they are passed into the inlined function
253 # which means that their values may not deserve specialization
254 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)]
255 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
256 else:
257 code.writeline("@triton.jit")
259 def input_name(self, i):
260 is_tensor = self.fx.is_tensor(i)
261 name = "in" if is_tensor else "val"
262 index = self.fx.input_index(i)
263 return f"{name}{index}"
265 def output_name(self, i):
266 return f"out{i}"
268 def gen_signature(self, code, with_block_pointer=False):
269 code.writeline(f"def {self.name}(")
270 with code.indent():
271 input_tensor_index = 0
272 non_tensor_index = 0
273 output_tensor_index = 0
275 schema = self.fx
276 # signature: inputs ptrs & non tensor inputs
277 for i in range(schema.num_inputs()):
278 if schema.is_tensor(i):
279 code.writeline(
280 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
281 )
282 input_tensor_index += 1
283 else:
284 if schema.input_type(i) is not None:
285 code.writeline(
286 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
287 )
288 else:
289 code.writeline(f"val{non_tensor_index},")
290 non_tensor_index += 1
292 # signature: output ptrs
293 for i in range(schema.num_outputs()):
294 code.writeline(
295 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
296 )
297 output_tensor_index += 1
299 # signature: strides, for each tensor arguments
300 ndim = self.ndim
301 if ndim > 0:
302 # strides for inputs
303 for i in range(schema.num_input_tensors()):
304 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
305 code.writeline(f"{stride_args}, # strides for in{i}")
306 if with_block_pointer:
307 stride_order_args = _cs(
308 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
309 )
310 code.writeline(f"{stride_order_args}, # stride order for in{i}")
312 # strides for outputs
313 for i in range(schema.num_output_tensors()):
314 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
315 code.writeline(f"{stride_args}, # strides for out{i}")
316 if with_block_pointer:
317 stride_order_args = _cs(
318 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
319 )
320 code.writeline(
321 f"{stride_order_args}, # stride order for out{i}"
322 )
324 # task space, used to reconstruct multi index
325 task_space_args = _cs(f"s{i}" for i in range(ndim))
326 code.writeline(f"{task_space_args}, # task_space")
328 # number of tasks, used to compute mask
329 code.writeline("num_tasks,")
331 # tile size & tiles_per_cta, gsl style
332 if ndim > 0:
333 code.writeline("tiles_per_cta: int,")
334 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim))
335 code.writeline(f"{tile_sizes},")
336 code.writeline("one_tile_per_cta: tl.constexpr,")
337 code.writeline("):")
339 def gen_signature_1d_tile(self, code):
340 code.writeline(f"def {self.name}(")
341 with code.indent():
342 input_tensor_index = 0
343 non_tensor_index = 0
344 output_tensor_index = 0
346 schema = self.fx
347 # signature: inputs ptrs & non tensor inputs
348 for i in range(schema.num_inputs()):
349 if schema.is_tensor(i):
350 code.writeline(
351 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
352 )
353 input_tensor_index += 1
354 else:
355 if schema.input_type(i) is not None:
356 code.writeline(
357 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
358 )
359 else:
360 code.writeline(f"val{non_tensor_index},")
361 non_tensor_index += 1
363 # signature: output ptrs
364 for i in range(schema.num_outputs()):
365 code.writeline(
366 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
367 )
368 output_tensor_index += 1
370 # signature: strides, for each tensor arguments
371 ndim = self.ndim
372 if ndim > 0:
373 # strides for inputs
374 for i in range(schema.num_input_tensors()):
375 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
376 code.writeline(f"{stride_args}, # strides for in{i}")
378 # strides for outputs
379 for i in range(schema.num_output_tensors()):
380 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
381 code.writeline(f"{stride_args}, # strides for out{i}")
383 # task space, used to reconstruct multi index
384 task_space_args = _cs(f"s{i}" for i in range(ndim))
385 code.writeline(f"{task_space_args}, # task_space")
387 # number of tasks, used to compute mask
388 code.writeline("num_tasks,")
390 # tile size & tiles_per_cta, gsl style
391 if ndim > 0:
392 code.writeline("tiles_per_cta: int,")
393 code.writeline("tile_size: tl.constexpr,")
394 code.writeline("one_tile_per_cta: tl.constexpr,")
395 code.writeline("):")
397 def gen_num_tiles(self, code):
398 # tile-grid size
399 ndim = self.ndim
400 for i in range(ndim):
401 if i < ndim:
402 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})")
404 def gen_body_for_0d(self, code):
405 schema = self.fx
406 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
407 outputs_to_scalar_fn = [
408 self.output_name(i) for i in range(schema.num_output_tensors())
409 ]
410 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
411 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
413 code.writeline("# loads")
414 for i in range(schema.num_input_tensors()):
415 code.writeline(
416 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) "
417 "# workaround the bug on bool, we should use the pointer's dtype)"
418 )
419 code.newline()
421 code.writeline("# compute")
422 code.writeline(
423 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
424 )
425 code.newline()
427 code.writeline("# stores")
428 for i in range(schema.num_output_tensors()):
429 code.writeline(
430 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))"
431 )
432 code.newline()
433 return code
435 # nd tile 1d grid kernel with block pointer
436 def gen_body_one_tile_per_cta_with_bptr(self, code):
437 ndim = self.ndim
438 schema = self.fx
440 # block pointer for each operand
441 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim)))
442 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim)))
443 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim)))
445 # reconstruct pid multi index
446 code.writeline(
447 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
448 )
449 for i in reversed(range(ndim)):
450 if i > 0:
451 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
452 code.writeline(f"tile_id //= num_tiles{i}")
453 else:
454 code.writeline(f"tile_id{i} = tile_id")
455 code.newline()
457 # cta_offsets
458 code.writeline("# tile offsets")
459 for i in range(ndim):
460 # Or else: AssertionError: Block pointers only support 32 bit
461 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing
462 # for 64 bit support
463 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)")
465 # loads
466 code.writeline("# loads")
467 for i in range(schema.num_input_tensors()):
468 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim)))
469 order = _tuple_content(tuple(f"in{i}_stride_order{j}" for j in range(ndim)))
470 code.writeline(
471 f"in{i}_bptr = tl.make_block_ptr("
472 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
473 )
474 code.writeline(
475 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) "
476 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)"
477 )
478 code.newline()
480 # compute
481 # TODO: sepearate this part
482 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
483 outputs_to_scalar_fn = [
484 self.output_name(i) for i in range(schema.num_output_tensors())
485 ]
486 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
487 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
489 code.writeline("# compute")
490 code.writeline(
491 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
492 )
493 code.newline()
495 # stores
496 code.writeline(
497 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype"
498 )
499 for i in range(schema.num_output_tensors()):
500 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim)))
501 order = _tuple_content(
502 tuple(f"out{i}_stride_order{j}" for j in range(ndim))
503 )
504 code.writeline(
505 f"out{i}_bptr = tl.make_block_ptr("
506 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
507 )
508 code.writeline(
509 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))"
510 )
512 def gen_body_gsl_with_bptr(self, code):
513 code.writeline("num_ctas = tle.num_programs(0)")
514 code.writeline("for j in range(0, tiles_per_cta):")
515 with code.indent():
516 code.writeline("tile_id = pid + j * num_ctas")
517 self.gen_body_one_tile_per_cta_with_bptr(code)
519 def gen_body_one_tile_per_cta_without_bptr(self, code):
520 ndim = self.ndim
521 schema = self.fx
523 # reconstruct pid multi index
524 code.writeline(
525 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
526 )
527 for i in reversed(range(ndim)):
528 if i > 0:
529 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
530 code.writeline(f"tile_id //= num_tiles{i}")
531 else:
532 code.writeline(f"tile_id{i} = tile_id")
533 code.newline()
535 # offsets
536 for i in range(ndim):
537 code.writeline(
538 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})"
539 )
541 # masks
542 for i in range(ndim):
543 code.writeline(f"mask{i} = offsets{i} < s{i}")
544 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim))
545 mask_combine = " & ".join(masks)
546 code.writeline(f"mask = {mask_combine}")
548 # loads
549 code.writeline("# loads")
550 for i in range(schema.num_input_tensors()):
551 offsets = tuple(
552 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}"
553 for j in range(ndim)
554 )
555 offset_combine = " + ".join(offsets)
556 code.writeline(
557 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
558 )
560 code.newline()
562 # compute
563 # TODO: sepearate this part
564 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
565 outputs_to_scalar_fn = [
566 self.output_name(i) for i in range(schema.num_output_tensors())
567 ]
568 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
569 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
571 code.writeline("# compute")
572 code.writeline(
573 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
574 )
575 code.newline()
577 # stores
578 for i in range(schema.num_output_tensors()):
579 offsets = tuple(
580 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}"
581 for j in range(ndim)
582 )
583 offset_combine = " + ".join(offsets)
584 code.writeline(
585 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
586 )
588 def gen_body_gsl_without_bptr(self, code):
589 code.writeline("num_ctas = tle.num_programs(0)")
590 code.writeline("for j in range(0, tiles_per_cta):")
591 with code.indent():
592 code.writeline("tile_id = pid + j * num_ctas")
593 self.gen_body_one_tile_per_cta_without_bptr(code)
595 def codegen_nd_tile_with_bptr(self, code):
596 """Generate kernel nd tile & 1d grid with gsl support with block pointer."""
597 self.gen_import_function(code)
598 self.gen_decorators(code)
599 self.gen_signature(code, with_block_pointer=True)
601 # function body for rank-0
602 if self.ndim == 0:
603 with code.indent():
604 self.gen_body_for_0d(code)
605 return code
607 with code.indent():
608 code.writeline("pid = tle.program_id(0)")
609 self.gen_num_tiles(code)
610 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
611 code.writeline("if one_tile_per_cta: # monolitic kernel style")
612 with code.indent():
613 code.writeline("tile_id = pid")
614 self.gen_body_one_tile_per_cta_with_bptr(code)
615 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
616 code.writeline("else: # grid-stride-loop style kernel")
617 with code.indent():
618 self.gen_body_gsl_with_bptr(code)
619 code.newline()
620 return code
622 def codegen_nd_tile_without_bptr(self, code):
623 self.gen_import_function(code)
624 self.gen_decorators(code)
625 self.gen_signature(code, with_block_pointer=False)
627 # function body for rank-0
628 if self.ndim == 0:
629 with code.indent():
630 self.gen_body_for_0d(code)
631 return code
633 with code.indent():
634 code.writeline("pid = tle.program_id(0)")
635 self.gen_num_tiles(code)
636 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
637 code.writeline("if one_tile_per_cta: # monolitic kernel style")
638 with code.indent():
639 code.writeline("tile_id = pid")
640 self.gen_body_one_tile_per_cta_without_bptr(code)
641 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
642 code.writeline("else: # grid-stride-loop style kernel")
643 with code.indent():
644 self.gen_body_gsl_without_bptr(code)
645 code.newline()
646 return code
648 def codegen_nd_tile(self, code):
649 use_block_pointer = self.config.prefer_block_pointer
650 if use_block_pointer:
651 self.codegen_nd_tile_with_bptr(code)
652 else:
653 self.codegen_nd_tile_without_bptr(code)
654 return code
656 def gen_body_one_tile_per_cta_1d_tile(self, code):
657 ndim = self.ndim
658 schema = self.fx
660 # tile id
661 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)")
662 code.writeline("mask = tid < num_tasks")
664 # multi index reconstruction
665 for i in reversed(range(ndim)):
666 if i > 0:
667 code.writeline(f"i{i} = tid % s{i}")
668 code.writeline(f"tid //= s{i}")
669 else:
670 code.writeline(f"i{i} = tid")
671 code.newline()
673 # loads
674 code.writeline("# loads")
675 for i in range(schema.num_input_tensors()):
676 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim))
677 offset_combine = " + ".join(offsets)
678 code.writeline(
679 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
680 )
682 code.newline()
684 # compute
685 # TODO: sepearate this part
686 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
687 outputs_to_scalar_fn = [
688 self.output_name(i) for i in range(schema.num_output_tensors())
689 ]
690 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
691 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
693 code.writeline("# compute")
694 code.writeline(
695 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
696 )
697 code.newline()
699 # stores
700 for i in range(schema.num_output_tensors()):
701 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim))
702 offset_combine = " + ".join(offsets)
703 code.writeline(
704 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
705 )
707 def gen_body_gsl_1d_tile(self, code):
708 code.writeline("num_ctas = tle.num_programs(0)")
709 code.writeline("for j in range(0, tiles_per_cta):")
710 with code.indent():
711 code.writeline("tile_id = pid + j * num_ctas")
712 self.gen_body_one_tile_per_cta_1d_tile(code)
714 def codegen_1d_tile(self, code):
715 """Generate kernel 1d tile & 1d grid with gsl support."""
716 self.gen_import_function(code)
717 self.gen_decorators(code)
718 self.gen_signature_1d_tile(code)
720 # function body for rank-0
721 if self.ndim == 0:
722 with code.indent():
723 self.gen_body_for_0d(code)
724 return code
726 with code.indent():
727 code.writeline("pid = tle.program_id(0)")
728 # code.writeline("num_ctas = te.num_programs(0)")
729 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
730 code.writeline("if one_tile_per_cta: # monolitic kernel style")
731 with code.indent():
732 code.writeline("tile_id = pid")
733 self.gen_body_one_tile_per_cta_1d_tile(code)
734 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
735 code.writeline("else: # grid-stride-loop style kernel")
736 with code.indent():
737 self.gen_body_gsl_1d_tile(code)
738 code.newline()
739 return code
742class WrapperGenerator:
743 def __init__(
744 self,
745 function_schema: FunctionSchema,
746 jit_fn_name: str,
747 ndim: int,
748 name: str,
749 config: CodeGenConfig,
750 ):
751 self.fx = function_schema
752 self.jit_fn_name = jit_fn_name
753 self.ndim = ndim
754 self.name = name
755 self.config = config
757 def input_name(self, i):
758 is_tensor = self.fx.is_tensor(i)
759 name = "in" if is_tensor else "val"
760 index = self.fx.input_index(i)
761 return f"{name}{index}"
763 def output_name(self, i):
764 return f"out{i}"
766 def gen_signature(self, code: IndentedBuffer):
767 # TODO: check if triton handles constexprs transitively
768 schema = self.fx
769 params: List[str] = []
770 for i in range(schema.num_inputs()):
771 if schema.is_tensor(i):
772 params.append(
773 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]"
774 )
775 else:
776 arg_type = schema.input_type(i)
777 if arg_type is not None:
778 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}")
779 else:
780 params.append(f"{self.input_name(i)}")
781 # NOTE: [the wrapper's signature and rules for passing parameters ]
782 # input params: must be passed by position, since the names are renamed to
783 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd
784 # So we enforce that these parameters must be passed by position.
785 # maybe we can fix it later
786 # output parameters: must be passed by keyword, since the scalar function
787 # do not have output parameters(think of it as some scalar function, output
788 # parameter does not make sense in this case.) They are added to allow destination
789 # passing style API. Output parameter is convenient in cases where we want
790 # to use some pre-defiend outputs(especially when they are some views of other
791 # tensors). We emphasize that these parameters are added in-addition, we enforce
792 # that they be passed by keyword. After all, out0, out1, ... does not mismatch
793 # names form the scalar function, since it does not have output parameters.
794 params.append("/")
795 params.append("*") # output params must be passed by keyword
797 for i in range(schema.num_output_tensors()):
798 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]")
799 code.writeline(f"def {self.name}({_cs(params)}): ")
801 def gen_docstring(self, code: IndentedBuffer):
802 schema = self.fx
803 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""'
804 code.writeline(doc)
806 def gen_same_shape_check(self, code: IndentedBuffer):
807 schema: FunctionSchema = self.fx
808 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [
809 f"out{i}.shape" for i in range(schema.num_output_tensors())
810 ]
811 check: str = " == ".join(params)
812 code.writeline(f"assert {check}, 'operand shapes mismatch'")
814 def gen_task_partition(self, code: IndentedBuffer):
815 code.writeline("# task partitioning")
816 ndim = self.ndim
817 if ndim == 0:
818 code.writeline("num_warps = 1")
819 code.writeline("num_ctas = 1")
820 else:
821 code.writeline("shape = out0.shape")
822 code.writeline("num_tasks = out0.numel()")
823 code.writeline("if num_tasks == 0:")
824 with code.indent():
825 self.gen_return(code)
826 max_tile_size = self.config.max_tile_size
827 # Check if all input and output dtypes are complex
828 all_complex = True
829 for i in range(self.fx.num_inputs()):
830 if self.fx.is_tensor(i):
831 input_dtype = self.fx.input_type(i)
832 if input_dtype is not None and not (
833 input_dtype == torch.complex64
834 or input_dtype == torch.complex128
835 ):
836 all_complex = False
837 break
838 if all_complex:
839 # If all inputs are complex, set max_tile_size to half
840 max_tile_size = max_tile_size // 2
841 major, _ = get_device_capability()
842 if self.name.find("fill_scalar") != -1 and major >= 9:
843 code.writeline("tile_sizes = tuple([64])")
844 else:
845 code.writeline(
846 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)"
847 )
848 code.writeline("tile_size = math.prod(tile_sizes)")
849 code.writeline(
850 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))"
851 )
853 if self.name.find("fill_scalar") != -1 and major >= 9:
854 code.writeline("num_ctas = num_tiles")
855 else:
856 max_grid_size0 = self.config.max_grid_size[0]
857 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
859 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
860 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
861 code.writeline("one_tile_per_cta = tiles_per_cta==1")
862 code.writeline("grid = (num_ctas, 1, 1)")
864 def gen_task_partition_1d(self, code: IndentedBuffer):
865 code.writeline("# task partitioning")
866 ndim = self.ndim
867 if ndim == 0:
868 code.writeline("num_warps = 1")
869 code.writeline("num_ctas = 1")
870 else:
871 code.writeline("shape = out0.shape")
872 code.writeline("num_tasks = out0.numel()")
873 code.writeline("if num_tasks == 0:")
874 with code.indent():
875 self.gen_return(code)
876 max_tile_size = self.config.max_tile_size
877 # Check if all input and output dtypes are complex
878 all_complex = True
879 for i in range(self.fx.num_inputs()):
880 if self.fx.is_tensor(i):
881 input_dtype = self.fx.input_type(i)
882 if input_dtype is not None and not (
883 input_dtype == torch.complex64
884 or input_dtype == torch.complex128
885 ):
886 all_complex = False
887 break
888 if all_complex:
889 max_tile_size = max_tile_size // 2
890 major, _ = get_device_capability()
891 if self.name.find("fill_scalar") != -1 and major >= 9:
892 code.writeline("tile_sizes = tuple([1024])")
893 else:
894 code.writeline(
895 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)"
896 )
898 code.writeline("tile_size = tile_sizes[0]")
899 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)")
901 if self.name.find("fill_scalar") != -1 and major >= 9:
902 code.writeline("num_ctas = num_tiles")
903 else:
904 max_grid_size0 = self.config.max_grid_size[0]
905 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
907 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
908 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
909 code.writeline("one_tile_per_cta = tiles_per_cta==1")
910 code.writeline("grid = (num_ctas, 1, 1)")
912 def gen_kernel_launch(
913 self,
914 code: IndentedBuffer,
915 ):
916 schema = self.fx
917 ndim = self.ndim
919 with_block_pointer = self.config.prefer_block_pointer
921 code.writeline("# kernel launch")
922 for i in range(schema.num_input_tensors()):
923 code.writeline(f"in{i}_strides = in{i}.stride()")
924 if not with_block_pointer:
925 continue
926 if ndim >= 2: # where ndim is 1, we don't need to compute stride order
927 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)")
928 else:
929 code.writeline(f"in{i}_stride_order = (0,)")
930 for i in range(schema.num_output_tensors()):
931 code.writeline(f"out{i}_strides = out{i}.stride()")
932 if not with_block_pointer:
933 continue
934 if ndim >= 2:
935 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)")
936 else:
937 code.writeline(f"out{i}_stride_order = (0,)")
939 code.writeline("with torch_device_fn.device(in0.device.index):")
940 with code.indent():
941 code.writeline(f"{self.jit_fn_name}[grid](")
942 with code.indent():
943 params = []
944 # NOTE: WRAP
945 for i in range(schema.num_inputs()):
946 if schema.is_tensor(i):
947 params.append(f"{self.input_name(i)}")
948 else:
949 params.append(self.input_name(i))
950 for i in range(schema.num_output_tensors()):
951 params.append(f"{self.output_name(i)}")
953 code.writeline(f"{_cs(params)},")
955 if ndim > 0:
956 for i in range(schema.num_input_tensors()):
957 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
958 code.writeline(f"{s}, # stride for in{i}")
959 if not with_block_pointer:
960 continue
961 order = ", ".join(
962 f"in{i}_stride_order[{j}]" for j in range(ndim)
963 )
964 code.writeline(f"{order}, # stride order for in{i}")
966 for i in range(schema.num_output_tensors()):
967 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
968 code.writeline(f"{s}, # stride for out{i}")
969 if not with_block_pointer:
970 continue
971 order = ", ".join(
972 f"out{i}_stride_order[{j}]" for j in range(ndim)
973 )
974 code.writeline(f"{order}, # stride orderfor out{i}")
976 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
977 code.writeline(f"{shape_args}, # task indexing space")
978 code.writeline("num_tasks, # num tasks")
979 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
980 for i in range(ndim):
981 code.writeline(f"tile_size{i}=tile_sizes[{i}],")
982 code.writeline("one_tile_per_cta=one_tile_per_cta,")
983 code.writeline("num_warps=num_warps,")
984 code.writeline(")")
986 def gen_kernel_launch_1d(
987 self,
988 code: IndentedBuffer,
989 ):
990 schema = self.fx
991 ndim = self.ndim
993 code.writeline("# kernel launch")
994 for i in range(schema.num_input_tensors()):
995 code.writeline(f"in{i}_strides = in{i}.stride()")
996 for i in range(schema.num_output_tensors()):
997 code.writeline(f"out{i}_strides = out{i}.stride()")
999 code.writeline("with torch_device_fn.device(in0.device.index):")
1000 with code.indent():
1001 code.writeline(f"{self.jit_fn_name}[grid](")
1002 with code.indent():
1003 params = []
1004 # NOTE: WRAP
1005 for i in range(schema.num_inputs()):
1006 if schema.is_tensor(i):
1007 params.append(f"{self.input_name(i)}")
1008 else:
1009 params.append(self.input_name(i))
1010 for i in range(schema.num_output_tensors()):
1011 params.append(f"{self.output_name(i)}")
1013 code.writeline(f"{_cs(params)},")
1015 if ndim > 0:
1016 for i in range(schema.num_input_tensors()):
1017 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
1018 code.writeline(f"{s}, # stride for in{i}")
1019 for i in range(schema.num_output_tensors()):
1020 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
1021 code.writeline(f"{s}, # stride for out{i}")
1023 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
1024 code.writeline(f"{shape_args}, # task indexing space")
1025 code.writeline("num_tasks, # num tasks")
1026 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
1027 code.writeline("tile_size=tile_size,")
1028 code.writeline("one_tile_per_cta=one_tile_per_cta,")
1029 code.writeline("num_warps=num_warps,")
1030 code.writeline(")")
1032 def gen_return(self, code: IndentedBuffer):
1033 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors()))
1034 code.writeline(f"return {return_exprs}")
1036 def codegen_nd_tile(self, code):
1037 self.gen_signature(code)
1039 with code.indent():
1040 self.gen_docstring(code)
1041 self.gen_same_shape_check(code)
1042 self.gen_task_partition(code)
1043 self.gen_kernel_launch(code)
1044 self.gen_return(code)
1045 code.newline()
1046 return code
1048 def codegen_1d_tile(self, code):
1049 self.gen_signature(code)
1051 with code.indent():
1052 self.gen_docstring(code)
1053 self.gen_same_shape_check(code)
1054 self.gen_task_partition_1d(code)
1055 self.gen_kernel_launch_1d(code)
1056 self.gen_return(code)
1057 code.newline()
1058 return code
1061class ModuleGenerator:
1062 def __init__(
1063 self,
1064 function_schema: FunctionSchema,
1065 scalar_fn: triton.JITFunction,
1066 ndim: int,
1067 jit_fn_name: str,
1068 wrapper_name: str,
1069 config: CodeGenConfig,
1070 ):
1071 self.config = config
1072 self.scalar_fn = scalar_fn
1073 self.wrapper_gen = WrapperGenerator(
1074 function_schema, jit_fn_name, ndim, wrapper_name, config
1075 )
1076 self.kernel_gen = KernelGenerator(
1077 function_schema, scalar_fn, ndim, jit_fn_name, config
1078 )
1080 @staticmethod
1081 def _collect_jit_deps(scalar_fn):
1082 """Collect extra imports and local @triton.jit helper sources.
1084 Parses the source module where scalar_fn is defined using AST.
1085 Returns a tuple of:
1086 - extra_imports: dict of module_path -> set of names
1087 - local_sources: list of source strings for local @triton.jit
1088 functions (those NOT decorated with @pointwise_dynamic)
1089 """
1090 import ast
1091 import inspect
1093 py_fn = getattr(scalar_fn, "fn", scalar_fn)
1094 module_name = getattr(py_fn, "__module__", None)
1095 if not module_name:
1096 return {}, []
1097 try:
1098 mod = importlib.import_module(module_name)
1099 source_file = inspect.getfile(mod)
1100 except (ImportError, TypeError, OSError):
1101 return {}, []
1102 try:
1103 with open(source_file) as f:
1104 module_source = f.read()
1105 source_lines = module_source.splitlines(keepends=True)
1106 tree = ast.parse(module_source)
1107 except (OSError, SyntaxError):
1108 return {}, []
1110 # Collect non-standard import-from lines
1111 ALREADY_IMPORTED = {
1112 "math",
1113 "typing",
1114 "torch",
1115 "triton",
1116 "triton.language",
1117 "flag_gems.utils.shape_utils",
1118 "flag_gems.utils.tensor_wrapper",
1119 "flag_gems.utils.libentry",
1120 "flag_gems.utils",
1121 "flag_gems.runtime",
1122 "flag_gems.utils.pointwise_dynamic",
1123 }
1124 extra_imports = {}
1125 for node in ast.iter_child_nodes(tree):
1126 if isinstance(node, ast.ImportFrom) and node.module:
1127 if node.module in ALREADY_IMPORTED:
1128 continue
1129 names = {alias.name for alias in node.names}
1130 extra_imports.setdefault(node.module, set()).update(names)
1132 # Collect local @triton.jit functions (without @pointwise_dynamic)
1133 def _has_decorator(func_node, name):
1134 for dec in func_node.decorator_list:
1135 src = "".join(source_lines[dec.lineno - 1 : dec.end_lineno])
1136 if name in src:
1137 return True
1138 return False
1140 def _extract_source(func_node):
1141 start = func_node.lineno - 1
1142 if func_node.decorator_list:
1143 start = func_node.decorator_list[0].lineno - 1
1144 end = func_node.end_lineno
1145 return "".join(source_lines[start:end])
1147 local_sources = []
1148 for node in ast.iter_child_nodes(tree):
1149 if not isinstance(node, ast.FunctionDef):
1150 continue
1151 if not _has_decorator(node, "triton.jit") and not _has_decorator(
1152 node, "jit"
1153 ):
1154 continue
1155 if _has_decorator(node, "pointwise_dynamic"):
1156 continue
1157 local_sources.append(_extract_source(node))
1159 return extra_imports, local_sources
1161 def generate_imports(self, code: IndentedBuffer) -> IndentedBuffer:
1162 code.writeline("import math")
1163 code.writeline("from typing import Union")
1164 code.writeline("import torch")
1165 code.writeline("import triton")
1166 code.writeline("from triton import language as tl")
1167 code.newline()
1168 code.writeline("from flag_gems.utils.shape_utils import (")
1169 code.writeline(" heuristics_for_tile_size,")
1170 code.writeline(" heuristics_for_num_warps,")
1171 code.writeline(" stride_order,")
1172 code.writeline(")")
1173 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer")
1174 code.writeline("from flag_gems.utils.libentry import libentry")
1175 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
1176 code.writeline("from flag_gems.runtime import torch_device_fn")
1178 # Generate extra imports and local JIT deps of the scalar function
1179 jit_dep_imports, local_jit_sources = self._collect_jit_deps(self.scalar_fn)
1180 for module_path, names in sorted(jit_dep_imports.items()):
1181 sorted_names = ", ".join(sorted(names))
1182 code.writeline(f"from {module_path} import {sorted_names}")
1184 code.newline()
1185 code.newline()
1187 # Emit local @triton.jit helper functions
1188 for source in local_jit_sources:
1189 for line in source.splitlines():
1190 code.writeline(line)
1191 code.newline()
1193 return code
1195 def codegen(self, code: IndentedBuffer):
1196 code = self.generate_imports(code)
1197 if self.config.prefer_1d_tile:
1198 code = self.wrapper_gen.codegen_1d_tile(code)
1199 code = self.kernel_gen.codegen_1d_tile(code)
1200 else:
1201 code = self.wrapper_gen.codegen_nd_tile(code)
1202 code = self.kernel_gen.codegen_nd_tile(code)
1203 return code
1206@dataclass
1207class KernelInfo:
1208 """Information about a generated kernel for C++ integration."""
1210 file_path: str
1211 kernel_name: str
1212 wrapper_name: str
1213 ndim: int
1216class ComplexMode(Enum):
1217 NONE = auto()
1218 ELEMENTWISE = auto() # add/sub: view_as_real → same kernel → view_as_complex
1219 CROSS = auto() # mul/div: split ar/ai/br/bi → cross_kernel
1222@dataclass
1223class ComplexStrategy:
1224 mode: ComplexMode = ComplexMode.NONE
1225 cross_kernel: object = None
1226 tensorize_scalars: bool = False
1227 fallback_target: object = None
1230_REAL_TO_COMPLEX = {
1231 torch.float16: torch.complex32,
1232 torch.bfloat16: torch.complex32,
1233 torch.float32: torch.complex64,
1234 torch.float64: torch.complex128,
1235}
1238class PointwiseDynamicFunction:
1239 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
1240 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
1241 The generated code are written out to the cache directory (defaults to ~/.flaggems).
1242 """
1244 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None):
1245 self.fx = op_desc
1247 assert isinstance(scalar_fn, JITFunction)
1248 self._scalar_fn = scalar_fn
1249 self._scalar_fn_cache_key = scalar_fn.cache_key
1250 self.pid = os.getpid()
1252 self.config: CodeGenConfig = config or get_codegen_config()
1254 # instantiated & cached overloads
1255 self.overloads: Mapping[str, Callable] = {}
1256 # cached kernel info for C++ integration
1257 self._kernel_info_cache: Mapping[str, KernelInfo] = {}
1259 # complex dispatch support
1260 self.complex_strategy = ComplexStrategy()
1261 self._operand_indices = self._infer_operand_indices()
1263 # -------------------- operand index inference --------------------
1265 def _infer_operand_indices(self):
1266 """Infer operand indices from schema._promotion_methods, done once at init."""
1267 indices = set()
1268 for pm in self.fx._promotion_methods:
1269 for idx in pm[:-1]:
1270 indices.add(idx)
1271 return frozenset(indices)
1273 # -------------------- register_complex --------------------
1275 def register_complex(
1276 self, mode, cross_kernel=None, tensorize_scalars=False, fallback_target=None
1277 ):
1278 """Register complex number support for this kernel.
1280 Args:
1281 mode: ComplexMode.ELEMENTWISE (add/sub) or ComplexMode.CROSS (mul/div).
1282 cross_kernel: A PointwiseDynamicFunction for cross-term ops (mul/div).
1283 tensorize_scalars: If True, scalar operands are converted to tensors
1284 before delegating to fallback_target.
1285 fallback_target: A PointwiseDynamicFunction (tensor-tensor version)
1286 to delegate to after tensorizing scalar operands.
1287 """
1288 self.complex_strategy = ComplexStrategy(
1289 mode=mode,
1290 cross_kernel=cross_kernel,
1291 tensorize_scalars=tensorize_scalars,
1292 fallback_target=fallback_target,
1293 )
1294 return self
1296 # -------------------- call entry --------------------
1298 def __call__(self, *args, **kwargs):
1299 if self._should_use_complex_path(args):
1300 return self._call_complex_dispatch(*args, **kwargs)
1301 return self._call_real_impl(*args, **kwargs)
1303 def _call_real_impl(self, *args, **kwargs):
1304 """Single entry point for real kernel invocation."""
1305 ndim, args, kwargs = self.prepare_args(*args, **kwargs)
1306 overload = self.instantiate(ndim)
1307 out = overload(*args, **kwargs)
1308 return self._unwrap(out)
1310 # -------------------- complex helpers --------------------
1312 @staticmethod
1313 def _is_complex_arg(a):
1314 return (isinstance(a, torch.Tensor) and a.is_complex()) or isinstance(
1315 a, complex
1316 )
1318 def _should_use_complex_path(self, args):
1319 if self.complex_strategy.mode == ComplexMode.NONE:
1320 return False
1321 return any(
1322 self._is_complex_arg(args[i])
1323 for i in self._operand_indices
1324 if i < len(args)
1325 )
1327 def _split_args(self, args):
1328 """Split args into operands and others by original position index."""
1329 operands = {}
1330 others = {}
1331 for i, a in enumerate(args):
1332 if i in self._operand_indices:
1333 operands[i] = a
1334 else:
1335 others[i] = a
1336 return operands, others
1338 def _merge_args(self, operands, others):
1339 """Rebuild args tuple from operands and others by original position index."""
1340 total = len(operands) + len(others)
1341 merged = [None] * total
1342 for i, v in operands.items():
1343 merged[i] = v
1344 for i, v in others.items():
1345 merged[i] = v
1346 return tuple(merged)
1348 def _classify_complex_inputs(self, operands):
1349 """Classify operands as 'all_complex', 'mixed', or 'real'."""
1350 complex_count = sum(1 for v in operands.values() if self._is_complex_arg(v))
1351 if complex_count == len(operands):
1352 return "all_complex"
1353 elif complex_count > 0:
1354 return "mixed"
1355 return "real"
1357 def _infer_device(self, operands):
1358 for v in operands.values():
1359 if isinstance(v, torch.Tensor):
1360 return v.device
1361 return None
1363 def _infer_complex_dtype(self, operands):
1364 return torch.result_type(*operands.values())
1366 def _tensorize_scalar_operands(self, operands, dtype, device):
1367 """Convert scalar operands to tensors."""
1368 result = {}
1369 for i, v in operands.items():
1370 if not isinstance(v, torch.Tensor):
1371 if isinstance(v, complex):
1372 result[i] = torch.tensor(v, dtype=dtype, device=device)
1373 elif isinstance(v, float):
1374 result[i] = torch.tensor(v, dtype=torch.float32, device=device)
1375 elif isinstance(v, (int, bool)):
1376 result[i] = torch.tensor(v, dtype=torch.int64, device=device)
1377 else:
1378 result[i] = v
1379 else:
1380 result[i] = v
1381 return result
1383 def _to_complex_tensor(self, a, target_dtype, device):
1384 """Convert a scalar or real tensor to a complex tensor."""
1385 if isinstance(a, torch.Tensor):
1386 if a.is_complex():
1387 return a
1388 if a.is_floating_point():
1389 cdtype = _REAL_TO_COMPLEX.get(a.dtype, torch.complex64)
1390 else:
1391 a = a.to(torch.float32)
1392 cdtype = torch.complex64
1393 return torch.complex(a, torch.zeros_like(a)).to(cdtype)
1394 elif isinstance(a, complex):
1395 return torch.tensor(a, dtype=target_dtype, device=device)
1396 elif isinstance(a, (int, float)):
1397 return torch.tensor(complex(a, 0), dtype=target_dtype, device=device)
1398 return a
1400 # -------------------- complex dispatch --------------------
1402 def _call_complex_dispatch(self, *args, **kwargs):
1403 """Unified complex dispatch entry point."""
1404 strategy = self.complex_strategy
1405 operands, others = self._split_args(args)
1407 device = self._infer_device(operands)
1408 result_dtype = self._infer_complex_dtype(operands)
1410 # tensorize scalar operands and delegate to fallback_target
1411 if strategy.tensorize_scalars and strategy.fallback_target is not None:
1412 operands = self._tensorize_scalar_operands(operands, result_dtype, device)
1413 new_args = self._merge_args(operands, others)
1414 return strategy.fallback_target(*new_args, **kwargs)
1416 # convert all operands to complex tensors
1417 for i in list(operands.keys()):
1418 operands[i] = self._to_complex_tensor(operands[i], result_dtype, device)
1420 # broadcast complex tensor operands
1421 complex_tensors = [operands[i] for i in sorted(operands.keys())]
1422 complex_tensors = torch.broadcast_tensors(*complex_tensors)
1423 for idx, key in enumerate(sorted(operands.keys())):
1424 operands[key] = complex_tensors[idx]
1426 classification = self._classify_complex_inputs(operands)
1428 if strategy.mode == ComplexMode.CROSS and classification == "all_complex":
1429 return self._call_complex_cross(operands, result_dtype)
1430 elif classification in ("all_complex", "mixed"):
1431 return self._call_complex_elementwise(
1432 operands, others, result_dtype, kwargs
1433 )
1434 else:
1435 new_args = self._merge_args(operands, others)
1436 return self._call_real_impl(*new_args, **kwargs)
1438 def _call_complex_elementwise(self, operands, others, result_dtype, kwargs):
1439 """Elementwise: view_as_real -> call real kernel -> view_as_complex."""
1440 real_tensors = {i: torch.view_as_real(t) for i, t in operands.items()}
1442 # promote to common real dtype
1443 dtypes = [t.dtype for t in real_tensors.values()]
1444 common_dtype = dtypes[0]
1445 for d in dtypes[1:]:
1446 common_dtype = torch.promote_types(common_dtype, d)
1447 real_tensors = {i: t.to(common_dtype) for i, t in real_tensors.items()}
1449 new_args = self._merge_args(real_tensors, others)
1450 out_real = self._call_real_impl(*new_args, **kwargs)
1451 return torch.view_as_complex(out_real.contiguous()).to(result_dtype)
1453 def _call_complex_cross(self, operands, result_dtype):
1454 """Cross-term: split ar/ai/br/bi -> call cross_kernel -> stack -> view_as_complex."""
1455 sorted_keys = sorted(operands.keys())
1456 A, B = operands[sorted_keys[0]], operands[sorted_keys[1]]
1457 Ar = torch.view_as_real(A)
1458 Br = torch.view_as_real(B)
1459 ar, ai = Ar[..., 0], Ar[..., 1]
1460 br, bi = Br[..., 0], Br[..., 1]
1462 common_dtype = torch.promote_types(ar.dtype, br.dtype)
1463 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
1464 br, bi = br.to(common_dtype), bi.to(common_dtype)
1466 real, imag = self.complex_strategy.cross_kernel(ar, ai, br, bi)
1468 out = torch.stack((real, imag), dim=-1)
1469 return torch.view_as_complex(out.contiguous()).to(result_dtype)
1471 @staticmethod
1472 def use_fast_path(tensors):
1473 return all_the_same_shape(tensors) and (
1474 all_c_contiguous(tensors)
1475 or (
1476 all_the_same_stride(tensors)
1477 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0])
1478 )
1479 )
1481 def prepare_args(self, *args, _skip_tensor_check=False, **kwargs):
1482 # output allocation(when needed)
1483 # task simplification & task-rank infernece & input-output reinterpretation
1484 schema = self.fx
1485 outputs_that_need_allocation: List[int] = []
1486 out_tensors = []
1487 for i in range(schema.num_output_tensors()):
1488 k = f"out{i}"
1489 if k in kwargs:
1490 out_tensors.append(kwargs[k])
1491 else:
1492 outputs_that_need_allocation.append(i)
1493 # input arguments must be passed by position
1494 if not _skip_tensor_check and schema._is_tensor is not None:
1495 if not check_tensor_attributes(args, (schema._is_tensor)):
1496 raise ValueError(
1497 "Input arguments must be passed by position, and the corresponding dtype must be specified."
1498 )
1499 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)]
1501 # output dtype promotions
1502 outputs_dtypes_for_allocation = []
1503 for i in outputs_that_need_allocation:
1504 *arg_indices, method = schema._promotion_methods[i]
1505 promote_args = (args[j] for j in arg_indices)
1506 _, dtype = type_promotion(*promote_args, type_promotion=method)
1507 outputs_dtypes_for_allocation.append(dtype)
1509 tensors = out_tensors + in_tensors
1510 INT32_MAX = torch.iinfo(torch.int32).max
1511 if tensors[0].numel() > INT32_MAX:
1512 self.config.prefer_block_pointer = False
1513 if self.use_fast_path(tensors): # dimension collapse & use physical ordering
1514 allocated_outputs = [
1515 torch.empty_like(tensors[0], dtype=dtype)
1516 for dtype in outputs_dtypes_for_allocation
1517 ]
1518 task_shape = (tensors[0].numel(),)
1519 strides = (1,)
1520 ndim = 1
1521 args = tuple(
1522 (
1523 StridedBuffer(item, task_shape, strides)
1524 if schema.is_tensor(i)
1525 else item
1526 )
1527 for i, item in enumerate(args)
1528 )
1529 kwargs = {
1530 k: StridedBuffer(item, task_shape, strides)
1531 for k, item in kwargs.items()
1532 }
1533 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1534 kwargs[f"out{output_id}"] = StridedBuffer(
1535 allocated_outputs[seq_id], task_shape, strides
1536 )
1537 else:
1538 # a simple strategy: all the undefined tensors will follow the first
1539 # tensor that is not broadcated, no attempts to simplify task, no reordering,
1540 # no dimenion collapsing
1541 shapes = tuple(item.shape for item in in_tensors)
1543 task_shape = broadcast_shapes(shapes)
1545 if out_tensors:
1546 for index, item in enumerate(out_tensors):
1547 if list(item.shape) != list(task_shape):
1548 raise RuntimeError(
1549 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!"
1550 )
1551 # output arguments must not have internal overlapping for pointwise operation
1552 if has_internal_overlapping(item) == MemOverlap.Yes:
1553 raise RuntimeError(
1554 "Pointwise Input arguments should not have internal overlapping."
1555 )
1557 ndim = len(task_shape)
1558 for item in tensors:
1559 if item.shape == task_shape:
1560 allocated_outputs = [
1561 torch.empty_like(item, dtype=dtype)
1562 for dtype in outputs_dtypes_for_allocation
1563 ]
1564 break
1565 else: # nobreak
1566 device = tensors[0].device
1567 allocated_outputs = [
1568 torch.empty(task_shape, dtype=dtype, device=device)
1569 for dtype in outputs_dtypes_for_allocation
1570 ]
1571 args = tuple(
1572 (
1573 StridedBuffer(
1574 item,
1575 task_shape,
1576 broadcasted_stride(item.shape, item.stride(), task_shape),
1577 )
1578 if schema.is_tensor(i)
1579 else item
1580 )
1581 for i, item in enumerate(args)
1582 )
1583 kwargs = {
1584 k: StridedBuffer(
1585 item,
1586 task_shape,
1587 broadcasted_stride(item.shape, item.stride(), task_shape),
1588 )
1589 for k, item in kwargs.items()
1590 }
1591 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1592 item = allocated_outputs[seq_id]
1593 kwargs[f"out{output_id}"] = StridedBuffer(
1594 item,
1595 task_shape,
1596 broadcasted_stride(item.shape, item.stride(), task_shape),
1597 )
1598 return (ndim, args, kwargs)
1600 def _unwrap(self, tensors):
1601 # unwrap StridedBuffer to get Tensor
1602 if self.fx.num_output_tensors() == 1:
1603 item = tensors
1604 return item.unwrap()
1605 return tuple(item.unwrap() for item in tensors)
1607 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]:
1608 """Compute kernel name, wrapper name, and file path for a given ndim.
1610 This is the single source of truth for naming, used by both instantiate()
1611 and get_kernel_info() to ensure consistency.
1613 Returns:
1614 Tuple of (kernel_name, wrapper_name, file_path)
1615 """
1616 scalar_fn_name = self._scalar_fn.__name__
1617 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}"
1618 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}"
1620 file_name = (
1621 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_"
1622 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}"
1623 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}"
1624 ".py"
1625 )
1626 file_path = str(code_cache_dir() / file_name)
1628 return kernel_name, wrapper_name, file_path
1630 def instantiate(self, ndim):
1631 # NOTE: manually instantiated overload does not have `prepare_args` as
1632 # preprocessing, so you have to manually allocate output and make sure that
1633 # the inputs & ouputs actually fits the manually instantiated overload
1634 key = f"{ndim}_{self.config.prefer_block_pointer}"
1635 if key in self.overloads:
1636 return self.overloads[key]
1638 code = IndentedBuffer()
1640 # Use helper to compute names (single source of truth)
1641 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim)
1643 module_gen = ModuleGenerator(
1644 self.fx,
1645 self._scalar_fn,
1646 ndim,
1647 kernel_name,
1648 wrapper_name,
1649 self.config,
1650 )
1651 module_gen.codegen(code)
1653 # NOTE: [why write the generated code to a file]
1654 # triton uses inpsect to get the source of the jitted function, which requires
1655 # that the source code can be found by inspect
1656 # We write it into a file, since inspect cannot find the source of functions dynamically
1657 # created via exec string. We can help inspect to find the source by hacking linecache
1658 # library, but we find generating a module simpler, since we can generating 2 functions
1659 # the kernel and the wrapper, and the wrapper calls the kernel.
1660 write_atomic(file_path, code.getvalue())
1662 # load
1663 spec = importlib.util.spec_from_file_location(
1664 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}",
1665 file_path,
1666 )
1667 m = importlib.util.module_from_spec(spec)
1668 # do not expose it to sys.modules
1669 # sys.modules["_add_module"] = m
1671 # NOTE: [why not import the scalar function]
1672 # we do not re-import the scalar function, although the generated kernel **calls** it
1673 # Since a function's __name__ may be changed, from the module where it is defined import its
1674 # __name__ is not same; Also the same may be rebind to something else, importing via name
1675 # cannot guarantee that scalar function is imported.
1676 # So we copy the scalar function and its __globals__ to the generated module to do this
1677 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime
1678 spec.loader.exec_module(m)
1679 m.__dict__.update(self._scalar_fn.__globals__)
1680 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn
1682 overload = getattr(m, wrapper_name)
1683 self.overloads[key] = overload
1685 # Cache kernel info for C++ integration
1686 self._kernel_info_cache[key] = KernelInfo(
1687 file_path=file_path,
1688 kernel_name=kernel_name,
1689 wrapper_name=wrapper_name,
1690 ndim=ndim,
1691 )
1693 return overload
1695 def get_kernel_info(self, ndim: int) -> KernelInfo:
1696 """Get kernel information for a given ndim.
1698 This method is useful for C++ integration to get the file path and
1699 kernel name without duplicating the naming logic.
1701 If the kernel hasn't been instantiated yet, this will instantiate it first.
1703 Args:
1704 ndim: The rank of the task space
1706 Returns:
1707 KernelInfo with file_path, kernel_name, wrapper_name, and ndim
1708 """
1709 key = f"{ndim}_{self.config.prefer_block_pointer}"
1711 # Ensure the kernel is instantiated
1712 if key not in self._kernel_info_cache:
1713 self.instantiate(ndim)
1715 return self._kernel_info_cache[key]
1718def pointwise_dynamic(
1719 f: Optional[JITFunction] = None,
1720 *,
1721 num_inputs: Optional[int] = None,
1722 is_tensor: Optional[List[bool]] = None,
1723 dtypes: Optional[List[Optional[type]]] = None,
1724 num_outputs: Optional[int] = None,
1725 promotion_methods: Optional[Tuple[int, ...]] = None,
1726 config: Optional[CodeGenConfig] = None,
1727):
1728 def decorator(fn):
1729 nonlocal num_inputs
1730 if (num_inputs is None) and (is_tensor is None) and (dtypes is None):
1731 num_inputs = len(fn.arg_names)
1732 op_desc = FunctionSchema(
1733 num_inputs=num_inputs,
1734 is_tensor=is_tensor,
1735 dtypes=dtypes,
1736 num_outputs=num_outputs,
1737 promotion_methods=promotion_methods,
1738 )
1739 return PointwiseDynamicFunction(op_desc, fn, config)
1741 if f is not None:
1742 return decorator(f)
1743 return decorator