Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Customizable request schema #1347

Merged
merged 7 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bentoml/adapters/base_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class BaseInputAdapter:
def __init__(self, http_input_example=None, **base_config):
self._config = base_config
self._http_input_example = http_input_example
self.custom_request_schema = base_config.get('request_schema')

@property
def config(self):
Expand Down
6 changes: 5 additions & 1 deletion bentoml/service/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ def request_schema(self):
"""
:return: the HTTP API request schema in OpenAPI/Swagger format
"""
schema = self.input_adapter.request_schema
if self.input_adapter.custom_request_schema is None:
schema = self.input_adapter.request_schema
else:
schema = self.input_adapter.custom_request_schema

if schema.get('application/json'):
schema.get('application/json')[
'example'
Expand Down
17 changes: 17 additions & 0 deletions tests/integration/projects/general/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ def predict_json(self, input_datas):
def customezed_route(self, input_datas):
return input_datas

CUSTOM_SCHEMA = {
"application/json": {
"schema": {
"type": "object",
"required": ["field1", "field2"],
"properties": {
"field1": {"type": "string"},
"field2": {"type": "uuid"},
},
},
}
}

@bentoml.api(input=JsonInput(request_schema=CUSTOM_SCHEMA), batch=True)
def customezed_schema(self, input_datas):
return input_datas

@bentoml.api(input=JsonInput(), batch=True)
def predict_strict_json(self, input_datas, tasks: Sequence[InferenceTask] = None):
filtered_jsons = []
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/projects/general/tests/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,17 @@ def path_in_docs(response_body):
data=json.dumps("hello"),
assert_data=bytes('"hello"', 'ascii'),
)


@pytest.mark.asyncio
async def test_customized_request_schema(host):
def has_customized_schema(doc_bytes):
json_str = doc_bytes.decode()
return "field1" in json_str

await pytest.assert_request(
"GET",
f"http://{host}/docs.json",
headers=(("Content-Type", "application/json"),),
assert_data=has_customized_schema,
)
22 changes: 11 additions & 11 deletions tests/integration/projects/general/tests/test_microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ async def test_slow_server(host):
await asyncio.gather(*tasks)
assert time.time() - time_start < 12


@pytest.mark.asyncio
async def test_fast_server(host):
if not pytest.enable_microbatch:
pytest.skip()

A, B = 0.0002, 0.01
data = '{"a": %s, "b": %s}' % (A, B)

req_count = 100
tasks = tuple(
pytest.assert_request(
Expand All @@ -46,17 +55,8 @@ async def test_slow_server(host):
)
await asyncio.gather(*tasks)


@pytest.mark.asyncio
async def test_fast_server(host):
if not pytest.enable_microbatch:
pytest.skip()

A, B = 0.0002, 0.01
data = '{"a": %s, "b": %s}' % (A, B)

time_start = time.time()
req_count = 500
req_count = 200
tasks = tuple(
pytest.assert_request(
"POST",
Expand All @@ -70,4 +70,4 @@ async def test_fast_server(host):
for i in range(req_count)
)
await asyncio.gather(*tasks)
assert time.time() - time_start < 5
assert time.time() - time_start < 2
24 changes: 14 additions & 10 deletions tests/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@

def _wait_until_api_server_ready(host_url, timeout, container=None, check_interval=1):
start_time = time.time()
proxy_handler = urllib.request.ProxyHandler({})
opener = urllib.request.build_opener(proxy_handler)
ex = None
while time.time() - start_time < timeout:
try:
if (
urllib.request.urlopen(f'http://{host_url}/healthz', timeout=1).status
== 200
):
break
if opener.open(f'http://{host_url}/healthz', timeout=1).status == 200:
return
elif container.status != "running":
break
else:
logger.info("Waiting for host %s to be ready..", host_url)
time.sleep(check_interval)
except Exception as e: # pylint:disable=broad-except
logger.info(f"'{e}', retrying to connect to the host {host_url}...")
logger.info(f"retrying to connect to the host {host_url}...")
ex = e
time.sleep(check_interval)
finally:
if container:
Expand All @@ -40,7 +41,8 @@ def _wait_until_api_server_ready(host_url, timeout, container=None, check_interv
logger.info(f">>> {log_record}")
else:
raise AssertionError(
f"Timed out waiting {timeout} seconds for Server {host_url} to be ready"
f"Timed out waiting {timeout} seconds for Server {host_url} to be ready, "
f"exception: {ex}"
)


Expand Down Expand Up @@ -148,6 +150,8 @@ def print_log(p):
) as p:
host_url = f"127.0.0.1:{port}"
threading.Thread(target=print_log, args=(p,), daemon=True).start()
_wait_until_api_server_ready(host_url, timeout=timeout)
yield host_url
p.terminate()
try:
_wait_until_api_server_ready(host_url, timeout=timeout)
yield host_url
finally:
p.terminate()