Coverage for src/flag_gems/runtime/backend/_arm/ops/sub.py: 0%
279 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import pointwise_dynamic
10logger = logging.getLogger(__name__)
11_PREWARM_SUB_DONE = False
13_SUPPORTED_FAST_DTYPES = (
14 torch.float16,
15 torch.bfloat16,
16 torch.float32,
17 torch.float64,
18)
19_SUPPORTED_INT_FAST_DTYPES = (
20 torch.int8,
21 torch.int16,
22 torch.int32,
23 torch.int64,
24)
27@triton.jit(do_not_specialize=["alpha", "n_elements"])
28def _sub_contiguous_kernel(
29 x_ptr,
30 y_ptr,
31 out_ptr,
32 alpha,
33 n_elements,
34 BLOCK_SIZE: tl.constexpr,
35):
36 pid = tl.program_id(0)
37 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
38 mask = offsets < n_elements
39 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
40 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
41 tl.store(out_ptr + offsets, x - y * alpha, mask=mask)
44@triton.jit(do_not_specialize=["alpha", "n_elements"])
45def _sub_contiguous_single_program_kernel(
46 x_ptr,
47 y_ptr,
48 out_ptr,
49 alpha,
50 n_elements,
51 BLOCK_SIZE: tl.constexpr,
52):
53 offs = tl.arange(0, BLOCK_SIZE)
54 for base in range(0, n_elements, BLOCK_SIZE):
55 idx = base + offs
56 mask = idx < n_elements
57 x = tl.load(x_ptr + idx, mask=mask, other=0.0)
58 y = tl.load(y_ptr + idx, mask=mask, other=0.0)
59 tl.store(out_ptr + idx, x - y * alpha, mask=mask)
62@triton.jit(do_not_specialize=["alpha", "rows", "cols"])
63def _sub_broadcast_lastdim1_kernel(
64 x_ptr,
65 y_ptr,
66 out_ptr,
67 alpha,
68 rows,
69 cols,
70 BLOCK_SIZE: tl.constexpr,
71):
72 row = tl.program_id(0)
73 if row >= rows:
74 return
76 y = tl.load(y_ptr + row)
77 offs = tl.arange(0, BLOCK_SIZE)
78 row_start = row * cols
79 for base in range(0, cols, BLOCK_SIZE):
80 col = base + offs
81 mask = col < cols
82 x = tl.load(x_ptr + row_start + col, mask=mask, other=0.0)
83 tl.store(out_ptr + row_start + col, x - y * alpha, mask=mask)
86@triton.jit(do_not_specialize=["scalar", "alpha", "n_elements"])
87def _sub_tensor_scalar_kernel(
88 x_ptr,
89 scalar,
90 out_ptr,
91 alpha,
92 n_elements,
93 BLOCK_SIZE: tl.constexpr,
94):
95 pid = tl.program_id(0)
96 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
97 mask = offsets < n_elements
98 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
99 tl.store(out_ptr + offsets, x - scalar * alpha, mask=mask)
102@triton.jit(do_not_specialize=["scalar", "alpha", "n_elements"])
103def _sub_tensor_scalar_single_program_kernel(
104 x_ptr,
105 scalar,
106 out_ptr,
107 alpha,
108 n_elements,
109 BLOCK_SIZE: tl.constexpr,
110):
111 offs = tl.arange(0, BLOCK_SIZE)
112 for base in range(0, n_elements, BLOCK_SIZE):
113 idx = base + offs
114 mask = idx < n_elements
115 x = tl.load(x_ptr + idx, mask=mask, other=0.0)
116 tl.store(out_ptr + idx, x - scalar * alpha, mask=mask)
119@triton.jit(do_not_specialize=["scalar", "n_elements"])
120def _sub_tensor_scalar_int_kernel(
121 x_ptr,
122 scalar,
123 out_ptr,
124 n_elements,
125 BLOCK_SIZE: tl.constexpr,
126):
127 pid = tl.program_id(0)
128 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
129 mask = offsets < n_elements
130 x = tl.load(x_ptr + offsets, mask=mask, other=0)
131 tl.store(out_ptr + offsets, x - scalar, mask=mask)
134@triton.jit(do_not_specialize=["scalar", "n_elements"])
135def _sub_tensor_scalar_int_single_program_kernel(
136 x_ptr,
137 scalar,
138 out_ptr,
139 n_elements,
140 BLOCK_SIZE: tl.constexpr,
141):
142 offs = tl.arange(0, BLOCK_SIZE)
143 for base in range(0, n_elements, BLOCK_SIZE):
144 idx = base + offs
145 mask = idx < n_elements
146 x = tl.load(x_ptr + idx, mask=mask, other=0)
147 tl.store(out_ptr + idx, x - scalar, mask=mask)
150@triton.jit(do_not_specialize=["scalar", "alpha", "n_elements"])
151def _sub_scalar_tensor_kernel(
152 scalar,
153 y_ptr,
154 out_ptr,
155 alpha,
156 n_elements,
157 BLOCK_SIZE: tl.constexpr,
158):
159 pid = tl.program_id(0)
160 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
161 mask = offsets < n_elements
162 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
163 tl.store(out_ptr + offsets, scalar - y * alpha, mask=mask)
166@triton.jit(do_not_specialize=["scalar", "alpha", "n_elements"])
167def _sub_scalar_tensor_single_program_kernel(
168 scalar,
169 y_ptr,
170 out_ptr,
171 alpha,
172 n_elements,
173 BLOCK_SIZE: tl.constexpr,
174):
175 offs = tl.arange(0, BLOCK_SIZE)
176 for base in range(0, n_elements, BLOCK_SIZE):
177 idx = base + offs
178 mask = idx < n_elements
179 y = tl.load(y_ptr + idx, mask=mask, other=0.0)
180 tl.store(out_ptr + idx, scalar - y * alpha, mask=mask)
183@triton.jit(do_not_specialize=["scalar", "n_elements"])
184def _sub_scalar_tensor_int_kernel(
185 scalar,
186 y_ptr,
187 out_ptr,
188 n_elements,
189 BLOCK_SIZE: tl.constexpr,
190):
191 pid = tl.program_id(0)
192 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
193 mask = offsets < n_elements
194 y = tl.load(y_ptr + offsets, mask=mask, other=0)
195 tl.store(out_ptr + offsets, scalar - y, mask=mask)
198@triton.jit(do_not_specialize=["scalar", "n_elements"])
199def _sub_scalar_tensor_int_single_program_kernel(
200 scalar,
201 y_ptr,
202 out_ptr,
203 n_elements,
204 BLOCK_SIZE: tl.constexpr,
205):
206 offs = tl.arange(0, BLOCK_SIZE)
207 for base in range(0, n_elements, BLOCK_SIZE):
208 idx = base + offs
209 mask = idx < n_elements
210 y = tl.load(y_ptr + idx, mask=mask, other=0)
211 tl.store(out_ptr + idx, scalar - y, mask=mask)
214@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
215@triton.jit
216def sub_func(x, y, alpha):
217 return x - y * alpha
220@pointwise_dynamic(
221 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
222)
223@triton.jit
224def sub_func_tensor_scalar(x, y, alpha):
225 return x - y * alpha
228@pointwise_dynamic(
229 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
230)
231@triton.jit
232def sub_func_scalar_tensor(x, y, alpha):
233 return x - y * alpha
236def _select_block_size(n_elements, dtype):
237 if n_elements <= 32:
238 return 32
239 if n_elements <= 1024:
240 return 32
241 if n_elements <= 8192:
242 return 64
243 return 256 if dtype in (torch.float16, torch.bfloat16) else 128
246def _single_program_block(n_elements):
247 if n_elements <= 256:
248 return 32
249 if n_elements <= 2048:
250 return 128
251 return 256
254def _launch_sub_tensor_tensor(x, y, out, alpha, n_elements, block_size):
255 if 1 < n_elements <= 8192:
256 single_block = _single_program_block(n_elements)
257 _sub_contiguous_single_program_kernel[(1,)](
258 x,
259 y,
260 out,
261 alpha,
262 n_elements,
263 BLOCK_SIZE=single_block,
264 num_warps=1,
265 num_stages=1,
266 )
267 return
269 grid = (triton.cdiv(n_elements, block_size),)
270 _sub_contiguous_kernel[grid](
271 x,
272 y,
273 out,
274 alpha,
275 n_elements,
276 BLOCK_SIZE=block_size,
277 num_warps=1,
278 num_stages=1,
279 )
282def _launch_sub_tensor_scalar(x, scalar, out, alpha, n_elements, block_size):
283 if 1 < n_elements <= 8192:
284 single_block = _single_program_block(n_elements)
285 _sub_tensor_scalar_single_program_kernel[(1,)](
286 x,
287 scalar,
288 out,
289 alpha,
290 n_elements,
291 BLOCK_SIZE=single_block,
292 num_warps=1,
293 num_stages=1,
294 )
295 return
297 grid = (triton.cdiv(n_elements, block_size),)
298 _sub_tensor_scalar_kernel[grid](
299 x,
300 scalar,
301 out,
302 alpha,
303 n_elements,
304 BLOCK_SIZE=block_size,
305 num_warps=1,
306 num_stages=1,
307 )
310def _launch_sub_tensor_scalar_int(x, scalar, out, n_elements, block_size):
311 if 1 < n_elements <= 8192:
312 single_block = _single_program_block(n_elements)
313 _sub_tensor_scalar_int_single_program_kernel[(1,)](
314 x,
315 scalar,
316 out,
317 n_elements,
318 BLOCK_SIZE=single_block,
319 num_warps=1,
320 num_stages=1,
321 )
322 return
324 grid = (triton.cdiv(n_elements, block_size),)
325 _sub_tensor_scalar_int_kernel[grid](
326 x,
327 scalar,
328 out,
329 n_elements,
330 BLOCK_SIZE=block_size,
331 num_warps=1,
332 num_stages=1,
333 )
336def _launch_sub_broadcast_lastdim1(x, y, out, alpha):
337 rows = x.numel() // x.shape[-1]
338 cols = x.shape[-1]
339 if rows == 0 or cols == 0:
340 return
341 if cols <= 1024:
342 block_size = 64
343 elif cols <= 4096:
344 block_size = 128
345 else:
346 block_size = 256
347 grid = (rows,)
348 _sub_broadcast_lastdim1_kernel[grid](
349 x,
350 y,
351 out,
352 alpha,
353 rows,
354 cols,
355 BLOCK_SIZE=block_size,
356 num_warps=1,
357 num_stages=1,
358 )
361def _launch_sub_scalar_tensor(scalar, y, out, alpha, n_elements, block_size):
362 if 1 < n_elements <= 8192:
363 single_block = _single_program_block(n_elements)
364 _sub_scalar_tensor_single_program_kernel[(1,)](
365 scalar,
366 y,
367 out,
368 alpha,
369 n_elements,
370 BLOCK_SIZE=single_block,
371 num_warps=1,
372 num_stages=1,
373 )
374 return
376 grid = (triton.cdiv(n_elements, block_size),)
377 _sub_scalar_tensor_kernel[grid](
378 scalar,
379 y,
380 out,
381 alpha,
382 n_elements,
383 BLOCK_SIZE=block_size,
384 num_warps=1,
385 num_stages=1,
386 )
389def _launch_sub_scalar_tensor_int(scalar, y, out, n_elements, block_size):
390 if 1 < n_elements <= 8192:
391 single_block = _single_program_block(n_elements)
392 _sub_scalar_tensor_int_single_program_kernel[(1,)](
393 scalar,
394 y,
395 out,
396 n_elements,
397 BLOCK_SIZE=single_block,
398 num_warps=1,
399 num_stages=1,
400 )
401 return
403 grid = (triton.cdiv(n_elements, block_size),)
404 _sub_scalar_tensor_int_kernel[grid](
405 scalar,
406 y,
407 out,
408 n_elements,
409 BLOCK_SIZE=block_size,
410 num_warps=1,
411 num_stages=1,
412 )
415def _can_use_contiguous_fastpath(a, b):
416 return (
417 isinstance(a, torch.Tensor)
418 and isinstance(b, torch.Tensor)
419 and a.device.type == "cpu"
420 and b.device == a.device
421 and a.is_contiguous()
422 and b.is_contiguous()
423 and a.shape == b.shape
424 and a.dtype == b.dtype
425 and a.dtype in _SUPPORTED_FAST_DTYPES
426 )
429def _can_use_broadcast_lastdim1_fastpath(a, b):
430 return (
431 isinstance(a, torch.Tensor)
432 and isinstance(b, torch.Tensor)
433 and a.device.type == "cpu"
434 and b.device == a.device
435 and a.is_contiguous()
436 and b.is_contiguous()
437 and a.ndim >= 1
438 and b.ndim == a.ndim
439 and a.shape[:-1] == b.shape[:-1]
440 and b.shape[-1] == 1
441 and a.dtype == b.dtype
442 and a.dtype in _SUPPORTED_FAST_DTYPES
443 )
446def _can_use_tensor_scalar_int_fastpath(a, scalar, alpha):
447 return (
448 isinstance(a, torch.Tensor)
449 and a.device.type == "cpu"
450 and a.is_contiguous()
451 and a.dtype in _SUPPORTED_INT_FAST_DTYPES
452 and isinstance(scalar, int)
453 and int(alpha) == 1
454 and float(alpha) == 1.0
455 )
458def _can_use_scalar_tensor_fastpath(b, scalar):
459 return (
460 isinstance(b, torch.Tensor)
461 and b.device.type == "cpu"
462 and b.is_contiguous()
463 and b.dtype in _SUPPORTED_FAST_DTYPES
464 and isinstance(scalar, (int, float))
465 )
468def _can_use_scalar_tensor_int_fastpath(b, scalar, alpha):
469 return (
470 isinstance(b, torch.Tensor)
471 and b.device.type == "cpu"
472 and b.is_contiguous()
473 and b.dtype in _SUPPORTED_INT_FAST_DTYPES
474 and isinstance(scalar, int)
475 and int(alpha) == 1
476 and float(alpha) == 1.0
477 )
480def _maybe_scalar(v):
481 if isinstance(v, torch.Tensor) and v.numel() == 1:
482 return v.item()
483 if isinstance(v, (int, float)):
484 return v
485 return None
488def _maybe_prewarm_sub_kernels():
489 global _PREWARM_SUB_DONE
490 if _PREWARM_SUB_DONE:
491 return
492 if os.environ.get("GEMS_ARM_SUB_PREWARM", "1") != "1":
493 _PREWARM_SUB_DONE = True
494 return
495 try:
496 x = torch.zeros(8, dtype=torch.float32, device="cpu")
497 y = torch.ones(8, dtype=torch.float32, device="cpu")
498 out = torch.empty_like(x)
499 _launch_sub_tensor_tensor(x, y, out, 1.0, x.numel(), 32)
500 _launch_sub_tensor_scalar(x, 1.0, out, 1.0, x.numel(), 32)
501 _launch_sub_scalar_tensor(1.0, x, out, 1.0, x.numel(), 32)
503 xi = torch.arange(8, dtype=torch.int64, device="cpu")
504 oi = torch.empty_like(xi)
505 _launch_sub_tensor_scalar_int(xi, 1, oi, xi.numel(), 32)
506 _launch_sub_scalar_tensor_int(1, xi, oi, xi.numel(), 32)
508 xb = torch.zeros((1, 5, 32), dtype=torch.float32, device="cpu")
509 yb = torch.zeros((1, 5, 1), dtype=torch.float32, device="cpu")
510 ob = torch.empty_like(xb)
511 _launch_sub_broadcast_lastdim1(xb, yb.view(-1), ob, 1.0)
512 except Exception:
513 logger.debug("GEMS ARM sub prewarm failed", exc_info=True)
514 _PREWARM_SUB_DONE = True
517def sub(A, B, *, alpha=1):
518 logger.debug("GEMS SUB")
519 _maybe_prewarm_sub_kernels()
521 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
522 if _can_use_contiguous_fastpath(A, B):
523 out = torch.empty_like(A)
524 block_size = _select_block_size(A.numel(), A.dtype)
525 _launch_sub_tensor_tensor(A, B, out, float(alpha), A.numel(), block_size)
526 return out
527 if _can_use_broadcast_lastdim1_fastpath(A, B):
528 out = torch.empty_like(A)
529 _launch_sub_broadcast_lastdim1(A, B.view(-1), out, float(alpha))
530 return out
531 return sub_func(A, B, alpha)
533 if isinstance(A, torch.Tensor):
534 scalar = _maybe_scalar(B)
535 if (
536 scalar is not None
537 and A.device.type == "cpu"
538 and A.is_contiguous()
539 and A.dtype in _SUPPORTED_FAST_DTYPES
540 ):
541 out = torch.empty_like(A)
542 block_size = _select_block_size(A.numel(), A.dtype)
543 _launch_sub_tensor_scalar(
544 A, float(scalar), out, float(alpha), A.numel(), block_size
545 )
546 return out
547 if _can_use_tensor_scalar_int_fastpath(A, scalar, alpha):
548 out = torch.empty_like(A)
549 block_size = _select_block_size(A.numel(), A.dtype)
550 _launch_sub_tensor_scalar_int(A, int(scalar), out, A.numel(), block_size)
551 return out
552 return sub_func_tensor_scalar(A, B, alpha)
554 if isinstance(B, torch.Tensor):
555 scalar = _maybe_scalar(A)
556 if _can_use_scalar_tensor_fastpath(B, scalar):
557 out = torch.empty_like(B)
558 block_size = _select_block_size(B.numel(), B.dtype)
559 _launch_sub_scalar_tensor(
560 float(scalar), B, out, float(alpha), B.numel(), block_size
561 )
562 return out
563 if _can_use_scalar_tensor_int_fastpath(B, scalar, alpha):
564 out = torch.empty_like(B)
565 block_size = _select_block_size(B.numel(), B.dtype)
566 _launch_sub_scalar_tensor_int(int(scalar), B, out, B.numel(), block_size)
567 return out
568 return sub_func_scalar_tensor(A, B, alpha)
570 return torch.tensor(A - B * alpha)
573def sub_(A, B, *, alpha=1):
574 logger.debug("GEMS SUB_")
575 _maybe_prewarm_sub_kernels()
577 if isinstance(B, torch.Tensor):
578 if _can_use_contiguous_fastpath(A, B):
579 if A.untyped_storage().data_ptr() == B.untyped_storage().data_ptr():
580 return sub_func(A, B, alpha, out0=A)
581 block_size = _select_block_size(A.numel(), A.dtype)
582 _launch_sub_tensor_tensor(A, B, A, float(alpha), A.numel(), block_size)
583 return A
584 if _can_use_broadcast_lastdim1_fastpath(A, B):
585 _launch_sub_broadcast_lastdim1(A, B.view(-1), A, float(alpha))
586 return A
587 return sub_func(A, B, alpha, out0=A)
589 scalar = _maybe_scalar(B)
590 if (
591 scalar is not None
592 and isinstance(A, torch.Tensor)
593 and A.device.type == "cpu"
594 and A.is_contiguous()
595 and A.dtype in _SUPPORTED_FAST_DTYPES
596 ):
597 block_size = _select_block_size(A.numel(), A.dtype)
598 _launch_sub_tensor_scalar(
599 A, float(scalar), A, float(alpha), A.numel(), block_size
600 )
601 return A
602 if _can_use_tensor_scalar_int_fastpath(A, scalar, alpha):
603 block_size = _select_block_size(A.numel(), A.dtype)
604 _launch_sub_tensor_scalar_int(A, int(scalar), A, A.numel(), block_size)
605 return A
607 return sub_func_tensor_scalar(A, B, alpha, out0=A)
610_maybe_prewarm_sub_kernels()