Coverage for src/flag_gems/utils/triton_version_utils.py: 65%

26 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import re 

2 

3import triton 

4from packaging.version import InvalidVersion, Version 

5 

6 

7def _coerce_triton_version(version: str) -> Version: 

8 try: 

9 return Version(version) 

10 except InvalidVersion: 

11 release = [] 

12 for part in version.split("+", 1)[0].split(".")[:3]: 

13 match = re.match(r"\d+", part) 

14 release.append(match.group(0) if match else "0") 

15 while len(release) < 3: 

16 release.append("0") 

17 return Version(".".join(release)) 

18 

19 

20def _triton_version_at_least(major: int, minor: int, patch: int = 0) -> bool: 

21 version = str(getattr(triton, "__version__", "0.0.0")) 

22 return _coerce_triton_version(version) >= Version(f"{major}.{minor}.{patch}") 

23 

24 

25def has_triton_tle(major: int = 0, minor: int = 0, patch: int = 0) -> bool: 

26 if not _triton_version_at_least(major, minor, patch): 

27 return False 

28 try: 

29 import triton.experimental.tle.language as _tle # noqa: F401 

30 

31 return True 

32 except ImportError: 

33 return False 

34 

35 

36HAS_TLE = has_triton_tle()