diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index 33b0320808..f43152abf1 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -83,12 +83,15 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast: def make_predictions(prediction_net, inputs: dict): - # MXNet predictors only support positional arguments - class_name = prediction_net.__class__.__module__ - if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"): - return prediction_net(*inputs.values()) - else: - return prediction_net(**inputs) + try: + # Feed inputs as positional arguments for MXNet block predictors + import mxnet as mx + + if isinstance(prediction_net, mx.gluon.HybridBlock): + return prediction_net(*inputs.values()) + except ImportError: + pass + return prediction_net(**inputs) class ForecastGenerator: