Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather_dispatch.py: 0%
12 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1from typing import Callable
3import torch
5from flag_gems.runtime.backend._ascend.ops.gather_ascend import gather
6from flag_gems.runtime.backend._ascend.ops.gather_collapsed_uintdiv import (
7 apply_prefix_narrows,
8 can_collapse_axes,
9 gather_collapsed,
10)
13def gather_auto(
14 inp: torch.Tensor,
15 dim: int,
16 index: torch.Tensor,
17 out: torch.Tensor,
18 grid_fn,
19 magic_map=None,
20 use_collapsed=False,
21 with_negative_index=False,
22) -> Callable[[], None]:
23 ok, narrows = can_collapse_axes(inp, index, dim)
24 if ok and use_collapsed:
25 inp = apply_prefix_narrows(inp, narrows)
26 run_kernel = gather_collapsed(
27 inp,
28 dim,
29 index,
30 out,
31 grid_fn=grid_fn,
32 return_run_kernel=True,
33 with_negative_index=with_negative_index,
34 )
36 else:
38 def run_kernel():
39 gather(
40 inp,
41 dim,
42 index,
43 out,
44 grid_fn=grid_fn,
45 magic_map=magic_map,
46 with_negative_index=with_negative_index,
47 )
49 return run_kernel