Skip to content

Commit

Permalink
Add VGG16 backbone (keras-team#1737)
Browse files Browse the repository at this point in the history
* Agg Vgg16 backbone

* update names

* update tests

* update test

* add image classifier

* incorporate review comments

* Update test case

* update backbone test

* add image classifier

* classifier cleanup

* code reformat

* add vgg16 image classifier

* make vgg generic

* update doc string

* update docstring

* add classifier test

* update tests

* update docstring

* address review comments

* code reformat

* update the configs

* address review comments

* fix task saved model test

* update init

* code reformatted
  • Loading branch information
divyashreepathihalli authored and mattdangerw committed Sep 10, 2024
1 parent 23815d6 commit 7a4bc47
Show file tree
Hide file tree
Showing 8 changed files with 514 additions and 14 deletions.
3 changes: 3 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
GPTNeoXPreprocessor,
)
from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer
from keras_nlp.src.models.image_classifier import ImageClassifier
from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM
from keras_nlp.src.models.llama3.llama3_causal_lm_preprocessor import (
Expand Down Expand Up @@ -228,6 +229,8 @@
from keras_nlp.src.models.text_classifier_preprocessor import (
TextClassifierPreprocessor,
)
from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone
from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier
from keras_nlp.src.models.whisper.whisper_backbone import WhisperBackbone
from keras_nlp.src.models.whisper.whisper_tokenizer import WhisperTokenizer
from keras_nlp.src.models.xlm_roberta.xlm_roberta_backbone import (
Expand Down
90 changes: 90 additions & 0 deletions keras_nlp/src/models/image_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.task import Task


@keras_nlp_export("keras_nlp.models.ImageClassifier")
class ImageClassifier(Task):
"""Base class for all image classification tasks.
`ImageClassifier` tasks wrap a `keras_nlp.models.Backbone` and
a `keras_nlp.models.Preprocessor` to create a model that can be used for
image classification. `ImageClassifier` tasks take an additional
`num_classes` argument, controlling the number of predicted output classes.
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
All `ImageClassifier` tasks include a `from_preset()` constructor which can be
used to load a pre-trained config and weights.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Default compilation.
self.compile()

def compile(
self,
optimizer="auto",
loss="auto",
*,
metrics="auto",
**kwargs,
):
"""Configures the `ImageClassifier` task for training.
The `ImageClassifier` task extends the default compilation signature of
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
`metrics`. To override these defaults, pass any value
to these arguments during compilation.
Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default optimizer
for the given model and task. See `keras.Model.compile` and
`keras.optimizers` for more info on possible `optimizer` values.
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, where a
`keras.losses.SparseCategoricalCrossentropy` loss will be
applied for the classification task. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
metrics: `"auto"`, or a list of metrics to be evaluated by
the model during training and testing. Defaults to `"auto"`,
where a `keras.metrics.SparseCategoricalAccuracy` will be
applied to track the accuracy of the model during training.
See `keras.Model.compile` and `keras.metrics` for
more info on possible `metrics` values.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
optimizer = keras.optimizers.Adam(5e-5)
if loss == "auto":
activation = getattr(self, "activation", None)
activation = keras.activations.get(activation)
from_logits = activation != keras.activations.softmax
loss = keras.losses.SparseCategoricalCrossentropy(from_logits)
if metrics == "auto":
metrics = [keras.metrics.SparseCategoricalAccuracy()]
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
**kwargs,
)
13 changes: 13 additions & 0 deletions keras_nlp/src/models/vgg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
159 changes: 159 additions & 0 deletions keras_nlp/src/models/vgg/vgg_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
from keras import layers

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.backbone import Backbone


@keras_nlp_export("keras_nlp.models.VGGBackbone")
class VGGBackbone(Backbone):
"""
This class represents Keras Backbone of VGG model.
This class implements a VGG backbone as described in [Very Deep
Convolutional Networks for Large-Scale Image Recognition](
https://arxiv.org/abs/1409.1556)(ICLR 2015).
Args:
stackwise_num_repeats: list of ints, number of repeated convolutional
blocks per VGG block. For VGG16 this is [2, 2, 3, 3, 3] and for
VGG19 this is [2, 2, 4, 4, 4].
stackwise_num_filters: list of ints, filter size for convolutional
blocks per VGG block. For both VGG16 and VGG19 this is [
64, 128, 256, 512, 512].
include_rescaling: bool, whether to rescale the inputs. If set to
True, inputs will be passed through a `Rescaling(1/255.0)` layer.
input_shape: tuple, optional shape tuple, defaults to (224, 224, 3).
pooling: bool, Optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be
the 4D tensor output of the
last convolutional block.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional block, and thus
the output of the model will be a 2D tensor.
- `max` means that global max pooling will
be applied.
Examples:
```python
input_data = np.ones((2, 224, 224, 3), dtype="float32")
# Pretrained VGG backbone.
model = keras_nlp.models.VGGBackbone.from_preset("vgg16")
model(input_data)
# Randomly initialized VGG backbone with a custom config.
model = keras_nlp.models.VGGBackbone(
stackwise_num_repeats = [2, 2, 3, 3, 3],
stackwise_num_filters = [64, 128, 256, 512, 512],
input_shape = (224, 224, 3),
include_rescaling = False,
pooling = "avg",
)
model(input_data)
```
"""

def __init__(
self,
stackwise_num_repeats,
stackwise_num_filters,
include_rescaling,
input_image_shape=(224, 224, 3),
pooling="avg",
**kwargs,
):

# === Functional Model ===
img_input = keras.layers.Input(shape=input_image_shape)
x = img_input

if include_rescaling:
x = layers.Rescaling(scale=1 / 255.0)(x)
for stack_index in range(len(stackwise_num_repeats) - 1):
x = apply_vgg_block(
x=x,
num_layers=stackwise_num_repeats[stack_index],
filters=stackwise_num_filters[stack_index],
kernel_size=(3, 3),
activation="relu",
padding="same",
max_pool=True,
name=f"block{stack_index + 1}",
)
if pooling == "avg":
x = layers.GlobalAveragePooling2D()(x)
elif pooling == "max":
x = layers.GlobalMaxPooling2D()(x)

super().__init__(inputs=img_input, outputs=x, **kwargs)

# === Config ===
self.stackwise_num_repeats = stackwise_num_repeats
self.stackwise_num_filters = stackwise_num_filters
self.include_rescaling = include_rescaling
self.input_image_shape = input_image_shape
self.pooling = pooling

def get_config(self):
return {
"stackwise_num_repeats": self.stackwise_num_repeats,
"stackwise_num_filters": self.stackwise_num_filters,
"include_rescaling": self.include_rescaling,
"input_image_shape": self.input_image_shape,
"pooling": self.pooling,
}


def apply_vgg_block(
x,
num_layers,
filters,
kernel_size,
activation,
padding,
max_pool,
name,
):
"""
Applies VGG block
Args:
x: Tensor, input tensor to pass through network
num_layers: int, number of CNN layers in the block
filters: int, filter size of each CNN layer in block
kernel_size: int (or) tuple, kernel size for CNN layer in block
activation: str (or) callable, activation function for each CNN layer in
block
padding: str (or) callable, padding function for each CNN layer in block
max_pool: bool, whether to add MaxPooling2D layer at end of block
name: str, name of the block
Returns:
keras.KerasTensor
"""
for num in range(1, num_layers + 1):
x = layers.Conv2D(
filters,
kernel_size,
activation=activation,
padding=padding,
name=f"{name}_conv{num}",
)(x)
if max_pool:
x = layers.MaxPooling2D((2, 2), (2, 2), name=f"{name}_pool")(x)
return x
48 changes: 48 additions & 0 deletions keras_nlp/src/models/vgg/vgg_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone
from keras_nlp.src.tests.test_case import TestCase


class VGGBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"stackwise_num_repeats": [2, 3, 3],
"stackwise_num_filters": [8, 64, 64],
"input_image_shape": (16, 16, 3),
"include_rescaling": False,
"pooling": "avg",
}
self.input_data = np.ones((2, 16, 16, 3), dtype="float32")

def test_backbone_basics(self):
self.run_backbone_test(
cls=VGGBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 64),
run_mixed_precision_check=False,
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=VGGBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading

0 comments on commit 7a4bc47

Please sign in to comment.