Skip to content

Commit

Permalink
Added unit test with pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
mreso committed Jun 28, 2022
1 parent c027522 commit 212b8fc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Subsequently, we need to add the command ```torch.save(model.state_dict(), "mode
python sst2_classification_non_distributed.py
```

A pretrained ```model.pt``` is also available for download [here](https://bert-mar-file.s3.us-west-2.amazonaws.com/text_classification_with_scriptable_tokenizer/model.pt).
The trained model can then be combined and compiled with TorchScript using the script_tokenizer_and_model.py script. Here ```model.pt``` are the model weights saved after training and ```model_jit.pt``` is the combination of tokenizer and model compiled with TorchScript.

```bash
Expand Down
55 changes: 50 additions & 5 deletions test/pytest/test_example_scriptable_tokenzier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import shutil
import sys
import time
from argparse import Namespace

import pytest
Expand Down Expand Up @@ -145,10 +146,12 @@ def mar_file_path(work_dir, session_mocker, jit_file_path, archiver):
"""
Create mar file and return file path.
"""
mar_file_path = os.path.join(work_dir, "scriptable_tokenizer.mar")
model_name = "scriptable_tokenizer_untrained"

mar_file_path = os.path.join(work_dir, model_name + ".mar")

args = Namespace(
model_name="scriptable_tokenizer",
model_name=model_name,
version="1.0",
serialized_file=jit_file_path,
model_file=None,
Expand Down Expand Up @@ -229,9 +232,12 @@ def model_name(mar_file_path, model_store, torchserve):
test_utils.unregister_model(model_name)


def test_inference_with_sample_text(model_name):
@pytest.fixture
def test_file():
return os.path.join(EXAMPLE_ROOT_DIR, "sample_text.txt")


test_file = os.path.join(EXAMPLE_ROOT_DIR, "sample_text.txt")
def test_inference_with_untrained_model_and_sample_text(model_name, test_file):

with open(test_file, "rb") as f:
response = requests.post(
Expand All @@ -250,7 +256,7 @@ def test_inference_with_sample_text(model_name):
assert float(result_entries["Positive"]) == pytest.approx(0.36607906222343445, 1e-3)


def test_inference_with_empty_string(model_name):
def test_inference_with_untrained_model_and_empty_string(model_name):

data = "".encode("utf8")

Expand All @@ -268,3 +274,42 @@ def test_inference_with_empty_string(model_name):
# We're using an untrained model for the unit test, so results do not make sense but should be consistent
assert float(result_entries["Negative"]) == pytest.approx(0.6082412600517273, 1e-3)
assert float(result_entries["Positive"]) == pytest.approx(0.3917587101459503, 1e-3)


def test_inference_with_pretrained_model(model_store, test_file, torchserve):
model_name = "scriptable_tokenizer"

params = (
("model_name", model_name),
(
"url",
"https://bert-mar-file.s3.us-west-2.amazonaws.com/text_classification_with_scriptable_tokenizer/scriptable_tokenizer.mar",
),
("initial_workers", "1"),
)

test_utils.reg_resp = test_utils.register_model_with_params(params)

# Give test some time for model to be downloaded from S3 bucket
for sleep_time in [2, 4, 8, 16, 32, 64]:
with open(test_file, "rb") as f:
response = requests.post(
url=f"http://localhost:8080/predictions/{model_name}", data=f
)
if response.status_code == 200:
break
time.sleep(sleep_time)

assert response.status_code == 200

result_entries = json.loads(response.text)

assert "Negative" in result_entries
assert "Positive" in result_entries

assert float(result_entries["Negative"]) == pytest.approx(
0.0001851904089562595, 1e-3
)
assert float(result_entries["Positive"]) == pytest.approx(0.9998148083686829, 1e-3)

test_utils.unregister_model(model_name)

0 comments on commit 212b8fc

Please sign in to comment.