diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index 5e4b2b16e..37714161f 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -12,7 +12,6 @@ from stripe import error, oauth_error, http_client, version, util, six from stripe.multipart_data_generator import MultipartDataGenerator from stripe.six.moves.urllib.parse import urlencode, urlsplit, urlunsplit -from stripe.request_metrics import RequestMetrics from stripe.stripe_response import StripeResponse @@ -32,10 +31,6 @@ def _encode_nested_dict(key, data, fmt="%s[%s]"): return d -def _now_ms(): - return int(round(time.time() * 1000)) - - def _api_encode(data): for key, value in six.iteritems(data): key = util.utf8(key) @@ -110,8 +105,6 @@ def __init__( self._client = stripe.default_http_client self._default_proxy = proxy - self._last_request_metrics = None - @classmethod def format_app_info(cls, info): str = info["name"] @@ -277,11 +270,6 @@ def request_headers(self, api_key, method): if self.api_version is not None: headers["Stripe-Version"] = self.api_version - if stripe.enable_telemetry and self._last_request_metrics: - headers["X-Stripe-Client-Telemetry"] = json.dumps( - {"last_request_metrics": self._last_request_metrics.payload()} - ) - return headers def request_raw(self, method, url, params=None, supplied_headers=None): @@ -351,8 +339,6 @@ def request_raw(self, method, url, params=None, supplied_headers=None): api_version=self.api_version, ) - request_start = _now_ms() - rbody, rcode, rheaders = self._client.request_with_retries( method, abs_url, headers, post_data ) @@ -366,11 +352,6 @@ def request_raw(self, method, url, params=None, supplied_headers=None): "Dashboard link for request", link=util.dashboard_link(request_id), ) - if stripe.enable_telemetry: - request_duration_ms = _now_ms() - request_start - self._last_request_metrics = RequestMetrics( - request_id, request_duration_ms - ) return rbody, rcode, rheaders, my_api_key diff --git a/stripe/http_client.py b/stripe/http_client.py index b269f244e..f6a944df0 100644 --- a/stripe/http_client.py +++ b/stripe/http_client.py @@ -7,9 +7,11 @@ import time import random import threading +import json import stripe from stripe import error, util, six +from stripe.request_metrics import RequestMetrics # - Requests is the preferred HTTP library # - Google App Engine has urlfetch @@ -61,6 +63,10 @@ from stripe.six.moves.urllib.parse import urlparse +def _now_ms(): + return int(round(time.time() * 1000)) + + def new_default_http_client(*args, **kwargs): if urlfetch: impl = UrlFetchClient @@ -105,9 +111,13 @@ def __init__(self, verify_ssl_certs=True, proxy=None): self._thread_local = threading.local() def request_with_retries(self, method, url, headers, post_data=None): + self._add_telemetry_header(headers) + num_retries = 0 while True: + request_start = _now_ms() + try: num_retries += 1 response = self.request(method, url, headers, post_data) @@ -134,6 +144,8 @@ def request_with_retries(self, method, url, headers, post_data=None): time.sleep(sleep_time) else: if response is not None: + self._record_request_metrics(response, request_start) + return response else: raise connection_error @@ -182,6 +194,25 @@ def _add_jitter_time(self, sleep_seconds): sleep_seconds *= 0.5 * (1 + random.uniform(0, 1)) return sleep_seconds + def _add_telemetry_header(self, headers): + last_request_metrics = getattr( + self._thread_local, "last_request_metrics", None + ) + if stripe.enable_telemetry and last_request_metrics: + telemetry = { + "last_request_metrics": last_request_metrics.payload() + } + headers["X-Stripe-Client-Telemetry"] = json.dumps(telemetry) + + def _record_request_metrics(self, response, request_start): + _, _, rheaders = response + if "Request-Id" in rheaders and stripe.enable_telemetry: + request_id = rheaders["Request-Id"] + request_duration_ms = _now_ms() - request_start + self._thread_local.last_request_metrics = RequestMetrics( + request_id, request_duration_ms + ) + def close(self): raise NotImplementedError( "HTTPClient subclasses must implement `close`" diff --git a/tests/test_api_requestor.py b/tests/test_api_requestor.py index 8e199a6cb..81c4628fc 100644 --- a/tests/test_api_requestor.py +++ b/tests/test_api_requestor.py @@ -45,7 +45,6 @@ def __init__( user_agent=None, app_info=None, idempotency_key=None, - client_telemetry=None, ): self.request_method = request_method self.api_key = api_key or stripe.api_key @@ -53,7 +52,6 @@ def __init__( self.user_agent = user_agent self.app_info = app_info self.idempotency_key = idempotency_key - self.client_telemetry = client_telemetry def __eq__(self, other): return ( @@ -63,7 +61,6 @@ def __eq__(self, other): and self._x_stripe_ua_contains_app_info(other) and self._idempotency_key_match(other) and self._extra_match(other) - and self._check_telemetry(other) ) def __repr__(self): @@ -87,8 +84,6 @@ def _keys_match(self, other): and self.request_method in self.METHOD_EXTRA_KEYS ): expected_keys.extend(self.METHOD_EXTRA_KEYS[self.request_method]) - if self.client_telemetry: - expected_keys.append("X-Stripe-Client-Telemetry") return sorted(other.keys()) == sorted(expected_keys) def _auth_match(self, other): @@ -121,24 +116,6 @@ def _extra_match(self, other): return True - def _check_telemetry(self, other): - if not self.client_telemetry: - return "X-Stripe-Client-Telemetry" not in other - - if "X-Stripe-Client-Telemetry" not in other: - return False - - telemetry = json.loads(other["X-Stripe-Client-Telemetry"]) - req_id = telemetry["last_request_metrics"]["request_id"] - - if req_id != self.client_telemetry["request_id"]: - return False - - if "request_duration_ms" not in telemetry["last_request_metrics"]: - return False - - return True - class QueryMatcher(object): def __init__(self, expected): @@ -413,32 +390,6 @@ def test_uses_headers(self, requestor, mock_response, check_call): requestor.request("get", self.valid_path, {}, {"foo": "bar"}) check_call("get", headers=APIHeaderMatcher(extra={"foo": "bar"})) - def test_telemetry_headers_disabled( - self, requestor, mock_response, check_call - ): - mock_response("{}", 200, headers={"Request-Id": 1}) - requestor.request("get", self.valid_path, {}) - check_call("get", headers=APIHeaderMatcher(client_telemetry=None)) - - mock_response("{}", 200, headers={"Request-Id": 2}) - requestor.request("get", self.valid_path, {}) - check_call("get", headers=APIHeaderMatcher(client_telemetry=None)) - - def test_telemetry_headers_enabled( - self, requestor, mock_response, check_call - ): - stripe.enable_telemetry = True - - mock_response("{}", 200, headers={"Request-Id": 1}) - requestor.request("get", self.valid_path, {}) - check_call("get", headers=APIHeaderMatcher(client_telemetry=None)) - - mock_response("{}", 200, headers={"Request-Id": 2}) - requestor.request("get", self.valid_path, {}) - check_call( - "get", headers=APIHeaderMatcher(client_telemetry={"request_id": 1}) - ) - def test_uses_instance_key(self, http_client, mock_response, check_call): key = "fookey" requestor = stripe.api_requestor.APIRequestor(key, client=http_client) diff --git a/tests/test_http_client.py b/tests/test_http_client.py index 8d26b17c9..c65511102 100644 --- a/tests/test_http_client.py +++ b/tests/test_http_client.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import pytest +import json import stripe from stripe import six @@ -163,6 +164,43 @@ def test_should_retry_on_num_retries(self, mocker): ) +class TestHTTPClient(object): + @pytest.fixture(autouse=True) + def setup_stripe(self): + orig_attrs = {"enable_telemetry": stripe.enable_telemetry} + stripe.enable_telemetry = False + yield + stripe.enable_telemetry = orig_attrs["enable_telemetry"] + + def test_sends_telemetry_on_second_request(self, mocker): + class TestClient(stripe.http_client.HTTPClient): + pass + + stripe.enable_telemetry = True + + url = "http://fake.url" + + client = TestClient() + + client.request = mocker.MagicMock( + return_value=["", 200, {"Request-Id": "req_123"}] + ) + _, code, _ = client.request_with_retries("get", url, {}, None) + assert code == 200 + client.request.assert_called_with("get", url, {}, None) + + client.request = mocker.MagicMock( + return_value=["", 200, {"Request-Id": "req_234"}] + ) + _, code, _ = client.request_with_retries("get", url, {}, None) + assert code == 200 + args, _ = client.request.call_args + assert "X-Stripe-Client-Telemetry" in args[2] + + telemetry = json.loads(args[2]["X-Stripe-Client-Telemetry"]) + assert telemetry["last_request_metrics"]["request_id"] == "req_123" + + class ClientTestBase(object): @pytest.fixture def request_mock(self, request_mocks): @@ -247,6 +285,7 @@ def mock_response(mock, body, code): result = mocker.Mock() result.content = body result.status_code = code + result.headers = {} session.request = mocker.MagicMock(return_value=result) mock.Session = mocker.MagicMock(return_value=session) @@ -303,10 +342,11 @@ def test_timeout(self, request_mock, mock_response, check_call): class TestRequestClientRetryBehavior(TestRequestsClient): @pytest.fixture def response(self, mocker): - def response(code=200): + def response(code=200, headers={}): result = mocker.Mock() result.content = "{}" result.status_code = code + result.headers = headers return result return response @@ -461,6 +501,7 @@ def mock_response(mock, body, code): result = mocker.Mock() result.content = body result.status_code = code + result.headers = {} mock.fetch = mocker.Mock(return_value=result) diff --git a/tests/test_integration.py b/tests/test_integration.py index 4ca4e28d2..5355c5aa0 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,7 +1,10 @@ +from __future__ import absolute_import, division, print_function + import sys -from threading import Thread +from threading import Thread, Lock import json import warnings +import time import stripe import pytest @@ -27,16 +30,22 @@ def setup_stripe(self): "api_base": stripe.api_base, "api_key": stripe.api_key, "default_http_client": stripe.default_http_client, + "enable_telemetry": stripe.enable_telemetry, + "max_network_retries": stripe.max_network_retries, "proxy": stripe.proxy, } stripe.api_base = "http://localhost:12111" # stripe-mock stripe.api_key = "sk_test_123" stripe.default_http_client = None + stripe.enable_telemetry = False + stripe.max_network_retries = 3 stripe.proxy = None yield stripe.api_base = orig_attrs["api_base"] stripe.api_key = orig_attrs["api_key"] stripe.default_http_client = orig_attrs["default_http_client"] + stripe.enable_telemetry = orig_attrs["enable_telemetry"] + stripe.max_network_retries = orig_attrs["max_network_retries"] stripe.proxy = orig_attrs["proxy"] def setup_mock_server(self, handler): @@ -126,3 +135,118 @@ def do_GET(self): ) stripe.Balance.retrieve() assert MockServerRequestHandler.num_requests == 1 + + def test_passes_client_telemetry_when_enabled(self): + class MockServerRequestHandler(BaseHTTPRequestHandler): + num_requests = 0 + + def do_GET(self): + try: + self.__class__.num_requests += 1 + req_num = self.__class__.num_requests + if req_num == 1: + time.sleep(31 / 1000) # 31 ms + assert not self.headers.get( + "X-Stripe-Client-Telemetry" + ) + elif req_num == 2: + assert self.headers.get("X-Stripe-Client-Telemetry") + telemetry = json.loads( + self.headers.get("x-stripe-client-telemetry") + ) + assert "last_request_metrics" in telemetry + req_id = telemetry["last_request_metrics"][ + "request_id" + ] + duration_ms = telemetry["last_request_metrics"][ + "request_duration_ms" + ] + assert req_id == "req_1" + # The first request took 31 ms, so the client perceived + # latency shouldn't be outside this range. + assert 30 < duration_ms < 300 + else: + assert False, ( + "Should not have reached request %d" % req_num + ) + + self.send_response(200) + self.send_header( + "Content-Type", "application/json; charset=utf-8" + ) + self.send_header("Request-Id", "req_%d" % req_num) + self.end_headers() + self.wfile.write(json.dumps({}).encode("utf-8")) + except AssertionError as ex: + # Throwing assertions on the server side causes a + # connection error to be logged instead of an assertion + # failure. Instead, we return the assertion failure as + # json so it can be logged as a StripeError. + self.send_response(400) + self.send_header( + "Content-Type", "application/json; charset=utf-8" + ) + self.end_headers() + self.wfile.write( + json.dumps( + { + "error": { + "type": "invalid_request_error", + "message": str(ex), + } + } + ).encode("utf-8") + ) + + self.setup_mock_server(MockServerRequestHandler) + stripe.api_base = "http://localhost:%s" % self.mock_server_port + stripe.enable_telemetry = True + + stripe.Balance.retrieve() + stripe.Balance.retrieve() + assert MockServerRequestHandler.num_requests == 2 + + def test_uses_thread_local_client_telemetry(self): + class MockServerRequestHandler(BaseHTTPRequestHandler): + num_requests = 0 + seen_metrics = set() + stats_lock = Lock() + + def do_GET(self): + with self.__class__.stats_lock: + self.__class__.num_requests += 1 + req_num = self.__class__.num_requests + + if self.headers.get("X-Stripe-Client-Telemetry"): + telemetry = json.loads( + self.headers.get("X-Stripe-Client-Telemetry") + ) + req_id = telemetry["last_request_metrics"]["request_id"] + with self.__class__.stats_lock: + self.__class__.seen_metrics.add(req_id) + + self.send_response(200) + self.send_header( + "Content-Type", "application/json; charset=utf-8" + ) + self.send_header("Request-Id", "req_%d" % req_num) + self.end_headers() + self.wfile.write(json.dumps({}).encode("utf-8")) + + self.setup_mock_server(MockServerRequestHandler) + stripe.api_base = "http://localhost:%s" % self.mock_server_port + stripe.enable_telemetry = True + stripe.default_http_client = stripe.http_client.RequestsClient() + + def work(): + stripe.Balance.retrieve() + stripe.Balance.retrieve() + + threads = [Thread(target=work) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert MockServerRequestHandler.num_requests == 20 + assert len(MockServerRequestHandler.seen_metrics) == 10