From 2d3f53c19d24a71bd2a2d6c8c584c8f3dfecb312 Mon Sep 17 00:00:00 2001 From: Chonghan Chen <33018020+PaulCCCCCCH@users.noreply.github.com> Date: Sun, 6 Nov 2022 12:52:26 -0500 Subject: [PATCH] Fixes TypeError when running on GPU A fix for #34 --- pytorch_influence_functions/calc_influence_function.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_influence_functions/calc_influence_function.py b/pytorch_influence_functions/calc_influence_function.py index 8861cd4..8499a82 100644 --- a/pytorch_influence_functions/calc_influence_function.py +++ b/pytorch_influence_functions/calc_influence_function.py @@ -269,7 +269,7 @@ def calc_influence_function(train_dataset_size, grad_z_vecs=None, # There is one grad_z per training data sample ################################### ]) / train_dataset_size - influences.append(tmp_influence) + influences.append(tmp_influence.cpu()) display_progress("Calc. influence function: ", i, train_dataset_size) harmful = np.argsort(influences) @@ -340,7 +340,7 @@ def calc_influence_single(model, train_loader, test_loader, test_id_num, gpu, torch.sum(k * j).data for k, j in zip(grad_z_vec, s_test_vec) ]) / train_dataset_size - influences.append(tmp_influence) + influences.append(tmp_influence.cpu()) display_progress("Calc. influence function: ", i, train_dataset_size) harmful = np.argsort(influences)