Skip to content

Commit

Permalink
Use tuple input for TextPairClassification inference (#723)
Browse files Browse the repository at this point in the history
* Change input to textpair inference dataset_from_dicts
  • Loading branch information
Timoeller authored Feb 23, 2021
1 parent d3658be commit c7e5d1a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
8 changes: 6 additions & 2 deletions examples/text_pair_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,13 @@ def text_pair_classification():

# 9. Load it & harvest your fruits (Inference)
# Add your own text adapted to the dataset you provide
# For correct Text Pair Classification on raw dictionaries (inference mode), we need to put both
# texts (text, text_b) into a tuple.
# See corresponding conversion in the file_to_dicts() method of TextPairClassificationProcessor: https://github.com/deepset-ai/FARM/blob/5ab5b1620cb51ceb874d4b30c887e377ad1a6e9a/farm/data_handler/processor.py#L744
basic_texts = [
{"text": "how many times have real madrid won the champions league in a row", "text_b": "They have also won the competition the most times in a row, winning it five times from 1956 to 1960"},
{"text": "how many seasons of the blacklist are there on netflix", "text_b": "Retrieved March 27 , 2018 ."},
{"text": ("how many times have real madrid won the champions league in a row",
"They have also won the competition the most times in a row, winning it five times from 1956 to 1960")},
{"text": ("how many seasons of the blacklist are there on netflix", "Retrieved March 27 , 2018 .")},
]

model = Inferencer.load(save_dir)
Expand Down
3 changes: 2 additions & 1 deletion farm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,9 @@ def _inference_without_multiprocessing(self, dicts, return_json, aggregate_preds
:return: list of predictions
:rtype: list
"""
indices = list(range(len(dicts)))
dataset, tensor_names, problematic_ids, baskets = self.processor.dataset_from_dicts(
dicts, indices=[i for i in range(len(dicts))], return_baskets=True
dicts, indices=indices, return_baskets=True
)
self.problematic_sample_ids = problematic_ids
if self.benchmarking:
Expand Down
18 changes: 11 additions & 7 deletions test/test_text_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from farm.utils import set_all_seeds, initialize_device_settings


def test_text_pair_classification(caplog):
def test_text_pair_classification(caplog=None):
if caplog:
caplog.set_level(logging.CRITICAL)

Expand Down Expand Up @@ -83,9 +83,12 @@ def test_text_pair_classification(caplog):
model.save(save_dir)
processor.save(save_dir)


# For correct Text Pair Classification on raw dictionaries, we need to put both texts (text, text_b) into a tuple
# See corresponding operation in the file_to_dicts method of TextPairClassificationProcessor here: https://github.com/deepset-ai/FARM/blob/5ab5b1620cb51ceb874d4b30c887e377ad1a6e9a/farm/data_handler/processor.py#L744
basic_texts = [
{"text": "how many times have real madrid won the champions league in a row", "text_b": "They have also won the competition the most times in a row, winning it five times from 1956 to 1960"},
{"text": "how many seasons of the blacklist are there on netflix", "text_b": "Retrieved March 27 , 2018 ."},
{"text": ("how many times have real madrid won the champions league in a row", "They have also won the competition the most times in a row, winning it five times from 1956 to 1960")},
{"text": ("how many seasons of the blacklist are there on netflix", "Retrieved March 27 , 2018 .")},
]

model = Inferencer.load(save_dir)
Expand All @@ -96,7 +99,7 @@ def test_text_pair_classification(caplog):
model.close_multiprocessing_pool()


def test_text_pair_regression(caplog):
def test_text_pair_regression(caplog=None):
if caplog:
caplog.set_level(logging.CRITICAL)

Expand Down Expand Up @@ -163,8 +166,8 @@ def test_text_pair_regression(caplog):
processor.save(save_dir)

basic_texts = [
{"text": "how many times have real madrid won the champions league in a row", "text_b": "They have also won the competition the most times in a row, winning it five times from 1956 to 1960"},
{"text": "how many seasons of the blacklist are there on netflix", "text_b": "Retrieved March 27 , 2018 ."},
{"text": ("how many times have real madrid won the champions league in a row", "They have also won the competition the most times in a row, winning it five times from 1956 to 1960")},
{"text": ("how many seasons of the blacklist are there on netflix", "Retrieved March 27 , 2018 .")},
]

model = Inferencer.load(save_dir)
Expand All @@ -173,7 +176,7 @@ def test_text_pair_regression(caplog):
assert np.isclose(result[0]["predictions"][0]["pred"], 0.7976, rtol=0.05)
model.close_multiprocessing_pool()

def test_segment_ids(caplog):
def test_segment_ids(caplog=None):
if caplog:
caplog.set_level(logging.CRITICAL)
lang_model = "microsoft/MiniLM-L12-H384-uncased"
Expand All @@ -198,6 +201,7 @@ def test_segment_ids(caplog):


if __name__ == "__main__":
test_text_pair_classification()
test_text_pair_regression()

# fmt: on

0 comments on commit c7e5d1a

Please # to comment.