From 51a154661c6420c412ec9e276d889f069128ea45 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 19 Sep 2024 12:31:47 -0700 Subject: [PATCH 1/3] Add anchor_generator, box_matcher and non_max_supression --- keras_hub/src/models/retinanet/__init__.py | 13 + .../src/models/retinanet/anchor_generator.py | 176 ++++++ .../models/retinanet/anchor_generator_test.py | 111 ++++ keras_hub/src/models/retinanet/box_matcher.py | 266 ++++++++ .../src/models/retinanet/box_matcher_test.py | 128 ++++ .../models/retinanet/non_max_supression.py | 586 ++++++++++++++++++ .../retinanet/non_max_supression_test.py | 72 +++ 7 files changed, 1352 insertions(+) create mode 100644 keras_hub/src/models/retinanet/__init__.py create mode 100644 keras_hub/src/models/retinanet/anchor_generator.py create mode 100644 keras_hub/src/models/retinanet/anchor_generator_test.py create mode 100644 keras_hub/src/models/retinanet/box_matcher.py create mode 100644 keras_hub/src/models/retinanet/box_matcher_test.py create mode 100644 keras_hub/src/models/retinanet/non_max_supression.py create mode 100644 keras_hub/src/models/retinanet/non_max_supression_test.py diff --git a/keras_hub/src/models/retinanet/__init__.py b/keras_hub/src/models/retinanet/__init__.py new file mode 100644 index 000000000..fd48fde00 --- /dev/null +++ b/keras_hub/src/models/retinanet/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/models/retinanet/anchor_generator.py new file mode 100644 index 000000000..55d965d46 --- /dev/null +++ b/keras_hub/src/models/retinanet/anchor_generator.py @@ -0,0 +1,176 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import math + +import keras +from keras import ops + +from keras_hub.src.bounding_box.converters import convert_format + + +class AnchorGenerator(keras.layers.Layer): + """Generates anchor boxes for object detection tasks. + + This layer creates a set of anchor boxes (also known as default boxes or + priors) for use in object detection models, particularly those utilizing + Feature Pyramid Networks (FPN). It generates anchors across multiple + pyramid levels, with various scales and aspect ratios. + + Feature Pyramid Levels: + - Levels typically range from 2 to 6 (P2 to P7), corresponding to different + resolutions of the input image. + - Each level l has a stride of 2^l pixels relative to the input image. + - Lower levels (e.g., P2) have higher resolution and are used for + detecting smaller objects. + - Higher levels (e.g., P7) have lower resolution and are used + for larger objects. + + Args: + bounding_box_format (str): The format of the bounding boxes + to be generated. Expected to be a string like 'xyxy', 'xywh', etc. + min_level (int): Minimum level of the output feature pyramid. + max_level (int): Maximum level of the output feature pyramid. + num_scales (int): Number of intermediate scales added on each level. + For example, num_scales=2 adds one additional intermediate anchor + scale [2^0, 2^0.5] on each level. + aspect_ratios (list of float): Aspect ratios of anchors added on + each level. Each number indicates the ratio of width to height. + anchor_size (float): Scale of size of the base anchor relative to the + feature stride 2^level. + + Call arguments: + images (Optional[Tensor]): An image tensor with shape `[B, H, W, C]` or + `[H, W, C]`. If provided, its shape will be used to determine anchor + sizes. + + Returns: + OrderedDict: A dictionary mapping feature levels + (e.g., 'P3', 'P4', etc.) to anchor boxes. Each entry contains a tensor + of shape `(H/stride * W/stride * num_anchors_per_location, 4)`, + where H and W are the height and width of the image, stride is 2^level, + and num_anchors_per_location is `num_scales * len(aspect_ratios)`. + + Example: + ```python + anchor_generator = AnchorGenerator( + bounding_box_format='xyxy', + min_level=3, + max_level=7, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=4.0, + ) + anchors = anchor_generator(image_shape=(640, 480)) + ``` + """ + + def __init__( + self, + bounding_box_format, + min_level, + max_level, + num_scales, + aspect_ratios, + anchor_size, + **kwargs, + ): + super().__init__(**kwargs) + self.bounding_box_format = bounding_box_format + self.min_level = min_level + self.max_level = max_level + self.num_scales = num_scales + self.aspect_ratios = aspect_ratios + self.anchor_size = anchor_size + self.built = True + + def call(self, images): + images_shape = images.shape + if len(images_shape) == 4: + image_shape = images_shape[1:-1] + else: + image_shape = images_shape[:-1] + + image_shape = tuple(image_shape) + + multilevel_boxes = collections.OrderedDict() + for level in range(self.min_level, self.max_level + 1): + boxes_l = [] + # Calculate the feature map size for this level + feat_size_y = math.ceil(image_shape[0] / 2**level) + feat_size_x = math.ceil(image_shape[1] / 2**level) + + # Calculate the stride (step size) for this level + stride_y = ops.cast(image_shape[0] / feat_size_y, "float32") + stride_x = ops.cast(image_shape[1] / feat_size_x, "float32") + + # Generate anchor center points + # Start from stride/2 to center anchors on pixels + cx = ops.arange(stride_x / 2, image_shape[1], stride_x) + cy = ops.arange(stride_y / 2, image_shape[0], stride_y) + + # Create a grid of anchor centers + cx_grid, cy_grid = ops.meshgrid(cx, cy) + + for scale in range(self.num_scales): + for aspect_ratio in self.aspect_ratios: + # Calculate the intermediate scale factor + intermidate_scale = 2 ** (scale / self.num_scales) + # Calculate the base anchor size for this level and scale + base_anchor_size = ( + self.anchor_size * 2**level * intermidate_scale + ) + # Adjust anchor dimensions based on aspect ratio + aspect_x = aspect_ratio**0.5 + aspect_y = aspect_ratio**-0.5 + half_anchor_size_x = base_anchor_size * aspect_x / 2.0 + half_anchor_size_y = base_anchor_size * aspect_y / 2.0 + + # Generate anchor boxes (y1, x1, y2, x2 format) + boxes = ops.stack( + [ + cy_grid - half_anchor_size_y, + cx_grid - half_anchor_size_x, + cy_grid + half_anchor_size_y, + cx_grid + half_anchor_size_x, + ], + axis=-1, + ) + boxes_l.append(boxes) + # Concat anchors on the same level to tensor shape HxWx(Ax4) + boxes_l = ops.concatenate(boxes_l, axis=-1) + boxes_l = ops.reshape(boxes_l, (-1, 4)) + # Convert to user defined + multilevel_boxes[f"P{level}"] = convert_format( + boxes_l, + source="yxyx", + target=self.bounding_box_format, + ) + return multilevel_boxes + + def compute_output_shape(self, input_shape): + multilevel_boxes_shape = {} + for level in range(self.min_level, self.max_level + 1): + multilevel_boxes_shape[f"P{level}"] = (None, None, 4) + return multilevel_boxes_shape + + @property + def anchors_per_location(self): + """ + The `anchors_per_location` property returns the number of anchors + generated per pixel location, which is equal to + `num_scales * len(aspect_ratios)`. + """ + return self.num_scales * len(self.aspect_ratios) diff --git a/keras_hub/src/models/retinanet/anchor_generator_test.py b/keras_hub/src/models/retinanet/anchor_generator_test.py new file mode 100644 index 000000000..0c54d0526 --- /dev/null +++ b/keras_hub/src/models/retinanet/anchor_generator_test.py @@ -0,0 +1,111 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +from keras import ops + +from keras_hub.src.bounding_box.converters import convert_format +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator +from keras_hub.src.tests.test_case import TestCase + + +class AnchorGeneratorTest(TestCase): + @parameterized.parameters( + # Single scale anchor + ("yxyx", 5, 5, 1, [1.0], 2.0, [64, 64]) + + ( + { + "P5": [ + [-16.0, -16.0, 48.0, 48.0], + [-16.0, 16.0, 48.0, 80.0], + [16.0, -16.0, 80.0, 48.0], + [16.0, 16.0, 80.0, 80.0], + ] + }, + ), + # Multi scale anchor + ("xywh", 5, 6, 1, [1.0], 2.0, [64, 64]) + + ( + { + "P5": [ + [-16.0, -16.0, 48.0, 48.0], + [-16.0, 16.0, 48.0, 80.0], + [16.0, -16.0, 80.0, 48.0], + [16.0, 16.0, 80.0, 80.0], + ], + "P6": [[-32, -32, 96, 96]], + }, + ), + # Multi aspect ratio anchor + ("xyxy", 6, 6, 1, [1.0, 4.0, 0.25], 2.0, [64, 64]) + + ( + { + "P6": [ + [-32.0, -32.0, 96.0, 96.0], + [0.0, -96.0, 64.0, 160.0], + [-96.0, 0.0, 160.0, 64.0], + ] + }, + ), + # Intermidate scales + ("yxyx", 5, 5, 2, [1.0], 1.0, [32, 32]) + + ( + { + "P5": [ + [0.0, 0.0, 32.0, 32.0], + [ + 16 - 16 * 2**0.5, + 16 - 16 * 2**0.5, + 16 + 16 * 2**0.5, + 16 + 16 * 2**0.5, + ], + ] + }, + ), + # Non-square + ("xywh", 5, 5, 1, [1.0], 1.0, [64, 32]) + + ({"P5": [[0, 0, 32, 32], [32, 0, 64, 32]]},), + # Indivisible by 2^level + ("xyxy", 5, 5, 1, [1.0], 1.0, [40, 32]) + + ({"P5": [[-6, 0, 26, 32], [14, 0, 46, 32]]},), + ) + def test_anchor_generator( + self, + bounding_box_format, + min_level, + max_level, + num_scales, + aspect_ratios, + anchor_size, + image_shape, + expected_boxes, + ): + anchor_generator = AnchorGenerator( + bounding_box_format, + min_level, + max_level, + num_scales, + aspect_ratios, + anchor_size, + ) + images = ops.ones(shape=(1, image_shape[0], image_shape[1], 3)) + multilevel_boxes = anchor_generator(images=images) + for key in expected_boxes: + expected_boxes[key] = ops.convert_to_tensor(expected_boxes[key]) + expected_boxes[key] = convert_format( + expected_boxes[key], + source="yxyx", + target=bounding_box_format, + ) + self.assertAllClose(expected_boxes, multilevel_boxes) diff --git a/keras_hub/src/models/retinanet/box_matcher.py b/keras_hub/src/models/retinanet/box_matcher.py new file mode 100644 index 000000000..1ed53d3ca --- /dev/null +++ b/keras_hub/src/models/retinanet/box_matcher.py @@ -0,0 +1,266 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras +from keras import ops + + +class BoxMatcher(keras.layers.Layer): + """Box matching logic based on argmax of highest value (e.g., IOU). + + This class computes matches from a similarity matrix. Each row will be + matched to at least one column, the matched result can either be positive + / negative, or simply ignored depending on the setting. + + The settings include `thresholds` and `match_values`, for example if: + 1) `thresholds=[negative_threshold, positive_threshold]`, and + `match_values=[negative_value=0, ignore_value=-1, positive_value=1]`: the + rows will be assigned to positive_value if its argmax result >= + positive_threshold; the rows will be assigned to negative_value if its + argmax result < negative_threshold, and the rows will be assigned to + ignore_value if its argmax result is between [negative_threshold, + positive_threshold). + 2) `thresholds=[negative_threshold, positive_threshold]`, and + `match_values=[ignore_value=-1, negative_value=0, positive_value=1]`: the + rows will be assigned to positive_value if its argmax result >= + positive_threshold; the rows will be assigned to ignore_value if its + argmax result < negative_threshold, and the rows will be assigned to + negative_value if its argmax result is between [negative_threshold, + positive_threshold). This is different from case 1) by swapping first two + values. + 3) `thresholds=[positive_threshold]`, and + `match_values=[negative_values, positive_value]`: the rows will be + assigned to positive value if its argmax result >= positive_threshold; + the rows will be assigned to negative_value if its argmax result < + negative_threshold. + + Args: + thresholds: A sorted list of floats to classify the matches into + different results (e.g. positive or negative or ignored match). The + list will be prepended with -Inf and and appended with +Inf. + match_values: A list of integers representing matched results (e.g. + positive or negative or ignored match). len(`match_values`) must + equal to len(`thresholds`) + 1. + force_match_for_each_col: each row will be argmax matched to at + least one column. This means some columns will be matched to + multiple rows while some columns will not be matched to any rows. + Filtering by `thresholds` will make less columns match to positive + result. Setting this to True guarantees that each column will be + matched to positive result to at least one row. + + Raises: + ValueError: if `thresholds` not sorted or + len(`match_values`) != len(`thresholds`) + 1 + + Example: + + ```python + box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1]) + iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes) + matched_columns, matched_match_values = box_matcher(iou_metric) + cls_mask = ops.less_equal(matched_match_values, 0) + ``` + + """ + + def __init__( + self, + thresholds, + match_values, + force_match_for_each_col=False, + **kwargs, + ): + super().__init__(**kwargs) + if sorted(thresholds) != thresholds: + raise ValueError(f"`threshold` must be sorted, got {thresholds}") + self.match_values = match_values + if len(match_values) != len(thresholds) + 1: + raise ValueError( + f"len(`match_values`) must be len(`thresholds`) + 1, got " + f"match_values {match_values}, thresholds {thresholds}" + ) + thresholds.insert(0, -float("inf")) + thresholds.append(float("inf")) + self.thresholds = thresholds + self.force_match_for_each_col = force_match_for_each_col + self.built = True + + def call(self, similarity_matrix): + """Matches each row to a column based on argmax + + Args: + similarity_matrix: A float Tensor of shape `[num_rows, num_cols]` or + `[batch_size, num_rows, num_cols]` representing any similarity + metric. + + Returns: + matched_columns: An integer tensor of shape `[num_rows]` or + `[batch_size, num_rows]` storing the index of the matched + column for each row. + matched_values: An integer tensor of shape [num_rows] or + `[batch_size, num_rows]` storing the match result + `(positive match, negative match, ignored match)`. + """ + squeeze_result = False + if len(similarity_matrix.shape) == 2: + squeeze_result = True + similarity_matrix = ops.expand_dims(similarity_matrix, axis=0) + static_shape = list(similarity_matrix.shape) + num_rows = static_shape[1] or ops.shape(similarity_matrix)[1] + batch_size = static_shape[0] or ops.shape(similarity_matrix)[0] + + def _match_when_cols_are_empty(): + """Performs matching when the rows of similarity matrix are empty. + When the rows are empty, all detections are false positives. So we + return a tensor of -1's to indicate that the rows do not match to + any columns. + + Returns: + matched_columns: An integer tensor of shape [batch_size, + num_rows] storing the index of the matched column for each + row. + matched_values: An integer tensor of shape [batch_size, + num_rows] storing the match type indicator (e.g. positive or + negative or ignored match). + """ + with keras.name_scope("empty_boxes"): + matched_columns = ops.zeros( + [batch_size, num_rows], dtype="int32" + ) + matched_values = -ops.ones( + [batch_size, num_rows], dtype="int32" + ) + return matched_columns, matched_values + + def _match_when_cols_are_non_empty(): + """Performs matching when the rows of similarity matrix are + non-empty. + Returns: + matched_columns: An integer tensor of shape [batch_size, + num_rows] storing the index of the matched column for each + row. + matched_values: An integer tensor of shape [batch_size, + num_rows] storing the match type indicator (e.g. positive or + negative or ignored match). + """ + with keras.name_scope("non_empty_boxes"): + # Jax traces this function even when running eagerly and the + # columns are non-empty. Therefore, we need to handle the case + # where the similarity matrix is empty. We do this by padding + # some -1s to the end. -1s are guaranteed to not affect argmax + # matching because all values in a similarity matrix are [0,1] + # and the indexing won't change because these are added at the + # end. + padded_similarity_matrix = ops.concatenate( + [similarity_matrix, -ops.ones((batch_size, num_rows, 1))], + axis=-1, + ) + + matched_columns = ops.argmax( + padded_similarity_matrix, + axis=-1, + ) + + # Get logical indices of ignored and unmatched columns as int32 + matched_vals = ops.max(padded_similarity_matrix, axis=-1) + matched_values = ops.zeros([batch_size, num_rows], "int32") + + match_dtype = matched_vals.dtype + for ind, low, high in zip( + self.match_values, self.thresholds[:-1], self.thresholds[1:] + ): + low_threshold = ops.cast(low, match_dtype) + high_threshold = ops.cast(high, match_dtype) + mask = ops.logical_and( + ops.greater_equal(matched_vals, low_threshold), + ops.less(matched_vals, high_threshold), + ) + matched_values = self._set_values_using_indicator( + matched_values, mask, ind + ) + + if self.force_match_for_each_col: + # [batch_size, num_cols], for each column (groundtruth_box), + # find the best matching row (anchor). + matching_rows = ops.argmax( + padded_similarity_matrix, + axis=1, + ) + # [batch_size, num_cols, num_rows], a transposed 0-1 mapping + # matrix M, where M[j, i] = 1 means column j is matched to + # row i. + column_to_row_match_mapping = ops.one_hot( + matching_rows, num_rows + ) + # [batch_size, num_rows], for each row (anchor), find the + # matched column (groundtruth_box). + force_matched_columns = ops.argmax( + column_to_row_match_mapping, + axis=1, + ) + # [batch_size, num_rows] + force_matched_column_mask = ops.cast( + ops.max(column_to_row_match_mapping, axis=1), + "bool", + ) + # [batch_size, num_rows] + matched_columns = ops.where( + force_matched_column_mask, + force_matched_columns, + matched_columns, + ) + matched_values = ops.where( + force_matched_column_mask, + self.match_values[-1] + * ops.ones([batch_size, num_rows], dtype="int32"), + matched_values, + ) + + return ops.cast(matched_columns, "int32"), matched_values + + num_boxes = ( + similarity_matrix.shape[-1] or ops.shape(similarity_matrix)[-1] + ) + matched_columns, matched_values = ops.cond( + pred=ops.greater(num_boxes, 0), + true_fn=_match_when_cols_are_non_empty, + false_fn=_match_when_cols_are_empty, + ) + + if squeeze_result: + matched_columns = ops.squeeze(matched_columns, axis=0) + matched_values = ops.squeeze(matched_values, axis=0) + + return matched_columns, matched_values + + def _set_values_using_indicator(self, x, indicator, val): + """Set the indicated fields of x to val. + + Args: + x: tensor. + indicator: boolean with same shape as x. + val: scalar with value to set. + Returns: + modified tensor. + """ + indicator = ops.cast(indicator, x.dtype) + return ops.add(ops.multiply(x, 1 - indicator), val * indicator) + + def get_config(self): + config = { + "thresholds": self.thresholds[1:-1], + "match_values": self.match_values, + "force_match_for_each_col": self.force_match_for_each_col, + } + return config diff --git a/keras_hub/src/models/retinanet/box_matcher_test.py b/keras_hub/src/models/retinanet/box_matcher_test.py new file mode 100644 index 000000000..3eb18290d --- /dev/null +++ b/keras_hub/src/models/retinanet/box_matcher_test.py @@ -0,0 +1,128 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from keras import ops + +from keras_hub.src.models.retinanet.box_matcher import BoxMatcher +from keras_hub.src.tests.test_case import TestCase + + +class BoxMatcherTest(TestCase): + def test_box_matcher_invalid_length(self): + fg_threshold = 0.5 + bg_thresh_hi = 0.2 + bg_thresh_lo = 0.0 + + with self.assertRaisesRegex(ValueError, "must be len"): + _ = BoxMatcher( + thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold], + match_values=[-3, -2, -1], + ) + + def test_box_matcher_unsorted_thresholds(self): + fg_threshold = 0.5 + bg_thresh_hi = 0.2 + bg_thresh_lo = 0.0 + + with self.assertRaisesRegex(ValueError, "must be sorted"): + _ = BoxMatcher( + thresholds=[bg_thresh_hi, bg_thresh_lo, fg_threshold], + match_values=[-3, -2, -1, 1], + ) + + def test_box_matcher_unbatched(self): + sim_matrix = np.array([[0.04, 0, 0, 0], [0, 0, 1.0, 0]]) + + fg_threshold = 0.5 + bg_thresh_hi = 0.2 + bg_thresh_lo = 0.0 + + matcher = BoxMatcher( + thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold], + match_values=[-3, -2, -1, 1], + ) + match_indices, matched_values = matcher(sim_matrix) + positive_matches = ops.greater_equal(matched_values, 0) + negative_matches = ops.equal(matched_values, -2) + + self.assertAllEqual(positive_matches, [False, True]) + self.assertAllEqual(negative_matches, [True, False]) + self.assertAllEqual(match_indices, [0, 2]) + self.assertAllEqual(matched_values, [-2, 1]) + + def test_box_matcher_batched(self): + sim_matrix = np.array([[[0.04, 0, 0, 0], [0, 0, 1.0, 0]]]) + + fg_threshold = 0.5 + bg_thresh_hi = 0.2 + bg_thresh_lo = 0.0 + + matcher = BoxMatcher( + thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold], + match_values=[-3, -2, -1, 1], + ) + match_indices, matched_values = matcher(sim_matrix) + positive_matches = ops.greater_equal(matched_values, 0) + negative_matches = ops.equal(matched_values, -2) + + self.assertAllEqual(positive_matches, [[False, True]]) + self.assertAllEqual(negative_matches, [[True, False]]) + self.assertAllEqual(match_indices, [[0, 2]]) + self.assertAllEqual(matched_values, [[-2, 1]]) + + def test_box_matcher_force_match(self): + sim_matrix = np.array( + [[0, 0.04, 0, 0.1], [0, 0, 1.0, 0], [0.1, 0, 0, 0], [0, 0, 0, 0.6]], + ) + + fg_threshold = 0.5 + bg_thresh_hi = 0.2 + bg_thresh_lo = 0.0 + + matcher = BoxMatcher( + thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold], + match_values=[-3, -2, -1, 1], + force_match_for_each_col=True, + ) + match_indices, matched_values = matcher(sim_matrix) + positive_matches = ops.greater_equal(matched_values, 0) + negative_matches = ops.equal(matched_values, -2) + + self.assertAllEqual(positive_matches, [True, True, True, True]) + self.assertAllEqual(negative_matches, [False, False, False, False]) + # the first anchor cannot be matched to 4th gt box given that is matched + # to the last anchor. + self.assertAllEqual(match_indices, [1, 2, 0, 3]) + self.assertAllEqual(matched_values, [1, 1, 1, 1]) + + def test_box_matcher_empty_gt_boxes(self): + sim_matrix = np.array([[], []]) + + fg_threshold = 0.5 + bg_thresh_hi = 0.2 + bg_thresh_lo = 0.0 + + matcher = BoxMatcher( + thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold], + match_values=[-3, -2, -1, 1], + ) + match_indices, matched_values = matcher(sim_matrix) + positive_matches = ops.greater_equal(matched_values, 0) + ignore_matches = ops.equal(matched_values, -1) + + self.assertAllEqual(positive_matches, [False, False]) + self.assertAllEqual(ignore_matches, [True, True]) + self.assertAllEqual(match_indices, [0, 0]) + self.assertAllEqual(matched_values, [-1, -1]) diff --git a/keras_hub/src/models/retinanet/non_max_supression.py b/keras_hub/src/models/retinanet/non_max_supression.py new file mode 100644 index 000000000..39a61ab71 --- /dev/null +++ b/keras_hub/src/models/retinanet/non_max_supression.py @@ -0,0 +1,586 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import keras +from keras import ops + +from keras_hub.src.bounding_box import converters +from keras_hub.src.bounding_box import utils +from keras_hub.src.bounding_box import validate_format + +EPSILON = 1e-8 + + +class NonMaxSuppression(keras.layers.Layer): + """A Keras layer that decodes predictions of an object detection model. + + Args: + bounding_box_format: The format of bounding boxes of input dataset. + Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + from_logits: boolean, True means input score is logits, False means + confidence. + iou_threshold: a float value in the range [0, 1] representing the + minimum IoU threshold for two boxes to be considered + same for suppression. Defaults to 0.5. + confidence_threshold: a float value in the range [0, 1]. All boxes with + confidence below this value will be discarded, defaults to 0.5. + max_detections: the maximum detections to consider after nms is applied. + A large number may trigger significant memory overhead, + defaults to 100. + """ # noqa: E501 + + def __init__( + self, + bounding_box_format, + from_logits, + iou_threshold=0.5, + confidence_threshold=0.5, + max_detections=100, + **kwargs, + ): + super().__init__(**kwargs) + self.bounding_box_format = bounding_box_format + self.from_logits = from_logits + self.iou_threshold = iou_threshold + self.confidence_threshold = confidence_threshold + self.max_detections = max_detections + self.built = True + + def call( + self, box_prediction, class_prediction, images=None, image_shape=None + ): + """Accepts images and raw predictions, and returns bounding box + predictions. + + Args: + box_prediction: Dense Tensor of shape [batch, boxes, 4] in the + `bounding_box_format` specified in the constructor. + class_prediction: Dense Tensor of shape [batch, boxes, num_classes]. + """ + target_format = "yxyx" + if utils.is_relative(self.bounding_box_format): + target_format = utils.as_relative(target_format) + + box_prediction = converters.convert_format( + box_prediction, + source=self.bounding_box_format, + target=target_format, + images=images, + image_shape=image_shape, + ) + if self.from_logits: + class_prediction = ops.sigmoid(class_prediction) + + confidence_prediction = ops.max(class_prediction, axis=-1) + + idx, valid_det = non_max_suppression( + box_prediction, + confidence_prediction, + max_output_size=self.max_detections, + iou_threshold=self.iou_threshold, + score_threshold=self.confidence_threshold, + ) + + box_prediction = ops.take_along_axis( + box_prediction, ops.expand_dims(idx, axis=-1), axis=1 + ) + box_prediction = ops.reshape( + box_prediction, (-1, self.max_detections, 4) + ) + confidence_prediction = ops.take_along_axis( + confidence_prediction, idx, axis=1 + ) + class_prediction = ops.take_along_axis( + class_prediction, ops.expand_dims(idx, axis=-1), axis=1 + ) + + box_prediction = converters.convert_format( + box_prediction, + source=target_format, + target=self.bounding_box_format, + images=images, + image_shape=image_shape, + ) + bounding_boxes = { + "boxes": box_prediction, + "confidence": confidence_prediction, + "classes": ops.argmax(class_prediction, axis=-1), + "num_detections": valid_det, + } + + # this is required to comply with KerasCV bounding box format. + return mask_invalid_detections(bounding_boxes) + + def get_config(self): + config = super().get_config() + config.update( + { + "bounding_box_format": self.bounding_box_format, + "from_logits": self.from_logits, + "iou_threshold": self.iou_threshold, + "confidence_threshold": self.confidence_threshold, + "max_detections": self.max_detections, + } + ) + return config + + +def non_max_suppression( + boxes, + scores, + max_output_size, + iou_threshold=0.5, + score_threshold=0.0, + tile_size=512, +): + # Box format must be yxyx + """Non-maximum suppression. + Ported from https://github.com/tensorflow/tensorflow/blob/v2.12.0/tensorflow/python/ops/image_ops_impl.py#L5368-L5458 + + Args: + boxes: a tensor of rank 2 or higher with a shape of + `[..., num_boxes, 4]`. Dimensions except the last two are batch + dimensions. The last dimension represents box coordinates in + yxyx format. + scores: a tensor of rank 1 or higher with a shape of `[..., num_boxes]`. + max_output_size: a scalar integer tensor representing the maximum + number of boxes to be selected by non max suppression. + iou_threshold: a float representing the threshold for + deciding whether boxes overlap too much with respect + to IoU (intersection over union). + score_threshold: a float representing the threshold for box scores. + Boxes with a score that is not larger than this threshold + will be suppressed. + tile_size: an integer representing the number of boxes in a tile, i.e., + the maximum number of boxes per image that can be used to suppress + other boxes in parallel; larger tile_size means larger parallelism + and potentially more redundant work. + + Returns: + idx: a tensor with a shape of `[..., num_boxes]` representing the + indices selected by non-max suppression. The leading dimensions + are the batch dimensions of the input boxes. All numbers are within + `[0, num_boxes)`. For each image (i.e., `idx[i]`), only the first + `num_valid[i]` indices (i.e., `idx[i][:num_valid[i]]`) are valid. + num_valid: a tensor of rank 0 or higher with a shape of [...] + representing the number of valid indices in idx. Its dimensions + are the batch dimensions of the input boxes. + """ # noqa: E501 + + def _sort_scores_and_boxes(scores, boxes): + """Sort boxes based their score from highest to lowest. + + Args: + scores: a tensor with a shape of `[batch_size, num_boxes]` + representing the scores of boxes. + boxes: a tensor with a shape of `[batch_size, num_boxes, 4]` + representing the boxes. + + Returns: + sorted_scores: a tensor with a shape of + `[batch_size, num_boxes]` representing the sorted scores. + sorted_boxes: a tensor representing the sorted boxes. + sorted_scores_indices: a tensor with a shape of + `[batch_size, num_boxes]` representing the index of the scores + in a sorted descending order. + """ # noqa: E501 + with keras.name_scope("sort_scores_and_boxes"): + sorted_scores_indices = ops.flip( + ops.cast(ops.argsort(scores, axis=1), "int32"), axis=1 + ) + sorted_scores = ops.take_along_axis( + scores, + sorted_scores_indices, + axis=1, + ) + sorted_boxes = ops.take_along_axis( + boxes, + ops.expand_dims(sorted_scores_indices, axis=-1), + axis=1, + ) + return sorted_scores, sorted_boxes, sorted_scores_indices + + batch_dims = ops.shape(boxes)[:-2] + num_boxes = boxes.shape[-2] + boxes = ops.reshape(boxes, [-1, num_boxes, 4]) + scores = ops.reshape(scores, [-1, num_boxes]) + batch_size = boxes.shape[0] + if score_threshold != float("-inf"): + with keras.name_scope("filter_by_score"): + score_mask = ops.cast(scores > score_threshold, scores.dtype) + scores *= score_mask + box_mask = ops.expand_dims(ops.cast(score_mask, boxes.dtype), 2) + boxes *= box_mask + + scores, boxes, sorted_indices = _sort_scores_and_boxes(scores, boxes) + + pad = ( + math.ceil(max(num_boxes, max_output_size) / tile_size) * tile_size + - num_boxes + ) + boxes = ops.pad(ops.cast(boxes, "float32"), [[0, 0], [0, pad], [0, 0]]) + scores = ops.pad(ops.cast(scores, "float32"), [[0, 0], [0, pad]]) + num_boxes_after_padding = num_boxes + pad + num_iterations = num_boxes_after_padding // tile_size + + def _loop_cond(unused_boxes, unused_threshold, output_size, idx): + return ops.logical_and( + ops.min(output_size) < ops.cast(max_output_size, "int32"), + ops.cast(idx, "int32") < num_iterations, + ) + + def suppression_loop_body(boxes, iou_threshold, output_size, idx): + return _suppression_loop_body( + boxes, iou_threshold, output_size, idx, tile_size + ) + + selected_boxes, _, output_size, _ = ops.while_loop( + _loop_cond, + suppression_loop_body, + [ + boxes, + iou_threshold, + ops.zeros([batch_size], "int32"), + ops.array(0), + ], + ) + num_valid = ops.minimum(output_size, max_output_size) + idx = num_boxes_after_padding - ops.cast( + ops.top_k( + ops.cast(ops.any(selected_boxes > 0, [2]), "int32") + * ops.cast( + ops.expand_dims(ops.arange(num_boxes_after_padding, 0, -1), 0), + "int32", + ), + max_output_size, + )[0], + "int32", + ) + idx = ops.minimum(idx, num_boxes - 1) + + index_offsets = ops.cast(ops.arange(batch_size) * num_boxes, "int32") + take_along_axis_idx = ops.reshape( + idx + ops.expand_dims(index_offsets, 1), [-1] + ) + + if keras.backend.backend() != "tensorflow": + idx = ops.take_along_axis( + ops.reshape(sorted_indices, [-1]), take_along_axis_idx + ) + else: + import tensorflow as tf + + idx = tf.gather(ops.reshape(sorted_indices, [-1]), take_along_axis_idx) + idx = ops.reshape(idx, [batch_size, -1]) + + invalid_index = ops.zeros([batch_size, max_output_size], dtype="int32") + idx_index = ops.cast( + ops.expand_dims(ops.arange(max_output_size), 0), "int32" + ) + num_valid_expanded = ops.expand_dims(num_valid, 1) + idx = ops.where(idx_index < num_valid_expanded, idx, invalid_index) + + num_valid = ops.reshape(num_valid, batch_dims) + return idx, num_valid + + +def _bbox_overlap(boxes_a, boxes_b): + """Calculates the overlap (iou - intersection over union) between boxes_a + and boxes_b. + + Args: + boxes_a: a tensor with a shape of `[batch_size, N, 4]`. + `N` is the number of boxes per image. The last dimension is the + pixel coordinates in `[ymin, xmin, ymax, xmax]` form. + boxes_b: a tensor with a shape of `[batch_size, M, 4]`. M is the number of + boxes. The last dimension is the pixel coordinates in + `[ymin, xmin, ymax, xmax]` form. + + Returns: + intersection_over_union: a tensor with as a shape of + `[batch_size, N, M]`, representing the ratio of intersection area + over union area (IoU) between two boxes + """ # noqa: E501 + with keras.name_scope("bbox_overlap"): + if len(boxes_a.shape) == 4: + boxes_a = ops.squeeze(boxes_a, axis=0) + a_y_min, a_x_min, a_y_max, a_x_max = ops.split(boxes_a, 4, axis=2) + b_y_min, b_x_min, b_y_max, b_x_max = ops.split(boxes_b, 4, axis=2) + + # Calculates the intersection area. + i_xmin = ops.maximum(a_x_min, ops.transpose(b_x_min, [0, 2, 1])) + i_xmax = ops.minimum(a_x_max, ops.transpose(b_x_max, [0, 2, 1])) + i_ymin = ops.maximum(a_y_min, ops.transpose(b_y_min, [0, 2, 1])) + i_ymax = ops.minimum(a_y_max, ops.transpose(b_y_max, [0, 2, 1])) + i_area = ops.maximum((i_xmax - i_xmin), 0) * ops.maximum( + (i_ymax - i_ymin), 0 + ) + + # Calculates the union area. + a_area = (a_y_max - a_y_min) * (a_x_max - a_x_min) + b_area = (b_y_max - b_y_min) * (b_x_max - b_x_min) + + # Adds a small epsilon to avoid divide-by-zero. + u_area = a_area + ops.transpose(b_area, [0, 2, 1]) - i_area + EPSILON + + intersection_over_union = i_area / u_area + + return intersection_over_union + + +def _self_suppression(iou, _, iou_sum, iou_threshold): + """Suppress boxes in the same tile. + + Compute boxes that cannot be suppressed by others (i.e., + can_suppress_others), and then use them to suppress boxes in the same tile. + + Args: + iou: a tensor of shape `[batch_size, num_boxes_with_padding]` + representing intersection over union. + iou_sum: a scalar tensor. + iou_threshold: a scalar tensor. + + Returns: + iou_suppressed: a tensor of shape + `[batch_size, num_boxes_with_padding]`. + iou_diff: a scalar tensor representing whether any box is supressed in + this step. + iou_sum_new: a scalar tensor of shape `[batch_size]` that represents + the iou sum after suppression. + iou_threshold: a scalar tensor. + """ # noqa: E501 + batch_size = ops.shape(iou)[0] + can_suppress_others = ops.cast( + ops.reshape(ops.max(iou, 1) < iou_threshold, [batch_size, -1, 1]), + iou.dtype, + ) + iou_after_suppression = ( + ops.reshape( + ops.cast( + ops.max(can_suppress_others * iou, 1) < iou_threshold, iou.dtype + ), + [batch_size, -1, 1], + ) + * iou + ) + iou_sum_new = ops.sum(iou_after_suppression, [1, 2]) + return [ + iou_after_suppression, + ops.any(iou_sum - iou_sum_new > iou_threshold), + iou_sum_new, + iou_threshold, + ] + + +def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx, tile_size): + """Suppress boxes between different tiles. + + Args: + boxes: a tensor of shape `[batch_size, num_boxes_with_padding, 4]` + box_slice: a tensor of shape `[batch_size, tile_size, 4]` + iou_threshold: a scalar tensor + inner_idx: a scalar tensor representing the tile index of the tile + that is used to supress box_slice + tile_size: an integer representing the number of boxes in a tile + + Returns: + boxes: unchanged boxes as input + box_slice_after_suppression: box_slice after suppression + iou_threshold: unchanged + """ + slice_index = ops.expand_dims( + ops.expand_dims( + ops.cast( + ops.linspace( + inner_idx * tile_size, + (inner_idx + 1) * tile_size - 1, + tile_size, + ), + "int32", + ), + axis=0, + ), + axis=-1, + ) + new_slice = ops.expand_dims( + ops.take_along_axis(boxes, slice_index, axis=1), 0 + ) + iou = _bbox_overlap(new_slice, box_slice) + box_slice_after_suppression = ( + ops.expand_dims( + ops.cast(ops.all(iou < iou_threshold, [1]), box_slice.dtype), 2 + ) + * box_slice + ) + return boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1 + + +def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size): + """Process boxes in the range [idx*tile_size, (idx+1)*tile_size). + + Args: + boxes: a tensor with a shape of [batch_size, anchors, 4]. + iou_threshold: a float representing the threshold for deciding whether + boxes overlap too much with respect to IOU. + output_size: an int32 tensor of size [batch_size]. Representing the + number of selected boxes for each batch. + idx: an integer scalar representing induction variable. + tile_size: an integer representing the number of boxes in a tile + + Returns: + boxes: updated boxes. + iou_threshold: pass down iou_threshold to the next iteration. + output_size: the updated output_size. + idx: the updated induction variable. + """ # noqa: E501 + with keras.name_scope("suppression_loop_body"): + num_tiles = boxes.shape[1] // tile_size + batch_size = boxes.shape[0] + + def cross_suppression_func(boxes, box_slice, iou_threshold, inner_idx): + return _cross_suppression( + boxes, box_slice, iou_threshold, inner_idx, tile_size + ) + + # Iterates over tiles that can possibly suppress the current tile. + slice_index = ops.expand_dims( + ops.expand_dims( + ops.cast( + ops.linspace( + idx * tile_size, (idx + 1) * tile_size - 1, tile_size + ), + "int32", + ), + axis=0, + ), + axis=-1, + ) + box_slice = ops.take_along_axis(boxes, slice_index, axis=1) + _, box_slice, _, _ = ops.while_loop( + lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx, + cross_suppression_func, + [boxes, box_slice, iou_threshold, ops.array(0)], + ) + + # Iterates over the current tile to compute self-suppression. + iou = _bbox_overlap(box_slice, box_slice) + mask = ops.expand_dims( + ops.reshape(ops.arange(tile_size), [1, -1]) + > ops.reshape(ops.arange(tile_size), [-1, 1]), + 0, + ) + iou *= ops.cast(ops.logical_and(mask, iou >= iou_threshold), iou.dtype) + suppressed_iou, _, _, _ = ops.while_loop( + lambda _iou, loop_condition, _iou_sum, _: loop_condition, + _self_suppression, + [iou, ops.array(True), ops.sum(iou, [1, 2]), iou_threshold], + ) + suppressed_box = ops.sum(suppressed_iou, 1) > 0 + box_slice *= ops.expand_dims( + 1.0 - ops.cast(suppressed_box, box_slice.dtype), 2 + ) + + # Uses box_slice to update the input boxes. + mask = ops.reshape( + ops.cast(ops.equal(ops.arange(num_tiles), idx), boxes.dtype), + [1, -1, 1, 1], + ) + boxes = ops.tile( + ops.expand_dims(box_slice, 1), [1, num_tiles, 1, 1] + ) * mask + ops.reshape(boxes, [batch_size, num_tiles, tile_size, 4]) * ( + 1 - mask + ) + boxes = ops.reshape(boxes, [batch_size, -1, 4]) + + # Updates output_size. + output_size += ops.cast( + ops.sum(ops.any(box_slice > 0, [2]), [1]), "int32" + ) + return boxes, iou_threshold, output_size, idx + 1 + + +def mask_invalid_detections(bounding_boxes): + """masks out invalid detections with -1s. + + This utility is mainly used on the output of non-max suppression operations. + The output of non-max-suppression contains all the detections, even invalid + ones. Users are expected to use `num_detections` to determine how many boxes + are in each image. + + In contrast, KerasCV expects all bounding boxes to be padded with -1s. + This function uses the value of `num_detections` to mask out + invalid boxes with -1s. + + Args: + bounding_boxes: a dictionary complying with KerasCV bounding box format. + In addition to the normal required keys, these boxes are also + expected to have a `num_detections` key. + output_ragged: whether to output RaggedTensor based bounding + boxes. + Returns: + bounding boxes with proper masking of the boxes according to + `num_detections`. This allows proper interop with non-max supression. + Returned boxes match the specification fed to the function, so if the + bounding box tensor uses `tf.RaggedTensor` to represent boxes the + returned value will also return `tf.RaggedTensor` representations. + """ + # ensure we are complying with KerasCV bounding box format. + info = validate_format.validate_format(bounding_boxes) + if info["ragged"]: + raise ValueError( + "`bounding_box.mask_invalid_detections()` requires inputs to be " + "Dense tensors. Please call " + "`bounding_box.to_dense(bounding_boxes)` before passing your boxes " + "to `bounding_box.mask_invalid_detections()`." + ) + if "num_detections" not in bounding_boxes: + raise ValueError( + "`bounding_boxes` must have key 'num_detections' " + "to be used with `bounding_box.mask_invalid_detections()`." + ) + + boxes = bounding_boxes.get("boxes") + classes = bounding_boxes.get("classes") + confidence = bounding_boxes.get("confidence", None) + num_detections = bounding_boxes.get("num_detections") + + # Create a mask to select only the first N boxes from each batch + mask = ops.cast( + ops.expand_dims(ops.arange(boxes.shape[1]), axis=0), + num_detections.dtype, + ) + mask = mask < num_detections[:, None] + + classes = ops.where(mask, classes, -ops.ones_like(classes)) + + if confidence is not None: + confidence = ops.where(mask, confidence, -ops.ones_like(confidence)) + + # reuse mask for boxes + mask = ops.expand_dims(mask, axis=-1) + mask = ops.repeat(mask, repeats=boxes.shape[-1], axis=-1) + boxes = ops.where(mask, boxes, -ops.ones_like(boxes)) + + result = bounding_boxes.copy() + + result["boxes"] = boxes + result["classes"] = classes + if confidence is not None: + result["confidence"] = confidence + + return result diff --git a/keras_hub/src/models/retinanet/non_max_supression_test.py b/keras_hub/src/models/retinanet/non_max_supression_test.py new file mode 100644 index 000000000..0af0207ef --- /dev/null +++ b/keras_hub/src/models/retinanet/non_max_supression_test.py @@ -0,0 +1,72 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from keras import ops + +from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression +from keras_hub.src.tests.test_case import TestCase + + +class NonMaxSupressionTest(TestCase): + def test_confidence_threshold(self): + boxes = np.random.uniform(low=0, high=1, size=(2, 5, 4)) + classes = ops.expand_dims( + np.array( + [[0.1, 0.1, 0.4, 0.9, 0.5], [0.7, 0.5, 0.3, 0.0, 0.0]], + "float32", + ), + axis=-1, + ) + + nms = NonMaxSuppression( + bounding_box_format="yxyx", + from_logits=False, + iou_threshold=1.0, + confidence_threshold=0.45, + max_detections=2, + ) + + outputs = nms(boxes, classes) + + self.assertAllClose( + outputs["boxes"], [boxes[0][-2:, ...], boxes[1][:2, ...]] + ) + self.assertAllClose(outputs["classes"], [[0.0, 0.0], [0.0, 0.0]]) + self.assertAllClose(outputs["confidence"], [[0.9, 0.5], [0.7, 0.5]]) + + def test_max_detections(self): + boxes = np.random.uniform(low=0, high=1, size=(2, 5, 4)) + classes = ops.expand_dims( + np.array( + [[0.1, 0.1, 0.4, 0.5, 0.9], [0.7, 0.5, 0.3, 0.0, 0.0]], + "float32", + ), + axis=-1, + ) + + nms = NonMaxSuppression( + bounding_box_format="yxyx", + from_logits=False, + iou_threshold=1.0, + confidence_threshold=0.1, + max_detections=1, + ) + + outputs = nms(boxes, classes) + + self.assertAllClose( + outputs["boxes"], [boxes[0][-1:, ...], boxes[1][:1, ...]] + ) + self.assertAllClose(outputs["classes"], [[0.0], [0.0]]) + self.assertAllClose(outputs["confidence"], [[0.9], [0.7]]) From 27c8894f068885aea001203508f6c43cbf5d4ffa Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 19 Sep 2024 15:13:09 -0700 Subject: [PATCH 2/3] nit --- .../src/models/retinanet/anchor_generator.py | 26 +-- keras_hub/src/models/retinanet/box_matcher.py | 165 +++++++------ .../models/retinanet/non_max_supression.py | 220 +++++++++--------- 3 files changed, 198 insertions(+), 213 deletions(-) diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/models/retinanet/anchor_generator.py index 55d965d46..548e44914 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/models/retinanet/anchor_generator.py @@ -57,24 +57,24 @@ class AnchorGenerator(keras.layers.Layer): sizes. Returns: - OrderedDict: A dictionary mapping feature levels + Dict: A dictionary mapping feature levels (e.g., 'P3', 'P4', etc.) to anchor boxes. Each entry contains a tensor of shape `(H/stride * W/stride * num_anchors_per_location, 4)`, where H and W are the height and width of the image, stride is 2^level, and num_anchors_per_location is `num_scales * len(aspect_ratios)`. Example: - ```python - anchor_generator = AnchorGenerator( - bounding_box_format='xyxy', - min_level=3, - max_level=7, - num_scales=3, - aspect_ratios=[0.5, 1.0, 2.0], - anchor_size=4.0, - ) - anchors = anchor_generator(image_shape=(640, 480)) - ``` + ```python + anchor_generator = AnchorGenerator( + bounding_box_format='xyxy', + min_level=3, + max_level=7, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=4.0, + ) + anchors = anchor_generator(images=keas.ops.ones(shape=(2, 640, 480, 3))) + ``` """ def __init__( @@ -97,7 +97,7 @@ def __init__( self.built = True def call(self, images): - images_shape = images.shape + images_shape = ops.shape(images) if len(images_shape) == 4: image_shape = images_shape[1:-1] else: diff --git a/keras_hub/src/models/retinanet/box_matcher.py b/keras_hub/src/models/retinanet/box_matcher.py index 1ed53d3ca..9dfdcff28 100644 --- a/keras_hub/src/models/retinanet/box_matcher.py +++ b/keras_hub/src/models/retinanet/box_matcher.py @@ -21,7 +21,7 @@ class BoxMatcher(keras.layers.Layer): This class computes matches from a similarity matrix. Each row will be matched to at least one column, the matched result can either be positive - / negative, or simply ignored depending on the setting. + or negative, or simply ignored depending on the setting. The settings include `thresholds` and `match_values`, for example if: 1) `thresholds=[negative_threshold, positive_threshold]`, and @@ -64,13 +64,12 @@ class BoxMatcher(keras.layers.Layer): len(`match_values`) != len(`thresholds`) + 1 Example: - - ```python - box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1]) - iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes) - matched_columns, matched_match_values = box_matcher(iou_metric) - cls_mask = ops.less_equal(matched_match_values, 0) - ``` + ```python + box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1]) + iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes) + matched_columns, matched_match_values = box_matcher(iou_metric) + cls_mask = ops.less_equal(matched_match_values, 0) + ``` """ @@ -134,14 +133,9 @@ def _match_when_cols_are_empty(): num_rows] storing the match type indicator (e.g. positive or negative or ignored match). """ - with keras.name_scope("empty_boxes"): - matched_columns = ops.zeros( - [batch_size, num_rows], dtype="int32" - ) - matched_values = -ops.ones( - [batch_size, num_rows], dtype="int32" - ) - return matched_columns, matched_values + matched_columns = ops.zeros([batch_size, num_rows], dtype="int32") + matched_values = -ops.ones([batch_size, num_rows], dtype="int32") + return matched_columns, matched_values def _match_when_cols_are_non_empty(): """Performs matching when the rows of similarity matrix are @@ -154,80 +148,79 @@ def _match_when_cols_are_non_empty(): num_rows] storing the match type indicator (e.g. positive or negative or ignored match). """ - with keras.name_scope("non_empty_boxes"): - # Jax traces this function even when running eagerly and the - # columns are non-empty. Therefore, we need to handle the case - # where the similarity matrix is empty. We do this by padding - # some -1s to the end. -1s are guaranteed to not affect argmax - # matching because all values in a similarity matrix are [0,1] - # and the indexing won't change because these are added at the - # end. - padded_similarity_matrix = ops.concatenate( - [similarity_matrix, -ops.ones((batch_size, num_rows, 1))], - axis=-1, - ) + # Jax traces this function even when running eagerly and the + # columns are non-empty. Therefore, we need to handle the case + # where the similarity matrix is empty. We do this by padding + # some -1s to the end. -1s are guaranteed to not affect argmax + # matching because all values in a similarity matrix are [0,1] + # and the indexing won't change because these are added at the + # end. + padded_similarity_matrix = ops.concatenate( + [similarity_matrix, -ops.ones((batch_size, num_rows, 1))], + axis=-1, + ) - matched_columns = ops.argmax( - padded_similarity_matrix, - axis=-1, - ) + matched_columns = ops.argmax( + padded_similarity_matrix, + axis=-1, + ) - # Get logical indices of ignored and unmatched columns as int32 - matched_vals = ops.max(padded_similarity_matrix, axis=-1) - matched_values = ops.zeros([batch_size, num_rows], "int32") + # Get logical indices of ignored and unmatched columns as int32 + matched_vals = ops.max(padded_similarity_matrix, axis=-1) + matched_values = ops.zeros([batch_size, num_rows], "int32") - match_dtype = matched_vals.dtype - for ind, low, high in zip( - self.match_values, self.thresholds[:-1], self.thresholds[1:] - ): - low_threshold = ops.cast(low, match_dtype) - high_threshold = ops.cast(high, match_dtype) - mask = ops.logical_and( - ops.greater_equal(matched_vals, low_threshold), - ops.less(matched_vals, high_threshold), - ) - matched_values = self._set_values_using_indicator( - matched_values, mask, ind - ) + match_dtype = matched_vals.dtype + for ind, low, high in zip( + self.match_values, self.thresholds[:-1], self.thresholds[1:] + ): + low_threshold = ops.cast(low, match_dtype) + high_threshold = ops.cast(high, match_dtype) + mask = ops.logical_and( + ops.greater_equal(matched_vals, low_threshold), + ops.less(matched_vals, high_threshold), + ) + matched_values = self._set_values_using_indicator( + matched_values, mask, ind + ) - if self.force_match_for_each_col: - # [batch_size, num_cols], for each column (groundtruth_box), - # find the best matching row (anchor). - matching_rows = ops.argmax( - padded_similarity_matrix, - axis=1, - ) - # [batch_size, num_cols, num_rows], a transposed 0-1 mapping - # matrix M, where M[j, i] = 1 means column j is matched to - # row i. - column_to_row_match_mapping = ops.one_hot( - matching_rows, num_rows - ) - # [batch_size, num_rows], for each row (anchor), find the - # matched column (groundtruth_box). - force_matched_columns = ops.argmax( - column_to_row_match_mapping, - axis=1, - ) - # [batch_size, num_rows] - force_matched_column_mask = ops.cast( - ops.max(column_to_row_match_mapping, axis=1), - "bool", - ) - # [batch_size, num_rows] - matched_columns = ops.where( - force_matched_column_mask, - force_matched_columns, - matched_columns, - ) - matched_values = ops.where( - force_matched_column_mask, - self.match_values[-1] - * ops.ones([batch_size, num_rows], dtype="int32"), - matched_values, - ) + if self.force_match_for_each_col: + # [batch_size, num_cols], for each column (groundtruth_box), + # find the best matching row (anchor). + matching_rows = ops.argmax( + padded_similarity_matrix, + axis=1, + ) + # [batch_size, num_cols, num_rows], a transposed 0-1 mapping + # matrix M, where M[j, i] = 1 means column j is matched to + # row i. + column_to_row_match_mapping = ops.one_hot( + matching_rows, num_rows + ) + # [batch_size, num_rows], for each row (anchor), find the + # matched column (groundtruth_box). + force_matched_columns = ops.argmax( + column_to_row_match_mapping, + axis=1, + ) + # [batch_size, num_rows] + force_matched_column_mask = ops.cast( + ops.max(column_to_row_match_mapping, axis=1), + "bool", + ) + # [batch_size, num_rows] + matched_columns = ops.where( + force_matched_column_mask, + force_matched_columns, + matched_columns, + ) + matched_values = ops.where( + force_matched_column_mask, + self.match_values[-1] + * ops.ones([batch_size, num_rows], dtype="int32"), + matched_values, + ) - return ops.cast(matched_columns, "int32"), matched_values + return ops.cast(matched_columns, "int32"), matched_values num_boxes = ( similarity_matrix.shape[-1] or ops.shape(similarity_matrix)[-1] @@ -255,7 +248,7 @@ def _set_values_using_indicator(self, x, indicator, val): modified tensor. """ indicator = ops.cast(indicator, x.dtype) - return ops.add(ops.multiply(x, 1 - indicator), val * indicator) + return ops.where(indicator == 0, x, val) def get_config(self): config = { diff --git a/keras_hub/src/models/retinanet/non_max_supression.py b/keras_hub/src/models/retinanet/non_max_supression.py index 39a61ab71..1ba6a8f5a 100644 --- a/keras_hub/src/models/retinanet/non_max_supression.py +++ b/keras_hub/src/models/retinanet/non_max_supression.py @@ -30,7 +30,7 @@ class NonMaxSuppression(keras.layers.Layer): Args: bounding_box_format: The format of bounding boxes of input dataset. Refer - [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + TODO: link keras core bounding box docs for more details on supported bounding box formats. from_logits: boolean, True means input score is logits, False means confidence. @@ -42,7 +42,7 @@ class NonMaxSuppression(keras.layers.Layer): max_detections: the maximum detections to consider after nms is applied. A large number may trigger significant memory overhead, defaults to 100. - """ # noqa: E501 + """ def __init__( self, @@ -64,8 +64,7 @@ def __init__( def call( self, box_prediction, class_prediction, images=None, image_shape=None ): - """Accepts images and raw predictions, and returns bounding box - predictions. + """Accepts images and raw scores, returning bounding box predictions. Args: box_prediction: Dense Tensor of shape [batch, boxes, 4] in the @@ -123,7 +122,7 @@ def call( "num_detections": valid_det, } - # this is required to comply with KerasCV bounding box format. + # this is required to comply with bounding box format. return mask_invalid_detections(bounding_boxes) def get_config(self): @@ -148,8 +147,8 @@ def non_max_suppression( score_threshold=0.0, tile_size=512, ): - # Box format must be yxyx """Non-maximum suppression. + Ported from https://github.com/tensorflow/tensorflow/blob/v2.12.0/tensorflow/python/ops/image_ops_impl.py#L5368-L5458 Args: @@ -180,7 +179,7 @@ def non_max_suppression( num_valid: a tensor of rank 0 or higher with a shape of [...] representing the number of valid indices in idx. Its dimensions are the batch dimensions of the input boxes. - """ # noqa: E501 + """ def _sort_scores_and_boxes(scores, boxes): """Sort boxes based their score from highest to lowest. @@ -198,21 +197,20 @@ def _sort_scores_and_boxes(scores, boxes): sorted_scores_indices: a tensor with a shape of `[batch_size, num_boxes]` representing the index of the scores in a sorted descending order. - """ # noqa: E501 - with keras.name_scope("sort_scores_and_boxes"): - sorted_scores_indices = ops.flip( - ops.cast(ops.argsort(scores, axis=1), "int32"), axis=1 - ) - sorted_scores = ops.take_along_axis( - scores, - sorted_scores_indices, - axis=1, - ) - sorted_boxes = ops.take_along_axis( - boxes, - ops.expand_dims(sorted_scores_indices, axis=-1), - axis=1, - ) + """ + sorted_scores_indices = ops.flip( + ops.cast(ops.argsort(scores, axis=1), "int32"), axis=1 + ) + sorted_scores = ops.take_along_axis( + scores, + sorted_scores_indices, + axis=1, + ) + sorted_boxes = ops.take_along_axis( + boxes, + ops.expand_dims(sorted_scores_indices, axis=-1), + axis=1, + ) return sorted_scores, sorted_boxes, sorted_scores_indices batch_dims = ops.shape(boxes)[:-2] @@ -221,11 +219,10 @@ def _sort_scores_and_boxes(scores, boxes): scores = ops.reshape(scores, [-1, num_boxes]) batch_size = boxes.shape[0] if score_threshold != float("-inf"): - with keras.name_scope("filter_by_score"): - score_mask = ops.cast(scores > score_threshold, scores.dtype) - scores *= score_mask - box_mask = ops.expand_dims(ops.cast(score_mask, boxes.dtype), 2) - boxes *= box_mask + score_mask = ops.cast(scores > score_threshold, scores.dtype) + scores *= score_mask + box_mask = ops.expand_dims(ops.cast(score_mask, boxes.dtype), 2) + boxes *= box_mask scores, boxes, sorted_indices = _sort_scores_and_boxes(scores, boxes) @@ -315,32 +312,31 @@ def _bbox_overlap(boxes_a, boxes_b): intersection_over_union: a tensor with as a shape of `[batch_size, N, M]`, representing the ratio of intersection area over union area (IoU) between two boxes - """ # noqa: E501 - with keras.name_scope("bbox_overlap"): - if len(boxes_a.shape) == 4: - boxes_a = ops.squeeze(boxes_a, axis=0) - a_y_min, a_x_min, a_y_max, a_x_max = ops.split(boxes_a, 4, axis=2) - b_y_min, b_x_min, b_y_max, b_x_max = ops.split(boxes_b, 4, axis=2) - - # Calculates the intersection area. - i_xmin = ops.maximum(a_x_min, ops.transpose(b_x_min, [0, 2, 1])) - i_xmax = ops.minimum(a_x_max, ops.transpose(b_x_max, [0, 2, 1])) - i_ymin = ops.maximum(a_y_min, ops.transpose(b_y_min, [0, 2, 1])) - i_ymax = ops.minimum(a_y_max, ops.transpose(b_y_max, [0, 2, 1])) - i_area = ops.maximum((i_xmax - i_xmin), 0) * ops.maximum( - (i_ymax - i_ymin), 0 - ) + """ + if len(boxes_a.shape) == 4: + boxes_a = ops.squeeze(boxes_a, axis=0) + a_y_min, a_x_min, a_y_max, a_x_max = ops.split(boxes_a, 4, axis=2) + b_y_min, b_x_min, b_y_max, b_x_max = ops.split(boxes_b, 4, axis=2) + + # Calculates the intersection area. + i_xmin = ops.maximum(a_x_min, ops.transpose(b_x_min, [0, 2, 1])) + i_xmax = ops.minimum(a_x_max, ops.transpose(b_x_max, [0, 2, 1])) + i_ymin = ops.maximum(a_y_min, ops.transpose(b_y_min, [0, 2, 1])) + i_ymax = ops.minimum(a_y_max, ops.transpose(b_y_max, [0, 2, 1])) + i_area = ops.maximum((i_xmax - i_xmin), 0) * ops.maximum( + (i_ymax - i_ymin), 0 + ) - # Calculates the union area. - a_area = (a_y_max - a_y_min) * (a_x_max - a_x_min) - b_area = (b_y_max - b_y_min) * (b_x_max - b_x_min) + # Calculates the union area. + a_area = (a_y_max - a_y_min) * (a_x_max - a_x_min) + b_area = (b_y_max - b_y_min) * (b_x_max - b_x_min) - # Adds a small epsilon to avoid divide-by-zero. - u_area = a_area + ops.transpose(b_area, [0, 2, 1]) - i_area + EPSILON + # Adds a small epsilon to avoid divide-by-zero. + u_area = a_area + ops.transpose(b_area, [0, 2, 1]) - i_area + EPSILON - intersection_over_union = i_area / u_area + intersection_over_union = i_area / u_area - return intersection_over_union + return intersection_over_union def _self_suppression(iou, _, iou_sum, iou_threshold): @@ -363,7 +359,7 @@ def _self_suppression(iou, _, iou_sum, iou_threshold): iou_sum_new: a scalar tensor of shape `[batch_size]` that represents the iou sum after suppression. iou_threshold: a scalar tensor. - """ # noqa: E501 + """ batch_size = ops.shape(iou)[0] can_suppress_others = ops.cast( ops.reshape(ops.max(iou, 1) < iou_threshold, [batch_size, -1, 1]), @@ -447,70 +443,67 @@ def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size): iou_threshold: pass down iou_threshold to the next iteration. output_size: the updated output_size. idx: the updated induction variable. - """ # noqa: E501 - with keras.name_scope("suppression_loop_body"): - num_tiles = boxes.shape[1] // tile_size - batch_size = boxes.shape[0] - - def cross_suppression_func(boxes, box_slice, iou_threshold, inner_idx): - return _cross_suppression( - boxes, box_slice, iou_threshold, inner_idx, tile_size - ) - - # Iterates over tiles that can possibly suppress the current tile. - slice_index = ops.expand_dims( - ops.expand_dims( - ops.cast( - ops.linspace( - idx * tile_size, (idx + 1) * tile_size - 1, tile_size - ), - "int32", + """ + num_tiles = boxes.shape[1] // tile_size + batch_size = boxes.shape[0] + + def cross_suppression_func(boxes, box_slice, iou_threshold, inner_idx): + return _cross_suppression( + boxes, box_slice, iou_threshold, inner_idx, tile_size + ) + + # Iterates over tiles that can possibly suppress the current tile. + slice_index = ops.expand_dims( + ops.expand_dims( + ops.cast( + ops.linspace( + idx * tile_size, (idx + 1) * tile_size - 1, tile_size ), - axis=0, + "int32", ), - axis=-1, - ) - box_slice = ops.take_along_axis(boxes, slice_index, axis=1) - _, box_slice, _, _ = ops.while_loop( - lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx, - cross_suppression_func, - [boxes, box_slice, iou_threshold, ops.array(0)], - ) + axis=0, + ), + axis=-1, + ) + box_slice = ops.take_along_axis(boxes, slice_index, axis=1) + _, box_slice, _, _ = ops.while_loop( + lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx, + cross_suppression_func, + [boxes, box_slice, iou_threshold, ops.array(0)], + ) - # Iterates over the current tile to compute self-suppression. - iou = _bbox_overlap(box_slice, box_slice) - mask = ops.expand_dims( - ops.reshape(ops.arange(tile_size), [1, -1]) - > ops.reshape(ops.arange(tile_size), [-1, 1]), - 0, - ) - iou *= ops.cast(ops.logical_and(mask, iou >= iou_threshold), iou.dtype) - suppressed_iou, _, _, _ = ops.while_loop( - lambda _iou, loop_condition, _iou_sum, _: loop_condition, - _self_suppression, - [iou, ops.array(True), ops.sum(iou, [1, 2]), iou_threshold], - ) - suppressed_box = ops.sum(suppressed_iou, 1) > 0 - box_slice *= ops.expand_dims( - 1.0 - ops.cast(suppressed_box, box_slice.dtype), 2 - ) + # Iterates over the current tile to compute self-suppression. + iou = _bbox_overlap(box_slice, box_slice) + mask = ops.expand_dims( + ops.reshape(ops.arange(tile_size), [1, -1]) + > ops.reshape(ops.arange(tile_size), [-1, 1]), + 0, + ) + iou *= ops.cast(ops.logical_and(mask, iou >= iou_threshold), iou.dtype) + suppressed_iou, _, _, _ = ops.while_loop( + lambda _iou, loop_condition, _iou_sum, _: loop_condition, + _self_suppression, + [iou, ops.array(True), ops.sum(iou, [1, 2]), iou_threshold], + ) + suppressed_box = ops.sum(suppressed_iou, 1) > 0 + box_slice *= ops.expand_dims( + 1.0 - ops.cast(suppressed_box, box_slice.dtype), 2 + ) - # Uses box_slice to update the input boxes. - mask = ops.reshape( - ops.cast(ops.equal(ops.arange(num_tiles), idx), boxes.dtype), - [1, -1, 1, 1], - ) - boxes = ops.tile( - ops.expand_dims(box_slice, 1), [1, num_tiles, 1, 1] - ) * mask + ops.reshape(boxes, [batch_size, num_tiles, tile_size, 4]) * ( - 1 - mask - ) - boxes = ops.reshape(boxes, [batch_size, -1, 4]) + # Uses box_slice to update the input boxes. + mask = ops.reshape( + ops.cast(ops.equal(ops.arange(num_tiles), idx), boxes.dtype), + [1, -1, 1, 1], + ) + boxes = ops.tile( + ops.expand_dims(box_slice, 1), [1, num_tiles, 1, 1] + ) * mask + ops.reshape(boxes, [batch_size, num_tiles, tile_size, 4]) * ( + 1 - mask + ) + boxes = ops.reshape(boxes, [batch_size, -1, 4]) - # Updates output_size. - output_size += ops.cast( - ops.sum(ops.any(box_slice > 0, [2]), [1]), "int32" - ) + # Updates output_size. + output_size += ops.cast(ops.sum(ops.any(box_slice > 0, [2]), [1]), "int32") return boxes, iou_threshold, output_size, idx + 1 @@ -522,16 +515,15 @@ def mask_invalid_detections(bounding_boxes): ones. Users are expected to use `num_detections` to determine how many boxes are in each image. - In contrast, KerasCV expects all bounding boxes to be padded with -1s. + In contrast, KerasHub expects all bounding boxes to be padded with -1s. This function uses the value of `num_detections` to mask out invalid boxes with -1s. Args: - bounding_boxes: a dictionary complying with KerasCV bounding box format. + bounding_boxes: a dictionary complying with Keras bounding box format. In addition to the normal required keys, these boxes are also expected to have a `num_detections` key. - output_ragged: whether to output RaggedTensor based bounding - boxes. + Returns: bounding boxes with proper masking of the boxes according to `num_detections`. This allows proper interop with non-max supression. @@ -539,7 +531,7 @@ def mask_invalid_detections(bounding_boxes): bounding box tensor uses `tf.RaggedTensor` to represent boxes the returned value will also return `tf.RaggedTensor` representations. """ - # ensure we are complying with KerasCV bounding box format. + # ensure we are complying with Keras bounding box format. info = validate_format.validate_format(bounding_boxes) if info["ragged"]: raise ValueError( From 27695c30024fb825658cb3cad1006ad97a5a76c7 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 19 Sep 2024 15:36:21 -0700 Subject: [PATCH 3/3] nit --- keras_hub/src/models/retinanet/anchor_generator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/models/retinanet/anchor_generator.py index 548e44914..10b562437 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/models/retinanet/anchor_generator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import math import keras @@ -105,7 +104,7 @@ def call(self, images): image_shape = tuple(image_shape) - multilevel_boxes = collections.OrderedDict() + multilevel_boxes = {} for level in range(self.min_level, self.max_level + 1): boxes_l = [] # Calculate the feature map size for this level