diff --git a/official/vision/configs/retinanet.py b/official/vision/configs/retinanet.py index 9fdd258ac1d..98da0a3721b 100644 --- a/official/vision/configs/retinanet.py +++ b/official/vision/configs/retinanet.py @@ -135,6 +135,8 @@ class DetectionGenerator(hyperparams.Config): ) # Return decoded boxes/scores even if apply_nms is set `True`. return_decoded: Optional[bool] = None + # Only works when nms_version='v2'. + use_class_agnostic_nms: Optional[bool] = False @dataclasses.dataclass diff --git a/official/vision/modeling/factory.py b/official/vision/modeling/factory.py index 927326cb527..05cbd18c75e 100644 --- a/official/vision/modeling/factory.py +++ b/official/vision/modeling/factory.py @@ -323,6 +323,7 @@ def build_retinanet( soft_nms_sigma=generator_config.soft_nms_sigma, tflite_post_processing_config=tflite_post_processing_config, return_decoded=generator_config.return_decoded, + use_class_agnostic_nms=generator_config.use_class_agnostic_nms, ) model = retinanet_model.RetinaNetModel(