Coverage for src/flag_gems/fused/__init__.py: 100%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1from flag_gems.fused.add_rms_norm import add_rms_norm 

2from flag_gems.fused.apply_repetition_penalties import apply_repetition_penalties 

3from flag_gems.fused.bincount import bincount 

4from flag_gems.fused.chunk_gated_delta_rule import chunk_gated_delta_rule 

5from flag_gems.fused.concat_and_cache_mla import concat_and_cache_mla 

6from flag_gems.fused.cp_gather_indexer_k_quant_cache import ( 

7 cp_gather_indexer_k_quant_cache, 

8) 

9from flag_gems.fused.cross_entropy_loss import cross_entropy_loss 

10from flag_gems.fused.cutlass_scaled_mm import cutlass_scaled_mm 

11from flag_gems.fused.deepseek_v4_attention_combine_topk_swa_indices import ( 

12 combine_topk_swa_indices, 

13) 

14from flag_gems.fused.deepseek_v4_attention_compute_global_topk_indices_and_lens import ( 

15 compute_global_topk_indices_and_lens, 

16) 

17from flag_gems.fused.deepseek_v4_attention_dequantize_and_gather_k_cache import ( 

18 dequantize_and_gather_k_cache, 

19) 

20from flag_gems.fused.deepseek_v4_attention_fused_q_kv_rmsnorm import fused_q_kv_rmsnorm 

21from flag_gems.fused.DSA.bin_topk import bucket_sort_topk 

22from flag_gems.fused.FLA import ( 

23 chunk_gated_delta_rule_fwd, 

24 fused_recurrent_gated_delta_rule_fwd, 

25) 

26from flag_gems.fused.flash_mla import flash_mla 

27from flag_gems.fused.flash_mla_with_kvcache import flash_mla_with_kvcache 

28from flag_gems.fused.flashmla_sparse import flash_mla_sparse_fwd 

29from flag_gems.fused.fused_add_rms_norm import fused_add_rms_norm 

30from flag_gems.fused.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert import ( 

31 fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert, 

32) 

33from flag_gems.fused.fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant 

34from flag_gems.fused.fused_moe import ( 

35 dispatch_fused_moe_kernel, 

36 fused_experts_impl, 

37 inplace_fused_experts, 

38 invoke_fused_moe_triton_kernel, 

39 outplace_fused_experts, 

40) 

41from flag_gems.fused.geglu import dgeglu, geglu 

42from flag_gems.fused.gelu_and_mul import gelu_and_mul 

43from flag_gems.fused.grouped_topk import grouped_topk 

44from flag_gems.fused.indexer_k_quant_and_cache import indexer_k_quant_and_cache 

45from flag_gems.fused.instance_norm import instance_norm 

46from flag_gems.fused.mhc import ( 

47 hc_head_fused_kernel, 

48 hc_head_fused_kernel_ref, 

49 mhc_bwd, 

50 mhc_bwd_ref, 

51 mhc_post, 

52 mhc_pre, 

53 sinkhorn_forward, 

54) 

55from flag_gems.fused.moe_align_block_size import ( 

56 moe_align_block_size, 

57 moe_align_block_size_triton, 

58) 

59from flag_gems.fused.moe_sum import moe_sum 

60from flag_gems.fused.outer import outer 

61from flag_gems.fused.pack_seq import pack_seq_triton 

62from flag_gems.fused.reglu import dreglu, reglu 

63from flag_gems.fused.reshape_and_cache import reshape_and_cache 

64from flag_gems.fused.reshape_and_cache_flash import reshape_and_cache_flash 

65from flag_gems.fused.rotary_embedding import apply_rotary_pos_emb 

66from flag_gems.fused.rwkv_ka_fusion import rwkv_ka_fusion 

67from flag_gems.fused.rwkv_mm_sparsity import rwkv_mm_sparsity 

68from flag_gems.fused.silu_and_mul import silu_and_mul, silu_and_mul_out 

69from flag_gems.fused.silu_and_mul_with_clamp import ( 

70 silu_and_mul_with_clamp, 

71 silu_and_mul_with_clamp_out, 

72) 

73from flag_gems.fused.skip_layernorm import skip_layer_norm 

74from flag_gems.fused.sparse_attention import sparse_attn_triton 

75from flag_gems.fused.swiglu import dswiglu, swiglu 

76from flag_gems.fused.top_k_per_row_decode import top_k_per_row_decode 

77from flag_gems.fused.top_k_per_row_prefill import top_k_per_row_prefill 

78from flag_gems.fused.topk_softmax import topk_softmax 

79from flag_gems.fused.topk_softplus_sqrt import topk_softplus_sqrt 

80from flag_gems.fused.unpack_seq import unpack_seq_triton 

81from flag_gems.fused.weight_norm import weight_norm 

82 

83__all__ = [ 

84 "add_rms_norm", 

85 "apply_repetition_penalties", 

86 "apply_rotary_pos_emb", 

87 "bincount", 

88 "bucket_sort_topk", 

89 "chunk_gated_delta_rule", 

90 "chunk_gated_delta_rule_fwd", 

91 "combine_topk_swa_indices", 

92 "compute_global_topk_indices_and_lens", 

93 "concat_and_cache_mla", 

94 "cp_gather_indexer_k_quant_cache", 

95 "cross_entropy_loss", 

96 "cutlass_scaled_mm", 

97 "dequantize_and_gather_k_cache", 

98 "dgeglu", 

99 "dispatch_fused_moe_kernel", 

100 "dreglu", 

101 "dswiglu", 

102 "flash_mla", 

103 "flash_mla_sparse_fwd", 

104 "flash_mla_with_kvcache", 

105 "fused_add_rms_norm", 

106 "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", 

107 "fused_experts_impl", 

108 "fused_inv_rope_fp8_quant", 

109 "fused_q_kv_rmsnorm", 

110 "fused_recurrent_gated_delta_rule_fwd", 

111 "geglu", 

112 "gelu_and_mul", 

113 "grouped_topk", 

114 "hc_head_fused_kernel", 

115 "hc_head_fused_kernel_ref", 

116 "indexer_k_quant_and_cache", 

117 "inplace_fused_experts", 

118 "instance_norm", 

119 "invoke_fused_moe_triton_kernel", 

120 "mhc_bwd", 

121 "mhc_bwd_ref", 

122 "mhc_post", 

123 "mhc_pre", 

124 "moe_align_block_size", 

125 "moe_align_block_size_triton", 

126 "moe_sum", 

127 "outer", 

128 "outplace_fused_experts", 

129 "pack_seq_triton", 

130 "reglu", 

131 "reshape_and_cache", 

132 "reshape_and_cache_flash", 

133 "rwkv_ka_fusion", 

134 "rwkv_mm_sparsity", 

135 "silu_and_mul", 

136 "silu_and_mul_out", 

137 "silu_and_mul_with_clamp", 

138 "silu_and_mul_with_clamp_out", 

139 "sinkhorn_forward", 

140 "skip_layer_norm", 

141 "sparse_attn_triton", 

142 "swiglu", 

143 "top_k_per_row_decode", 

144 "top_k_per_row_prefill", 

145 "topk_softmax", 

146 "topk_softplus_sqrt", 

147 "unpack_seq_triton", 

148 "weight_norm", 

149]