Skip to content

Commit

Permalink
[LLM INFER] Fix some bugs and chatglm_v2 support block_attn (#9271)
Browse files Browse the repository at this point in the history
* chatglm2 support block_attn and fix some bugs
* fix ci
* fix more ut error
* update
  • Loading branch information
yuanlehome authored Oct 25, 2024
1 parent b237ba7 commit 2e8b220
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 138 deletions.
56 changes: 31 additions & 25 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,16 @@ class PredictorArgument:
},
)

@property
def total_max_length(self):
if self.device == "npu":
return self.src_length + self.max_length
else:
return 8192 # Maximum sequence length.
total_max_length: int = field(
default=4096, metadata={"help": "Super parameter. Maximum sequence length(encoder+decoder)."}
)

def __post_init__(self):
if self.append_attn:
self.block_attn = True
assert (
self.src_length + self.max_length <= self.total_max_length
), "src_length + max_length should smaller than total_max_length."


@dataclass
Expand Down Expand Up @@ -520,7 +520,7 @@ def _preprocess(self, source):
alibi_slopes = llm_utils.get_alibi_slopes(self.model_config.n_head)
inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32")
arange_tensor_encoder = paddle.arange(self.config.total_max_length, dtype=self.config.dtype)
alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder
alibi = (alibi_slopes[None, :, None, None] * arange_tensor_encoder).astype(self.config.dtype)

if self.model_config.tensor_parallel_degree > 1:
block_size = self.model_config.n_head // self.model_config.tensor_parallel_degree
Expand Down Expand Up @@ -1352,13 +1352,19 @@ def create_predictor(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()

elif "chatglmv2forcausallm" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel as Model,
)

model = Model.from_pretrained(
predictor_args.total_max_length = config.seq_length
if predictor_args.block_attn:
config.block_size = predictor_args.block_size
config.max_seq_len = predictor_args.total_max_length
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMBlockInferenceModel as ChatGLMv2InferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel as ChatGLMv2InferenceModel,
)
model = ChatGLMv2InferenceModel.from_pretrained(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()
Expand Down Expand Up @@ -1522,19 +1528,19 @@ def create_predictor(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "chatglmv2forcausallm" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel,
)

cache_kvs_shape = ChatGLMv2ForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "chatglmv2forcausallm" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel,
)
predictor_args.total_max_length = config.seq_length
if predictor_args.block_attn:
config.block_size = predictor_args.block_size
config.max_seq_len = predictor_args.total_max_length
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMBlockInferenceModel as ChatGLMv2InferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel as ChatGLMv2InferenceModel,
)

cache_kvs_shape = ChatGLMv2ForCausalLMInferenceModel.get_cache_kvs_shape(
cache_kvs_shape = ChatGLMv2InferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "chatglmforcausallm" in config.architectures[0].lower():
Expand Down
208 changes: 198 additions & 10 deletions paddlenlp/experimental/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
from paddle.nn.quant import weight_quantize

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedBlockMultiTransformer,
FusedBlockMultiTransformerWeightOnly,
FusedMultiTransformerBase,
FusedMultiTransformerConfig,
FusedMultiTransformerWeightOnly,
)
from paddlenlp.experimental.transformers.generation_utils import (
GenerationBlockInferenceModel,
GenerationInferenceModel,
)
from paddlenlp.experimental.transformers.utils import infererence_model_from_pretrained
from paddlenlp.transformers import ChatGLMv2Config, ChatGLMv2PretrainedModel
from paddlenlp.transformers.chatglm_v2.modeling import (
Embedding,
Expand All @@ -39,9 +43,7 @@
register_base_model,
)

__all__ = [
"ChatGLMv2ForCausalLMInferenceModel",
]
__all__ = ["ChatGLMv2ForCausalLMInferenceModel", "ChatGLMv2ForCausalLMBlockInferenceModel"]


@register_base_model
Expand Down Expand Up @@ -176,17 +178,20 @@ def __init__(self, config: ChatGLMv2Config, empty_init=True):
kv_num_heads=config.multi_query_group_num,
)

