Skip to content

Commit

Permalink
custom models, text file as input and cache results
Browse files Browse the repository at this point in the history
  • Loading branch information
daswer123 committed Dec 21, 2023
1 parent 2576f3e commit fc033f1
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 35 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ docker compose up # or with -d to run in background
Use the `--deepspeed` flag to process the result fast ( 2-3x acceleration )

```
usage: xtts_api_server [-h] [-hs HOST] [-p PORT] [-sf SPEAKER_FOLDER] [-o OUTPUT] [-t TUNNEL_URL] [-ms MODEL_SOURCE] [--lowvram] [--deepspeed] [--streaming-mode] [--stream-play-sync]
usage: xtts_api_server [-h] [-hs HOST] [-p PORT] [-sf SPEAKER_FOLDER] [-o OUTPUT] [-t TUNNEL_URL] [-ms MODEL_SOURCE] [--use-cache] [--lowvram] [--deepspeed] [--streaming-mode] [--stream-play-sync]
Run XTTSv2 within a FastAPI application
Expand All @@ -85,14 +85,19 @@ options:
-o OUTPUT, --output Output folder
-t TUNNEL_URL, --tunnel URL of tunnel used (e.g: ngrok, localtunnel)
-ms MODEL_SOURCE, --model-source ["api","apiManual","local"]
-v MODEL_VERSION, --version You can choose any version of the model, keep in mind that if you choose model-source api, only the latest version will be loaded
-v MODEL_VERSION, --version You can download the official model or your own model, official version you can find [here](https://huggingface.co/coqui/XTTS-v2/tree/main) the model version name is the same as the branch name [v2.0.2,v2.0.3, main] etc.
--use-cache Enables caching of results, your results will be saved and if there will be a repeated request, you will get a file instead of generation
--lowvram The mode in which the model will be stored in RAM and when the processing will move to VRAM, the difference in speed is small
--deepspeed allows you to speed up processing by several times, automatically downloads the necessary libraries
--streaming-mode Enables streaming mode, currently has certain limitations, as described below.
--streaming-mode-improve Enables streaming mode, includes an improved streaming mode that consumes 2gb more VRAM and uses a better tokenizer and more context.
--stream-play-sync Additional flag for streaming mod that allows you to play all audio one at a time without interruption
```

You can specify the path to the file as text, then the path counts and the file will be voiced

You can load your own model, for this you need to create a folder in models and load the model with configs, note in the folder should be 3 files `config.json` `vocab.json` `model.pth`

If you want your host to listen, use -hs 0.0.0.0

