@@ -1103,6 +1103,47 @@ cudaError_t moereduction_allreduce_fusion_kernel_launcher(
1103
1103
return cudaSuccess;
1104
1104
}
1105
1105
1106
+ #define DISPATCH_BOOL_ (expr, const_expr, ...) \
1107
+ [&]() -> cudaError_t { \
1108
+ if (expr) { \
1109
+ constexpr bool const_expr = true ; \
1110
+ return __VA_ARGS__ (); \
1111
+ } else { \
1112
+ constexpr bool const_expr = false ; \
1113
+ return __VA_ARGS__ (); \
1114
+ } \
1115
+ }()
1116
+
1117
+ #define _DISPATCH_MOEREDUCTION_CASE (n_ranks_val, N_RANKS_VAR, ar, res, rms, quant, AR, RES, RMS, \
1118
+ QUANT, ...) \
1119
+ case n_ranks_val: { \
1120
+ constexpr int N_RANKS_VAR = n_ranks_val; \
1121
+ return DISPATCH_BOOL_ (ar, AR, [&]() -> cudaError_t { \
1122
+ return DISPATCH_BOOL_ (res, RES, [&]() -> cudaError_t { \
1123
+ return DISPATCH_BOOL_ (rms, RMS, [&]() -> cudaError_t { \
1124
+ return DISPATCH_BOOL_ (quant, QUANT, [&]() -> cudaError_t { return __VA_ARGS__ (); }); \
1125
+ }); \
1126
+ }); \
1127
+ }); \
1128
+ }
1129
+
1130
+ #define DISPATCH_MOEREDUCTION (n_ranks, ar, res, rms, quant, N_RANKS, AR, RES, RMS, QUANT, ...) \
1131
+ [&]() -> cudaError_t { \
1132
+ switch (n_ranks) { \
1133
+ _DISPATCH_MOEREDUCTION_CASE (2 , N_RANKS, ar, res, rms, quant, AR, RES, RMS, QUANT, \
1134
+ __VA_ARGS__) \
1135
+ _DISPATCH_MOEREDUCTION_CASE (4 , N_RANKS, ar, res, rms, quant, AR, RES, RMS, QUANT, \
1136
+ __VA_ARGS__) \
1137
+ _DISPATCH_MOEREDUCTION_CASE (8 , N_RANKS, ar, res, rms, quant, AR, RES, RMS, QUANT, \
1138
+ __VA_ARGS__) \
1139
+ _DISPATCH_MOEREDUCTION_CASE (16 , N_RANKS, ar, res, rms, quant, AR, RES, RMS, QUANT, \
1140
+ __VA_ARGS__) \
1141
+ default : \
1142
+ FLASHINFER_CHECK (false , " Unsupported n_ranks" ); \
1143
+ return cudaErrorNotSupported; \
1144
+ } \
1145
+ }()
1146
+
1106
1147
template <typename T>
1107
1148
cudaError_t moereduction_allreduce_fusion_op (MoeReductionAllReduceFusionParams<T> const & params,
1108
1149
bool launch_with_pdl) {
@@ -1119,72 +1160,6 @@ cudaError_t moereduction_allreduce_fusion_op(MoeReductionAllReduceFusionParams<T
1119
1160
params.allreduce_out || params.residual_out || params.norm_out || params.quant_out ,
1120
1161
" at least one of allreduce_out, residual_out, norm_out, quant_out must be set" );
1121
1162
1122
- #define DISPATCH_MOEREDUCTION_KERNEL_RANKS (T, PARAMS, LAUNCH_WITH_PDL, AR, RES, RMS, QUANT ) \
1123
- do { \
1124
- switch ((PARAMS).nranks ) { \
1125
- case 2 : \
1126
- FLASHINFER_CUDA_CALL ( \
1127
- (moereduction_allreduce_fusion_kernel_launcher<T, 2 , AR, RES, RMS, QUANT>( \
1128
- (PARAMS), (LAUNCH_WITH_PDL)))); \
1129
- break ; \
1130
- case 4 : \
1131
- FLASHINFER_CUDA_CALL ( \
1132
- (moereduction_allreduce_fusion_kernel_launcher<T, 4 , AR, RES, RMS, QUANT>( \
1133
- (PARAMS), (LAUNCH_WITH_PDL)))); \
1134
- break ; \
1135
- case 8 : \
1136
- FLASHINFER_CUDA_CALL ( \
1137
- (moereduction_allreduce_fusion_kernel_launcher<T, 8 , AR, RES, RMS, QUANT>( \
1138
- (PARAMS), (LAUNCH_WITH_PDL)))); \
1139
- break ; \
1140
- case 16 : \
1141
- FLASHINFER_CUDA_CALL ( \
1142
- (moereduction_allreduce_fusion_kernel_launcher<T, 16 , AR, RES, RMS, QUANT>( \
1143
- (PARAMS), (LAUNCH_WITH_PDL)))); \
1144
- break ; \
1145
- default : \
1146
- FLASHINFER_CHECK (false , " Unsupported nranks" ); \
1147
- } \
1148
- return cudaSuccess; \
1149
- } while (0 )
1150
-
1151
- #define DISPATCH_MOEREDUCTION_PATTERN (T, PARAMS, LAUNCH_WITH_PDL ) \
1152
- do { \
1153
- bool AR = (PARAMS).allreduce_out ; \
1154
- bool RES = (PARAMS).residual_out ; \
1155
- bool NORM = (PARAMS).norm_out ; \
1156
- bool QUANT = (PARAMS).quant_out ; \
1157
- \
1158
- if (AR && RES && !NORM && QUANT) { \
1159
- /* AR + Residual + Quant */ \
1160
- DISPATCH_MOEREDUCTION_KERNEL_RANKS (T, PARAMS, LAUNCH_WITH_PDL, true , true , false , true ); \
1161
- } else if (!AR && RES && !NORM && QUANT) { \
1162
- /* Residual + Quant */ \
1163
- DISPATCH_MOEREDUCTION_KERNEL_RANKS (T, PARAMS, LAUNCH_WITH_PDL, false , true , false , true ); \
1164
- } else if (AR && !RES && NORM && !QUANT) { \
1165
- /* AR + RMS */ \
1166
- DISPATCH_MOEREDUCTION_KERNEL_RANKS (T, PARAMS, LAUNCH_WITH_PDL, true , false , true , false ); \
1167
- } else if (!AR && !RES && NORM && !QUANT) { \
1168
- /* RMS only */ \
1169
- DISPATCH_MOEREDUCTION_KERNEL_RANKS (T, PARAMS, LAUNCH_WITH_PDL, false , false , true , false ); \
1170
- } else if (AR && RES && NORM && !QUANT) { \
1171
- /* AR + Add + RMS */ \
1172
- DISPATCH_MOEREDUCTION_KERNEL_RANKS (T, PARAMS, LAUNCH_WITH_PDL, true , true , true , false ); \
1173
- } else if (!AR && RES && NORM && !QUANT) { \
1174
- /* Add + RMS */ \
1175
- DISPATCH_MOEREDUCTION_KERNEL_RANKS (T, PARAMS, LAUNCH_WITH_PDL, false , true , true , false ); \
1176
- } else if (AR && RES && NORM && QUANT) { \
1177
- /* AR + Add + RMS + Quant */ \
1178
- DISPATCH_MOEREDUCTION_KERNEL_RANKS (T, PARAMS, LAUNCH_WITH_PDL, true , true , true , true ); \
1179
- } else if (!AR && RES && NORM && QUANT) { \
1180
- /* Add + RMS + Quant */ \
1181
- DISPATCH_MOEREDUCTION_KERNEL_RANKS (T, PARAMS, LAUNCH_WITH_PDL, false , true , true , true ); \
1182
- } else { \
1183
- FLASHINFER_CHECK (false , " allreduce_fusion_kernel: unsupported pattern!" ); \
1184
- return cudaErrorNotSupported; \
1185
- } \
1186
- } while (0 )
1187
-
1188
1163
// hidden_dim (d) = 7168 for dpsk moe, and hence 128 tokens as one-shot threshold
1189
1164
// AR outputs are optional, since we always have fused options followed.
1190
1165
// pattern1: AR+Residual+Add_RMS+Quant
@@ -1203,10 +1178,14 @@ cudaError_t moereduction_allreduce_fusion_op(MoeReductionAllReduceFusionParams<T
1203
1178
// [m, d] bf16 allreduce_in, [m, d] bf16 residual_in
1204
1179
// [m, d] bf16 residual_out, [m, d] bf16 norm_out, [m, d] fp4 quant_out
1205
1180
1206
- DISPATCH_MOEREDUCTION_PATTERN (T, params, launch_with_pdl);
1207
-
1208
- FLASHINFER_CHECK (false , " allreduce_fusion_kernel: unsupported pattern!" );
1209
- return cudaErrorNotSupported;
1181
+ auto status = DISPATCH_MOEREDUCTION (
1182
+ params.nranks , params.allreduce_out , params.residual_out , params.rms_gamma , params.quant_out ,
1183
+ N_RANKS, AR, RES, RMS, QUANT, [&]() -> cudaError_t {
1184
+ FLASHINFER_CUDA_CALL (
1185
+ (moereduction_allreduce_fusion_kernel_launcher<T, N_RANKS, AR, RES, RMS, QUANT>(
1186
+ (params), (launch_with_pdl))));
1187
+ });
1188
+ return status;
1210
1189
}
1211
1190
} // namespace trtllm_moe_allreduce_fusion
1212
1191
} // namespace flashinfer
0 commit comments