if self.use_weight_only:
self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config)
else:
self.transformer_block = FusedMultiTransformerBase(transformer_config)
self.set_transformer_block(transformer_config)

self.post_layer_norm = config.post_layer_norm
if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else nn.LayerNorm
# Final layer norm before output.
self.final_layernorm = LayerNormFunc(config.hidden_size, epsilon=config.layernorm_epsilon, config=config)

def set_transformer_block(self, transformer_config):
if self.use_weight_only:
self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config)
else:
self.transformer_block = FusedMultiTransformerBase(transformer_config)

def get_input_embeddings(self):
return self.embedding.word_embeddings

Expand Down Expand Up @@ -341,7 +346,7 @@ def key(name):

if self.use_weight_only:
linear_quanted_weight_tensor, linear_weight_scale_tensor = weight_quantize(
out_proj_weight, algo=self.quant_algo
paddle.to_tensor(out_proj_weight), algo=self.quant_algo
)
self.transformer_block.linear_weights[i].set_value(linear_quanted_weight_tensor)
self.transformer_block.linear_weights_scale[i].set_value(linear_weight_scale_tensor)
Expand All @@ -352,7 +357,7 @@ def key(name):

if self.use_weight_only:
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
ffn1_weight, algo=self.quant_algo
paddle.to_tensor(ffn1_weight), algo=self.quant_algo
)
self.transformer_block.ffn1_weights[i].set_value(ffn1_quanted_weight_tensor)
self.transformer_block.ffn1_weights_scale[i].set_value(ffn1_weight_scale_tensor)
Expand All @@ -361,20 +366,87 @@ def key(name):

if self.use_weight_only:
ffn2_quanted_weight_tensor, ffn2_weight_scale_tensor = weight_quantize(
ffn2_weight, algo=self.quant_algo
paddle.to_tensor(ffn2_weight), algo=self.quant_algo
)
self.transformer_block.ffn2_weights[i].set_value(ffn2_quanted_weight_tensor)
self.transformer_block.ffn2_weights_scale[i].set_value(ffn2_weight_scale_tensor)
else:
self.transformer_block.ffn2_weights[i].set_value(ffn2_weight)


@register_base_model
class ChatGLMv2BlockInferenceModel(ChatGLMv2InferenceModel):
def __init__(self, config: ChatGLMv2Config):
super().__init__(config)
self.max_seq_len = config.max_sequence_length
self.block_size = config.block_size

def set_transformer_block(self, transformer_config):
if self.use_weight_only:
self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config)
else:
self.transformer_block = FusedBlockMultiTransformer(transformer_config)

def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset_v2

ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k

def forward(
self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
caches=None,
pre_caches=None,
output_attentions=False,
output_hidden_states=None,
return_dict=False,
**kwargs,
):
seq_lens_this_time = kwargs.get("seq_lens_this_time", None)
rope_emb = kwargs.get("rope_emb", None)
ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding(
input_ids, seq_lens_this_time
)
kwargs["cu_seqlens_q"] = cu_seqlens_q
kwargs["cu_seqlens_k"] = cu_seqlens_k
kwargs["padding_offsets"] = padding_offset
kwargs["max_input_length"] = self.max_seq_len

inputs_embeds = self.embedding.word_embeddings(ids_remove_padding)

with dy2st_nocheck_guard_context():
hidden_states, _ = self.transformer_block(
input_ids=input_ids,
src=inputs_embeds,
cum_offsets=cum_offsets,
attn_mask=attention_mask,
caches=caches,
pre_caches=None,
rotary_embs=rope_emb,
**kwargs,
)
hidden_states = self.final_layernorm(hidden_states)

return tuple(v for v in [hidden_states, None, None, None] if v is not None)


class ChatGLMv2ForCausalLMInferenceModel(GenerationInferenceModel, ChatGLMv2PretrainedModel):
def __init__(self, config: ChatGLMv2Config):
super().__init__(config)
self.max_sequence_length = config.max_sequence_length
self.chatglm_v2 = ChatGLMv2InferenceModel(config)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs)

