Coverage for src/flag_gems/runtime/backend/_arm/ops/div.py: 0%
115 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems.ops import div as base_div
11@triton.jit
12def _div_tensor_scalar_kernel(
13 x_ptr,
14 out_ptr,
15 scalar,
16 n_elements,
17 BLOCK_SIZE: tl.constexpr,
18):
19 pid = tl.program_id(0)
20 num_prog = tl.num_programs(0)
21 start = pid * BLOCK_SIZE
22 step = num_prog * BLOCK_SIZE
23 for off in range(start, n_elements, step):
24 offsets = off + tl.arange(0, BLOCK_SIZE)
25 mask = offsets < n_elements
26 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
27 y = x / scalar
28 tl.store(out_ptr + offsets, y, mask=mask)
31def _select_block_size(n_elements, dtype):
32 if n_elements >= (1 << 20):
33 return 512 if dtype in (torch.float16, torch.bfloat16) else 256
34 if n_elements >= (1 << 18):
35 return 256 if dtype in (torch.float16, torch.bfloat16) else 128
36 return 256 if dtype in (torch.float16, torch.bfloat16) else 128
39def _maybe_contiguous(x, out):
40 if x.is_contiguous():
41 return x, out, False
42 if out is None:
43 return x.contiguous(), out, True
44 if out.is_contiguous():
45 return x.contiguous(), out, True
46 return x, out, False
49def _div_tensor_scalar_triton(x, scalar, out=None):
50 n_elements = x.numel()
51 if n_elements == 0:
52 return x if out is None else out
53 if n_elements == 1 and x.dtype is torch.bfloat16:
54 val = x.item() / scalar
55 if out is None:
56 out = torch.empty_like(x)
57 out.fill_(val)
58 return out
60 block_size = _select_block_size(n_elements, x.dtype)
61 block_size = min(block_size, triton.next_power_of_2(max(n_elements, 1)))
62 num_blocks = triton.cdiv(n_elements, block_size)
63 grid = (num_blocks,)
64 x_contig, out_contig, _ = _maybe_contiguous(x, out)
65 if out_contig is None:
66 out_contig = torch.empty_like(x_contig)
67 num_warps = 1
68 _div_tensor_scalar_kernel[grid](
69 x_contig,
70 out_contig,
71 scalar,
72 n_elements,
73 BLOCK_SIZE=block_size,
74 num_warps=num_warps,
75 )
76 return out_contig
79def _maybe_get_scalar_tensor(val):
80 if isinstance(val, torch.Tensor) and val.numel() == 1:
81 return val.item()
82 return None
85def true_divide(A, B):
86 logging.debug("GEMS_ARM TRUE_DIVIDE")
87 if os.environ.get("GEMS_DEBUG_DIV") == "1":
88 a_shape = tuple(A.shape) if isinstance(A, torch.Tensor) else None
89 b_shape = tuple(B.shape) if isinstance(B, torch.Tensor) else None
90 print(f"[GEMS_DEBUG_DIV] true_divide: A={a_shape} B={b_shape}")
91 if isinstance(A, torch.Tensor) and not isinstance(B, torch.Tensor):
92 return _div_tensor_scalar_triton(A, B)
93 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
94 scalar = _maybe_get_scalar_tensor(B)
95 if scalar is not None:
96 return _div_tensor_scalar_triton(A, scalar)
97 return base_div.true_divide(A, B)
100def true_divide_(A, B):
101 logging.debug("GEMS_ARM TRUE_DIVIDE_")
102 if isinstance(A, torch.Tensor) and not isinstance(B, torch.Tensor):
103 if A.is_contiguous():
104 return _div_tensor_scalar_triton(A, B, out=A)
105 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
106 scalar = _maybe_get_scalar_tensor(B)
107 if scalar is not None and A.is_contiguous():
108 return _div_tensor_scalar_triton(A, scalar, out=A)
109 return base_div.true_divide_(A, B)
112def trunc_divide(A, B):
113 logging.debug("GEMS_ARM TRUNC_DIVIDE")
114 return base_div.trunc_divide(A, B)
117def trunc_divide_(A, B):
118 logging.debug("GEMS_ARM TRUNC_DIVIDE_")
119 return base_div.trunc_divide_(A, B)
122def floor_divide(A, B):
123 logging.debug("GEMS_ARM FLOOR_DIVIDE")
124 return base_div.floor_divide(A, B)
127def floor_divide_(A, B):
128 logging.debug("GEMS_ARM FLOOR_DIVIDE_")
129 return base_div.floor_divide_(A, B)
132def div_mode(A, B, rounding_mode=None):
133 if rounding_mode is None:
134 return true_divide(A, B)
135 if rounding_mode == "trunc":
136 return trunc_divide(A, B)
137 if rounding_mode == "floor":
138 return floor_divide(A, B)
139 msg = (
140 "div expected rounding_mode to be one of None, 'trunc', or 'floor' "
141 f"but found {rounding_mode}."
142 )
143 raise ValueError(msg)
146def div_mode_(A, B, rounding_mode=None):
147 if rounding_mode is None:
148 return true_divide_(A, B)
149 if rounding_mode == "trunc":
150 return trunc_divide_(A, B)
151 if rounding_mode == "floor":
152 return floor_divide_(A, B)
153 msg = (
154 "div expected rounding_mode to be one of None, 'trunc', or 'floor' "
155 f"but found {rounding_mode}."
156 )
157 raise ValueError(msg)
160def remainder(A, B):
161 logging.debug("GEMS_ARM REMAINDER")
162 return base_div.remainder(A, B)
165def remainder_(A, B):
166 logging.debug("GEMS_ARM REMAINDER_")
167 return base_div.remainder_(A, B)