diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py index c9bbb9717..4f2a54388 100644 --- a/kafka/consumer/fetcher.py +++ b/kafka/consumer/fetcher.py @@ -372,11 +372,6 @@ def _append(self, drained, part, max_records): tp, next_offset) for record in part_records: - # Fetched compressed messages may include additional records - if record.offset < fetch_offset: - log.debug("Skipping message offset: %s (expecting %s)", - record.offset, fetch_offset) - continue drained[tp].append(record) self._subscriptions.assignment[tp].position = next_offset @@ -843,10 +838,15 @@ def __init__(self, fetch_offset, tp, messages): # When fetching an offset that is in the middle of a # compressed batch, we will get all messages in the batch. # But we want to start 'take' at the fetch_offset + # (or the next highest offset in case the message was compacted) for i, msg in enumerate(messages): - if msg.offset == fetch_offset: + if msg.offset < fetch_offset: + log.debug("Skipping message offset: %s (expecting %s)", + msg.offset, fetch_offset) + else: self.message_idx = i break + else: self.message_idx = 0 self.messages = None @@ -868,8 +868,9 @@ def take(self, n=None): next_idx = self.message_idx + n res = self.messages[self.message_idx:next_idx] self.message_idx = next_idx - if len(self) > 0: - self.fetch_offset = self.messages[self.message_idx].offset + # fetch_offset should be incremented by 1 to parallel the + # subscription position (also incremented by 1) + self.fetch_offset = max(self.fetch_offset, res[-1].offset + 1) return res diff --git a/test/test_fetcher.py b/test/test_fetcher.py index 4547222bd..fc031f742 100644 --- a/test/test_fetcher.py +++ b/test/test_fetcher.py @@ -514,8 +514,8 @@ def test_partition_records_offset(): records = Fetcher.PartitionRecords(fetch_offset, None, messages) assert len(records) > 0 msgs = records.take(1) - assert msgs[0].offset == 123 - assert records.fetch_offset == 124 + assert msgs[0].offset == fetch_offset + assert records.fetch_offset == fetch_offset + 1 msgs = records.take(2) assert len(msgs) == 2 assert len(records) > 0 @@ -538,3 +538,20 @@ def test_partition_records_no_fetch_offset(): for i in range(batch_start, batch_end)] records = Fetcher.PartitionRecords(fetch_offset, None, messages) assert len(records) == 0 + + +def test_partition_records_compacted_offset(): + """Test that messagesets are handle correctly + when the fetch offset points to a message that has been compacted + """ + batch_start = 0 + batch_end = 100 + fetch_offset = 42 + tp = TopicPartition('foo', 0) + messages = [ConsumerRecord(tp.topic, tp.partition, i, + None, None, 'key', 'value', 'checksum', 0, 0) + for i in range(batch_start, batch_end) if i != fetch_offset] + records = Fetcher.PartitionRecords(fetch_offset, None, messages) + assert len(records) == batch_end - fetch_offset - 1 + msgs = records.take(1) + assert msgs[0].offset == fetch_offset + 1