-
-
Notifications
You must be signed in to change notification settings - Fork 16.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1487bc8
commit 7e457e0
Showing
3 changed files
with
91 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
## Flask Rest API | ||
Rest API are commonly used to expose machine learning models to other services. This folder contains an example rest API created using Flask to expose the `yolov5s` model from pytorch hub. | ||
|
||
Install Flask in your environment and run: | ||
|
||
`$ python3 restapi.py --port 5000` | ||
|
||
Then use [curl](https://curl.se/) to perform a request: | ||
|
||
`$ curl -X POST -F image=@zidane.jpg 'http://localhost:5000/v1/object-detection/yolov5s'` | ||
|
||
The model inference results are returned: | ||
|
||
``` | ||
[{'class': 0, | ||
'confidence': 0.8197850585, | ||
'name': 'person', | ||
'xmax': 1159.1403808594, | ||
'xmin': 750.912902832, | ||
'ymax': 711.2583007812, | ||
'ymin': 44.0350036621}, | ||
{'class': 0, | ||
'confidence': 0.5667674541, | ||
'name': 'person', | ||
'xmax': 1065.5523681641, | ||
'xmin': 116.0448303223, | ||
'ymax': 713.8904418945, | ||
'ymin': 198.4603881836}, | ||
{'class': 27, | ||
'confidence': 0.5661227107, | ||
'name': 'tie', | ||
'xmax': 516.7975463867, | ||
'xmin': 416.6880187988, | ||
'ymax': 717.0524902344, | ||
'ymin': 429.2020568848}] | ||
``` | ||
|
||
An example python script to perform inference using [requests](https://docs.python-requests.org/en/master/) is given in `example_request.py` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
"""Perform test request""" | ||
import pprint | ||
import requests | ||
|
||
DETECTION_URL = "http://localhost:5000/v1/object-detection/yolov5s" | ||
TEST_IMAGE = "zidane.jpg" | ||
|
||
image_data = open(TEST_IMAGE, "rb").read() | ||
|
||
response = requests.post(DETECTION_URL, files={"image": image_data}).json() | ||
|
||
pprint.pprint(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""" | ||
Run a rest API exposing the yolov5s object detection model | ||
""" | ||
import argparse | ||
import io | ||
from PIL import Image | ||
|
||
import torch | ||
from flask import Flask, request | ||
|
||
app = Flask(__name__) | ||
|
||
DETECTION_URL = "/v1/object-detection/yolov5s" | ||
|
||
|
||
@app.route(DETECTION_URL, methods=["POST"]) | ||
def predict(): | ||
if not request.method == "POST": | ||
return | ||
|
||
if request.files.get("image"): | ||
image_file = request.files["image"] | ||
image_bytes = image_file.read() | ||
|
||
img = Image.open(io.BytesIO(image_bytes)) | ||
|
||
results = model(img, size=640) | ||
data = results.pandas().xyxy[0].to_json(orient="records") | ||
return data | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Flask api exposing yolov5 model") | ||
parser.add_argument("--port", default=5000, type=int, help="port number") | ||
args = parser.parse_args() | ||
|
||
model = torch.hub.load( | ||
"ultralytics/yolov5", "yolov5s", pretrained=True, force_reload=True | ||
).autoshape() # force_reload = recache latest code | ||
model.eval() | ||
app.run(host="0.0.0.0", port=args.port) # debug=True causes Restarting with stat |