diff --git a/examples/pytorch/image-classification/README.md b/examples/pytorch/image-classification/README.md new file mode 100644 index 00000000000000..24a0ab22a605d1 --- /dev/null +++ b/examples/pytorch/image-classification/README.md @@ -0,0 +1,133 @@ + + +# Image classification examples + +The following examples showcase how to fine-tune a `ViT` for image-classification using PyTorch. + +## Using datasets from 🤗 `datasets` + +Here we show how to fine-tune a `ViT` on the [beans](https://huggingface.co/datasets/beans) dataset. + +👀 See the results here: [nateraw/vit-base-beans](https://huggingface.co/nateraw/vit-base-beans). + +```bash +python run_image_classification.py \ + --dataset_name beans \ + --output_dir ./beans_outputs/ \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --push_to_hub \ + --push_to_hub_model_id vit-base-beans \ + --learning_rate 2e-5 \ + --num_train_epochs 5 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 8 \ + --logging_strategy steps \ + --logging_steps 10 \ + --evaluation_strategy epoch \ + --save_strategy epoch \ + --load_best_model_at_end True \ + --save_total_limit 3 \ + --seed 1337 +``` + +Here we show how to fine-tune a `ViT` on the [cats_vs_dogs](https://huggingface.co/datasets/cats_vs_dogs) dataset. + +👀 See the results here: [nateraw/vit-base-cats-vs-dogs](https://huggingface.co/nateraw/vit-base-cats-vs-dogs). + +```bash +python run_image_classification.py \ + --dataset_name cats_vs_dogs \ + --output_dir ./cats_vs_dogs_outputs/ \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --push_to_hub \ + --push_to_hub_model_id vit-base-cats-vs-dogs \ + --fp16 True \ + --learning_rate 2e-4 \ + --num_train_epochs 5 \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 32 \ + --logging_strategy steps \ + --logging_steps 10 \ + --evaluation_strategy epoch \ + --save_strategy epoch \ + --load_best_model_at_end True \ + --save_total_limit 3 \ + --seed 1337 +``` + +## Using your own data + +To use your own dataset, the training script expects the following directory structure: + +```bash +root/dog/xxx.png +root/dog/xxy.png +root/dog/[...]/xxz.png + +root/cat/123.png +root/cat/nsdf3.png +root/cat/[...]/asd932_.png +``` + +Once you've prepared your dataset, you can can run the script like this: + +```bash +python run_image_classification.py \ + --dataset_name nateraw/image-folder \ + --train_dir \ + --output_dir ./outputs/ \ + --remove_unused_columns False \ + --do_train \ + --do_eval +``` + +### 💡 The above will split the train dir into training and evaluation sets + - To control the split amount, use the `--train_val_split` flag. + - To provide your own validation split in its own directory, you can pass the `--validation_dir ` flag. + + +## Sharing your model on 🤗 Hub + +0. If you haven't already, [sign up](https://huggingface.co/join) for a 🤗 account + +1. Make sure you have `git-lfs` installed and git set up. + +```bash +$ apt install git-lfs +$ git config --global user.email "you@example.com" +$ git config --global user.name "Your Name" +``` + +2. Log in with your HuggingFace account credentials using `huggingface-cli` + +```bash +$ huggingface-cli login +# ...follow the prompts +``` + +3. When running the script, pass the following arguments: + +```bash +python run_image_classification.py \ + --push_to_hub \ + --push_to_hub_model_id \ + ... +``` \ No newline at end of file diff --git a/examples/pytorch/image-classification/requirements.txt b/examples/pytorch/image-classification/requirements.txt new file mode 100644 index 00000000000000..08c6d3bc1d217d --- /dev/null +++ b/examples/pytorch/image-classification/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.9.0 +torchvision>=0.10.0 \ No newline at end of file diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py new file mode 100644 index 00000000000000..529a19da3eabce --- /dev/null +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +import datasets +import numpy as np +import torch +from datasets import load_dataset +from PIL import Image +from torchvision.transforms import ( + CenterCrop, + Compose, + Normalize, + RandomHorizontalFlip, + RandomResizedCrop, + Resize, + ToTensor, +) + +import transformers +from transformers import ( + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + AutoConfig, + AutoFeatureExtractor, + AutoModelForImageClassification, + HfArgumentParser, + Trainer, + TrainingArguments, +) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +""" Fine-tuning a 🤗 Transformers model for image classification""" + +logger = logging.getLogger(__name__) + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.10.0.dev0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def pil_loader(path: str): + with open(path, "rb") as f: + im = Image.open(f) + return im.convert("RGB") + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + dataset_name: Optional[str] = field( + default="nateraw/image-folder", metadata={"help": "Name of a dataset from the datasets package"} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the training data."}) + validation_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the validation data."}) + train_val_split: Optional[float] = field( + default=0.15, metadata={"help": "Percent to split off of train for validation."} + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."}) + + def __post_init__(self): + data_files = dict() + if self.train_dir is not None: + data_files["train"] = self.train_dir + if self.validation_dir is not None: + data_files["val"] = self.validation_dir + self.data_files = data_files if data_files else None + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + default="google/vit-base-patch16-224-in21k", + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}, + ) + model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."}) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example["labels"] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Initialize our dataset and prepare it for the 'image-classification' task. + ds = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + data_files=data_args.data_files, + cache_dir=model_args.cache_dir, + task="image-classification", + ) + + # Define torchvision transforms to be applied to each image. + normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + _train_transforms = Compose( + [ + RandomResizedCrop(data_args.image_size), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] + ) + _val_transforms = Compose( + [ + Resize(data_args.image_size), + CenterCrop(data_args.image_size), + ToTensor(), + normalize, + ] + ) + + def train_transforms(example_batch): + """Apply _train_transforms across a batch.""" + example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]] + return example_batch + + def val_transforms(example_batch): + """Apply _val_transforms across a batch.""" + example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]] + return example_batch + + # If we don't have a validation split, split off a percentage of train as validation. + data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split + if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: + split = ds["train"].train_test_split(data_args.train_val_split) + ds["train"] = split["train"] + ds["validation"] = split["test"] + + # Prepare label mappings. + # We'll include these in the model's config to get human readable labels in the Inference API. + labels = ds["train"].features["labels"].names + label2id, id2label = dict(), dict() + for i, label in enumerate(labels): + label2id[label] = str(i) + id2label[str(i)] = label + + # Load the accuracy metric from the datasets package + metric = datasets.load_metric("accuracy") + + # Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a + # predictions and label_ids field) and has to return a dictionary string to float. + def compute_metrics(p): + """Computes accuracy on a batch of predictions""" + return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids) + + config = AutoConfig.from_pretrained( + model_args.config_name or model_args.model_name_or_path, + num_labels=len(labels), + label2id=label2id, + id2label=id2label, + finetuning_task="image-classification", + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForImageClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + # NOTE - We aren't directly using this feature extractor since we defined custom transforms above. + # We initialize this instance below and pass it to Trainer to ensure that the feature extraction + # config, preprocessor_config.json, is included in output directories. + # This way if we push a model to the hub, the inference widget will work. + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name or model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + size=data_args.image_size, + image_mean=normalize.mean, + image_std=normalize.std, + ) + + if training_args.do_train: + if "train" not in ds: + raise ValueError("--do_train requires a train dataset") + if data_args.max_train_samples is not None: + ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) + # Set the training transforms + ds["train"].set_transform(train_transforms) + + if training_args.do_eval: + if "validation" not in ds: + raise ValueError("--do_eval requires a validation dataset") + if data_args.max_eval_samples is not None: + ds["validation"] = ( + ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) + ) + # Set the validation transforms + ds["validation"].set_transform(val_transforms) + + # Initalize our trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=ds["train"] if training_args.do_train else None, + eval_dataset=ds["validation"] if training_args.do_eval else None, + compute_metrics=compute_metrics, + tokenizer=feature_extractor, + data_collator=collate_fn, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Write model card and (optionally) push to hub + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "tasks": "image-classification", + "dataset": data_args.dataset_name, + "tags": ["image-classification"], + } + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index 74f1cb28c1ef9c..200e398d0cf11e 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -38,6 +38,7 @@ "question-answering", "summarization", "translation", + "image-classification", ] ] sys.path.extend(SRC_DIRS) @@ -47,6 +48,7 @@ import run_clm import run_generation import run_glue + import run_image_classification import run_mlm import run_ner import run_qa as run_squad @@ -340,3 +342,35 @@ def test_run_translation(self): run_translation.main() result = get_results(tmp_dir) self.assertGreaterEqual(result["eval_bleu"], 30) + + def test_run_image_classification(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_image_classification.py + --output_dir {tmp_dir} + --model_name_or_path google/vit-base-patch16-224-in21k + --train_dir tests/fixtures/tests_samples/cats_and_dogs/ + --do_train + --do_eval + --learning_rate 2e-5 + --per_device_train_batch_size 2 + --per_device_eval_batch_size 1 + --remove_unused_columns False + --overwrite_output_dir True + --dataloader_num_workers 16 + --metric_for_best_model accuracy + --max_steps 30 + --train_val_split 0.1 + --seed 7 + """.split() + + if is_cuda_and_apex_available(): + testargs.append("--fp16") + + with patch.object(sys, "argv", testargs): + run_image_classification.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_accuracy"], 0.8) diff --git a/tests/fixtures/tests_samples/cats_and_dogs/Cat/1.jpg b/tests/fixtures/tests_samples/cats_and_dogs/Cat/1.jpg new file mode 100644 index 00000000000000..cfa244e7a2aab2 Binary files /dev/null and b/tests/fixtures/tests_samples/cats_and_dogs/Cat/1.jpg differ diff --git a/tests/fixtures/tests_samples/cats_and_dogs/Cat/2.jpg b/tests/fixtures/tests_samples/cats_and_dogs/Cat/2.jpg new file mode 100644 index 00000000000000..de35c029b9cd03 Binary files /dev/null and b/tests/fixtures/tests_samples/cats_and_dogs/Cat/2.jpg differ diff --git a/tests/fixtures/tests_samples/cats_and_dogs/Cat/3.jpg b/tests/fixtures/tests_samples/cats_and_dogs/Cat/3.jpg new file mode 100644 index 00000000000000..b24c743f0f2d87 Binary files /dev/null and b/tests/fixtures/tests_samples/cats_and_dogs/Cat/3.jpg differ diff --git a/tests/fixtures/tests_samples/cats_and_dogs/Cat/4.jpg b/tests/fixtures/tests_samples/cats_and_dogs/Cat/4.jpg new file mode 100644 index 00000000000000..ed2ef8a6504d35 Binary files /dev/null and b/tests/fixtures/tests_samples/cats_and_dogs/Cat/4.jpg differ diff --git a/tests/fixtures/tests_samples/cats_and_dogs/Cat/5.jpg b/tests/fixtures/tests_samples/cats_and_dogs/Cat/5.jpg new file mode 100644 index 00000000000000..ce293a8a98c36b Binary files /dev/null and b/tests/fixtures/tests_samples/cats_and_dogs/Cat/5.jpg differ diff --git a/tests/fixtures/tests_samples/cats_and_dogs/Dog/1.jpg b/tests/fixtures/tests_samples/cats_and_dogs/Dog/1.jpg new file mode 100644 index 00000000000000..2f9eadca529034 Binary files /dev/null and b/tests/fixtures/tests_samples/cats_and_dogs/Dog/1.jpg differ diff --git a/tests/fixtures/tests_samples/cats_and_dogs/Dog/2.jpg b/tests/fixtures/tests_samples/cats_and_dogs/Dog/2.jpg new file mode 100644 index 00000000000000..ee511173b82b43 Binary files /dev/null and b/tests/fixtures/tests_samples/cats_and_dogs/Dog/2.jpg differ diff --git a/tests/fixtures/tests_samples/cats_and_dogs/Dog/3.jpg b/tests/fixtures/tests_samples/cats_and_dogs/Dog/3.jpg new file mode 100644 index 00000000000000..addeee0cc23954 Binary files /dev/null and b/tests/fixtures/tests_samples/cats_and_dogs/Dog/3.jpg differ diff --git a/tests/fixtures/tests_samples/cats_and_dogs/Dog/4.jpg b/tests/fixtures/tests_samples/cats_and_dogs/Dog/4.jpg new file mode 100644 index 00000000000000..fd15a2090b33a1 Binary files /dev/null and b/tests/fixtures/tests_samples/cats_and_dogs/Dog/4.jpg differ diff --git a/tests/fixtures/tests_samples/cats_and_dogs/Dog/5.jpg b/tests/fixtures/tests_samples/cats_and_dogs/Dog/5.jpg new file mode 100644 index 00000000000000..46105c5d053457 Binary files /dev/null and b/tests/fixtures/tests_samples/cats_and_dogs/Dog/5.jpg differ