-
Notifications
You must be signed in to change notification settings - Fork 2
/
evaluate.py
105 lines (85 loc) · 3.6 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import sys
import warnings
import pandas as pd
import numpy as np
import torch
from torchvision.transforms import v2
sys.path.append('../src')
sys.path.append('./methods')
from model.frame import FrameModel
from data.datasets import DeepfakeDataset
from evaluation.compute_metrics import computeExplanationMetrics
from evaluation.generate_ff_test_data import getFFPath
#"GradCAM++" - "RISE" - "SHAP" - "LIME" - "SOBOL" - "All"
evaluation_explanation_methods="All"
valid_methods=["GradCAM++", "RISE", "SHAP", "LIME", "SOBOL", "All"]
if(evaluation_explanation_methods not in valid_methods):
print("Invalid explanation method(s) to evaluate")
sys.exit(0)
#Load the model
rs_size = 224
task = "multiclass"
model = FrameModel.load_from_checkpoint("../model/checkpoint/ff_attribution.ckpt",map_location='cuda').eval()
#Create the transforms for inference and visualization purposes
interpolation = 3
inference_transforms = v2.Compose([
v2.ToImage(),
v2.Resize(rs_size, interpolation=interpolation, antialias=False),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
visualize_transforms = v2.Compose([
v2.ToImage(),
v2.Resize(rs_size, interpolation=interpolation, antialias=False),
v2.ToDtype(torch.float32, scale=True),
])
#Create the deepfake test examples and load the dataset
ds_path = getFFPath("../data/csvs/ff_test.csv")
#Dataset with inference transformations
target_transforms = lambda x: torch.tensor(x, dtype=torch.float32)
ds = DeepfakeDataset(
ds_path,
"../data/xai_test_data.lmdb",
transforms=inference_transforms,
target_transforms=target_transforms,
task=task
)
#Dataset with visualization transformations
ds_visualize = DeepfakeDataset(
ds_path,
"../data/xai_test_data.lmdb",
transforms=visualize_transforms,
target_transforms=target_transforms,
task=task
)
if not os.path.exists('./results'):
os.makedirs('./results')
#Call the corresponding function to compute them
computeExplanationMetrics(model, ds, ds_visualize, inference_transforms, evaluation_explanation_methods)
#Load the saved results
save_name="results_"+evaluation_explanation_methods
scores = np.load("./results/"+save_name+".npy")
scores = list(scores)
#Compute the mean values
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
scores=np.nanmean(scores,axis=0)
#Create a dataframe and save the results in a csv
if(evaluation_explanation_methods=="All"):
index_values = ["Original", "GradCAM++", "RISE", "SHAP", "LIME", "SOBOL"]
else:
index_values = ["Original", evaluation_explanation_methods]
column_values = ["Sufficiency DF Top 1","Sufficiency DF Top 2","Sufficiency DF Top 3", "Stability DF",
"Sufficiency F2F Top 1","Sufficiency F2F Top 2","Sufficiency F2F Top 3", "Stability F2F",
"Sufficiency FS Top 1","Sufficiency FS Top 2","Sufficiency FS Top 3", "Stability FS",
"Sufficiency NT Top 1","Sufficiency NT Top 2","Sufficiency NT Top 3", "Stability NT",
"Accuracy",
"Accuracy DF (Top 1)","Accuracy DF Top 2", "Accuracy DF Top 3",
"Accuracy F2F (Top 1)","Accuracy F2F Top 2", "Accuracy F2F Top 3",
"Accuracy FS (Top 1)","Accuracy FS Top 2", "Accuracy FS Top 3",
"Accuracy NT (Top 1)","Accuracy NT Top 2", "Accuracy NT Top 3"]
df = pd.DataFrame(data=scores,index=index_values,columns=column_values)
csv_save_name="scores_"+evaluation_explanation_methods
df.round(3).to_csv("./results/"+csv_save_name+".csv",sep=',')
print(df.round(3).to_string())