Skip to content

Commit

Permalink
Refactor PT2 code changes (#2222)
Browse files Browse the repository at this point in the history
* Refactor PT2 code changes
  • Loading branch information
msaroufim committed Apr 19, 2023
1 parent 2fa042b commit cf7544b
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 119 deletions.
20 changes: 10 additions & 10 deletions examples/pt2/README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
## PyTorch 2.x integration

PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental until the official release and while we are relying on the nightly builds.
PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental given that most public benchmarks have focused on training instead of inference.

We strongly recommend you leverage newer hardware so for GPUs that would be an Ampere architecture. You'll get even more benefits from using server GPU deployments like A10G and A100 vs consumer cards. But you should expect to see some speedups for any Volta or Ampere architecture.

## Get started

Install torchserve with nightly torch binaries
Install torchserve and ensure that you're using at least `torch>=2.0.0`

```
python ts_scripts/install_dependencies.py --cuda=cu117 --nightly_torch
```sh
python ts_scripts/install_dependencies.py --cuda=cu117
pip install torchserve torch-model-archiver
```

## Package your model

PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `compile.json` during your model packaging
PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `model_config.yaml` during your model packaging

`{"pt2" : "inductor"}`
`pt2: "inductor"`

As an example let's expand our getting started guide with the only difference being passing in the extra `compile.json` file
As an example let's expand our getting started guide with the only difference being passing in the extra `model_config.yaml` file

```
mkdir model_store
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json,./serve/examples/image_classifier/compile.json --handler image_classifier
torchserve --start --ncs --model-store model_store --models densenet161.mar
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json --handler image_classifier
torchserve --start --ncs --model-store model_store --models densenet161.mar --config-file model_config.yaml
```

The exact same approach works with any other model, what's going on is the below
Expand All @@ -35,7 +35,7 @@ opt_mod = torch.compile(mod)
# 2. Train the optimized module
# ....
# 3. Save the original module (weights are shared)
torch.save(model, "model.pt")
torch.save(model, "model.pt")

# 4. Load the non optimized model
mod = torch.load(model)
Expand Down
20 changes: 20 additions & 0 deletions test/pytest/test_data/torch_compile/compile_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

from ts.torch_handler.base_handler import BaseHandler


class CompileHandler(BaseHandler):
def __init__(self):
super().__init__()

def initialize(self, context):
super().initialize(context)

def preprocess(self, data):
instances = data[0]["body"]["instances"]
input_tensor = torch.as_tensor(instances, dtype=torch.float32)
return input_tensor

def postprocess(self, data):
# Convert the output tensor to a list and return
return data.tolist()[2]
File renamed without changes.
1 change: 1 addition & 0 deletions test/pytest/test_data/torch_compile/pt2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pt2 : "inductor"
1 change: 1 addition & 0 deletions test/pytest/test_data/torch_compile/xla.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pt2 : "torchxla_trace_once"
1 change: 0 additions & 1 deletion test/pytest/test_data/torch_xla/compile.json

This file was deleted.

104 changes: 104 additions & 0 deletions test/pytest/test_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import glob
import json
import os
import subprocess
import time
from pathlib import Path

import pytest
import torch
from pkg_resources import packaging

PT_2_AVAILABLE = (
True
if packaging.version.parse(torch.__version__) >= packaging.version.parse("2.0")
else False
)

CURR_FILE_PATH = Path(__file__).parent
TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data", "torch_compile")

MODEL_FILE = os.path.join(TEST_DATA_DIR, "model.py")
HANDLER_FILE = os.path.join(TEST_DATA_DIR, "compile_handler.py")
YAML_CONFIG = os.path.join(TEST_DATA_DIR, "pt2.yaml")


SERIALIZED_FILE = os.path.join(TEST_DATA_DIR, "model.pt")
MODEL_STORE_DIR = os.path.join(TEST_DATA_DIR, "model_store")
MODEL_NAME = "half_plus_two"


@pytest.mark.skipif(PT_2_AVAILABLE == False, reason="torch version is < 2.0.0")
class TestTorchCompile:
def teardown_class(self):
subprocess.run("torchserve --stop", shell=True, check=True)
time.sleep(10)

def test_archive_model_artifacts(self):
assert len(glob.glob(MODEL_FILE)) == 1
assert len(glob.glob(YAML_CONFIG)) == 1
subprocess.run(f"cd {TEST_DATA_DIR} && python model.py", shell=True, check=True)
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True)
subprocess.run(
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG} --export-path {MODEL_STORE_DIR} --handler {HANDLER_FILE} -f",
shell=True,
check=True,
)
assert len(glob.glob(SERIALIZED_FILE)) == 1
assert len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}.mar"))) == 1

def test_start_torchserve(self):
cmd = f"torchserve --start --ncs --models {MODEL_NAME}.mar --model-store {MODEL_STORE_DIR}"
subprocess.run(
cmd,
shell=True,
check=True,
)
time.sleep(10)
assert len(glob.glob("logs/access_log.log")) == 1
assert len(glob.glob("logs/model_log.log")) == 1
assert len(glob.glob("logs/ts_log.log")) == 1

def test_server_status(self):
result = subprocess.run(
"curl http://localhost:8080/ping",
shell=True,
capture_output=True,
check=True,
)
expected_server_status_str = '{"status": "Healthy"}'
expected_server_status = json.loads(expected_server_status_str)
assert json.loads(result.stdout) == expected_server_status

def test_registered_model(self):
result = subprocess.run(
"curl http://localhost:8081/models",
shell=True,
capture_output=True,
check=True,
)
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two", "modelUrl": "half_plus_two.mar"}]}'
expected_registered_model = json.loads(expected_registered_model_str)
assert json.loads(result.stdout) == expected_registered_model

def test_serve_inference(self):
request_data = {"instances": [[1.0], [2.0], [3.0]]}
request_json = json.dumps(request_data)

result = subprocess.run(
f"curl -s -X POST -H \"Content-Type: application/json;\" http://localhost:8080/predictions/half_plus_two -d '{request_json}'",
shell=True,
capture_output=True,
check=True,
)

string_result = result.stdout.decode("utf-8")
float_result = float(string_result)
expected_result = 3.5

assert float_result == expected_result

model_log_path = glob.glob("logs/model_log.log")[0]
with open(model_log_path, "rt") as model_log_file:
model_log = model_log_file.read()
assert "Compiled model with backend inductor" in model_log
8 changes: 4 additions & 4 deletions test/pytest/test_torch_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
TORCHXLA_AVAILABLE = False

CURR_FILE_PATH = Path(__file__).parent
TORCH_XLA_TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data")
TORCH_XLA_TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data", "torch_compile")

MODEL_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.py")
EXTRA_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "compile.json")
YAML_CONFIG = os.path.join(TORCH_XLA_TEST_DATA_DIR, "xla.yaml")
CONFIG_PROPERTIES = os.path.join(TORCH_XLA_TEST_DATA_DIR, "config.properties")

SERIALIZED_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.pt")
Expand All @@ -40,14 +40,14 @@ def teardown_class(self):

def test_archive_model_artifacts(self):
assert len(glob.glob(MODEL_FILE)) == 1
assert len(glob.glob(EXTRA_FILE)) == 1
assert len(glob.glob(YAML_CONFIG)) == 1
assert len(glob.glob(CONFIG_PROPERTIES)) == 1
subprocess.run(
f"cd {TORCH_XLA_TEST_DATA_DIR} && python model.py", shell=True, check=True
)
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True)
subprocess.run(
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --extra-files {EXTRA_FILE} --export-path {MODEL_STORE_DIR} --handler base_handler -f",
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG} --export-path {MODEL_STORE_DIR} --handler base_handler -f",
shell=True,
check=True,
)
Expand Down
28 changes: 19 additions & 9 deletions test/pytest/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,41 @@ def model_archiver_command_builder(
handler=None,
extra_files=None,
force=False,
config_file=None,
):
cmd = "torch-model-archiver"
# Initialize a list to store the command-line arguments
cmd_parts = ["torch-model-archiver"]

# Append arguments to the list
if model_name:
cmd += " --model-name {0}".format(model_name)
cmd_parts.append(f"--model-name {model_name}")

if version:
cmd += " --version {0}".format(version)
cmd_parts.append(f"--version {version}")

if model_file:
cmd += " --model-file {0}".format(model_file)
cmd_parts.append(f"--model-file {model_file}")

if serialized_file:
cmd += " --serialized-file {0}".format(serialized_file)
cmd_parts.append(f"--serialized-file {serialized_file}")

if handler:
cmd += " --handler {0}".format(handler)
cmd_parts.append(f"--handler {handler}")

if extra_files:
cmd += " --extra-files {0}".format(extra_files)
cmd_parts.append(f"--extra-files {extra_files}")

if config_file:
cmd_parts.append(f"--config-file {config_file}")

if force:
cmd += " --force"
cmd_parts.append("--force")

# Append the export-path argument to the list
cmd_parts.append(f"--export-path {MODEL_STORE}")

cmd += " --export-path {0}".format(MODEL_STORE)
# Convert the list into a string to represent the complete command
cmd = " ".join(cmd_parts)

return cmd

Expand Down
Loading

0 comments on commit cf7544b

Please sign in to comment.