Skip to content

Commit

Permalink
[Refactor] earlier fork
Browse files Browse the repository at this point in the history
  • Loading branch information
bojiang committed Mar 2, 2021
1 parent b040add commit 78433d8
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 114 deletions.
62 changes: 44 additions & 18 deletions bentoml/cli/bento_service.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
import argparse
import click
import sys

import json
import multiprocessing
import re
import sys

import click
import psutil

from bentoml import __version__
from bentoml.configuration import BENTOML_CONFIG
from bentoml.configuration.containers import BentoMLConfiguration, BentoMLContainer
from bentoml.utils.lazy_loader import LazyLoader
from bentoml.server import start_dev_server, start_prod_server
from bentoml.server.open_api import get_open_api_spec_json
from bentoml.utils import (
ProtoMessageToDict,
resolve_bundle_path,
)
from bentoml.cli.click_utils import (
CLI_COLOR_SUCCESS,
_echo,
BentoMLCommandGroup,
_echo,
conditional_argument,
)
from bentoml.cli.utils import Spinner
from bentoml.configuration import BENTOML_CONFIG
from bentoml.configuration.containers import BentoMLConfiguration, BentoMLContainer
from bentoml.saved_bundle import (
load_from_dir,
load_bento_service_api,
load_bento_service_metadata,
load_from_dir,
)
from bentoml.server import (
start_dev_server,
start_prod_batching_server,
start_prod_server,
)
from bentoml.server.open_api import get_open_api_spec_json
from bentoml.utils import ProtoMessageToDict, reserve_free_port, resolve_bundle_path
from bentoml.utils.docker_utils import validate_tag
from bentoml.utils.lazy_loader import LazyLoader
from bentoml.yatai.client import get_yatai_client

