Skip to content

Commit

Permalink
Refactor sanity checks to use pytest (#2221)
Browse files Browse the repository at this point in the history
* Remove wait time when stopping and starting torchserve in tests

Make sanity test run with pytest 90%

Added missing test for snapshotting

Use pytest tests in torchserve sanity checks

* Remove trailing white spaces

* Skip sanity test folder in regression test

---------

Co-authored-by: Geeta Chauhan <4461127+chauhang@users.noreply.github.com>
  • Loading branch information
mreso and chauhang committed Mar 11, 2024
1 parent 616c1ad commit c56715c
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 113 deletions.
27 changes: 27 additions & 0 deletions test/pytest/sanity/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import json
import sys
from pathlib import Path

import pytest

REPO_ROOT = Path(__file__).parents[3]


MAR_CONFIG = REPO_ROOT.joinpath("ts_scripts", "mar_config.json")


@pytest.fixture(name="gen_models", scope="module")
def load_gen_models() -> dict:
with open(MAR_CONFIG) as f:
models = json.load(f)
models = {m["model_name"]: m for m in models}
return models


@pytest.fixture(scope="module")
def ts_scripts_path():
sys.path.append(REPO_ROOT.as_posix())

yield

sys.path.pop()
54 changes: 54 additions & 0 deletions test/pytest/sanity/test_config_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
from pathlib import Path

import pytest
import test_utils

REPO_ROOT = Path(__file__).parents[3]
SANITY_MODELS_CONFIG = REPO_ROOT.joinpath("ts_scripts", "configs", "sanity_models.json")


def load_resnet18() -> dict:
with open(SANITY_MODELS_CONFIG) as f:
models = json.load(f)
return list(filter(lambda x: x["name"] == "resnet-18", models))[0]


@pytest.fixture(name="resnet18")
def generate_resnet18(model_store, gen_models, ts_scripts_path):
model = load_resnet18()

from ts_scripts.marsgen import generate_model

generate_model(gen_models[model["name"]], model_store)

yield model


@pytest.fixture(scope="module")
def torchserve_with_snapshot(model_store):
test_utils.torchserve_cleanup()

test_utils.start_torchserve(
model_store=model_store, no_config_snapshots=False, gen_mar=False
)

yield

test_utils.torchserve_cleanup()


def test_config_snapshotting(
resnet18, model_store, torchserve_with_snapshot, ts_scripts_path
):
from ts_scripts.sanity_utils import run_rest_test

run_rest_test(resnet18, unregister_model=False)

test_utils.stop_torchserve()

test_utils.start_torchserve(
model_store=model_store, no_config_snapshots=False, gen_mar=False
)

run_rest_test(resnet18, register_model=False)
55 changes: 55 additions & 0 deletions test/pytest/sanity/test_model_registering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
from pathlib import Path

import pytest

REPO_ROOT = Path(__file__).parents[3]
SANITY_MODELS_CONFIG = REPO_ROOT.joinpath("ts_scripts", "configs", "sanity_models.json")


@pytest.fixture(scope="module")
def grpc_client_stubs(ts_scripts_path):
from ts_scripts.shell_utils import rm_file
from ts_scripts.tsutils import generate_grpc_client_stubs

generate_grpc_client_stubs()

yield

rm_file(REPO_ROOT.joinpath("ts_scripts", "*_pb2*.py").as_posix(), True)


def load_models() -> dict:
with open(SANITY_MODELS_CONFIG) as f:
models = json.load(f)
return models


@pytest.fixture(name="model", params=load_models(), scope="module")
def models_to_validate(request, model_store, gen_models, ts_scripts_path):
model = request.param

if model["name"] in gen_models:
from ts_scripts.marsgen import generate_model

generate_model(gen_models[model["name"]], model_store)

yield model


def test_models_with_grpc(model, torchserve, ts_scripts_path, grpc_client_stubs):
from ts_scripts.sanity_utils import run_grpc_test

run_grpc_test(model)


def test_models_with_rest(model, torchserve, ts_scripts_path):
from ts_scripts.sanity_utils import run_rest_test

run_rest_test(model)


def test_gpu_setup(ts_scripts_path):
from ts_scripts.sanity_utils import test_gpu_setup

test_gpu_setup()
125 changes: 59 additions & 66 deletions ts_scripts/marsgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,64 @@ def gen_mar(model_store=None):
print(f"## Symlink {src}, {dst} successfully.")


def generate_model(model, model_store_dir):
serialized_file_path = None
if model.get("serialized_file_remote", None):
if model.get("gen_scripted_file_path", None):
subprocess.run(["python", model["gen_scripted_file_path"]])
else:
serialized_model_file_url = (
f"https://download.pytorch.org/models/{model['serialized_file_remote']}"
)
urllib.request.urlretrieve(
serialized_model_file_url,
f'{model_store_dir}/{model["serialized_file_remote"]}',
)
serialized_file_path = os.path.join(
model_store_dir, model["serialized_file_remote"]
)
elif model.get("serialized_file_local", None):
serialized_file_path = model["serialized_file_local"]

handler = model.get("handler", None)

extra_files = model.get("extra_files", None)

runtime = model.get("runtime", None)

archive_format = model.get("archive_format", "zip-store")

requirements_file = model.get("requirements_file", None)

export_path = model.get("export_path", model_store_dir)

cmd = model_archiver_command_builder(
model["model_name"],
model["version"],
model.get("model_file", None),
serialized_file_path,
handler,
extra_files,
runtime,
archive_format,
requirements_file,
export_path,
)
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}\n")
try:
subprocess.check_call(cmd, shell=True)
marfile = "{}.mar".format(model["model_name"])
print("## {} is generated.\n".format(marfile))
mar_set.add(marfile)
except subprocess.CalledProcessError as exc:
print("## {} creation failed !, error: {}\n".format(model["model_name"], exc))

