-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathmodelutils_mixtral.py
149 lines (124 loc) · 5.67 KB
/
modelutils_mixtral.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gc
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
from qLinearLayer import find_qlinear_layers
from qMixtralLayer import QMixtralDecoderLayer
from gptq import GPTQ, Quantizer_GPTQ
from functools import partial
from quant import quantize_activation_wrapper, quantize_attn_v_wrapper, quantize_attn_k_wrapper
def reorder_model_mixtral(model, device, args, reorder_index):
model.config.use_cache = False
layers = model.model.layers
assert reorder_index is not None, "Reorder index is None"
for i in tqdm(range(len(layers))):
layers[i] = layers[i].to(device)
layers[i] = layers[i].to(device)
if isinstance(layers[i], MixtralDecoderLayer):
m = QMixtralDecoderLayer(
originalLayer=layers[i],
args=args,
)
elif isinstance(layers[i], QMixtralDecoderLayer):
m = layers[i]
# reordering for the attention
nameTemplate = 'layers.{}.{}.{}.{}' # Something like layers.10.self_attn.q_proj
m.input_layernorm.register_buffer('reorder_index',
reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')] # Random choose one from k,q,v proj.
)
# K has outlier should be kept.
# Not reorder due to the RoPE embedding.
m.self_attn.q_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')],
out_reorder_index=None
)
m.self_attn.k_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')],
out_reorder_index=None
)
m.self_attn.v_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'k_proj', 'input')],
out_reorder_index=None
)
m.self_attn.o_proj.reorder(
in_reorder_index=reorder_index[nameTemplate.format(i, 'self_attn', 'o_proj', 'input')],
out_reorder_index=None
)
m.self_attn.register_buffer('reorder_index', reorder_index[nameTemplate.format(i, 'self_attn', 'o_proj', 'input')])
# reordering for the MoE
nameTemplate_moe = 'layers.{}.{}.{}.{}.{}.{}' # Something like layers.10.block_sparse_moe.experts.1.w1
# pick expert.0.w1's order and reorder all related modules
m.block_sparse_moe.gate.reorder(
in_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w1', 'input')],
out_reorder_index=None
)
num_experts = m.block_sparse_moe.num_experts
for j in range(num_experts):
m.block_sparse_moe.experts[j].w1.reorder(
in_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w1', 'input')],
out_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w2', 'input')]
)
m.block_sparse_moe.experts[j].w3.reorder(
in_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w1', 'input')],
out_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w2', 'input')]
)
m.block_sparse_moe.experts[j].w2.reorder(
in_reorder_index=reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w2', 'input')],
out_reorder_index=None
)
m.post_attention_layernorm.register_buffer('reorder_index',
reorder_index[nameTemplate_moe.format(i, 'block_sparse_moe', 'experts', 0, 'w1', 'input')],
)
layers[i] = layers[i].cpu()
layers[i] = m.cpu()
del m
torch.cuda.empty_cache()
return model
def add_act_quant_wrapper_mixtral(model, device, args, scales):
model.config.use_cache = False
layers = model.model.layers
for i in tqdm(range(len(layers))):
if isinstance(layers[i], MixtralDecoderLayer):
m = QMixtralDecoderLayer(
originalLayer=layers[i],
args=args,
)
elif isinstance(layers[i], QMixtralDecoderLayer):
m = layers[i]
else:
continue
m = m.to(device)
m.self_attn.act_quant = partial(quantize_activation_wrapper, args=args)
m.self_attn.v_quant = partial(quantize_attn_v_wrapper, args=args)
m.self_attn.k_quant = partial(quantize_attn_k_wrapper, args=args)
for expert in m.block_sparse_moe.experts:
expert.act_quant = partial(quantize_activation_wrapper, args=args)
m.act_quant = partial(quantize_activation_wrapper, args=args)
m.block_sparse_moe.act_quant = partial(quantize_activation_wrapper, args=args)
layers[i] = m.cpu()
torch.cuda.empty_cache()
return model
def quantize_model_mixtral(model, device, args):
model.config.use_cache = False
layers = model.model.layers
for i in tqdm(range(len(layers))):
if isinstance(layers[i], MixtralDecoderLayer):
m = QMixtralDecoderLayer(
originalLayer=layers[i],
args=args,
)
elif isinstance(layers[i], QMixtralDecoderLayer):
m = layers[i]
else:
continue
m = m.to(device)
for expert in m.block_sparse_moe.experts:
expert.quant()
m.self_attn.q_proj.quant()
m.self_attn.k_proj.quant()
m.self_attn.v_proj.quant()
m.self_attn.o_proj.quant()
layers[i] = m.cpu()
torch.cuda.empty_cache()
return model