From 078cffddfa8ee514ff740fca035ddac03dc37ae8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= <16805946+edgarrmondragon@users.noreply.github.com> Date: Thu, 25 Jul 2024 10:55:47 -0400 Subject: [PATCH] refactor: Make `SQLSink` a generic with a `SQLConnector` type parameter (#2564) --- samples/sample_target_sqlite/__init__.py | 2 +- singer_sdk/sinks/sql.py | 12 +++++++----- tests/conftest.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/samples/sample_target_sqlite/__init__.py b/samples/sample_target_sqlite/__init__.py index 07e5172aa..d296c5fb0 100644 --- a/samples/sample_target_sqlite/__init__.py +++ b/samples/sample_target_sqlite/__init__.py @@ -26,7 +26,7 @@ def get_sqlalchemy_url(self, config: dict[str, t.Any]) -> str: # noqa: PLR6301 return f"sqlite:///{config[DB_PATH_CONFIG]}" -class SQLiteSink(SQLSink): +class SQLiteSink(SQLSink[SQLiteConnector]): """The Sink class for SQLite. This class allows developers to optionally override `get_records()` and other diff --git a/singer_sdk/sinks/sql.py b/singer_sdk/sinks/sql.py index 9b8823f27..33a741614 100644 --- a/singer_sdk/sinks/sql.py +++ b/singer_sdk/sinks/sql.py @@ -23,11 +23,13 @@ from singer_sdk.target_base import Target +_C = t.TypeVar("_C", bound=SQLConnector) -class SQLSink(BatchSink): + +class SQLSink(BatchSink, t.Generic[_C]): """SQL-type sink type.""" - connector_class: type[SQLConnector] + connector_class: type[_C] soft_delete_column_name = "_sdc_deleted_at" version_column_name = "_sdc_table_version" @@ -37,7 +39,7 @@ def __init__( stream_name: str, schema: dict, key_properties: t.Sequence[str] | None, - connector: SQLConnector | None = None, + connector: _C | None = None, ) -> None: """Initialize SQL Sink. @@ -48,12 +50,12 @@ def __init__( key_properties: The primary key columns. connector: Optional connector to reuse. """ - self._connector: SQLConnector + self._connector: _C self._connector = connector or self.connector_class(dict(target.config)) super().__init__(target, stream_name, schema, key_properties) @property - def connector(self) -> SQLConnector: + def connector(self) -> _C: """The connector object. Returns: diff --git a/tests/conftest.py b/tests/conftest.py index d2961722f..0b5bc4b74 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -138,7 +138,7 @@ class SQLConnectorMock(SQLConnector): """A Mock SQLConnector class.""" -class SQLSinkMock(SQLSink): +class SQLSinkMock(SQLSink[SQLConnectorMock]): """A mock Sink class.""" name = "sql-sink-mock"