diff --git a/quixstreams/kafka/configuration.py b/quixstreams/kafka/configuration.py index 0637cb1de..52f2f69e4 100644 --- a/quixstreams/kafka/configuration.py +++ b/quixstreams/kafka/configuration.py @@ -1,7 +1,7 @@ -from typing import Callable, Literal, Optional, Tuple, Type +from typing import Callable, Literal, Optional, Tuple, Type, Union import pydantic -from pydantic import AliasChoices, Field, SecretStr +from pydantic import AliasChoices, Field, ImportString, SecretStr from pydantic.functional_validators import BeforeValidator from pydantic_settings import ( BaseSettings as PydanticBaseSettings, @@ -52,11 +52,13 @@ class ConnectionConfig(BaseSettings): sasl_kerberos_principal: Optional[str] = None # for oauth_cb, see https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#pythonclient-configuration - oauth_cb: Optional[Callable[[str], Tuple[str, float]]] = pydantic.Field( - # Prevent the AliasGenerator from changing the field name to "oauth.cb" - default=None, - alias_priority=2, - serialization_alias="oauth_cb", + oauth_cb: Optional[Union[Callable[[str], Tuple[str, float]], ImportString]] = ( + pydantic.Field( + # Prevent the AliasGenerator from changing the field name to "oauth.cb" + default=None, + alias_priority=2, + serialization_alias="oauth_cb", + ) ) sasl_oauthbearer_config: Optional[str] = None diff --git a/tests/test_quixstreams/test_kafka/test_configuration.py b/tests/test_quixstreams/test_kafka/test_configuration.py index 0831400d5..7af263346 100644 --- a/tests/test_quixstreams/test_kafka/test_configuration.py +++ b/tests/test_quixstreams/test_kafka/test_configuration.py @@ -1,4 +1,5 @@ import os +from typing import Tuple from unittest.mock import patch import pydantic @@ -7,6 +8,10 @@ from quixstreams.kafka.configuration import ConnectionConfig +def example_oauth_cb(config: str) -> Tuple[str, float]: + return config, 1.0 + + class TestConnectionConfig: def test_literal_casings(self): """ @@ -139,3 +144,14 @@ def test_environment_not_read(self): with patch.dict(os.environ, {"SASL_PASSWORD": "cool_pw"}): config = ConnectionConfig(bootstrap_servers="url") assert config.sasl_password is None + + def test_oauth_cb_with_callable(self): + config = ConnectionConfig(bootstrap_servers="url", oauth_cb=example_oauth_cb) + assert config.oauth_cb == example_oauth_cb + + def test_oauth_cb_with_import_string(self): + config = ConnectionConfig( + bootstrap_servers="url", + oauth_cb="tests.test_quixstreams.test_kafka.test_configuration.example_oauth_cb", + ) + assert config.oauth_cb == example_oauth_cb