try:
Expand Down Expand Up @@ -322,11 +324,35 @@ def serve_gunicorn(
config.override(["marshal_server", "workers"], microbatch_workers)
container.config.from_dict(config.as_dict())

from bentoml import marshal, server

container.wire(packages=[marshal, server])
if enable_microbatch:
from bentoml import marshal, server

container.wire(packages=[marshal, server])
prometheus_lock = multiprocessing.Lock()
with reserve_free_port() as api_server_port:
pass

model_server_job = multiprocessing.Process(
target=start_prod_server,
kwargs=dict(
saved_bundle_path=saved_bundle_path,
port=api_server_port,
prometheus_lock=prometheus_lock,
),
daemon=True,
)
model_server_job.start()
start_prod_batching_server(
saved_bundle_path=saved_bundle_path,
api_server_port=api_server_port,
prometheus_lock=prometheus_lock,
)
model_server_job.join()
else:
from bentoml import server

start_prod_server(saved_bundle_path)
container.wire(packages=[server])
start_prod_server(saved_bundle_path)

@bentoml_cli.command(
help="Install shell command completion",
Expand Down
34 changes: 28 additions & 6 deletions bentoml/marshal/marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def __init__(
BentoMLContainer.config.api_server.max_request_size
],
zipkin_api_url: str = Provide[BentoMLContainer.config.tracing.zipkin_api_url],
outbound_unix_socket: str = None,
):
self.outbound_unix_socket = outbound_unix_socket
self.outbound_host = outbound_host
self.outbound_port = outbound_port
self.outbound_workers = outbound_workers
Expand Down Expand Up @@ -178,6 +180,7 @@ def __init__(
"or launch more microbatch instances to accept more concurrent connection.",
self.CONNECTION_LIMIT,
)
self._client = None

def set_outbound_port(self, outbound_port):
self.outbound_port = outbound_port
Expand All @@ -187,6 +190,22 @@ def fetch_sema(self):
self._outbound_sema = NonBlockSema(self.outbound_workers)
return self._outbound_sema

def get_client(self):
if self._client is None:
jar = aiohttp.DummyCookieJar()
if self.outbound_unix_socket:
conn = aiohttp.UnixConnector(path=self.outbound_unix_socket,)
else:
conn = aiohttp.TCPConnector(limit=30)
self._client = aiohttp.ClientSession(
connector=conn, auto_decompress=False, cookie_jar=jar,
)
return self._client

def __del__(self):
if self._client is not None and not self._client.closed:
self._client.close()

def add_batch_handler(self, api_route, max_latency, max_batch_size):
'''
Params:
Expand Down Expand Up @@ -268,11 +287,16 @@ async def relay_handler(self, request):
span_name=f"[2]{url.path} relay",
) as trace_ctx:
headers.update(make_http_headers(trace_ctx))
async with aiohttp.ClientSession(auto_decompress=False) as client:
try:
client = self.get_client()
async with client.request(
request.method, url, data=data, headers=request.headers
) as resp:
body = await resp.read()
except aiohttp.client_exceptions.ClientConnectionError as e:
raise RemoteException(
e, payload=HTTPResponse(status=503, body=b"Service Unavailable")
)
return aiohttp.web.Response(
status=resp.status, body=body, headers=resp.headers,
)
Expand All @@ -298,11 +322,9 @@ async def _batch_handler_template(self, requests, api_route):
headers.update(make_http_headers(trace_ctx))
reqs_s = DataLoader.merge_requests(requests)
try:
async with aiohttp.ClientSession(auto_decompress=False) as client:
async with client.post(
api_url, data=reqs_s, headers=headers
) as resp:
raw = await resp.read()
client = self.get_client()
async with client.post(api_url, data=reqs_s, headers=headers) as resp:
raw = await resp.read()
except aiohttp.client_exceptions.ClientConnectionError as e:
raise RemoteException(
e, payload=HTTPResponse(status=503, body=b"Service Unavailable")
Expand Down
107 changes: 54 additions & 53 deletions bentoml/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.

import logging
from dependency_injector.wiring import inject, Provide
import multiprocessing
from typing import Optional

from dependency_injector.wiring import Provide, inject

from bentoml.configuration.containers import BentoMLContainer

Expand All @@ -36,8 +39,6 @@ def start_dev_server(
):
logger.info("Starting BentoML API server in development mode..")

import multiprocessing

from bentoml.saved_bundle import load_from_dir
from bentoml.server.api_server import BentoAPIServer
from bentoml.utils import reserve_free_port
Expand Down Expand Up @@ -70,16 +71,12 @@ def start_dev_server(
marshal_proc.start()

bento_service = load_from_dir(saved_bundle_path)
api_server = BentoAPIServer(
bento_service, port=api_server_port, enable_swagger=enable_swagger
)
api_server.start()
api_server = BentoAPIServer(bento_service, enable_swagger=enable_swagger)
api_server.start(port=api_server_port)
else:
bento_service = load_from_dir(saved_bundle_path)
api_server = BentoAPIServer(
bento_service, port=port, enable_swagger=enable_swagger
)
api_server.start()
api_server = BentoAPIServer(bento_service, enable_swagger=enable_swagger)
api_server.start(port=port)


def start_dev_batching_server(
Expand Down Expand Up @@ -110,58 +107,62 @@ def start_prod_server(
port: int = Provide[BentoMLContainer.config.api_server.port],
timeout: int = Provide[BentoMLContainer.config.api_server.timeout],
workers: int = Provide[BentoMLContainer.api_server_workers],
enable_microbatch: bool = Provide[
BentoMLContainer.config.api_server.enable_microbatch
],
mb_max_batch_size: int = Provide[
BentoMLContainer.config.marshal_server.max_batch_size
],
mb_max_latency: int = Provide[BentoMLContainer.config.marshal_server.max_latency],
microbatch_workers: int = Provide[BentoMLContainer.config.marshal_server.workers],
enable_swagger: bool = Provide[BentoMLContainer.config.api_server.enable_swagger],
prometheus_lock: Optional[multiprocessing.Lock] = None,
):
logger.info("Starting BentoML API server in production mode..")

import multiprocessing

import psutil

assert (
psutil.POSIX
), "BentoML API Server production mode only supports POSIX platforms"

from bentoml.server.gunicorn_server import GunicornBentoServer
from bentoml.server.marshal_server import GunicornMarshalServer
from bentoml.utils import reserve_free_port

if enable_microbatch:
prometheus_lock = multiprocessing.Lock()
# avoid load model before gunicorn fork
with reserve_free_port() as api_server_port:
marshal_server = GunicornMarshalServer(
bundle_path=saved_bundle_path,
port=port,
workers=microbatch_workers,
prometheus_lock=prometheus_lock,
outbound_host="localhost",
outbound_port=api_server_port,
outbound_workers=workers,
mb_max_batch_size=mb_max_batch_size,
mb_max_latency=mb_max_latency,
)
gunicorn_app = GunicornBentoServer(
saved_bundle_path,
bind=f"0.0.0.0:{port}",
workers=workers,
timeout=timeout,
enable_swagger=enable_swagger,
prometheus_lock=prometheus_lock,
)
gunicorn_app.run()

gunicorn_app = GunicornBentoServer(
saved_bundle_path,
api_server_port,
workers,
timeout,
prometheus_lock,
enable_swagger,
)
marshal_server.async_run()
gunicorn_app.run()
else:
gunicorn_app = GunicornBentoServer(
saved_bundle_path, port, workers, timeout, enable_swagger=enable_swagger
)
gunicorn_app.run()

@inject
def start_prod_batching_server(
saved_bundle_path: str,
api_server_port: int,
api_server_workers: int = Provide[BentoMLContainer.api_server_workers],
port: int = Provide[BentoMLContainer.config.api_server.port],
workers: int = Provide[BentoMLContainer.config.marshal_server.workers],
mb_max_batch_size: int = Provide[
BentoMLContainer.config.marshal_server.max_batch_size
],
mb_max_latency: int = Provide[BentoMLContainer.config.marshal_server.max_latency],
prometheus_lock: Optional[multiprocessing.Lock] = None,
):

import psutil

assert (
psutil.POSIX
), "BentoML Batching Server production mode only supports POSIX platforms"

from bentoml.server.marshal_server import GunicornMarshalServer

# avoid load model before gunicorn fork
marshal_server = GunicornMarshalServer(
bundle_path=saved_bundle_path,
port=port,
workers=workers,
prometheus_lock=prometheus_lock,
outbound_host="localhost",
outbound_port=api_server_port,
outbound_workers=api_server_workers,
mb_max_batch_size=mb_max_batch_size,
mb_max_latency=mb_max_latency,
)
marshal_server.run()
14 changes: 8 additions & 6 deletions bentoml/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import logging
import os
import sys
from dependency_injector.wiring import inject, Provide
from functools import partial

from dependency_injector.wiring import Provide, inject
from flask import Flask, Response, jsonify, make_response, request, send_from_directory
from google.protobuf.json_format import MessageToJson
from werkzeug.exceptions import BadRequest, NotFound
Expand All @@ -27,10 +27,10 @@
from bentoml.configuration.containers import BentoMLContainer
from bentoml.exceptions import BentoMLException
from bentoml.marshal.utils import DataLoader
from bentoml.tracing.__init__ import trace
from bentoml.server.instruments import InstrumentMiddleware
from bentoml.server.open_api import get_open_api_spec_json
from bentoml.service import InferenceAPI
from bentoml.tracing.__init__ import trace

CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")

Expand Down Expand Up @@ -150,7 +150,6 @@ class BentoAPIServer:
def __init__(
self,
bento_service: BentoService,
port: int = Provide[BentoMLContainer.config.api_server.port],
app_name: str = None,
enable_swagger: bool = True,
enable_metrics: bool = Provide[
Expand All @@ -165,7 +164,6 @@ def __init__(
):
app_name = bento_service.name if app_name is None else app_name

self.port = port
self.bento_service = bento_service
self.app = Flask(app_name, static_folder=None)
self.static_path = self.bento_service.get_web_static_content_path()
Expand All @@ -183,14 +181,18 @@ def __init__(

self.setup_routes()

def start(self):
def start(self, port: int, host: str = "127.0.0.1"):
"""
Start an REST server at the specific port on the instance or parameter.
"""
# Bentoml api service is not thread safe.
# Flask dev server enabled threaded by default, disable it.
self.app.run(
port=self.port, threaded=False, debug=get_debug_mode(), use_reloader=False,
host=host,
port=port,
threaded=False,
debug=get_debug_mode(),
use_reloader=False,
)

@staticmethod
Expand Down
Loading

0 comments on commit 78433d8

Please sign in to comment.