Skip to content

refactored httpclient for kerberos auth #543

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 137 additions & 52 deletions splitio/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
TELEMETRY_URL = 'https://telemetry.split.io/api'

_LOGGER = logging.getLogger(__name__)

_EXC_MSG = '{source} library is throwing exceptions'

HttpResponse = namedtuple('HttpResponse', ['status_code', 'body', 'headers'])

Expand Down Expand Up @@ -122,7 +122,7 @@ def _get_headers(self, extra_headers, sdk_key):
class HttpClient(HttpClientBase):
"""HttpClient wrapper."""

def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, telemetry_url=None, authentication_scheme=None, authentication_params=None):
def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, telemetry_url=None):
"""
Class constructor.

Expand All @@ -140,8 +140,6 @@ def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, t
_LOGGER.debug("Initializing httpclient")
self._timeout = timeout/1000 if timeout else None # Convert ms to seconds.
self._urls = _construct_urls(sdk_url, events_url, auth_url, telemetry_url)
self._authentication_scheme = authentication_scheme
self._authentication_params = authentication_params
self._lock = threading.RLock()

def get(self, server, path, sdk_key, query=None, extra_headers=None): # pylint: disable=too-many-arguments
Expand All @@ -162,22 +160,19 @@ def get(self, server, path, sdk_key, query=None, extra_headers=None): # pylint:
:return: Tuple of status_code & response text
:rtype: HttpResponse
"""
with self._lock:
start = get_current_epoch_time_ms()
with requests.Session() as session:
self._set_authentication(session)
try:
response = session.get(
_build_url(server, path, self._urls),
params=query,
headers=self._get_headers(extra_headers, sdk_key),
timeout=self._timeout
)
self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start)
return HttpResponse(response.status_code, response.text, response.headers)
start = get_current_epoch_time_ms()
try:
response = requests.get(
_build_url(server, path, self._urls),
params=query,
headers=self._get_headers(extra_headers, sdk_key),
timeout=self._timeout
)
self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start)
return HttpResponse(response.status_code, response.text, response.headers)

except Exception as exc: # pylint: disable=broad-except
raise HttpClientException('requests library is throwing exceptions') from exc
except Exception as exc: # pylint: disable=broad-except
raise HttpClientException(_EXC_MSG.format(source='request')) from exc

def post(self, server, path, sdk_key, body, query=None, extra_headers=None): # pylint: disable=too-many-arguments
"""
Expand All @@ -199,37 +194,19 @@ def post(self, server, path, sdk_key, body, query=None, extra_headers=None): #
:return: Tuple of status_code & response text
:rtype: HttpResponse
"""
with self._lock:
start = get_current_epoch_time_ms()
with requests.Session() as session:
self._set_authentication(session)
try:
response = session.post(
_build_url(server, path, self._urls),
json=body,
params=query,
headers=self._get_headers(extra_headers, sdk_key),
timeout=self._timeout,
)
self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start)
return HttpResponse(response.status_code, response.text, response.headers)
except Exception as exc: # pylint: disable=broad-except
raise HttpClientException('requests library is throwing exceptions') from exc

def _set_authentication(self, session):
if self._authentication_scheme == AuthenticateScheme.KERBEROS_SPNEGO:
_LOGGER.debug("Using Kerberos Spnego Authentication")
if self._authentication_params != [None, None]:
session.auth = HTTPKerberosAuth(principal=self._authentication_params[0], password=self._authentication_params[1], mutual_authentication=OPTIONAL)
else:
session.auth = HTTPKerberosAuth(mutual_authentication=OPTIONAL)
elif self._authentication_scheme == AuthenticateScheme.KERBEROS_PROXY:
_LOGGER.debug("Using Kerberos Proxy Authentication")
if self._authentication_params != [None, None]:
session.mount('https://', HTTPAdapterWithProxyKerberosAuth(principal=self._authentication_params[0], password=self._authentication_params[1]))
else:
session.mount('https://', HTTPAdapterWithProxyKerberosAuth())

start = get_current_epoch_time_ms()
try:
response = requests.post(
_build_url(server, path, self._urls),
json=body,
params=query,
headers=self._get_headers(extra_headers, sdk_key),
timeout=self._timeout,
)
self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start)
return HttpResponse(response.status_code, response.text, response.headers)
except Exception as exc: # pylint: disable=broad-except
raise HttpClientException(_EXC_MSG.format(source='request')) from exc

def _record_telemetry(self, status_code, elapsed):
"""
Expand Down Expand Up @@ -306,7 +283,7 @@ async def get(self, server, path, apikey, query=None, extra_headers=None): # py
return HttpResponse(response.status, body, response.headers)

except aiohttp.ClientError as exc: # pylint: disable=broad-except
raise HttpClientException('aiohttp library is throwing exceptions') from exc
raise HttpClientException(_EXC_MSG.format(source='aiohttp')) from exc

async def post(self, server, path, apikey, body, query=None, extra_headers=None): # pylint: disable=too-many-arguments
"""
Expand Down Expand Up @@ -350,7 +327,7 @@ async def post(self, server, path, apikey, body, query=None, extra_headers=None)
return HttpResponse(response.status, body, response.headers)

except aiohttp.ClientError as exc: # pylint: disable=broad-except
raise HttpClientException('aiohttp library is throwing exceptions') from exc
raise HttpClientException(_EXC_MSG.format(source='aiohttp')) from exc

async def _record_telemetry(self, status_code, elapsed):
"""
Expand All @@ -372,3 +349,111 @@ async def _record_telemetry(self, status_code, elapsed):
async def close_session(self):
if not self._session.closed:
await self._session.close()

class HttpClientKerberos(HttpClient):
"""HttpClient wrapper."""

def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, telemetry_url=None, authentication_scheme=None, authentication_params=None):
"""
Class constructor.

