diff --git a/jobs/v4beta1/requirements-test.txt b/jobs/v4beta1/requirements-test.txt index 10b405fa748e..bbef4fe11b9a 100644 --- a/jobs/v4beta1/requirements-test.txt +++ b/jobs/v4beta1/requirements-test.txt @@ -1 +1 @@ -pytest==5.4.2 \ No newline at end of file +pytest==5.4.2 diff --git a/tables/automl/batch_predict_test.py b/tables/automl/batch_predict_test.py new file mode 100644 index 000000000000..37b5f0e09c34 --- /dev/null +++ b/tables/automl/batch_predict_test.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.cloud.automl_v1beta1.gapic import enums + +import pytest + +import automl_tables_model +import automl_tables_predict +import model_test + + +PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"] +REGION = "us-central1" +STATIC_MODEL = model_test.STATIC_MODEL +GCS_INPUT = "gs://{}-automl-tables-test/bank-marketing.csv".format(PROJECT) +GCS_OUTPUT = "gs://{}-automl-tables-test/TABLE_TEST_OUTPUT/".format(PROJECT) + + +@pytest.mark.slow +def test_batch_predict(capsys): + ensure_model_online() + automl_tables_predict.batch_predict( + PROJECT, REGION, STATIC_MODEL, GCS_INPUT, GCS_OUTPUT + ) + out, _ = capsys.readouterr() + assert "Batch prediction complete" in out + + +def ensure_model_online(): + model = model_test.ensure_model_ready() + if model.deployment_state != enums.Model.DeploymentState.DEPLOYED: + automl_tables_model.deploy_model(PROJECT, REGION, model.display_name) + + return automl_tables_model.get_model(PROJECT, REGION, model.display_name)