From 2a2b482daa171b5fe169b43f942ab29a2ef172bb Mon Sep 17 00:00:00 2001 From: lingfeng Date: Mon, 7 Aug 2023 13:00:48 +0800 Subject: [PATCH] feat: Encapsulated aiohttp methods to replace __del__ to avoid http connection leaks (#55) --- src/casdoor/async_main.py | 308 ++++++++++++++++------------------ src/tests/test_async_oauth.py | 97 ++++++----- 2 files changed, 204 insertions(+), 201 deletions(-) diff --git a/src/casdoor/async_main.py b/src/casdoor/async_main.py index 505956d..5037e5e 100644 --- a/src/casdoor/async_main.py +++ b/src/casdoor/async_main.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import json +import base64 +from functools import cached_property from typing import Dict, List, Optional import aiohttp @@ -23,19 +23,54 @@ import jwt +from yarl import URL + from .user import User +class AioHttpClient: + def __init__(self, base_url): + self.base_url = base_url + self.session = None + + async def fetch(self, path, method="GET", **kwargs): + url = self.base_url + path + async with self.session.request(method, url, **kwargs) as response: + if ( + response.status != 200 + and "application/json" not in response.headers["Content-Type"] + ): + raise ValueError(f"Casdoor response error:{response.text}") + return await response.json() + + async def get(self, path, **kwargs): + return await self.fetch(path, method="GET", **kwargs) + + async def post(self, path, **kwargs): + return await self.fetch(path, method="POST", **kwargs) + + async def __aenter__(self): + self.session = await aiohttp.ClientSession().__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type: + raise exc_val + finally: + await self.session.__aexit__(exc_type, exc_val, exc_tb) + + class AsyncCasdoorSDK: def __init__( - self, - endpoint: str, - client_id: str, - client_secret: str, - certificate: str, - org_name: str, - application_name: str, - front_endpoint: str = None + self, + endpoint: str, + client_id: str, + client_secret: str, + certificate: str, + org_name: str, + application_name: str, + front_endpoint: str = None, ): self.endpoint = endpoint if front_endpoint: @@ -50,24 +85,31 @@ def __init__( self.grant_type = "authorization_code" self.algorithms = ["RS256"] - self._session = aiohttp.ClientSession() + self._session = AioHttpClient(base_url=self.endpoint) - def __del__(self): - loop = asyncio.get_running_loop() - loop.create_task(self._session.close()) + @cached_property + def headers(self) -> Dict: + basic_auth = base64.b64encode( + f"{self.client_id}:{self.client_secret}".encode("utf-8") + ).decode("utf-8") + return { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Basic {basic_auth}", + } @property def certification(self) -> bytes: - if type(self.certificate) is not str: + if not isinstance(self.certificate, str): raise TypeError("certificate field must be str type") return self.certificate.encode("utf-8") async def get_auth_link( - self, - redirect_uri: str, - response_type: str = "code", - scope: str = "read" - ): + self, + redirect_uri: str, + response_type: str = "code", + scope: str = "read", + ) -> str: url = self.front_endpoint + "/login/oauth/authorize" params = { "client_id": self.client_id, @@ -76,14 +118,13 @@ async def get_auth_link( "scope": scope, "state": self.application_name, } - async with self._session.request("", url, params=params) as request: - return request.url + return str(URL(url).with_query(params)) async def get_oauth_token( - self, - code: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None + self, + code: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, ) -> Dict: """ Request the Casdoor server to get OAuth token. @@ -99,10 +140,10 @@ async def get_oauth_token( return await self.oauth_token_request(code, username, password) def _get_payload_for_access_token_request( - self, - code: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None + self, + code: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, ) -> Dict: """ Return payload for request body which was selecting by strategy. @@ -111,8 +152,7 @@ def _get_payload_for_access_token_request( return self.__get_payload_for_authorization_code(code=code) elif username and password: return self.__get_payload_for_password_credentials( - username=username, - password=password + username=username, password=password ) else: return self.__get_payload_for_client_credentials() @@ -129,9 +169,7 @@ def __get_payload_for_authorization_code(self, code: str) -> Dict: } def __get_payload_for_password_credentials( - self, - username: str, - password: str + self, username: str, password: str ) -> Dict: """ Return payload for auth request with resource owner password @@ -142,7 +180,7 @@ def __get_payload_for_password_credentials( "client_id": self.client_id, "client_secret": self.client_secret, "username": username, - "password": password + "password": password, } def __get_payload_for_client_credentials(self) -> Dict: @@ -156,10 +194,10 @@ def __get_payload_for_client_credentials(self) -> Dict: } async def oauth_token_request( - self, - code: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None + self, + code: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, ) -> Dict: """ Request the Casdoor server to get access_token. @@ -174,9 +212,7 @@ async def oauth_token_request( :return: Response from Casdoor """ params = self._get_payload_for_access_token_request( - code=code, - username=username, - password=password + code=code, username=username, password=password ) return await self._oauth_token_request(payload=params) @@ -187,14 +223,12 @@ async def _oauth_token_request(self, payload: Dict) -> Dict: :param payload: Body for POST request. :return: Response from Casdoor """ - url = self.endpoint + "/api/login/oauth/access_token" - async with self._session.post(url, data=payload) as response: - return await response.json() + path = "/api/login/oauth/access_token" + async with self._session as session: + return await session.post(path, data=payload) async def refresh_token_request( - self, - refresh_token: str, - scope: str = "" + self, refresh_token: str, scope: str = "" ) -> Dict: """ Request the Casdoor server to get access_token. @@ -203,7 +237,7 @@ async def refresh_token_request( :param scope: OAuth scope :return: Response from Casdoor """ - url = self.endpoint + "/api/login/oauth/refresh_token" + path = "/api/login/oauth/refresh_token" params = { "grant_type": "refresh_token", "client_id": self.client_id, @@ -211,13 +245,11 @@ async def refresh_token_request( "scope": scope, "refresh_token": refresh_token, } - async with self._session.post(url, data=params) as request: - return await request.json() + async with self._session as session: + return await session.post(path, data=params) async def refresh_oauth_token( - self, - refresh_token: str, - scope: str = "" + self, refresh_token: str, scope: str = "" ) -> str: """ Request the Casdoor server to get access_token. @@ -226,10 +258,8 @@ async def refresh_oauth_token( :param scope: OAuth scope :return: Response from Casdoor """ - r = await self.refresh_token_request(refresh_token, scope) - access_token = r.get("access_token") - - return access_token + token = await self.refresh_token_request(refresh_token, scope) + return token.get("access_token") def parse_jwt_token(self, token: str) -> Dict: """ @@ -240,8 +270,7 @@ def parse_jwt_token(self, token: str) -> Dict: :return: the data in dict format """ certificate = x509.load_pem_x509_certificate( - self.certification, - default_backend() + self.certification, default_backend() ) return_json = jwt.decode( @@ -253,17 +282,18 @@ def parse_jwt_token(self, token: str) -> Dict: return return_json async def enforce( - self, - permission_model_name: str, - sub: str, - obj: str, - act: str, - v3: Optional[str] = None, - v4: Optional[str] = None, - v5: Optional[str] = None, + self, + permission_model_name: str, + sub: str, + obj: str, + act: str, + v3: Optional[str] = None, + v4: Optional[str] = None, + v5: Optional[str] = None, ) -> bool: """ Send data to Casdoor enforce API + # https://casdoor.org/docs/permission/exposed-casbin-apis#enforce :param permission_model_name: Name permission model :param sub: sub from Casbin @@ -273,11 +303,7 @@ async def enforce( :param v4: v4 from Casbin :param v5: v5 from Casbin """ - url = self.endpoint + "/api/enforce" - query_params = { - "clientId": self.client_id, - "clientSecret": self.client_secret - } + path = "/api/enforce" params = { "id": permission_model_name, "v0": sub, @@ -287,28 +313,16 @@ async def enforce( "v4": v4, "v5": v5, } - async with self._session.post( - url, params=query_params, json=params - ) as response: - if ( - response.status != 200 or - "json" not in response.headers["content-type"] - ): - error_str = "Casdoor response error:\n" + str(response.text) - raise ValueError(error_str) - - has_permission = await response.json() - + async with self._session as session: + has_permission = await session.post( + path, headers=self.headers, json=params + ) if not isinstance(has_permission, bool): - error_str = "Casdoor response error:\n" + await response.text() - raise ValueError(error_str) - + raise ValueError(f"Casdoor response error: {has_permission}") return has_permission async def batch_enforce( - self, - permission_model_name: str, - permission_rules: List[List[str]] + self, permission_model_name: str, permission_rules: List[List[str]] ) -> List[bool]: """ Send data to Casdoor enforce API @@ -322,58 +336,43 @@ async def batch_enforce( [][4] -> v4: v4 from Casbin (optional) [][5] -> v5: v5 from Casbin (optional) """ - url = self.endpoint + "/api/batch-enforce" - query_params = { - "clientId": self.client_id, - "clientSecret": self.client_secret - } + path = "/api/batch-enforce" def map_rule(rule: List[str], idx) -> Dict: if len(rule) < 3: - raise ValueError("Invalid permission rule[{0}]: {1}" - .format(idx, rule)) - result = { - "id": permission_model_name - } - for i in range(0, len(rule)): - result.update({"v{0}".format(i): rule[i]}) + raise ValueError(f"Invalid permission rule[{idx}]: {rule}") + result = {"id": permission_model_name} + for i in range(len(rule)): + result.update({f"v{i}": rule[i]}) return result - params = [map_rule(permission_rules[i], i) - for i in range(0, len(permission_rules))] - async with self._session.post( - url, params=query_params, json=params - ) as response: - if ( - response.status != 200 or - "json" not in response.headers["content-type"] - ): - error_str = "Casdoor response error:\n" + str(response.text) - raise ValueError(error_str) - - enforce_results = await response.json() + params = [ + map_rule(permission_rules[i], i) + for i in range(len(permission_rules)) + ] + async with self._session as session: + enforce_results = await session.post( + path, headers=self.headers, json=params + ) if not isinstance(enforce_results, bool): - error_str = "Casdoor response error:\n" + await response.text() - raise ValueError(error_str) + raise ValueError(f"Casdoor response error:{enforce_results}") return enforce_results - async def get_users(self) -> List[Dict]: + async def get_users(self) -> Dict: """ Get the users from Casdoor. :return: a list of dicts containing user info """ - url = self.endpoint + "/api/get-users" - params = { - "owner": self.org_name, - "clientId": self.client_id, - "clientSecret": self.client_secret, - } - async with self._session.get(url, params=params) as request: - users = await request.json() - return users + path = "/api/get-users" + params = {"owner": self.org_name} + async with self._session as session: + users = await session.get( + path, headers=self.headers, params=params + ) + return users["data"] async def get_user(self, user_id: str) -> Dict: """ @@ -382,15 +381,11 @@ async def get_user(self, user_id: str) -> Dict: :param user_id: the id of the user :return: a dict that contains user's info """ - url = self.endpoint + "/api/get-user" - params = { - "id": f"{self.org_name}/{user_id}", - "clientId": self.client_id, - "clientSecret": self.client_secret, - } - async with self._session.get(url, params=params) as request: - user = await request.json() - return user + path = "/api/get-user" + params = {"id": f"{self.org_name}/{user_id}"} + async with self._session as session: + user = await session.get(path, headers=self.headers, params=params) + return user["data"] async def get_user_count(self, is_online: bool = None) -> int: """ @@ -399,11 +394,9 @@ async def get_user_count(self, is_online: bool = None) -> int: None for all users :return: the count of filtered users for an organization """ - url = self.endpoint + "/api/get-user-count" + path = "/api/get-user-count" params = { "owner": self.org_name, - "clientId": self.client_id, - "clientSecret": self.client_secret, } if is_online is None: @@ -411,33 +404,26 @@ async def get_user_count(self, is_online: bool = None) -> int: else: params["isOnline"] = "1" if is_online else "0" - async with self._session.get(url, params=params) as request: - count = await request.json() - return count + async with self._session as session: + count = await session.get( + path, headers=self.headers, params=params + ) + return count["data"] - async def modify_user(self, method: str, user: User) -> Dict: - url = self.endpoint + f"/api/{method}" - user.owner = self.org_name - params = { - "id": f"{user.owner}/{user.name}", - "clientId": self.client_id, - "clientSecret": self.client_secret, - } - user_info = json.dumps(user.to_dict()) - async with self._session.post( - url, - params=params, - data=user_info - ) as request: - response = await request.json() - return response + async def modify_user(self, method: str, user: User, params=None) -> Dict: + path = f"/api/{method}" + async with self._session as session: + return await session.post( + path, params=params, headers=self.headers, json=user.to_dict() + ) async def add_user(self, user: User) -> Dict: response = await self.modify_user("add-user", user) return response async def update_user(self, user: User) -> Dict: - response = await self.modify_user("update-user", user) + params = {"id": f"{user.owner}/{user.name}"} + response = await self.modify_user("update-user", user, params) return response async def delete_user(self, user: User) -> Dict: diff --git a/src/tests/test_async_oauth.py b/src/tests/test_async_oauth.py index 940d383..11f2639 100644 --- a/src/tests/test_async_oauth.py +++ b/src/tests/test_async_oauth.py @@ -36,7 +36,6 @@ class TestOAuth(IsolatedAsyncioTestCase): @staticmethod def get_sdk(): - sdk = AsyncCasdoorSDK( endpoint="http://test.casbin.com:8000", client_id="3267f876b11e7d1cb217", @@ -55,8 +54,8 @@ async def test__oauth_token_request(self): "client_secret": sdk.client_secret, "code": self.code, } - response = await sdk._oauth_token_request(payload=data) - self.assertIsInstance(response, dict) + auth_token = await sdk._oauth_token_request(payload=data) + self.assertIn("access_token", auth_token) async def test__get_payload_for_authorization_code(self): sdk = self.get_sdk() @@ -68,14 +67,15 @@ async def test__get_payload_for_authorization_code(self): async def test__get_payload_for_password_credentials(self): sdk = self.get_sdk() result = sdk._AsyncCasdoorSDK__get_payload_for_password_credentials( # noqa: It's private method - username="test", - password="test" + username="test", password="test" ) self.assertEqual("password", result.get("grant_type")) async def test__get_payload_for_client_credentials(self): sdk = self.get_sdk() - result = sdk._AsyncCasdoorSDK__get_payload_for_client_credentials() # noqa: It's private method + result = ( + sdk._AsyncCasdoorSDK__get_payload_for_client_credentials() + ) # noqa: It's private method self.assertEqual("client_credentials", result.get("grant_type")) async def test__get_payload_for_access_token_request_with_code(self): @@ -86,8 +86,7 @@ async def test__get_payload_for_access_token_request_with_code(self): async def test__get_payload_for_access_token_request_with_cred(self): sdk = self.get_sdk() result = sdk._get_payload_for_access_token_request( - username="test", - password="test" + username="test", password="test" ) self.assertEqual("password", result.get("grant_type")) @@ -99,8 +98,7 @@ async def test_get_payload_for_access_token_request_with_client_cred(self): async def test_get_oauth_token_with_password(self): sdk = self.get_sdk() token = await sdk.get_oauth_token( - username=self.username, - password=self.password + username=self.username, password=self.password ) access_token = token.get("access_token") self.assertIsInstance(access_token, str) @@ -126,6 +124,7 @@ async def test_refresh_token_request(self): sdk = self.get_sdk() response = await sdk.oauth_token_request(self.code) refresh_token = response.get("refresh_token") + self.assertIsInstance(refresh_token, str) response = await sdk.refresh_token_request(refresh_token) self.assertIsInstance(response, dict) @@ -152,25 +151,29 @@ async def test_enforce(self): def mocked_enforce_requests_post(*args, **kwargs): class MockResponse: - def __init__(self, - json_data, - status_code=200, - headers={'content-type': 'json'}): + def __init__( + self, + json_data, + status_code=200, + headers={"content-type": "json"}, + ): self.json_data = json_data self.status_code = status_code self.headers = headers def json(self): return self.json_data + result = True for i in range(0, 5): - if kwargs.get('json').get(f"v{i}") != f"v{i}": + if kwargs.get("json").get(f"v{i}") != f"v{i}": result = False return MockResponse(result) - @mock.patch("aiohttp.ClientSession.post", - side_effect=mocked_enforce_requests_post) + @mock.patch( + "aiohttp.ClientSession.post", side_effect=mocked_enforce_requests_post + ) async def test_enforce_parmas(self, mock_post): sdk = self.get_sdk() status = await sdk.enforce( @@ -178,25 +181,28 @@ async def test_enforce_parmas(self, mock_post): "v0", "v1", "v2", - v3='v3', - v4='v4', - v5='v5' + v3="v3", + v4="v4", + v5="v5", ) self.assertEqual(status, True) def mocked_batch_enforce_requests_post(*args, **kwargs): class MockResponse: - def __init__(self, - json_data, - status_code=200, - headers={'content-type': 'json'}): + def __init__( + self, + json_data, + status_code=200, + headers={"content-type": "json"}, + ): self.json_data = json_data self.status_code = status_code self.headers = headers def json(self): return self.json_data - json = kwargs.get('json') + + json = kwargs.get("json") result = [True for i in range(0, len(json))] for k in range(0, len(json)): for i in range(0, len(json[k]) - 1): @@ -205,34 +211,34 @@ def json(self): return MockResponse(result) - @mock.patch("aiohttp.ClientSession.post", - side_effect=mocked_batch_enforce_requests_post) + @mock.patch( + "aiohttp.ClientSession.post", + side_effect=mocked_batch_enforce_requests_post, + ) def test_batch_enforce(self, mock_post): sdk = self.get_sdk() status = sdk.batch_enforce( "built-in/permission-built-in", [ - ["v0", "v1", "v2", "v3", "v4", 'v5'], - ["v0", "v1", "v2", "v3", "v4", "v1"] - ] + ["v0", "v1", "v2", "v3", "v4", "v5"], + ["v0", "v1", "v2", "v3", "v4", "v1"], + ], ) self.assertEqual(len(status), 2) self.assertEqual(status[0], True) self.assertEqual(status[1], False) - @mock.patch("aiohttp.ClientSession.post", - side_effect=mocked_batch_enforce_requests_post) + @mock.patch( + "aiohttp.ClientSession.post", + side_effect=mocked_batch_enforce_requests_post, + ) def test_batch_enforce_raise(self, mock_post): sdk = self.get_sdk() with self.assertRaises(ValueError) as context: - sdk.batch_enforce( - "built-in/permission-built-in", - [ - ["v0", "v1"] - ] - ) - self.assertEqual("Invalid permission rule[0]: ['v0', 'v1']", - str(context.exception)) + sdk.batch_enforce("built-in/permission-built-in", [["v0", "v1"]]) + self.assertEqual( + "Invalid permission rule[0]: ['v0', 'v1']", str(context.exception) + ) async def test_get_users(self): sdk = self.get_sdk() @@ -243,6 +249,7 @@ async def test_get_user(self): sdk = self.get_sdk() user = await sdk.get_user("admin") self.assertIsInstance(user, dict) + self.assertEqual(user["name"], "admin") async def test_get_user_count(self): sdk = self.get_sdk() @@ -258,6 +265,7 @@ async def test_modify_user(self): sdk = self.get_sdk() user = User() user.name = "test_ffyuanda" + user.owner = sdk.org_name await sdk.delete_user(user) response = await sdk.add_user(user) @@ -278,3 +286,12 @@ async def test_modify_user(self): def check_enforce_request(*args, **kwargs): return True + + async def test_auth_link(self): + sdk = self.get_sdk() + redirect_uri = "http://localhost:9000/callback" + response = await sdk.get_auth_link(redirect_uri=redirect_uri) + self.assertEqual( + response, + f"{sdk.front_endpoint}/login/oauth/authorize?client_id={sdk.client_id}&response_type=code&redirect_uri={redirect_uri}&scope=read&state={sdk.application_name}", # noqa + )