Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Bugfix] Fix the LoRA weight sharding in ColumnParallelLinearWithLoRA #10450

Merged
merged 25 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
161982b
Init
jeejeelee Nov 19, 2024
5d25c64
Merge branch 'vllm-project:main' into fix-merged-linear-lora
jeejeelee Nov 19, 2024
74767cb
Complete weight shard logic
jeejeelee Nov 19, 2024
a9ad377
Merge branch 'vllm-project:main' into fix-merged-linear-lora
jeejeelee Nov 19, 2024
c054ddf
Merge branch 'vllm-project:main' into fix-merged-linear-lora
jeejeelee Nov 20, 2024
f0e8f31
Add todo for bias slice
jeejeelee Nov 21, 2024
097b003
Merge branch 'vllm-project:main' into fix-merged-linear-lora
jeejeelee Nov 21, 2024
3fa2fb7
Done
jeejeelee Nov 22, 2024
5de45db
Merge branch 'vllm-project:main' into fix-merged-linear-lora
jeejeelee Nov 22, 2024
ff771cd
Update vllm/lora/layers.py
jeejeelee Nov 22, 2024
5f66271
Update vllm/lora/fully_sharded_layers.py
jeejeelee Nov 22, 2024
a76016e
Format code
jeejeelee Nov 22, 2024
9abad3c
Add LoRA TP test
jeejeelee Nov 22, 2024
3aa890f
Done
jeejeelee Nov 22, 2024
efb37a4
Optimize unit test
jeejeelee Nov 22, 2024
d7ae951
Merge branch 'vllm-project:main' into fix-merged-linear-lora
jeejeelee Nov 22, 2024
fe826eb
Configure LoRA TP test
jeejeelee Nov 22, 2024
0f38dde
Make yapf happy
jeejeelee Nov 22, 2024
b99b893
Optimize unit test
jeejeelee Nov 22, 2024
80a238b
Delete empty line
jeejeelee Nov 22, 2024
b7f0479
Fix conftext bug
jeejeelee Nov 22, 2024
83b76c6
Fix chatglm bug
jeejeelee Nov 22, 2024
251ab41
Fix chatglm bug
jeejeelee Nov 22, 2024
7f0da81
Merge branch 'vllm-project:main' into fix-merged-linear-lora
jeejeelee Nov 23, 2024
34f9381
Merge branch 'vllm-project:main' into fix-merged-linear-lora
jeejeelee Nov 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ steps:
source_file_dependencies:
- vllm/lora
- tests/lora
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore lora/test_long_context.py lora/test_chatglm3_tp.py lora/test_llama_tp.py
parallelism: 4

- label: "PyTorch Fullgraph Smoke Test" # 9min
Expand Down Expand Up @@ -475,18 +475,23 @@ steps:
- pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py

- label: LoRA Long Context (Distributed) # 11min
# This test runs llama 13B, so it is required to run on 4 GPUs.
- label: LoRA TP Test (Distributed)
num_gpus: 4
soft_fail: true
source_file_dependencies:
- vllm/lora
- tests/lora/test_long_context
- tests/lora
commands:
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# This test runs llama 13B, so it is required to run on 4 GPUs.
- pytest -v -s -x lora/test_long_context.py
# There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py


- label: Weight Loading Multiple GPU Test # 33min
working_dir: "/vllm-workspace/tests"
Expand Down
63 changes: 53 additions & 10 deletions tests/lora/test_chatglm3.py → tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from typing import List

import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest

from ..utils import multi_gpu_test

MODEL_PATH = "THUDM/chatglm3-6b"

PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501

EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age",
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
Expand All @@ -20,7 +29,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
Expand All @@ -37,23 +45,58 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)

expected_lora_output = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age",
]
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False)

output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True)
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
146 changes: 0 additions & 146 deletions tests/lora/test_llama.py

This file was deleted.

Loading