diff --git a/easy_rec/python/utils/io_util.py b/easy_rec/python/utils/io_util.py index 5cef1d9ef..7b4666391 100644 --- a/easy_rec/python/utils/io_util.py +++ b/easy_rec/python/utils/io_util.py @@ -213,12 +213,24 @@ def convert_tf_flags_to_argparse(flags): flag.choices if hasattr(flag, 'choices') else None ] + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + for flag_name, (multi, flag_type, default, help_str, choices) in args.items(): if flag_type == bool: parser.add_argument( '--' + flag_name, - dest=flag_name, - action='store_true' if default else 'store_false', + type=str2bool, + nargs='?', + const=True, + default=False, help=help_str) elif flag_type == str: if choices: