Skip to content

Commit

Permalink
Store thread-local request metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshageman-stripe committed Jan 30, 2019
1 parent 5bc4162 commit b976801
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 7 deletions.
18 changes: 12 additions & 6 deletions stripe/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ def __init__(self, verify_ssl_certs=True, proxy=None):

self._thread_local = threading.local()

self._last_request_metrics = None

def request_with_retries(self, method, url, headers, post_data=None):
if stripe.enable_telemetry and self._last_request_metrics:
if stripe.enable_telemetry and self._last_request_metrics():
headers["X-Stripe-Client-Telemetry"] = json.dumps(
{"last_request_metrics": self._last_request_metrics.payload()}
{
"last_request_metrics": self._last_request_metrics().payload()
}
)

num_retries = 0
Expand Down Expand Up @@ -153,8 +153,8 @@ def request_with_retries(self, method, url, headers, post_data=None):
if "Request-Id" in rheaders and stripe.enable_telemetry:
request_id = rheaders["Request-Id"]
request_duration_ms = _now_ms() - request_start
self._last_request_metrics = RequestMetrics(
request_id, request_duration_ms
self._set_last_request_metrics(
RequestMetrics(request_id, request_duration_ms)
)

return response
Expand Down Expand Up @@ -205,6 +205,12 @@ def _add_jitter_time(self, sleep_seconds):
sleep_seconds *= 0.5 * (1 + random.uniform(0, 1))
return sleep_seconds

def _last_request_metrics(self):
return getattr(self._thread_local, "last_request_metrics", None)

def _set_last_request_metrics(self, metrics):
self._thread_local.last_request_metrics = metrics

def close(self):
raise NotImplementedError(
"HTTPClient subclasses must implement `close`"
Expand Down
50 changes: 49 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import absolute_import, division, print_function

import sys
from threading import Thread
from threading import Thread, Lock
import json
import warnings
import time
Expand Down Expand Up @@ -31,18 +31,21 @@ def setup_stripe(self):
"api_key": stripe.api_key,
"default_http_client": stripe.default_http_client,
"enable_telemetry": stripe.enable_telemetry,
"max_network_retries": stripe.max_network_retries,
"proxy": stripe.proxy,
}
stripe.api_base = "http://localhost:12111" # stripe-mock
stripe.api_key = "sk_test_123"
stripe.default_http_client = None
stripe.enable_telemetry = False
stripe.max_network_retries = 3
stripe.proxy = None
yield
stripe.api_base = orig_attrs["api_base"]
stripe.api_key = orig_attrs["api_key"]
stripe.default_http_client = orig_attrs["default_http_client"]
stripe.enable_telemetry = orig_attrs["enable_telemetry"]
stripe.max_network_retries = orig_attrs["max_network_retries"]
stripe.proxy = orig_attrs["proxy"]

def setup_mock_server(self, handler):
Expand Down Expand Up @@ -202,3 +205,48 @@ def do_GET(self):
stripe.Balance.retrieve()
stripe.Balance.retrieve()
assert MockServerRequestHandler.num_requests == 2

def test_uses_thread_local_client_telemetry(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0
seen_metrics = set()
stats_lock = Lock()

def do_GET(self):
with self.__class__.stats_lock:
self.__class__.num_requests += 1
req_num = self.__class__.num_requests

if self.headers.get("X-Stripe-Client-Telemetry"):
telemetry = json.loads(
self.headers.get("X-Stripe-Client-Telemetry")
)
req_id = telemetry["last_request_metrics"]["request_id"]
with self.__class__.stats_lock:
self.__class__.seen_metrics.add(req_id)

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.send_header("Request-Id", "req_%d" % req_num)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))

self.setup_mock_server(MockServerRequestHandler)
stripe.api_base = "http://localhost:%s" % self.mock_server_port
stripe.enable_telemetry = True
stripe.default_http_client = stripe.http_client.RequestsClient()

def work():
stripe.Balance.retrieve()
stripe.Balance.retrieve()

threads = [Thread(target=work) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()

assert MockServerRequestHandler.num_requests == 20
assert len(MockServerRequestHandler.seen_metrics) == 10

0 comments on commit b976801

Please sign in to comment.