From 4992302a7e4c1863a95a29ffee23ea31301a5df5 Mon Sep 17 00:00:00 2001 From: waterstone <27319794@qq.com> Date: Thu, 7 Sep 2023 12:16:22 +0800 Subject: [PATCH] fix for M2 Ultra --- cli_demo.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cli_demo.py b/cli_demo.py index b1415b5..041c975 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -6,6 +6,7 @@ from tempfile import NamedTemporaryFile from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.utils import GenerationConfig +import torch.mps def init_model(): @@ -69,17 +70,15 @@ def main(stream=True): for response in model.chat(tokenizer, messages, stream=True): print(response[position:], end='', flush=True) position = len(response) - if torch.backends.mps.is_available(): - torch.mps.empty_cache() except KeyboardInterrupt: pass print() else: response = model.chat(tokenizer, messages) print(response) - if torch.backends.mps.is_available(): - torch.mps.empty_cache() messages.append({"role": "assistant", "content": response}) + if torch.backends.mps.is_available(): + torch.mps.empty_cache() print(Style.RESET_ALL)