-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
- Loading branch information
1 parent
f7e23fb
commit 10ebc9e
Showing
2 changed files
with
76 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Optional | ||
|
||
from torch import nn | ||
from transformers import RobertaConfig | ||
|
||
from vllm.config import CacheConfig | ||
from vllm.model_executor.layers.pooler import Pooler, PoolingConfig | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
VocabParallelEmbedding) | ||
from vllm.model_executor.models.bert import (BertEmbedding, BertEmbeddingModel, | ||
BertEncoder, BertModel) | ||
|
||
|
||
class RobertaModel(BertModel): | ||
|
||
def __init__( | ||
self, | ||
config: RobertaConfig, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
): | ||
# Skip BertModel.__init__() | ||
nn.Module.__init__(self) | ||
self.embeddings = RobertaEmbedding(config) | ||
self.encoder = BertEncoder(config, cache_config, quant_config) | ||
|
||
|
||
class RobertaEmbedding(BertEmbedding): | ||
|
||
def __init__(self, config: RobertaConfig): | ||
# Skip BertEmbedding.__init__() | ||
nn.Module.__init__(self) | ||
self.size = config.hidden_size | ||
self.word_embeddings = VocabParallelEmbedding(config.vocab_size, | ||
config.hidden_size) | ||
self.padding_idx = config.pad_token_id | ||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, | ||
config.hidden_size, | ||
padding_idx=self.padding_idx) | ||
|
||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, | ||
config.hidden_size) | ||
self.LayerNorm = nn.LayerNorm(config.hidden_size, | ||
eps=config.layer_norm_eps) | ||
|
||
self.position_embedding_type = config.position_embedding_type | ||
if self.position_embedding_type != "absolute": | ||
raise ValueError("Only 'absolute' position_embedding_type" + | ||
" is supported") | ||
|
||
|
||
class RobertaEmbeddingModel(BertEmbeddingModel): | ||
"""A model that uses Roberta to provide embedding functionalities. | ||
This class encapsulates the RobertaModel and provides an interface for | ||
embedding operations and customized pooling functions. | ||
Attributes: | ||
model: An instance of RobertaModel used for forward operations. | ||
_pooler: An instance of Pooler used for pooling operations. | ||
""" | ||
|
||
def __init__(self, | ||
config: RobertaConfig, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
pooling_config: Optional[PoolingConfig] = None) -> None: | ||
# Skip BertEmbeddingModule.__init__() | ||
nn.Module.__init__(self) | ||
self.model = RobertaModel(config, cache_config, quant_config) | ||
self._pooler = Pooler(pooling_config.pooling_type, | ||
pooling_config.normalize) |