Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[FLINK-XXX] Add set_deserializer method to Python KafkaSourceBuilder #148

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion flink-python/pyflink/datastream/connectors/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
'KafkaOffsetResetStrategy',
'KafkaRecordSerializationSchema',
'KafkaRecordSerializationSchemaBuilder',
'KafkaTopicSelector'
'KafkaTopicSelector',
'KafkaRecordDeserializationSchema',
'SimpleStringValueKafkaRecordDeserializationSchema'
]


Expand Down Expand Up @@ -353,6 +355,38 @@ def ignore_failures_after_transaction_timeout(self) -> 'FlinkKafkaProducer':

# ---- KafkaSource ----

class KafkaRecordDeserializationSchema:
"""
Base class for KafkaRecordDeserializationSchema. The kafka record deserialization schema
describes how to turn the byte messages delivered by Apache Kafka into data types (Java/
Scala objects) that are processed by Flink.

In addition, the KafkaRecordDeserializationSchema describes the produced type which lets
Flink create internal serializers and structures to handle the type.
"""
def __init__(self, j_kafka_record_deserialization_schema=None):
self.j_kafka_record_deserialization_schema = j_kafka_record_deserialization_schema


class SimpleStringValueKafkaRecordDeserializationSchema(KafkaRecordDeserializationSchema):
"""
Very simple deserialization schema for strings values. By default, the deserializer uses
'UTF-8' for byte to string conversion.
"""

def __init__(self, charset: str = 'UTF-8'):
gate_way = get_gateway()
j_char_set = gate_way.jvm.java.nio.charset.Charset.forName(charset)
j_simple_string_serialization_schema = gate_way.jvm \
.org.apache.flink.api.common.serialization.SimpleStringSchema(j_char_set)
j_kafka_record_deserialization_schema = gate_way.jvm \
.org.apache.flink.connector.kafka.source.reader.deserializer \
.KafkaRecordDeserializationSchema.valueOnly(j_simple_string_serialization_schema)
KafkaRecordDeserializationSchema.__init__(
self, j_kafka_record_deserialization_schema=j_kafka_record_deserialization_schema)


# ---- KafkaSource ----

class KafkaSource(Source):
"""
Expand Down Expand Up @@ -611,6 +645,22 @@ def set_value_only_deserializer(self, deserialization_schema: DeserializationSch
self._j_builder.setValueOnlyDeserializer(deserialization_schema._j_deserialization_schema)
return self

def set_deserializer(
self,
kafka_record_deserialization_schema: KafkaRecordDeserializationSchema
) -> 'KafkaSourceBuilder':
"""
Sets the :class:`~pyflink.datastream.connectors.kafka.KafkaRecordDeserializationSchema`
for deserializing Kafka ConsumerRecords.

:param kafka_record_deserialization_schema: the :class:`KafkaRecordDeserializationSchema`
to use for deserialization.
:return: this KafkaSourceBuilder.
"""
self._j_builder.setDeserializer(
kafka_record_deserialization_schema.j_kafka_record_deserialization_schema)
return self

def set_client_id_prefix(self, prefix: str) -> 'KafkaSourceBuilder':
"""
Sets the client id prefix of this KafkaSource.
Expand Down
19 changes: 18 additions & 1 deletion flink-python/pyflink/datastream/connectors/tests/test_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from pyflink.datastream.connectors.base import DeliveryGuarantee
from pyflink.datastream.connectors.kafka import KafkaSource, KafkaTopicPartition, \
KafkaOffsetsInitializer, KafkaOffsetResetStrategy, KafkaRecordSerializationSchema, KafkaSink, \
FlinkKafkaProducer, FlinkKafkaConsumer
FlinkKafkaProducer, FlinkKafkaConsumer, KafkaRecordDeserializationSchema, \
SimpleStringValueKafkaRecordDeserializationSchema
from pyflink.datastream.formats.avro import AvroRowDeserializationSchema, AvroRowSerializationSchema
from pyflink.datastream.formats.csv import CsvRowDeserializationSchema, CsvRowSerializationSchema
from pyflink.datastream.formats.json import JsonRowDeserializationSchema, JsonRowSerializationSchema
Expand Down Expand Up @@ -332,6 +333,22 @@ def _check(schema: DeserializationSchema, class_name: str):
'org.apache.flink.formats.avro.AvroRowDeserializationSchema'
)

def test_set_kafka_record_deserialization_schema(self):
def _check(schema: KafkaRecordDeserializationSchema, java_class_name: str):
source = KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_topics('test_topic') \
.set_deserializer(schema) \
.build()
kafka_record_deserialization_schema = get_field_value(source.get_java_function(),
'deserializationSchema')
self.assertEqual(kafka_record_deserialization_schema.getClass().getCanonicalName(),
java_class_name)

_check(SimpleStringValueKafkaRecordDeserializationSchema(),
'org.apache.flink.connector.kafka.source.reader.deserializer.'
'KafkaValueOnlyDeserializationSchemaWrapper')

def _check_reader_handled_offsets_initializer(self,
source: KafkaSource,
offset: int,
Expand Down