Replies: 9 comments 5 replies
-
I read the paper but it's still not very clear for me if L²QER change something at inference time or not. Please correct if I'm wrong, but what I understand is that we firstly do LQER by running SVD to get LoRA, then further optimizing it via L²QER (around page 6 of the paper). Anyway,
FYI, someone already done the python implementation of GGML quants in HF transformers library. Hopefully this can help your implementation.
I'm really interested in this, because control vector can be benefit from SVD (we're currently using PCA)
Yes, I think it would be useful to know if we should call The current |
Beta Was this translation helpful? Give feedback.
-
If it's any use then I tidied up the Mergekit code to extract LoRAs from fine turned models: It works on all but the weirdest models now (eg: Might be useful to make a reference version using pytorch and then compare with a native C/C++ reimplemention. EDIT: The only part of Mergekit it's using is the "lazy tensor loading" and other than that it's mostly based on: https://github.com/thomasgauthier/LoRD (I'm not sure how it ended up part of Mergekit) |
Beta Was this translation helpful? Give feedback.
-
Interesting work, will be cool to try to implement this approach and see how the perplexity improves for different ranks I'm looking at some of the results in the paper and not sure how to interpret Appendix B: Based on the graph, it seems L2QER performs worse (i.e. higher error) compared to LQER, while the text states the opposite. Am I reading it wrong? |
Beta Was this translation helpful? Give feedback.
-
So I'm just experimenting with this now, but so far have done the opposite of The biggest hurdle in The GNU Scientific Library mostly implements all this old robust Fortran code : https://www.gnu.org/software/gsl/doc/html/linalg.html (for dense BLAS - the sparse BLAS stuff we don't really care about much anyway) Obviously this isn't much use on its own as everything will be CPU-based and not actually use the GGML back-ends... BUT: It actually only relies on an implementation of "CBLAS", which it provides itself here: https://git.savannah.gnu.org/cgit/gsl.git/tree/cblas but actually it can use any implementation, such as the one in MKL (it's a bit of a PITA to link though): https://stackoverflow.com/questions/52989133/linking-gsl-c-program-with-intel-mkl and I have seen people also link with the Netlib "wrapper for legacy Fortran" successfully. So to use this with the GGML back-ends, all you would actually need to do is implement these level 1/2/3 CBLAS functions: https://git.savannah.gnu.org/cgit/gsl.git/tree/cblas/gsl_cblas.h and then you would get the full power of the GGML back-ends, but with the carefully curated set of linear algebra algorithms! You could also skip all the "C" (complex) and "Z" (double complex) functions and many of the matrix types for now: If you look at the source folder: https://git.savannah.gnu.org/cgit/gsl.git/tree/cblas then it's clearly not a huge amount of code to write, and most of it is just pure boilerplate (the tests and the main It's actually likely in practice that you could start with the pure-CPU version of GSL's provided BLAS implementation and convert it bit by bit to use the GGML code. @ggerganov What is your opinion on the idea of having the GSL dependency in https://en.wikipedia.org/wiki/Comparison_of_linear_algebra_libraries I should make it clear that there are two things here: We would only need to implement the |
Beta Was this translation helpful? Give feedback.
-
The only other viable option I can see is to use: https://www.boost.org/doc/libs/1_87_0/libs/python/doc/html/index.html which would then let us call the PyTorch linear algebra code (I think the C++ Torch API is pretty much dead now AFAIK?), but then this brings in the massive bloat and long compile times and pretty sure this won't be wanted as a dependency... |
Beta Was this translation helpful? Give feedback.
-
def ggml_quantize_residual(tensor: torch.Tensor, quant_type: gguf.GGMLQuantizationType) -> torch.Tensor:
"""
Returns the residual between original tensor and its quantized-dequantized version.
Args:
tensor: Input torch tensor
quant_type: GGML quantization type to use
Returns:
Residual tensor (original - reconstructed) on the same device as input
"""
# Save original device and move to CPU for numpy conversion
orig_device = tensor.device
cpu_tensor = tensor.cpu()
# Convert to numpy
np_tensor = cpu_tensor.numpy()
# Quantize and dequantize
if quant_type in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]:
dtype = np.float32 if quant_type == gguf.GGMLQuantizationType.F32 else np.float16
quant = np_tensor.astype(dtype)
dequant = quant.astype(np.float32)
else:
quant = gguf.quants.quantize(np_tensor, quant_type)
dequant = gguf.quants.dequantize(quant, quant_type)
# Convert dequantized back to torch
reconstructed = torch.from_numpy(dequant).to(orig_device)
# Calculate and return residual
return tensor - reconstructed
def svd_compress_residual(tensor: torch.Tensor, rank: int):
"""
SVD compression for quantized-dequantized residual.
Returns LoRA matrices (A and B) and residual tensor.
"""
assert tensor.dtype == torch.float32, f"Expected float32 input tensor, got {tensor.dtype}"
# Compute the SVD
U_r, S_r, Vh_r, var_expl = truncated_svd(tensor, r=rank)
# Create LoRA matrices
sqrtS = torch.sqrt(S_r)
lora_A = (sqrtS.unsqueeze(1) * Vh_r).contiguous() # [rank, input_dim]
lora_B = (U_r * sqrtS.unsqueeze(0)).contiguous() # [output_dim, rank]
new_size = rank * (tensor.shape[0] + tensor.shape[1])
orig_size = tensor.shape[0] * tensor.shape[1]
compression_ratio = new_size / orig_size
print(f"- Rank : {rank}")
print(f"- Compression Ratio : {compression_ratio*100:.2f}%")
print(f"- Variance Explained : {var_expl*100:.2f}%")
print(f"- LoRA Shapes : A {lora_A.shape}, B {lora_B.shape}")
return lora_A, lora_B
.
.
.
w_ggml_residual = ggml_quantize_residual(w_deq, quant_type)
lora_a, lora_b = svd_compress_residual(w_ggml_residual, rank=rank)
.
.
.
parser.add_argument("--quant-type", type=str, default="Q4_0", choices=["Q4_0", "Q4_1", "Q5_0", "Q5_1", "Q8_0"],
help="Quantization type for GGUF export (default: Q4_0)") Just trying https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py (ignore the |
Beta Was this translation helpful? Give feedback.
-
I'm a little sceptical if this will work. These are the truncated-SVD stats for Rank-256 - adds around 16% extra overhead: (256×(7168 + 2048))/(7168×2048)
Rank-64 - adds around 4% extra overhead: (64×(7168 + 2048))/(7168×2048)
This distribution of singular values does look remarkably flat compared to my previous (failed) "reverse LQER" attempt. |
Beta Was this translation helpful? Give feedback.
-
It's a little better on the early layers:
But the "reverse LQER" attempt was getting 50%+ Variance Explained for these... |
Beta Was this translation helpful? Give feedback.
-
Slightly better than "reverse LQER", but still not worth bothering with IMO - just adding a extra bit to the quant would give far more improvement than LQER with all the extra overhead it adds... The only interesting thing is it does show that the early layers (of the MoE tensors of Overall LQER seems a waste of time (can't comment to L2QER though). |
Beta Was this translation helpful? Give feedback.
-
Since the recent LoRA refactor by @ngxson in #8332, I think it should be possible to improve existing quantization schemes with Low-Rank Quantization Error Reconstruction (see https://arxiv.org/abs/2402.02446)
It would only need two things:
gguf-py/gguf/quants.py
to make this easier.And also I think L²QER could be implemented with the existing imatrix files.
Beta Was this translation helpful? Give feedback.
All reactions