@classmethod
def get_cache_kvs_shape(cls, config: ChatGLMv2Config, max_batch_size: int = None, max_length: int = None):
"""get cache_kvs tensor for opt model
Expand Down Expand Up @@ -487,3 +559,119 @@ def forward(
@paddle.no_grad()
def set_state_dict(self, state_dict):
self.chatglm_v2.set_state_dict(state_dict)


class ChatGLMv2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, ChatGLMv2PretrainedModel):
def __init__(self, config):
super().__init__(config)
self.chatglm_v2 = ChatGLMv2BlockInferenceModel(config)
self.max_sequence_length = config.max_sequence_length

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs)

@classmethod
def get_cache_kvs_shape(cls, config: ChatGLMv2Config, max_batch_size: int = None, max_length: int = None):
"""get cache_kvs tensor for chatglmv2 model
Args:
max_batch_size (int): the max batch size
max_length (int | None, optional): the max_length of cache_kvs. Defaults to None.
Returns:
list[paddle.Tensor]: the list tensor shape for cache
"""
max_block_per_seq = (config.max_seq_len + config.block_size - 1) // config.block_size
if max_batch_size == -1:
max_block_nums = None
else:
max_block_nums = max_batch_size * max_block_per_seq

cache_kvs = []
for _ in range(config.num_hidden_layers):
cache_kv_shape = [
max_block_nums,
config.multi_query_group_num,
config.block_size,
config.hidden_size // config.num_attention_heads,
]
cache_kvs.append(cache_kv_shape)
cache_kvs.append(cache_kv_shape)
return cache_kvs

def prepare_inputs_for_generation(self, **kwargs):
# only last token for inputs_ids if cache is defined in kwargs
input_ids = kwargs["input_ids"]
src_mask = kwargs.get("src_mask", None)
block_tables = kwargs.get("block_tables", None)

pre_caches = kwargs.get("pre_caches", None)
caches = kwargs.get("caches", None)

rope_emb = kwargs["rope_emb"]
seq_lens_this_time = kwargs["seq_lens_this_time"]
seq_lens_encoder = kwargs["seq_lens_encoder"]
seq_lens_decoder = kwargs["seq_lens_decoder"]
k_quant_scales = kwargs.get("k_quant_scales", None)
v_quant_scales = kwargs.get("v_quant_scales", None)
k_dequant_scales = kwargs.get("k_dequant_scales", None)
v_dequant_scales = kwargs.get("v_dequant_scales", None)
model_inputs = {
"input_ids": input_ids,
"src_mask": src_mask,
"rope_emb": rope_emb,
"pre_caches": pre_caches,
"caches": caches,
"seq_lens_this_time": seq_lens_this_time,
"seq_lens_encoder": seq_lens_encoder,
"seq_lens_decoder": seq_lens_decoder,
"block_tables": block_tables,
"k_quant_scales": k_quant_scales,
"v_quant_scales": v_quant_scales,
"k_dequant_scales": k_dequant_scales,
"v_dequant_scales": v_dequant_scales,
}
return model_inputs

def forward(
self,
input_ids,
src_mask=None,
pre_caches=None,
caches=None,
seq_lens_this_time=None,
seq_lens_encoder=None,
seq_lens_decoder=None,
rope_emb=None,
block_tables=None,
k_quant_scales=None,
v_quant_scales=None,
k_dequant_scales=None,
v_dequant_scales=None,
):
outputs = self.chatglm_v2(
input_ids,
src_mask=src_mask,
caches=caches,
rope_emb=rope_emb,
block_tables=block_tables,
pre_caches=pre_caches,
seq_lens_this_time=seq_lens_this_time,
seq_lens_encoder=seq_lens_encoder,
seq_lens_decoder=seq_lens_decoder,
k_quant_scales=k_quant_scales,
v_quant_scales=v_quant_scales,
k_dequant_scales=k_dequant_scales,
v_dequant_scales=v_dequant_scales,
)

hidden_states = outputs[0]
lm_logits = self.chatglm_v2.output_layer(hidden_states)
output = (lm_logits,) + outputs[1:]

return output

@paddle.no_grad()
def set_state_dict(self, state_dict):
self.chatglm_v2.set_state_dict(state_dict)
Loading

0 comments on commit 2e8b220

Please # to comment.