From 74f4ad89f7fc961f9373678e97d198d74ed27db0 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 27 Jan 2025 14:43:45 +0800 Subject: [PATCH] Format files (#541) ## Summary Format ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Signed-off-by: Austin Liu --- .../test_mini_models_multimodal.py | 19 +++++++++---------- .../test_fused_linear_cross_entropy.py | 7 ++++--- test/transformers/test_rms_norm.py | 8 ++++---- test/utils.py | 4 +++- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py index c4d08b716..a7f8296a0 100644 --- a/test/convergence/test_mini_models_multimodal.py +++ b/test/convergence/test_mini_models_multimodal.py @@ -1,6 +1,15 @@ import functools import os +import pytest +import torch + +from datasets import load_dataset +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizerFast + +from liger_kernel.transformers import apply_liger_kernel_to_mllama +from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl from test.utils import FAKE_CONFIGS_PATH from test.utils import UNTOKENIZED_DATASET_PATH from test.utils import MiniModelConfig @@ -13,16 +22,6 @@ from test.utils import supports_bfloat16 from test.utils import train_bpe_tokenizer -import pytest -import torch - -from datasets import load_dataset -from torch.utils.data import DataLoader -from transformers import PreTrainedTokenizerFast - -from liger_kernel.transformers import apply_liger_kernel_to_mllama -from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl - try: # Qwen2-VL is only available in transformers>=4.45.0 from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 78a21b02b..ffbe52275 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -1,11 +1,12 @@ -from test.transformers.test_cross_entropy import CrossEntropyWithZLoss -from test.utils import assert_verbose_allclose -from test.utils import set_seed from typing import Optional import pytest import torch +from test.transformers.test_cross_entropy import CrossEntropyWithZLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed + from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction from liger_kernel.transformers.functional import liger_fused_linear_cross_entropy from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss diff --git a/test/transformers/test_rms_norm.py b/test/transformers/test_rms_norm.py index d829785f0..4f1972c7c 100644 --- a/test/transformers/test_rms_norm.py +++ b/test/transformers/test_rms_norm.py @@ -1,13 +1,13 @@ import os -from test.utils import assert_verbose_allclose -from test.utils import set_seed -from test.utils import supports_bfloat16 - import pytest import torch import torch.nn as nn +from test.utils import assert_verbose_allclose +from test.utils import set_seed +from test.utils import supports_bfloat16 + from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.transformers.functional import liger_rms_norm from liger_kernel.transformers.rms_norm import LigerRMSNorm diff --git a/test/utils.py b/test/utils.py index 3fcb07b71..004c3780b 100644 --- a/test/utils.py +++ b/test/utils.py @@ -541,7 +541,9 @@ def get_batch_loss_metrics( **loss_kwargs, ): """Compute the loss metrics for the given batch of inputs for train or test.""" - forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob, preference_labels, nll_target) + forward_output = self.concatenated_forward( + _input, weight, target, bias, average_log_prob, preference_labels, nll_target + ) ( policy_chosen_logps, policy_rejected_logps,