Coverage for src/flag_gems/ops/mv.py: 56%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=runtime.get_tuned_config("mv"), 

18 key=["M", "N"], 

19 strategy=["align32", "align32"], 

20 flagtune_op_name="mv", 

21) 

22@triton.jit 

23def mv_kernel( 

24 A, 

25 B, 

26 C, 

27 N, 

28 M, 

29 stride_an, 

30 stride_am, 

31 stride_bm, 

32 stride_cn, 

33 BLOCK_N: tl.constexpr, 

34 BLOCK_M: tl.constexpr, 

35): 

36 pid = ext.program_id(0) 

37 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] 

38 offset_m = tl.arange(0, BLOCK_M)[None, :] 

39 n_mask = offset_n < N 

40 A_ptrs = A + offset_n * stride_an + offset_m * stride_am 

41 B_ptrs = B + offset_m * stride_bm 

42 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) 

43 for m in range(0, M, BLOCK_M): 

44 m_mask = m + offset_m < M 

45 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) 

46 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) 

47 acc += a * b 

48 A_ptrs += BLOCK_M * stride_am 

49 B_ptrs += BLOCK_M * stride_bm 

50 

51 acc = tl.sum(acc, axis=1) 

52 C_ptrs = C + offset_n * stride_cn 

53 tl.store(C_ptrs, acc[:, None], mask=n_mask) 

54 

55 

56def mv(inp, vec): 

57 logger.debug("GEMS MV") 

58 assert inp.shape[1] == vec.shape[0], "incompatible dimensions" 

59 N, M = inp.shape 

60 out = torch.empty((N,), device=inp.device, dtype=inp.dtype) 

61 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

62 with torch_device_fn.device(inp.device): 

63 mv_kernel[grid]( 

64 inp, 

65 vec, 

66 out, 

67 N, 

68 M, 

69 inp.stride(0), 

70 inp.stride(1), 

71 vec.stride(0), 

72 out.stride(0), 

73 ) 

74 return out