Coverage for src/flag_gems/runtime/backend/_mthreads/ops/utils.py: 0%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import os
2from collections import OrderedDict
4import numpy as np
5import torch
6import triton
7import triton.language as tl
9_TMA_DESCRIPTOR_CACHE_MAXSIZE = 256
10_tma_descriptor_cache = OrderedDict()
12# Detect once whether fill_2d_tma_descriptor expects a pointer (int) or numpy array.
13# triton >= 3.2 changed the last parameter from numpy array to int pointer.
14_fill_2d_tma = triton.runtime.driver.active.utils.fill_2d_tma_descriptor
15_tma_desc_wants_ptr = tuple(int(x) for x in triton.__version__.split(".")[:2]) >= (3, 2)
18def _tma_desc_arg(desc_np):
19 return int(desc_np.ctypes.data) if _tma_desc_wants_ptr else desc_np
22def create_tma_device_descriptor(tensor, block_m, block_n, device):
23 assert tensor.dim() == 2, "TMA descriptor only supports 2D tensors"
24 TMA_DESCRIPTOR_SIZE = 64
25 desc_np = np.empty(TMA_DESCRIPTOR_SIZE, dtype=np.int8)
26 shapes = [tensor.shape[0], tensor.shape[1]]
27 if not tensor.is_contiguous():
28 assert (
29 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
30 ), "TMA descriptor only supports contiguous or transposed 2D tensors"
31 shapes.reverse()
32 _fill_2d_tma(
33 tensor.data_ptr(),
34 shapes[0],
35 shapes[1],
36 block_m,
37 block_n,
38 tensor.element_size(),
39 _tma_desc_arg(desc_np),
40 )
41 desc = torch.tensor(desc_np, device=device)
42 return desc
45def _tma_descriptor_cache_key(tensor, block_m, block_n, device):
46 return (
47 tensor.data_ptr(),
48 tuple(tensor.shape),
49 tuple(tensor.stride()),
50 str(tensor.dtype),
51 block_m,
52 block_n,
53 str(device),
54 )
57def get_cached_tma_device_descriptor(tensor, block_m, block_n, device):
58 key = _tma_descriptor_cache_key(tensor, block_m, block_n, device)
59 desc = _tma_descriptor_cache.get(key)
60 if desc is not None:
61 _tma_descriptor_cache.move_to_end(key)
62 return desc
64 desc = create_tma_device_descriptor(tensor, block_m, block_n, device)
65 _tma_descriptor_cache[key] = desc
66 if len(_tma_descriptor_cache) > _TMA_DESCRIPTOR_CACHE_MAXSIZE:
67 _tma_descriptor_cache.popitem(last=False)
68 return desc
71def get_triton_dtype(dtype):
72 dtype_map = {
73 torch.float16: tl.float16,
74 torch.bfloat16: tl.bfloat16,
75 torch.float32: tl.float32,
76 }
77 return dtype_map.get(dtype, None)
80def should_enable_sqmma(a_dtype, b_dtype, M, N, K):
81 return (
82 (os.getenv("MUSA_ENABLE_SQMMA", "0") == "1")
83 and (a_dtype in [torch.float16, torch.bfloat16] and a_dtype.itemsize == 2)
84 and ((M, N, K) not in [(1, 1, 32), (15, 160, 1024), (495, 5333, 71)])
85 )