Skip to content

added support for spnego/kerberos auth #534

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 6 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:

- name: Install dependencies
run: |
sudo apt-get install -y libkrb5-dev
pip install -U setuptools pip wheel
pip install -e .[cpphash,redis,uwsgi]

Expand Down
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
'attrs==22.1.0',
'pytest-asyncio==0.21.0',
'aiohttp>=3.8.4',
'aiofiles>=23.1.0'
'aiofiles>=23.1.0',
'requests-kerberos>=0.14.0'
]

INSTALL_REQUIRES = [
Expand Down Expand Up @@ -46,7 +47,8 @@
'redis': ['redis>=2.10.5'],
'uwsgi': ['uwsgi>=2.0.0'],
'cpphash': ['mmh3cffi==0.2.1'],
'asyncio': ['aiohttp>=3.8.4', 'aiofiles>=23.1.0']
'asyncio': ['aiohttp>=3.8.4', 'aiofiles>=23.1.0'],
'kerberos': ['requests-kerberos>=0.14.0']
},
setup_requires=['pytest-runner', 'pluggy==1.0.0;python_version<"3.8"'],
classifiers=[
Expand Down
25 changes: 21 additions & 4 deletions splitio/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import abc
import logging
import json
from splitio.optional.loaders import HTTPKerberosAuth, OPTIONAL

from splitio.client.config import AuthenticateScheme
from splitio.optional.loaders import aiohttp
from splitio.util.time import get_current_epoch_time_ms

Expand Down Expand Up @@ -95,7 +97,7 @@ def set_telemetry_data(self, metric_name, telemetry_runtime_producer):
class HttpClient(HttpClientBase):
"""HttpClient wrapper."""

def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, telemetry_url=None):
def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, telemetry_url=None, authentication_scheme=None, authentication_params=None):
"""
Class constructor.

Expand All @@ -111,6 +113,8 @@ def __init__(self, timeout=None, sdk_url=None, events_url=None, auth_url=None, t
:type telemetry_url: str
"""
self._timeout = timeout/1000 if timeout else None # Convert ms to seconds.
self._authentication_scheme = authentication_scheme
self._authentication_params = authentication_params
self._urls = _construct_urls(sdk_url, events_url, auth_url, telemetry_url)

def get(self, server, path, sdk_key, query=None, extra_headers=None): # pylint: disable=too-many-arguments
Expand All @@ -135,13 +139,15 @@ def get(self, server, path, sdk_key, query=None, extra_headers=None): # pylint:
if extra_headers is not None:
headers.update(extra_headers)

authentication = self._get_authentication()
start = get_current_epoch_time_ms()
try:
response = requests.get(
_build_url(server, path, self._urls),
params=query,
headers=headers,
timeout=self._timeout
timeout=self._timeout,
auth=authentication
)
self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start)
return HttpResponse(response.status_code, response.text, response.headers)
Expand Down Expand Up @@ -174,21 +180,32 @@ def post(self, server, path, sdk_key, body, query=None, extra_headers=None): #
if extra_headers is not None:
headers.update(extra_headers)

authentication = self._get_authentication()
start = get_current_epoch_time_ms()
try:
response = requests.post(
_build_url(server, path, self._urls),
json=body,
params=query,
headers=headers,
timeout=self._timeout
timeout=self._timeout,
auth=authentication
)
self._record_telemetry(response.status_code, get_current_epoch_time_ms() - start)
return HttpResponse(response.status_code, response.text, response.headers)

except Exception as exc: # pylint: disable=broad-except
raise HttpClientException('requests library is throwing exceptions') from exc

def _get_authentication(self):
authentication = None
if self._authentication_scheme == AuthenticateScheme.KERBEROS:
if self._authentication_params is not None:
authentication = HTTPKerberosAuth(principal=self._authentication_params[0], password=self._authentication_params[1], mutual_authentication=OPTIONAL)
else:
authentication = HTTPKerberosAuth(mutual_authentication=OPTIONAL)
return authentication

