-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[Feature Suggest] Tensor Parallellism for Accelerating LLM #29
Comments
|
Nice measurments! It seems @zhengpeirong please check the 0.3.1 version. Now all tasks are executed in parallel so it should be a bit better. |
@b4rtaz The 'qkv' has been reverted. Do you plan to deal with this issue? Not only the 'MulHead' costs time, but also the 'Finalize' costs a big portion of time. |
@zhengpeirong yes I know. The With the Yes, I want to keep working on this project. More hands are welcome. :-) |
@b4rtaz Thanks for your persistence and endeavor.
If you combine all those mechanisms, the non-parallel functions will be optimized! Here is the draft workflow: TransformerArch buildLlama2Arch(TransformerSpec* spec) {
TransformerArch a;
// Inference
a.I(sendPoke, TASK_TYPE_TRANSFER);
for (int i = 0; i < spec->nLayers; i++) {
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE); // Combine the existing llamaRmsAtt and llamaRmsAttNorm
a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE); // Quantization
a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); // Sending
a.I(llamaQkv, TASK_TYPE_INFERENCE); // Compute Q K V
a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); // Merge kv-cache, add RoPE encoding, compute a part of multi-head attention locally
a.I(llamaAttOutput, TASK_TYPE_INFERENCE); // Worker computes W_O matrix
a.I(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
a.I(llamaSyncAtt, TASK_TYPE_TRANSFER); // First communication time-consuming
a.I(llamaDequantizeAtt, TASK_TYPE_INFERENCE);
a.I(llamaMergeAtt, TASK_TYPE_INFERENCE); // Merge all attention matrices
a.I(llamaRmfFfn, TASK_TYPE_INFERENCE);
a.I(llamaRmfFfnNorm, TASK_TYPE_INFERENCE);
a.I(llamaQuantizeRmfFfn, TASK_TYPE_INFERENCE);
a.I(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
a.I(llamaFfn, TASK_TYPE_INFERENCE); // Compute SwiGLU activation
a.I(llamaFfn2, TASK_TYPE_INFERENCE); // Compute the second FFN
a.I(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
a.I(llamaSyncFfn2, TASK_TYPE_TRANSFER); // Second communication time-consuming
a.I(llamaDequantizeFfn2, TASK_TYPE_INFERENCE);
a.I(llamaMergeFfn2, TASK_TYPE_INFERENCE);
a.I(llamaNextBlock, TASK_TYPE_INFERENCE);
}
a.I(llamaRmsFinal, TASK_TYPE_INFERENCE);
a.I(llamaRmsFinalNorm, TASK_TYPE_INFERENCE);
a.I(llamaLogits, TASK_TYPE_INFERENCE);
a.I(llamaQuantizeLogits, TASK_TYPE_INFERENCE);
a.I(llamaSyncLogits, TASK_TYPE_TRANSFER);
a.I(llamaDequantizeLogits, TASK_TYPE_INFERENCE);
a.I(llamaMergeLogits, TASK_TYPE_INFERENCE);
// Worker
for (int i = 0; i < spec->nLayers; i++) {
a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
a.W(llamaQkv, TASK_TYPE_INFERENCE); // Compute Q K V
a.W(llamaMultiheadAtt, TASK_TYPE_INFERENCE); // Merge kv-cache, add RoPE encoding, compute a part of multi-head attention locally
a.W(llamaAttOutput, TASK_TYPE_INFERENCE); // Worker computes W_O matrix
a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
a.W(llamaSyncAtt, TASK_TYPE_TRANSFER);
a.W(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
a.W(llamaFfn, TASK_TYPE_INFERENCE);
a.W(llamaFfn2, TASK_TYPE_INFERENCE);
a.W(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
a.W(llamaSyncFfn2, TASK_TYPE_TRANSFER);
a.W(llamaNextBlock, TASK_TYPE_INFERENCE);
}
a.W(llamaLogits, TASK_TYPE_INFERENCE);
a.W(llamaQuantizeLogits, TASK_TYPE_INFERENCE);
a.W(llamaSyncLogits, TASK_TYPE_TRANSFER);
return a;
} I hope this repo can catch up with the state-of-the-art algorithm as soon as possible~~ |
@zhengpeirong
|
The issue you are currently facing lies in separately calculating the QKV matrices, which are split according to the However, if tensor parallelism is supported, the splitting is performed along the In summary, the RoPE computation and the multi-head attention computation are orthogonal, operating on different dimensions: the former on the |
I needed a bit of time to notice my thinking error. After all the rope layer is splitted out to the root node and workers. 🎉 Tested it with 1, 2 and 4 nodes and the macbeth test generates the same output on different topologies *. * The macbeth test doesn't work with the buffer quantization (it generates a different output), because now the RoPE is applied before the transfer quantization. Previously, it was applied after the transfer dequantization. I expect this affects the perplexity somehow. Probably this will be resolved if the Now all nodes have the RoPE cache, and the size of the cache is different for all nodes. This may be a bit optimized, but "so far so good".
Next, I'll try to split out the |
Finally I splitted out the multihead layer into all nodes (still not merged, I need to fix mixtral & grok architectures). First measurments: Model: Llama 3 8B Q40 Transfer size / token
Avg tokens / second
* I think the used switch is completely non-deterministic, it achieves a random speed at different times. So I recommend to compare only the avg inference time. It looks like that gave a tiny speed up (maybe 3%). I expected a bit more. 🤔 |
Update: I changed the implementation a bit, now there is no synchronization between Transfer size / token
The final state of the attention synchronization looks like this for a single block:
The previous implementation:
|
@b4rtaz 🎉You have completed the sota tensor parallel for Attention Layer!!!
In summary, at most 12% acceleration can be made upon the current version. When the worker number increases, 4 workers in this issue, this acceleration would enjoy more parallelism. |
@b4rtaz Just so your reference, this code implements the FFN layer of llama with Tensor Parallel acceleration. |
@zhengpeirong it seems after I adjusted mlp layers to your suggestion the transfer has dropped by ~40% per token. 🤯
Later I'll check the impact on the generation time. |
Where you see the generation time data? 🤔 |
'272 kB ' is compatible with the theory analysis.
Except for transfer data for the embedding layers, we can treat this as 2. This means there are 2 times All-Reduce transfers in a single Transformer block. And the Congratulations on finishing this feature suggestion! |
Llama 2 7B Q40nTokens = 90, buffer = Q80 4 x Rasperry Pi 5 8GB
2 x Rasperry Pi 5 8GB
Tinylama 1.3B 3T Q40nTokens = 128, buffer = Q80 2 x Rasperry Pi 5 8GB
Llama 3 8B Q40nTokens = 90, buffer = Q80 2 x AMD EPYC 7402P 24-Core Processor
|
In all cases the average transfer time has dropped. What is interesting the non-blocking sockets reduce the speed on Raspberry Pi but on a strong machine not. Maybe this mode should be optional. |
Do you mean blocking sockets reduces the speed? Could you try 8 x Raspberry Pi? Since there are obvious transfer delays for 8 devices, I am curious whether it's because of network traffic congestion. BTW, I think it's time to update the |
No. The non-blocking sockets I think. From the 0.6.1 Distributed Llama has enabled non-blocking sockets for root <> node communciation.
Unfortunelty I don't have 8 devices anymore. I have only 4 x Raspberry Pi 5 8GB.
You're right. I'll do it soon. |
The non-blocking sockets will make the CPU do other jobs instead of waiting. But what's the logical connection between non-blocking and increased inference time?
In this discussion, you are invited to conduct experiments with more devices and find what number of devices is the best choice, then present it in README. If the |
I think this problem appears only on slow devices like Raspberry Pi. I cannot explain it but you can see the drop in the speed 0.6.0 -> 0.7.0. This was only a minor change between these versions. Maybe we need more tests. |
Dear Author,
Your contribution is critical for the open-source community. The distributed-llama repo has implemented tensor parallelism from scratch. And the result is amazingly significant. However, there are still improvements that could be made. Because of my poor coding ability, not able to make improvements myself, I hope you can look at my suggestions below.
Challenge: root node's special task and synchronization
When I run the repo version '0.1.0', I find that the
softmax
operations inMultiHead
are conducted on the root node only. This operation costs a significant portion of the total time. Second, thesynFfnA
andsynFfn2
functions also cost a lot of time.Mature solutions
In fact, these challenges have been found in this paper: https://arxiv.org/abs/1909.08053. Its solution is shown in the image:
It conducts attention mechanism(softmax) on every worker. Second, the matrix segmentation direction is using column segment and row segment in two consecutive matrices, thus reducing to one synchronization operation instead of two.
If you are willing to make further improvements to the repo, the following is the mature solution for every component of
llama2
using tensor parallelism and sequence parallelism.https://pytorch.org/tutorials/intermediate/TP_tutorial.html
However, it's implemented in Python, and you will be the first one to implement the solution in C++.
Thanks for your contribution!!!
Best Regards
The text was updated successfully, but these errors were encountered: