Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

The serializer error occurs when the example of werewolf exits error #1688

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def create_result_data(self, round: int, score: float, avg_cost: float, total_co
return {"round": round, "score": score, "avg_cost": avg_cost, "total_cost": total_cost, "time": now}

def save_results(self, json_file_path: str, data: list):
write_json_file(json_file_path, data, encoding="utf-8", indent=4)
write_json_file(json_file_path, data, encoding="utf-8", indent=4, cls=NumpyJSONEncoder)

def _load_scores(self, path=None, mode="Graph"):
if mode == "Graph":
Expand All @@ -147,3 +147,30 @@ def _load_scores(self, path=None, mode="Graph"):
self.top_scores.sort(key=lambda x: x["score"], reverse=True)

return self.top_scores

class NumpyJSONEncoder(json.JSONEncoder):
"""customized JSON encoder for numpy data type

features:
1. support numpy array serialization (automatically convert to list)
2. support numpy scalar type (int32, float64, etc.) serialization
3. keep the original data type precision
4. compatible with regular JSON data types
"""
def default(self, obj):
"""override the default serialize method"""

# handle numpy array type
if isinstance(obj, np.ndarray):
return {
'__ndarray__': obj.tolist(), # 转换为Python列表
'dtype': str(obj.dtype), # 保留数据类型信息
'shape': obj.shape # 保留数组形状
}

# handle numpy scalar type
elif isinstance(obj, np.generic):
return obj.item() # convert to python native type

# handle other type using default method
return super().default(obj)
14 changes: 12 additions & 2 deletions metagpt/ext/werewolf/schema.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any

import json
from pydantic import BaseModel, Field, field_validator

from metagpt.schema import Message
from metagpt.utils.common import any_to_str_set

from metagpt.configs.llm_config import LLMType

class RoleExperience(BaseModel):
id: str = ""
Expand All @@ -31,3 +31,13 @@ class WwMessage(Message):
@classmethod
def check_restricted_to(cls, restricted_to: Any):
return any_to_str_set(restricted_to if restricted_to else set())


class WwJsonEncoder(json.JSONEncoder):
def _default(self, obj):
if isinstance(obj, type): # handle class
return {
"__type__": obj.__name__,
"__module__": obj.__module__
}
return super().default(obj)
3 changes: 2 additions & 1 deletion metagpt/ext/werewolf/werewolf_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from metagpt.actions.add_requirement import UserRequirement
from metagpt.context import Context
from metagpt.environment.werewolf.werewolf_env import WerewolfEnv
from metagpt.ext.werewolf.schema import WwMessage
from metagpt.ext.werewolf.schema import WwMessage, WwJsonEncoder
from metagpt.team import Team


Expand All @@ -14,6 +14,7 @@ class WerewolfGame(Team):

def __init__(self, context: Context = None, **data: Any):
super(Team, self).__init__(**data)
self.json_encoder = WwJsonEncoder
ctx = context or Context()
if not self.env:
self.env = WerewolfEnv(context=ctx)
Expand Down
4 changes: 3 additions & 1 deletion metagpt/team.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import warnings
import json
from pathlib import Path
from typing import Any, Optional

Expand Down Expand Up @@ -40,6 +41,7 @@ class Team(BaseModel):
env: Optional[Environment] = None
investment: float = Field(default=10.0)
idea: str = Field(default="")
json_encoder: json.JSONEncoder = Field(default=None, exclude=True)

def __init__(self, context: Context = None, **data: Any):
super(Team, self).__init__(**data)
Expand All @@ -59,7 +61,7 @@ def serialize(self, stg_path: Path = None):
serialized_data = self.model_dump()
serialized_data["context"] = self.env.context.serialize()

write_json_file(team_info_path, serialized_data)
write_json_file(team_info_path, serialized_data, cls=self.json_encoder)

@classmethod
def deserialize(cls, stg_path: Path, context: Context = None) -> "Team":
Expand Down
15 changes: 13 additions & 2 deletions metagpt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,14 +571,25 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
raise ValueError(f"read json file: {json_file} failed")
return data

def wrapper_none_error(func):
"""Wrapper for ValueError"""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except ValueError as e:
# the to_jsonable_python of pydantic will raise PydanticSerializationError
# return None to call the custom JSONEncoder
return None
return wrapper

def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4):
def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4, cls: json.JSONEncoder=None):
folder_path = Path(json_file).parent
if not folder_path.exists():
folder_path.mkdir(parents=True, exist_ok=True)

with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python)
json.dump(data, fout, ensure_ascii=False, cls=cls, indent=indent,
default=wrapper_none_error(to_jsonable_python) if cls else to_jsonable_python)


def read_jsonl_file(jsonl_file: str, encoding="utf-8") -> list[dict]:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def run(self):
"llama-index-postprocessor-cohere-rerank==0.1.4",
"llama-index-postprocessor-colbert-rerank==0.1.1",
"llama-index-postprocessor-flag-embedding-reranker==0.1.2",
# "llama-index-vector-stores-milvus==0.1.23",
"llama-index-vector-stores-milvus==0.1.23",
"docx2txt==0.8",
],
}
Expand Down
15 changes: 15 additions & 0 deletions tests/metagpt/ext/werewolf/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from metagpt.ext.werewolf.schema import WwJsonEncoder
from metagpt.ext.werewolf.actions.common_actions import Speak
from metagpt.environment.werewolf.const import RoleType, RoleState, RoleActionRes
import json
from metagpt.utils.common import to_jsonable_python
def test_ww_json_encoder():
encoder = WwJsonEncoder
data = {
"test": RoleType.VILLAGER,
"test2": RoleState.ALIVE,
"test3": RoleActionRes.PASS,
"test4": [Speak],
}
encoded = json.dumps(data, cls=encoder, default=to_jsonable_python)
# print(encoded)
Loading