Coverage for src/flag_gems/runtime/backend/_spacemit/ops/mv.py: 0%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.autotune( 

16 configs=runtime.get_tuned_config("mv"), 

17 key=["M", "N"], 

18) 

19@triton.jit 

20def mv_kernel( 

21 A, 

22 B, 

23 C, 

24 N, 

25 M, 

26 stride_an, 

27 stride_am, 

28 stride_bm, 

29 stride_cn, 

30 BLOCK_M: tl.constexpr, 

31 BLOCK_N: tl.constexpr, 

32): 

33 pid = tl.program_id(0) 

34 block_start_n = pid * BLOCK_N 

35 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=A.dtype.element_ty) 

36 

37 for m in range(0, M, BLOCK_M): 

38 a_block_ptr = tl.make_block_ptr( 

39 base=A, 

40 shape=[N, M], 

41 strides=[stride_an, stride_am], 

42 offsets=[block_start_n, m], 

43 block_shape=[BLOCK_N, BLOCK_M], 

44 order=[1, 0], 

45 ) 

46 a = tl.load(a_block_ptr, boundary_check=(0, 1)).to(A.dtype.element_ty) 

47 

48 b_block_ptr = tl.make_block_ptr( 

49 base=B, 

50 shape=[M], 

51 strides=[stride_bm], 

52 offsets=[m], 

53 block_shape=[BLOCK_M], 

54 order=[0], 

55 ) 

56 b = tl.load(b_block_ptr, boundary_check=(0,)).to(A.dtype.element_ty) 

57 acc += a * b[None, :] 

58 

59 result = tl.sum(acc, axis=1) 

60 c_block_ptr = tl.make_block_ptr( 

61 base=C, 

62 shape=[N], 

63 strides=[stride_cn], 

64 offsets=[block_start_n], 

65 block_shape=[BLOCK_N], 

66 order=[0], 

67 ) 

68 tl.store(c_block_ptr, result.to(C.dtype.element_ty), boundary_check=(0,)) 

69 

70 

71def mv(inp, vec): 

72 logger.debug("GEMS_SPACEMIT MV") 

73 assert inp.shape[1] == vec.shape[0], "incompatible dimensions" 

74 N, M = inp.shape 

75 out = torch.empty((N,), device=inp.device, dtype=inp.dtype) 

76 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

77 with torch_device_fn.device(inp.device): 

78 mv_kernel[grid]( 

79 inp, 

80 vec, 

81 out, 

82 N, 

83 M, 

84 inp.stride(0), 

85 inp.stride(1), 

86 vec.stride(0), 

87 out.stride(0), 

88 ) 

89 return out