From c29e38ec49b380a1f353d491c19c7e321be2746a Mon Sep 17 00:00:00 2001 From: Quinn Damerell Date: Mon, 10 Feb 2025 21:11:42 -0800 Subject: [PATCH] work --- .../homeway_linuxhost/sage/fibermanager.py | 78 +++++++++++++------ homeway/homeway_linuxhost/sage/sagehandler.py | 27 +++++-- .../sage/sagetranscribehandler.py | 44 ++++++----- 3 files changed, 102 insertions(+), 47 deletions(-) diff --git a/homeway/homeway_linuxhost/sage/fibermanager.py b/homeway/homeway_linuxhost/sage/fibermanager.py index a1e319c..6b201b4 100644 --- a/homeway/homeway_linuxhost/sage/fibermanager.py +++ b/homeway/homeway_linuxhost/sage/fibermanager.py @@ -3,7 +3,7 @@ import struct import logging import threading -from typing import List, Dict +from typing import List, Dict, Optional import octoflatbuffers from homeway.sentry import Sentry @@ -81,12 +81,16 @@ def createDataContextOffset(builder:octoflatbuffers.Builder) -> int: class ResponseContext: Text:str = None StatusCode:int = None + ErrorMessage:str = None response:ResponseContext = ResponseContext() - async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool): + async def onDataStreamReceived(statusCode:int, errorMessage:Optional[str], data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool): # For listen, this should only be called once if response.StatusCode is not None: raise Exception("Sage Listen onDataStreamReceived called more than once.") + # Set the error string, to either a string or None + response.ErrorMessage = errorMessage + # Check for a failure, which can happen at anytime. # If we have anything but a 200, stop processing now. response.StatusCode = statusCode @@ -107,8 +111,12 @@ async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageD # Do the operation, stream or wait for the response. result = await self._SendAndReceive(SageOperationTypes.Listen, audio, createDataContextOffset, onDataStreamReceived, isTransmissionDone) - # If the status code is set at any time and not 200, we failed, regardless of the mode. - # Check this before the function result, so we get the status code. + # If the error string or status code is set at any time, we failed, regardless of the mode. + # Check these before anything else, to ensure they get handled. + if response.ErrorMessage is not None: + # Return the raw string, since it's intended to send to the user. + return ListenResult.Failure(response.ErrorMessage) + if response.StatusCode is not None and response.StatusCode != 200: # In some special cases, we want to map the status code to a user message. errorStr = self._MapErrorStatusCodeToUserStr(response.StatusCode) @@ -148,8 +156,11 @@ def createDataContextOffset(builder:octoflatbuffers.Builder) -> int: # to stream back. class ResponseContext: StatusCode:int = None + ErrorMessage:str = None response:ResponseContext = ResponseContext() - async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool): + async def onDataStreamReceived(statusCode:int, errorMessage:Optional[str], data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool): + # Set the error string, to either a string or None + response.ErrorMessage = errorMessage # Check for a failure, which can happen at anytime. # If we have anything but a 200, stop processing now. @@ -167,8 +178,10 @@ async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageD requestBytes = json.dumps(request).encode("utf-8") result = await self._SendAndReceive(SageOperationTypes.Speak, requestBytes, createDataContextOffset, onDataStreamReceived) - # If the status code is set at any time and not 200, we failed, regardless of the mode. - # Check this before the function result, so we get the status code. + # Check if either of the error systems are set. + if response.ErrorMessage is not None: + self.Logger.error(f"Sage Speak failed with error msg: {response.ErrorMessage}") + return False if response.StatusCode is not None and response.StatusCode != 200: self.Logger.error(f"Sage Speak failed with status code {response.StatusCode}") return False @@ -182,7 +195,7 @@ async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageD # Takes a chat json object and returns the assistant's response text. # Returns None on failure. - async def Chat(self, requestJson:str, homeContext_CanBeNone:CompressionResult, states_CanBeNone:CompressionResult) -> bytearray: + async def Chat(self, requestJson:str, homeContext_CanBeNone:CompressionResult, states_CanBeNone:CompressionResult) -> str: # Our data type is a json string. def createDataContextOffset(builder:octoflatbuffers.Builder) -> int: @@ -191,13 +204,17 @@ def createDataContextOffset(builder:octoflatbuffers.Builder) -> int: # We expect the onDataStreamReceived handler to be called once, with the full response. class ResponseContext: Bytes = None - StatusCode = None + StatusCode:int = None + ErrorMessage:str = None response:ResponseContext = ResponseContext() - async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool): + async def onDataStreamReceived(statusCode:int, errorMessage:Optional[str], data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool): # For Chat, this should only be called once if response.StatusCode is not None: raise Exception("Sage Chat onDataStreamReceived called more than once.") + # Set the error string, to either a string or None + response.ErrorMessage = errorMessage + # Check for a failure, which can happen at anytime. # If we have anything but a 200, stop processing now. response.StatusCode = statusCode @@ -220,13 +237,18 @@ async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageD # Do the operation, wait for the result. result = await self._SendAndReceive(SageOperationTypes.Chat, data, createDataContextOffset, onDataStreamReceived, True) - # If the status code is set at any time and not 200, we failed. - # Check this before the function result, so we get the status code. + # If the error string or status code is set at any time, we failed, regardless of the mode. + # Check these before anything else, to ensure they get handled. + # Note, the server might also return a valid response but override the response text with the error. + if response.ErrorMessage is not None: + # Return the raw string, since it's intended to send to the user. + return response.ErrorMessage + if response.StatusCode is not None and response.StatusCode != 200: # In some special cases, we want to map the status code to a user message. - userError = self._MapErrorStatusCodeToUserStr(response.StatusCode) - if userError is not None: - return userError + errorStr = self._MapErrorStatusCodeToUserStr(response.StatusCode) + if errorStr is not None: + return errorStr self.Logger.error(f"Sage Chat failed with status code {response.StatusCode}") return None @@ -316,7 +338,8 @@ async def _SendAndReceive(self, # If we are doing the upload stream, check that the stream wasn't aborted from the server side before returning. if context.StatusCode is not None or context.IsAborted: # If the stream was aborted, fire the callback and return False - await onDataStreamReceivedCallbackAsync(context.StatusCode, bytearray(), None, True) + # ErrorMessage is either None or a string to send to the user. + await onDataStreamReceivedCallbackAsync(context.StatusCode, context.ErrorMessage, bytearray(), None, True) return False # Otherwise, keep the context around and return success. @@ -333,6 +356,7 @@ async def _SendAndReceive(self, data:List[bytearray] = None dataContext:SageDataContext = None statusCode:int = None + errorMessage:str = None isDataDownloadComplete:bool = True with self.StreamContextLock: # Grab all of the data we currently have and process it. @@ -341,21 +365,22 @@ async def _SendAndReceive(self, isDataDownloadComplete = context.IsDataDownloadComplete dataContext = context.DataContext statusCode = context.StatusCode + errorMessage = context.ErrorMessage # Clear the event under lock to ensure we don't miss a set. context.Event.clear() # Check for a stream abort, before we check if there's no data. if context.IsAborted: - self.Logger.error(f"Sage message stream was aborted: {context.StatusCode}") - await onDataStreamReceivedCallbackAsync(context.StatusCode, bytearray(), None, True) + self.Logger.error(f"Sage message stream was aborted: {statusCode}") + await onDataStreamReceivedCallbackAsync(statusCode, errorMessage, bytearray(), None, True) return False # Regardless of the other vars, if we didn't get any data, the response wait timed out. if len(data) == 0: self.Logger.error("Sage message timed out while waiting for a response.") context.StatusCode = 608 - await onDataStreamReceivedCallbackAsync(context.StatusCode, bytearray(), None, True) + await onDataStreamReceivedCallbackAsync(statusCode, errorMessage, bytearray(), None, True) return False # Process the data @@ -368,7 +393,7 @@ async def _SendAndReceive(self, for d in data: count += 1 isLastChunk = isDataDownloadComplete and count == len(data) - if await onDataStreamReceivedCallbackAsync(statusCode, d, dataContext, isLastChunk) is False: + if await onDataStreamReceivedCallbackAsync(statusCode, errorMessage, d, dataContext, isLastChunk) is False: return False # If we processed all the data and the stream is done, we're done. @@ -530,6 +555,10 @@ def OnIncomingMessage(self, buf:bytearray): self.Logger.debug(f"Sage got a message for stream [{streamId}] but couldn't find the context.") return + # If there's an error string, always set or update it. + if msg.ErrorMessage() is not None: + context.ErrorMessage = msg.ErrorMessage().decode("utf-8") + # The abort message doesn't require the other checks. if msg.IsAbortMsg(): # The abort message should always have a non-success status code. @@ -575,7 +604,7 @@ def OnIncomingMessage(self, buf:bytearray): raise Exception("Sage Fiber message has is missing the data context.") context.DataContext = dataContext - # Append the data. + # Append the data. context.Data.append(data) # Set the event so the caller can consume the data. @@ -636,8 +665,13 @@ def __init__(self, streamId:int, requestType:SageOperationTypes): # Note this will be set multiple times if the response is being streamed. self.Event = threading.Event() - # Set when the first response is received. + # Usually set to 200 for success otherwise it can be an error. self.StatusCode:int = None + + # Optional - If set, this is an error message that should be shown to the user. + self.ErrorMessage:str = None + + # Set when the first response is received. self.DataContext:SageDataContext = None # When set, the data has been fully downloaded. diff --git a/homeway/homeway_linuxhost/sage/sagehandler.py b/homeway/homeway_linuxhost/sage/sagehandler.py index a6d8df9..8538373 100644 --- a/homeway/homeway_linuxhost/sage/sagehandler.py +++ b/homeway/homeway_linuxhost/sage/sagehandler.py @@ -29,8 +29,10 @@ class SageHandler(AsyncEventHandler): # This is more of a sanity check, the server will enforce it's own limits. c_MaxStringLength = 500 - # A quick static cache of the info event, since it's gets called in rapid succession sometimes. - c_InfoEventMaxCacheTimeSec = 15.0 + # Cache the describe info because it doesn't change often and it's sometimes requested many times rapidly. + # Sometimes the describe event is called before an action for some reason and is blocking, so the cache helps there as well. + # Since this hardly ever changes, there's no reason to update it often. It will update when the addon restarts. + c_InfoEventMaxCacheTimeSec = 60 * 60 * 24 * 3 s_InfoEvent:Info = None s_InfoEventTime:float = 0.0 s_InfoEventLock = threading.Lock() @@ -238,19 +240,30 @@ async def _HandleDescribe(self) -> bool: attempt += 1 # Attempt getting a valid response. + serviceInfo:Info = None try: self.Logger.debug(f"Sage - Starting Info Service Request - {url}") + # Attempt to get and parse the info object. response = requests.get(url, timeout=10) if response.status_code == 200: - info = self._BuildInfoEvent(response.json()) - if info is None: + serviceInfo = self._BuildInfoEvent(response.json()) + if serviceInfo is None: raise Exception("Failed to build info event.") - await self._CacheAndWriteInfoEvent(info) - return True - self.Logger.warning(f"Sage - Failed to get models from service. Attempt: {attempt} - {response.status_code}") + else: + self.Logger.warning(f"Sage - Failed to get models from service. Attempt: {attempt} - {response.status_code}") except Exception as e: self.Logger.warning(f"Sage - Failed to get models from service. Attempt: {attempt} - {e}") + # If we got service info, try to write it to the client. + if serviceInfo is not None: + self.Logger.debug("Sage - Service request successful, sending to Wyoming protocol.") + try: + await self._CacheAndWriteInfoEvent(serviceInfo) + # Success + return True + except Exception as e: + self.Logger.warning(f"Sage - Failed to send info to wyoming protocol. Attempt: {attempt} - {e}") + # If we fail, try a few times. Throw when we hit the limit. if attempt > 3: raise Exception("Failed to get models from service after 3 attempts.") diff --git a/homeway/homeway_linuxhost/sage/sagetranscribehandler.py b/homeway/homeway_linuxhost/sage/sagetranscribehandler.py index 691f3c1..b47e19c 100644 --- a/homeway/homeway_linuxhost/sage/sagetranscribehandler.py +++ b/homeway/homeway_linuxhost/sage/sagetranscribehandler.py @@ -38,9 +38,9 @@ def __init__(self, logger:logging.Logger, sageHandler, fiberManager:FiberManager raise ValueError("SageTranscribeHandler must be created with a AudioStart event.") self.IncomingAudioStartEvent:AudioStart = AudioStart.from_event(startAudioEvent) - # If set to true, this stream has failed, and all future requests should be ignored - # until the next audio stream is started. - self.HasErrored:bool = False + # If this is not None, we have hit an error and can ignore the rest of this listen stream. + # But due to the way the protocol works, we still need to handle the rest of the stream and send the error at the end as the result. + self.ErrorMessage:str = None # This holds the incoming audio buffer that is being streamed to the server. # If it's none, that means we haven't gotten any chunks since the last start. @@ -54,10 +54,10 @@ def __init__(self, logger:logging.Logger, sageHandler, fiberManager:FiberManager # Logs and error and writes an error message back to the client. - async def _WriteError(self, text:str, code:str=None) -> None: + def _SetError(self, text:str) -> None: # If we ever write an error back, set this boolean so we stop handing audio for this stream request. - self.HasErrored = True - await self.SageHandler.WriteError(text, code) + self.Logger.warning(f"Sage Listen Failed - {text}") + self.ErrorMessage = text # Helper for writing an event back to the client. @@ -67,13 +67,14 @@ async def _WriteEvent(self, event:Event) -> None: # Handles all streaming audio for speech to text. async def HandleStreamingAudio(self, event: Event) -> bool: - # If this is set, we failed to handle this stream some time in the past. - # We should ignore all future requests for this stream. - if self.HasErrored: - return True # Called when audio is being streamed to the server. if AudioChunk.is_type(event.type): + + # If we have errored and this is more stream, we just ignore it until the stream ends. + if self.ErrorMessage is not None: + return True + e = AudioChunk.from_event(event) if e.audio is None or len(e.audio) == 0: # This would be ok, if it ever happened. We have logic that will detect if we never got any audio. @@ -86,7 +87,7 @@ async def HandleStreamingAudio(self, event: Event) -> bool: else: streamTimeSec = time.time() - self.AudioStreamStartTimeSec if streamTimeSec > SageTranscribeHandler.c_MaxStreamTimeSec: - await self._WriteError(f"Homeway Sage Hit The Audio Stream Time Limit Of {int(SageTranscribeHandler.c_MaxStreamTimeSec)}s") + self._SetError(f"You hit the max speech-to-text time limit of {int(SageTranscribeHandler.c_MaxStreamTimeSec)} seconds.") return True # If this is the start of a new buffer, create the buffer now and start the timer. @@ -110,14 +111,13 @@ async def HandleStreamingAudio(self, event: Event) -> bool: # This should never happen. if result is None: - await self._WriteError("Homeway Sage Audio Stream Failed - No Result") + self._SetError("Homeway Sage Audio Stream Failed - No Result") return True # If the error text is set, we failed to send the audio. # We try to send the error string, since it might help the user. if result.Error is not None: - await self._WriteEvent(Transcript(text=result.Error).event()) - await self._WriteError("Homeway Sage Audio Stream Failed - " + result.Error) + self._SetError(result.Error) return True # Ensure the operation didn't take too long. @@ -138,6 +138,12 @@ async def HandleStreamingAudio(self, event: Event) -> bool: await self._WriteEvent(Transcript("").event()) return True + # If we have an error that ended the stream early, return it now. + # Note we can't write this as any kind of error or it gets lost, so we write it as a result. + if self.ErrorMessage is not None: + await self._WriteEvent(Transcript(self.ErrorMessage).event()) + return True + # Send the final audio chunk indicating that the audio stream is done. # This will now block and wait for a response. # Note that this incoming audio buffer can be empty if we don't have any buffered audio, which is fine. @@ -146,19 +152,21 @@ async def HandleStreamingAudio(self, event: Event) -> bool: # This should never happen. if result is None: - await self._WriteError("Homeway Sage Listen - No Result Returned.") + self._SetError("Homeway Sage Listen - No Result Returned.") + await self._WriteEvent(Transcript(self.ErrorMessage).event()) return True # If the error text is set, we failed to send the audio. # We try to send the error string, since it might help the user. if result.Error is not None: - await self._WriteEvent(Transcript(text=result.Error).event()) - await self._WriteError("Homeway Sage Audio Stream Failed - " + result.Error) + self._SetError(result.Error) + await self._WriteEvent(Transcript(self.ErrorMessage).event()) return True # This shouldn't happen. if result.Result is None: - await self._WriteError("Homeway Sage Listen - No Result Returned.") + self._SetError("Homeway Sage Listen - No Result Returned") + await self._WriteEvent(Transcript(self.ErrorMessage).event()) return True # Send the text back to the client.