Skip to content

Commit

Permalink
client: allow for custom kafka clients
Browse files Browse the repository at this point in the history
Provide the consumer, producer and admin client with the option to
create the kafka client from a custom callable, thus allowing more
flexibility in handling certain low level errors
  • Loading branch information
Gabriel Tincu committed Oct 20, 2020
1 parent 6f932ba commit 234cca5
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
13 changes: 9 additions & 4 deletions kafka/admin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from kafka.admin.acl_resource import ACLOperation, ACLPermissionType, ACLFilter, ACL, ResourcePattern, ResourceType, \
ACLResourcePatternType
from kafka.client_async import KafkaClient, selectors
from kafka.client_async import selectors
from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment, ConsumerProtocol
import kafka.errors as Errors
from kafka.errors import (
Expand All @@ -26,6 +26,7 @@
from kafka.protocol.metadata import MetadataRequest
from kafka.protocol.types import Array
from kafka.structs import TopicPartition, OffsetAndMetadata, MemberInformation, GroupInformation
from kafka.util import get_client_factory
from kafka.version import __version__


Expand Down Expand Up @@ -146,6 +147,7 @@ class KafkaAdminClient(object):
sasl mechanism handshake. Default: one of bootstrap servers
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
instance. (See kafka.oauth.abstract). Default: None
client_factory (callable): Custom class / callable for creating KafkaClient instances
"""
DEFAULT_CONFIG = {
Expand Down Expand Up @@ -186,6 +188,7 @@ class KafkaAdminClient(object):
'metric_reporters': [],
'metrics_num_samples': 2,
'metrics_sample_window_ms': 30000,
'client_factory': None,
}

def __init__(self, **configs):
Expand All @@ -205,9 +208,11 @@ def __init__(self, **configs):
reporters = [reporter() for reporter in self.config['metric_reporters']]
self._metrics = Metrics(metric_config, reporters)

self._client = KafkaClient(metrics=self._metrics,
metric_group_prefix='admin',
**self.config)
self._client = get_client_factory(self.config)(
metrics=self._metrics,
metric_group_prefix='admin',
**self.config
)
self._client.check_version(timeout=(self.config['api_version_auto_timeout_ms'] / 1000))

# Get auto-discovered version from client if necessary
Expand Down
7 changes: 5 additions & 2 deletions kafka/consumer/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from kafka.vendor import six

from kafka.client_async import KafkaClient, selectors
from kafka.client_async import selectors
from kafka.consumer.fetcher import Fetcher
from kafka.consumer.subscription_state import SubscriptionState
from kafka.coordinator.consumer import ConsumerCoordinator
Expand All @@ -18,6 +18,7 @@
from kafka.metrics import MetricConfig, Metrics
from kafka.protocol.offset import OffsetResetStrategy
from kafka.structs import TopicPartition
from kafka.util import get_client_factory
from kafka.version import __version__

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -244,6 +245,7 @@ class KafkaConsumer(six.Iterator):
sasl mechanism handshake. Default: one of bootstrap servers
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
instance. (See kafka.oauth.abstract). Default: None
client_factory (callable): Custom class / callable for creating KafkaClient instances
Note:
Configuration parameters are described in more detail at
Expand Down Expand Up @@ -306,6 +308,7 @@ class KafkaConsumer(six.Iterator):
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None,
'legacy_iterator': False, # enable to revert to < 1.4.7 iterator
'client_factory': None,
}
DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000

Expand Down Expand Up @@ -353,7 +356,7 @@ def __init__(self, *topics, **configs):
log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated',
str(self.config['api_version']), str_version)

self._client = KafkaClient(metrics=self._metrics, **self.config)
self._client = get_client_factory(self.config)(metrics=self._metrics, **self.config)

# Get auto-discovered version from client if necessary
if self.config['api_version'] is None:
Expand Down
14 changes: 9 additions & 5 deletions kafka/producer/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from kafka.vendor import six

import kafka.errors as Errors
from kafka.client_async import KafkaClient, selectors
from kafka.client_async import selectors
from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd
from kafka.metrics import MetricConfig, Metrics
from kafka.partitioner.default import DefaultPartitioner
Expand All @@ -22,6 +22,7 @@
from kafka.record.legacy_records import LegacyRecordBatchBuilder
from kafka.serializer import Serializer
from kafka.structs import TopicPartition
from kafka.util import get_client_factory


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -280,6 +281,7 @@ class KafkaProducer(object):
sasl mechanism handshake. Default: one of bootstrap servers
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
instance. (See kafka.oauth.abstract). Default: None
client_factory (callable): Custom class / callable for creating KafkaClient instances
Note:
Configuration parameters are described in more detail at
Expand Down Expand Up @@ -332,7 +334,8 @@ class KafkaProducer(object):
'sasl_plain_password': None,
'sasl_kerberos_service_name': 'kafka',
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None
'sasl_oauth_token_provider': None,
'client_factory': None,
}

_COMPRESSORS = {
Expand Down Expand Up @@ -378,9 +381,10 @@ def __init__(self, **configs):
reporters = [reporter() for reporter in self.config['metric_reporters']]
self._metrics = Metrics(metric_config, reporters)

client = KafkaClient(metrics=self._metrics, metric_group_prefix='producer',
wakeup_timeout_ms=self.config['max_block_ms'],
**self.config)
client = get_client_factory(self.config)(
metrics=self._metrics, metric_group_prefix='producer',
wakeup_timeout_ms=self.config['max_block_ms'],
**self.config)

# Get auto-discovered version from client if necessary
if self.config['api_version'] is None:
Expand Down
10 changes: 10 additions & 0 deletions kafka/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import

import binascii
import kafka
import weakref

from kafka.vendor import six
Expand Down Expand Up @@ -64,3 +65,12 @@ class Dict(dict):
See: https://docs.python.org/2/library/weakref.html
"""
pass


def get_client_factory(config):
if config.get('client_factory') is not None:
client_factory = config['client_factory']
assert callable(client_factory), "'client_factory' should be a callable or None, is {}".format(type(client_factory))
else:
client_factory = kafka.client_async.KafkaClient
return client_factory

0 comments on commit 234cca5

Please # to comment.