def _record_telemetry(self, status_code, elapsed):
"""
Record Telemetry info
Expand Down Expand Up @@ -333,4 +350,4 @@ async def _record_telemetry(self, status_code, elapsed):

async def close_session(self):
if not self._session.closed:
await self._session.close()
await self._session.close()
22 changes: 21 additions & 1 deletion splitio/client/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
"""Default settings for the Split.IO SDK Python client."""
import os.path
import logging
from enum import Enum

from splitio.engine.impressions import ImpressionsMode
from splitio.client.input_validator import validate_flag_sets

_LOGGER = logging.getLogger(__name__)
DEFAULT_DATA_SAMPLING = 1

class AuthenticateScheme(Enum):
"""Authentication Scheme."""
NONE = 'NONE'
KERBEROS = 'KERBEROS'


DEFAULT_CONFIG = {
'operationMode': 'standalone',
'connectionTimeout': 1500,
Expand Down Expand Up @@ -59,7 +66,10 @@
'storageWrapper': None,
'storagePrefix': None,
'storageType': None,
'flagSetsFilter': None
'flagSetsFilter': None,
'httpAuthenticateScheme': AuthenticateScheme.NONE,
'kerberosPrincipalUser': None,
'kerberosPrincipalPassword': None
}

def _parse_operation_mode(sdk_key, config):
Expand Down Expand Up @@ -148,4 +158,14 @@ def sanitize(sdk_key, config):
else:
processed['flagSetsFilter'] = sorted(validate_flag_sets(processed['flagSetsFilter'], 'SDK Config')) if processed['flagSetsFilter'] is not None else None

if config.get('httpAuthenticateScheme') is not None:
try:
authenticate_scheme = AuthenticateScheme(config['httpAuthenticateScheme'].upper())
except (ValueError, AttributeError):
authenticate_scheme = AuthenticateScheme.NONE
_LOGGER.warning('You passed an invalid HttpAuthenticationScheme, HttpAuthenticationScheme should be ' \
'one of the following values: `none` or `kerberos`. '
' Defaulting to `none` mode.')
processed["httpAuthenticateScheme"] = authenticate_scheme

return processed
11 changes: 9 additions & 2 deletions splitio/client/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from splitio.optional.loaders import asyncio
from splitio.client.client import Client, ClientAsync
from splitio.client import input_validator
from splitio.client.config import sanitize as sanitize_config, DEFAULT_DATA_SAMPLING, AuthenticateScheme
from splitio.client.manager import SplitManager, SplitManagerAsync
from splitio.client.config import sanitize as sanitize_config, DEFAULT_DATA_SAMPLING
from splitio.client import util
from splitio.client.listener import ImpressionListenerWrapper, ImpressionListenerWrapperAsync
from splitio.engine.impressions.impressions import Manager as ImpressionsManager
Expand Down Expand Up @@ -508,12 +508,19 @@ def _build_in_memory_factory(api_key, cfg, sdk_url=None, events_url=None, # pyl
telemetry_evaluation_producer = telemetry_producer.get_telemetry_evaluation_producer()
telemetry_init_producer = telemetry_producer.get_telemetry_init_producer()

authentication_params = None
if cfg.get("httpAuthenticateScheme") == AuthenticateScheme.KERBEROS:
authentication_params = [cfg.get("kerberosPrincipalUser"),
cfg.get("kerberosPrincipalPassword")]

http_client = HttpClient(
sdk_url=sdk_url,
events_url=events_url,
auth_url=auth_api_base_url,
telemetry_url=telemetry_api_base_url,
timeout=cfg.get('connectionTimeout')
timeout=cfg.get('connectionTimeout'),
authentication_scheme = cfg.get("httpAuthenticateScheme"),
authentication_params = authentication_params
)

sdk_metadata = util.get_metadata(cfg)
Expand Down
12 changes: 12 additions & 0 deletions splitio/optional/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,17 @@ def missing_asyncio_dependencies(*_, **__):
asyncio = missing_asyncio_dependencies
aiofiles = missing_asyncio_dependencies

try:
from requests_kerberos import HTTPKerberosAuth, OPTIONAL
except ImportError:
def missing_auth_dependencies(*_, **__):
"""Fail if missing dependencies are used."""
raise NotImplementedError(
'Missing kerberos auth dependency. '
'Please use `pip install splitio_client[kerberos]` to install the sdk with kerberos auth support'
)
HTTPKerberosAuth = missing_auth_dependencies
OPTIONAL = missing_auth_dependencies

async def _anext(it):
return await it.__anext__()
2 changes: 1 addition & 1 deletion splitio/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '10.0.1'
__version__ = '10.1.0rc1'
48 changes: 40 additions & 8 deletions tests/api/test_httpclient.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""HTTPClient test module."""
from requests_kerberos import HTTPKerberosAuth, OPTIONAL
import pytest
import unittest.mock as mock

