diff --git a/flink-python/pyflink/datastream/connectors/kafka.py b/flink-python/pyflink/datastream/connectors/kafka.py index 062c5b2e8..cf05cb5b2 100644 --- a/flink-python/pyflink/datastream/connectors/kafka.py +++ b/flink-python/pyflink/datastream/connectors/kafka.py @@ -44,7 +44,9 @@ 'KafkaOffsetResetStrategy', 'KafkaRecordSerializationSchema', 'KafkaRecordSerializationSchemaBuilder', - 'KafkaTopicSelector' + 'KafkaTopicSelector', + 'KafkaRecordDeserializationSchema', + 'SimpleStringValueKafkaRecordDeserializationSchema' ] @@ -353,6 +355,38 @@ def ignore_failures_after_transaction_timeout(self) -> 'FlinkKafkaProducer': # ---- KafkaSource ---- +class KafkaRecordDeserializationSchema: + """ + Base class for KafkaRecordDeserializationSchema. The kafka record deserialization schema + describes how to turn the byte messages delivered by Apache Kafka into data types (Java/ + Scala objects) that are processed by Flink. + + In addition, the KafkaRecordDeserializationSchema describes the produced type which lets + Flink create internal serializers and structures to handle the type. + """ + def __init__(self, j_kafka_record_deserialization_schema=None): + self.j_kafka_record_deserialization_schema = j_kafka_record_deserialization_schema + + +class SimpleStringValueKafkaRecordDeserializationSchema(KafkaRecordDeserializationSchema): + """ + Very simple deserialization schema for strings values. By default, the deserializer uses + 'UTF-8' for byte to string conversion. + """ + + def __init__(self, charset: str = 'UTF-8'): + gate_way = get_gateway() + j_char_set = gate_way.jvm.java.nio.charset.Charset.forName(charset) + j_simple_string_serialization_schema = gate_way.jvm \ + .org.apache.flink.api.common.serialization.SimpleStringSchema(j_char_set) + j_kafka_record_deserialization_schema = gate_way.jvm \ + .org.apache.flink.connector.kafka.source.reader.deserializer \ + .KafkaRecordDeserializationSchema.valueOnly(j_simple_string_serialization_schema) + KafkaRecordDeserializationSchema.__init__( + self, j_kafka_record_deserialization_schema=j_kafka_record_deserialization_schema) + + +# ---- KafkaSource ---- class KafkaSource(Source): """ @@ -611,6 +645,22 @@ def set_value_only_deserializer(self, deserialization_schema: DeserializationSch self._j_builder.setValueOnlyDeserializer(deserialization_schema._j_deserialization_schema) return self + def set_deserializer( + self, + kafka_record_deserialization_schema: KafkaRecordDeserializationSchema + ) -> 'KafkaSourceBuilder': + """ + Sets the :class:`~pyflink.datastream.connectors.kafka.KafkaRecordDeserializationSchema` + for deserializing Kafka ConsumerRecords. + + :param kafka_record_deserialization_schema: the :class:`KafkaRecordDeserializationSchema` + to use for deserialization. + :return: this KafkaSourceBuilder. + """ + self._j_builder.setDeserializer( + kafka_record_deserialization_schema.j_kafka_record_deserialization_schema) + return self + def set_client_id_prefix(self, prefix: str) -> 'KafkaSourceBuilder': """ Sets the client id prefix of this KafkaSource. diff --git a/flink-python/pyflink/datastream/connectors/tests/test_kafka.py b/flink-python/pyflink/datastream/connectors/tests/test_kafka.py index dea06b3e0..ef89256b5 100644 --- a/flink-python/pyflink/datastream/connectors/tests/test_kafka.py +++ b/flink-python/pyflink/datastream/connectors/tests/test_kafka.py @@ -29,7 +29,8 @@ from pyflink.datastream.connectors.base import DeliveryGuarantee from pyflink.datastream.connectors.kafka import KafkaSource, KafkaTopicPartition, \ KafkaOffsetsInitializer, KafkaOffsetResetStrategy, KafkaRecordSerializationSchema, KafkaSink, \ - FlinkKafkaProducer, FlinkKafkaConsumer + FlinkKafkaProducer, FlinkKafkaConsumer, KafkaRecordDeserializationSchema, \ + SimpleStringValueKafkaRecordDeserializationSchema from pyflink.datastream.formats.avro import AvroRowDeserializationSchema, AvroRowSerializationSchema from pyflink.datastream.formats.csv import CsvRowDeserializationSchema, CsvRowSerializationSchema from pyflink.datastream.formats.json import JsonRowDeserializationSchema, JsonRowSerializationSchema @@ -332,6 +333,22 @@ def _check(schema: DeserializationSchema, class_name: str): 'org.apache.flink.formats.avro.AvroRowDeserializationSchema' ) + def test_set_kafka_record_deserialization_schema(self): + def _check(schema: KafkaRecordDeserializationSchema, java_class_name: str): + source = KafkaSource.builder() \ + .set_bootstrap_servers('localhost:9092') \ + .set_topics('test_topic') \ + .set_deserializer(schema) \ + .build() + kafka_record_deserialization_schema = get_field_value(source.get_java_function(), + 'deserializationSchema') + self.assertEqual(kafka_record_deserialization_schema.getClass().getCanonicalName(), + java_class_name) + + _check(SimpleStringValueKafkaRecordDeserializationSchema(), + 'org.apache.flink.connector.kafka.source.reader.deserializer.' + 'KafkaValueOnlyDeserializationSchemaWrapper') + def _check_reader_handled_offsets_initializer(self, source: KafkaSource, offset: int,