The -t or --tunnel flag is needed so that when you get speakers via get you get the correct link to hear the preview. More info [here](https://imgur.com/a/MvpFT59)
Expand All @@ -103,7 +108,7 @@ Model-source defines in which format you want to use xtts:
2. `apiManual` - loads version 2.0.2 by default, but you can specify the version via the -v flag, model saves into the models folder and uses the `tts_to_file` function from the TTS api
3. `api` - will load the latest version of the model. The -v flag won't work.

All versions of the XTTSv2 model can be found [here](https://huggingface.co/coqui/XTTS-v2/tree/v2.0.2) in the branches
All versions of the XTTSv2 model can be found [here](https://huggingface.co/coqui/XTTS-v2/tree/main) the model version name is the same as the branch name [v2.0.2,v2.0.3, main] etc.

The first time you run or generate, you may need to confirm that you agree to use XTTS.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "xtts-api-server"
version = "0.6.8"
version = "0.7.0"
authors = [
{ name="daswer123", email="daswerq123@gmail.com" },
]
Expand Down
27 changes: 19 additions & 8 deletions xtts_api_server/RealtimeTTS/engines/coqui_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from xtts_api_server.tts_funcs import official_model_list
from torch.multiprocessing import Process, Pipe, Event, set_start_method
from .base_engine import BaseEngine
from typing import Union, List
Expand Down Expand Up @@ -92,7 +93,16 @@ def __init__(self,
ModelManager().download_model(model_name)
else:
logging.info(f"Local XTTS Model: \"{specific_model}\" specified")
self.local_model_path = self.download_model(specific_model, local_models_path)
is_official_model = False
for model in official_model_list:
if self.specific_model == model:
is_official_model = True
break

if is_official_model:
self.local_model_path = self.download_model(specific_model, local_models_path)
else:
self.local_model_path = os.path.join(local_models_path,specific_model)

self.synthesize_process = Process(target=CoquiEngine._synthesize_worker, args=(child_synthesize_pipe, model_name, cloning_reference_wav, language, self.main_synthesize_ready_event, level, self.speed, thread_count, stream_chunk_size, full_sentences, overlap_wav_len, temperature, length_penalty, repetition_penalty, top_k, top_p, enable_text_splitting, use_mps, self.local_model_path, use_deepspeed, self.voices_path))
self.synthesize_process.start()
Expand Down Expand Up @@ -540,28 +550,29 @@ def download_file(url, destination):
progress_bar.close()

@staticmethod
def download_model(model_name = "2.0.2", local_models_path = None):
def download_model(model_name = "v2.0.2", local_models_path = None):

# Creating a unique folder for each model version
if local_models_path and len(local_models_path) > 0:
model_folder = os.path.join(local_models_path, f'v{model_name}')
model_folder = os.path.join(local_models_path, f'{model_name}')
logging.info(f"Local models path: \"{model_folder}\"")
else:
model_folder = os.path.join(os.getcwd(), 'models', f'v{model_name}')
model_folder = os.path.join(os.getcwd(), 'models', f'{model_name}')
logging.info(f"Checking for models within application directory: \"{model_folder}\"")

os.makedirs(model_folder, exist_ok=True)
print(model_name)

files = {
"config.json": f"https://huggingface.co/coqui/XTTS-v2/raw/v{model_name}/config.json",
"model.pth": f"https://huggingface.co/coqui/XTTS-v2/resolve/v{model_name}/model.pth?download=true",
"vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/v{model_name}/vocab.json"
"config.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_name}/config.json",
"model.pth": f"https://huggingface.co/coqui/XTTS-v2/resolve/{model_name}/model.pth?download=true",
"vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_name}/vocab.json"
}

for file_name, url in files.items():
file_path = os.path.join(model_folder, file_name)
if not os.path.exists(file_path):
logger.info(f"Downloading {file_name} for Model v{model_name}...")
logger.info(f"Downloading {file_name} for Model {model_name}...")
CoquiEngine.download_file(url, file_path)
# r = requests.get(url, allow_redirects=True)
# with open(file_path, 'wb') as f:
Expand Down
4 changes: 3 additions & 1 deletion xtts_api_server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
parser.add_argument("-t", "--tunnel", default="", type=str, help="URL of tunnel used (e.g: ngrok, localtunnel)")
parser.add_argument("-ms", "--model-source", default="local", choices=["api","apiManual", "local"],
help="Define the model source: 'api' for latest version from repository, apiManual for 2.0.2 model and api inference or 'local' for using local inference and model v2.0.2.")
parser.add_argument("-v", "--version", default="2.0.2", type=str, help="You can specify which version of xtts to use,This version will be used everywhere in local, api and apiManual.")
parser.add_argument("-v", "--version", default="v2.0.2", type=str, help="You can specify which version of xtts to use,This version will be used everywhere in local, api and apiManual.")
parser.add_argument("--lowvram", action='store_true', help="Enable low vram mode which switches the model to RAM when not actively processing.")
parser.add_argument("--deepspeed", action='store_true', help="Enables deepspeed mode, speeds up processing by several times.")
parser.add_argument("--use-cache", action='store_true', help="Enables caching of results, your results will be saved and if there will be a repeated request, you will get a file instead of generation.")
parser.add_argument("--streaming-mode", action='store_true', help="Enables streaming mode, currently needs a lot of work.")
parser.add_argument("--streaming-mode-improve", action='store_true', help="Includes an improved streaming mode that consumes 2gb more VRAM and uses a better tokenizer, good for languages such as Chinese")
parser.add_argument("--stream-play-sync", action='store_true', help="Additional flag for streaming mod that allows you to play all audio one at a time without interruption")
Expand All @@ -27,6 +28,7 @@
os.environ['TUNNEL_URL'] = args.tunnel # it is necessary to correctly return correct previews in list of speakers
os.environ['MODEL_SOURCE'] = args.model_source # Set environment variable for the model source
os.environ["MODEL_VERSION"] = args.version # Specify version of XTTS model
os.environ["USE_CACHE"] = str(args.use_cache).lower() # Set lowvram mode
os.environ["DEEPSPEED"] = str(args.deepspeed).lower() # Set lowvram mode
os.environ["LOWVRAM_MODE"] = str(args.lowvram).lower() # Set lowvram mode
os.environ["STREAM_MODE"] = str(args.streaming_mode).lower() # Enable Streaming mode
Expand Down
8 changes: 4 additions & 4 deletions xtts_api_server/modeldownloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def check_stream2sentence_version():
def download_model(this_dir,model_version):
# Define paths
base_path = this_dir / 'models'
model_path = base_path / f'v{model_version}'
model_path = base_path / f'{model_version}'

# Define files and their corresponding URLs
files_to_download = {
"config.json": f"https://huggingface.co/coqui/XTTS-v2/raw/v{model_version}/config.json",
"model.pth": f"https://huggingface.co/coqui/XTTS-v2/resolve/v{model_version}/model.pth?download=true",
"vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/v{model_version}/vocab.json"
"config.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_version}/config.json",
"model.pth": f"https://huggingface.co/coqui/XTTS-v2/resolve/{model_version}/model.pth?download=true",
"vocab.json": f"https://huggingface.co/coqui/XTTS-v2/raw/{model_version}/vocab.json"
}

# Check and create directories
Expand Down
16 changes: 9 additions & 7 deletions xtts_api_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
SPEAKER_FOLDER = os.getenv('SPEAKER', 'speakers')
BASE_URL = os.getenv('BASE_URL', '127.0.0.1:8020')
MODEL_SOURCE = os.getenv("MODEL_SOURCE", "local")
MODEL_VERSION = os.getenv("MODEL_VERSION","2.0.2")
MODEL_VERSION = os.getenv("MODEL_VERSION","v2.0.2")
LOWVRAM_MODE = os.getenv("LOWVRAM_MODE") == 'true'
DEEPSPEED = os.getenv("DEEPSPEED") == 'true'
USE_CACHE = os.getenv("USE_CACHE") == 'true'

# STREAMING VARS
STREAM_MODE = os.getenv("STREAM_MODE") == 'true'
STREAM_MODE_IMPROVE = os.getenv("STREAM_MODE_IMPROVE") == 'true'
Expand All @@ -41,17 +43,14 @@

# Create an instance of the TTSWrapper class and server
app = FastAPI()
XTTS = TTSWrapper(OUTPUT_FOLDER,SPEAKER_FOLDER,LOWVRAM_MODE,MODEL_SOURCE,MODEL_VERSION,DEVICE,DEEPSPEED)
XTTS = TTSWrapper(OUTPUT_FOLDER,SPEAKER_FOLDER,LOWVRAM_MODE,MODEL_SOURCE,MODEL_VERSION,DEVICE,DEEPSPEED,USE_CACHE)

# Create version string
version_string = ""
if MODEL_SOURCE == "api":
if MODEL_SOURCE == "api" or MODEL_VERSION == "main":
version_string = "lastest"
else:
version_string = "v"+MODEL_VERSION

if MODEL_SOURCE == "api" and MODEL_SOURCE != "2.0.2":
logger.warning("Attention you have specified flag -v but you have selected --model-source api, please change --model-souce to apiManual or local to use the specified version, otherwise the latest version of the model will be loaded.")
version_string = MODEL_VERSION

# Load model
# logger.info(f"The model {version_string} starts to load,wait until it loads")
Expand All @@ -74,6 +73,9 @@
else:
XTTS.load_model()

if USE_CACHE:
logger.info("You have enabled caching, this option enables caching of results, your results will be saved and if there is a repeat request, you will get a file instead of generation")

# Add CORS middleware
origins = ["*"]
app.add_middleware(
Expand Down
Loading

0 comments on commit fc033f1

Please # to comment.