if model.get("serialized_file_remote", None) and os.path.exists(
serialized_file_path
):
os.remove(serialized_file_path)


def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_DIR):
"""
By default generate_mars reads ts_scripts/mar_config.json and outputs mar files in dir model_store_gen
Expand All @@ -67,72 +125,7 @@ def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_D
models = json.loads(f.read())

for model in models:
serialized_file_path = None
if model.get("serialized_file_remote") and model["serialized_file_remote"]:
if (
model.get("gen_scripted_file_path")
and model["gen_scripted_file_path"]
):
subprocess.run(["python", model["gen_scripted_file_path"]])
else:
serialized_model_file_url = (
"https://download.pytorch.org/models/{}".format(
model["serialized_file_remote"]
)
)
urllib.request.urlretrieve(
serialized_model_file_url,
f'{model_store_dir}/{model["serialized_file_remote"]}',
)
serialized_file_path = os.path.join(
model_store_dir, model["serialized_file_remote"]
)
elif model.get("serialized_file_local") and model["serialized_file_local"]:
serialized_file_path = model["serialized_file_local"]

handler = model.get("handler", None)

extra_files = model.get("extra_files", None)

runtime = model.get("runtime", None)

archive_format = model.get("archive_format", "zip-store")

requirements_file = model.get("requirements_file", None)

export_path = model.get("export_path", model_store_dir)

cmd = model_archiver_command_builder(
model["model_name"],
model["version"],
model.get("model_file", None),
serialized_file_path,
handler,
extra_files,
runtime,
archive_format,
requirements_file,
export_path,
)
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}\n")
try:
subprocess.check_call(cmd, shell=True)
marfile = "{}.mar".format(model["model_name"])
print("## {} is generated.\n".format(marfile))
mar_set.add(marfile)
except subprocess.CalledProcessError as exc:
print(
"## {} creation failed !, error: {}\n".format(
model["model_name"], exc
)
)

if (
model.get("serialized_file_remote")
and model["serialized_file_remote"]
and os.path.exists(serialized_file_path)
):
os.remove(serialized_file_path)
generate_model(model, model_store_dir)
os.chdir(cwd)


Expand Down
2 changes: 1 addition & 1 deletion ts_scripts/regression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def run_pytest():
if status != 0:
print("Could not generate gRPC client stubs")
sys.exit(1)
cmd = "python -m pytest -v ./"
cmd = "python -m pytest -v ./ --ignore=sanity"
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}")
status = os.system(cmd)
rm_file("*_pb2*.py", True)
Expand Down
58 changes: 12 additions & 46 deletions ts_scripts/sanity_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ts_scripts import marsgen as mg
from ts_scripts import tsutils as ts
from ts_scripts import utils
from ts_scripts.tsutils import generate_grpc_client_stubs

REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
sys.path.append(REPO_ROOT)
Expand Down Expand Up @@ -204,51 +203,18 @@ def run_rest_test(model, register_model=True, unregister_model=True):


def test_sanity():
generate_grpc_client_stubs()

print("## Started sanity tests")

models_to_validate = load_model_to_validate()

test_gpu_setup()

ts_log_file = os.path.join("logs", "ts_console.log")

os.makedirs("model_store", exist_ok=True)
os.makedirs("logs", exist_ok=True)

mg.mar_set = set(os.listdir("model_store"))
started = ts.start_torchserve(log_file=ts_log_file, gen_mar=False)
if not started:
sys.exit(1)

resnet18_model = models_to_validate["resnet-18"]

models_to_validate = {
k: v for k, v in models_to_validate.items() if k != "resnet-18"
}

for _, model in models_to_validate.items():
run_grpc_test(model)
run_rest_test(model)

run_rest_test(resnet18_model, unregister_model=False)

stopped = ts.stop_torchserve()
if not stopped:
sys.exit(1)

# Restarting torchserve
# This should restart with the generated snapshot and resnet-18 model should be automatically registered
started = ts.start_torchserve(log_file=ts_log_file, gen_mar=False)
if not started:
sys.exit(1)

run_rest_test(resnet18_model, register_model=False)

stopped = ts.stop_torchserve()
if not stopped:
sys.exit(1)
# Execute python tests
print("## Started TorchServe sanity pytests")
test_dir = os.path.join("test", "pytest", "sanity")
coverage_dir = os.path.join("ts")
report_output_dir = os.path.join(test_dir, "coverage.xml")

ts_test_cmd = f"python -m pytest --cov-report xml:{report_output_dir} --cov={coverage_dir} {test_dir}"
print(f"## In directory: {os.getcwd()} | Executing command: {ts_test_cmd}")
ts_test_error_code = os.system(ts_test_cmd)

if ts_test_error_code != 0:
sys.exit("## TorchServe sanity test failed !")


def test_workflow_sanity():
Expand Down

0 comments on commit c56715c

Please sign in to comment.