Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Stripe Client Telemetry #530

Merged
merged 7 commits into from
Jan 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions stripe/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from stripe import error, oauth_error, http_client, version, util, six
from stripe.multipart_data_generator import MultipartDataGenerator
from stripe.six.moves.urllib.parse import urlencode, urlsplit, urlunsplit
from stripe.request_metrics import RequestMetrics
from stripe.stripe_response import StripeResponse


Expand All @@ -32,10 +31,6 @@ def _encode_nested_dict(key, data, fmt="%s[%s]"):
return d


def _now_ms():
return int(round(time.time() * 1000))


def _api_encode(data):
for key, value in six.iteritems(data):
key = util.utf8(key)
Expand Down Expand Up @@ -110,8 +105,6 @@ def __init__(
self._client = stripe.default_http_client
self._default_proxy = proxy

self._last_request_metrics = None

@classmethod
def format_app_info(cls, info):
str = info["name"]
Expand Down Expand Up @@ -277,11 +270,6 @@ def request_headers(self, api_key, method):
if self.api_version is not None:
headers["Stripe-Version"] = self.api_version

if stripe.enable_telemetry and self._last_request_metrics:
headers["X-Stripe-Client-Telemetry"] = json.dumps(
{"last_request_metrics": self._last_request_metrics.payload()}
)

return headers

def request_raw(self, method, url, params=None, supplied_headers=None):
Expand Down Expand Up @@ -351,8 +339,6 @@ def request_raw(self, method, url, params=None, supplied_headers=None):
api_version=self.api_version,
)

request_start = _now_ms()

rbody, rcode, rheaders = self._client.request_with_retries(
method, abs_url, headers, post_data
)
Expand All @@ -366,11 +352,6 @@ def request_raw(self, method, url, params=None, supplied_headers=None):
"Dashboard link for request",
link=util.dashboard_link(request_id),
)
if stripe.enable_telemetry:
request_duration_ms = _now_ms() - request_start
self._last_request_metrics = RequestMetrics(
request_id, request_duration_ms
)

return rbody, rcode, rheaders, my_api_key

Expand Down
31 changes: 31 additions & 0 deletions stripe/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import time
import random
import threading
import json

import stripe
from stripe import error, util, six
from stripe.request_metrics import RequestMetrics

# - Requests is the preferred HTTP library
# - Google App Engine has urlfetch
Expand Down Expand Up @@ -61,6 +63,10 @@
from stripe.six.moves.urllib.parse import urlparse


def _now_ms():
return int(round(time.time() * 1000))


def new_default_http_client(*args, **kwargs):
if urlfetch:
impl = UrlFetchClient
Expand Down Expand Up @@ -105,9 +111,13 @@ def __init__(self, verify_ssl_certs=True, proxy=None):
self._thread_local = threading.local()

def request_with_retries(self, method, url, headers, post_data=None):
self._add_telemetry_header(headers)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just now that I'm taking a second look at this, would you mind moving this into a helper method called something like _add_telemetry_header(headers)? Just so we have less code in the main body here.

num_retries = 0

while True:
request_start = _now_ms()

try:
num_retries += 1
response = self.request(method, url, headers, post_data)
Expand All @@ -134,6 +144,8 @@ def request_with_retries(self, method, url, headers, post_data=None):
time.sleep(sleep_time)
else:
if response is not None:
self._record_request_metrics(response, request_start)

return response
else:
raise connection_error
Expand Down Expand Up @@ -182,6 +194,25 @@ def _add_jitter_time(self, sleep_seconds):
sleep_seconds *= 0.5 * (1 + random.uniform(0, 1))
return sleep_seconds

def _add_telemetry_header(self, headers):
last_request_metrics = getattr(
self._thread_local, "last_request_metrics", None
)
if stripe.enable_telemetry and last_request_metrics:
telemetry = {
"last_request_metrics": last_request_metrics.payload()
}
headers["X-Stripe-Client-Telemetry"] = json.dumps(telemetry)

def _record_request_metrics(self, response, request_start):
_, _, rheaders = response
if "Request-Id" in rheaders and stripe.enable_telemetry:
request_id = rheaders["Request-Id"]
request_duration_ms = _now_ms() - request_start
self._thread_local.last_request_metrics = RequestMetrics(
request_id, request_duration_ms
)

def close(self):
raise NotImplementedError(
"HTTPClient subclasses must implement `close`"
Expand Down
49 changes: 0 additions & 49 deletions tests/test_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,13 @@ def __init__(
user_agent=None,
app_info=None,
idempotency_key=None,
client_telemetry=None,
):
self.request_method = request_method
self.api_key = api_key or stripe.api_key
self.extra = extra
self.user_agent = user_agent
self.app_info = app_info
self.idempotency_key = idempotency_key
self.client_telemetry = client_telemetry

def __eq__(self, other):
return (
Expand All @@ -63,7 +61,6 @@ def __eq__(self, other):
and self._x_stripe_ua_contains_app_info(other)
and self._idempotency_key_match(other)
and self._extra_match(other)
and self._check_telemetry(other)
)

def __repr__(self):
Expand All @@ -87,8 +84,6 @@ def _keys_match(self, other):
and self.request_method in self.METHOD_EXTRA_KEYS
):
expected_keys.extend(self.METHOD_EXTRA_KEYS[self.request_method])
if self.client_telemetry:
expected_keys.append("X-Stripe-Client-Telemetry")
return sorted(other.keys()) == sorted(expected_keys)

