-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathserver.py
1439 lines (1193 loc) · 63.8 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import json
import logging
import uuid
import warnings
from abc import abstractmethod
from datetime import datetime
from functools import wraps
from threading import Lock
from typing import Callable, List, Optional, Tuple, Union
from fastapi import HTTPException
import memgpt.constants as constants
import memgpt.presets.presets as presets
import memgpt.server.utils as server_utils
import memgpt.system as system
from memgpt.agent import Agent, save_agent
from memgpt.agent_store.storage import StorageConnector, TableType
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.cli.cli_config import get_model_options
from memgpt.config import MemGPTConfig
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from memgpt.credentials import MemGPTCredentials
from memgpt.data_sources.connectors import DataConnector, load_data
from memgpt.data_types import (
AgentState,
EmbeddingConfig,
LLMConfig,
Message,
Preset,
Source,
Token,
User,
)
# TODO use custom interface
from memgpt.interface import AgentInterface # abstract
from memgpt.interface import CLIInterface # for printing to terminal
from memgpt.metadata import MetadataStore
from memgpt.models.pydantic_models import (
DocumentModel,
PassageModel,
PresetModel,
SourceModel,
ToolModel,
)
from memgpt.settings import settings
logger = logging.getLogger(__name__)
class Server(object):
"""Abstract server class that supports multi-agent multi-user"""
@abstractmethod
def list_agents(self, user_id: uuid.UUID) -> dict:
"""List all available agents to a user"""
raise NotImplementedError
@abstractmethod
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
"""Paginated query of in-context messages in agent message queue"""
raise NotImplementedError
@abstractmethod
def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
"""Return the memory of an agent (core memory + non-core statistics)"""
raise NotImplementedError
@abstractmethod
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
"""Return the config of an agent"""
raise NotImplementedError
@abstractmethod
def get_server_config(self, user_id: uuid.UUID) -> dict:
"""Return the base config"""
raise NotImplementedError
@abstractmethod
def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_memory_contents: dict) -> dict:
"""Update the agents core memory block, return the new state"""
raise NotImplementedError
@abstractmethod
def create_agent(
self,
user_id: uuid.UUID,
agent_config: Union[dict, AgentState],
interface: Union[AgentInterface, None],
# persistence_manager: Union[PersistenceManager, None],
) -> str:
"""Create a new agent using a config"""
raise NotImplementedError
@abstractmethod
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
"""Process a message from the user, internally calls step"""
raise NotImplementedError
@abstractmethod
def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
"""Process a message from the system, internally calls step"""
raise NotImplementedError
@abstractmethod
def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]:
"""Run a command on the agent, e.g. /memory
May return a string with a message generated by the command
"""
raise NotImplementedError
class LockingServer(Server):
"""Basic support for concurrency protections (all requests that modify an agent lock the agent until the operation is complete)"""
# Locks for each agent
_agent_locks = {}
@staticmethod
def agent_lock_decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(self, user_id: uuid.UUID, agent_id: uuid.UUID, *args, **kwargs):
# logger.info("Locking check")
# Initialize the lock for the agent_id if it doesn't exist
if agent_id not in self._agent_locks:
# logger.info(f"Creating lock for agent_id = {agent_id}")
self._agent_locks[agent_id] = Lock()
# Check if the agent is currently locked
if not self._agent_locks[agent_id].acquire(blocking=False):
# logger.info(f"agent_id = {agent_id} is busy")
raise HTTPException(status_code=423, detail=f"Agent '{agent_id}' is currently busy.")
try:
# Execute the function
# logger.info(f"running function on agent_id = {agent_id}")
print("USERID", user_id)
return func(self, user_id, agent_id, *args, **kwargs)
finally:
# Release the lock
# logger.info(f"releasing lock on agent_id = {agent_id}")
self._agent_locks[agent_id].release()
return wrapper
@agent_lock_decorator
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
raise NotImplementedError
@agent_lock_decorator
def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]:
raise NotImplementedError
class SyncServer(LockingServer):
"""Simple single-threaded / blocking server process"""
def __init__(
self,
chaining: bool = True,
max_chaining_steps: bool = None,
# default_interface_cls: AgentInterface = CLIInterface,
default_interface: AgentInterface = CLIInterface(),
# default_persistence_manager_cls: PersistenceManager = LocalStateManager,
# auth_mode: str = "none", # "none, "jwt", "external"
):
"""Server process holds in-memory agents that are being run"""
# Server supports several auth modes:
# "none":
# no authentication, trust the incoming requests to have access to the user_id being modified
# "jwt_local":
# clients send bearer JWT tokens, which decode to user_ids
# JWT tokens are generated by the server process (using pyJWT) and stored in a database table
# "jwt_external":
# clients still send bearer JWT tokens, but token generation and validation is handled by an external service
# ie the server process will call 'external.decode(token)' to get the user_id
# if auth_mode == "none":
# self.auth_mode = auth_mode
# raise NotImplementedError # TODO
# elif auth_mode == "jwt_local":
# self.auth_mode = auth_mode
# elif auth_mode == "jwt_external":
# self.auth_mode = auth_mode
# raise NotImplementedError # TODO
# else:
# raise ValueError(auth_mode)
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
self.active_agents = []
# chaining = whether or not to run again if request_heartbeat=true
self.chaining = chaining
# if chaining == true, what's the max number of times we'll chain before yielding?
# none = no limit, can go on forever
self.max_chaining_steps = max_chaining_steps
# The default interface that will get assigned to agents ON LOAD
# self.default_interface_cls = default_interface_cls
self.default_interface = default_interface
# The default persistence manager that will get assigned to agents ON CREATION
# self.default_persistence_manager_cls = default_persistence_manager_cls
# Initialize the connection to the DB
self.config = MemGPTConfig.load()
print(f"server :: loading configuration from '{self.config.config_path}'")
assert self.config.persona is not None, "Persona must be set in the config"
assert self.config.human is not None, "Human must be set in the config"
# Update storage URI to match passed in settings
# TODO: very hack, fix in the future
for memory_type in ("archival", "recall", "metadata"):
if settings.memgpt_pg_uri:
# override with env
setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri)
self.config.save()
# TODO figure out how to handle credentials for the server
self.credentials = MemGPTCredentials.load()
# check credentials
# TODO: add checks for other providers
if (
self.config.default_embedding_config.embedding_endpoint_type == "openai"
or self.config.default_llm_config.model_endpoint_type == "openai"
):
assert self.credentials.openai_key is not None, "OpenAI key must be set in the credentials file"
# Ensure valid database configuration
# TODO: add back once tests are matched
# assert (
# self.config.metadata_storage_type == "postgres"
# ), f"Invalid metadata_storage_type for server: {self.config.metadata_storage_type}"
# assert (
# self.config.archival_storage_type == "postgres"
# ), f"Invalid archival_storage_type for server: {self.config.archival_storage_type}"
# assert self.config.recall_storage_type == "postgres", f"Invalid recall_storage_type for server: {self.config.recall_storage_type}"
# Generate default LLM/Embedding configs for the server
# TODO: we may also want to do the same thing with default persona/human/etc.
self.server_llm_config = LLMConfig(
model=self.config.default_llm_config.model,
model_endpoint_type=self.config.default_llm_config.model_endpoint_type,
model_endpoint=self.config.default_llm_config.model_endpoint,
model_wrapper=self.config.default_llm_config.model_wrapper,
context_window=self.config.default_llm_config.context_window,
# openai_key=self.credentials.openai_key,
# azure_key=self.credentials.azure_key,
# azure_endpoint=self.credentials.azure_endpoint,
# azure_version=self.credentials.azure_version,
# azure_deployment=self.credentials.azure_deployment,
)
self.server_embedding_config = EmbeddingConfig(
embedding_endpoint_type=self.config.default_embedding_config.embedding_endpoint_type,
embedding_endpoint=self.config.default_embedding_config.embedding_endpoint,
embedding_dim=self.config.default_embedding_config.embedding_dim,
embedding_model=self.config.default_embedding_config.embedding_model,
embedding_chunk_size=self.config.default_embedding_config.embedding_chunk_size,
)
assert self.server_embedding_config.embedding_model is not None, vars(self.server_embedding_config)
# Initialize the metadata store
self.ms = MetadataStore(self.config)
# pre-fill database (users, presets, humans, personas)
# TODO: figure out how to handle default users (server is technically multi-user)
user_id = uuid.UUID(self.config.anon_clientid)
user = User(
id=uuid.UUID(self.config.anon_clientid),
)
if self.ms.get_user(user_id):
# update user
self.ms.update_user(user)
else:
self.ms.create_user(user)
presets.add_default_presets(user_id, self.ms)
# NOTE: removed, since server should be multi-user
## Create the default user
# base_user_id = uuid.UUID(self.config.anon_clientid)
# if not self.ms.get_user(user_id=base_user_id):
# base_user = User(id=base_user_id)
# self.ms.create_user(base_user)
def save_agents(self):
"""Saves all the agents that are in the in-memory object store"""
for agent_d in self.active_agents:
try:
# agent_d["agent"].save()
save_agent(agent_d["agent"], self.ms)
logger.info(f"Saved agent {agent_d['agent_id']}")
except Exception as e:
logger.exception(f"Error occurred while trying to save agent {agent_d['agent_id']}:\n{e}")
def _get_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Union[Agent, None]:
"""Get the agent object from the in-memory object store"""
for d in self.active_agents:
if d["user_id"] == str(user_id) and d["agent_id"] == str(agent_id):
return d["agent"]
return None
def _add_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, agent_obj: Agent) -> None:
"""Put an agent object inside the in-memory object store"""
# Make sure the agent doesn't already exist
if self._get_agent(user_id=user_id, agent_id=agent_id) is not None:
# Can be triggered on concucrent request, so don't throw a full error
# raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is already loaded")
logger.exception(f"Agent (user={user_id}, agent={agent_id}) is already loaded")
return
# Add Agent instance to the in-memory list
self.active_agents.append(
{
"user_id": str(user_id),
"agent_id": str(agent_id),
"agent": agent_obj,
}
)
def _load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, interface: Union[AgentInterface, None] = None) -> Agent:
"""Loads a saved agent into memory (if it doesn't exist, throw an error)"""
assert isinstance(user_id, uuid.UUID), user_id
assert isinstance(agent_id, uuid.UUID), agent_id
# If an interface isn't specified, use the default
if interface is None:
interface = self.default_interface
try:
logger.info(f"Grabbing agent user_id={user_id} agent_id={agent_id} from database")
agent_state = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
if not agent_state:
logger.exception(f"agent_id {agent_id} does not exist")
raise ValueError(f"agent_id {agent_id} does not exist")
# print(f"server._load_agent :: load got agent state {agent_id}, messages = {agent_state.state['messages']}")
# Instantiate an agent object using the state retrieved
logger.info(f"Creating an agent object")
memgpt_agent = Agent(agent_state=agent_state, interface=interface)
# Add the agent to the in-memory store and return its reference
logger.info(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}")
self._add_agent(user_id=user_id, agent_id=agent_id, agent_obj=memgpt_agent)
return memgpt_agent
except Exception as e:
logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}")
raise
def _get_or_load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Agent:
"""Check if the agent is in-memory, then load"""
logger.info(f"Checking for agent user_id={user_id} agent_id={agent_id}")
memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
if not memgpt_agent:
logger.info(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
return memgpt_agent
def _step(self, user_id: uuid.UUID, agent_id: uuid.UUID, input_message: Union[str, Message]) -> int:
"""Send the input message through the agent"""
logger.debug(f"Got input message: {input_message}")
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
if memgpt_agent is None:
raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded")
logger.debug(f"Starting agent step")
no_verify = True
next_input_message = input_message
counter = 0
while True:
new_messages, heartbeat_request, function_failed, token_warning, tokens_accumulated = memgpt_agent.step(
next_input_message,
first_message=False,
skip_verify=no_verify,
return_dicts=False,
)
counter += 1
# Chain stops
if not self.chaining:
logger.debug("No chaining, stopping after one step")
break
elif self.max_chaining_steps is not None and counter > self.max_chaining_steps:
logger.debug(f"Hit max chaining steps, stopping after {counter} steps")
break
# Chain handlers
elif token_warning:
next_input_message = system.get_token_limit_warning()
continue # always chain
elif function_failed:
next_input_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
continue # always chain
elif heartbeat_request:
next_input_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
continue # always chain
# MemGPT no-op / yield
else:
break
memgpt_agent.interface.step_yield()
logger.debug(f"Finished agent step")
return tokens_accumulated
def _command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]:
"""Process a CLI command"""
logger.debug(f"Got command: {command}")
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
# print("AGENT", memgpt_agent.agent_state.id, memgpt_agent.agent_state.user_id)
if command.lower() == "exit":
# exit not supported on server.py
raise ValueError(command)
elif command.lower() == "save" or command.lower() == "savechat":
save_agent(memgpt_agent, self.ms)
elif command.lower() == "attach":
# Different from CLI, we extract the data source name from the command
command = command.strip().split()
try:
data_source = int(command[1])
except:
raise ValueError(command)
# attach data to agent from source
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
memgpt_agent.attach_source(data_source, source_connector, self.ms)
elif command.lower() == "dump" or command.lower().startswith("dump "):
# Check if there's an additional argument that's an integer
command = command.strip().split()
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0
if amount == 0:
memgpt_agent.interface.print_messages(memgpt_agent.messages, dump=True)
else:
memgpt_agent.interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
elif command.lower() == "dumpraw":
memgpt_agent.interface.print_messages_raw(memgpt_agent.messages)
elif command.lower() == "memory":
ret_str = (
f"\nDumping memory contents:\n"
+ f"\n{str(memgpt_agent.memory)}"
+ f"\n{str(memgpt_agent.persistence_manager.archival_memory)}"
+ f"\n{str(memgpt_agent.persistence_manager.recall_memory)}"
)
return ret_str
elif command.lower() == "pop" or command.lower().startswith("pop "):
# Check if there's an additional argument that's an integer
command = command.strip().split()
pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3
n_messages = len(memgpt_agent.messages)
MIN_MESSAGES = 2
if n_messages <= MIN_MESSAGES:
logger.info(f"Agent only has {n_messages} messages in stack, none left to pop")
elif n_messages - pop_amount < MIN_MESSAGES:
logger.info(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}")
else:
logger.info(f"Popping last {pop_amount} messages from stack")
for _ in range(min(pop_amount, len(memgpt_agent.messages))):
memgpt_agent.messages.pop()
elif command.lower() == "retry":
# TODO this needs to also modify the persistence manager
logger.info(f"Retrying for another answer")
while len(memgpt_agent.messages) > 0:
if memgpt_agent.messages[-1].get("role") == "user":
# we want to pop up to the last user message and send it again
memgpt_agent.messages[-1].get("content")
memgpt_agent.messages.pop()
break
memgpt_agent.messages.pop()
elif command.lower() == "rethink" or command.lower().startswith("rethink "):
# TODO this needs to also modify the persistence manager
if len(command) < len("rethink "):
logger.warning("Missing text after the command")
else:
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
if memgpt_agent.messages[x].get("role") == "assistant":
text = command[len("rethink ") :].strip()
memgpt_agent.messages[x].update({"content": text})
break
elif command.lower() == "rewrite" or command.lower().startswith("rewrite "):
# TODO this needs to also modify the persistence manager
if len(command) < len("rewrite "):
logger.warning("Missing text after the command")
else:
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
if memgpt_agent.messages[x].get("role") == "assistant":
text = command[len("rewrite ") :].strip()
args = json.loads(memgpt_agent.messages[x].get("function_call").get("arguments"), strict=JSON_LOADS_STRICT)
args["message"] = text
memgpt_agent.messages[x].get("function_call").update(
{"arguments": json.dumps(args, ensure_ascii=JSON_ENSURE_ASCII)}
)
break
# No skip options
elif command.lower() == "wipe":
# exit not supported on server.py
raise ValueError(command)
elif command.lower() == "heartbeat":
input_message = system.get_heartbeat()
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
elif command.lower() == "memorywarning":
input_message = system.get_token_limit_warning()
self._step(user_id=user_id, agent_id=agent_id, input_message=input_message)
@LockingServer.agent_lock_decorator
def user_message(
self, user_id: uuid.UUID, agent_id: uuid.UUID, message: Union[str, Message], timestamp: Optional[datetime] = None
) -> None:
"""Process an incoming user message and feed it through the MemGPT agent"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# Basic input sanitization
if isinstance(message, str):
if len(message) == 0:
raise ValueError(f"Invalid input: '{message}'")
# If the input begins with a command prefix, reject
elif message.startswith("/"):
raise ValueError(f"Invalid input: '{message}'")
packaged_user_message = system.package_user_message(user_message=message)
# NOTE: eventually deprecate and only allow passing Message types
# Convert to a Message object
message = Message(
user_id=user_id,
agent_id=agent_id,
role="user",
text=packaged_user_message,
# name=None, # TODO handle name via API
)
if isinstance(message, Message):
# Can't have a null text field
if len(message.text) == 0 or message.text is None:
raise ValueError(f"Invalid input: '{message.text}'")
# If the input begins with a command prefix, reject
elif message.text.startswith("/"):
raise ValueError(f"Invalid input: '{message.text}'")
else:
raise TypeError(f"Invalid input: '{message}' - type {type(message)}")
if timestamp:
# Override the timestamp with what the caller provided
message.created_at = timestamp
# Run the agent state forward
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_user_message)
@LockingServer.agent_lock_decorator
def system_message(
self, user_id: uuid.UUID, agent_id: uuid.UUID, message: Union[str, Message], timestamp: Optional[datetime] = None
) -> None:
"""Process an incoming system message and feed it through the MemGPT agent"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# Basic input sanitization
if isinstance(message, str):
if len(message) == 0:
raise ValueError(f"Invalid input: '{message}'")
# If the input begins with a command prefix, reject
elif message.startswith("/"):
raise ValueError(f"Invalid input: '{message}'")
packaged_system_message = system.package_system_message(system_message=message)
# NOTE: eventually deprecate and only allow passing Message types
# Convert to a Message object
message = Message(
user_id=user_id,
agent_id=agent_id,
role="user",
text=packaged_system_message,
# name=None, # TODO handle name via API
)
if isinstance(message, Message):
# Can't have a null text field
if len(message.text) == 0 or message.text is None:
raise ValueError(f"Invalid input: '{message.text}'")
# If the input begins with a command prefix, reject
elif message.text.startswith("/"):
raise ValueError(f"Invalid input: '{message.text}'")
else:
raise TypeError(f"Invalid input: '{message}' - type {type(message)}")
if timestamp:
# Override the timestamp with what the caller provided
message.created_at = timestamp
# Run the agent state forward
self._step(user_id=user_id, agent_id=agent_id, input_message=packaged_system_message)
@LockingServer.agent_lock_decorator
def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]:
"""Run a command on the agent"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# If the input begins with a command prefix, attempt to process it as a command
if command.startswith("/"):
if len(command) > 1:
command = command[1:] # strip the prefix
return self._command(user_id=user_id, agent_id=agent_id, command=command)
def create_user(
self,
user_config: Optional[Union[dict, User]] = {},
):
"""Create a new user using a config"""
if not isinstance(user_config, dict):
raise ValueError(f"user_config must be provided as a dictionary")
user = User(
id=user_config["id"] if "id" in user_config else None,
)
self.ms.create_user(user)
logger.info(f"Created new user from config: {user}")
return user
def create_agent(
self,
user_id: uuid.UUID,
name: Optional[str] = None,
preset: Optional[str] = None,
persona: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
human: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
persona_name: Optional[str] = None,
human_name: Optional[str] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
interface: Union[AgentInterface, None] = None,
# persistence_manager: Union[PersistenceManager, None] = None,
function_names: Optional[List[str]] = None, # TODO remove
) -> AgentState:
"""Create a new agent using a config"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if interface is None:
# interface = self.default_interface_cls()
interface = self.default_interface
# if persistence_manager is None:
# persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config)
logger.debug(f"Attempting to find user: {user_id}")
user = self.ms.get_user(user_id=user_id)
if not user:
raise ValueError(f"cannot find user with associated client id: {user_id}")
# NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation
# TODO: fix this db dependency and remove
# self.ms.create_agent(agent_state)
# TODO modify to do creation via preset
try:
preset_obj = self.ms.get_preset(name=preset if preset else self.config.preset, user_id=user_id)
preset_override = False
assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist"
logger.debug(f"Attempting to create agent from preset:\n{preset_obj}")
# Overwrite fields in the preset if they were specified
if human is not None and human != preset_obj.human:
preset_override = True
preset_obj.human = human
# This is a check for a common bug where users were providing filenames instead of values
# try:
# get_human_text(human)
# raise ValueError(human)
# raise UserWarning(
# f"It looks like there is a human file named {human} - did you mean to pass the file contents to the `human` arg?"
# )
# except:
# pass
if persona is not None:
preset_override = True
preset_obj.persona = persona
# try:
# get_persona_text(persona)
# raise ValueError(persona)
# raise UserWarning(
# f"It looks like there is a persona file named {persona} - did you mean to pass the file contents to the `persona` arg?"
# )
# except:
# pass
if human_name is not None and human_name != preset_obj.human_name:
preset_override = True
preset_obj.human_name = human_name
if persona_name is not None and persona_name != preset_obj.persona_name:
preset_override = True
preset_obj.persona_name = persona_name
llm_config = llm_config if llm_config else self.server_llm_config
embedding_config = embedding_config if embedding_config else self.server_embedding_config
# TODO remove (https://github.com/cpacker/MemGPT/issues/1138)
if function_names is not None:
preset_override = True
# available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
available_tools = self.ms.list_tools()
available_tools_names = [t.name for t in available_tools]
assert all([f_name in available_tools_names for f_name in function_names])
preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names]
print("overriding preset_obj tools with:", preset_obj.functions_schema)
# If the user overrode any parts of the preset, we need to create a new preset to refer back to
if preset_override:
# Change the name and uuid
preset_obj = Preset.clone(preset_obj=preset_obj)
# Then write out to the database for storage
self.ms.create_preset(preset=preset_obj)
agent = Agent(
interface=interface,
preset=preset_obj,
name=name,
created_by=user.id,
llm_config=llm_config,
embedding_config=embedding_config,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
)
# FIXME: this is a hacky way to get the system prompts injected into agent into the DB
# self.ms.update_agent(agent.agent_state)
except Exception as e:
logger.exception(e)
try:
self.ms.delete_agent(agent_id=agent.agent_state.id)
except Exception as delete_e:
logger.exception(f"Failed to delete_agent:\n{delete_e}")
raise e
save_agent(agent, self.ms)
logger.info(f"Created new agent from config: {agent}")
return agent.agent_state
def delete_agent(
self,
user_id: uuid.UUID,
agent_id: uuid.UUID,
):
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# TODO: Make sure the user owns the agent
agent = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
if agent is not None:
self.ms.delete_agent(agent_id=agent_id)
def delete_preset(self, user_id: uuid.UUID, preset_id: uuid.UUID) -> Preset:
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
# first get the preset by name
preset = self.get_preset(preset_id=preset_id, user_id=user_id)
if preset is None:
raise ValueError(f"Could not find preset_id {preset_id}")
# then delete via name
# TODO allow delete-by-id, eg via server.delete_preset function
self.ms.delete_preset(name=preset.name, user_id=user_id)
return preset
def initialize_default_presets(self, user_id: uuid.UUID):
"""Add default preset options into the metadata store"""
presets.add_default_presets(user_id, self.ms)
def create_preset(self, preset: Preset):
"""Create a new preset using a config"""
if preset.user_id is not None and self.ms.get_user(user_id=preset.user_id) is None:
raise ValueError(f"User user_id={preset.user_id} does not exist")
self.ms.create_preset(preset)
return preset
def get_preset(
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None
) -> Preset:
"""Get the preset"""
return self.ms.get_preset(preset_id=preset_id, name=preset_name, user_id=user_id)
def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]:
# TODO update once we strip Preset in favor of PresetModel
presets = self.ms.list_presets(user_id=user_id)
presets = [PresetModel(**vars(p)) for p in presets]
return presets
def _agent_state_to_config(self, agent_state: AgentState) -> dict:
"""Convert AgentState to a dict for a JSON response"""
assert agent_state is not None
agent_config = {
"id": agent_state.id,
"name": agent_state.name,
"human": agent_state.human,
"persona": agent_state.persona,
"created_at": agent_state.created_at.isoformat(),
}
return agent_config
# TODO make return type pydantic
def list_agents(
self,
user_id: uuid.UUID,
) -> dict:
"""List all available agents to a user"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
agents_states = self.ms.list_agents(user_id=user_id)
agents_states_dicts = [self._agent_state_to_config(state) for state in agents_states]
# TODO add a get_message_obj_from_message_id(...) function
# this would allow grabbing Message.created_by without having to load the agent object
# all_available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
all_available_tools = self.ms.list_tools()
for agent_state, return_dict in zip(agents_states, agents_states_dicts):
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
# TODO remove this eventually when return type get pydanticfied
# this is to add persona_name and human_name so that the columns in UI can populate
preset = self.ms.get_preset(name=agent_state.preset, user_id=user_id)
# TODO hack for frontend, remove
# (top level .persona is persona_name, and nested memory.persona is the state)
return_dict["persona"] = preset.persona_name
return_dict["human"] = preset.human_name
# Add information about tools
# TODO memgpt_agent should really have a field of List[ToolModel]
# then we could just pull that field and return it here
return_dict["tools"] = [tool for tool in all_available_tools if tool.json_schema in memgpt_agent.functions]
# Add information about memory (raw core, size of recall, size of archival)
core_memory = memgpt_agent.memory
recall_memory = memgpt_agent.persistence_manager.recall_memory
archival_memory = memgpt_agent.persistence_manager.archival_memory
memory_obj = {
"core_memory": {
"persona": core_memory.persona,
"human": core_memory.human,
},
"recall_memory": len(recall_memory) if recall_memory is not None else None,
"archival_memory": len(archival_memory) if archival_memory is not None else None,
}
return_dict["memory"] = memory_obj
# Add information about last run
# NOTE: 'last_run' is just the timestamp on the latest message in the buffer
# Retrieve the Message object via the recall storage or by directly access _messages
last_msg_obj = memgpt_agent._messages[-1]
return_dict["last_run"] = last_msg_obj.created_at
# Add information about attached sources
sources_ids = self.ms.list_attached_sources(agent_id=agent_state.id)
sources = [self.ms.get_source(source_id=s_id) for s_id in sources_ids]
return_dict["sources"] = [vars(s) for s in sources]
# Sort agents by "last_run" in descending order, most recent first
agents_states_dicts.sort(key=lambda x: x["last_run"], reverse=True)
logger.info(f"Retrieved {len(agents_states)} agents for user {user_id}:\n{[vars(s) for s in agents_states]}")
return {
"num_agents": len(agents_states),
"agents": agents_states_dicts,
}
def get_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID):
"""Get the agent state"""
return self.ms.get_agent(agent_id=agent_id, user_id=user_id)
def get_user(self, user_id: uuid.UUID) -> User:
"""Get the user"""
return self.ms.get_user(user_id=user_id)
def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
"""Return the memory of an agent (core memory + non-core statistics)"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
core_memory = memgpt_agent.memory
recall_memory = memgpt_agent.persistence_manager.recall_memory
archival_memory = memgpt_agent.persistence_manager.archival_memory
memory_obj = {
"core_memory": {
"persona": core_memory.persona,
"human": core_memory.human,
},
"recall_memory": len(recall_memory) if recall_memory is not None else None,
"archival_memory": len(archival_memory) if archival_memory is not None else None,
}
return memory_obj
def get_in_context_message_ids(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> List[uuid.UUID]:
"""Get the message ids of the in-context messages in the agent's memory"""
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
return [m.id for m in memgpt_agent._messages]
def get_agent_message(self, agent_id: uuid.UUID, message_id: uuid.UUID) -> Message:
"""Get message based on agent and message ID"""
agent_state = self.ms.get_agent(agent_id=agent_id)
if agent_state is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
user_id = agent_state.user_id
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
message = memgpt_agent.persistence_manager.recall_memory.storage.get(message_id=message_id)
return message
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
"""Paginated query of all messages in agent message queue"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
if start < 0 or count < 0:
raise ValueError("Start and count values should be non-negative")
if start + count < len(memgpt_agent._messages): # messages can be returned from whats in memory
# Reverse the list to make it in reverse chronological order
reversed_messages = memgpt_agent._messages[::-1]
# Check if start is within the range of the list
if start >= len(reversed_messages):
raise IndexError("Start index is out of range")
# Calculate the end index, ensuring it does not exceed the list length
end_index = min(start + count, len(reversed_messages))
# Slice the list for pagination
messages = reversed_messages[start:end_index]
# Convert to json
# Add a tag indicating in-context or not
json_messages = [{**record.to_json(), "in_context": True} for record in messages]
else:
# need to access persistence manager for additional messages
db_iterator = memgpt_agent.persistence_manager.recall_memory.storage.get_all_paginated(page_size=count, offset=start)
# get a single page of messages
# TODO: handle stop iteration
page = next(db_iterator, [])
# return messages in reverse chronological order
messages = sorted(page, key=lambda x: x.created_at, reverse=True)