diff --git a/examples/text_pair_classification.py b/examples/text_pair_classification.py index 715787f8b..cbe8c0538 100644 --- a/examples/text_pair_classification.py +++ b/examples/text_pair_classification.py @@ -98,9 +98,13 @@ def text_pair_classification(): # 9. Load it & harvest your fruits (Inference) # Add your own text adapted to the dataset you provide + # For correct Text Pair Classification on raw dictionaries (inference mode), we need to put both + # texts (text, text_b) into a tuple. + # See corresponding conversion in the file_to_dicts() method of TextPairClassificationProcessor: https://github.com/deepset-ai/FARM/blob/5ab5b1620cb51ceb874d4b30c887e377ad1a6e9a/farm/data_handler/processor.py#L744 basic_texts = [ - {"text": "how many times have real madrid won the champions league in a row", "text_b": "They have also won the competition the most times in a row, winning it five times from 1956 to 1960"}, - {"text": "how many seasons of the blacklist are there on netflix", "text_b": "Retrieved March 27 , 2018 ."}, + {"text": ("how many times have real madrid won the champions league in a row", + "They have also won the competition the most times in a row, winning it five times from 1956 to 1960")}, + {"text": ("how many seasons of the blacklist are there on netflix", "Retrieved March 27 , 2018 .")}, ] model = Inferencer.load(save_dir) diff --git a/farm/infer.py b/farm/infer.py index f0fea6c0a..6340385ef 100644 --- a/farm/infer.py +++ b/farm/infer.py @@ -473,8 +473,9 @@ def _inference_without_multiprocessing(self, dicts, return_json, aggregate_preds :return: list of predictions :rtype: list """ + indices = list(range(len(dicts))) dataset, tensor_names, problematic_ids, baskets = self.processor.dataset_from_dicts( - dicts, indices=[i for i in range(len(dicts))], return_baskets=True + dicts, indices=indices, return_baskets=True ) self.problematic_sample_ids = problematic_ids if self.benchmarking: diff --git a/test/test_text_pair.py b/test/test_text_pair.py index e4aaaee81..0094e5b46 100644 --- a/test/test_text_pair.py +++ b/test/test_text_pair.py @@ -16,7 +16,7 @@ from farm.utils import set_all_seeds, initialize_device_settings -def test_text_pair_classification(caplog): +def test_text_pair_classification(caplog=None): if caplog: caplog.set_level(logging.CRITICAL) @@ -83,9 +83,12 @@ def test_text_pair_classification(caplog): model.save(save_dir) processor.save(save_dir) + + # For correct Text Pair Classification on raw dictionaries, we need to put both texts (text, text_b) into a tuple + # See corresponding operation in the file_to_dicts method of TextPairClassificationProcessor here: https://github.com/deepset-ai/FARM/blob/5ab5b1620cb51ceb874d4b30c887e377ad1a6e9a/farm/data_handler/processor.py#L744 basic_texts = [ - {"text": "how many times have real madrid won the champions league in a row", "text_b": "They have also won the competition the most times in a row, winning it five times from 1956 to 1960"}, - {"text": "how many seasons of the blacklist are there on netflix", "text_b": "Retrieved March 27 , 2018 ."}, + {"text": ("how many times have real madrid won the champions league in a row", "They have also won the competition the most times in a row, winning it five times from 1956 to 1960")}, + {"text": ("how many seasons of the blacklist are there on netflix", "Retrieved March 27 , 2018 .")}, ] model = Inferencer.load(save_dir) @@ -96,7 +99,7 @@ def test_text_pair_classification(caplog): model.close_multiprocessing_pool() -def test_text_pair_regression(caplog): +def test_text_pair_regression(caplog=None): if caplog: caplog.set_level(logging.CRITICAL) @@ -163,8 +166,8 @@ def test_text_pair_regression(caplog): processor.save(save_dir) basic_texts = [ - {"text": "how many times have real madrid won the champions league in a row", "text_b": "They have also won the competition the most times in a row, winning it five times from 1956 to 1960"}, - {"text": "how many seasons of the blacklist are there on netflix", "text_b": "Retrieved March 27 , 2018 ."}, + {"text": ("how many times have real madrid won the champions league in a row", "They have also won the competition the most times in a row, winning it five times from 1956 to 1960")}, + {"text": ("how many seasons of the blacklist are there on netflix", "Retrieved March 27 , 2018 .")}, ] model = Inferencer.load(save_dir) @@ -173,7 +176,7 @@ def test_text_pair_regression(caplog): assert np.isclose(result[0]["predictions"][0]["pred"], 0.7976, rtol=0.05) model.close_multiprocessing_pool() -def test_segment_ids(caplog): +def test_segment_ids(caplog=None): if caplog: caplog.set_level(logging.CRITICAL) lang_model = "microsoft/MiniLM-L12-H384-uncased" @@ -198,6 +201,7 @@ def test_segment_ids(caplog): if __name__ == "__main__": + test_text_pair_classification() test_text_pair_regression() # fmt: on