Coverage for src/flag_gems/runtime/backend/_ascend/ops/gather_dispatch.py: 0%

12 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1from typing import Callable 

2 

3import torch 

4 

5from flag_gems.runtime.backend._ascend.ops.gather_ascend import gather 

6from flag_gems.runtime.backend._ascend.ops.gather_collapsed_uintdiv import ( 

7 apply_prefix_narrows, 

8 can_collapse_axes, 

9 gather_collapsed, 

10) 

11 

12 

13def gather_auto( 

14 inp: torch.Tensor, 

15 dim: int, 

16 index: torch.Tensor, 

17 out: torch.Tensor, 

18 grid_fn, 

19 magic_map=None, 

20 use_collapsed=False, 

21 with_negative_index=False, 

22) -> Callable[[], None]: 

23 ok, narrows = can_collapse_axes(inp, index, dim) 

24 if ok and use_collapsed: 

25 inp = apply_prefix_narrows(inp, narrows) 

26 run_kernel = gather_collapsed( 

27 inp, 

28 dim, 

29 index, 

30 out, 

31 grid_fn=grid_fn, 

32 return_run_kernel=True, 

33 with_negative_index=with_negative_index, 

34 ) 

35 

36 else: 

37 

38 def run_kernel(): 

39 gather( 

40 inp, 

41 dim, 

42 index, 

43 out, 

44 grid_fn=grid_fn, 

45 magic_map=magic_map, 

46 with_negative_index=with_negative_index, 

47 ) 

48 

49 return run_kernel