diff --git a/stripe/__init__.py b/stripe/__init__.py index 85cb3c7dc..937cff5b2 100644 --- a/stripe/__init__.py +++ b/stripe/__init__.py @@ -97,9 +97,6 @@ def _warn_if_mismatched_proxy(): WebhookSignature as WebhookSignature, ) -from stripe._raw_request import raw_request as raw_request # noqa -from stripe._raw_request import deserialize as deserialize # noqa - # StripeClient from stripe._stripe_client import StripeClient as StripeClient # noqa diff --git a/stripe/_api_requestor.py b/stripe/_api_requestor.py index 1ade3619e..65bb449fe 100644 --- a/stripe/_api_requestor.py +++ b/stripe/_api_requestor.py @@ -529,7 +529,7 @@ def _args_for_request_with_retries( usage: Optional[List[str]] = None, ): """ - Mechanism for issuing an API call + Mechanism for issuing an API call. Used by request_raw and request_raw_async. """ request_options = merge_options(self._options, options) @@ -800,7 +800,7 @@ def _interpret_response( rbody: object, rcode: int, rheaders: Mapping[str, str], - api_mode: Optional[ApiMode], + api_mode: ApiMode, ) -> StripeResponse: try: if hasattr(rbody, "decode"): @@ -831,7 +831,7 @@ def _interpret_streaming_response( stream: IOBase, rcode: int, rheaders: Mapping[str, str], - api_mode: Optional[ApiMode], + api_mode: ApiMode, ) -> StripeStreamResponse: # Streaming response are handled with minimal processing for the success # case (ie. we don't want to read the content). When an error is @@ -862,7 +862,7 @@ async def _interpret_streaming_response_async( stream: AsyncIterable[bytes], rcode: int, rheaders: Mapping[str, str], - api_mode: Optional[ApiMode], + api_mode: ApiMode, ) -> StripeStreamResponseAsync: if self._should_handle_code_as_error(rcode): json_content = b"".join([chunk async for chunk in stream]) diff --git a/stripe/_raw_request.py b/stripe/_raw_request.py deleted file mode 100644 index 99149b5a5..000000000 --- a/stripe/_raw_request.py +++ /dev/null @@ -1,50 +0,0 @@ -from stripe._api_requestor import _APIRequestor -from stripe._util import _convert_to_stripe_object, get_api_mode - -from typing import Any, Dict, Optional, Union - -from stripe._stripe_object import StripeObject -from stripe._stripe_response import StripeResponse -from stripe._request_options import extract_options_from_dict -from stripe._api_mode import ApiMode - - -def raw_request(method_, url_, **params): - params = params.copy() - options, params = extract_options_from_dict(params) - api_mode = get_api_mode(url_) - base_address = params.pop("base", "api") - - requestor = _APIRequestor._global_instance() - - rbody, rcode, rheaders = requestor.request_raw( - method_, - url_, - params=params, - options=options, - base_address=base_address, - api_mode=api_mode, - ) - - return requestor._interpret_response(rbody, rcode, rheaders, api_mode) - - -def deserialize( - resp: Union[StripeResponse, Dict[str, Any]], - api_key: Optional[str] = None, - stripe_version: Optional[str] = None, - stripe_account: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, - *, - api_mode: ApiMode, -) -> StripeObject: - return _convert_to_stripe_object( - resp=resp, - params=params, - requestor=_APIRequestor._global_with_options( - api_key=api_key, - stripe_version=stripe_version, - stripe_account=stripe_account, - ), - api_mode=api_mode, - ) diff --git a/stripe/_stripe_client.py b/stripe/_stripe_client.py index d3e1284b7..b51fe7e38 100644 --- a/stripe/_stripe_client.py +++ b/stripe/_stripe_client.py @@ -10,8 +10,10 @@ DEFAULT_METER_EVENTS_API_BASE, ) +from stripe._api_mode import ApiMode from stripe._error import AuthenticationError from stripe._api_requestor import _APIRequestor +from stripe._request_options import extract_options_from_dict from stripe._requestor_options import RequestorOptions, BaseAddresses from stripe._client_options import _ClientOptions from stripe._http_client import ( @@ -20,11 +22,14 @@ new_http_client_async_fallback, ) from stripe._api_version import _ApiVersion +from stripe._stripe_object import StripeObject +from stripe._stripe_response import StripeResponse +from stripe._util import _convert_to_stripe_object, get_api_mode from stripe._webhook import Webhook, WebhookSignature from stripe._event import Event from stripe.v2._event import ThinEvent -from typing import Optional, Union, cast +from typing import Any, Dict, Optional, Union, cast # Non-generated services from stripe._oauth_service import OAuthService @@ -298,3 +303,64 @@ def parse_snapshot_event( ) return event + + def raw_request(self, method_: str, url_: str, **params): + params = params.copy() + options, params = extract_options_from_dict(params) + api_mode = get_api_mode(url_) + base_address = params.pop("base", "api") + + stripe_context = params.pop("stripe_context", None) + + # stripe-context goes *here* and not in api_requestor. Properties + # go on api_requestor when you want them to persist onto requests + # made when you call instance methods on APIResources that come from + # the first request. No need for that here, as we aren't deserializing APIResources + if stripe_context is not None: + options["headers"] = options.get("headers", {}) + assert isinstance(options["headers"], dict) + options["headers"].update({"Stripe-Context": stripe_context}) + + rbody, rcode, rheaders = self._requestor.request_raw( + method_, + url_, + params=params, + options=options, + base_address=base_address, + api_mode=api_mode, + usage=["raw_request"], + ) + + return self._requestor._interpret_response(rbody, rcode, rheaders, api_mode) + + async def raw_request_async(self, method_: str, url_: str, **params): + params = params.copy() + options, params = extract_options_from_dict(params) + api_mode = get_api_mode(url_) + base_address = params.pop("base", "api") + + rbody, rcode, rheaders = await self._requestor.request_raw_async( + method_, + url_, + params=params, + options=options, + base_address=base_address, + api_mode=api_mode, + usage=["raw_request"], + ) + + return self._requestor._interpret_response(rbody, rcode, rheaders, api_mode) + + def deserialize( + self, + resp: Union[StripeResponse, Dict[str, Any]], + params: Optional[Dict[str, Any]] = None, + *, + api_mode: ApiMode + ) -> StripeObject: + return _convert_to_stripe_object( + resp=resp, + params=params, + requestor=self._requestor, + api_mode=api_mode, + ) diff --git a/tests/test_integration.py b/tests/test_integration.py index f3742ca06..843937541 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -13,6 +13,8 @@ from collections import defaultdict from typing import List, Dict, Tuple, Optional +from stripe._stripe_client import StripeClient + if platform.python_implementation() == "PyPy": pytest.skip("skip integration tests with PyPy", allow_module_level=True) @@ -347,7 +349,9 @@ async def async_http_client(self, request, anyio_backend): async def set_global_async_http_client(self, async_http_client): stripe.default_http_client = async_http_client - async def test_async_success(self, set_global_async_http_client): + async def test_async_raw_request_success( + self, set_global_async_http_client + ): class MockServerRequestHandler(MyTestHandler): default_body = '{"id": "cus_123", "object": "customer"}'.encode( "utf-8" @@ -356,11 +360,16 @@ class MockServerRequestHandler(MyTestHandler): self.setup_mock_server(MockServerRequestHandler) - stripe.api_base = "http://localhost:%s" % self.mock_server_port - - cus = await stripe.Customer.create_async( - description="My test customer" + client = StripeClient( + "sk_test_123", + base_addresses={ + "api": "http://localhost:%s" % self.mock_server_port + }, + ) + resp = await client.raw_request_async( + "post", "/v1/customers", description="My test customer" ) + cus = client.deserialize(resp.data, api_mode="V1") reqs = MockServerRequestHandler.get_requests(1) req = reqs[0] @@ -369,14 +378,15 @@ class MockServerRequestHandler(MyTestHandler): assert req.command == "POST" assert isinstance(cus, stripe.Customer) - async def test_async_timeout(self, set_global_async_http_client): + async def test_async_raw_request_timeout( + self, set_global_async_http_client + ): class MockServerRequestHandler(MyTestHandler): def do_request(self, n): time.sleep(0.02) return super().do_request(n) self.setup_mock_server(MockServerRequestHandler) - stripe.api_base = "http://localhost:%s" % self.mock_server_port # If we set HTTPX's generic timeout the test is flaky (sometimes it's a ReadTimeout, sometimes its a ConnectTimeout) # so we set only the read timeout specifically. hc = stripe.default_http_client @@ -390,11 +400,20 @@ def do_request(self, n): expected_message = "A ServerTimeoutError was raised" else: raise ValueError(f"Unknown http client: {hc.name}") - stripe.max_network_retries = 0 exception = None try: - await stripe.Customer.create_async(description="My test customer") + client = StripeClient( + "sk_test_123", + http_client=hc, + base_addresses={ + "api": "http://localhost:%s" % self.mock_server_port + }, + max_network_retries=0, + ) + await client.raw_request_async( + "post", "/v1/customers", description="My test customer" + ) except stripe.APIConnectionError as e: exception = e @@ -402,7 +421,9 @@ def do_request(self, n): assert expected_message in str(exception.user_message) - async def test_async_retries(self, set_global_async_http_client): + async def test_async_raw_request_retries( + self, set_global_async_http_client + ): class MockServerRequestHandler(MyTestHandler): def do_request(self, n): if n == 0: @@ -416,16 +437,26 @@ def do_request(self, n): pass self.setup_mock_server(MockServerRequestHandler) - stripe.api_base = "http://localhost:%s" % self.mock_server_port - await stripe.Customer.create_async(description="My test customer") + client = StripeClient( + "sk_test_123", + base_addresses={ + "api": "http://localhost:%s" % self.mock_server_port + }, + max_network_retries=stripe.max_network_retries, + ) + await client.raw_request_async( + "post", "/v1/customers", description="My test customer" + ) reqs = MockServerRequestHandler.get_requests(2) req = reqs[0] assert req.path == "/v1/customers" - async def test_async_unretryable(self, set_global_async_http_client): + async def test_async_raw_request_unretryable( + self, set_global_async_http_client + ): class MockServerRequestHandler(MyTestHandler): def do_request(self, n): return ( @@ -437,11 +468,18 @@ def do_request(self, n): pass self.setup_mock_server(MockServerRequestHandler) - stripe.api_base = "http://localhost:%s" % self.mock_server_port exception = None try: - await stripe.Customer.create_async(description="My test customer") + client = StripeClient( + "sk_test_123", + base_addresses={ + "api": "http://localhost:%s" % self.mock_server_port + }, + ) + await client.raw_request_async( + "post", "/v1/customers", description="My test customer" + ) except stripe.AuthenticationError as e: exception = e diff --git a/tests/test_raw_request.py b/tests/test_raw_request.py index a9a9aab9c..49fb007ef 100644 --- a/tests/test_raw_request.py +++ b/tests/test_raw_request.py @@ -2,6 +2,8 @@ import datetime +import pytest + import stripe from tests.test_api_requestor import GMT1 @@ -18,7 +20,9 @@ class TestRawRequest(object): POST_REL_URL_V2 = "/v2/billing/meter_event_session" GET_REL_URL_V2 = "/v2/accounts/acct_123" - def test_form_request_get(self, http_client_mock): + def test_form_request_get( + self, http_client_mock, stripe_mock_stripe_client + ): http_client_mock.stub_request( "get", path=self.GET_REL_URL, @@ -27,13 +31,15 @@ def test_form_request_get(self, http_client_mock): rheaders={}, ) - resp = stripe.raw_request("get", self.GET_REL_URL) + resp = stripe_mock_stripe_client.raw_request("get", self.GET_REL_URL) http_client_mock.assert_requested("get", path=self.GET_REL_URL) - deserialized = stripe.deserialize(resp, api_mode="V1") + deserialized = stripe_mock_stripe_client.deserialize(resp, api_mode="V1") assert isinstance(deserialized, stripe.Account) - def test_form_request_post(self, http_client_mock): + def test_form_request_post( + self, http_client_mock, stripe_mock_stripe_client + ): http_client_mock.stub_request( "post", path=self.POST_REL_URL, @@ -44,7 +50,7 @@ def test_form_request_post(self, http_client_mock): expectation = "type=standard&int=123&datetime=1356994801" - resp = stripe.raw_request( + resp = stripe_mock_stripe_client.raw_request( "post", self.POST_REL_URL, **self.ENCODE_INPUTS ) @@ -55,10 +61,12 @@ def test_form_request_post(self, http_client_mock): post_data=expectation, ) - deserialized = stripe.deserialize(resp, api_mode="V1") + deserialized = stripe_mock_stripe_client.deserialize(resp, api_mode="V1") assert isinstance(deserialized, stripe.Account) - def test_preview_request_post(self, http_client_mock): + def test_preview_request_post( + self, http_client_mock, stripe_mock_stripe_client + ): http_client_mock.stub_request( "post", path=self.POST_REL_URL_V2, @@ -72,7 +80,9 @@ def test_preview_request_post(self, http_client_mock): '{"type": "standard", "int": 123, "datetime": 1356994801}' ) - resp = stripe.raw_request("post", self.POST_REL_URL_V2, **params) + resp = stripe_mock_stripe_client.raw_request( + "post", self.POST_REL_URL_V2, **params + ) http_client_mock.assert_requested( "post", @@ -82,10 +92,14 @@ def test_preview_request_post(self, http_client_mock): is_json=True, ) - deserialized = stripe.deserialize(resp, api_mode="V2") - assert isinstance(deserialized, stripe.v2.billing.MeterEventSession) + deserialized = stripe_mock_stripe_client.deserialize( + resp, api_mode="V2" + ) + assert isinstance(deserialized, stripe.Account) - def test_form_request_with_extra_headers(self, http_client_mock): + def test_form_request_with_extra_headers( + self, http_client_mock, stripe_mock_stripe_client + ): http_client_mock.stub_request( "get", path=self.GET_REL_URL, @@ -97,7 +111,9 @@ def test_form_request_with_extra_headers(self, http_client_mock): extra_headers = {"foo": "bar", "Stripe-Account": "acct_123"} params = {"headers": extra_headers} - stripe.raw_request("get", self.GET_REL_URL, **params) + stripe_mock_stripe_client.raw_request( + "get", self.GET_REL_URL, **params + ) http_client_mock.assert_requested( "get", @@ -105,7 +121,9 @@ def test_form_request_with_extra_headers(self, http_client_mock): extra_headers=extra_headers, ) - def test_preview_request_default_api_version(self, http_client_mock): + def test_preview_request_default_api_version( + self, http_client_mock, stripe_mock_stripe_client + ): http_client_mock.stub_request( "get", path=self.GET_REL_URL_V2, @@ -115,14 +133,18 @@ def test_preview_request_default_api_version(self, http_client_mock): ) params = {} - stripe.raw_request("get", self.GET_REL_URL_V2, **params) + stripe_mock_stripe_client.raw_request( + "get", self.GET_REL_URL_V2, **params + ) http_client_mock.assert_requested( "get", path=self.GET_REL_URL_V2, ) - def test_preview_request_overridden_api_version(self, http_client_mock): + def test_preview_request_overridden_api_version( + self, http_client_mock, stripe_mock_stripe_client + ): http_client_mock.stub_request( "post", path=self.POST_REL_URL_V2, @@ -135,7 +157,9 @@ def test_preview_request_overridden_api_version(self, http_client_mock): "stripe_version": stripe_version_override, } - stripe.raw_request("post", self.POST_REL_URL_V2, **params) + stripe_mock_stripe_client.raw_request( + "post", self.POST_REL_URL_V2, **params + ) http_client_mock.assert_requested( "post", @@ -145,3 +169,53 @@ def test_preview_request_overridden_api_version(self, http_client_mock): post_data="{}", is_json=True, ) + +# TODO(jar) this test is not applicable yet, but may be some day +# @pytest.mark.anyio +# async def test_form_request_get_async( +# self, http_client_mock, stripe_mock_stripe_client +# ): +# http_client_mock.stub_request( +# "get", +# path=self.GET_REL_URL, +# rbody='{"id": "acct_123", "object": "account"}', +# rcode=200, +# rheaders={}, +# ) +# +# resp = await stripe_mock_stripe_client.raw_request_async( +# "get", self.GET_REL_URL +# ) +# +# http_client_mock.assert_requested("get", path=self.GET_REL_URL) +# +# deserialized = stripe_mock_stripe_client.deserialize(resp) +# assert isinstance(deserialized, stripe.Account) +# + def test_raw_request_usage_reported( + self, http_client_mock, stripe_mock_stripe_client + ): + http_client_mock.stub_request( + "post", + path=self.POST_REL_URL, + rbody='{"id": "acct_123", "object": "account"}', + rcode=200, + rheaders={}, + ) + + expectation = "type=standard&int=123&datetime=1356994801" + + resp = stripe_mock_stripe_client.raw_request( + "post", self.POST_REL_URL, **self.ENCODE_INPUTS + ) + + http_client_mock.assert_requested( + "post", + path=self.POST_REL_URL, + content_type="application/x-www-form-urlencoded", + post_data=expectation, + usage=["raw_request"], + ) + + deserialized = stripe_mock_stripe_client.deserialize(resp, api_mode="V1") + assert isinstance(deserialized, stripe.Account)