Coverage for src/flag_gems/runtime/common.py: 100%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1from enum import Enum
4class vendors(Enum):
5 NVIDIA = 0
6 CAMBRICON = 1
7 METAX = 2
8 ILUVATAR = 3
9 MTHREADS = 4
10 KUNLUNXIN = 5
11 HYGON = 6
12 AMD = 7
13 AIPU = 8
14 ASCEND = 9
15 TSINGMICRO = 10
16 SUNRISE = 11
17 ENFLAME = 12
18 SPACEMIT = 13
19 THEAD = 14
20 ARM = 15
22 @classmethod
23 def get_all_vendors(cls) -> dict:
24 vendorDict = {}
25 for member in cls:
26 vendorDict[member.name.lower()] = member
27 return vendorDict
30UNSUPPORT_FP64 = frozenset(
31 {
32 vendors.AIPU,
33 vendors.ASCEND,
34 vendors.CAMBRICON,
35 vendors.ENFLAME,
36 vendors.ILUVATAR,
37 vendors.KUNLUNXIN,
38 vendors.MTHREADS,
39 vendors.SUNRISE,
40 vendors.SPACEMIT,
41 vendors.TSINGMICRO,
42 }
43)
45UNSUPPORT_BF16 = frozenset(
46 {
47 vendors.AIPU,
48 vendors.SUNRISE,
49 vendors.SPACEMIT,
50 }
51)
53UNSUPPORT_INT64 = frozenset(
54 {
55 vendors.AIPU,
56 vendors.ENFLAME,
57 vendors.SPACEMIT,
58 vendors.SUNRISE,
59 vendors.TSINGMICRO,
60 }
61)
63DEFAULT_STRATEGIES = {
64 "addmm": ["align32", "align32", "align32"],
65 "addmm_sqmma": ["align32", "align32", "align32"],
66 "baddbmm": ["align32", "align32", "align32"],
67 "bmm": ["align32", "align32", "align32", "align32", "align32"],
68 "bmm_sqmma": ["align32", "align32", "align32"],
69 "gemv": ["align32", "align32", "align32", "default"],
70 "mm": ["align32", "align32", "align32", "align32", "align32"],
71 "mm_general_tma": [
72 "align32",
73 "align32",
74 "align32",
75 "align32",
76 "align32",
77 "default",
78 ],
79 "mv": ["align32", "align32"],
80 "sparse_attention": ["align32", "align32", "align32"],
81 "w8a8_block_fp8_general": [
82 "align32",
83 "align32",
84 "align32",
85 "align32",
86 "align32",
87 ],
88 "w8a8_block_fp8_general_splitk": [
89 "align32",
90 "align32",
91 "align32",
92 "align32",
93 "align32",
94 ],
95 "w8a8_block_fp8_general_tma": [
96 "align32",
97 "align32",
98 "align32",
99 "align32",
100 "align32",
101 "default",
102 ],
103 "mm_splitk": ["align32", "align32", "align32", "align32", "align32"],
104}
106OP_KEY_ORDERS = {
107 "addmm": ["M", "N", "K"],
108 "addmm_sqmma": ["M", "N", "K"],
109 "bmm": ["M", "N", "K", "stride_am", "stride_bk"],
110 "bmm_sqmma": ["M", "N", "K"],
111 "baddbmm": ["M", "N", "K"],
112 "gemv": ["M", "K", "stride_am", "stride_bk"],
113 "mm": ["M", "N", "K", "stride_am", "stride_bk"],
114 "mm_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"],
115 "mv": ["M", "N"],
116 "sparse_attention": ["topk", "H_ACTUAL", "D"],
117 "w8a8_block_fp8_general": ["M", "N", "K", "stride_am", "stride_bk"],
118 "w8a8_block_fp8_general_splitk": ["M", "N", "K", "stride_am", "stride_bk"],
119 "w8a8_block_fp8_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"],
120 "mm_splitk": ["M", "N", "K", "stride_am", "stride_bk"],
121}
124# Mapping from vendor name to torch attribute for quick detection
125_VENDOR_TORCH_ATTR = {
126 "ascend": "npu",
127 "cambricon": "mlu",
128 "enflame": "gcu",
129 "hygon": "__hcu_version__",
130 "iluvatar": "corex",
131 "mthreads": "musa",
132 "sunrise": "ptpu",
133}
135__all__ = [
136 "vendors",
137 "UNSUPPORT_FP64",
138 "UNSUPPORT_BF16",
139 "UNSUPPORT_INT64",
140 "DEFAULT_STRATEGIES",
141 "OP_KEY_ORDERS",
142 "_VENDOR_TORCH_ATTR",
143]