Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Removed single_model_path; added infer_tokenizer to dpr load() #1060

Merged
merged 1 commit into from
Jun 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 40 additions & 51 deletions haystack/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self,
document_store: BaseDocumentStore,
query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base",
single_model_path: Optional[Union[Path, str]] = None,
model_version: Optional[str] = None,
max_seq_len_query: int = 64,
max_seq_len_passage: int = 256,
Expand Down Expand Up @@ -74,9 +73,6 @@ def __init__(self,
:param passage_embedding_model: Local path or remote name of passage encoder checkpoint. The format equals the
one used by hugging-face transformers' modelhub models
Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"``
:param single_model_path: Local path or remote name of a query and passage embedder in one single model. Those
models are typically trained within FARM.
Currently available remote names: TODO add FARM DPR model to HF modelhub
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
:param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down."
:param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down."
Expand All @@ -101,7 +97,7 @@ def __init__(self,
# save init parameters to enable export of component config as YAML
self.set_config(
document_store=document_store, query_embedding_model=query_embedding_model,
passage_embedding_model=passage_embedding_model, single_model_path=single_model_path,
passage_embedding_model=passage_embedding_model,
model_version=model_version, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage,
top_k=top_k, use_gpu=use_gpu, batch_size=batch_size, embed_title=embed_title,
use_fast_tokenizers=use_fast_tokenizers, infer_tokenizer_classes=infer_tokenizer_classes,
Expand Down Expand Up @@ -137,51 +133,42 @@ def __init__(self,
tokenizers_default_classes["passage"] = None # type: ignore

# Init & Load Encoders
if single_model_path is None:
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
tokenizer_class=tokenizers_default_classes["query"])
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
revision=model_version,
language_model_class="DPRQuestionEncoder")
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
tokenizer_class=tokenizers_default_classes["passage"])
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
tokenizer_class=tokenizers_default_classes["query"])
self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model,
revision=model_version,
language_model_class="DPRQuestionEncoder")
self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
do_lower_case=True,
use_fast=use_fast_tokenizers,
tokenizer_class=tokenizers_default_classes["passage"])
self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model,
revision=model_version,
language_model_class="DPRContextEncoder")

self.processor = TextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
passage_tokenizer=self.passage_tokenizer,
max_seq_len_passage=max_seq_len_passage,
max_seq_len_query=max_seq_len_query,
label_list=["hard_negative", "positive"],
metric="text_similarity_metric",
embed_title=embed_title,
num_hard_negatives=0,
num_positives=1)
prediction_head = TextSimilarityHead(similarity_function=similarity_function)
self.model = BiAdaptiveModel(
language_model1=self.query_encoder,
language_model2=self.passage_encoder,
prediction_heads=[prediction_head],
embeds_dropout_prob=0.1,
lm1_output_types=["per_sequence"],
lm2_output_types=["per_sequence"],
device=self.device,
)
else:
self.processor = TextSimilarityProcessor.load_from_dir(single_model_path)
self.processor.max_seq_len_passage = max_seq_len_passage
self.processor.max_seq_len_query = max_seq_len_query
self.processor.embed_title = embed_title
self.processor.num_hard_negatives = 0
self.processor.num_positives = 1 # during indexing of documents only one embedding is created
self.model = BiAdaptiveModel.load(single_model_path, device=self.device)
language_model_class="DPRContextEncoder")

self.processor = TextSimilarityProcessor(query_tokenizer=self.query_tokenizer,
passage_tokenizer=self.passage_tokenizer,
max_seq_len_passage=max_seq_len_passage,
max_seq_len_query=max_seq_len_query,
label_list=["hard_negative", "positive"],
metric="text_similarity_metric",
embed_title=embed_title,
num_hard_negatives=0,
num_positives=1)
prediction_head = TextSimilarityHead(similarity_function=similarity_function)
self.model = BiAdaptiveModel(
language_model1=self.query_encoder,
language_model2=self.passage_encoder,
prediction_heads=[prediction_head],
embeds_dropout_prob=0.1,
lm1_output_types=["per_sequence"],
lm2_output_types=["per_sequence"],
device=self.device,
)

self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False)

Expand Down Expand Up @@ -414,7 +401,8 @@ def load(cls,
use_fast_tokenizers: bool = True,
similarity_function: str = "dot_product",
query_encoder_dir: str = "query_encoder",
passage_encoder_dir: str = "passage_encoder"
passage_encoder_dir: str = "passage_encoder",
infer_tokenizer_classes: bool = False
):
"""
Load DensePassageRetriever from the specified directory.
Expand All @@ -431,7 +419,8 @@ def load(cls,
batch_size=batch_size,
embed_title=embed_title,
use_fast_tokenizers=use_fast_tokenizers,
similarity_function=similarity_function
similarity_function=similarity_function,
infer_tokenizer_classes=infer_tokenizer_classes
)
logger.info(f"DPR model loaded from {load_dir}")

Expand Down