From 274d95dab2f8795d51d8e21f5d7c7a0343f6bd84 Mon Sep 17 00:00:00 2001 From: jimpang Date: Tue, 21 Nov 2023 15:40:53 +0800 Subject: [PATCH] feat: support baichuan2 --- vllm/model_executor/model_loader.py | 10 ++ vllm/model_executor/models/__init__.py | 6 +- vllm/model_executor/models/baichuan.py | 148 ++++++++++++++++--------- 3 files changed, 113 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 54b87c4b866e3..90383bc0d1245 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -17,6 +17,8 @@ "AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2 "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b "BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b + "BaiChuan2ForCausalLM": BaiChuan2ForCausalLM, # baichuan2-rope + "Baichuan2ForCausalLM": Baichuan2ForCausalLM, # baichuan2-alibi "BloomForCausalLM": BloomForCausalLM, "ChatGLMModel": ChatGLMForCausalLM, "FalconForCausalLM": FalconForCausalLM, @@ -52,6 +54,14 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: architectures = getattr(config, "architectures", []) for arch in architectures: if arch in _MODEL_REGISTRY: + # baichuan 2 has different vocab size + if ("baichuan" in arch.lower()) and (getattr(config, "vocab_size") + == 125696): + # baichuan 2 7b and 13b have different intermediate size + if getattr(config, "intermediate_size") == 11008: + return BaiChuan2ForCausalLM + elif getattr(config, "intermediate_size") == 13696: + return Baichuan2ForCausalLM return _MODEL_REGISTRY[arch] raise ValueError( f"Model architectures {architectures} are not supported for now. " diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 078d3d74719df..9babb8e0cd8ef 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,6 +1,8 @@ from vllm.model_executor.models.aquila import AquilaForCausalLM from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM, - BaichuanForCausalLM) + BaichuanForCausalLM, + BaiChuan2ForCausalLM, + Baichuan2ForCausalLM) from vllm.model_executor.models.bloom import BloomForCausalLM from vllm.model_executor.models.falcon import FalconForCausalLM from vllm.model_executor.models.gpt2 import GPT2LMHeadModel @@ -21,6 +23,8 @@ "AquilaForCausalLM", "BaiChuanForCausalLM", "BaichuanForCausalLM", + "BaiChuan2ForCausalLM", + "Baichuan2ForCausalLM", "BloomForCausalLM", "ChatGLMForCausalLM", "FalconForCausalLM", diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 93cbc1a8516a7..739d2fb86df9d 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -28,18 +28,17 @@ import torch from torch import nn +from vllm.logger import init_logger from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, - PagedAttentionWithALiBi) +from vllm.model_executor.layers.attention import (PagedAttentionWithALiBi, PagedAttentionWithRoPE) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) + QKVParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.layers.vocab_parallel_embedding import (ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.weight_utils import (default_weight_loader, @@ -47,13 +46,15 @@ from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.baichuan import BaiChuanConfig +logger = init_logger(__name__) + KVCache = Tuple[torch.Tensor, torch.Tensor] def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) @@ -61,7 +62,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) num_remaining_heads = min(closest_power_of_2, @@ -78,11 +79,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: class BaiChuanMLP(nn.Module): def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -109,13 +110,13 @@ class BaiChuanAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( - self, - hidden_size: int, - num_heads: int, - position_embedding: str, - rope_theta: float = 10000, - max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + self, + hidden_size: int, + num_heads: int, + position_embedding: str, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.hidden_size = hidden_size @@ -153,11 +154,11 @@ def __init__( alibi_slopes = _get_alibi_slopes(self.total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end].tolist() - scaling = self.head_dim**-0.5 + scaling = self.head_dim ** -0.5 self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, scaling, alibi_slopes) else: - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.attn = PagedAttentionWithRoPE( self.num_heads, self.head_dim, @@ -167,12 +168,12 @@ def __init__( max_position=self.max_position_embeddings) def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) @@ -219,13 +220,13 @@ def __init__(self, eps=config.rms_norm_eps) def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: KVCache, - input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], - residual: Optional[torch.Tensor], + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -271,12 +272,12 @@ def __init__(self, self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -295,29 +296,62 @@ def forward( return hidden_states +class NormHead(ColumnParallelLinear): + + def __init__(self, hidden_size, vocab_size, bias=False): + super().__init__(hidden_size, + vocab_size, + bias=False, + gather_output=False) + self.first_flag = True + + def get_weight(self): + if self.first_flag: + self.first_flag = False + self.weight = nn.Parameter(nn.functional.normalize(self.weight)) + return self.weight + + def forward(self, hidden_states): + if self.first_flag: + self.first_flag = False + self.weight = nn.Parameter(nn.functional.normalize(self.weight)) + return ColumnParallelLinear.forward(self, hidden_states) + + class BaiChuanBaseForCausalLM(nn.Module): def __init__(self, config, position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + linear_method: Optional[LinearMethodBase] = None, version: str = "1"): super().__init__() self.config = config self.linear_method = linear_method self.model = BaiChuanModel(config, position_embedding, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.version = version + if self.version == "1": + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + elif self.version == "2": + self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False) + else: + raise ValueError("Only support baichuan version 1 and 2") + self.sampler = Sampler(config.vocab_size) def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], ) -> SamplerOutput: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) + + lm_head_weight = self.lm_head.weight + if self.version == "2": + lm_head_weight = self.lm_head.get_weight() next_tokens = self.sampler(self.lm_head.weight, hidden_states, input_metadata) return next_tokens @@ -365,3 +399,17 @@ def __init__(self, config, linear_method: Optional[LinearMethodBase] = None): super().__init__(config, "ROPE", linear_method) + + +class Baichuan2ForCausalLM(BaiChuanBaseForCausalLM): # baichuan2 13b + + def __init__(self, config, linear_method: Optional[LinearMethodBase] = None): + logger.info("start init Baichuan2ForCausalLM for 13B version") + super().__init__(config, "ALIBI", linear_method, "2") + + +class BaiChuan2ForCausalLM(BaiChuanBaseForCausalLM): # baichuan2 7b + + def __init__(self, config, linear_method: Optional[LinearMethodBase] = None): + logger.info("start init Baichuan2ForCausalLM for 7B version") + super().__init__(config, "ROPE", linear_method, "2")