Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Improve connection lock handling; always use context manager #1895

Merged
merged 2 commits into from
Sep 3, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 151 additions & 126 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,21 +593,30 @@ def _try_authenticate_plain(self, future):
self.config['sasl_plain_username'],
self.config['sasl_plain_password']]).encode('utf-8'))
size = Int32.encode(len(msg))
try:
with self._lock:
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))
self._send_bytes_blocking(size + msg)

# The server will send a zero sized message (that is Int32(0)) on success.
# The connection is closed on failure
data = self._recv_bytes_blocking(4)
err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
try:
self._send_bytes_blocking(size + msg)

# The server will send a zero sized message (that is Int32(0)) on success.
# The connection is closed on failure
data = self._recv_bytes_blocking(4)

except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
self.close(error=error)
return future.failure(error)
except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

if data != b'\x00\x00\x00\x00':
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
Expand All @@ -625,61 +634,67 @@ def _try_authenticate_gssapi(self, future):
).canonicalize(gssapi.MechType.kerberos)
log.debug('%s: GSSAPI name: %s', self, gssapi_name)

self._lock.acquire()
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))
# Establish security context and negotiate protection level
# For reference RFC 2222, section 7.2.1
try:
# Exchange tokens until authentication either succeeds or fails
client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate')
received_token = None
while not client_ctx.complete:
# calculate an output token from kafka token (or None if first iteration)
output_token = client_ctx.step(received_token)

# pass output token to kafka, or send empty response if the security
# context is complete (output token is None in that case)
if output_token is None:
self._send_bytes_blocking(Int32.encode(0))
else:
msg = output_token
err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
# Establish security context and negotiate protection level
# For reference RFC 2222, section 7.2.1
try:
# Exchange tokens until authentication either succeeds or fails
client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate')
received_token = None
while not client_ctx.complete:
# calculate an output token from kafka token (or None if first iteration)
output_token = client_ctx.step(received_token)

# pass output token to kafka, or send empty response if the security
# context is complete (output token is None in that case)
if output_token is None:
self._send_bytes_blocking(Int32.encode(0))
else:
msg = output_token
size = Int32.encode(len(msg))
self._send_bytes_blocking(size + msg)

# The server will send a token back. Processing of this token either
# establishes a security context, or it needs further token exchange.
# The gssapi will be able to identify the needed next step.
# The connection is closed on failure.
header = self._recv_bytes_blocking(4)
(token_size,) = struct.unpack('>i', header)
received_token = self._recv_bytes_blocking(token_size)

# Process the security layer negotiation token, sent by the server
# once the security context is established.

# unwraps message containing supported protection levels and msg size
msg = client_ctx.unwrap(received_token).message
# Kafka currently doesn't support integrity or confidentiality security layers, so we
# simply set QoP to 'auth' only (first octet). We reuse the max message size proposed
# by the server
msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:]
# add authorization identity to the response, GSS-wrap and send it
msg = client_ctx.wrap(msg + auth_id.encode(), False).message
size = Int32.encode(len(msg))
self._send_bytes_blocking(size + msg)

# The server will send a token back. Processing of this token either
# establishes a security context, or it needs further token exchange.
# The gssapi will be able to identify the needed next step.
# The connection is closed on failure.
header = self._recv_bytes_blocking(4)
(token_size,) = struct.unpack('>i', header)
received_token = self._recv_bytes_blocking(token_size)

# Process the security layer negotiation token, sent by the server
# once the security context is established.

# unwraps message containing supported protection levels and msg size
msg = client_ctx.unwrap(received_token).message
# Kafka currently doesn't support integrity or confidentiality security layers, so we
# simply set QoP to 'auth' only (first octet). We reuse the max message size proposed
# by the server
msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:]
# add authorization identity to the response, GSS-wrap and send it
msg = client_ctx.wrap(msg + auth_id.encode(), False).message
size = Int32.encode(len(msg))
self._send_bytes_blocking(size + msg)
except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True
except Exception as e:
err = e
close = True

except (ConnectionError, TimeoutError) as e:
self._lock.release()
log.exception("%s: Error receiving reply from server", self)
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
self.close(error=error)
return future.failure(error)
except Exception as e:
self._lock.release()
return future.failure(e)
if err is not None:
if close:
self.close(error=err)
return future.failure(err)

self._lock.release()
log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name)
return future.success(True)

