diff --git a/backend/configs/config.yaml b/backend/configs/config.yaml index f0e3e925..d2fcb93e 100644 --- a/backend/configs/config.yaml +++ b/backend/configs/config.yaml @@ -10,7 +10,7 @@ bgm_separation: # Whether to offload the model after the inference. Should be true if your setup has a VRAM less than <16GB enable_offload: true # Device to load BGM separation model - device: cuda + device: xpu # Settings that apply to the `cache' directory. The output files for `/bgm-separation` are stored in the `cache' directory, # (You can check out the actual generated files by testing `/bgm-separation`.) diff --git a/backend/tests/test_backend_bgm_separation.py b/backend/tests/test_backend_bgm_separation.py index fc44cb7e..7c90f140 100644 --- a/backend/tests/test_backend_bgm_separation.py +++ b/backend/tests/test_backend_bgm_separation.py @@ -12,7 +12,7 @@ ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip the test because CUDA is not available") +@pytest.mark.skipif(not torch.xpu.is_available(), reason="Skip the test because CUDA is not available") @pytest.mark.parametrize( "bgm_separation_params", [ diff --git a/configs/default_parameters.yaml b/configs/default_parameters.yaml index 89ea2041..1100c4e4 100644 --- a/configs/default_parameters.yaml +++ b/configs/default_parameters.yaml @@ -1,37 +1,37 @@ whisper: - model_size: "large-v2" - file_format: "SRT" - lang: "Automatic Detection" + model_size: large-v2 + lang: chinese is_translate: false beam_size: 5 - log_prob_threshold: -1 + log_prob_threshold: -1.0 no_speech_threshold: 0.6 + compute_type: float32 best_of: 5 - patience: 1 + patience: 1.0 condition_on_previous_text: true prompt_reset_on_temperature: 0.5 initial_prompt: null - temperature: 0 + temperature: 0.0 compression_ratio_threshold: 2.4 - chunk_length: 30 - batch_size: 24 - length_penalty: 1 - repetition_penalty: 1 + length_penalty: 1.0 + repetition_penalty: 1.0 no_repeat_ngram_size: 0 prefix: null suppress_blank: true - suppress_tokens: "[-1]" - max_initial_timestamp: 1 + suppress_tokens: '[-1]' + max_initial_timestamp: 1.0 word_timestamps: false - prepend_punctuations: "\"'“¿([{-" - append_punctuations: "\"'.。,,!!??::”)]}、" + prepend_punctuations: '"''“¿([{-' + append_punctuations: '"''.。,,!!??::”)]}、' max_new_tokens: null + chunk_length: 30 hallucination_silence_threshold: null hotwords: null language_detection_threshold: 0.5 language_detection_segments: 1 + batch_size: 24 add_timestamp: true - + file_format: SRT vad: vad_filter: false threshold: 0.5 @@ -39,26 +39,25 @@ vad: max_speech_duration_s: 9999 min_silence_duration_ms: 1000 speech_pad_ms: 2000 - diarization: is_diarize: false - hf_token: "" - + diarization_device: xpu + hf_token: '' bgm_separation: is_separate_bgm: false - uvr_model_size: "UVR-MDX-NET-Inst_HQ_4" + uvr_model_size: UVR-MDX-NET-Inst_HQ_4 + uvr_device: xpu segment_size: 256 save_file: false enable_offload: true - translation: deepl: - api_key: "" + api_key: '' is_pro: false - source_lang: "Automatic Detection" - target_lang: "English" + source_lang: Automatic Detection + target_lang: English nllb: - model_size: "facebook/nllb-200-1.3B" + model_size: facebook/nllb-200-1.3B source_lang: null target_lang: null max_length: 200 diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py index daaedf97..9e25c893 100644 --- a/modules/diarize/diarizer.py +++ b/modules/diarize/diarizer.py @@ -129,15 +129,15 @@ def offload(self): if self.pipe is not None: del self.pipe self.pipe = None - if self.device == "cuda": - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() + if self.device == "xpu": + torch.xpu.empty_cache() + torch.xpu.reset_max_memory_allocated() gc.collect() @staticmethod def get_device(): - if torch.cuda.is_available(): - return "cuda" + if torch.xpu.is_available(): + return "xpu" elif torch.backends.mps.is_available(): return "mps" else: @@ -146,8 +146,8 @@ def get_device(): @staticmethod def get_available_device(): devices = ["cpu"] - if torch.cuda.is_available(): - devices.append("cuda") + if torch.xpu.is_available(): + devices.append("xpu") elif torch.backends.mps.is_available(): devices.append("mps") return devices \ No newline at end of file diff --git a/modules/translation/translation_base.py b/modules/translation/translation_base.py index 08297a40..10535197 100644 --- a/modules/translation/translation_base.py +++ b/modules/translation/translation_base.py @@ -127,31 +127,31 @@ def translate_file(self, print(f"Error translating file: {e}") raise finally: - self.release_cuda_memory() + self.release_xpu_memory() def offload(self): """Offload the model and free up the memory""" if self.model is not None: del self.model self.model = None - if self.device == "cuda": - self.release_cuda_memory() + if self.device == "xpu": + self.release_xpu_memory() gc.collect() @staticmethod def get_device(): - if torch.cuda.is_available(): - return "cuda" + if torch.xpu.is_available(): + return "xpu" elif torch.backends.mps.is_available(): return "mps" else: return "cpu" @staticmethod - def release_cuda_memory(): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() + def release_xpu_memory(): + if torch.xpu.is_available(): + torch.xpu.empty_cache() + torch.xpu.reset_max_memory_allocated() @staticmethod def remove_input_files(file_paths: List[str]): diff --git a/modules/uvr/music_separator.py b/modules/uvr/music_separator.py index 5fc353e4..5c14d88f 100644 --- a/modules/uvr/music_separator.py +++ b/modules/uvr/music_separator.py @@ -20,7 +20,7 @@ def __init__(self, output_dir: Optional[str] = UVR_OUTPUT_DIR): self.model = None self.device = self.get_device() - self.available_devices = ["cpu", "cuda"] + self.available_devices = ["cpu", "xpu"] self.model_dir = model_dir self.output_dir = output_dir instrumental_output_dir = os.path.join(self.output_dir, "instrumental") @@ -159,15 +159,15 @@ def separate_files(self, @staticmethod def get_device(): """Get device for the model""" - return "cuda" if torch.cuda.is_available() else "cpu" + return "xpu" if torch.xpu.is_available() else "cpu" def offload(self): """Offload the model and free up the memory""" if self.model is not None: del self.model self.model = None - if self.device == "cuda": - torch.cuda.empty_cache() + if self.device == "xpu": + torch.xpu.empty_cache() gc.collect() self.audio_info = None diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index c2012ff2..dcd8b0ee 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -265,7 +265,7 @@ def transcribe_file(self, except Exception as e: raise RuntimeError(f"Error transcribing file: {e}") from e finally: - self.release_cuda_memory() + self.release_xpu_memory() def transcribe_mic(self, mic_audio: str, @@ -328,7 +328,7 @@ def transcribe_mic(self, except Exception as e: raise RuntimeError(f"Error transcribing mic: {e}") from e finally: - self.release_cuda_memory() + self.release_xpu_memory() def transcribe_youtube(self, youtube_link: str, @@ -400,7 +400,7 @@ def transcribe_youtube(self, except Exception as e: raise RuntimeError(f"Error transcribing youtube: {e}") from e finally: - self.release_cuda_memory() + self.release_xpu_memory() def get_compute_type(self): if "float16" in self.available_compute_types: @@ -421,8 +421,8 @@ def offload(self): if self.model is not None: del self.model self.model = None - if self.device == "cuda": - self.release_cuda_memory() + if self.device == "xpu": + self.release_xpu_memory() gc.collect() @staticmethod @@ -454,8 +454,8 @@ def format_time(elapsed_time: float) -> str: @staticmethod def get_device(): - if torch.cuda.is_available(): - return "cuda" + if torch.xpu.is_available(): + return "xpu" elif torch.backends.mps.is_available(): if not BaseTranscriptionPipeline.is_sparse_api_supported(): # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886 @@ -482,11 +482,11 @@ def is_sparse_api_supported(): return False @staticmethod - def release_cuda_memory(): + def release_xpu_memory(): """Release memory""" - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() + if torch.xpu.is_available(): + torch.xpu.empty_cache() + torch.xpu.max_memory_allocated() @staticmethod def remove_input_files(file_paths: List[str]): diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index ad72ee33..d2e793fe 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -156,7 +156,7 @@ def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components class DiarizationParams(BaseParams): """Speaker diarization parameters""" is_diarize: bool = Field(default=False, description="Enable speaker diarization") - diarization_device: str = Field(default="cuda", description="Device to run Diarization model.") + diarization_device: str = Field(default="xpu", description="Device to run Diarization model.") hf_token: str = Field( default="", description="Hugging Face token for downloading diarization models" @@ -174,7 +174,7 @@ def to_gradio_inputs(cls, ), gr.Dropdown( label=_("Device"), - choices=["cpu", "cuda"] if available_devices is None else available_devices, + choices=["cpu", "xpu"] if available_devices is None else available_devices, value=defaults.get("device", device), ), gr.Textbox( @@ -192,7 +192,7 @@ class BGMSeparationParams(BaseParams): default="UVR-MDX-NET-Inst_HQ_4", description="UVR model size" ) - uvr_device: str = Field(default="cuda", description="Device to run UVR model.") + uvr_device: str = Field(default="xpu", description="Device to run UVR model.") segment_size: int = Field( default=256, gt=0, @@ -228,7 +228,7 @@ def to_gradio_input(cls, ), gr.Dropdown( label=_("Device"), - choices=["cpu", "cuda"] if available_devices is None else available_devices, + choices=["cpu", "xpu"] if available_devices is None else available_devices, value=defaults.get("device", device), ), gr.Number( diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index a41d799a..672ca157 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -183,8 +183,8 @@ def get_model_paths(self): @staticmethod def get_device(): - if torch.cuda.is_available(): - return "cuda" + if torch.xpu.is_available(): + return "xpu" else: return "auto" diff --git a/requirements.txt b/requirements.txt index feef81a1..b0c883a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,9 @@ --extra-index-url https://download.pytorch.org/whl/cu124 -torch -torchaudio + git+https://github.com/jhj0517/jhj0517-whisper.git -faster-whisper==1.1.1 +#faster-whisper==1.1.1 transformers gradio gradio-i18n diff --git a/tests/test_bgm_separation.py b/tests/test_bgm_separation.py index 95b77a0b..f8199780 100644 --- a/tests/test_bgm_separation.py +++ b/tests/test_bgm_separation.py @@ -11,7 +11,7 @@ @pytest.mark.skipif( - not is_cuda_available(), + not is_xpu_available(), reason="Skipping because the test only works on GPU" ) @pytest.mark.parametrize( @@ -32,7 +32,7 @@ def test_bgm_separation_pipeline( @pytest.mark.skipif( - not is_cuda_available(), + not is_xpu_available(), reason="Skipping because the test only works on GPU" ) @pytest.mark.parametrize( diff --git a/tests/test_config.py b/tests/test_config.py index 5997cfa6..266f3ad0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -21,8 +21,8 @@ @functools.lru_cache -def is_cuda_available(): - return torch.cuda.is_available() +def is_xpu_available(): + return torch.xpu.is_available() @functools.lru_cache diff --git a/tests/test_diarization.py b/tests/test_diarization.py index f18a2633..4a19d558 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -10,7 +10,7 @@ @pytest.mark.skipif( - not is_cuda_available(), + not is_xpu_available(), reason="Skipping because the test only works on GPU" ) @pytest.mark.parametrize(