Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add README documentation for scikit-learn MNIST example #21887

Merged
merged 2 commits into from
Jun 15, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions sdks/python/apache_beam/examples/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,54 @@ He looked up and saw the sun and stars .;moon
Each line has data separated by a semicolon ";".
The first item is the sentence with the last word masked. The second item
is the word that the model predicts for the mask.

---
## MNITST digit classification
[`sklearn_mnist_classification.py`](./sklearn_mnist_classification.py) contains
an implementation for a RunInference pipeline that performs image classification on handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) database.

The pipeline reads rows of pixels corresponding to a digit, performs basic preprocessing, passes the pixels to the Scikit-learn implementation of RunInference, and then writes the predictions to a text file.

### Dataset and model for language modeling
- **Required**: A path to a file called `INPUT` that contains label and pixels to
feed into the model. Each row should have elements that are comma-separated. The first element is the label. All subsuequent values are pixels from pixel0 to pixel784. It should look something like this:
yeandy marked this conversation as resolved.
Show resolved Hide resolved
```
1,0,0,0...
0,0,0,0...
1,0,0,0...
4,0,0,0...
...
```
- **Required**: A path to a file called `OUTPUT`, to which the pipeline will
write the predictions.
- **Required**: A path to a file called `MODEL_PATH` that contains the pickled file of a scikit-learn model trained on MNIST data. Please refer to this scikit-learn [documentation](https://scikit-learn.org/stable/model_persistence.html) on how to serialize models.


### Running `sklearn_mnist_classification.py`

To run the MNIST classification pipeline locally, use the following command:
```sh
python -m apache_beam.examples.inference.sklearn_mnist_classification.py \
--input_file INPUT \
--output OUTPUT \
--model_path MODEL_PATH
```
For example:
```sh
python -m apache_beam.examples.inference.sklearn_mnist_classification.py \
--input_file mnist_data.csv \
--output predictions.csv \
--model_path mnist_model_svm.pickle
```

This writes the output to the `predictions.csv` with contents like:
yeandy marked this conversation as resolved.
Show resolved Hide resolved
```
1,1
4,9
7,1
0,0
...
```
Each line has data separated by a comma ",".
The first item is the actual label of the digit. The second item
is the predicted label of the digit.