From 70ee6fa533f342a7634d007a6455a82f5ee6a3c9 Mon Sep 17 00:00:00 2001 From: hackermd Date: Tue, 3 May 2022 17:25:22 -0400 Subject: [PATCH 1/3] Add session utils for OIDC auth --- requirements_test.txt | 1 + setup.py | 3 +- src/dicomweb_client/ext/gcp/session_utils.py | 4 +- src/dicomweb_client/session_utils.py | 496 ++++++++++++++++++- tests/test_session_utils.py | 151 ++++++ 5 files changed, 640 insertions(+), 15 deletions(-) create mode 100644 tests/test_session_utils.py diff --git a/requirements_test.txt b/requirements_test.txt index bfbc080..80e9562 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -3,5 +3,6 @@ pytest==6.2.4 pytest-cov==2.12.1 pytest-flake8==1.0.7 pytest-localserver==0.5.0 +responses==0.10.16 types-requests==2.27.14 types-Pillow==9.0.8 diff --git a/setup.py b/setup.py index b6d7842..c72c410 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ package_dir={'': 'src'}, extras_require={ 'gcp': [ - 'dataclasses>=0.8; python_version=="3.6"', + 'dataclasses>=0.8; python_version < "3.7.0"', 'google-auth>=1.6', 'google-oauth>=1.0', ], @@ -61,6 +61,7 @@ install_requires=[ 'numpy>=1.19', 'requests>=2.18', + 'requests-oauthlib>=1.2', 'retrying>=1.3.3', 'Pillow>=8.3', 'pydicom>=2.2', diff --git a/src/dicomweb_client/ext/gcp/session_utils.py b/src/dicomweb_client/ext/gcp/session_utils.py index ef4ae91..8b597ed 100644 --- a/src/dicomweb_client/ext/gcp/session_utils.py +++ b/src/dicomweb_client/ext/gcp/session_utils.py @@ -13,8 +13,8 @@ def create_session_from_gcp_credentials( - google_credentials: Optional[Any] = None - ) -> requests.Session: + google_credentials: Optional[Any] = None +) -> requests.Session: """Creates an authorized session for Google Cloud Platform. Parameters diff --git a/src/dicomweb_client/session_utils.py b/src/dicomweb_client/session_utils.py index a6f004a..798200f 100644 --- a/src/dicomweb_client/session_utils.py +++ b/src/dicomweb_client/session_utils.py @@ -1,13 +1,485 @@ +import base64 +import hashlib import logging import os -from typing import Optional, Any +import random +import re +import string +import time import warnings +import webbrowser +from abc import ABCMeta +from http.server import BaseHTTPRequestHandler, HTTPServer +from threading import Thread +from typing import Any, Callable, List, NamedTuple, NoReturn, Optional, Union import requests +from oauthlib.oauth2 import ( # type: ignore + Client as OAuth2Client, + BackendApplicationClient, + WebApplicationClient, +) +from requests_oauthlib.oauth2_session import OAuth2Session # type: ignore + logger = logging.getLogger(__name__) +_STORE = {} + + +class _AuthorizationCodeError(Exception): + """Exception raised when an authorization code could not be obtained. + + An authorization code is obtained from the authorization server as part of + the OAuth 2.0 Authorization Code grant type. + + """ + + pass + + +class _LocalHTTPRequestHandler(BaseHTTPRequestHandler): + """HTTP request handler. + + Handles received redirected HTTP request messages from the authorization + server as part of the OAuth 2.0 Authorization Code flow. + + """ + def _extract_value(self, param: str) -> str: + pattern = re.compile(rf"[^_]{param}=([^&]+)") + match = pattern.search(self.path) + if match is None: + raise ValueError(f'Value of parameter "{param}" not found in URI.') + return match.group(1) + + def do_GET(self) -> None: + """Respond to GET request.""" + try: + code = self._extract_value("code") + state = self._extract_value("state") + _STORE[state] = code + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + page = b""" + + + Local Server + + + + Success + + + """ + self.wfile.write(page) + except ValueError: + self.send_response(401) + self.send_header("Content-type", "text/html") + self.end_headers() + page = b""" + + + Local Server + + + + Unauthorized + + + """ + self.wfile.write(page) + + +class _LocalHTTPServer(Thread): + """Local HTTP server running in a thread. + + Receives redirected HTTP request messages from the authorization server as + part of the OAuth 2.0 Authorization Code flow. + + """ + + def __init__(self, port: int): + """Construct object. + + Parameters + ---------- + port: int + Port on localhost to which server should listen + + """ + super().__init__() + address = ("127.0.0.1", port) + self.server = HTTPServer(address, _LocalHTTPRequestHandler) + self.daemon = False + + def __enter__(self) -> '_LocalHTTPServer': + """Enter scope of with statement block.""" + logger.debug("start local server") + self.start() + return self + + def __exit__( + self, + error_type: Optional[type], + error_value: Optional[str], + error_traceback: Optional[object] + ) -> None: + """Exit scope of with statement block. + + Parameters + ---------- + except_type: type, optional + Error class + except_value: str, optional + Error message + except_trace: types.TracebackType, optional + Error traceback + + """ + if error_type is not None: + logger.error(f"an error occured: {error_value}") + logger.debug("stop local server") + self.stop() + if error_type is not None: + raise error_type(error_value) + + def run(self) -> None: + """Start serving.""" + self.server.serve_forever() + + def stop(self) -> None: + """Shut down the server.""" + self.server.shutdown() + self.join() + self.server.server_close() + self.server.socket.close() + + def get_authorization_code(self, state: str) -> str: + """Get a cached authorization code. + + Parameters + ---------- + state: str + Value of the OAuth 2.0 "state" parameter + + Returns + ------- + str + Authorization code + + """ + logger.debug("check if authorization code has been received") + try: + return _STORE[state] + except KeyError: + raise _AuthorizationCodeError(f'Code not found for state "{state}"') + + +class _AuthorizedSession(OAuth2Session, metaclass=ABCMeta): + """Abstract base class for an authorized OAuth 2.0 session.""" + + def __init__( + self, + client: OAuth2Client, + scope: Optional[List[str]] = None, + token_updater: Optional[Callable] = None, + auto_refresh_url: Optional[str] = None, + auto_refresh_kwargs: Optional[dict] = None, + redirect_uri: Optional[str] = None, + ) -> None: + """Construct object. + + Parameters + ---------- + client: oauthlib.oauth2.Client + OAuth 2.0 Client object + scope: List[str], optional + Restricted scope of client access + token_updater: Callable, optional + Function for handling retrieved access tokens + (signature: ``def token_updater(token: Dict[str, str]) -> None``) + auto_refresh_url: str, optional + URL for automatically refreshing access tokens + auto_refresh_kwargs: dict, optional + Parameters for automatically refreshing access tokens + redirect_uri: str, optional + URI of service to which authorization requests will be redirected + to + + """ + super().__init__( + client=client, + scope=scope, + token_updater=token_updater, + redirect_uri=redirect_uri, + auto_refresh_url=auto_refresh_url, + auto_refresh_kwargs=auto_refresh_kwargs, + ) + + +class PublicClientCredentials(NamedTuple): + """Credentials for a public OAuth 2.0 client.""" + + client_id: str + token_uri: str + auth_uri: str + + +class ConfidentialClientCredentials(NamedTuple): + """Credentials for a confidential OAuth 2.0 client.""" + + client_id: str + client_secret: str + token_uri: str + + +class PublicClientSession(_AuthorizedSession): + """Authorized session for public OAuth 2.0 clients. + + Should be used by clients that are incapable of maintaining the + confidentiality of their credentials. For example, a shell in an + environment that is under the control of the resource owner. + + Uses the OAuth 2.0 Authorization Code grant type with + Proof Key for Code Exchange (PKCE) challenge. + + """ + + def __init__( + self, + client_id: str, + auth_uri: str, + token_uri: str, + token_updater: Optional[Callable] = None, + scope: Optional[List[str]] = None, + redirect_port: int = 37474, + redirect_timeout: int = 30, + open_browser: bool = True + ) -> None: + """Construct object. + + Parameters + ---------- + client_id: str + Identifier of an OAuth 2.0 client. + auth_uri: str + Unique resource identifier of the authorization endpoint - used by + the client to obtain authorization via redirection + token_uri: str + Unique resource identifier of the token endpoint - used by the + client to obtain an access token via an authorization code grant + token_updater: Callable, optional + Function for handling retrieved access tokens + (signature: ``def token_updater(token: Dict[str, str]) -> None``) + scope: List[str], optional + Restricted scope of client access + redirect_port: int, optional + Local port of HTTP server to which authentication requests will + be redirected + redirect_timeout: int, optional + Seconds to wait for redirect message + open_browser: bool, optional + Whether the authorization URL should automatically be opened in + a new tab of the default browser + + Note + ---- + When `open_browser` is set to ``True``, a window should open + in your browser and prompt you for your credentials. Otherwise, the + authorization URL must be obtained from the log message and copied + manually into a browser. + + Note + ---- + The OAuth 2.0 client must be configured to authorize the redirect URI + (default: ``"http://localhost:37474/"``). The URI must match exactly, + including the final slash. + + """ + logger.info( + "create session for public client using the authentication code " + "grant type" + ) + + redirect_server = _LocalHTTPServer(redirect_port) + redirect_uri = f"http://localhost:{redirect_port}/" + + client = WebApplicationClient(client_id=client_id, token={}) + super().__init__( + client=client, + scope=scope, + token_updater=token_updater, + redirect_uri=redirect_uri, + auto_refresh_url=token_uri, + auto_refresh_kwargs={"client_id": client_id}, + ) + + code_verifier = self._create_code_verifier() + extra_query_params = { + "code_challenge": self._create_s256_code_challenge(code_verifier), + "code_challenge_method": "S256", + } + extra_query_string = "&".join( + [f"{key}={value}" for key, value in extra_query_params.items()] + ) + + logger.info("authenticate via web application") + with redirect_server as redirect_session: + authorization_url, state = self.authorization_url(auth_uri) + logger.info(f'authorization URL: "{authorization_url}"') + if open_browser: + webbrowser.open_new_tab(authorization_url) + + logger.info("wait for receipt of authorization code") + code = None + t_end = time.time() + redirect_timeout + while time.time() < t_end: + try: + code = redirect_session.get_authorization_code(state) + break + except _AuthorizationCodeError: + time.sleep(1) + continue + if not code: + raise ValueError("Could not obtain authorization code.") + + authorization_url += f"&{extra_query_string}" + logger.info("fetch access token using received authorization code") + self.fetch_token( + token_url=token_uri, + authorization_response=authorization_url, + code_verifier=code_verifier, + code=code, + ) + + @staticmethod + def _create_code_verifier() -> str: + """Create a code verifier for PKCE code challenge. + + Returns + ------- + str + Code verifier + + """ + chars = string.ascii_letters + string.digits + rand = random.SystemRandom() + return "".join(rand.choice(chars) for _ in range(48)) + + @staticmethod + def _create_s256_code_challenge(value: Union[str, int, float]) -> str: + """Create a PKCE code challenge using the S256 method. + + Parameters + ---------- + code_verifier: Union[int, float, str] + Random string of 48 random characters + + Returns + ------- + str + Code challenge + + """ + # Encodes the provided values as follows: BASE64(SHA256(ASCII(value))) + data = hashlib.sha256(bytes(str(value).encode("ascii"))).digest() + data = base64.urlsafe_b64encode(data).rstrip(b"=") + return data.decode("utf-8") + + +class ConfidentialClientSession(_AuthorizedSession): + """Authorized session for confidential OAuth 2.0 clients. + + Should be used for clients that are capable of maintaining the + confidentiality of their credentials. For example, a client used by an + application server on the backend in a secure environment. + + Uses the OAuth 2.0 Client Credentials grant type. + + """ + + def __init__( + self, + client_id: str, + client_secret: str, + token_uri: str, + scope: Optional[List[str]] = None, + token_updater: Optional[Callable] = None, + ): + """Construct object. + + Parameters + ---------- + client_id: str + Client identifier + client_secret: str + Client secret + token_uri: str + Unique resource identifier of the token endpoint - used by the + client to obtain an access token via the provided client secret + token_updater: Callable, optional + Function for handling retrieved access tokens + (signature: ``def token_updater(token: Dict[str, str]) -> None``) + scope: List[str], optional + Restricted scope of client access + + """ + logger.info( + "create session for confidential client using the client " + "authentication grant type" + ) + client = BackendApplicationClient(client_id=client_id) + super().__init__( + client=client, + scope=scope, + auto_refresh_url=token_uri, + token_updater=token_updater, + auto_refresh_kwargs={ + "client_id": client_id, + "client_secret": client_secret, + }, + ) + logger.info("fetch access token using client credentials") + self.fetch_token(client_secret=client_secret, token_url=token_uri) + + +def create_session_from_client_credentials( + credentials: Union[ConfidentialClientCredentials, PublicClientCredentials] +) -> requests.Session: + """Construct an authorized session for accessing protected web resources. + + Parameters + ---------- + credentials: Union[dicomweb_client.session_utils.ConfidentialClientCredentials, dicomweb_client.session_utils.PublicClientCredentials] + Credentials of OAuth 2.0 client (public or confidential) + + Returns + ------- + requests.Session + Authorized session object + + """ # noqa + if isinstance(credentials, PublicClientCredentials): + return PublicClientSession( + client_id=credentials.client_id, + auth_uri=credentials.auth_uri, + token_uri=credentials.token_uri, + ) + elif isinstance(credentials, ConfidentialClientCredentials): + return ConfidentialClientSession( + client_id=credentials.client_id, + client_secret=credentials.client_secret, + token_uri=credentials.token_uri, + ) + else: + raise TypeError( + 'Argument "credentials" must be of type ' + '"PublicClientCredentials" or "ConfidentialClientCredentials".' + ) + + def create_session() -> requests.Session: """Creates an unauthorized session. @@ -22,8 +494,8 @@ def create_session() -> requests.Session: def create_session_from_auth( - auth: requests.auth.AuthBase - ) -> requests.Session: + auth: requests.auth.AuthBase +) -> requests.Session: """Creates a session from a gicen AuthBase object. Parameters @@ -46,9 +518,9 @@ def create_session_from_auth( def create_session_from_user_pass( - username: str, - password: str - ) -> requests.Session: + username: str, + password: str +) -> requests.Session: """Creates a session from a given username and password. Parameters @@ -72,10 +544,10 @@ def create_session_from_user_pass( def add_certs_to_session( - session: requests.Session, - ca_bundle: Optional[str] = None, - cert: Optional[str] = None - ) -> requests.Session: + session: requests.Session, + ca_bundle: Optional[str] = None, + cert: Optional[str] = None +) -> requests.Session: """Adds CA bundle and certificate to an existing session. Parameters @@ -113,8 +585,8 @@ def add_certs_to_session( def create_session_from_gcp_credentials( - google_credentials: Optional[Any] = None - ) -> requests.Session: + google_credentials: Optional[Any] = None +) -> requests.Session: """Creates an authorized session for Google Cloud Platform. Parameters diff --git a/tests/test_session_utils.py b/tests/test_session_utils.py new file mode 100644 index 0000000..bf2f78d --- /dev/null +++ b/tests/test_session_utils.py @@ -0,0 +1,151 @@ +import time +import unittest + +import responses # type: ignore +import requests + +from dicomweb_client.session_utils import ( + ConfidentialClientCredentials, + ConfidentialClientSession, + create_session_from_client_credentials, + PublicClientCredentials, +) + + +class TestConfidentialClientCredentials(unittest.TestCase): + + def test_construction(self) -> None: + client_id = "test" + client_secret = "client_secret" + token_uri = "https://test.mghpathology.org/token" + credentials = ConfidentialClientCredentials( + client_id=client_id, + client_secret=client_secret, + token_uri=token_uri + ) + self.assertIsInstance(credentials, ConfidentialClientCredentials) + self.assertEqual(credentials.client_id, client_id) + self.assertEqual(credentials.client_secret, client_secret) + self.assertEqual(credentials.token_uri, token_uri) + + def test_construction_missing_client_secret(self) -> None: + with self.assertRaises(TypeError): + ConfidentialClientCredentials( + client_id="test", + token_uri="https://test.mghpathology.org/token" + ) + + def test_construction_extra_auth_uri(self) -> None: + with self.assertRaises(TypeError): + ConfidentialClientCredentials( + client_id="id", + client_secret="secret", + token_uri="https://test.mghpathology.org/token", + auth_uri="https://test.mghpathology.org/auth", + ) + + +class TestPublicClientCredentials(unittest.TestCase): + + def test_construction(self) -> None: + client_id = "test" + token_uri = "https://test.mghpathology.org/token" + auth_uri = "https://test.mghpathology.org/auth" + credentials = PublicClientCredentials( + client_id=client_id, + token_uri=token_uri, + auth_uri=auth_uri, + ) + self.assertIsInstance(credentials, PublicClientCredentials) + self.assertEqual(credentials.client_id, client_id) + self.assertEqual(credentials.token_uri, token_uri) + self.assertEqual(credentials.auth_uri, auth_uri) + + def test_construction_extra_client_secret(self) -> None: + with self.assertRaises(TypeError): + PublicClientCredentials( + client_id="id", + client_secret="secret", + token_uri="https://test.mghpathology.org/token", + ) + + def test_construction_missing_auth_uri(self) -> None: + with self.assertRaises(TypeError): + PublicClientCredentials( + client_id="id", + token_uri="https://test.mghpathology.org/token", + ) + + +class TestConfidentialClientSession(unittest.TestCase): + def setUp(self) -> None: + self._credentials = ConfidentialClientCredentials( + client_id="test", + client_secret="test", + token_uri="https://test.mghpathology.org/token", + ) + self._token = { + "token_type": "Bearer", + "access_token": "a", + "refresh_token": "b", + "expires_in": "1", + "expires_at": str(time.time()), + } + + @responses.activate + def test_construction(self) -> None: + responses.add( + responses.Response( + method="POST", + url=self._credentials.token_uri, + json=self._token, + status=200, + ) + ) + session = ConfidentialClientSession( + client_id=self._credentials.client_id, + client_secret=self._credentials.client_secret, + token_uri=self._credentials.token_uri, + ) + self.assertIsInstance(session, ConfidentialClientSession) + self.assertIsInstance(session, requests.Session) + self.assertIsInstance(session.token, dict) + self.assertEqual( + session.token["access_token"], + self._token["access_token"] + ) + + @responses.activate + def test_construction_using_factory_function(self) -> None: + responses.add( + responses.Response( + method="POST", + url=self._credentials.token_uri, + json=self._token, + status=200, + ) + ) + session = create_session_from_client_credentials(self._credentials) + self.assertIsInstance(session, ConfidentialClientSession) + self.assertIsInstance(session, requests.Session) + + @responses.activate + def test_construction_with_scope(self) -> None: + responses.add( + responses.Response( + method="POST", + url=self._credentials.token_uri, + json=self._token, + status=200, + ) + ) + scope = [ + "foo", + "bar", + ] + ConfidentialClientSession( + client_id=self._credentials.client_id, + client_secret=self._credentials.client_secret, + token_uri=self._credentials.token_uri, + scope=scope, + ) From 241b8dbe006130e84b8eb48792656302fc3a66ee Mon Sep 17 00:00:00 2001 From: hackermd Date: Tue, 3 May 2022 17:33:07 -0400 Subject: [PATCH 2/3] Remove obsolete import statement --- src/dicomweb_client/session_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dicomweb_client/session_utils.py b/src/dicomweb_client/session_utils.py index 798200f..4e00102 100644 --- a/src/dicomweb_client/session_utils.py +++ b/src/dicomweb_client/session_utils.py @@ -11,7 +11,7 @@ from abc import ABCMeta from http.server import BaseHTTPRequestHandler, HTTPServer from threading import Thread -from typing import Any, Callable, List, NamedTuple, NoReturn, Optional, Union +from typing import Any, Callable, List, NamedTuple, Optional, Union import requests from oauthlib.oauth2 import ( # type: ignore From 4dcc26d6a81d4d90578cddaa0ad8103b774e9f21 Mon Sep 17 00:00:00 2001 From: hackermd Date: Tue, 3 May 2022 17:38:58 -0400 Subject: [PATCH 3/3] Fix mypy errors --- src/dicomweb_client/file.py | 2 +- tests/test_session_utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dicomweb_client/file.py b/src/dicomweb_client/file.py index 22951ae..7883ed2 100644 --- a/src/dicomweb_client/file.py +++ b/src/dicomweb_client/file.py @@ -612,7 +612,7 @@ def decode_frame(self, index: int, value: bytes): n_pixels = self._pixels_per_frame pixel_offset = int(((index * n_pixels / 8) % 1) * 8) pixel_array = unpacked_frame[pixel_offset:pixel_offset + n_pixels] - return pixel_array.reshape(rows, columns) + return pixel_array.reshape(rows, columns) # type: ignore else: # This hack creates a small dataset containing a Pixel Data element # with only a single frame item, which can then be decoded using the diff --git a/tests/test_session_utils.py b/tests/test_session_utils.py index bf2f78d..b7607a5 100644 --- a/tests/test_session_utils.py +++ b/tests/test_session_utils.py @@ -30,14 +30,14 @@ def test_construction(self) -> None: def test_construction_missing_client_secret(self) -> None: with self.assertRaises(TypeError): - ConfidentialClientCredentials( + ConfidentialClientCredentials( # type: ignore client_id="test", token_uri="https://test.mghpathology.org/token" ) def test_construction_extra_auth_uri(self) -> None: with self.assertRaises(TypeError): - ConfidentialClientCredentials( + ConfidentialClientCredentials( # type: ignore client_id="id", client_secret="secret", token_uri="https://test.mghpathology.org/token", @@ -63,7 +63,7 @@ def test_construction(self) -> None: def test_construction_extra_client_secret(self) -> None: with self.assertRaises(TypeError): - PublicClientCredentials( + PublicClientCredentials( # type: ignore client_id="id", client_secret="secret", token_uri="https://test.mghpathology.org/token", @@ -71,7 +71,7 @@ def test_construction_extra_client_secret(self) -> None: def test_construction_missing_auth_uri(self) -> None: with self.assertRaises(TypeError): - PublicClientCredentials( + PublicClientCredentials( # type: ignore client_id="id", token_uri="https://test.mghpathology.org/token", )