Skip to content

Commit

Permalink
feat: adding modeling utilities as extras. [bump major] (#29)
Browse files Browse the repository at this point in the history
* feat: adding modeling utilities as extras. [bump major]

Currently supports rxnfp inference.

Included unit tests.

Including pydantic-settings.

Dropping support for py<3.7.

Signed-off-by: Matteo Manica <drugilsberg@gmail.com>

* tests: more flexibility in test_attrs.py.

Signed-off-by: Matteo Manica <drugilsberg@gmail.com>

---------

Signed-off-by: Matteo Manica <drugilsberg@gmail.com>
  • Loading branch information
drugilsberg authored Feb 13, 2024
1 parent 74b1f03 commit 914a6c2
Show file tree
Hide file tree
Showing 13 changed files with 1,661 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
with:
python-version: 3.9
- name: Install Dependencies
run: pip install -e .[dev]
run: pip install -e ".[dev, modeling]"
- name: Install additional dependencies (for pydantic>2)
run: pip install pydantic_settings
- name: Check black
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ requires = ["setuptools >= 59.2.0", "wheel"]
build-backend = "setuptools.build_meta"

[tool.mypy]
strict = true
strict = false

[[tool.mypy.overrides]]
module = [
"diskcache.*",
"pymongo.*",
"pydantic.*"
"pydantic.*",
"transformers.*"
]
ignore_missing_imports = true

Expand Down
7 changes: 5 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ project_urls =
classifiers =
Operating System :: OS Independent
Programming Language :: Python :: 3
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Expand All @@ -25,14 +24,15 @@ classifiers =
package_dir =
= src
packages = find_namespace:
python_requires = >= 3.6
python_requires = >= 3.7
zip_safe = False
include_package_data = True
install_requires =
attrs>=21.2.0
click>=8.0
diskcache>=5.2.1
pydantic>=1.9.0
pydantic_settings>=2.1.0
pymongo>=3.9.0
tqdm>=4.31.0
typing-extensions>=4.1.1
Expand All @@ -55,6 +55,9 @@ dev =
pytest>=7.0.1
types-setuptools>=57.4.14
types-tqdm>=4.64.0
modeling =
torch>=1.5.0,<2.0.0
transformers>=4.21.0

[options.entry_points]
console_scripts =
Expand Down
1 change: 1 addition & 0 deletions src/rxn/utilities/databases/pymongo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""PyMongo-related utilities."""

import os
from functools import lru_cache
from typing import Any, Dict, Optional
Expand Down
1 change: 1 addition & 0 deletions src/rxn/utilities/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Module for modeling utilities."""
200 changes: 200 additions & 0 deletions src/rxn/utilities/modeling/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""Core modeling utilities."""

from collections import OrderedDict
from enum import Enum
from pathlib import Path
from typing import Dict, Iterable, Iterator, List, Optional, Union

import torch
from transformers import ( # type:ignore
AutoModel,
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
)
from transformers.modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
SequenceClassifierOutput,
)

from ..containers import chunker
from ..files import PathLike
from .tokenization import SmilesTokenizer
from .utils import device_claim, get_associated_max_len, map_dict_sequences_to_tensors


class ModelType(str, Enum):
regression = "regression"
classification = "classification"
mlm = "masked_language_modeling"
fingerprint = "fingerprint"


MODEL_TYPE_TO_MODEL_CLASS = OrderedDict(
[
(ModelType.regression, AutoModelForSequenceClassification),
(ModelType.classification, AutoModelForSequenceClassification),
(ModelType.mlm, AutoModelForMaskedLM),
(ModelType.fingerprint, AutoModel),
]
)
MODEL_TYPE_TO_KWARGS = {
ModelType.regression: {"num_labels": 1},
ModelType.classification: {},
ModelType.mlm: {},
ModelType.fingerprint: {},
}
RawModelOutput = Union[SequenceClassifierOutput, MaskedLMOutput, BaseModelOutput]
ModelOutput = Union[float, str, List[float]]


class RXNTransformersModelForReactions:
def __init__(
self,
model_name_or_path: PathLike,
model_type: ModelType,
tokenizer: SmilesTokenizer,
maximum_length: Optional[int] = None,
batch_size: Optional[int] = None,
device: Optional[Union[torch.device, str]] = None,
) -> None:
"""Construct a RXNTransformersModel.
Args:
model_name_or_path: model name or path.
model_type: model type.
tokenizer: a tokenizer for reactions.
maximum_length: maximum tokenized sequence length. If not specified,
will be determined from the model config file.
batch_size: number of reactions to predict per batch. Defaults to
having one single batch for all the reactions.
device: device where the inference is running either as a dedicated class or
a string. If not provided is inferred.
Raises:
ValueError: in case the model type is not supported.
"""
self.model_name_or_path = str(model_name_or_path)
self.model_type = model_type
if (
self.model_type in MODEL_TYPE_TO_MODEL_CLASS
and self.model_type in MODEL_TYPE_TO_KWARGS
):
self.model = MODEL_TYPE_TO_MODEL_CLASS[self.model_type].from_pretrained(
self.model_name_or_path, **MODEL_TYPE_TO_KWARGS[self.model_type]
)
else:
raise ValueError(
f"model_type={self.model_type} not supported! Select one from: {MODEL_TYPE_TO_MODEL_CLASS.keys()}"
)
self.tokenizer = tokenizer
if maximum_length is not None:
self.maximum_length = maximum_length
else:
self.maximum_length = get_associated_max_len(Path(model_name_or_path))
self.batch_size = batch_size
self.device = device_claim(device)
self.model.to(self.device)

def tokenize_batch(self, rxns: Iterable[str]) -> Dict[str, torch.Tensor]:
"""Tokenize a batch of reactions.
Args:
rxns: a list of reactions.
Returns:
a dictionary containing the token ids as well as token type ids and attention masks.
"""
return map_dict_sequences_to_tensors(
self.tokenizer.batch_encode_plus(
rxns,
add_special_tokens=True,
max_length=self.maximum_length,
padding="max_length",
truncation=True,
return_token_type_ids=True,
return_tensors="pt",
),
device=self.device,
)

def _postprocess_outputs(self, outputs: RawModelOutput) -> List[ModelOutput]:
"""Post-process raw model predictions to get the output.
Args:
outputs: raw model forward pass output.
Raises:
NotImplementedError: not implemented for base class RXNTransformersModelForReactions.
Returns:
the post-processed output.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _postprocess_outputs!"
)

