Skip to content

Commit

Permalink
attempt to fix mypy by always using typing-extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Nov 13, 2021
1 parent 1b75ec0 commit 5430e00
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 123 deletions.
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ jinja2>=2.11.2
pydantic>=1.8.0
click>=7.1.2
python-dotenv>=0.12.0
typing-extensions>=3.7
contextvars; python_version < '3.7'
typing-extensions>=3.7; python_version < '3.9.2'
23 changes: 6 additions & 17 deletions src/prisma/_types.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
import sys
from typing import Callable, Coroutine, TypeVar, Any
from typing_extensions import (
TypedDict as TypedDict,
Protocol as Protocol,
Literal as Literal,
runtime_checkable as runtime_checkable,
)

from pydantic import BaseModel


if sys.version_info >= (3, 9, 2):
from typing import ( # pylint: disable=no-name-in-module, unused-import
TypedDict as TypedDict,
Protocol as Protocol,
Literal as Literal,
runtime_checkable as runtime_checkable,
)
else:
from typing_extensions import (
TypedDict as TypedDict,
Protocol as Protocol,
Literal as Literal,
runtime_checkable as runtime_checkable,
)


Method = Literal['GET', 'POST']

BaseModelT = TypeVar('BaseModelT', bound=BaseModel)
Expand Down
6 changes: 1 addition & 5 deletions src/prisma/generator/templates/_header.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ from typing import (
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal


120 changes: 20 additions & 100 deletions tests/test_generation/exhaustive/__snapshots__/test_exhaustive.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template actions.py.jinja --
from . import types, errors
Expand Down Expand Up @@ -3133,11 +3129,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template builder.py.jinja --

Expand Down Expand Up @@ -3987,11 +3979,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template client.py.jinja --
from types import TracebackType
Expand Down Expand Up @@ -5787,11 +5775,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template engine/query.py.jinja --

Expand Down Expand Up @@ -6025,11 +6009,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template enums.py.jinja --
from enum import Enum
Expand Down Expand Up @@ -6072,11 +6052,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template fields.py.jinja --
import base64
Expand Down Expand Up @@ -6220,11 +6196,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template http.py.jinja --
from ._async_http import (
Expand Down Expand Up @@ -6264,11 +6236,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template models.py.jinja --
import os
Expand Down Expand Up @@ -8175,11 +8143,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template partials.py.jinja --
from pydantic import BaseModel, Field, validator
Expand Down Expand Up @@ -8225,11 +8189,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template types.py.jinja --
from .utils import _NoneType
Expand Down Expand Up @@ -18648,11 +18608,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template actions.py.jinja --
from . import types, errors
Expand Down Expand Up @@ -21754,11 +21710,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template builder.py.jinja --

Expand Down Expand Up @@ -22608,11 +22560,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template client.py.jinja --
from types import TracebackType
Expand Down Expand Up @@ -24408,11 +24356,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template engine/query.py.jinja --

Expand Down Expand Up @@ -24639,11 +24583,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template enums.py.jinja --
from enum import Enum
Expand Down Expand Up @@ -24686,11 +24626,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template fields.py.jinja --
import base64
Expand Down Expand Up @@ -24834,11 +24770,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template http.py.jinja --
from ._sync_http import (
Expand Down Expand Up @@ -24878,11 +24810,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template models.py.jinja --
import os
Expand Down Expand Up @@ -26789,11 +26717,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template partials.py.jinja --
from pydantic import BaseModel, Field, validator
Expand Down Expand Up @@ -26839,11 +26763,7 @@
overload,
cast,
)

if sys.version_info >= (3, 9, 2):
from typing import TypedDict, Literal
else:
from typing_extensions import TypedDict, Literal
from typing_extensions import TypedDict, Literal

# -- template types.py.jinja --
from .utils import _NoneType
Expand Down

0 comments on commit 5430e00

Please # to comment.