from splitio.client.config import AuthenticateScheme
from splitio.api import client
from splitio.engine.telemetry import TelemetryStorageProducer, TelemetryStorageProducerAsync
from splitio.storage.inmemmory import InMemoryTelemetryStorage, InMemoryTelemetryStorageAsync
Expand All @@ -25,7 +27,8 @@ def test_get(self, mocker):
client.SDK_URL + '/test1',
headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'},
params={'param1': 123},
timeout=None
timeout=None,
auth=None
)
assert response.status_code == 200
assert response.body == 'ok'
Expand All @@ -37,7 +40,8 @@ def test_get(self, mocker):
client.EVENTS_URL + '/test1',
headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'},
params={'param1': 123},
timeout=None
timeout=None,
auth=None
)
assert get_mock.mock_calls == [call]
assert response.status_code == 200
Expand All @@ -59,7 +63,8 @@ def test_get_custom_urls(self, mocker):
'https://sdk.com/test1',
headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'},
params={'param1': 123},
timeout=None
timeout=None,
auth=None
)
assert get_mock.mock_calls == [call]
assert response.status_code == 200
Expand All @@ -71,7 +76,8 @@ def test_get_custom_urls(self, mocker):
'https://events.com/test1',
headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'},
params={'param1': 123},
timeout=None
timeout=None,
auth=None
)
assert response.status_code == 200
assert response.body == 'ok'
Expand All @@ -95,7 +101,8 @@ def test_post(self, mocker):
json={'p1': 'a'},
headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'},
params={'param1': 123},
timeout=None
timeout=None,
auth=None
)
assert response.status_code == 200
assert response.body == 'ok'
Expand All @@ -108,7 +115,8 @@ def test_post(self, mocker):
json={'p1': 'a'},
headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'},
params={'param1': 123},
timeout=None
timeout=None,
auth=None
)
assert response.status_code == 200
assert response.body == 'ok'
Expand All @@ -131,7 +139,8 @@ def test_post_custom_urls(self, mocker):
json={'p1': 'a'},
headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'},
params={'param1': 123},
timeout=None
timeout=None,
auth=None
)
assert response.status_code == 200
assert response.body == 'ok'
Expand All @@ -144,12 +153,35 @@ def test_post_custom_urls(self, mocker):
json={'p1': 'a'},
headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'},
params={'param1': 123},
timeout=None
timeout=None,
auth=None
)
assert response.status_code == 200
assert response.body == 'ok'
assert get_mock.mock_calls == [call]

def test_authentication_scheme(self, mocker):
response_mock = mocker.Mock()
response_mock.status_code = 200
response_mock.text = 'ok'
get_mock = mocker.Mock()
get_mock.return_value = response_mock
mocker.patch('splitio.api.client.requests.get', new=get_mock)
httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS)
httpclient.set_telemetry_data("metric", mocker.Mock())
response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'})
call = mocker.call(
'https://sdk.com/test1',
headers={'Authorization': 'Bearer some_api_key', 'h1': 'abc', 'Content-Type': 'application/json'},
params={'param1': 123},
timeout=None,
auth=HTTPKerberosAuth(mutual_authentication=OPTIONAL)
)

httpclient = client.HttpClient(sdk_url='https://sdk.com', authentication_scheme=AuthenticateScheme.KERBEROS, authentication_params=['bilal', 'split'])
httpclient.set_telemetry_data("metric", mocker.Mock())
response = httpclient.get('sdk', '/test1', 'some_api_key', {'param1': 123}, {'h1': 'abc'})

def test_telemetry(self, mocker):
telemetry_storage = InMemoryTelemetryStorage()
telemetry_producer = TelemetryStorageProducer(telemetry_storage)
Expand Down
10 changes: 10 additions & 0 deletions tests/client/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,19 @@ def test_sanitize(self):
processed = config.sanitize('some', configs)
assert processed['redisLocalCacheEnabled'] # check default is True
assert processed['flagSetsFilter'] is None
assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.NONE

processed = config.sanitize('some', {'redisHost': 'x', 'flagSetsFilter': ['set']})
assert processed['flagSetsFilter'] is None

processed = config.sanitize('some', {'storageType': 'pluggable', 'flagSetsFilter': ['set']})
assert processed['flagSetsFilter'] is None

processed = config.sanitize('some', {'httpAuthenticateScheme': 'KERBEROS'})
assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.KERBEROS

processed = config.sanitize('some', {'httpAuthenticateScheme': 'anything'})
assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.NONE

processed = config.sanitize('some', {'httpAuthenticateScheme': 'NONE'})
assert processed['httpAuthenticateScheme'] is config.AuthenticateScheme.NONE
Loading