:param timeout: How many milliseconds to wait until the server responds.
:type timeout: int
:param sdk_url: Optional alternative sdk URL.
:type sdk_url: str
:param events_url: Optional alternative events URL.
:type events_url: str
:param auth_url: Optional alternative auth URL.
:type auth_url: str
:param telemetry_url: Optional alternative telemetry URL.
:type telemetry_url: str
"""
_LOGGER.debug("Initializing httpclient for Kerberos auth")
HttpClient.__init__(self, timeout=timeout, sdk_url=sdk_url, events_url=events_url, auth_url=auth_url, telemetry_url=telemetry_url)
self._authentication_scheme = authentication_scheme
self._authentication_params = authentication_params

def get(self, server, path, sdk_key, query=None, extra_headers=None): # pylint: disable=too-many-arguments
"""
Issue a get request.
:param server: Whether the request is for SDK server, Events server or Auth server.
:typee server: str
:param path: path to append to the host url.
:type path: str
:param sdk_key: sdk key.
:type sdk_key: str
:param query: Query string passed as dictionary.
:type query: dict
:param extra_headers: key/value pairs of possible extra headers.
:type extra_headers: dict

:return: Tuple of status_code & response text
:rtype: HttpResponse
"""
with self._lock:
start = get_current_epoch_time_ms()
with requests.Session() as session:
self._set_authentication(session)
try:
response = session.get(
_build_url(server, path, self._urls),
headers=self._get_headers(extra_headers, sdk_key),
params=query,
timeout=self._timeout
)
self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start)
return HttpResponse(response.status_code, response.text, response.headers)

except Exception as exc: # pylint: disable=broad-except
raise HttpClientException(_EXC_MSG.format(source='request')) from exc

def post(self, server, path, sdk_key, body, query=None, extra_headers=None): # pylint: disable=too-many-arguments
"""
Issue a POST request.

:param server: Whether the request is for SDK server or Events server.
:typee server: str
:param path: path to append to the host url.
:type path: str
:param sdk_key: sdk key.
:type sdk_key: str
:param body: body sent in the request.
:type body: str
:param query: Query string passed as dictionary.
:type query: dict
:param extra_headers: key/value pairs of possible extra headers.
:type extra_headers: dict

:return: Tuple of status_code & response text
:rtype: HttpResponse
"""
with self._lock:
start = get_current_epoch_time_ms()
with requests.Session() as session:
self._set_authentication(session)
try:
response = session.post(
_build_url(server, path, self._urls),
params=query,
headers=self._get_headers(extra_headers, sdk_key),
json=body,
timeout=self._timeout,
)
self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start)
return HttpResponse(response.status_code, response.text, response.headers)
except Exception as exc: # pylint: disable=broad-except
raise HttpClientException(_EXC_MSG.format(source='request')) from exc

def _set_authentication(self, session):
if self._authentication_scheme == AuthenticateScheme.KERBEROS_SPNEGO:
_LOGGER.debug("Using Kerberos Spnego Authentication")
if self._authentication_params != [None, None]:
session.auth = HTTPKerberosAuth(principal=self._authentication_params[0], password=self._authentication_params[1], mutual_authentication=OPTIONAL)
else:
session.auth = HTTPKerberosAuth(mutual_authentication=OPTIONAL)
elif self._authentication_scheme == AuthenticateScheme.KERBEROS_PROXY:
_LOGGER.debug("Using Kerberos Proxy Authentication")
if self._authentication_params != [None, None]:
session.mount('https://', HTTPAdapterWithProxyKerberosAuth(principal=self._authentication_params[0], password=self._authentication_params[1]))
else:
session.mount('https://', HTTPAdapterWithProxyKerberosAuth())
29 changes: 18 additions & 11 deletions splitio/client/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
PluggableImpressionsStorageAsync, PluggableSegmentStorageAsync, PluggableSplitStorageAsync

# APIs
from splitio.api.client import HttpClient, HttpClientAsync
from splitio.api.client import HttpClient, HttpClientAsync, HttpClientKerberos
from splitio.api.splits import SplitsAPI, SplitsAPIAsync
from splitio.api.segments import SegmentsAPI, SegmentsAPIAsync
from splitio.api.impressions import ImpressionsAPI, ImpressionsAPIAsync
Expand Down Expand Up @@ -512,16 +512,23 @@ def _build_in_memory_factory(api_key, cfg, sdk_url=None, events_url=None, # pyl
if cfg.get("httpAuthenticateScheme") in [AuthenticateScheme.KERBEROS_SPNEGO, AuthenticateScheme.KERBEROS_PROXY]:
authentication_params = [cfg.get("kerberosPrincipalUser"),
cfg.get("kerberosPrincipalPassword")]

http_client = HttpClient(
sdk_url=sdk_url,
events_url=events_url,
auth_url=auth_api_base_url,
telemetry_url=telemetry_api_base_url,
timeout=cfg.get('connectionTimeout'),
authentication_scheme = cfg.get("httpAuthenticateScheme"),
authentication_params = authentication_params
)
http_client = HttpClientKerberos(
sdk_url=sdk_url,
events_url=events_url,
auth_url=auth_api_base_url,
telemetry_url=telemetry_api_base_url,
timeout=cfg.get('connectionTimeout'),
authentication_scheme = cfg.get("httpAuthenticateScheme"),
authentication_params = authentication_params
)
else:
http_client = HttpClient(
sdk_url=sdk_url,
events_url=events_url,
auth_url=auth_api_base_url,
telemetry_url=telemetry_api_base_url,
timeout=cfg.get('connectionTimeout'),
)

sdk_metadata = util.get_metadata(cfg)
apis = {
Expand Down
Loading
Loading