def _auth_match(self, other):
Expand Down Expand Up @@ -121,24 +116,6 @@ def _extra_match(self, other):

return True

def _check_telemetry(self, other):
if not self.client_telemetry:
return "X-Stripe-Client-Telemetry" not in other

if "X-Stripe-Client-Telemetry" not in other:
return False

telemetry = json.loads(other["X-Stripe-Client-Telemetry"])
req_id = telemetry["last_request_metrics"]["request_id"]

if req_id != self.client_telemetry["request_id"]:
return False

if "request_duration_ms" not in telemetry["last_request_metrics"]:
return False

return True


class QueryMatcher(object):
def __init__(self, expected):
Expand Down Expand Up @@ -413,32 +390,6 @@ def test_uses_headers(self, requestor, mock_response, check_call):
requestor.request("get", self.valid_path, {}, {"foo": "bar"})
check_call("get", headers=APIHeaderMatcher(extra={"foo": "bar"}))

def test_telemetry_headers_disabled(
self, requestor, mock_response, check_call
):
mock_response("{}", 200, headers={"Request-Id": 1})
requestor.request("get", self.valid_path, {})
check_call("get", headers=APIHeaderMatcher(client_telemetry=None))

mock_response("{}", 200, headers={"Request-Id": 2})
requestor.request("get", self.valid_path, {})
check_call("get", headers=APIHeaderMatcher(client_telemetry=None))

def test_telemetry_headers_enabled(
self, requestor, mock_response, check_call
):
stripe.enable_telemetry = True

mock_response("{}", 200, headers={"Request-Id": 1})
requestor.request("get", self.valid_path, {})
check_call("get", headers=APIHeaderMatcher(client_telemetry=None))

mock_response("{}", 200, headers={"Request-Id": 2})
requestor.request("get", self.valid_path, {})
check_call(
"get", headers=APIHeaderMatcher(client_telemetry={"request_id": 1})
)

def test_uses_instance_key(self, http_client, mock_response, check_call):
key = "fookey"
requestor = stripe.api_requestor.APIRequestor(key, client=http_client)
Expand Down
43 changes: 42 additions & 1 deletion tests/test_http_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

import pytest
import json

import stripe
from stripe import six
Expand Down Expand Up @@ -163,6 +164,43 @@ def test_should_retry_on_num_retries(self, mocker):
)


class TestHTTPClient(object):
@pytest.fixture(autouse=True)
def setup_stripe(self):
orig_attrs = {"enable_telemetry": stripe.enable_telemetry}
stripe.enable_telemetry = False
yield
stripe.enable_telemetry = orig_attrs["enable_telemetry"]

def test_sends_telemetry_on_second_request(self, mocker):
class TestClient(stripe.http_client.HTTPClient):
pass

stripe.enable_telemetry = True

url = "http://fake.url"

client = TestClient()

client.request = mocker.MagicMock(
return_value=["", 200, {"Request-Id": "req_123"}]
)
_, code, _ = client.request_with_retries("get", url, {}, None)
assert code == 200
client.request.assert_called_with("get", url, {}, None)

client.request = mocker.MagicMock(
return_value=["", 200, {"Request-Id": "req_234"}]
)
_, code, _ = client.request_with_retries("get", url, {}, None)
assert code == 200
args, _ = client.request.call_args
assert "X-Stripe-Client-Telemetry" in args[2]

telemetry = json.loads(args[2]["X-Stripe-Client-Telemetry"])
assert telemetry["last_request_metrics"]["request_id"] == "req_123"


class ClientTestBase(object):
@pytest.fixture
def request_mock(self, request_mocks):
Expand Down Expand Up @@ -247,6 +285,7 @@ def mock_response(mock, body, code):
result = mocker.Mock()
result.content = body
result.status_code = code
result.headers = {}

session.request = mocker.MagicMock(return_value=result)
mock.Session = mocker.MagicMock(return_value=session)
Expand Down Expand Up @@ -303,10 +342,11 @@ def test_timeout(self, request_mock, mock_response, check_call):
class TestRequestClientRetryBehavior(TestRequestsClient):
@pytest.fixture
def response(self, mocker):
def response(code=200):
def response(code=200, headers={}):
result = mocker.Mock()
result.content = "{}"
result.status_code = code
result.headers = headers
return result

return response
Expand Down Expand Up @@ -461,6 +501,7 @@ def mock_response(mock, body, code):
result = mocker.Mock()
result.content = body
result.status_code = code
result.headers = {}

mock.fetch = mocker.Mock(return_value=result)

Expand Down
Loading