@@ -673,6 +673,44 @@ def set_gguf_parameters(self):
673
673
self .gguf_writer .add_parallel_residual (self .hparams .get ("use_parallel_residual" , True ))
674
674
self .gguf_writer .add_layer_norm_eps (self .hparams ["layer_norm_eps" ])
675
675
676
+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
677
+ del bid # unused
678
+
679
+ n_head = self .hparams .get ("n_head" , self .hparams .get ("num_attention_heads" ))
680
+ n_embed = self .hparams .get ("hidden_size" , self .hparams .get ("n_embed" ))
681
+
682
+ tensors : list [tuple [str , Tensor ]] = []
683
+
684
+ if re .match (r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight" , name ):
685
+ # Map bloom-style qkv_linear to gpt-style qkv_linear
686
+ # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
687
+ # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
688
+ qkv_weights = data_torch .reshape ((n_head , 3 , n_embed // n_head , n_embed ))
689
+ data_torch = torch .cat (
690
+ (
691
+ qkv_weights [:, 0 , :, :].reshape ((- 1 , n_embed )),
692
+ qkv_weights [:, 1 , :, :].reshape ((- 1 , n_embed )),
693
+ qkv_weights [:, 2 , :, :].reshape ((- 1 , n_embed )),
694
+ ),
695
+ dim = 0 ,
696
+ )
697
+ logger .info ("re-format attention.linear_qkv.weight" )
698
+ elif re .match (r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.bias" , name ):
699
+ qkv_bias = data_torch .reshape ((n_head , 3 , n_embed // n_head ))
700
+ data_torch = torch .cat (
701
+ (
702
+ qkv_bias [:, 0 , :].reshape ((n_embed ,)),
703
+ qkv_bias [:, 1 , :].reshape ((n_embed ,)),
704
+ qkv_bias [:, 2 , :].reshape ((n_embed ,)),
705
+ ),
706
+ dim = 0 ,
707
+ )
708
+ logger .info ("re-format attention.linear_qkv.bias" )
709
+
710
+ tensors .append ((self .map_tensor_name (name ), data_torch ))
711
+
712
+ return tensors
713
+
676
714
677
715
@Model .register ("BloomForCausalLM" )
678
716
class BloomModel (Model ):
0 commit comments