diff --git a/predict.py b/predict.py index d695394..cbd70d0 100644 --- a/predict.py +++ b/predict.py @@ -171,6 +171,11 @@ def main(args): elif predictions.isna().any().any(): logging.error("Pickle function returned at least 1 NaN prediction") exit_with_help(1) + elif not (predictions.between(0, 1).all().all()): + logging.error( + "Pickle function returned invalid predictions. Ensure values are between 0 and 1." + ) + exit_with_help(1) except TypeError as e: logging.error(f"Pickle function is invalid - {e}") if args.debug: