From 7f5ddad1707c04c9925e0547c8ad347086efdf77 Mon Sep 17 00:00:00 2001 From: yangxudong Date: Tue, 31 Dec 2024 16:19:31 +0800 Subject: [PATCH] fix bug of undefined flags of easyrec tools run with DeepRec --- .../python/tools/add_boundaries_to_config.py | 3 +++ .../tools/add_feature_info_to_config.py | 3 +++ easy_rec/python/tools/faiss_index_pai.py | 3 +++ easy_rec/python/tools/feature_selection.py | 3 +++ easy_rec/python/tools/hit_rate_ds.py | 3 +++ easy_rec/python/tools/hit_rate_pai.py | 3 +++ easy_rec/python/tools/pre_check.py | 3 +++ easy_rec/python/tools/split_model_pai.py | 3 +++ easy_rec/python/tools/split_pdn_model_pai.py | 3 +++ easy_rec/python/utils/io_util.py | 21 +++++++++++++++++++ 10 files changed, 48 insertions(+) diff --git a/easy_rec/python/tools/add_boundaries_to_config.py b/easy_rec/python/tools/add_boundaries_to_config.py index 09d2d9a1d..18d5f6037 100644 --- a/easy_rec/python/tools/add_boundaries_to_config.py +++ b/easy_rec/python/tools/add_boundaries_to_config.py @@ -3,11 +3,13 @@ import json import logging import os +import sys import common_io import tensorflow as tf from easy_rec.python.utils import config_util +from easy_rec.python.utils import io_util if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -61,4 +63,5 @@ def main(argv): if __name__ == '__main__': + sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) tf.app.run() diff --git a/easy_rec/python/tools/add_feature_info_to_config.py b/easy_rec/python/tools/add_feature_info_to_config.py index b11cfc0a7..7594d038b 100644 --- a/easy_rec/python/tools/add_feature_info_to_config.py +++ b/easy_rec/python/tools/add_feature_info_to_config.py @@ -3,10 +3,12 @@ import json import logging import os +import sys import tensorflow as tf from easy_rec.python.utils import config_util +from easy_rec.python.utils import io_util from easy_rec.python.utils.hive_utils import HiveUtils if tf.__version__ >= '2.0': @@ -139,4 +141,5 @@ def main(argv): if __name__ == '__main__': + sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) tf.app.run() diff --git a/easy_rec/python/tools/faiss_index_pai.py b/easy_rec/python/tools/faiss_index_pai.py index 718382733..b7eb66bc0 100644 --- a/easy_rec/python/tools/faiss_index_pai.py +++ b/easy_rec/python/tools/faiss_index_pai.py @@ -4,10 +4,12 @@ import logging import os +import sys import faiss import numpy as np import tensorflow as tf +from easy_rec.python.utils import io_util logging.basicConfig( level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s') @@ -109,4 +111,5 @@ def main(argv): if __name__ == '__main__': + sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) tf.app.run() diff --git a/easy_rec/python/tools/feature_selection.py b/easy_rec/python/tools/feature_selection.py index 6d9f59911..f50a00fac 100644 --- a/easy_rec/python/tools/feature_selection.py +++ b/easy_rec/python/tools/feature_selection.py @@ -3,6 +3,7 @@ import json import os +import sys from collections import OrderedDict import numpy as np @@ -11,6 +12,7 @@ from tensorflow.python.framework.meta_graph import read_meta_graph_file from easy_rec.python.utils import config_util +from easy_rec.python.utils import io_util if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -299,6 +301,7 @@ def _visualize_feature_importance(self, feature_importance, group_name): if __name__ == '__main__': + sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) if FLAGS.model_type == 'variational_dropout': fs = VariationalDropoutFS( FLAGS.config_path, diff --git a/easy_rec/python/tools/hit_rate_ds.py b/easy_rec/python/tools/hit_rate_ds.py index 552b96aad..5528e0aa2 100644 --- a/easy_rec/python/tools/hit_rate_ds.py +++ b/easy_rec/python/tools/hit_rate_ds.py @@ -20,12 +20,14 @@ import json import logging import os +import sys import graphlearn as gl import tensorflow as tf from easy_rec.python.protos.dataset_pb2 import DatasetConfig from easy_rec.python.utils import config_util +from easy_rec.python.utils import io_util from easy_rec.python.utils.config_util import process_multi_file_input_path from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch from easy_rec.python.utils.hit_rate_utils import load_graph @@ -217,4 +219,5 @@ def main(): if __name__ == '__main__': + sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) main() diff --git a/easy_rec/python/tools/hit_rate_pai.py b/easy_rec/python/tools/hit_rate_pai.py index 73c8a2095..5f97b3429 100644 --- a/easy_rec/python/tools/hit_rate_pai.py +++ b/easy_rec/python/tools/hit_rate_pai.py @@ -17,8 +17,10 @@ from __future__ import division from __future__ import print_function +import sys import tensorflow as tf +from easy_rec.python.utils import io_util from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch from easy_rec.python.utils.hit_rate_utils import load_graph from easy_rec.python.utils.hit_rate_utils import reduce_hitrate @@ -131,4 +133,5 @@ def main(): if __name__ == '__main__': + sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) main() diff --git a/easy_rec/python/tools/pre_check.py b/easy_rec/python/tools/pre_check.py index 8fcaa2caf..da7f1923b 100644 --- a/easy_rec/python/tools/pre_check.py +++ b/easy_rec/python/tools/pre_check.py @@ -3,12 +3,14 @@ import json import logging import os +import sys import tensorflow as tf from easy_rec.python.input.input import Input from easy_rec.python.utils import config_util from easy_rec.python.utils import fg_util +from easy_rec.python.utils import io_util from easy_rec.python.utils.check_utils import check_env_and_input_path from easy_rec.python.utils.check_utils import check_sequence @@ -114,4 +116,5 @@ def main(argv): if __name__ == '__main__': + sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) tf.app.run() diff --git a/easy_rec/python/tools/split_model_pai.py b/easy_rec/python/tools/split_model_pai.py index bdb2087de..cf1657deb 100644 --- a/easy_rec/python/tools/split_model_pai.py +++ b/easy_rec/python/tools/split_model_pai.py @@ -2,6 +2,7 @@ import copy import logging import os +import sys import tensorflow as tf from tensorflow.core.framework import graph_pb2 @@ -11,6 +12,7 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.tools import saved_model_utils from tensorflow.python.training import saver as tf_saver +from easy_rec.python.utils import io_util if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -282,4 +284,5 @@ def main(argv): if __name__ == '__main__': + sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) tf.app.run() diff --git a/easy_rec/python/tools/split_pdn_model_pai.py b/easy_rec/python/tools/split_pdn_model_pai.py index e2341d57d..849250b37 100644 --- a/easy_rec/python/tools/split_pdn_model_pai.py +++ b/easy_rec/python/tools/split_pdn_model_pai.py @@ -2,6 +2,7 @@ import copy import logging import os +import sys import tensorflow as tf from tensorflow.core.framework import graph_pb2 @@ -12,6 +13,7 @@ from tensorflow.python.saved_model.utils_impl import get_variables_path from tensorflow.python.tools import saved_model_utils from tensorflow.python.training import saver as tf_saver +from easy_rec.python.utils import io_util FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('model_dir', '', '') @@ -265,4 +267,5 @@ def main(argv): if __name__ == '__main__': + sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) tf.app.run() diff --git a/easy_rec/python/utils/io_util.py b/easy_rec/python/utils/io_util.py index 091e10e07..431394e5e 100644 --- a/easy_rec/python/utils/io_util.py +++ b/easy_rec/python/utils/io_util.py @@ -185,3 +185,24 @@ def read_data_from_json_path(json_path): else: logging.info('json_path not exists, return None') return None + +def filter_unknown_args(flags, args): + """Filter unknown args.""" + defined_flags = set(flag.name for flag in flags._flags().values()) + logging.info('defined arguments: %s', ', '.join(defined_flags)) + logging.info('actual arguments: %s', ', '.join(args[1:])) + known_args = [args[0]] + unknown = False + for arg in args[1:]: + if arg.startswith('--'): + flag_name = arg.split('=')[0][2:] + if flag_name in defined_flags: + known_args.append(arg) + unknown = False + else: + unknown = True + logging.warning('Ignore unknown arg: %s' % arg) + elif not unknown: + known_args.append(arg) + logging.info('keep arguments: %s', ', '.join(known_args[1:])) + return known_args