Expand All @@ -688,25 +703,31 @@ def _try_authenticate_oauth(self, future):

msg = bytes(self._build_oauth_client_request().encode("utf-8"))
size = Int32.encode(len(msg))
self._lock.acquire()
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))
try:
# Send SASL OAuthBearer request with OAuth token
self._send_bytes_blocking(size + msg)

# The server will send a zero sized message (that is Int32(0)) on success.
# The connection is closed on failure
data = self._recv_bytes_blocking(4)
err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
try:
# Send SASL OAuthBearer request with OAuth token
self._send_bytes_blocking(size + msg)

except (ConnectionError, TimeoutError) as e:
self._lock.release()
log.exception("%s: Error receiving reply from server", self)
error = Errors.KafkaConnectionError("%s: %s" % (self, e))
self.close(error=error)
return future.failure(error)
# The server will send a zero sized message (that is Int32(0)) on success.
# The connection is closed on failure
data = self._recv_bytes_blocking(4)

self._lock.release()
except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

if data != b'\x00\x00\x00\x00':
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
Expand Down Expand Up @@ -857,6 +878,9 @@ def _send(self, request, blocking=True):
future = Future()
with self._lock:
if not self._can_send_recv():
# In this case, since we created the future above,
# we know there are no callbacks/errbacks that could fire w/
# lock. So failing + returning inline should be safe
return future.failure(Errors.NodeNotReadyError(str(self)))

correlation_id = self._protocol.send_request(request)
Expand Down Expand Up @@ -935,56 +959,57 @@ def recv(self):
def _recv(self):
"""Take all available bytes from socket, return list of any responses from parser"""
recvd = []
self._lock.acquire()
if not self._can_send_recv():
log.warning('%s cannot recv: socket not connected', self)
self._lock.release()
return ()

while len(recvd) < self.config['sock_chunk_buffer_count']:
try:
data = self._sock.recv(self.config['sock_chunk_bytes'])
# We expect socket.recv to raise an exception if there are no
# bytes available to read from the socket in non-blocking mode.
# but if the socket is disconnected, we will get empty data
# without an exception raised
if not data:
log.error('%s: socket disconnected', self)
self._lock.release()
self.close(error=Errors.KafkaConnectionError('socket disconnected'))
return []
else:
recvd.append(data)
err = None
with self._lock:
if not self._can_send_recv():
log.warning('%s cannot recv: socket not connected', self)
return ()

except SSLWantReadError:
break
except (ConnectionError, TimeoutError) as e:
if six.PY2 and e.errno == errno.EWOULDBLOCK:
while len(recvd) < self.config['sock_chunk_buffer_count']:
try:
data = self._sock.recv(self.config['sock_chunk_bytes'])
# We expect socket.recv to raise an exception if there are no
# bytes available to read from the socket in non-blocking mode.
# but if the socket is disconnected, we will get empty data
# without an exception raised
if not data:
log.error('%s: socket disconnected', self)
err = Errors.KafkaConnectionError('socket disconnected')
break
else:
recvd.append(data)

except SSLWantReadError:
break
log.exception('%s: Error receiving network data'
' closing socket', self)
self._lock.release()
self.close(error=Errors.KafkaConnectionError(e))
return []
except BlockingIOError:
if six.PY3:
except (ConnectionError, TimeoutError) as e:
if six.PY2 and e.errno == errno.EWOULDBLOCK:
break
log.exception('%s: Error receiving network data'
' closing socket', self)
err = Errors.KafkaConnectionError(e)
break
self._lock.release()
raise

recvd_data = b''.join(recvd)
if self._sensors:
self._sensors.bytes_received.record(len(recvd_data))

try:
responses = self._protocol.receive_bytes(recvd_data)
except Errors.KafkaProtocolError as e:
self._lock.release()
self.close(e)
return []
else:
self._lock.release()
return responses
except BlockingIOError:
if six.PY3:
break
# For PY2 this is a catchall and should be re-raised
raise

# Only process bytes if there was no connection exception
if err is None:
recvd_data = b''.join(recvd)
if self._sensors:
self._sensors.bytes_received.record(len(recvd_data))

# We need to keep the lock through protocol receipt
# so that we ensure that the processed byte order is the
# same as the received byte order
try:
return self._protocol.receive_bytes(recvd_data)
except Errors.KafkaProtocolError as e:
err = e

self.close(error=err)
return ()

def requests_timed_out(self):
with self._lock:
Expand Down