Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add method to save predictions to csv #50

Merged
merged 7 commits into from
Aug 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions wildlifeml/training/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
)

import numpy as np
import pandas as pd
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
)
from tensorflow.keras import Model

from wildlifeml.data import (
BBoxMapper,
WildlifeDataset,
subset_dataset,
)
from wildlifeml.training.trainer import BaseTrainer
from wildlifeml.utils.datasets import (
map_bbox_to_img,
map_preds_to_img,
Expand All @@ -32,21 +33,25 @@ class Evaluator:

def __init__(
self,
label_file_path: str,
detector_file_path: str,
dataset: WildlifeDataset,
num_classes: int,
label_file_path: Optional[str] = None,
empty_class_id: Optional[int] = None,
conf_threshold: Optional[float] = None,
batch_size: int = 64,
) -> None:
"""Initialize evaluator object."""
self.detector_dict = load_json(detector_file_path)
self.label_dict = {key: float(val) for key, val in load_csv(label_file_path)}
self.batch_size = batch_size
self.num_classes = num_classes
self.conf_threshold = conf_threshold

if label_file_path is not None:
self.label_dict = {k: float(v) for k, v in load_csv(label_file_path)}
else:
self.label_dict = {}

# Index what images are contained in the eval dataset
self.dataset_imgs = set([map_bbox_to_img(k) for k in dataset.keys])
# Get mapping of img -> bboxs for dataset
Expand Down Expand Up @@ -90,10 +95,10 @@ def __init__(
self.truth_imgs_clf: List = []
self.truth_imgs_ppl: List = []

def evaluate(self, trainer: BaseTrainer) -> None:
def evaluate(self, model: Model) -> None:
"""Obtain metrics for a supplied model."""
# Get predictions for bboxs
self.preds = trainer.predict(self.dataset)
self.preds = model.predict(self.dataset)

# Above predictions are on bbox level, but image level prediction is desired.
# For this every prediction is reweighted with the MD confidence score.
Expand All @@ -116,8 +121,13 @@ def evaluate(self, trainer: BaseTrainer) -> None:
detector_dict=self.detector_dict,
empty_class_id=self.empty_class_id,
)
self.truth_imgs_clf = [self.label_dict[k] for k in self.preds_imgs_clf.keys()]
self.truth_imgs_ppl = [self.label_dict[k] for k in self.preds_imgs_ppl.keys()]
if len(self.label_dict) > 0:
self.truth_imgs_clf = [
self.label_dict[k] for k in self.preds_imgs_clf.keys()
]
self.truth_imgs_ppl = [
self.label_dict[k] for k in self.preds_imgs_ppl.keys()
]

def get_details(self) -> Dict:
"""Obtain further details about predictions."""
Expand All @@ -132,8 +142,30 @@ def get_details(self) -> Dict:
'truth_imgs_ppl': self.truth_imgs_ppl,
}

def save_predictions(self, filepath: str, img_level: bool = True) -> None:
"""Save predictions to csv file."""
details = self.get_details()
keys = details['keys_bbox_empty'] + details['keys_bbox_nonempty']
if img_level:
keys = list(details['preds_imgs_ppl'].keys())
preds = list(details['preds_imgs_ppl'].values())
else:
preds = list(np.concatenate([self.empty_pred_arr, self.preds]))
labels = [np.argmax(x) for x in preds]
df = pd.DataFrame(
list(zip(keys, labels, preds)),
columns=['img_key', 'hard_label', 'prediction'],
)
df[[f'prob_class_{i}' for i in range(self.num_classes)]] = pd.DataFrame(
df.prediction.to_list(), index=df.index
).astype(float)
df = df.drop(columns=['prediction'])
df.to_csv(filepath, float_format='%.6f')

def compute_metrics(self) -> Dict:
"""Compute eval metrics for predictions."""
if len(self.truth_imgs_ppl) == 0:
raise ValueError('Metrics can only be computed with ground-truth labels.')
return self._compute_metrics(
np.array(self.truth_imgs_ppl),
np.array([np.argmax(v) for v in self.preds_imgs_ppl.values()]),
Expand Down
Loading