Coverage for src/flag_gems/runtime/common.py: 100%
30 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1from enum import Enum
4class vendors(Enum):
5 NVIDIA = 0
6 CAMBRICON = 1
7 METAX = 2
8 ILUVATAR = 3
9 MTHREADS = 4
10 KUNLUNXIN = 5
11 HYGON = 6
12 AMD = 7
13 AIPU = 8
14 ASCEND = 9
15 TSINGMICRO = 10
16 SUNRISE = 11
17 ENFLAME = 12
18 SPACEMIT = 13
19 THEAD = 14
21 @classmethod
22 def get_all_vendors(cls) -> dict:
23 vendorDict = {}
24 for member in cls:
25 vendorDict[member.name.lower()] = member
26 return vendorDict
29UNSUPPORT_FP64 = frozenset(
30 {
31 vendors.AIPU,
32 vendors.ASCEND,
33 vendors.CAMBRICON,
34 vendors.ENFLAME,
35 vendors.ILUVATAR,
36 vendors.KUNLUNXIN,
37 vendors.MTHREADS,
38 vendors.SUNRISE,
39 vendors.SPACEMIT,
40 vendors.TSINGMICRO,
41 }
42)
44UNSUPPORT_BF16 = frozenset(
45 {
46 vendors.AIPU,
47 vendors.SUNRISE,
48 vendors.SPACEMIT,
49 }
50)
52UNSUPPORT_INT64 = frozenset(
53 {
54 vendors.AIPU,
55 vendors.ENFLAME,
56 vendors.SPACEMIT,
57 vendors.SUNRISE,
58 vendors.TSINGMICRO,
59 }
60)
62DEFAULT_STRATEGIES = {
63 "addmm": ["align32", "align32", "align32"],
64 "addmm_sqmma": ["align32", "align32", "align32"],
65 "baddbmm": ["align32", "align32", "align32"],
66 "bmm": ["align32", "align32", "align32", "align32", "align32"],
67 "bmm_sqmma": ["align32", "align32", "align32"],
68 "gemv": ["align32", "align32", "align32", "default"],
69 "mm": ["align32", "align32", "align32", "align32", "align32"],
70 "mm_general_tma": [
71 "align32",
72 "align32",
73 "align32",
74 "align32",
75 "align32",
76 "default",
77 ],
78 "mv": ["align32", "align32"],
79 "sparse_attention": ["align32", "align32", "align32"],
80 "w8a8_block_fp8_general": [
81 "align32",
82 "align32",
83 "align32",
84 "align32",
85 "align32",
86 ],
87 "w8a8_block_fp8_general_splitk": [
88 "align32",
89 "align32",
90 "align32",
91 "align32",
92 "align32",
93 ],
94 "w8a8_block_fp8_general_tma": [
95 "align32",
96 "align32",
97 "align32",
98 "align32",
99 "align32",
100 "default",
101 ],
102 "mm_splitk": ["align32", "align32", "align32", "align32", "align32"],
103}
105OP_KEY_ORDERS = {
106 "addmm": ["M", "N", "K"],
107 "addmm_sqmma": ["M", "N", "K"],
108 "bmm": ["M", "N", "K", "stride_am", "stride_bk"],
109 "bmm_sqmma": ["M", "N", "K"],
110 "baddbmm": ["M", "N", "K"],
111 "gemv": ["M", "K", "stride_am", "stride_bk"],
112 "mm": ["M", "N", "K", "stride_am", "stride_bk"],
113 "mm_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"],
114 "mv": ["M", "N"],
115 "sparse_attention": ["topk", "H_ACTUAL", "D"],
116 "w8a8_block_fp8_general": ["M", "N", "K", "stride_am", "stride_bk"],
117 "w8a8_block_fp8_general_splitk": ["M", "N", "K", "stride_am", "stride_bk"],
118 "w8a8_block_fp8_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"],
119 "mm_splitk": ["M", "N", "K", "stride_am", "stride_bk"],
120}
123# Mapping from vendor name to torch attribute for quick detection
124_VENDOR_TORCH_ATTR = {
125 "ascend": "npu",
126 "cambricon": "mlu",
127 "enflame": "gcu",
128 "hygon": "__hcu_version__",
129 "iluvatar": "corex",
130 "mthreads": "musa",
131 "sunrise": "ptpu",
132}
134__all__ = [
135 "vendors",
136 "UNSUPPORT_FP64",
137 "UNSUPPORT_BF16",
138 "UNSUPPORT_INT64",
139 "DEFAULT_STRATEGIES",
140 "OP_KEY_ORDERS",
141 "_VENDOR_TORCH_ATTR",
142]