diff --git a/stripe/http_client.py b/stripe/http_client.py index 076d01947..c5255bf1e 100644 --- a/stripe/http_client.py +++ b/stripe/http_client.py @@ -8,6 +8,7 @@ import random import threading import json +import threading import stripe from stripe import error, util, six @@ -110,12 +111,12 @@ def __init__(self, verify_ssl_certs=True, proxy=None): self._thread_local = threading.local() - self._last_request_metrics = None - def request_with_retries(self, method, url, headers, post_data=None): - if stripe.enable_telemetry and self._last_request_metrics: + if stripe.enable_telemetry and self._last_request_metrics(): headers["X-Stripe-Client-Telemetry"] = json.dumps( - {"last_request_metrics": self._last_request_metrics.payload()} + { + "last_request_metrics": self._last_request_metrics().payload() + } ) num_retries = 0 @@ -153,8 +154,8 @@ def request_with_retries(self, method, url, headers, post_data=None): if "Request-Id" in rheaders and stripe.enable_telemetry: request_id = rheaders["Request-Id"] request_duration_ms = _now_ms() - request_start - self._last_request_metrics = RequestMetrics( - request_id, request_duration_ms + self._set_last_request_metrics( + RequestMetrics(request_id, request_duration_ms) ) return response @@ -205,6 +206,12 @@ def _add_jitter_time(self, sleep_seconds): sleep_seconds *= 0.5 * (1 + random.uniform(0, 1)) return sleep_seconds + def _last_request_metrics(self): + return getattr(self._thread_local, "last_request_metrics", None) + + def _set_last_request_metrics(self, metrics): + self._thread_local.last_request_metrics = metrics + def close(self): raise NotImplementedError( "HTTPClient subclasses must implement `close`" diff --git a/tests/test_integration.py b/tests/test_integration.py index e7ba65b2a..5355c5aa0 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, division, print_function import sys -from threading import Thread +from threading import Thread, Lock import json import warnings import time @@ -31,18 +31,21 @@ def setup_stripe(self): "api_key": stripe.api_key, "default_http_client": stripe.default_http_client, "enable_telemetry": stripe.enable_telemetry, + "max_network_retries": stripe.max_network_retries, "proxy": stripe.proxy, } stripe.api_base = "http://localhost:12111" # stripe-mock stripe.api_key = "sk_test_123" stripe.default_http_client = None stripe.enable_telemetry = False + stripe.max_network_retries = 3 stripe.proxy = None yield stripe.api_base = orig_attrs["api_base"] stripe.api_key = orig_attrs["api_key"] stripe.default_http_client = orig_attrs["default_http_client"] stripe.enable_telemetry = orig_attrs["enable_telemetry"] + stripe.max_network_retries = orig_attrs["max_network_retries"] stripe.proxy = orig_attrs["proxy"] def setup_mock_server(self, handler): @@ -202,3 +205,48 @@ def do_GET(self): stripe.Balance.retrieve() stripe.Balance.retrieve() assert MockServerRequestHandler.num_requests == 2 + + def test_uses_thread_local_client_telemetry(self): + class MockServerRequestHandler(BaseHTTPRequestHandler): + num_requests = 0 + seen_metrics = set() + stats_lock = Lock() + + def do_GET(self): + with self.__class__.stats_lock: + self.__class__.num_requests += 1 + req_num = self.__class__.num_requests + + if self.headers.get("X-Stripe-Client-Telemetry"): + telemetry = json.loads( + self.headers.get("X-Stripe-Client-Telemetry") + ) + req_id = telemetry["last_request_metrics"]["request_id"] + with self.__class__.stats_lock: + self.__class__.seen_metrics.add(req_id) + + self.send_response(200) + self.send_header( + "Content-Type", "application/json; charset=utf-8" + ) + self.send_header("Request-Id", "req_%d" % req_num) + self.end_headers() + self.wfile.write(json.dumps({}).encode("utf-8")) + + self.setup_mock_server(MockServerRequestHandler) + stripe.api_base = "http://localhost:%s" % self.mock_server_port + stripe.enable_telemetry = True + stripe.default_http_client = stripe.http_client.RequestsClient() + + def work(): + stripe.Balance.retrieve() + stripe.Balance.retrieve() + + threads = [Thread(target=work) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert MockServerRequestHandler.num_requests == 20 + assert len(MockServerRequestHandler.seen_metrics) == 10