Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_5_conv1d.py: 0%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""Monkey-patch Qwen3_5GatedDeltaNet.causal_conv1d_update with a TLE-CPU
2fused depthwise conv1d update kernel.
4The HF fallback torch_causal_conv1d_update goes through aten::conv1d
5(groups=conv_dim, kernel_size=4) which uses MKL-DNN with high dispatch
6overhead (~700us/call on ARM ACL). Profile shows 7% of decode time spent
7here on Qwen3.5-2B BF16.
9Replaces the entire torch_causal_conv1d_update body with one C kernel call
10that does cat-with-state + conv + (optional silu) + state-roll in a single
11NEON OMP loop.
13Decode (T=1, kernel_size=4, BF16) only — other shapes fall back to torch.
14"""
15import logging
17import torch
18import triton
19import triton.language as tl
20from triton.language.extra.cpu.tle_ops import (
21 causal_conv1d_update as _tle_causal_conv1d_update,
22)
24logger = logging.getLogger(__name__)
26_PATCHED: set = set()
27_DUMMY_BIAS = torch.zeros(1, dtype=torch.bfloat16)
30@triton.jit
31def _causal_conv1d_update_kernel(
32 hidden_ptr,
33 state_ptr,
34 weight_ptr,
35 bias_ptr,
36 out_ptr,
37 B: tl.constexpr,
38 C: tl.constexpr,
39 kernel_size: tl.constexpr,
40 silu: tl.constexpr,
41 has_bias: tl.constexpr,
42):
43 _tle_causal_conv1d_update(
44 hidden_ptr,
45 state_ptr,
46 weight_ptr,
47 bias_ptr,
48 out_ptr,
49 B,
50 C,
51 kernel_size,
52 silu,
53 has_bias,
54 )
57def _make_patched_fn(torch_causal_fn):
58 def fn(hidden_states, conv_state, weight, bias=None, activation=None):
59 # hidden_states: [B, C, T] bf16; weight: [C, kernel_size]; bias: None or [C]
60 # conv_state: [B, C, kernel_size-1] bf16 IN-OUT
61 # activation: 'silu' or None.
62 if (
63 hidden_states.shape[-1] != 1
64 or weight.shape[-1] != 4
65 or hidden_states.dtype != torch.bfloat16
66 or weight.dtype != torch.bfloat16
67 or conv_state.dtype != torch.bfloat16
68 or (activation not in ("silu", None))
69 ):
70 return torch_causal_fn(hidden_states, conv_state, weight, bias, activation)
72 B, C, _T = hidden_states.shape
73 # [B, C] contiguous
74 h = hidden_states.squeeze(-1).contiguous()
75 w = weight.contiguous()
76 # conv_state must be contiguous so the kernel can update it in place.
77 if not conv_state.is_contiguous():
78 conv_state_c = conv_state.contiguous()
79 else:
80 conv_state_c = conv_state
81 out = torch.empty(B, C, dtype=torch.bfloat16)
82 if bias is None:
83 b_t = _DUMMY_BIAS
84 has_bias = 0
85 else:
86 b_t = bias.contiguous()
87 has_bias = 1
88 silu_flag = 1 if activation == "silu" else 0
90 _causal_conv1d_update_kernel[(1,)](
91 h,
92 conv_state_c,
93 w,
94 b_t,
95 out,
96 B=B,
97 C=C,
98 kernel_size=4,
99 silu=silu_flag,
100 has_bias=has_bias,
101 )
103 # If we made a contiguous copy of conv_state, write back.
104 if conv_state_c.data_ptr() != conv_state.data_ptr():
105 conv_state.copy_(conv_state_c)
107 return out.unsqueeze(-1)
109 return fn
112def _get_qwen_gdn_classes() -> tuple:
113 classes = []
114 for modname, clsname in [
115 ("transformers.models.qwen3_5.modeling_qwen3_5", "Qwen3_5GatedDeltaNet"),
116 (
117 "transformers.models.qwen3_5_moe.modeling_qwen3_5_moe",
118 "Qwen3_5MoeGatedDeltaNet",
119 ),
120 (
121 "transformers.models.qwen3_next.modeling_qwen3_next",
122 "Qwen3NextGatedDeltaNet",
123 ),
124 ]:
125 try:
126 mod = __import__(modname, fromlist=[clsname])
127 classes.append(getattr(mod, clsname))
128 except (ImportError, AttributeError):
129 pass
130 return tuple(classes)
133def patch_qwen3_5_conv1d(model) -> int:
134 gdn_classes = _get_qwen_gdn_classes()
135 if not gdn_classes:
136 return 0
137 n = 0
138 for _name, module in list(model.named_modules()):
139 if isinstance(module, gdn_classes) and id(module) not in _PATCHED:
140 torch_fn = module.causal_conv1d_update
141 module._original_causal_conv1d_update = torch_fn
142 module.causal_conv1d_update = _make_patched_fn(torch_fn)
143 _PATCHED.add(id(module))
144 n += 1
145 if n > 0:
146 logger.info("Patched %d GDN causal_conv1d_update with TLE kernel", n)
147 return n
150def unpatch_qwen3_5_conv1d(model) -> int:
151 gdn_classes = _get_qwen_gdn_classes()
152 if not gdn_classes:
153 return 0
154 n = 0
155 for _name, module in list(model.named_modules()):
156 if isinstance(module, gdn_classes) and id(module) in _PATCHED:
157 if hasattr(module, "_original_causal_conv1d_update"):
158 module.causal_conv1d_update = module._original_causal_conv1d_update
159 del module._original_causal_conv1d_update
160 _PATCHED.discard(id(module))
161 n += 1
162 return n