diff --git a/README.md b/README.md index c48ddcfa0a79a..e0954f6cb329f 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) +- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 5d7f401cc6e2c..8bc747770e098 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -62,6 +62,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`MPTForCausalLM` - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. + * - :code:`OLMoForCausalLM` + - OLMo + - :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc. * - :code:`OPTForCausalLM` - OPT, OPT-IML - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 40858a517b311..e44452e9893cf 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -5,11 +5,20 @@ import pytest MODELS = [ - "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", - "mistralai/Mistral-7B-v0.1", "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2", - "bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b", - "EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b", - "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t" + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", + "mistralai/Mistral-7B-v0.1", + "Deci/DeciLM-7b", + "tiiuae/falcon-7b", + "gpt2", + "bigcode/tiny_starcoder_py", + "EleutherAI/gpt-j-6b", + "EleutherAI/pythia-70m", + "bigscience/bloom-560m", + "mosaicml/mpt-7b", + "microsoft/phi-2", + "stabilityai/stablelm-3b-4e1t", + "allenai/OLMo-1B", ] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 5cba1cf0414db..0f6a4bd9a4ad6 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -35,6 +35,7 @@ # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py new file mode 100644 index 0000000000000..2eb42935e8bfd --- /dev/null +++ b/vllm/model_executor/models/olmo.py @@ -0,0 +1,378 @@ +# coding=utf-8 +# Adapted from +# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and +# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py +# Copyright 2023 The vLLM team. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# +# BSD 3-Clause License +# +# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Inference-only OLMo model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, ) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.olmo import OLMoConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class SwiGLU(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + @property + def output_multiplier(self) -> float: + return 0.5 + + +class OlmoAttention(nn.Module): + """ + This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OLMoConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.d_model + assert config.d_model % config.n_heads == 0 + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = self.config.n_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // self.total_num_heads + + # Layer norms. + self.attn_norm = nn.LayerNorm(config.d_model, + elementwise_affine=False, + bias=False) + # Attention input projection. Projects x -> (q, k, v) + self.att_proj = QKVParallelLinear( + config.d_model, + self.head_dim, + self.total_num_heads, + bias=config.include_bias, + linear_method=linear_method, + ) + + # Rotary embeddings. + if self.config.rope: + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.scaling) + + # Attention output projection. + self.attn_out = RowParallelLinear( + config.d_model, + config.d_model, + bias=config.include_bias, + linear_method=linear_method, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.attn_norm(hidden_states) + qkv, _ = self.att_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + if self.config.rope: + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.attn_out(attn_output) + return output + + +class OlmoMLP(nn.Module): + """ + This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OLMoConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size + is not None else config.mlp_ratio * config.d_model) + + # Layer norms. + self.ff_norm = nn.LayerNorm(config.d_model, + elementwise_affine=False, + bias=False) + + # Feed-forward input projection. + self.ff_proj = ColumnParallelLinear( + config.d_model, + self.hidden_size, + bias=config.include_bias, + linear_method=linear_method, + ) + + # Activation function. + # self.act = SiluAndMul() + # self.act.output_multiplier = 0.5 + self.act = SwiGLU() + assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + + # Feed-forward output projection. + self.ff_out = RowParallelLinear( + int(self.act.output_multiplier * self.hidden_size), + config.d_model, + bias=config.include_bias, + linear_method=linear_method, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + x = self.ff_norm(x) + x, _ = self.ff_proj(x) + x = self.act(x) + x, _ = self.ff_out(x) + x = og_x + x + + return x + + +class OlmoBlock(nn.Module): + """ + This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, + config: OLMoConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + # Attention block. + self.attn = OlmoAttention(config, linear_method) + + # MLP block. + self.mlp = OlmoMLP(config, linear_method) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Attention block. + og_x = hidden_states + x = self.attn(positions, hidden_states, kv_cache, input_metadata) + x = x + og_x + + # MLP block. + hidden_states = self.mlp(x) + return hidden_states + + +class OlmoModel(nn.Module): + + def __init__(self, + config: OLMoConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + + self.transformer = nn.ModuleDict( + dict( + wte=VocabParallelEmbedding( + config.embedding_size or config.vocab_size, + config.d_model, + ), + ln_f=nn.LayerNorm(config.d_model, + elementwise_affine=False, + bias=False), + )) + + blocks = [ + OlmoBlock(config, linear_method) for i in range(config.n_layers) + ] + if self.config.block_group_size > 1: + raise NotImplementedError("Block group size > 1 not supported yet") + else: + self.transformer.update({"blocks": nn.ModuleList(blocks)}) + + if not config.weight_tying: + self.transformer.update({ + "ff_out": + ColumnParallelLinear( + config.d_model, + config.embedding_size or config.vocab_size, + bias=config.include_bias, + linear_method=linear_method, + ) + }) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.wte(input_ids) # type: ignore + + # Apply blocks one-by-one. + for block_idx, block in enumerate(self.transformer.blocks): + # shape: (batch_size, seq_len, d_model) + x = block( + positions, + x, + kv_caches[block_idx], + input_metadata, + ) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + x = self.transformer.ln_f(x) # type: ignore + return x + + +class OLMoForCausalLM(nn.Module): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, + config: OLMoConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = OlmoModel(config, linear_method) + self.lm_head_weight = (self.model.transformer.wte.weight + if config.weight_tying else + self.model.transformer.ff_out.weight) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + input_metadata=input_metadata, + ) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + # attention + if ".att" in name: + name = name.replace(".att", ".attn.att") + # mlp + if ".ff" in name and "transformer.ff_out" not in name: + name = name.replace(".ff", ".mlp.ff") + # there is no bias in olmo + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index bbba741ca536a..47bcc2b9594be 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,6 +1,7 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.mpt import MPTConfig +from vllm.transformers_utils.configs.olmo import OLMoConfig from vllm.transformers_utils.configs.qwen import QWenConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the @@ -11,6 +12,7 @@ "BaiChuanConfig", "ChatGLMConfig", "MPTConfig", + "OLMoConfig", "QWenConfig", "RWConfig", ] diff --git a/vllm/transformers_utils/configs/olmo.py b/vllm/transformers_utils/configs/olmo.py new file mode 100644 index 0000000000000..a9dfc6ec88ca6 --- /dev/null +++ b/vllm/transformers_utils/configs/olmo.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# adapted from https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/configuration_olmo.py +"""OLMo configuration""" +from transformers import PretrainedConfig + + +class OLMoConfig(PretrainedConfig): + model_type = 'olmo' + attribute_map = { + 'num_attention_heads': 'n_heads', + 'hidden_size': 'd_model', + 'num_hidden_layers': 'n_layers', + } + + # Note that the defaults for these attributes are equivalent to the base GPT2 model. + def __init__( + self, + d_model=768, + n_heads=12, + n_layers=12, + mlp_ratio=4, + mlp_hidden_size=None, + activation_type="swiglu", + block_type="sequential", + block_group_size=1, + alibi=False, + alibi_bias_max=8.0, + rope=False, + rope_full_precision=True, + multi_query_attention=False, + attention_layer_norm=False, + layer_norm_type="default", + layer_norm_with_affine=True, + attention_layer_norm_with_affine=True, + max_sequence_length=1024, + include_bias=True, + bias_for_layer_norm=None, + scale_logits=False, + vocab_size=50257, + embedding_size=50304, + weight_tying=True, + eos_token_id=50256, + pad_token_id=50256, + **kwargs, + ): + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.mlp_ratio = mlp_ratio + self.mlp_hidden_size = mlp_hidden_size + self.activation_type = activation_type + self.block_type = block_type + self.block_group_size = block_group_size + self.alibi = alibi + self.alibi_bias_max = alibi_bias_max + self.rope = rope + self.rope_full_precision = rope_full_precision + self.multi_query_attention = multi_query_attention + self.attention_layer_norm = attention_layer_norm + self.layer_norm_type = layer_norm_type + self.layer_norm_with_affine = layer_norm_with_affine + self.attention_layer_norm_with_affine = attention_layer_norm_with_affine + self.max_sequence_length = max_sequence_length + self.include_bias = include_bias + self.bias_for_layer_norm = bias_for_layer_norm + self.scale_logits = scale_logits + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.weight_tying = weight_tying + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + super().__init__(**kwargs)