Coverage for src/flag_gems/runtime/backend/_sunrise/ops/mv.py: 0%

40 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@libtuner( 

18 # configs=runtime.get_tuned_config("mv"), 

19 configs=runtime.ops_get_configs("mv", pre_hook=None) 

20 if os.environ.get("USE_FLAGTUNE") == "1" 

21 else runtime.get_tuned_config("mv"), 

22 key=["M", "N"], 

23 strategy=runtime.get_expand_config("mv")["strategy"] 

24 if os.environ.get("USE_FLAGTUNE") == "1" 

25 else ["align32", "align32"], 

26) 

27@triton.jit 

28def mv_kernel( 

29 A, 

30 B, 

31 C, 

32 N, 

33 M, 

34 stride_an, 

35 stride_am, 

36 stride_bm, 

37 stride_cn, 

38 BLOCK_N: tl.constexpr, 

39 BLOCK_M: tl.constexpr, 

40): 

41 pid = ext.program_id(0) 

42 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] 

43 offset_m = tl.arange(0, BLOCK_M)[None, :] 

44 n_mask = offset_n < N 

45 A_ptrs = A + offset_n * stride_an + offset_m * stride_am 

46 B_ptrs = B + offset_m * stride_bm 

47 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) 

48 for m in range(0, M, BLOCK_M): 

49 m_mask = m + offset_m < M 

50 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) 

51 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) 

52 acc += a * b 

53 A_ptrs += BLOCK_M * stride_am 

54 B_ptrs += BLOCK_M * stride_bm 

55 

56 acc = tl.sum(acc, axis=1) 

57 C_ptrs = C + offset_n * stride_cn 

58 tl.store(C_ptrs, acc[:, None], mask=n_mask) 

59 

60 

61def mv(inp, vec): 

62 logger.debug("GEMS MV") 

63 assert inp.shape[1] == vec.shape[0], "incompatible dimensions" 

64 N, M = inp.shape 

65 out = torch.empty((N,), dtype=inp.dtype).to(device=inp.device) 

66 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

67 with torch_device_fn.device(inp.device): 

68 mv_kernel[grid]( 

69 inp, 

70 vec, 

71 out, 

72 N, 

73 M, 

74 inp.stride(0), 

75 inp.stride(1), 

76 vec.stride(0), 

77 out.stride(0), 

78 ) 

79 return out