diff --git a/cli_demo.py b/cli_demo.py index b1415b5..041c975 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -6,6 +6,7 @@ from tempfile import NamedTemporaryFile from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.utils import GenerationConfig +import torch.mps def init_model(): @@ -69,17 +70,15 @@ def main(stream=True): for response in model.chat(tokenizer, messages, stream=True): print(response[position:], end='', flush=True) position = len(response) - if torch.backends.mps.is_available(): - torch.mps.empty_cache() except KeyboardInterrupt: pass print() else: response = model.chat(tokenizer, messages) print(response) - if torch.backends.mps.is_available(): - torch.mps.empty_cache() messages.append({"role": "assistant", "content": response}) + if torch.backends.mps.is_available(): + torch.mps.empty_cache() print(Style.RESET_ALL)