Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Jan 3, 2025
1 parent 7f5ddad commit 74b1edc
Showing 1 changed file with 60 additions and 16 deletions.
76 changes: 60 additions & 16 deletions easy_rec/python/utils/io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,23 +186,67 @@ def read_data_from_json_path(json_path):
logging.info('json_path not exists, return None')
return None


def convert_tf_flags_to_argparse(flags):
"""Convert tf.app.flags.FLAGS to argparse.ArgumentParser.
Args:
flags: tf.app.flags.FLAGS
Returns:
argparse.ArgumentParser: configurate ArgumentParser object
"""
import argparse
import ast
parser = argparse.ArgumentParser()

args = set()
for flag in flags._flags().values():
flag_name = flag.name
if flag_name in args:
continue
args.add(flag_name)
default = flag.value
flag_type = type(default)
help_str = flag.help or ''
if flag_type == bool:
parser.add_argument(
'--' + flag_name,
dest=flag_name,
action='store_true' if default else 'store_false',
help=help_str)
elif flag_type == str:
if hasattr(flag, 'choices') and flag.choices:
parser.add_argument(
'--' + flag_name,
type=str,
choices=flag.choices,
default=default,
help=help_str)
else:
parser.add_argument(
'--' + flag_name, type=str, default=default, help=help_str)
elif flag_type in (list, dict):
parser.add_argument(
'--' + flag_name,
type=lambda s: ast.literal_eval(s),
default=default,
help=help_str)
else:
parser.add_argument(
'--' + flag_name, type=flag_type, default=default, help=help_str)
return parser


def filter_unknown_args(flags, args):
"""Filter unknown args."""
defined_flags = set(flag.name for flag in flags._flags().values())
logging.info('defined arguments: %s', ', '.join(defined_flags))
logging.info('actual arguments: %s', ', '.join(args[1:]))
known_args = [args[0]]
unknown = False
for arg in args[1:]:
if arg.startswith('--'):
flag_name = arg.split('=')[0][2:]
if flag_name in defined_flags:
known_args.append(arg)
unknown = False
else:
unknown = True
logging.warning('Ignore unknown arg: %s' % arg)
elif not unknown:
known_args.append(arg)
logging.info('keep arguments: %s', ', '.join(known_args[1:]))
parser = convert_tf_flags_to_argparse(flags)
args, unknown = parser.parse_known_args(args)
if len(unknown) > 1:
logging.info('undefined arguments: %s', ', '.join(unknown[1:]))
for key, value in vars(args).items():
if type(value) != bool and not value:
continue
known_args.append('--' + key + '=' + str(value))
logging.info('defined arguments: %s', ', '.join(known_args[1:]))
return known_args

0 comments on commit 74b1edc

Please # to comment.