Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_5_conv1d.py: 0%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1"""Monkey-patch Qwen3_5GatedDeltaNet.causal_conv1d_update with a TLE-CPU 

2fused depthwise conv1d update kernel. 

3 

4The HF fallback torch_causal_conv1d_update goes through aten::conv1d 

5(groups=conv_dim, kernel_size=4) which uses MKL-DNN with high dispatch 

6overhead (~700us/call on ARM ACL). Profile shows 7% of decode time spent 

7here on Qwen3.5-2B BF16. 

8 

9Replaces the entire torch_causal_conv1d_update body with one C kernel call 

10that does cat-with-state + conv + (optional silu) + state-roll in a single 

11NEON OMP loop. 

12 

13Decode (T=1, kernel_size=4, BF16) only — other shapes fall back to torch. 

14""" 

15import logging 

16 

17import torch 

18import triton 

19import triton.language as tl 

20from triton.language.extra.cpu.tle_ops import ( 

21 causal_conv1d_update as _tle_causal_conv1d_update, 

22) 

23 

24logger = logging.getLogger(__name__) 

25 

26_PATCHED: set = set() 

27_DUMMY_BIAS = torch.zeros(1, dtype=torch.bfloat16) 

28 

29 

30@triton.jit 

31def _causal_conv1d_update_kernel( 

32 hidden_ptr, 

33 state_ptr, 

34 weight_ptr, 

35 bias_ptr, 

36 out_ptr, 

37 B: tl.constexpr, 

38 C: tl.constexpr, 

39 kernel_size: tl.constexpr, 

40 silu: tl.constexpr, 

41 has_bias: tl.constexpr, 

42): 

43 _tle_causal_conv1d_update( 

44 hidden_ptr, 

45 state_ptr, 

46 weight_ptr, 

47 bias_ptr, 

48 out_ptr, 

49 B, 

50 C, 

51 kernel_size, 

52 silu, 

53 has_bias, 

54 ) 

55 

56 

57def _make_patched_fn(torch_causal_fn): 

58 def fn(hidden_states, conv_state, weight, bias=None, activation=None): 

59 # hidden_states: [B, C, T] bf16; weight: [C, kernel_size]; bias: None or [C] 

60 # conv_state: [B, C, kernel_size-1] bf16 IN-OUT 

61 # activation: 'silu' or None. 

62 if ( 

63 hidden_states.shape[-1] != 1 

64 or weight.shape[-1] != 4 

65 or hidden_states.dtype != torch.bfloat16 

66 or weight.dtype != torch.bfloat16 

67 or conv_state.dtype != torch.bfloat16 

68 or (activation not in ("silu", None)) 

69 ): 

70 return torch_causal_fn(hidden_states, conv_state, weight, bias, activation) 

71 

72 B, C, _T = hidden_states.shape 

73 # [B, C] contiguous 

74 h = hidden_states.squeeze(-1).contiguous() 

75 w = weight.contiguous() 

76 # conv_state must be contiguous so the kernel can update it in place. 

77 if not conv_state.is_contiguous(): 

78 conv_state_c = conv_state.contiguous() 

79 else: 

80 conv_state_c = conv_state 

81 out = torch.empty(B, C, dtype=torch.bfloat16) 

82 if bias is None: 

83 b_t = _DUMMY_BIAS 

84 has_bias = 0 

85 else: 

86 b_t = bias.contiguous() 

87 has_bias = 1 

88 silu_flag = 1 if activation == "silu" else 0 

89 

90 _causal_conv1d_update_kernel[(1,)]( 

91 h, 

92 conv_state_c, 

93 w, 

94 b_t, 

95 out, 

96 B=B, 

97 C=C, 

98 kernel_size=4, 

99 silu=silu_flag, 

100 has_bias=has_bias, 

101 ) 

102 

103 # If we made a contiguous copy of conv_state, write back. 

104 if conv_state_c.data_ptr() != conv_state.data_ptr(): 

105 conv_state.copy_(conv_state_c) 

106 

107 return out.unsqueeze(-1) 

108 

109 return fn 

110 

111 

112def _get_qwen_gdn_classes() -> tuple: 

113 classes = [] 

114 for modname, clsname in [ 

115 ("transformers.models.qwen3_5.modeling_qwen3_5", "Qwen3_5GatedDeltaNet"), 

116 ( 

117 "transformers.models.qwen3_5_moe.modeling_qwen3_5_moe", 

118 "Qwen3_5MoeGatedDeltaNet", 

119 ), 

120 ( 

121 "transformers.models.qwen3_next.modeling_qwen3_next", 

122 "Qwen3NextGatedDeltaNet", 

123 ), 

124 ]: 

125 try: 

126 mod = __import__(modname, fromlist=[clsname]) 

127 classes.append(getattr(mod, clsname)) 

128 except (ImportError, AttributeError): 

129 pass 

130 return tuple(classes) 

131 

132 

133def patch_qwen3_5_conv1d(model) -> int: 

134 gdn_classes = _get_qwen_gdn_classes() 

135 if not gdn_classes: 

136 return 0 

137 n = 0 

138 for _name, module in list(model.named_modules()): 

139 if isinstance(module, gdn_classes) and id(module) not in _PATCHED: 

140 torch_fn = module.causal_conv1d_update 

141 module._original_causal_conv1d_update = torch_fn 

142 module.causal_conv1d_update = _make_patched_fn(torch_fn) 

143 _PATCHED.add(id(module)) 

144 n += 1 

145 if n > 0: 

146 logger.info("Patched %d GDN causal_conv1d_update with TLE kernel", n) 

147 return n 

148 

149 

150def unpatch_qwen3_5_conv1d(model) -> int: 

151 gdn_classes = _get_qwen_gdn_classes() 

152 if not gdn_classes: 

153 return 0 

154 n = 0 

155 for _name, module in list(model.named_modules()): 

156 if isinstance(module, gdn_classes) and id(module) in _PATCHED: 

157 if hasattr(module, "_original_causal_conv1d_update"): 

158 module.causal_conv1d_update = module._original_causal_conv1d_update 

159 del module._original_causal_conv1d_update 

160 _PATCHED.discard(id(module)) 

161 n += 1 

162 return n