Coverage for src/flag_gems/runtime/backend/_cambricon/utils/pointwise_dynamic.py: 0%
1111 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import importlib
2import os
3from dataclasses import dataclass
4from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
6import torch
7import triton
8from triton.runtime.jit import JITFunction
10from flag_gems.utils.code_cache import code_cache_dir
11from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
12from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config
13from flag_gems.utils.shape_utils import (
14 MemOverlap,
15 all_c_contiguous,
16 all_the_same_shape,
17 all_the_same_stride,
18 broadcast_shapes,
19 broadcasted_stride,
20 check_tensor_attributes,
21 has_internal_overlapping,
22)
23from flag_gems.utils.tensor_wrapper import StridedBuffer
24from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion
27# ------------------ Operation Description ---------------------------
28def _type_name(type) -> str:
29 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object"
30 if type in (bool, int, float, str):
31 return type.__name__
32 if isinstance(type, torch.dtype):
33 return str(type)
34 return str(type)
37def _check_typed_list(container, type):
38 for item in container:
39 assert isinstance(item, type)
42def _check_sized_list(container, size):
43 assert len(container) == size
46def _tuple_content(strings: Sequence[str]) -> str:
47 # comma separated list
48 if len(strings) == 0:
49 return ""
50 if len(strings) == 1:
51 return f"{strings[0]},"
52 else:
53 return ", ".join(strings)
56def _cs(strings: Iterable[str]) -> str:
57 return ", ".join(strings)
60def _broadcast_vec(i, ndim):
61 axes = [":" if j == i else "None" for j in range(ndim)]
62 return f"[{_cs(axes)}]"
65class FunctionSchema:
66 _num_inputs: int
67 _is_tensor: List[bool]
68 _dtypes: List[Optional[type]]
70 _num_input_tensors: int
71 _num_non_tensor_inputs: int
73 _num_outputs: int
74 _promotion_methods: List[Tuple[int, ...]]
76 def __init__(
77 self,
78 *,
79 num_inputs: Optional[int] = None,
80 is_tensor: Optional[List[bool]] = None,
81 dtypes: Optional[List[Optional[type]]] = None,
82 num_outputs: Optional[int] = None,
83 promotion_methods=None,
84 ):
85 if is_tensor is not None:
86 _check_typed_list(is_tensor, bool)
87 if dtypes is not None:
88 _check_typed_list(dtypes, (type, type(None)))
90 if promotion_methods is None:
91 raise ValueError(
92 "No type promotion method provided! You must provide type promotion method for each output!"
93 )
94 else:
95 self._promotion_methods = self.canonicalize_promotion_methods(
96 promotion_methods
97 )
98 if num_inputs is not None:
99 self._num_inputs = num_inputs
100 if is_tensor is not None:
101 _check_sized_list(is_tensor, num_inputs)
102 self._is_tensor = is_tensor
103 else:
104 self._is_tensor = [True] * num_inputs
106 if dtypes is not None:
107 _check_sized_list(dtypes, num_inputs)
108 self._dtypes = dtypes
109 else:
110 self._dtypes = [None] * num_inputs
111 elif is_tensor is not None:
112 self._num_inputs = len(is_tensor)
113 self._is_tensor = is_tensor
114 if dtypes is not None:
115 _check_sized_list(dtypes, self._num_inputs)
116 self._dtypes = dtypes
117 else:
118 self._dtypes = [None] * self._num_inputs
119 elif dtypes is not None:
120 self._num_inputs = len(dtypes)
121 self._dtypes = dtypes
122 if is_tensor is not None:
123 _check_sized_list(is_tensor, self._num_inputs)
124 self._is_tensor = is_tensor
125 else:
126 self._is_tensor = [item is None for item in dtypes]
127 else:
128 raise ValueError(
129 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified."
130 )
132 if num_outputs is not None:
133 self._num_outputs = num_outputs
134 _check_sized_list(promotion_methods, num_outputs)
135 else:
136 self._num_outputs = len(promotion_methods)
138 assert self._num_inputs >= 1
139 assert self._num_outputs >= 1
141 self._num_input_tensors = sum(self._is_tensor)
142 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors
143 self._input_id = self._compute_input_id()
145 @staticmethod
146 def canonicalize_promotion_methods(promotion_methods):
147 canonicalized = []
148 for item in promotion_methods:
149 *arg_indices, method = item
150 canonicalized.append(
151 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method])
152 )
153 return canonicalized
155 def num_inputs(self):
156 # num of arguments, outputs not included
157 return self._num_inputs
159 def num_outputs(self):
160 return self._num_outputs
162 def is_tensor(self, arg_id: int) -> bool:
163 return self._is_tensor[arg_id]
165 def input_type(self, arg_id) -> Optional[type]:
166 return self._dtypes[arg_id]
168 def output_type(self, i):
169 return self._promotion_methods[i]
171 def num_input_tensors(self) -> int:
172 return self._num_input_tensors
174 def num_output_tensors(self) -> int:
175 return self._num_outputs
177 def num_non_tensor_args(self) -> int:
178 return self._num_non_tensor_inputs
180 def signature(self, outputs_in_arg: bool = False) -> str:
181 input_types = []
182 for is_tensor, dtype in zip(self._is_tensor, self._dtypes):
183 if is_tensor:
184 input_types.append("StridedBuffer")
185 else:
186 if dtype is None:
187 input_types.append("scalar")
188 else:
189 input_types.append(_type_name(dtype))
191 output_types = []
193 if outputs_in_arg:
194 for i in range(self.num_outputs()):
195 output_types.append(f"StridedBuffer(a{1}!)")
196 input_types.extend(output_types)
197 else:
198 for _ in range(self.num_outputs()):
199 output_types.append("StridedBuffer")
200 sig = f"Pointwise: {', '.join(input_types)} -> {', '.join(output_types)}"
201 return sig
203 def _compute_input_id(self):
204 input_tensor_index = 0
205 non_tensor_index = 0
206 mapping: List[int] = []
207 for i in range(self.num_inputs()):
208 if self.is_tensor(i):
209 mapping.append(input_tensor_index)
210 input_tensor_index += 1
211 else:
212 mapping.append(non_tensor_index)
213 non_tensor_index += 1
214 return mapping
216 def input_index(self, idx):
217 return self._input_id[idx]
219 def __str__(self) -> str:
220 return self.signature(outputs_in_arg=False)
223class KernelGenerator:
224 def __init__(
225 self,
226 function_schema: FunctionSchema,
227 scalar_fn: triton.JITFunction,
228 rank: int,
229 name: str,
230 config: CodeGenConfig,
231 ):
232 self.fx = function_schema
233 self.fn = scalar_fn
234 self.ndim = rank
235 self.name = name
236 self.config = config
238 self.fn_name = scalar_fn.__name__
239 self.fn_module = scalar_fn.__module__
241 def gen_import_function(self, code: IndentedBuffer):
242 code.writeline("@triton.jit")
243 code.writemultiline(self.fn.src)
245 def gen_config_prune(self, code):
246 code.newline()
247 code.newline()
248 code.writeline("def config_prune(configs, named_args, **kwargs):")
249 with code.indent():
250 code.writeline("new_configs = []")
251 code.writeline("elem_sizes = []")
252 for i in range(self.fx.num_input_tensors()):
253 code.writeline(
254 f"elem_sizes.append(named_args['in{i}_ptr'].dtype.itemsize)"
255 )
256 for i in range(self.fx.num_output_tensors()):
257 code.writeline(
258 f"elem_sizes.append(named_args['out{i}_ptr'].dtype.itemsize)"
259 )
261 code.writeline("max_elem_size = max(elem_sizes)")
262 shape = ", ".join(f"s{i}" for i in range(self.ndim))
263 named_shape = ", ".join(f"named_args['s{i}']" for i in range(self.ndim))
264 code.writeline(f"{shape} = {named_shape}")
265 tile_sizes = ", ".join(f"tile_size{i}" for i in range(self.ndim))
266 tile_size_dict = ", ".join(
267 f"'tile_size{i}': tile_size{i}" for i in range(self.ndim)
268 )
270 code.writeline("if max_elem_size < 8:")
271 with code.indent():
272 code.writeline("max_tile_sizes = [1024, 2048, 4096, 8192, 16000]")
273 code.writeline("for max_tile_size in max_tile_sizes:")
274 with code.indent():
275 code.writeline(
276 f"({tile_sizes}, ) = heuristics_for_tile_size(max_tile_size, {shape})"
277 )
278 code.writeline(
279 f"new_configs.append(triton.Config({{{tile_size_dict}}}, num_stages=3, num_warps=1))"
280 )
281 code.writeline("else:")
282 with code.indent():
283 code.writeline("max_tile_sizes = [1024, 2048, 4096, 8000]")
284 code.writeline("for max_tile_size in max_tile_sizes:")
285 with code.indent():
286 code.writeline(
287 f"({tile_sizes}, ) = heuristics_for_tile_size(max_tile_size, {shape})"
288 )
289 code.writeline(
290 f"new_configs.append(triton.Config({{{tile_size_dict}}}, num_stages=3, num_warps=1))"
291 )
293 code.writeline("return new_configs")
295 def gen_hooks(self, code):
296 code.newline()
297 code.newline()
298 code.writeline("restore_copies = {}")
299 code.writeline(
300 "KEYSET = torch._C.DispatchKeySet(torch._C.DispatchKey.PrivateUse1)"
301 )
302 code.writeline("def pre_hook(kwargs, reset_only=False):")
303 with code.indent():
304 code.writeline("if not reset_only:")
305 with code.indent():
306 code.writeline(
307 "torch_copy_ = flag_gems.current_work_registrar.torch_ops_map['aten::copy_']"
308 )
309 code.writeline(f"for name in {self.name}.fn.restore_value:")
310 with code.indent():
311 code.writeline("restore_copy = torch.empty_like(kwargs[name])")
312 code.writeline(
313 "restore_copies[name] = torch_copy_.call_boxed(KEYSET, restore_copy, kwargs[name])"
314 )
316 code.writeline("def post_hook(kwargs, exception):")
317 with code.indent():
318 code.writeline(f"for name in {self.name}.fn.restore_value:")
319 with code.indent():
320 code.writeline(
321 "torch_copy_ = flag_gems.current_work_registrar.torch_ops_map['aten::copy_']"
322 )
323 code.writeline(
324 "kwargs[name] = torch_copy_.call_boxed(KEYSET, kwargs[name], restore_copies[name])"
325 )
327 def gen_decorators(self, code):
328 if self.ndim in [1, 2, 3, 4] and (not self.config.prefer_1d_tile):
329 self.gen_config_prune(code)
331 if self.fn_name == "_copy_kernel":
332 self.gen_hooks(code)
334 num_non_tensor_args = self.fx.num_non_tensor_args()
335 if num_non_tensor_args > 0:
336 non_tensor_arg_names = ", ".join(
337 f"'val{i}'" for i in range(num_non_tensor_args)
338 )
340 shapes = ", ".join(f"'s{i}'" for i in range(self.ndim))
341 stride_args = []
342 for i in range(self.fx.num_input_tensors()):
343 stride_args.append(_cs(f"'in{i}_stride{j}'" for j in range(self.ndim)))
344 for i in range(self.fx.num_output_tensors()):
345 stride_args.append(_cs(f"'out{i}_stride{j}'" for j in range(self.ndim)))
347 code.writeline("@libentry()")
348 if self.ndim == 1 and (not self.config.prefer_1d_tile):
349 code.writeline("@libtuner(")
350 with code.indent():
351 code.writeline("configs=[")
352 with code.indent():
353 code.writeline(
354 "triton.Config({'tile_size0': 1024}, num_stages=3, num_warps=1),"
355 )
356 code.writeline(
357 "triton.Config({'tile_size0': 2048}, num_stages=3, num_warps=1),"
358 )
359 code.writeline("],")
360 if num_non_tensor_args > 0:
361 code.writeline(
362 f"key=['num_tasks', {_cs(stride_args)}, {non_tensor_arg_names}],"
363 )
364 else:
365 code.writeline(f"key=['num_tasks', {_cs(stride_args)}],")
366 code.writeline("prune_configs_by={'early_config_prune': config_prune},")
367 output_params = [
368 f"out{i}_ptr" for i in range(self.fx.num_output_tensors())
369 ]
370 output_elements = ", ".join(f"'{name}'" for name in output_params)
371 code.writeline(f"restore_value=[{output_elements}],")
372 if self.fn_name == "_copy_kernel":
373 code.writeline("pre_hook=pre_hook,")
374 code.writeline("post_hook=post_hook,")
375 code.writeline(")")
377 if self.ndim == 2 and (not self.config.prefer_1d_tile):
378 code.writeline("@libtuner(")
379 with code.indent():
380 code.writeline("configs=[")
381 with code.indent():
382 code.writeline(
383 "triton.Config({'tile_size0': 1, 'tile_size1': 1024}, num_stages=3, num_warps=1),"
384 )
385 code.writeline(
386 "triton.Config({'tile_size0': 1, 'tile_size1': 2048}, num_stages=3, num_warps=1),"
387 )
388 code.writeline("],")
389 if num_non_tensor_args > 0:
390 code.writeline(
391 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}],"
392 )
393 else:
394 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],")
395 code.writeline("prune_configs_by={'early_config_prune': config_prune},")
396 output_params = [
397 f"out{i}_ptr" for i in range(self.fx.num_output_tensors())
398 ]
399 output_elements = ", ".join(f"'{name}'" for name in output_params)
400 code.writeline(f"restore_value=[{output_elements}],")
401 if self.fn_name == "_copy_kernel":
402 code.writeline("pre_hook=pre_hook,")
403 code.writeline("post_hook=post_hook,")
404 code.writeline(")")
406 if self.ndim == 3 and (not self.config.prefer_1d_tile):
407 code.writeline("@libtuner(")
408 with code.indent():
409 code.writeline("configs=[")
410 with code.indent():
411 code.writeline(
412 """
413 triton.Config({'tile_size0': 1, 'tile_size1': 1, 'tile_size2': 1024}, num_stages=3, num_warps=1),
414 """
415 )
416 code.writeline(
417 """
418 triton.Config({'tile_size0': 1, 'tile_size1': 1, 'tile_size2': 2048}, num_stages=3, num_warps=1),
419 """
420 )
421 code.writeline("],")
422 if num_non_tensor_args > 0:
423 code.writeline(
424 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}],"
425 )
426 else:
427 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],")
428 code.writeline("prune_configs_by={'early_config_prune': config_prune},")
429 output_params = [
430 f"out{i}_ptr" for i in range(self.fx.num_output_tensors())
431 ]
432 output_elements = ", ".join(f"'{name}'" for name in output_params)
433 code.writeline(f"restore_value=[{output_elements}],")
434 if self.fn_name == "_copy_kernel":
435 code.writeline("pre_hook=pre_hook,")
436 code.writeline("post_hook=post_hook,")
437 code.writeline(")")
439 if self.ndim == 4 and (not self.config.prefer_1d_tile):
440 code.writeline("@libtuner(")
441 with code.indent():
442 code.writeline("configs=[")
443 with code.indent():
444 code.writeline(
445 """
446 triton.Config({'tile_size0': 1,'tile_size1': 1,'tile_size2': 1,'tile_size3': 1024},num_stages=3,num_warps=1),
447 """
448 )
449 code.writeline(
450 """
451 triton.Config({'tile_size0': 1,'tile_size1': 1,'tile_size2': 1,'tile_size3': 2048},num_stages=3,num_warps=1),
452 """
453 )
454 code.writeline("],")
455 if num_non_tensor_args > 0:
456 code.writeline(
457 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}],"
458 )
459 else:
460 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],")
461 code.writeline("prune_configs_by={'early_config_prune': config_prune},")
462 output_params = [
463 f"out{i}_ptr" for i in range(self.fx.num_output_tensors())
464 ]
465 output_elements = ", ".join(f"'{name}'" for name in output_params)
466 code.writeline(f"restore_value=[{output_elements}],")
467 if self.fn_name == "_copy_kernel":
468 code.writeline("pre_hook=pre_hook,")
469 code.writeline("post_hook=post_hook,")
470 code.writeline(")")
472 if num_non_tensor_args > 0:
473 # we do not specialize non tensor args since they are passed into the inlined function
474 # which means that their values may not deserve specialization
475 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)]
476 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})")
477 else:
478 code.writeline("@triton.jit")
480 def input_name(self, i):
481 is_tensor = self.fx.is_tensor(i)
482 name = "in" if is_tensor else "val"
483 index = self.fx.input_index(i)
484 return f"{name}{index}"
486 def output_name(self, i):
487 return f"out{i}"
489 def gen_signature(self, code, with_block_pointer=False):
490 code.writeline(f"def {self.name}(")
491 with code.indent():
492 input_tensor_index = 0
493 non_tensor_index = 0
494 output_tensor_index = 0
496 schema = self.fx
497 # signature: inputs ptrs & non tensor inputs
498 for i in range(schema.num_inputs()):
499 if schema.is_tensor(i):
500 code.writeline(
501 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
502 )
503 input_tensor_index += 1
504 else:
505 if schema.input_type(i) is not None:
506 code.writeline(
507 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
508 )
509 else:
510 code.writeline(f"val{non_tensor_index},")
511 non_tensor_index += 1
513 # signature: output ptrs
514 for i in range(schema.num_outputs()):
515 code.writeline(
516 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
517 )
518 output_tensor_index += 1
520 # signature: strides, for each tensor arguments
521 ndim = self.ndim
522 if ndim > 0:
523 # strides for inputs
524 for i in range(schema.num_input_tensors()):
525 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
526 code.writeline(f"{stride_args}, # strides for in{i}")
527 if with_block_pointer:
528 stride_order_args = _cs(
529 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
530 )
531 code.writeline(f"{stride_order_args}, # stride order for in{i}")
532 zero_stride_args = _cs(
533 f"in{i}_zero_stride{j}: tl.constexpr" for j in range(ndim)
534 )
535 code.writeline(
536 f"{zero_stride_args}, # zero stride flag for in{i}"
537 )
539 # strides for outputs
540 for i in range(schema.num_output_tensors()):
541 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
542 code.writeline(f"{stride_args}, # strides for out{i}")
543 if with_block_pointer:
544 stride_order_args = _cs(
545 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim)
546 )
547 code.writeline(
548 f"{stride_order_args}, # stride order for out{i}"
549 )
550 zero_stride_args = _cs(
551 f"out{i}_zero_stride{j}: tl.constexpr" for j in range(ndim)
552 )
553 code.writeline(
554 f"{zero_stride_args}, # zero stride flag for out{i}"
555 )
557 # task space, used to reconstruct multi index
558 task_space_args = _cs(f"s{i}" for i in range(ndim))
559 code.writeline(f"{task_space_args}, # task_space")
561 # number of tasks, used to compute mask
562 code.writeline("num_tasks,")
563 if self.config.prefer_block_pointer:
564 code.writeline("FALLBACK_BPTR: tl.constexpr,")
566 # tile size & tiles_per_cta, gsl style
567 if ndim > 0:
568 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim))
569 code.writeline(f"{tile_sizes},")
570 if ndim > 4:
571 code.writeline("tiles_per_cta: int,")
572 code.writeline("one_tile_per_cta: tl.constexpr,")
573 code.writeline("):")
575 def gen_signature_1d_tile(self, code):
576 code.writeline(f"def {self.name}(")
577 with code.indent():
578 input_tensor_index = 0
579 non_tensor_index = 0
580 output_tensor_index = 0
582 schema = self.fx
583 # signature: inputs ptrs & non tensor inputs
584 for i in range(schema.num_inputs()):
585 if schema.is_tensor(i):
586 code.writeline(
587 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
588 )
589 input_tensor_index += 1
590 else:
591 if schema.input_type(i) is not None:
592 code.writeline(
593 f"val{non_tensor_index}: {_type_name(schema.input_type(i))},"
594 )
595 else:
596 code.writeline(f"val{non_tensor_index},")
597 non_tensor_index += 1
599 # signature: output ptrs
600 for i in range(schema.num_outputs()):
601 code.writeline(
602 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type"
603 )
604 output_tensor_index += 1
606 # signature: strides, for each tensor arguments
607 ndim = self.ndim
608 if ndim > 0:
609 # strides for inputs
610 for i in range(schema.num_input_tensors()):
611 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim))
612 code.writeline(f"{stride_args}, # strides for in{i}")
614 # strides for outputs
615 for i in range(schema.num_output_tensors()):
616 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim))
617 code.writeline(f"{stride_args}, # strides for out{i}")
619 # task space, used to reconstruct multi index
620 task_space_args = _cs(f"s{i}" for i in range(ndim))
621 code.writeline(f"{task_space_args}, # task_space")
623 # number of tasks, used to compute mask
624 code.writeline("num_tasks,")
626 if self.config.prefer_block_pointer:
627 code.writeline("FALLBACK_BPTR: tl.constexpr,")
629 # tile size & tiles_per_cta, gsl style
630 if ndim > 0:
631 code.writeline("tiles_per_cta: int,")
632 code.writeline("tile_size: tl.constexpr,")
633 code.writeline("one_tile_per_cta: tl.constexpr,")
634 code.writeline("):")
636 def gen_num_tiles(self, code):
637 # tile-grid size
638 ndim = self.ndim
639 for i in range(ndim):
640 if i < ndim:
641 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})")
643 def gen_body_for_0d(self, code):
644 schema = self.fx
645 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
646 outputs_to_scalar_fn = [
647 self.output_name(i) for i in range(schema.num_output_tensors())
648 ]
649 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
650 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
652 code.writeline("# loads")
653 for i in range(schema.num_input_tensors()):
654 code.writeline(
655 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) "
656 "# workaround the bug on bool, we should use the pointer's dtype)"
657 )
658 code.newline()
660 code.writeline("# compute")
661 code.writeline(
662 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
663 )
664 code.newline()
666 code.writeline("# stores")
667 for i in range(schema.num_output_tensors()):
668 code.writeline(
669 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))"
670 )
671 code.newline()
672 return code
674 # nd tile 1d grid kernel with block pointer
675 def gen_body_one_tile_per_cta_with_bptr(self, code):
676 ndim = self.ndim
677 schema = self.fx
679 # block pointer for each operand
680 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim)))
681 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim)))
682 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim)))
684 # reconstruct pid multi index
685 code.writeline(
686 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
687 )
688 for i in reversed(range(ndim)):
689 if i > 0:
690 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
691 code.writeline(f"tile_id //= num_tiles{i}")
692 else:
693 code.writeline(f"tile_id{i} = tile_id")
694 code.newline()
696 # cta_offsets
697 code.writeline("# tile offsets")
699 # Because block pointer only support `tl.int32` indexing, when max offsets
700 # of ptrs exceeding 2^31, we should fallback it to noraml indexing method.
701 code.writeline("if not FALLBACK_BPTR:")
702 with code.indent():
703 for i in range(ndim):
704 # Or else: AssertionError: Block pointers only support 32 bit
705 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing
706 # for 64 bit support
707 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)")
709 # loads
710 code.writeline("# loads")
711 for i in range(schema.num_input_tensors()):
712 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim)))
713 order = _tuple_content(
714 tuple(f"in{i}_stride_order{j}" for j in range(ndim))
715 )
717 for j in range(ndim):
718 code.writeline(f"if in{i}_zero_stride{j}:")
719 with code.indent():
720 code.writeline(f"in{i}_stride{j} = 0")
722 code.writeline(
723 f"in{i}_bptr = tl.make_block_ptr("
724 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
725 )
727 code.writeline(
728 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) "
729 )
730 code.newline()
732 # compute
733 # TODO: sepearate this part
734 inputs_to_scalar_fn = [
735 self.input_name(i) for i in range(schema.num_inputs())
736 ]
737 outputs_to_scalar_fn = [
738 self.output_name(i) for i in range(schema.num_output_tensors())
739 ]
740 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
741 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
743 code.writeline("# compute")
744 code.writeline(
745 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
746 )
747 code.newline()
749 # stores
750 for i in range(schema.num_output_tensors()):
751 strides = _tuple_content(
752 tuple(f"out{i}_stride{j}" for j in range(ndim))
753 )
754 order = _tuple_content(
755 tuple(f"out{i}_stride_order{j}" for j in range(ndim))
756 )
758 for j in range(ndim):
759 code.writeline(f"if out{i}_zero_stride{j}:")
760 with code.indent():
761 code.writeline(f"out{i}_stride{j} = 0")
763 code.writeline(
764 f"out{i}_bptr = tl.make_block_ptr("
765 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))"
766 )
768 code.writeline(
769 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))"
770 )
771 code.writeline("else:")
772 with code.indent():
773 # offsets
774 for i in range(ndim):
775 code.writeline(
776 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})"
777 )
779 # masks
780 for i in range(ndim):
781 code.writeline(f"mask{i} = offsets{i} < s{i}")
782 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim))
783 mask_combine = " & ".join(masks)
784 code.writeline(f"mask = {mask_combine}")
786 # loads
787 code.writeline("# loads")
788 for i in range(schema.num_input_tensors()):
789 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim)))
790 order = _tuple_content(
791 tuple(f"in{i}_stride_order{j}" for j in range(ndim))
792 )
794 for j in range(ndim):
795 code.writeline(f"if in{i}_zero_stride{j}:")
796 with code.indent():
797 code.writeline(f"in{i}_stride{j} = 0")
798 offsets = tuple(
799 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}"
800 for j in range(ndim)
801 )
802 offset_combine = " + ".join(offsets)
803 code.writeline(
804 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
805 )
807 code.newline()
809 # compute
810 inputs_to_scalar_fn = [
811 self.input_name(i) for i in range(schema.num_inputs())
812 ]
813 outputs_to_scalar_fn = [
814 self.output_name(i) for i in range(schema.num_output_tensors())
815 ]
816 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
817 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
819 code.writeline("# compute")
820 code.writeline(
821 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
822 )
823 code.newline()
825 # store
826 for i in range(schema.num_output_tensors()):
827 strides = _tuple_content(
828 tuple(f"out{i}_stride{j}" for j in range(ndim))
829 )
830 order = _tuple_content(
831 tuple(f"out{i}_stride_order{j}" for j in range(ndim))
832 )
834 for j in range(ndim):
835 code.writeline(f"if out{i}_zero_stride{j}:")
836 with code.indent():
837 code.writeline(f"out{i}_stride{j} = 0")
839 offsets = tuple(
840 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}"
841 for j in range(ndim)
842 )
843 offset_combine = " + ".join(offsets)
844 code.writeline(
845 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
846 )
848 def gen_body_gsl_with_bptr(self, code):
849 code.writeline("num_ctas = ext.num_programs(0)")
850 if self.ndim <= 4:
851 num_tiles = " * ".join([f"num_tiles{i}" for i in range(self.ndim)])
852 code.writeline(
853 f"tiles_per_cta = tl.cdiv({num_tiles}, num_ctas).to(tl.int32)"
854 )
855 code.writeline("for j in range(0, tiles_per_cta):")
856 with code.indent():
857 code.writeline("tile_id = pid + j * num_ctas")
858 self.gen_body_one_tile_per_cta_with_bptr(code)
860 def gen_body_one_tile_per_cta_without_bptr(self, code):
861 ndim = self.ndim
862 schema = self.fx
864 # reconstruct pid multi index
865 code.writeline(
866 "# pid multi index recontruction: we use c ordering, right axes changes fastest"
867 )
868 for i in reversed(range(ndim)):
869 if i > 0:
870 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}")
871 code.writeline(f"tile_id //= num_tiles{i}")
872 else:
873 code.writeline(f"tile_id{i} = tile_id")
874 code.newline()
876 # offsets
877 for i in range(ndim):
878 code.writeline(
879 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})"
880 )
882 # masks
883 for i in range(ndim):
884 code.writeline(f"mask{i} = offsets{i} < s{i}")
885 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim))
886 mask_combine = " & ".join(masks)
887 code.writeline(f"mask = {mask_combine}")
889 # loads
890 code.writeline("# loads")
891 for i in range(schema.num_input_tensors()):
892 offsets = tuple(
893 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}"
894 for j in range(ndim)
895 )
896 offset_combine = " + ".join(offsets)
897 code.writeline(
898 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
899 )
901 code.newline()
903 # compute
904 # TODO: sepearate this part
905 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
906 outputs_to_scalar_fn = [
907 self.output_name(i) for i in range(schema.num_output_tensors())
908 ]
909 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
910 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
912 code.writeline("# compute")
913 code.writeline(
914 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
915 )
916 code.newline()
918 # stores
919 for i in range(schema.num_output_tensors()):
920 offsets = tuple(
921 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}"
922 for j in range(ndim)
923 )
924 offset_combine = " + ".join(offsets)
925 code.writeline(
926 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
927 )
929 def gen_body_gsl_without_bptr(self, code):
930 code.writeline("num_ctas = ext.num_programs(0)")
931 if self.ndim <= 4:
932 num_tiles = " * ".join([f"num_tiles{i}" for i in range(self.ndim)])
933 code.writeline(f"tiles_per_cta = tl.cdiv({num_tiles}, num_ctas)")
934 code.writeline("for j in range(0, tiles_per_cta):")
935 with code.indent():
936 code.writeline("tile_id = pid + j * num_ctas")
937 self.gen_body_one_tile_per_cta_without_bptr(code)
939 def codegen_nd_tile_with_bptr(self, code):
940 """Generate kernel nd tile & 1d grid with gsl support with block pointer."""
941 self.gen_import_function(code)
942 self.gen_decorators(code)
943 self.gen_signature(code, with_block_pointer=True)
945 # function body for rank-0
946 if self.ndim == 0:
947 with code.indent():
948 self.gen_body_for_0d(code)
949 return code
951 with code.indent():
952 code.writeline("pid = ext.program_id(0)")
953 self.gen_num_tiles(code)
954 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
955 if self.ndim > 4:
956 code.writeline("if one_tile_per_cta: # monolitic kernel style")
957 with code.indent():
958 code.writeline("tile_id = pid")
959 self.gen_body_one_tile_per_cta_with_bptr(code)
960 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
961 code.writeline("else: # grid-stride-loop style kernel")
962 with code.indent():
963 self.gen_body_gsl_with_bptr(code)
964 else:
965 self.gen_body_gsl_with_bptr(code)
966 code.newline()
967 return code
969 def codegen_nd_tile_without_bptr(self, code):
970 self.gen_import_function(code)
971 self.gen_decorators(code)
972 self.gen_signature(code, with_block_pointer=False)
974 # function body for rank-0
975 if self.ndim == 0:
976 with code.indent():
977 self.gen_body_for_0d(code)
978 return code
980 with code.indent():
981 code.writeline("pid = ext.program_id(0)")
982 self.gen_num_tiles(code)
983 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
984 if self.ndim > 4:
985 code.writeline("if one_tile_per_cta: # monolitic kernel style")
986 with code.indent():
987 code.writeline("tile_id = pid")
988 self.gen_body_one_tile_per_cta_without_bptr(code)
989 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
990 code.writeline("else: # grid-stride-loop style kernel")
991 with code.indent():
992 self.gen_body_gsl_without_bptr(code)
993 else:
994 self.gen_body_gsl_without_bptr(code)
995 code.newline()
996 return code
998 def codegen_nd_tile(self, code):
999 use_block_pointer = self.config.prefer_block_pointer
1000 if use_block_pointer:
1001 self.codegen_nd_tile_with_bptr(code)
1002 else:
1003 self.codegen_nd_tile_without_bptr(code)
1004 return code
1006 def gen_body_one_tile_per_cta_1d_tile(self, code):
1007 ndim = self.ndim
1008 schema = self.fx
1010 # tile id
1011 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)")
1012 code.writeline("mask = tid < num_tasks")
1014 # multi index reconstruction
1015 for i in reversed(range(ndim)):
1016 if i > 0:
1017 code.writeline(f"i{i} = tid % s{i}")
1018 code.writeline(f"tid //= s{i}")
1019 else:
1020 code.writeline(f"i{i} = tid")
1021 code.newline()
1023 # loads
1024 code.writeline("# loads")
1025 for i in range(schema.num_input_tensors()):
1026 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim))
1027 offset_combine = " + ".join(offsets)
1028 code.writeline(
1029 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)"
1030 )
1032 code.newline()
1034 # compute
1035 # TODO: sepearate this part
1036 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())]
1037 outputs_to_scalar_fn = [
1038 self.output_name(i) for i in range(schema.num_output_tensors())
1039 ]
1040 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn)
1041 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn)
1043 code.writeline("# compute")
1044 code.writeline(
1045 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})"
1046 )
1047 code.newline()
1049 # stores
1050 for i in range(schema.num_output_tensors()):
1051 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim))
1052 offset_combine = " + ".join(offsets)
1053 code.writeline(
1054 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)"
1055 )
1057 def gen_body_gsl_1d_tile(self, code):
1058 code.writeline("num_ctas = ext.num_programs(0)")
1059 code.writeline("for j in range(0, tiles_per_cta):")
1060 with code.indent():
1061 code.writeline("tile_id = pid + j * num_ctas")
1062 self.gen_body_one_tile_per_cta_1d_tile(code)
1064 def codegen_1d_tile(self, code):
1065 """Generate kernel 1d tile & 1d grid with gsl support."""
1066 self.gen_import_function(code)
1067 self.gen_decorators(code)
1068 self.gen_signature_1d_tile(code)
1070 # function body for rank-0
1071 if self.ndim == 0:
1072 with code.indent():
1073 self.gen_body_for_0d(code)
1074 return code
1076 with code.indent():
1077 code.writeline("pid = ext.program_id(0)")
1078 # code.writeline("num_ctas = te.num_programs(0)")
1079 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute
1080 code.writeline("if one_tile_per_cta: # monolitic kernel style")
1081 with code.indent():
1082 code.writeline("tile_id = pid")
1083 self.gen_body_one_tile_per_cta_1d_tile(code)
1084 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
1085 code.writeline("else: # grid-stride-loop style kernel")
1086 with code.indent():
1087 self.gen_body_gsl_1d_tile(code)
1088 code.newline()
1089 return code
1092class WrapperGenerator:
1093 def __init__(
1094 self,
1095 function_schema: FunctionSchema,
1096 jit_fn_name: str,
1097 ndim: int,
1098 name: str,
1099 config: CodeGenConfig,
1100 ):
1101 self.fx = function_schema
1102 self.jit_fn_name = jit_fn_name
1103 self.ndim = ndim
1104 self.name = name
1105 self.config = config
1107 def input_name(self, i):
1108 is_tensor = self.fx.is_tensor(i)
1109 name = "in" if is_tensor else "val"
1110 index = self.fx.input_index(i)
1111 return f"{name}{index}"
1113 def output_name(self, i):
1114 return f"out{i}"
1116 def gen_signature(self, code: IndentedBuffer):
1117 # TODO: check if triton handles constexprs transitively
1118 schema = self.fx
1119 params: List[str] = []
1120 for i in range(schema.num_inputs()):
1121 if schema.is_tensor(i):
1122 params.append(
1123 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]"
1124 )
1125 else:
1126 arg_type = schema.input_type(i)
1127 if arg_type is not None:
1128 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}")
1129 else:
1130 params.append(f"{self.input_name(i)}")
1131 # NOTE: [the wrapper's signature and rules for passing parameters ]
1132 # input params: must be passed by position, since the names are renamed to
1133 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd
1134 # So we enforce that these parameters must be passed by position.
1135 # maybe we can fix it later
1136 # output parameters: must be passed by keyword, since the scalar function
1137 # do not have output parameters(think of it as some scalar function, output
1138 # parameter does not make sense in this case.) They are added to allow destination
1139 # passing style API. Output parameter is convenient in cases where we want
1140 # to use some pre-defiend outputs(especially when they are some views of other
1141 # tensors). We emphasize that these parameters are added in-addition, we enforce
1142 # that they be passed by keyword. After all, out0, out1, ... does not mismatch
1143 # names form the scalar function, since it does not have output parameters.
1144 params.append("/")
1145 params.append("*") # output params must be passed by keyword
1147 for i in range(schema.num_output_tensors()):
1148 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]")
1149 code.writeline(f"def {self.name}({_cs(params)}): ")
1151 def gen_docstring(self, code: IndentedBuffer):
1152 schema = self.fx
1153 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""'
1154 code.writeline(doc)
1156 def gen_same_shape_check(self, code: IndentedBuffer):
1157 schema: FunctionSchema = self.fx
1158 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [
1159 f"out{i}.shape" for i in range(schema.num_output_tensors())
1160 ]
1161 check: str = " == ".join(params)
1162 code.writeline(f"assert {check}, 'operand shapes mismatch'")
1164 def gen_fallback_bptr(self, code: IndentedBuffer):
1165 code.writeline("def fallback_bptr(t):")
1166 with code.indent():
1167 code.writeline("ndim = t.dim()")
1168 code.writeline("sizes = t.size()")
1169 code.writeline("if t.numel() == 0:")
1170 with code.indent():
1171 code.writeline("return False")
1172 code.writeline("for i in range(ndim):")
1173 with code.indent():
1174 code.writeline("if sizes[i] >= 2147483648:")
1175 with code.indent():
1176 code.writeline("return True")
1177 code.writeline("return False")
1178 code.newline()
1179 code.newline()
1181 def gen_task_partition(self, code: IndentedBuffer):
1182 code.writeline("# task partitioning")
1183 ndim = self.ndim
1184 if ndim == 0:
1185 code.writeline("num_warps = 1")
1186 code.writeline("num_ctas = 1")
1187 else:
1188 code.writeline("shape = out0.shape")
1189 code.writeline("num_tasks = out0.numel()")
1190 code.writeline("if num_tasks == 0:")
1191 with code.indent():
1192 self.gen_return(code)
1193 max_tile_size = self.config.max_tile_size
1194 code.writeline(
1195 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)"
1196 )
1197 code.writeline("tile_size = math.prod(tile_sizes)")
1198 code.writeline(
1199 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))"
1200 )
1201 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
1202 max_grid_size0 = self.config.max_grid_size[0]
1203 code.writeline(f"num_ctas = min({max_grid_size0} // num_warps, num_tiles)")
1205 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
1206 code.writeline("one_tile_per_cta = tiles_per_cta==1")
1207 if self.config.prefer_block_pointer:
1208 code.writeline("FALLBACK_BPTR = False")
1209 inputs = ",".join(
1210 [f"in{i}" for i in range(self.fx.num_input_tensors())]
1211 )
1212 outputs = ",".join(
1213 [f"out{i}" for i in range(self.fx.num_output_tensors())]
1214 )
1215 code.writeline(f"all_tensors = [{inputs}, {outputs}]")
1216 code.writeline("for t in all_tensors:")
1217 with code.indent():
1218 code.writeline("if fallback_bptr(t):")
1219 with code.indent():
1220 code.writeline("FALLBACK_BPTR = True")
1221 code.writeline("break")
1222 if ndim > 0 and ndim <= 4:
1223 max_grid_size0 = self.config.max_grid_size[0]
1224 dynamic_num_tiles = " * ".join(
1225 f"triton.cdiv(meta['s{i}'], meta['tile_size{i}'])" for i in range(ndim)
1226 )
1227 code.writeline(
1228 f"grid = lambda meta: (min({max_grid_size0} // num_warps, {dynamic_num_tiles}), )"
1229 )
1230 else:
1231 code.writeline("grid = (num_ctas, 1, 1)")
1233 def gen_task_partition_1d(self, code: IndentedBuffer):
1234 code.writeline("# task partitioning")
1235 ndim = self.ndim
1236 if ndim == 0:
1237 code.writeline("num_warps = 1")
1238 code.writeline("num_ctas = 1")
1239 else:
1240 code.writeline("shape = out0.shape")
1241 code.writeline("num_tasks = out0.numel()")
1242 code.writeline("if num_tasks == 0:")
1243 with code.indent():
1244 self.gen_return(code)
1245 max_tile_size = self.config.max_tile_size
1246 code.writeline(
1247 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)"
1248 )
1249 code.writeline("tile_size = tile_sizes[0]")
1250 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)")
1251 max_grid_size0 = self.config.max_grid_size[0]
1252 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)")
1254 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)")
1255 code.writeline("num_warps = heuristics_for_num_warps(tile_size)")
1256 code.writeline("one_tile_per_cta = tiles_per_cta==1")
1257 if self.config.prefer_block_pointer:
1258 code.writeline("FALLBACK_BPTR = False")
1259 inputs = ",".join(
1260 [f"in{i}" for i in range(self.fx.num_input_tensors())]
1261 )
1262 outputs = ",".join(
1263 [f"out{i}" for i in range(self.fx.num_output_tensors())]
1264 )
1265 code.writeline(f"all_tensors = [{inputs}, {outputs}]")
1266 code.writeline("for t in all_tensors:")
1267 with code.indent():
1268 code.writeline("if fallback_bptr(t):")
1269 with code.indent():
1270 code.writeline("FALLBACK_BPTR = True")
1271 code.writeline("break")
1272 code.writeline("grid = (num_ctas, 1, 1)")
1274 def gen_kernel_launch(
1275 self,
1276 code: IndentedBuffer,
1277 ):
1278 schema = self.fx
1279 ndim = self.ndim
1281 with_block_pointer = self.config.prefer_block_pointer
1283 code.writeline("# kernel launch")
1284 for i in range(schema.num_input_tensors()):
1285 code.writeline(f"in{i}_strides = in{i}.stride()")
1286 if not with_block_pointer:
1287 continue
1288 if ndim >= 2: # where ndim is 1, we don't need to compute stride order
1289 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)")
1290 else:
1291 code.writeline(f"in{i}_stride_order = (0,)")
1292 code.writeline(
1293 f"in{i}_zero_strides = [True if s == 0 else False for s in in{i}_strides]"
1294 )
1295 for i in range(schema.num_output_tensors()):
1296 code.writeline(f"out{i}_strides = out{i}.stride()")
1297 if not with_block_pointer:
1298 continue
1299 if ndim >= 2:
1300 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)")
1301 else:
1302 code.writeline(f"out{i}_stride_order = (0,)")
1303 code.writeline(
1304 f"out{i}_zero_strides = [True if s == 0 else False for s in out{i}_strides]"
1305 )
1307 code.writeline("with torch_device_fn.device(in0.device.index):")
1308 with code.indent():
1309 code.writeline(f"{self.jit_fn_name}[grid](")
1310 with code.indent():
1311 params = []
1312 # NOTE: WRAP
1313 for i in range(schema.num_inputs()):
1314 if schema.is_tensor(i):
1315 params.append(f"{self.input_name(i)}")
1316 else:
1317 params.append(self.input_name(i))
1318 for i in range(schema.num_output_tensors()):
1319 params.append(f"{self.output_name(i)}")
1321 code.writeline(f"{_cs(params)},")
1323 if ndim > 0:
1324 for i in range(schema.num_input_tensors()):
1325 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
1326 code.writeline(f"{s}, # stride for in{i}")
1327 if with_block_pointer:
1328 order = ", ".join(
1329 f"in{i}_stride_order[{j}]" for j in range(ndim)
1330 )
1331 code.writeline(f"{order}, # stride order for in{i}")
1332 zero_strides = ", ".join(
1333 f"in{i}_zero_strides[{j}]" for j in range(ndim)
1334 )
1335 code.writeline(
1336 f"{zero_strides}, # zero stride flag for in{i}"
1337 )
1339 for i in range(schema.num_output_tensors()):
1340 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
1341 code.writeline(f"{s}, # stride for out{i}")
1342 if with_block_pointer:
1343 order = ", ".join(
1344 f"out{i}_stride_order[{j}]" for j in range(ndim)
1345 )
1346 code.writeline(f"{order}, # stride orderfor out{i}")
1347 zero_strides = ", ".join(
1348 f"out{i}_zero_strides[{j}]" for j in range(ndim)
1349 )
1350 code.writeline(
1351 f"{zero_strides}, # zero stride flag for out{i}"
1352 )
1354 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
1355 code.writeline(f"{shape_args}, # task indexing space")
1356 code.writeline("num_tasks, # num tasks")
1357 if self.config.prefer_block_pointer:
1358 code.writeline("FALLBACK_BPTR=FALLBACK_BPTR,")
1359 if ndim > 4:
1360 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
1361 if ndim == 0 or ndim > 4:
1362 for i in range(ndim):
1363 code.writeline(f"tile_size{i}=tile_sizes[{i}],")
1364 if ndim > 4:
1365 code.writeline("one_tile_per_cta=one_tile_per_cta,")
1366 code.writeline("num_warps=num_warps,")
1367 code.writeline(")")
1369 def gen_kernel_launch_1d(
1370 self,
1371 code: IndentedBuffer,
1372 ):
1373 schema = self.fx
1374 ndim = self.ndim
1376 code.writeline("# kernel launch")
1377 for i in range(schema.num_input_tensors()):
1378 code.writeline(f"in{i}_strides = in{i}.stride()")
1379 for i in range(schema.num_output_tensors()):
1380 code.writeline(f"out{i}_strides = out{i}.stride()")
1382 code.writeline("with torch_device_fn.device(in0.device.index):")
1383 with code.indent():
1384 code.writeline(f"{self.jit_fn_name}[grid](")
1385 with code.indent():
1386 params = []
1387 # NOTE: WRAP
1388 for i in range(schema.num_inputs()):
1389 if schema.is_tensor(i):
1390 params.append(f"{self.input_name(i)}")
1391 else:
1392 params.append(self.input_name(i))
1393 for i in range(schema.num_output_tensors()):
1394 params.append(f"{self.output_name(i)}")
1396 code.writeline(f"{_cs(params)},")
1398 if ndim > 0:
1399 for i in range(schema.num_input_tensors()):
1400 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim))
1401 code.writeline(f"{s}, # stride for in{i}")
1402 for i in range(schema.num_output_tensors()):
1403 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim))
1404 code.writeline(f"{s}, # stride for out{i}")
1406 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim))
1407 code.writeline(f"{shape_args}, # task indexing space")
1408 code.writeline("num_tasks, # num tasks")
1409 if self.config.prefer_block_pointer:
1410 code.writeline("FALLBACK_BPTR=FALLBACK_BPTR,")
1411 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
1412 code.writeline("tile_size=tile_size,")
1413 code.writeline("one_tile_per_cta=one_tile_per_cta,")
1414 code.writeline("num_warps=num_warps,")
1415 code.writeline(")")
1417 def gen_return(self, code: IndentedBuffer):
1418 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors()))
1419 code.writeline(f"return {return_exprs}")
1421 def codegen_nd_tile(self, code):
1422 if self.config.prefer_block_pointer:
1423 self.gen_fallback_bptr(code)
1424 self.gen_signature(code)
1426 with code.indent():
1427 self.gen_docstring(code)
1428 self.gen_same_shape_check(code)
1429 self.gen_task_partition(code)
1430 self.gen_kernel_launch(code)
1431 self.gen_return(code)
1432 code.newline()
1433 return code
1435 def codegen_1d_tile(self, code):
1436 if self.config.prefer_block_pointer:
1437 self.gen_fallback_bptr(code)
1438 self.gen_signature(code)
1440 with code.indent():
1441 self.gen_docstring(code)
1442 self.gen_same_shape_check(code)
1443 self.gen_task_partition_1d(code)
1444 self.gen_kernel_launch_1d(code)
1445 self.gen_return(code)
1446 code.newline()
1447 return code
1450class ModuleGenerator:
1451 def __init__(
1452 self,
1453 function_schema: FunctionSchema,
1454 scalar_fn: triton.JITFunction,
1455 ndim: int,
1456 jit_fn_name: str,
1457 wrapper_name: str,
1458 config: CodeGenConfig,
1459 ):
1460 self.config = config
1461 self.scalar_fn = scalar_fn
1462 self.wrapper_gen = WrapperGenerator(
1463 function_schema, jit_fn_name, ndim, wrapper_name, config
1464 )
1465 self.kernel_gen = KernelGenerator(
1466 function_schema, scalar_fn, ndim, jit_fn_name, config
1467 )
1469 @staticmethod
1470 def _collect_jit_deps(scalar_fn):
1471 """Collect extra imports and local @triton.jit helper sources.
1473 Parses the source module where scalar_fn is defined using AST.
1474 Returns a tuple of:
1475 - extra_imports: dict of module_path -> set of names
1476 - local_sources: list of source strings for local @triton.jit
1477 functions (those NOT decorated with @pointwise_dynamic)
1478 """
1479 import ast
1480 import inspect
1482 py_fn = getattr(scalar_fn, "fn", scalar_fn)
1483 module_name = getattr(py_fn, "__module__", None)
1484 if not module_name:
1485 return {}, []
1486 try:
1487 mod = importlib.import_module(module_name)
1488 source_file = inspect.getfile(mod)
1489 except (ImportError, TypeError, OSError):
1490 return {}, []
1491 try:
1492 with open(source_file) as f:
1493 module_source = f.read()
1494 source_lines = module_source.splitlines(keepends=True)
1495 tree = ast.parse(module_source)
1496 except (OSError, SyntaxError):
1497 return {}, []
1499 # Collect non-standard import-from lines
1500 ALREADY_IMPORTED = {
1501 "math",
1502 "typing",
1503 "torch",
1504 "triton",
1505 "triton.language",
1506 "flag_gems.utils.shape_utils",
1507 "flag_gems.utils.tensor_wrapper",
1508 "flag_gems.utils.libentry",
1509 "flag_gems.utils",
1510 "flag_gems.runtime",
1511 "flag_gems.utils.pointwise_dynamic",
1512 "utils.pointwise_dynamic",
1513 "randn",
1514 "utils",
1515 "all",
1516 }
1517 extra_imports = {}
1518 for node in ast.iter_child_nodes(tree):
1519 if isinstance(node, ast.ImportFrom) and node.module:
1520 if node.module in ALREADY_IMPORTED:
1521 continue
1522 names = {alias.name for alias in node.names}
1523 extra_imports.setdefault(node.module, set()).update(names)
1525 # Collect local @triton.jit functions (without @pointwise_dynamic)
1526 def _has_decorator(func_node, name):
1527 for dec in func_node.decorator_list:
1528 src = "".join(source_lines[dec.lineno - 1 : dec.end_lineno])
1529 if name in src:
1530 return True
1531 return False
1533 def _extract_source(func_node):
1534 start = func_node.lineno - 1
1535 if func_node.decorator_list:
1536 start = func_node.decorator_list[0].lineno - 1
1537 end = func_node.end_lineno
1538 return "".join(source_lines[start:end])
1540 local_sources = []
1541 for node in ast.iter_child_nodes(tree):
1542 if not isinstance(node, ast.FunctionDef):
1543 continue
1544 if not _has_decorator(node, "triton.jit") and not _has_decorator(
1545 node, "jit"
1546 ):
1547 continue
1548 if _has_decorator(node, "pointwise_dynamic"):
1549 continue
1550 local_sources.append(_extract_source(node))
1552 return extra_imports, local_sources
1554 def generate_imports(self, code: IndentedBuffer) -> IndentedBuffer:
1555 code.writeline("import math")
1556 code.writeline("from typing import Union")
1557 code.writeline("import torch")
1558 code.writeline("import triton")
1559 code.writeline("from triton import language as tl")
1560 code.newline()
1561 code.writeline("import flag_gems")
1562 code.writeline("from flag_gems.utils.shape_utils import (")
1563 code.writeline(" heuristics_for_tile_size,")
1564 code.writeline(" heuristics_for_num_warps,")
1565 code.writeline(" stride_order,")
1566 code.writeline(")")
1567 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer")
1568 code.writeline("from flag_gems.utils.libentry import libentry, libtuner")
1569 code.writeline("from flag_gems.utils import triton_lang_extension as ext")
1570 code.writeline("from flag_gems.runtime import torch_device_fn")
1572 # Generate extra imports and local JIT deps of the scalar function
1573 jit_dep_imports, local_jit_sources = self._collect_jit_deps(self.scalar_fn)
1574 for module_path, names in sorted(jit_dep_imports.items()):
1575 sorted_names = ", ".join(sorted(names))
1576 code.writeline(f"from {module_path} import {sorted_names}")
1578 code.newline()
1579 code.newline()
1581 # Emit local @triton.jit helper functions
1582 for source in local_jit_sources:
1583 for line in source.splitlines():
1584 code.writeline(line)
1585 code.newline()
1587 return code
1589 def codegen(self, code: IndentedBuffer):
1590 code = self.generate_imports(code)
1591 if self.config.prefer_1d_tile:
1592 code = self.wrapper_gen.codegen_1d_tile(code)
1593 code = self.kernel_gen.codegen_1d_tile(code)
1594 else:
1595 code = self.wrapper_gen.codegen_nd_tile(code)
1596 code = self.kernel_gen.codegen_nd_tile(code)
1597 return code
1600@dataclass
1601class KernelInfo:
1602 """Information about a generated kernel for C++ integration."""
1604 file_path: str
1605 kernel_name: str
1606 wrapper_name: str
1607 ndim: int
1610class PointwiseDynamicFunction:
1611 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction
1612 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors).
1613 The generated code are written out to the cache directory (defaults to ~/.flaggems).
1614 """
1616 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None):
1617 self.fx = op_desc
1619 assert isinstance(scalar_fn, JITFunction)
1620 self._scalar_fn = scalar_fn
1621 self._scalar_fn_cache_key = scalar_fn.cache_key
1622 self.pid = os.getpid()
1624 self.config: CodeGenConfig = config or get_codegen_config()
1626 # instantiated & cached overloads
1627 self.overloads: Mapping[str, Callable] = {}
1628 # cached kernel info for C++ integration
1629 self._kernel_info_cache: Mapping[str, KernelInfo] = {}
1631 def __call__(self, *args, **kwargs):
1632 # inputs must be passed by position, outputs must be passed by keyword
1633 ndim, args, kwargs = self.prepare_args(*args, **kwargs)
1634 overload = self.instantiate(ndim)
1635 out = overload(*args, **kwargs)
1636 # NOTE: overload keeps the type of outputs:
1637 # if a pre-defiend output is a Tensor or StridedBuffer, the corresponding
1638 # output is also a Tensor StridedBuffer, respectively
1639 # since prepare_args Wraps all the arguments, the outputs are all StridedBuffer
1640 # but if manually instantiated overload is directly called, take care of
1641 # that manually
1642 return self._unwrap(out)
1644 @staticmethod
1645 def use_fast_path(tensors):
1646 return all_the_same_shape(tensors) and (
1647 all_c_contiguous(tensors)
1648 or (
1649 all_the_same_stride(tensors)
1650 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0])
1651 )
1652 )
1654 def prepare_args(self, *args, **kwargs):
1655 # output allocation(when needed)
1656 # task simplification & task-rank infernece & input-output reinterpretation
1657 schema = self.fx
1658 outputs_that_need_allocation: List[int] = []
1659 out_tensors = []
1660 for i in range(schema.num_output_tensors()):
1661 k = f"out{i}"
1662 if k in kwargs:
1663 out_tensors.append(kwargs[k])
1664 else:
1665 outputs_that_need_allocation.append(i)
1666 # input arguments must be passed by position
1667 if schema._is_tensor is not None:
1668 if not check_tensor_attributes(args, (schema._is_tensor)):
1669 raise ValueError(
1670 "Input arguments must be passed by position, and the corresponding dtype must be specified."
1671 )
1672 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)]
1674 # output dtype promotions
1675 outputs_dtypes_for_allocation = []
1676 for i in outputs_that_need_allocation:
1677 *arg_indices, method = schema._promotion_methods[i]
1678 promote_args = (args[j] for j in arg_indices)
1679 _, dtype = type_promotion(*promote_args, type_promotion=method)
1680 outputs_dtypes_for_allocation.append(dtype)
1682 tensors = out_tensors + in_tensors
1683 if self.use_fast_path(tensors): # dimension collapse & use physical ordering
1684 allocated_outputs = [
1685 torch.empty_like(tensors[0], dtype=dtype)
1686 for dtype in outputs_dtypes_for_allocation
1687 ]
1688 task_shape = (tensors[0].numel(),)
1689 strides = (1,)
1690 ndim = 1
1691 args = tuple(
1692 (
1693 StridedBuffer(item, task_shape, strides)
1694 if schema.is_tensor(i)
1695 else item
1696 )
1697 for i, item in enumerate(args)
1698 )
1699 kwargs = {
1700 k: StridedBuffer(item, task_shape, strides)
1701 for k, item in kwargs.items()
1702 }
1703 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1704 kwargs[f"out{output_id}"] = StridedBuffer(
1705 allocated_outputs[seq_id], task_shape, strides
1706 )
1707 else:
1708 # a simple strategy: all the undefined tensors will follow the first
1709 # tensor that is not broadcated, no attempts to simplify task, no reordering,
1710 # no dimenion collapsing
1711 shapes = tuple(item.shape for item in in_tensors)
1713 task_shape = broadcast_shapes(shapes)
1715 if out_tensors:
1716 for index, item in enumerate(out_tensors):
1717 if list(item.shape) != list(task_shape):
1718 raise RuntimeError(
1719 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!"
1720 )
1721 # output arguments must not have internal overlapping for pointwise operation
1722 if has_internal_overlapping(item) == MemOverlap.Yes:
1723 raise RuntimeError(
1724 "Pointwise Input arguments should not have internal overlapping."
1725 )
1727 ndim = len(task_shape)
1728 for item in tensors:
1729 if item.shape == task_shape:
1730 allocated_outputs = [
1731 torch.empty_like(item, dtype=dtype)
1732 for dtype in outputs_dtypes_for_allocation
1733 ]
1734 break
1735 else: # nobreak
1736 device = tensors[0].device
1737 allocated_outputs = [
1738 torch.empty(task_shape, dtype=dtype, device=device)
1739 for dtype in outputs_dtypes_for_allocation
1740 ]
1741 args = tuple(
1742 (
1743 StridedBuffer(
1744 item,
1745 task_shape,
1746 broadcasted_stride(item.shape, item.stride(), task_shape),
1747 )
1748 if schema.is_tensor(i)
1749 else item
1750 )
1751 for i, item in enumerate(args)
1752 )
1753 kwargs = {
1754 k: StridedBuffer(
1755 item,
1756 task_shape,
1757 broadcasted_stride(item.shape, item.stride(), task_shape),
1758 )
1759 for k, item in kwargs.items()
1760 }
1761 for seq_id, output_id in enumerate(outputs_that_need_allocation):
1762 item = allocated_outputs[seq_id]
1763 kwargs[f"out{output_id}"] = StridedBuffer(
1764 item,
1765 task_shape,
1766 broadcasted_stride(item.shape, item.stride(), task_shape),
1767 )
1768 return (ndim, args, kwargs)
1770 def _unwrap(self, tensors):
1771 # unwrap StridedBuffer to get Tensor
1772 if self.fx.num_output_tensors() == 1:
1773 item = tensors
1774 return item.unwrap()
1775 return tuple(item.unwrap() for item in tensors)
1777 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]:
1778 """Compute kernel name, wrapper name, and file path for a given ndim.
1780 This is the single source of truth for naming, used by both instantiate()
1781 and get_kernel_info() to ensure consistency.
1783 Returns:
1784 Tuple of (kernel_name, wrapper_name, file_path)
1785 """
1786 scalar_fn_name = self._scalar_fn.__name__
1787 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}"
1788 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}"
1790 file_name = (
1791 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_"
1792 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}"
1793 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}"
1794 ".py"
1795 )
1796 file_path = str(code_cache_dir() / file_name)
1798 return kernel_name, wrapper_name, file_path
1800 def instantiate(self, ndim):
1801 # NOTE: manually instantiated overload does not have `prepare_args` as
1802 # preprocessing, so you have to manually allocate output and make sure that
1803 # the inputs & ouputs actually fits the manually instantiated overload
1804 key = f"{ndim}_{self.config.prefer_block_pointer}"
1805 if key in self.overloads:
1806 return self.overloads[key]
1808 code = IndentedBuffer()
1810 # Use helper to compute names (single source of truth)
1811 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim)
1813 module_gen = ModuleGenerator(
1814 self.fx,
1815 self._scalar_fn,
1816 ndim,
1817 kernel_name,
1818 wrapper_name,
1819 self.config,
1820 )
1821 module_gen.codegen(code)
1823 # NOTE: [why write the generated code to a file]
1824 # triton uses inpsect to get the source of the jitted function, which requires
1825 # that the source code can be found by inspect
1826 # We write it into a file, since inspect cannot find the source of functions dynamically
1827 # created via exec string. We can help inspect to find the source by hacking linecache
1828 # library, but we find generating a module simpler, since we can generating 2 functions
1829 # the kernel and the wrapper, and the wrapper calls the kernel.
1830 write_atomic(file_path, code.getvalue())
1832 # load
1833 spec = importlib.util.spec_from_file_location(
1834 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}",
1835 file_path,
1836 )
1837 m = importlib.util.module_from_spec(spec)
1838 # do not expose it to sys.modules
1839 # sys.modules["_add_module"] = m
1841 # NOTE: [why not import the scalar function]
1842 # we do not re-import the scalar function, although the generated kernel **calls** it
1843 # Since a function's __name__ may be changed, from the module where it is defined import its
1844 # __name__ is not same; Also the same may be rebind to something else, importing via name
1845 # cannot guarantee that scalar function is imported.
1846 # So we copy the scalar function and its __globals__ to the generated module to do this
1847 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime
1848 spec.loader.exec_module(m)
1849 m.__dict__.update(self._scalar_fn.__globals__)
1850 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn
1852 overload = getattr(m, wrapper_name)
1853 self.overloads[key] = overload
1855 # Cache kernel info for C++ integration
1856 self._kernel_info_cache[key] = KernelInfo(
1857 file_path=file_path,
1858 kernel_name=kernel_name,
1859 wrapper_name=wrapper_name,
1860 ndim=ndim,
1861 )
1863 return overload
1865 def get_kernel_info(self, ndim: int) -> KernelInfo:
1866 """Get kernel information for a given ndim.
1868 This method is useful for C++ integration to get the file path and
1869 kernel name without duplicating the naming logic.
1871 If the kernel hasn't been instantiated yet, this will instantiate it first.
1873 Args:
1874 ndim: The rank of the task space
1876 Returns:
1877 KernelInfo with file_path, kernel_name, wrapper_name, and ndim
1878 """
1879 key = f"{ndim}_{self.config.prefer_block_pointer}"
1881 # Ensure the kernel is instantiated
1882 if key not in self._kernel_info_cache:
1883 self.instantiate(ndim)
1885 return self._kernel_info_cache[key]
1888def pointwise_dynamic(
1889 f: Optional[JITFunction] = None,
1890 *,
1891 num_inputs: Optional[int] = None,
1892 is_tensor: Optional[List[bool]] = None,
1893 dtypes: Optional[List[Optional[type]]] = None,
1894 num_outputs: Optional[int] = None,
1895 promotion_methods: Optional[Tuple[int, ...]] = None,
1896 config: Optional[CodeGenConfig] = None,
1897):
1898 def decorator(fn):
1899 nonlocal num_inputs
1900 if (num_inputs is None) and (is_tensor is None) and (dtypes is None):
1901 num_inputs = len(fn.arg_names)
1902 op_desc = FunctionSchema(
1903 num_inputs=num_inputs,
1904 is_tensor=is_tensor,
1905 dtypes=dtypes,
1906 num_outputs=num_outputs,
1907 promotion_methods=promotion_methods,
1908 )
1909 return PointwiseDynamicFunction(op_desc, fn, config)
1911 if f is not None:
1912 return decorator(f)
1913 return decorator