diff --git a/mmdet/evaluation/metrics/coco_metric.py b/mmdet/evaluation/metrics/coco_metric.py index cfdc66e03b9..5733fb4f373 100644 --- a/mmdet/evaluation/metrics/coco_metric.py +++ b/mmdet/evaluation/metrics/coco_metric.py @@ -18,6 +18,13 @@ from mmdet.structures.mask import encode_mask_results from ..functional import eval_recalls +try: + from faster_coco_eval import COCO as FasterCOCO + from faster_coco_eval import COCOeval_faster +except ImportError: + FasterCOCO = None + COCOeval_faster = None + @METRICS.register_module() class CocoMetric(BaseMetric): @@ -64,6 +71,7 @@ class CocoMetric(BaseMetric): sort_categories (bool): Whether sort categories in annotations. Only used for `Objects365V1Dataset`. Defaults to False. use_mp_eval (bool): Whether to use mul-processing evaluation + use_faster_coco_eval (bool): Whether to use Faster-COCO-Eval evaluation """ default_prefix: Optional[str] = 'coco' @@ -81,7 +89,8 @@ def __init__(self, collect_device: str = 'cpu', prefix: Optional[str] = None, sort_categories: bool = False, - use_mp_eval: bool = False) -> None: + use_mp_eval: bool = False, + use_faster_coco_eval: bool = False) -> None: super().__init__(collect_device=collect_device, prefix=prefix) # coco evaluation metrics self.metrics = metric if isinstance(metric, list) else [metric] @@ -96,6 +105,11 @@ def __init__(self, self.classwise = classwise # whether to use multi processing evaluation, default False self.use_mp_eval = use_mp_eval + # whether to use Faster Coco Eval, default False + self.use_faster_coco_eval = use_faster_coco_eval + if self.use_faster_coco_eva: + assert FasterCOCO is not None, 'faster-coco-eval is not installed' + raise RuntimeError('faster-coco-eval is not installed') # proposal_nums used to compute recall or precision. self.proposal_nums = list(proposal_nums) @@ -127,7 +141,10 @@ def __init__(self, if ann_file is not None: with get_local_path( ann_file, backend_args=self.backend_args) as local_path: - self._coco_api = COCO(local_path) + if self.use_faster_coco_eval: + self._coco_api = FasterCOCO(local_path) + else: + self._coco_api = COCO(local_path) if sort_categories: # 'categories' list in objects365_train.json and # objects365_val.json is inconsistent, need sort @@ -410,7 +427,10 @@ def compute_metrics(self, results: list) -> Dict[str, float]: logger.info('Converting ground truth to coco format...') coco_json_path = self.gt_to_coco_json( gt_dicts=gts, outfile_prefix=outfile_prefix) - self._coco_api = COCO(coco_json_path) + if self.use_faster_coco_eval: + self._coco_api = FasterCOCO(coco_json_path) + else: + self._coco_api = COCO(coco_json_path) # handle lazy init if self.cat_ids is None: @@ -468,6 +488,13 @@ def compute_metrics(self, results: list) -> Dict[str, float]: if self.use_mp_eval: coco_eval = COCOevalMP(self._coco_api, coco_dt, iou_type) + elif self.use_faster_coco_eval: + coco_eval = COCOeval_faster( + self._coco_api, + coco_dt, + iou_type, + print_function=logger.info, + ) else: coco_eval = COCOeval(self._coco_api, coco_dt, iou_type) diff --git a/requirements/optional.txt b/requirements/optional.txt index 31bdde50bea..3e65e25ef14 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,5 +1,6 @@ cityscapesscripts emoji fairscale +faster-coco-eval imagecorruptions scikit-learn diff --git a/tests/test_evaluation/test_metrics/test_coco_metric.py b/tests/test_evaluation/test_metrics/test_coco_metric.py index 547b8f21e0f..9dd58d0f91d 100644 --- a/tests/test_evaluation/test_metrics/test_coco_metric.py +++ b/tests/test_evaluation/test_metrics/test_coco_metric.py @@ -1,14 +1,21 @@ import os.path as osp import tempfile +import unittest from unittest import TestCase import numpy as np import pycocotools.mask as mask_util import torch from mmengine.fileio import dump +from parameterized import parameterized from mmdet.evaluation import CocoMetric +try: + from faster_coco_eval import COCO as FasterCOCO +except ImportError: + FasterCOCO = None + class TestCocoMetric(TestCase): @@ -111,7 +118,11 @@ def test_init(self): with self.assertRaisesRegex(KeyError, 'metric should be one of'): CocoMetric(ann_file=fake_json_file, metric='unknown') - def test_evaluate(self): + @parameterized.expand([False, True]) + def test_evaluate(self, use_faster_coco_eval): + if use_faster_coco_eval and (FasterCOCO is None): + return unittest.skip('faster-coco-eval is not installed') + # create dummy data fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') self._create_dummy_coco_json(fake_json_file) @@ -121,7 +132,9 @@ def test_evaluate(self): coco_metric = CocoMetric( ann_file=fake_json_file, classwise=False, - outfile_prefix=f'{self.tmp_dir.name}/test') + outfile_prefix=f'{self.tmp_dir.name}/test', + use_faster_coco_eval=use_faster_coco_eval, + ) coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process( {}, @@ -144,7 +157,9 @@ def test_evaluate(self): ann_file=fake_json_file, metric=['bbox', 'segm'], classwise=False, - outfile_prefix=f'{self.tmp_dir.name}/test') + outfile_prefix=f'{self.tmp_dir.name}/test', + use_faster_coco_eval=use_faster_coco_eval, + ) coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process( {}, @@ -174,7 +189,10 @@ def test_evaluate(self): with self.assertRaisesRegex(KeyError, 'metric item "invalid" is not supported'): coco_metric = CocoMetric( - ann_file=fake_json_file, metric_items=['invalid']) + ann_file=fake_json_file, + metric_items=['invalid'], + use_faster_coco_eval=use_faster_coco_eval, + ) coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process({}, [ dict( @@ -184,7 +202,10 @@ def test_evaluate(self): # test custom metric_items coco_metric = CocoMetric( - ann_file=fake_json_file, metric_items=['mAP_m']) + ann_file=fake_json_file, + metric_items=['mAP_m'], + use_faster_coco_eval=use_faster_coco_eval, + ) coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process( {}, @@ -195,7 +216,11 @@ def test_evaluate(self): } self.assertDictEqual(eval_results, target) - def test_classwise_evaluate(self): + @parameterized.expand([False, True]) + def test_classwise_evaluate(self, use_faster_coco_eval): + if use_faster_coco_eval and (FasterCOCO is None): + return unittest.skip('faster-coco-eval is not installed') + # create dummy data fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') self._create_dummy_coco_json(fake_json_file) @@ -203,9 +228,12 @@ def test_classwise_evaluate(self): # test single coco dataset evaluation coco_metric = CocoMetric( - ann_file=fake_json_file, metric='bbox', classwise=True) - # coco_metric1 = CocoMetric( - # ann_file=fake_json_file, metric='bbox', classwise=True) + ann_file=fake_json_file, + metric='bbox', + classwise=True, + use_faster_coco_eval=use_faster_coco_eval, + ) + coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process( {}, @@ -223,18 +251,30 @@ def test_classwise_evaluate(self): } self.assertDictEqual(eval_results, target) - def test_manually_set_iou_thrs(self): + @parameterized.expand([False, True]) + def test_manually_set_iou_thrs(self, use_faster_coco_eval): + if use_faster_coco_eval and (FasterCOCO is None): + return unittest.skip('faster-coco-eval is not installed') + # create dummy data fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') self._create_dummy_coco_json(fake_json_file) # test single coco dataset evaluation coco_metric = CocoMetric( - ann_file=fake_json_file, metric='bbox', iou_thrs=[0.3, 0.6]) + ann_file=fake_json_file, + metric='bbox', + iou_thrs=[0.3, 0.6], + use_faster_coco_eval=use_faster_coco_eval, + ) coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) self.assertEqual(coco_metric.iou_thrs, [0.3, 0.6]) - def test_fast_eval_recall(self): + @parameterized.expand([False, True]) + def test_fast_eval_recall(self, use_faster_coco_eval): + if use_faster_coco_eval and (FasterCOCO is None): + return unittest.skip('faster-coco-eval is not installed') + # create dummy data fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') self._create_dummy_coco_json(fake_json_file) @@ -242,7 +282,10 @@ def test_fast_eval_recall(self): # test default proposal nums coco_metric = CocoMetric( - ann_file=fake_json_file, metric='proposal_fast') + ann_file=fake_json_file, + metric='proposal_fast', + use_faster_coco_eval=use_faster_coco_eval, + ) coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process( {}, @@ -264,13 +307,21 @@ def test_fast_eval_recall(self): target = {'coco/AR@2': 0.5, 'coco/AR@4': 1.0} self.assertDictEqual(eval_results, target) - def test_evaluate_proposal(self): + @parameterized.expand([False, True]) + def test_evaluate_proposal(self, use_faster_coco_eval): + if use_faster_coco_eval and (FasterCOCO is None): + return unittest.skip('faster-coco-eval is not installed') + # create dummy data fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') self._create_dummy_coco_json(fake_json_file) dummy_pred = self._create_dummy_results() - coco_metric = CocoMetric(ann_file=fake_json_file, metric='proposal') + coco_metric = CocoMetric( + ann_file=fake_json_file, + metric='proposal', + use_faster_coco_eval=use_faster_coco_eval, + ) coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process( {}, @@ -287,11 +338,19 @@ def test_evaluate_proposal(self): } self.assertDictEqual(eval_results, target) - def test_empty_results(self): + @parameterized.expand([False, True]) + def test_empty_results(self, use_faster_coco_eval): + if use_faster_coco_eval and (FasterCOCO is None): + return unittest.skip('faster-coco-eval is not installed') + # create dummy data fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') self._create_dummy_coco_json(fake_json_file) - coco_metric = CocoMetric(ann_file=fake_json_file, metric='bbox') + coco_metric = CocoMetric( + ann_file=fake_json_file, + metric='bbox', + use_faster_coco_eval=use_faster_coco_eval, + ) coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) bboxes = np.zeros((0, 4)) labels = np.array([]) @@ -308,7 +367,11 @@ def test_empty_results(self): # coco api Index error will be caught coco_metric.evaluate(size=1) - def test_evaluate_without_json(self): + @parameterized.expand([False, True]) + def test_evaluate_without_json(self, use_faster_coco_eval): + if use_faster_coco_eval and (FasterCOCO is None): + return unittest.skip('faster-coco-eval is not installed') + dummy_pred = self._create_dummy_results() dummy_mask = np.zeros((10, 10), order='F', dtype=np.uint8) @@ -340,7 +403,8 @@ def test_evaluate_without_json(self): ann_file=None, metric=['bbox', 'segm'], classwise=False, - outfile_prefix=f'{self.tmp_dir.name}/test') + outfile_prefix=f'{self.tmp_dir.name}/test', + use_faster_coco_eval=use_faster_coco_eval) coco_metric.dataset_meta = dict(classes=['car', 'bicycle']) coco_metric.process({}, [ dict( @@ -373,7 +437,11 @@ def test_evaluate_without_json(self): self.assertTrue( osp.isfile(osp.join(self.tmp_dir.name, 'test.gt.json'))) - def test_format_only(self): + @parameterized.expand([False, True]) + def test_format_only(self, use_faster_coco_eval): + if use_faster_coco_eval and (FasterCOCO is None): + return unittest.skip('faster-coco-eval is not installed') + # create dummy data fake_json_file = osp.join(self.tmp_dir.name, 'fake_data.json') self._create_dummy_coco_json(fake_json_file) @@ -384,7 +452,8 @@ def test_format_only(self): ann_file=fake_json_file, classwise=False, format_only=True, - outfile_prefix=None) + outfile_prefix=None, + use_faster_coco_eval=use_faster_coco_eval) coco_metric = CocoMetric( ann_file=fake_json_file,