Coverage for src/flag_gems/utils/triton_version_utils.py: 60%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import re
3import triton
4from packaging.version import InvalidVersion, Version
7def _coerce_triton_version(version: str) -> Version:
8 try:
9 return Version(version)
10 except InvalidVersion:
11 release = []
12 for part in version.split("+", 1)[0].split(".")[:3]:
13 match = re.match(r"\d+", part)
14 release.append(match.group(0) if match else "0")
15 while len(release) < 3:
16 release.append("0")
17 return Version(".".join(release))
20def _triton_version_at_least(major: int, minor: int, patch: int = 0) -> bool:
21 version = str(getattr(triton, "__version__", "0.0.0"))
22 return _coerce_triton_version(version) >= Version(f"{major}.{minor}.{patch}")
25def has_triton_tle(major: int = 0, minor: int = 0, patch: int = 0) -> bool:
26 if not _triton_version_at_least(major, minor, patch):
27 return False
28 try:
29 import triton.experimental.tle.language as _tle # noqa: F401
31 return True
32 except ImportError:
33 return False
36HAS_TLE = has_triton_tle()
39def has_tle_device_mesh() -> bool:
40 """Check if TLE device_mesh is available."""
41 if not HAS_TLE:
42 return False
43 try:
44 import triton.experimental.tle.language as tle_exp
46 return hasattr(tle_exp, "device_mesh")
47 except ImportError:
48 return False
51HAS_TLE_DEVICE_MESH = has_tle_device_mesh()