Coverage for src/flag_gems/runtime/backend/_arm/ops/pow.py: 0%
296 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import os
4import numpy as np
5import torch
6import triton
7import triton.language as tl
9from flag_gems.ops.pow import pow_scalar as base_pow_scalar
10from flag_gems.ops.pow import pow_tensor_scalar as base_pow_tensor_scalar
11from flag_gems.ops.pow import pow_tensor_scalar_ as base_pow_tensor_scalar_
12from flag_gems.ops.pow import pow_tensor_tensor as base_pow_tensor_tensor
13from flag_gems.ops.pow import pow_tensor_tensor_ as base_pow_tensor_tensor_
15# For small tensors, bypass Triton entirely via numpy (zero-copy views).
16_POW_NATIVE_THRESHOLD = 4096
18_PREWARM_POW_DONE = False
19_POW_SQUARE_HOT_ENABLED = os.environ.get("GEMS_ARM_POW_SQUARE_HOT", "1") == "1"
20_POW_TRITON_ENABLED = os.environ.get("GEMS_ARM_POW_TRITON", "1") == "1"
21_POW_PREWARM_ENABLED = os.environ.get("GEMS_ARM_POW_PREWARM", "1") == "1"
24@triton.jit
25def _pow_square_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
26 pid = tl.program_id(0)
27 num_prog = tl.num_programs(0)
28 start = pid * BLOCK_SIZE
29 step = num_prog * BLOCK_SIZE
30 for off in range(start, n_elements, step):
31 offsets = off + tl.arange(0, BLOCK_SIZE)
32 mask = offsets < n_elements
33 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
34 tl.store(out_ptr + offsets, x * x, mask=mask)
37@triton.jit
38def _pow_square_single_program_kernel(
39 x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr
40):
41 offs = tl.arange(0, BLOCK_SIZE)
42 for base in range(0, n_elements, BLOCK_SIZE):
43 idx = base + offs
44 mask = idx < n_elements
45 x = tl.load(x_ptr + idx, mask=mask, other=0.0)
46 tl.store(out_ptr + idx, x * x, mask=mask)
49@triton.jit
50def _pow_square_1024_hot_kernel(
51 x_ptr,
52 out_ptr,
53):
54 offs = tl.arange(0, 256)
55 for base in range(0, 1024, 256):
56 x = tl.load(x_ptr + base + offs)
57 tl.store(out_ptr + base + offs, x * x)
60@triton.jit
61def _pow_square_2048_hot_kernel(
62 x_ptr,
63 out_ptr,
64):
65 offs = tl.arange(0, 256)
66 for base in range(0, 2048, 256):
67 x = tl.load(x_ptr + base + offs)
68 tl.store(out_ptr + base + offs, x * x)
71@triton.jit(do_not_specialize=["rows"])
72def _pow_square_rows128_hot_kernel(
73 x_ptr,
74 out_ptr,
75 rows,
76 MAX_ROWS: tl.constexpr,
77):
78 offs = tl.arange(0, 128)
79 for row in range(0, MAX_ROWS):
80 if row < rows:
81 base = row * 128
82 x = tl.load(x_ptr + base + offs)
83 tl.store(out_ptr + base + offs, x * x)
86@triton.jit(do_not_specialize=["rows"])
87def _pow_square_rows1024_hot_kernel(
88 x_ptr,
89 out_ptr,
90 rows,
91 MAX_ROWS: tl.constexpr,
92):
93 offs = tl.arange(0, 256)
94 for row in range(0, MAX_ROWS):
95 if row < rows:
96 base = row * 1024
97 for k in range(0, 1024, 256):
98 x = tl.load(x_ptr + base + k + offs)
99 tl.store(out_ptr + base + k + offs, x * x)
102@triton.jit
103def _pow_square_3584_hot_kernel(
104 x_ptr,
105 out_ptr,
106):
107 offs = tl.arange(0, 256)
108 for base in range(0, 3584, 256):
109 x = tl.load(x_ptr + base + offs)
110 tl.store(out_ptr + base + offs, x * x)
113@triton.jit(do_not_specialize=["rows"])
114def _pow_square_rows3584_hot_kernel(
115 x_ptr,
116 out_ptr,
117 rows,
118 MAX_ROWS: tl.constexpr,
119):
120 offs = tl.arange(0, 256)
121 for row in range(0, MAX_ROWS):
122 if row < rows:
123 base = row * 3584
124 for k in range(0, 3584, 256):
125 x = tl.load(x_ptr + base + k + offs)
126 tl.store(out_ptr + base + k + offs, x * x)
129@triton.jit
130def _pow_sqrt_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
131 pid = tl.program_id(0)
132 num_prog = tl.num_programs(0)
133 start = pid * BLOCK_SIZE
134 step = num_prog * BLOCK_SIZE
135 for off in range(start, n_elements, step):
136 offsets = off + tl.arange(0, BLOCK_SIZE)
137 mask = offsets < n_elements
138 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
139 y = tl.sqrt(x.to(tl.float32)).to(out_ptr.dtype.element_ty)
140 tl.store(out_ptr + offsets, y, mask=mask)
143@triton.jit
144def _pow_sqrt_single_program_kernel(
145 x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr
146):
147 offs = tl.arange(0, BLOCK_SIZE)
148 for base in range(0, n_elements, BLOCK_SIZE):
149 idx = base + offs
150 mask = idx < n_elements
151 x = tl.load(x_ptr + idx, mask=mask, other=0.0)
152 y = tl.sqrt(x.to(tl.float32)).to(out_ptr.dtype.element_ty)
153 tl.store(out_ptr + idx, y, mask=mask)
156@triton.jit
157def _pow_rsqrt_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
158 pid = tl.program_id(0)
159 num_prog = tl.num_programs(0)
160 start = pid * BLOCK_SIZE
161 step = num_prog * BLOCK_SIZE
162 for off in range(start, n_elements, step):
163 offsets = off + tl.arange(0, BLOCK_SIZE)
164 mask = offsets < n_elements
165 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
166 y = (1.0 / tl.sqrt(x.to(tl.float32))).to(out_ptr.dtype.element_ty)
167 tl.store(out_ptr + offsets, y, mask=mask)
170@triton.jit
171def _pow_rsqrt_single_program_kernel(
172 x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr
173):
174 offs = tl.arange(0, BLOCK_SIZE)
175 for base in range(0, n_elements, BLOCK_SIZE):
176 idx = base + offs
177 mask = idx < n_elements
178 x = tl.load(x_ptr + idx, mask=mask, other=0.0)
179 y = (1.0 / tl.sqrt(x.to(tl.float32))).to(out_ptr.dtype.element_ty)
180 tl.store(out_ptr + idx, y, mask=mask)
183def _select_block_size(n_elements, dtype):
184 # Tuned for Qwen decode hotspot shapes on triton-cpu.
185 if n_elements <= 32:
186 return 32
187 if n_elements <= 1024:
188 return 128
189 if n_elements <= 2048:
190 return 128
191 if n_elements <= 4096:
192 return 128
193 if n_elements <= (1 << 16):
194 return 128
195 return 256 if dtype in (torch.float16, torch.bfloat16) else 128
198def _single_program_block(n_elements):
199 if n_elements <= 256:
200 return 32
201 if n_elements <= 2048:
202 return 128
203 return 256
206def _maybe_scalar(v):
207 if isinstance(v, torch.Tensor) and v.numel() == 1:
208 return float(v.item())
209 if isinstance(v, (int, float)):
210 return float(v)
211 return None
214def _is_supported_tensor(t):
215 return (
216 isinstance(t, torch.Tensor)
217 and t.device.type == "cpu"
218 and t.dtype
219 in (
220 torch.float16,
221 torch.bfloat16,
222 torch.float32,
223 torch.float64,
224 )
225 )
228def _launch_pow_kernel(
229 multi_kernel, single_kernel, x, out_tensor, n_elements, block_size
230):
231 if 1 < n_elements <= 8192:
232 single_block = _single_program_block(n_elements)
233 single_kernel[(1,)](
234 x,
235 out_tensor,
236 n_elements,
237 BLOCK_SIZE=single_block,
238 num_warps=1,
239 num_stages=1,
240 )
241 return
242 grid = (triton.cdiv(n_elements, block_size),)
243 multi_kernel[grid](
244 x,
245 out_tensor,
246 n_elements,
247 BLOCK_SIZE=block_size,
248 num_warps=1,
249 num_stages=1,
250 )
253def _maybe_launch_pow_square_hotshape(x, out_tensor, n_elements):
254 if not _POW_SQUARE_HOT_ENABLED:
255 return False
256 if not x.is_contiguous() or x.numel() == 0:
257 return False
258 if x.ndim == 0:
259 return False
260 last_dim = x.shape[-1]
261 if last_dim == 128:
262 rows = n_elements // 128
263 if rows > 0 and rows <= 96 and rows * 128 == n_elements:
264 _pow_square_rows128_hot_kernel[(1,)](
265 x,
266 out_tensor,
267 rows,
268 MAX_ROWS=96,
269 num_warps=1,
270 num_stages=1,
271 )
272 return True
273 if last_dim == 1024:
274 rows = n_elements // 1024
275 if rows > 0 and rows <= 16 and rows * 1024 == n_elements:
276 _pow_square_rows1024_hot_kernel[(1,)](
277 x,
278 out_tensor,
279 rows,
280 MAX_ROWS=16,
281 num_warps=1,
282 num_stages=1,
283 )
284 return True
285 if last_dim == 3584:
286 rows = n_elements // 3584
287 if rows > 0 and rows <= 128 and rows * 3584 == n_elements:
288 _pow_square_rows3584_hot_kernel[(1,)](
289 x,
290 out_tensor,
291 rows,
292 MAX_ROWS=128,
293 num_warps=1,
294 num_stages=1,
295 )
296 return True
297 return False
300def _pow_tensor_scalar_special(x, exponent, out=None):
301 if not _is_supported_tensor(x):
302 return None
303 if not x.is_contiguous():
304 return None
305 if out is not None and not out.is_contiguous():
306 return None
307 if not _POW_TRITON_ENABLED:
308 return None
310 if exponent == 2.0:
311 kernel = _pow_square_kernel
312 single_kernel = _pow_square_single_program_kernel
313 elif exponent == 0.5:
314 kernel = _pow_sqrt_kernel
315 single_kernel = _pow_sqrt_single_program_kernel
316 elif exponent == -0.5:
317 kernel = _pow_rsqrt_kernel
318 single_kernel = _pow_rsqrt_single_program_kernel
319 else:
320 return None
322 n_elements = x.numel()
323 if n_elements == 0:
324 return x if out is None else out
326 block_size = _select_block_size(n_elements, x.dtype)
327 out_tensor = torch.empty_like(x) if out is None else out
328 if exponent == 2.0:
329 if n_elements == 1024 and x.is_contiguous():
330 _pow_square_1024_hot_kernel[(1,)](
331 x,
332 out_tensor,
333 num_warps=1,
334 num_stages=1,
335 )
336 return out_tensor
337 if n_elements == 3584 and x.is_contiguous():
338 _pow_square_3584_hot_kernel[(1,)](
339 x,
340 out_tensor,
341 num_warps=1,
342 num_stages=1,
343 )
344 return out_tensor
345 if n_elements == 2048 and x.is_contiguous():
346 _pow_square_2048_hot_kernel[(1,)](
347 x,
348 out_tensor,
349 num_warps=1,
350 num_stages=1,
351 )
352 return out_tensor
353 if _maybe_launch_pow_square_hotshape(x, out_tensor, n_elements):
354 return out_tensor
355 _launch_pow_kernel(kernel, single_kernel, x, out_tensor, n_elements, block_size)
356 return out_tensor
359def _maybe_prewarm_pow_kernels():
360 global _PREWARM_POW_DONE
361 if _PREWARM_POW_DONE:
362 return
363 if not _POW_PREWARM_ENABLED:
364 _PREWARM_POW_DONE = True
365 return
366 try:
367 for dt in (torch.float32, torch.bfloat16):
368 x1024 = torch.ones((1, 1, 1024), dtype=dt, device="cpu")
369 out1024 = torch.empty_like(x1024)
370 _pow_square_1024_hot_kernel[(1,)](
371 x1024,
372 out1024,
373 num_warps=1,
374 num_stages=1,
375 )
377 x2048 = torch.ones((1, 16, 1, 128), dtype=dt, device="cpu")
378 out2048 = torch.empty_like(x2048)
379 _pow_square_2048_hot_kernel[(1,)](
380 x2048,
381 out2048,
382 num_warps=1,
383 num_stages=1,
384 )
386 rows = x2048.numel() // 128
387 _pow_square_rows128_hot_kernel[(1,)](
388 x2048,
389 out2048,
390 rows,
391 MAX_ROWS=96,
392 num_warps=1,
393 num_stages=1,
394 )
396 x3584 = torch.ones((1, 1, 3584), dtype=dt, device="cpu")
397 out3584 = torch.empty_like(x3584)
398 _pow_square_3584_hot_kernel[(1,)](
399 x3584,
400 out3584,
401 num_warps=1,
402 num_stages=1,
403 )
405 x_rows3584 = torch.ones((1, 128, 3584), dtype=dt, device="cpu")
406 out_rows3584 = torch.empty_like(x_rows3584)
407 _pow_square_rows3584_hot_kernel[(1,)](
408 x_rows3584,
409 out_rows3584,
410 128,
411 MAX_ROWS=128,
412 num_warps=1,
413 num_stages=1,
414 )
416 block1024 = _select_block_size(x1024.numel(), x1024.dtype)
417 _launch_pow_kernel(
418 _pow_square_kernel,
419 _pow_square_single_program_kernel,
420 x1024,
421 out1024,
422 x1024.numel(),
423 block1024,
424 )
425 except Exception:
426 logging.debug("GEMS ARM pow prewarm failed", exc_info=True)
427 _PREWARM_POW_DONE = True
430def pow_tensor_tensor(A, exponent):
431 logging.debug("GEMS_ARM POW_TENSOR_TENSOR")
432 if (
433 isinstance(A, torch.Tensor)
434 and A.numel() < _POW_NATIVE_THRESHOLD
435 and A.is_contiguous()
436 ):
437 return torch.from_numpy(
438 np.power(
439 A.detach().numpy(),
440 float(exponent)
441 if not isinstance(exponent, torch.Tensor)
442 else exponent.detach().numpy(),
443 )
444 )
445 _maybe_prewarm_pow_kernels()
446 scalar_exp = _maybe_scalar(exponent)
447 if scalar_exp is not None:
448 special = _pow_tensor_scalar_special(A, scalar_exp)
449 if special is not None:
450 return special
451 return base_pow_tensor_scalar(A, scalar_exp)
452 return base_pow_tensor_tensor(A, exponent)
455def pow_tensor_tensor_(A, exponent):
456 logging.debug("GEMS_ARM POW_TENSOR_TENSOR_")
457 _maybe_prewarm_pow_kernels()
458 scalar_exp = _maybe_scalar(exponent)
459 if scalar_exp is not None:
460 special = _pow_tensor_scalar_special(A, scalar_exp, out=A)
461 if special is not None:
462 return special
463 return base_pow_tensor_scalar_(A, scalar_exp)
464 return base_pow_tensor_tensor_(A, exponent)
467def pow_tensor_scalar(A, exponent):
468 logging.debug("GEMS_ARM POW_TENSOR_SCALAR")
469 if (
470 isinstance(A, torch.Tensor)
471 and A.numel() < _POW_NATIVE_THRESHOLD
472 and A.is_contiguous()
473 ):
474 exp = (
475 float(exponent)
476 if not isinstance(exponent, torch.Tensor)
477 else exponent.item()
478 )
479 if exp == 2.0:
480 an = A.detach().numpy()
481 return torch.from_numpy(np.multiply(an, an))
482 return torch.from_numpy(np.power(A.detach().numpy(), exp))
483 _maybe_prewarm_pow_kernels()
484 scalar_exp = _maybe_scalar(exponent)
485 if scalar_exp is not None:
486 special = _pow_tensor_scalar_special(A, scalar_exp)
487 if special is not None:
488 return special
489 return base_pow_tensor_scalar(A, scalar_exp)
490 return base_pow_tensor_scalar(A, exponent)
493def pow_tensor_scalar_(A, exponent):
494 logging.debug("GEMS_ARM POW_TENSOR_SCALAR_")
495 _maybe_prewarm_pow_kernels()
496 scalar_exp = _maybe_scalar(exponent)
497 if scalar_exp is not None:
498 special = _pow_tensor_scalar_special(A, scalar_exp, out=A)
499 if special is not None:
500 return special
501 return base_pow_tensor_scalar_(A, scalar_exp)
502 return base_pow_tensor_scalar_(A, exponent)
505def pow_scalar(A, exponent):
506 logging.debug("GEMS_ARM POW_SCALAR")
507 _maybe_prewarm_pow_kernels()
508 return base_pow_scalar(A, exponent)
511_maybe_prewarm_pow_kernels()