From 26574aaddbc52b946b855c6443e6203680f1ae2a Mon Sep 17 00:00:00 2001 From: Keming Date: Sun, 19 Jan 2025 21:17:26 +0800 Subject: [PATCH] feat: support other content-type Signed-off-by: Keming --- .gitignore | 1 + defspec/spec.py | 22 ++++++++++++++++++---- pyproject.toml | 2 +- tests/test_spec.py | 16 +++++++++++++++- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index ebab65a..3471ded 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ cython_debug/ .ruff_cache/ *.json +*.html diff --git a/defspec/spec.py b/defspec/spec.py index 63913c6..3abcc1d 100644 --- a/defspec/spec.py +++ b/defspec/spec.py @@ -15,6 +15,9 @@ from typing_extensions import Self +DEFAULT_CONTENT_TYPE = "application/json" + + class OpenAPIInfo(msgspec.Struct, kw_only=True): title: str = "OpenAPI" description: str = "OpenAPI generated by defspec" @@ -35,10 +38,13 @@ class JSONSchema(msgspec.Struct, kw_only=True, omit_defaults=True): content: dict[str, dict] = msgspec.field(default_factory=dict) @classmethod - def with_json_schema(cls, schema: dict[str, dict]) -> Self: + def with_schema_content_type( + cls, schema: dict[str, dict], content_type: Optional[str] = None + ) -> Self: instance = cls() + content_type = content_type or DEFAULT_CONTENT_TYPE if schema.get("type", "") != "null": - instance.content["application/json"] = {"schema": schema} + instance.content[content_type] = {"schema": schema} return instance @@ -114,7 +120,9 @@ def register_route( method: HTTP_METHODS, summary: Optional[str] = None, request_type: Optional[Type] = None, + request_content_type: Optional[str] = None, response_type: Optional[Type] = None, + response_content_type: Optional[str] = None, query_type: Optional[Type] = None, header_type: Optional[Type] = None, cookie_type: Optional[Type] = None, @@ -142,8 +150,14 @@ def register_route( self.paths[path][method] = OpenAPIRoute( summary=summary or f"{method} from {path.replace('/', ' ')}", operation_id=f"{method}_{path.replace('/', '_')}", - request_body=OpenAPIRequestBody.with_json_schema(request_schema), - responses={"200": OpenAPIResponse.with_json_schema(response_schema)}, + request_body=OpenAPIRequestBody.with_schema_content_type( + request_schema, request_content_type + ), + responses={ + "200": OpenAPIResponse.with_schema_content_type( + response_schema, response_content_type + ) + }, deprecated=deprecated, ) diff --git a/pyproject.toml b/pyproject.toml index 1e18146..1b2555e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,4 +44,4 @@ ignore = ["E501"] [tool.ruff.lint.isort] known-first-party = ["defspec"] [tool.ruff.lint.pylint] -max-args = 10 +max-args = 12 diff --git a/tests/test_spec.py b/tests/test_spec.py index 944febd..8b1bb58 100644 --- a/tests/test_spec.py +++ b/tests/test_spec.py @@ -177,6 +177,14 @@ def openapi_spec(request): header_type=parameter.header, cookie_type=parameter.cookie, ) + openapi.register_route( + path="/test/msgpack", + method="post", + request_content_type="application/msgpack", + response_content_type="application/msgpack", + request_type=parameter.request, + response_type=parameter.response, + ) openapi.register_route( path="/", method="get", @@ -187,7 +195,7 @@ def openapi_spec(request): def test_openapi_spec(openapi_spec): spec = openapi_spec.to_dict() - assert list(spec["paths"].keys()) == ["/test", "/"] + assert list(spec["paths"].keys()) == ["/test", "/test/msgpack", "/"] assert list(spec["paths"]["/test"].keys()) == ["post"] assert list(spec["paths"]["/"].keys()) == ["get"] @@ -214,3 +222,9 @@ def test_openapi_spec(openapi_spec): assert header["schema"]["$ref"].startswith("#/$defs/Header") assert header["description"] == "Set your API key here." assert cookie["schema"]["$ref"].startswith("#/$defs/Cookie") + + msgpack = spec["paths"]["/test/msgpack"]["post"] + request = msgpack["requestBody"]["content"]["application/msgpack"]["schema"] + assert request["$ref"].startswith("#/$defs/RequestBody") + response = msgpack["responses"]["200"]["content"]["application/msgpack"]["schema"] + assert response["$ref"].startswith("#/$defs/Response")