From 8c3d322e736c418e08619b69c089dd300971987f Mon Sep 17 00:00:00 2001 From: Anand Inguva Date: Fri, 10 Jun 2022 15:47:47 -0400 Subject: [PATCH] fixup: bug --- .../apache_beam/ml/inference/sklearn_inference_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py index e23acf4ab413c..91eb86e2de4b0 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference_test.py @@ -240,7 +240,7 @@ def test_pipeline_pandas(self): splits = [dataframe.loc[[i]] for i in dataframe.index] pcoll = pipeline | 'start' >> beam.Create(splits) actual = pcoll | api.RunInference( - SklearnModelLoader(model_uri=temp_file_name)) + SklearnModelHandler(model_uri=temp_file_name)) expected = [ api.PredictionResult(splits[0], 5), @@ -265,7 +265,7 @@ def test_pipeline_pandas_with_keys(self): pcoll = pipeline | 'start' >> beam.Create(keyed_rows) actual = pcoll | api.RunInference( - SklearnModelLoader(model_uri=temp_file_name)) + SklearnModelHandler(model_uri=temp_file_name)) expected = [ ('0', api.PredictionResult(splits[0], 5)), ('1', api.PredictionResult(splits[1], 8)), @@ -279,14 +279,14 @@ def test_pipeline_pandas_with_keys(self): def test_infer_invalid_data_type(self): with self.assertRaises(ValueError): unexpected_input_type = [[1, 2, 3, 4], [5, 6, 7, 8]] - inference_runner = SklearnModelLoader(model_uri=unused) + inference_runner = SklearnModelHandler(model_uri='unused') fake_model = FakeModel() inference_runner.run_inference(unexpected_input_type, fake_model) def test_infer_too_many_rows_in_dataframe(self): with self.assertRaises(ValueError): data_frame_too_many_rows = pandas_dataframe() - inference_runner = SklearnModelLoader(model_uri=unused) + inference_runner = SklearnModelHandler(model_uri='unused') fake_model = FakeModel() inference_runner.run_inference([data_frame_too_many_rows], fake_model)