Skip to content

Commit b98d89e

Browse files
authored
[Misc] Medusa supports custom bias (#10361)
1 parent 8b6725b commit b98d89e

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

vllm/model_executor/models/medusa.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
class ResidualBlock(nn.Module):
1616

17-
def __init__(self, hidden_size: int, num_layers: int) -> None:
17+
def __init__(self, config: VllmConfig, hidden_size: int,
18+
num_layers: int) -> None:
1819
super().__init__()
1920

2021
self.layers = nn.ModuleList([
21-
nn.Linear(hidden_size, hidden_size, bias=False)
22+
nn.Linear(hidden_size,
23+
hidden_size,
24+
bias=getattr(config, "medusa_fc_bias", False))
2225
for _ in range(num_layers)
2326
])
2427
self.act = nn.SiLU()
@@ -49,7 +52,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
4952
super().__init__()
5053
self.config = config
5154
self.blocks = nn.ModuleList([
52-
ResidualBlock(hidden_size=self.config.hidden_size,
55+
ResidualBlock(config=config,
56+
hidden_size=self.config.hidden_size,
5357
num_layers=self.config.num_hidden_layers)
5458
for _ in range(self.config.num_heads)
5559
])

0 commit comments

Comments
 (0)