diff --git a/easy_rec/python/utils/io_util.py b/easy_rec/python/utils/io_util.py index 431394e5e..cfe20d4ac 100644 --- a/easy_rec/python/utils/io_util.py +++ b/easy_rec/python/utils/io_util.py @@ -186,23 +186,67 @@ def read_data_from_json_path(json_path): logging.info('json_path not exists, return None') return None + +def convert_tf_flags_to_argparse(flags): + """Convert tf.app.flags.FLAGS to argparse.ArgumentParser. + + Args: + flags: tf.app.flags.FLAGS + Returns: + argparse.ArgumentParser: configurate ArgumentParser object + """ + import argparse + import ast + parser = argparse.ArgumentParser() + + args = set() + for flag in flags._flags().values(): + flag_name = flag.name + if flag_name in args: + continue + args.add(flag_name) + default = flag.value + flag_type = type(default) + help_str = flag.help or '' + if flag_type == bool: + parser.add_argument( + '--' + flag_name, + dest=flag_name, + action='store_true' if default else 'store_false', + help=help_str) + elif flag_type == str: + if hasattr(flag, 'choices') and flag.choices: + parser.add_argument( + '--' + flag_name, + type=str, + choices=flag.choices, + default=default, + help=help_str) + else: + parser.add_argument( + '--' + flag_name, type=str, default=default, help=help_str) + elif flag_type in (list, dict): + parser.add_argument( + '--' + flag_name, + type=lambda s: ast.literal_eval(s), + default=default, + help=help_str) + else: + parser.add_argument( + '--' + flag_name, type=flag_type, default=default, help=help_str) + return parser + + def filter_unknown_args(flags, args): """Filter unknown args.""" - defined_flags = set(flag.name for flag in flags._flags().values()) - logging.info('defined arguments: %s', ', '.join(defined_flags)) - logging.info('actual arguments: %s', ', '.join(args[1:])) known_args = [args[0]] - unknown = False - for arg in args[1:]: - if arg.startswith('--'): - flag_name = arg.split('=')[0][2:] - if flag_name in defined_flags: - known_args.append(arg) - unknown = False - else: - unknown = True - logging.warning('Ignore unknown arg: %s' % arg) - elif not unknown: - known_args.append(arg) - logging.info('keep arguments: %s', ', '.join(known_args[1:])) + parser = convert_tf_flags_to_argparse(flags) + args, unknown = parser.parse_known_args(args) + if len(unknown) > 1: + logging.info('undefined arguments: %s', ', '.join(unknown[1:])) + for key, value in vars(args).items(): + if type(value) != bool and not value: + continue + known_args.append('--' + key + '=' + str(value)) + logging.info('defined arguments: %s', ', '.join(known_args[1:])) return known_args