Coverage for src/flag_gems/runtime/common.py: 100%
32 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import os
2from enum import Enum
5class vendors(Enum):
6 NVIDIA = 0
7 CAMBRICON = 1
8 METAX = 2
9 ILUVATAR = 3
10 MTHREADS = 4
11 KUNLUNXIN = 5
12 HYGON = 6
13 AMD = 7
14 AIPU = 8
15 ASCEND = 9
16 TSINGMICRO = 10
17 SUNRISE = 11
18 ENFLAME = 12
19 SPACEMIT = 13
20 THEAD = 14
22 @classmethod
23 def get_all_vendors(cls) -> dict:
24 vendorDict = {}
25 for member in cls:
26 vendorDict[member.name.lower()] = member
27 return vendorDict
30UNSUPPORT_FP64 = frozenset(
31 {
32 vendors.AIPU,
33 vendors.ASCEND,
34 vendors.CAMBRICON,
35 vendors.ENFLAME,
36 vendors.ILUVATAR,
37 vendors.KUNLUNXIN,
38 vendors.MTHREADS,
39 vendors.SUNRISE,
40 vendors.SPACEMIT,
41 vendors.TSINGMICRO,
42 }
43)
45UNSUPPORT_BF16 = frozenset(
46 {
47 vendors.AIPU,
48 vendors.SUNRISE,
49 vendors.SPACEMIT,
50 }
51)
53UNSUPPORT_INT64 = frozenset(
54 {
55 vendors.AIPU,
56 vendors.ENFLAME,
57 vendors.SPACEMIT,
58 vendors.SUNRISE,
59 vendors.TSINGMICRO,
60 }
61)
63DEFAULT_EXPAND_CONFIG_PATH = os.path.normpath(
64 os.path.join(
65 os.path.dirname(__file__),
66 "..",
67 "utils",
68 "configs",
69 "general_ops_expand_configs.yaml",
70 )
71)
74DEFAULT_STRATEGIES = {
75 "addmm": ["align32", "align32", "align32"],
76 "addmm_sqmma": ["align32", "align32", "align32"],
77 "baddbmm": ["align32", "align32", "align32"],
78 "bmm": ["align32", "align32", "align32", "align32", "align32"],
79 "bmm_sqmma": ["align32", "align32", "align32"],
80 "gemv": ["align32", "align32", "align32", "default"],
81 "mm": ["align32", "align32", "align32", "align32", "align32"],
82 "mm_general_tma": [
83 "align32",
84 "align32",
85 "align32",
86 "align32",
87 "align32",
88 "default",
89 ],
90 "mv": ["align32", "align32"],
91 "sparse_attention": ["align32", "align32", "align32"],
92 "w8a8_block_fp8_general": [
93 "align32",
94 "align32",
95 "align32",
96 "align32",
97 "align32",
98 ],
99 "w8a8_block_fp8_general_splitk": [
100 "align32",
101 "align32",
102 "align32",
103 "align32",
104 "align32",
105 ],
106 "w8a8_block_fp8_general_tma": [
107 "align32",
108 "align32",
109 "align32",
110 "align32",
111 "align32",
112 "default",
113 ],
114 "mm_splitk": ["align32", "align32", "align32", "align32", "align32"],
115}
117OP_KEY_ORDERS = {
118 "addmm": ["M", "N", "K"],
119 "addmm_sqmma": ["M", "N", "K"],
120 "bmm": ["M", "N", "K", "stride_am", "stride_bk"],
121 "bmm_sqmma": ["M", "N", "K"],
122 "baddbmm": ["M", "N", "K"],
123 "gemv": ["M", "K", "stride_am", "stride_bk"],
124 "mm": ["M", "N", "K", "stride_am", "stride_bk"],
125 "mm_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"],
126 "mv": ["M", "N"],
127 "sparse_attention": ["topk", "H_ACTUAL", "D"],
128 "w8a8_block_fp8_general": ["M", "N", "K", "stride_am", "stride_bk"],
129 "w8a8_block_fp8_general_splitk": ["M", "N", "K", "stride_am", "stride_bk"],
130 "w8a8_block_fp8_general_tma": ["M", "N", "K", "stride_am", "stride_bk", "dtype"],
131 "mm_splitk": ["M", "N", "K", "stride_am", "stride_bk"],
132}
135# Mapping from vendor name to torch attribute for quick detection
136_VENDOR_TORCH_ATTR = {
137 "ascend": "npu",
138 "cambricon": "mlu",
139 "enflame": "gcu",
140 "hygon": "__hcu_version__",
141 "iluvatar": "corex",
142 "mthreads": "musa",
143 "sunrise": "ptpu",
144}
146__all__ = [
147 "vendors",
148 "UNSUPPORT_FP64",
149 "UNSUPPORT_BF16",
150 "UNSUPPORT_INT64",
151 "DEFAULT_STRATEGIES",
152 "OP_KEY_ORDERS",
153 "_VENDOR_TORCH_ATTR",
154]