def predict(self, rxns: Iterable[str]) -> Iterator[ModelOutput]:
"""Run the model on a list of examples.
Args:
rxns: a list of reactions.
Returns:
an iterator over predictions.
"""
if self.batch_size is None:
chunks: Iterable[List[str]] = [list(rxns)]
else:
chunks = chunker(rxns, self.batch_size)

for batch in chunks:
with torch.no_grad():
tokenized_batch = self.tokenize_batch(batch)
outputs = self.model(**tokenized_batch)
yield from self._postprocess_outputs(outputs)


class RXNFPModel(RXNTransformersModelForReactions):
def __init__(
self,
model_name_or_path: PathLike,
maximum_length: Optional[int] = None,
tokenizer: Optional[SmilesTokenizer] = None,
batch_size: int = 16,
device: Optional[Union[torch.device, str]] = None,
) -> None:
"""Construct a RXNFPModel.
Args:
model_name_or_path: model name or path.
maximum_length: maximum tokenized sequence length. If not specified,
will be determined from the model config file.
tokenizer: a tokenizer for reactions. Defaults to a SmilesTokenizer
loaded from the vocabulary in the model directory.
batch_size: batch size for prediction
device: device where the inference is running either as a dedicated class or
a string. If not provided is inferred.
"""
if tokenizer is None:
tokenizer = SmilesTokenizer.from_pretrained(model_name_or_path)

super().__init__(
model_name_or_path=model_name_or_path,
model_type=ModelType.fingerprint,
maximum_length=maximum_length,
tokenizer=tokenizer,
batch_size=batch_size,
device=device,
)

def _postprocess_outputs(self, outputs: RawModelOutput) -> List[ModelOutput]:
"""Post-process raw model predictions to get the output.
Args:
outputs: raw model forward pass output.
Returns:
the post-processed output.
"""
return outputs["last_hidden_state"][:, 0, :].tolist()
Loading

0 comments on commit 914a6c2

Please # to comment.