diff --git a/fig2_linear_approx/LICENSE b/fig2_linear_approx/LICENSE new file mode 100644 index 0000000..657aa68 --- /dev/null +++ b/fig2_linear_approx/LICENSE @@ -0,0 +1 @@ +GPL-2 or later \ No newline at end of file diff --git a/fig2_linear_approx/__init__.py b/fig2_linear_approx/__init__.py new file mode 100644 index 0000000..9cbfbdc --- /dev/null +++ b/fig2_linear_approx/__init__.py @@ -0,0 +1,13 @@ +# __init__.py + +from .influence_functions_toolkits.influence_functions import ( + calc_img_wise, + calc_all_grad_then_test, + calc_influence_single, + s_test_sample, +) +from .influence_functions_toolkits.utils import ( + init_logging, + display_progress, + get_default_config +) diff --git a/fig2_linear_approx/influence_functions_toolkits/hvp_grad.py b/fig2_linear_approx/influence_functions_toolkits/hvp_grad.py new file mode 100644 index 0000000..bd52698 --- /dev/null +++ b/fig2_linear_approx/influence_functions_toolkits/hvp_grad.py @@ -0,0 +1,260 @@ +#! /usr/bin/env python3 +import torch +import torch.nn.functional as F +from torch.nn.utils import parameters_to_vector +from torch.autograd import grad +from torch.autograd.functional import vhp +from torch.utils.data import DataLoader +from tqdm import tqdm + +from fig2_linear_approx.influence_functions_toolkits.utils import ( + conjugate_gradient, + load_weights, + make_functional, + tensor_to_tuple, +) + + +def s_test_cg(x_test, y_test, model, train_loader, damp, gpu=-1, verbose=True, loss_func="cross_entropy"): + + if gpu > 0: + x_test, y_test = x_test.cuda(), y_test.cuda() + + v_flat = parameters_to_vector(grad_z(x_test, y_test, model, gpu, loss_func=loss_func)) + + def hvp_fn(x): + + x_tensor = torch.tensor(x, requires_grad=False) + if gpu > 0: + x_tensor = x_tensor.cuda() + + params, names = make_functional(model) + # Make params regular Tensors instead of nn.Parameter + params = tuple(p.detach().requires_grad_() for p in params) + flat_params = parameters_to_vector(params) + + hvp = torch.zeros_like(flat_params) + + for x_train, y_train in train_loader: + + if gpu > 0: + x_train, y_train = x_train.cuda(), y_train.cuda() + + def f(flat_params_): + split_params = tensor_to_tuple(flat_params_, params) + load_weights(model, names, split_params) + out = model(x_train) + loss = calc_loss(out, y_train) + return loss + + batch_hvp = vhp(f, flat_params, x_tensor, strict=True)[1] + + hvp += batch_hvp / float(len(train_loader)) + + with torch.no_grad(): + load_weights(model, names, params, as_params=True) + damped_hvp = hvp + damp * v_flat + + return damped_hvp.cpu().numpy() + + def print_function_value(_, f_linear, f_quadratic): + print( + f"Conjugate function value: {f_linear + f_quadratic}, lin: {f_linear}, quad: {f_quadratic}" + ) + + debug_callback = print_function_value if verbose else None + + result = conjugate_gradient( + hvp_fn, + v_flat.cpu().numpy(), + debug_callback=debug_callback, + avextol=1e-8, + maxiter=100, + ) + + result = torch.tensor(result) + if gpu > 0: + result = result.cuda() + + return result + + +def s_test(x_test, y_test, model, i, samples_loader, gpu=-1, damp=0.01, scale=25.0, loss_func="cross_entropy"): + """s_test can be precomputed for each test point of interest, and then + multiplied with grad_z to get the desired value for each training point. + Here, stochastic estimation is used to calculate s_test. s_test is the + Inverse Hessian Vector Product. + + Arguments: + x_test: torch tensor, test data points, such as test images + y_test: torch tensor, contains all test data labels + model: torch NN, model used to evaluate the dataset + i: the sample number + samples_loader: torch DataLoader, can load the training dataset + gpu: int, GPU id to use if >=0 and -1 means use CPU + damp: float, dampening factor + scale: float, scaling factor + + Returns: + h_estimate: list of torch tensors, s_test""" + + v = grad_z(x_test, y_test, model, gpu, loss_func=loss_func) + h_estimate = v + + params, names = make_functional(model) + # Make params regular Tensors instead of nn.Parameter + params = tuple(p.detach().requires_grad_() for p in params) + + # TODO: Dynamically set the recursion depth so that iterations stop once h_estimate stabilises + progress_bar = tqdm(samples_loader, desc=f"IHVP sample {i}") + for i, (x_train, y_train) in enumerate(progress_bar): + + if gpu > 0: + x_train, y_train = x_train.cuda(), y_train.cuda() + + def f(*new_params): + load_weights(model, names, new_params) + out = model(x_train) + loss = calc_loss(out, y_train, loss_func=loss_func) + return loss + + hv = vhp(f, params, tuple(h_estimate), strict=True)[1] + + # Recursively calculate h_estimate + with torch.no_grad(): + h_estimate = [ + _v + (1 - damp) * _h_e - _hv / scale + for _v, _h_e, _hv in zip(v, h_estimate, hv) + ] + + if i % 100 == 0: + norm = sum([h_.norm() for h_ in h_estimate]) + progress_bar.set_postfix({"est_norm": norm.item()}) + + with torch.no_grad(): + load_weights(model, names, params, as_params=True) + + return h_estimate + + +def calc_loss(logits, labels, loss_func="cross_entropy"): + """Calculates the loss + + Arguments: + logits: torch tensor, input with size (minibatch, nr_of_classes) + labels: torch tensor, target expected by loss of size (0 to nr_of_classes-1) + loss_func: str, specify loss function name + + Returns: + loss: scalar, the loss""" + + if loss_func == "cross_entropy": + if logits.shape[-1] == 1: + loss = F.binary_cross_entropy_with_logits(logits, labels.type(torch.float)) + else: + criterion = torch.nn.CrossEntropyLoss() + loss = criterion(logits, labels) + elif loss_func == "mean": + loss = torch.mean(logits) + else: + raise ValueError("{} is not a valid value for loss_func".format(loss_func)) + + return loss + + +def grad_z(x, y, model, gpu=-1, loss_func="cross_entropy"): + """Calculates the gradient z. One grad_z should be computed for each + training sample. + + Arguments: + x: torch tensor, training data points + e.g. an image sample (batch_size, 3, 256, 256) + y: torch tensor, training data labels + model: torch NN, model used to evaluate the dataset + gpu: int, device id to use for GPU, -1 for CPU + + Returns: + grad_z: list of torch tensor, containing the gradients + from model parameters to loss""" + model.eval() + + # initialize + if gpu > 0: + print(gpu) + x, y = x.cuda(), y.cuda() + + prediction = model(x) + + loss = calc_loss(prediction, y, loss_func=loss_func) + + # Compute sum of gradients from model parameters to loss + return grad(loss, model.parameters()) + + +def s_test_sample( + model, + x_test, + y_test, + train_loader, + gpu=-1, + damp=0.01, + scale=25, + recursion_depth=5000, + r=1, + loss_func="cross_entropy", +): + """Calculates s_test for a single test image taking into account the whole + training dataset. s_test = invHessian * nabla(Loss(test_img, model params)) + + Arguments: + model: pytorch model, for which s_test should be calculated + x_test: test image + y_test: test image label + train_loader: pytorch dataloader, which can load the train data + gpu: int, device id to use for GPU, -1 for CPU (default) + damp: float, influence function damping factor | + scale: float, influence calculation scaling factor (to keep hessian <= I) | in the paper code use 25 + recursion_depth: int, number of recursions to perform during s_test + calculation, increases accuracy. r*recursion_depth should equal the + training dataset size. + r: int, number of iterations of which to take the avg. + of the h_estimate calculation; r*recursion_depth should equal the + training dataset size. + loss_func: cross_entropy + + Returns: + s_test_vec: torch tensor, contains s_test for a single test image + """ + + """ + initialize inverse_hvp as a list of tensors with zeros, which should be first s_test as described in the paper + H_0^(-1)v = v + """ + inverse_hvp = [ + torch.zeros_like(params, dtype=torch.float) for params in model.parameters() + ] + + for i in range(r): # repeat r times to get average + + hessian_loader = DataLoader( + train_loader.dataset, + sampler=torch.utils.data.RandomSampler( + train_loader.dataset, True, num_samples=recursion_depth # as mentioned in paper, use "enought" samples + ), + batch_size=1, + # num_workers=4, + ) + + cur_estimate = s_test( + x_test, y_test, model, i, hessian_loader, gpu=gpu, damp=damp, scale=scale, loss_func=loss_func, + ) + + with torch.no_grad(): + inverse_hvp = [ + old + (cur / scale) for old, cur in zip(inverse_hvp, cur_estimate) # update inverse_hvp by adding new cur_estimate + ] + + with torch.no_grad(): + inverse_hvp = [component / r for component in inverse_hvp] + + return inverse_hvp diff --git a/fig2_linear_approx/influence_functions_toolkits/influence_functions.py b/fig2_linear_approx/influence_functions_toolkits/influence_functions.py new file mode 100644 index 0000000..f47b7da --- /dev/null +++ b/fig2_linear_approx/influence_functions_toolkits/influence_functions.py @@ -0,0 +1,554 @@ +#! /usr/bin/env python3 + +import torch +import time +import datetime +import numpy as np +import copy +import logging +from tqdm import tqdm + +from pathlib import Path + +from fig2_linear_approx.influence_functions_toolkits.hvp_grad import ( + grad_z, + s_test_sample, +) +from fig2_linear_approx.influence_functions_toolkits.utils import ( + save_json, + display_progress, +) + + +def calc_s_test( + model, + test_loader, + train_loader, + save=False, + gpu=-1, + damp=0.01, + scale=25, + recursion_depth=5000, + r=1, + start=0, +): + """Calculates s_test for the whole test dataset taking into account all + training data images. + + Arguments: + model: pytorch model, for which s_test should be calculated + test_loader: pytorch dataloader, which can load the test data + train_loader: pytorch dataloader, which can load the train data + save: Path, path where to save the s_test files if desired. Omitting + this argument will skip saving + gpu: int, device id to use for GPU, -1 for CPU (default) + damp: float, influence function damping factor + scale: float, influence calculation scaling factor + recursion_depth: int, number of recursions to perform during s_test + calculation, increases accuracy. r*recursion_depth should equal the + training dataset size. + r: int, number of iterations of which to take the avg. + of the h_estimate calculation; r*recursion_depth should equal the + training dataset size. + start: int, index of the first test index to use. default is 0 + + Returns: + s_tests: list of torch vectors, contain all s_test for the whole + dataset. Can be huge. + save: Path, path to the folder where the s_test files were saved to or + False if they were not saved.""" + if save and not isinstance(save, Path): + save = Path(save) + if not save: + logging.info("ATTENTION: not saving s_test files.") + + s_tests = [] + for i in range(start, len(test_loader.dataset)): + z_test, t_test = test_loader.dataset[i] + z_test = test_loader.collate_fn([z_test]) + t_test = test_loader.collate_fn([t_test]) + + s_test_vec = s_test_sample( + model, z_test, t_test, train_loader, gpu, damp, scale, recursion_depth, r + ) + + if save: + s_test_vec = [s.cpu() for s in s_test_vec] + torch.save( + s_test_vec, save.joinpath(f"{i}_recdep{recursion_depth}_r{r}.s_test") + ) + else: + s_tests.append(s_test_vec) + display_progress( + "Calc. z_test (s_test): ", i - start, len(test_loader.dataset) - start + ) + + return s_tests, save + + +def calc_grad_z(model, train_loader, save_pth=False, gpu=-1, start=0): + """Calculates grad_z and can save the output to files. One grad_z should + be computed for each training data sample. + + Arguments: + model: pytorch model, for which s_test should be calculated + train_loader: pytorch dataloader, which can load the train data + save_pth: Path, path where to save the grad_z files if desired. + Omitting this argument will skip saving + gpu: int, device id to use for GPU, -1 for CPU (default) + start: int, index of the first test index to use. default is 0 + + Returns: + grad_zs: list of torch tensors, contains the grad_z tensors + save_pth: Path, path where grad_z files were saved to or + False if they were not saved.""" + if save_pth and isinstance(save_pth, str): + save_pth = Path(save_pth) + if not save_pth: + logging.info("ATTENTION: Not saving grad_z files!") + + grad_zs = [] + for i in range(start, len(train_loader.dataset)): + z, t = train_loader.dataset[i] + z = train_loader.collate_fn([z]) + t = train_loader.collate_fn([t]) + grad_z_vec = grad_z(z, t, model, gpu=gpu) + if save_pth: + grad_z_vec = [g.cpu() for g in grad_z_vec] + torch.save(grad_z_vec, save_pth.joinpath(f"{i}.grad_z")) + else: + grad_zs.append(grad_z_vec) + display_progress("Calc. grad_z: ", i - start, len(train_loader.dataset) - start) + + return grad_zs, save_pth + + +def load_s_test( + s_test_dir=Path("./s_test/"), s_test_id=0, r_sample_size=10, train_dataset_size=-1 +): + """Loads all s_test data required to calculate the influence function + and returns a list of it. + + Arguments: + s_test_dir: Path, folder containing files storing the s_test values + s_test_id: int, number of the test data sample s_test was calculated + for + r_sample_size: int, number of s_tests precalculated + per test dataset point + train_dataset_size: int, number of total samples in dataset; + -1 indicates to use all available grad_z files + + Returns: + e_s_test: list of torch vectors, contains all e_s_tests for the whole + dataset. + s_test: list of torch vectors, contain all s_test for the whole + dataset. Can be huge.""" + if isinstance(s_test_dir, str): + s_test_dir = Path(s_test_dir) + + s_test = [] + logging.info(f"Loading s_test from: {s_test_dir} ...") + num_s_test_files = len(s_test_dir.glob("*.s_test")) + if num_s_test_files != r_sample_size: + logging.warning( + "Load Influence Data: number of s_test sample files" + " mismatches the available samples" + ) + ######################## + # TODO: should prob. not hardcode the file name, use natsort+glob + ######################## + for i in range(num_s_test_files): + s_test.append(torch.load(s_test_dir / str(s_test_id) + f"_{i}.s_test")) + display_progress("s_test files loaded: ", i, r_sample_size) + + ######################### + # TODO: figure out/change why here element 0 is chosen by default + ######################### + e_s_test = s_test[0] + # Calculate the sum + for i in range(len(s_test)): + e_s_test = [i + j for i, j in zip(e_s_test, s_test[0])] + + # Calculate the average + ######################### + # TODO: figure out over what to calculate the average + # should either be r_sample_size OR e_s_test + ######################### + e_s_test = [i / len(s_test) for i in e_s_test] + + return e_s_test, s_test + + +def load_grad_z(grad_z_dir=Path("./grad_z/"), train_dataset_size=-1): + """Loads all grad_z data required to calculate the influence function and + returns it. + + Arguments: + grad_z_dir: Path, folder containing files storing the grad_z values + train_dataset_size: int, number of total samples in dataset; + -1 indicates to use all available grad_z files + + Returns: + grad_z_vecs: list of torch tensors, contains the grad_z tensors""" + if isinstance(grad_z_dir, str): + grad_z_dir = Path(grad_z_dir) + + grad_z_vecs = [] + logging.info(f"Loading grad_z from: {grad_z_dir} ...") + available_grad_z_files = len(grad_z_dir.glob("*.grad_z")) + if available_grad_z_files != train_dataset_size: + logging.warn( + "Load Influence Data: number of grad_z files mismatches" " the dataset size" + ) + if -1 == train_dataset_size: + train_dataset_size = available_grad_z_files + for i in range(train_dataset_size): + grad_z_vecs.append(torch.load(grad_z_dir / str(i) + ".grad_z")) + display_progress("grad_z files loaded: ", i, train_dataset_size) + + return grad_z_vecs + + +def calc_influence_function(train_dataset_size, grad_z_vecs=None, e_s_test=None): + """Calculates the influence function + + Arguments: + train_dataset_size: int, total train dataset size + grad_z_vecs: list of torch tensor, containing the gradients + from model parameters to loss + e_s_test: list of torch tensor, contains s_test vectors + + Returns: + influence: list of float, influences of all training data samples + for one test sample + harmful: list of float, influences sorted by harmfulness + helpful: list of float, influences sorted by helpfulness""" + if not grad_z_vecs and not e_s_test: + grad_z_vecs = load_grad_z() + e_s_test, _ = load_s_test(train_dataset_size=train_dataset_size) + + if len(grad_z_vecs) != train_dataset_size: + logging.warn( + "Training data size and the number of grad_z files are" " inconsistent." + ) + train_dataset_size = len(grad_z_vecs) + + influences = [] + for i in range(train_dataset_size): + tmp_influence = ( + -sum( + [ + ################################### + # TODO: verify if computation really needs to be done + # on the CPU or if GPU would work, too + ################################### + torch.sum(k * j).data.cpu().numpy() + for k, j in zip(grad_z_vecs[i], e_s_test) + ################################### + # Originally with [i] because each grad_z contained + # a list of tensors as long as e_s_test list + # There is one grad_z per training data sample + ################################### + ] + ) + / train_dataset_size + ) + influences.append(tmp_influence) + # display_progress("Calc. influence function: ", i, train_dataset_size) + + harmful = np.argsort(influences) + helpful = harmful[::-1] + + return influences, harmful.tolist(), helpful.tolist() + + +def calc_influence_single( + model, + train_loader, + test_loader, + test_id_num, + recursion_depth, + r, + gpu=0, + damp=0.01, + scale=25, + s_test_vec=None, + time_logging=False, + loss_func="cross_entropy", +): + """Calculates the influences of all training data points on a single + test dataset image. + + Arugments: + model: pytorch model + train_loader: DataLoader, loads the training dataset + test_loader: DataLoader, loads the test dataset + test_id_num: int, id of the test sample for which to calculate the + influence function + recursion_depth: int, number of recursions to perform during s_test + calculation, increases accuracy. r*recursion_depth should equal the + training dataset size. | in the paper use 5000 + r: int, number of repeatation of which to take the avg. | in the paper use 10 + of the h_estimate calculation; r*recursion_depth should be less or equal to the + training dataset size. + gpu: int, identifies the gpu id, 0 for cpu + s_test_vec: list of torch tensor, contains s_test vectors. If left + empty it will also be calculated + + Returns: + influence: list of float, influences of all training data samples + for one test sample + harmful: list of float, influences sorted by harmfulness + helpful: list of float, influences sorted by helpfulness + test_id_num: int, the number of the test dataset point + the influence was calculated for""" + # Calculate s_test vectors if not provided + if s_test_vec is None: + z_test, t_test = test_loader.dataset[test_id_num] # image, label + z_test = test_loader.collate_fn([z_test]) # collate_fn is a function that takes a list of samples from dataset and collate them into a batch, return a batched sample + t_test = test_loader.collate_fn([t_test]) + s_test_vec = s_test_sample( + model, + z_test, + t_test, + train_loader, + gpu, + recursion_depth=recursion_depth, + r=r, + damp=damp, + scale=scale, + loss_func=loss_func, + ) + + # Calculate the influence function + train_dataset_size = len(train_loader.dataset) + influences = [] + for i in tqdm(range(train_dataset_size)): + z, t = train_loader.dataset[i] + z = train_loader.collate_fn([z]) + t = train_loader.collate_fn([t]) + + if time_logging: + time_a = datetime.datetime.now() + + grad_z_vec = grad_z(z, t, model, gpu=gpu) + + if time_logging: + time_b = datetime.datetime.now() + time_delta = time_b - time_a + logging.info( + f"Time for grad_z iter:" f" {time_delta.total_seconds() * 1000}" + ) + with torch.no_grad(): + tmp_influence = ( + -sum( + [ + torch.sum(k * j).data + for k, j in zip(grad_z_vec, s_test_vec) + ] + ) + / train_dataset_size + ) + + influences.append(tmp_influence) + + harmful = np.argsort(influences) + helpful = harmful[::-1] + + return influences, harmful.tolist(), helpful.tolist(), test_id_num + + +def get_dataset_sample_ids_per_class(class_id, num_samples, test_loader, start_index=0): + """Gets the first num_samples from class class_id starting from + start_index. Returns a list with the indicies which can be passed to + test_loader.dataset[X] to retreive the actual data. + + Arguments: + class_id: int, name or id of the class label + num_samples: int, number of samples per class to process + test_loader: DataLoader, can load the test dataset. + start_index: int, means after which x occourance to add an index + to the list of indicies. E.g. if =3, then it would add the + 4th occourance of an item with the label class_nr to the list. + + Returns: + sample_list: list of int, contains indicies of the relevant samples""" + sample_list = [] + img_count = 0 + for i in range(len(test_loader.dataset)): + _, t = test_loader.dataset[i] + if class_id == t: + img_count += 1 + if (img_count > start_index) and (img_count <= start_index + num_samples): + sample_list.append(i) + elif img_count > start_index + num_samples: + break + + return sample_list + + +def get_dataset_sample_ids(num_samples, test_loader, num_classes=None, start_index=0): + """Gets the first num_sample indices of all classes starting from + start_index per class. Returns a list and a dict containing the indicies. + + Arguments: + num_samples: int, number of samples of each class to return + test_loader: DataLoader, can load the test dataset + num_classes: int, number of classes contained in the dataset + start_index: int, means after which x occourance to add an index + to the list of indicies. E.g. if =3, then it would add the + 4th occourance of an item with the label class_nr to the list. + + Returns: + sample_dict: dict, containing dict[class] = list_of_indices + sample_list: list, containing a continious list of indices""" + sample_dict = {} + sample_list = [] + if not num_classes: + num_classes = len(np.unique(test_loader.dataset.targets)) + for i in range(num_classes): + sample_dict[str(i)] = get_dataset_sample_ids_per_class( + i, num_samples, test_loader, start_index + ) + # Append the new list on the same level as the old list + # Avoids having a list of lists + sample_list[len(sample_list) : len(sample_list)] = sample_dict[str(i)] + return sample_dict, sample_list + + +def calc_img_wise(config, model, train_loader, test_loader, loss_func="cross_entropy"): + """Calculates the influence function one test point at a time. Calcualtes + the `s_test` and `grad_z` values on the fly and discards them afterwards. + + Arguments: + config: dict, contains the configuration from cli params""" + influences_meta = copy.deepcopy(config) + test_sample_num = config["test_sample_num"] + test_start_index = config["test_start_index"] + outdir = Path(config["outdir"]) + + # If calculating the influence for a subset of the whole dataset, + # calculate it evenly for the same number of samples from all classes. + # `test_start_index` is `False` when it hasn't been set by the user. It can + # also be set to `0`. + if test_sample_num and test_start_index is not False: + test_dataset_iter_len = test_sample_num * config["num_classes"] + _, sample_list = get_dataset_sample_ids( + test_sample_num, test_loader, config["num_classes"], test_start_index + ) + else: + test_dataset_iter_len = len(test_loader.dataset) + + # Set up logging and save the metadata conf file + logging.info(f"Running on: {test_sample_num} images per class.") + logging.info(f"Starting at img number: {test_start_index} per class.") + influences_meta["test_sample_index_list"] = sample_list + influences_meta_fn = ( + f"influences_results_meta_{test_start_index}-" f"{test_sample_num}.json" + ) + influences_meta_path = outdir.joinpath(influences_meta_fn) + save_json(influences_meta, influences_meta_path) + + influences = {} + # Main loop for calculating the influence function one test sample per + # iteration. + for j in range(test_dataset_iter_len): + # If we calculate evenly per class, choose the test img indicies + # from the sample_list instead + if test_sample_num and test_start_index: + if j >= len(sample_list): + logging.warning( + "ERROR: the test sample id is out of index of the" + " defined test set. Jumping to next test sample." + ) + i = sample_list[j] + else: + i = j + + start_time = time.time() + influence, harmful, helpful, _ = calc_influence_single( + model, + train_loader, + test_loader, + test_id_num=i, + gpu=config["gpu"], + recursion_depth=config["recursion_depth"], + r=config["r_averaging"], + loss_func=loss_func, + ) + end_time = time.time() + + ########### + # Different from `influence` above + ########### + influences[str(i)] = {} + _, label = test_loader.dataset[i] + influences[str(i)]["label"] = label + influences[str(i)]["num_in_dataset"] = j + influences[str(i)]["time_calc_influence_s"] = end_time - start_time + infl = [x.cpu().numpy().tolist() for x in influence] + influences[str(i)]["influence"] = infl + influences[str(i)]["harmful"] = harmful[:500] + influences[str(i)]["helpful"] = helpful[:500] + + tmp_influences_path = outdir.joinpath( + f"influence_results_tmp_" + f"{test_start_index}_" + f"{test_sample_num}" + f"_last-i_{i}.json" + ) + save_json(influences, tmp_influences_path) + display_progress("Test samples processed: ", j, test_dataset_iter_len) + + logging.info(f"The results for this run are:") + logging.info("Influences: ") + logging.info(influence[:3]) + logging.info("Most harmful img IDs: ") + logging.info(harmful[:3]) + logging.info("Most helpful img IDs: ") + logging.info(helpful[:3]) + + influences_path = outdir.joinpath( + f"influence_results_{test_start_index}_" f"{test_sample_num}.json" + ) + save_json(influences, influences_path) + + +def calc_all_grad_then_test(config, model, train_loader, test_loader): + """Calculates the influence function by first calculating + all grad_z, all s_test and then loading them to calc the influence""" + + outdir = Path(config["outdir"]) + s_test_outdir = outdir.joinpath("s_test/") + if not s_test_outdir.exists(): + s_test_outdir.mkdir() + grad_z_outdir = outdir.joinpath("grad_z/") + if not grad_z_outdir.exists(): + grad_z_outdir.mkdir() + + influence_results = {} + + calc_s_test( + model, + test_loader, + train_loader, + s_test_outdir, + config["gpu"], + config["damp"], + config["scale"], + config["recursion_depth"], + config["r_averaging"], + config["test_start_index"], + ) + calc_grad_z( + model, train_loader, grad_z_outdir, config["gpu"], config["test_start_index"] + ) + + train_dataset_len = len(train_loader.dataset) + influences, harmful, helpful = calc_influence_function(train_dataset_len) + + influence_results["influences"] = influences + influence_results["harmful"] = harmful + influence_results["helpful"] = helpful + influences_path = outdir.joinpath("influence_results.json") + save_json(influence_results, influences_path) diff --git a/fig2_linear_approx/influence_functions_toolkits/utils.py b/fig2_linear_approx/influence_functions_toolkits/utils.py new file mode 100644 index 0000000..bc8c1eb --- /dev/null +++ b/fig2_linear_approx/influence_functions_toolkits/utils.py @@ -0,0 +1,274 @@ +import sys +import json +import logging +from pathlib import Path +from datetime import datetime as dt +from typing import Sequence + +import numpy as np +import torch +from scipy.optimize import fmin_ncg + + +def save_json( + json_obj, + json_path, + append_if_exists=False, + overwrite_if_exists=False, + unique_fn_if_exists=True, +): + """Saves a json file + + Arguments: + json_obj: json, json object + json_path: Path, path including the file name where the json object + should be saved to + append_if_exists: bool, append to the existing json file with the same + name if it exists (keep the json structure intact) + overwrite_if_exists: bool, xor with append, overwrites any existing + target file + unique_fn_if_exsists: bool, appends the current date and time to the + file name if the target file exists already. + """ + if isinstance(json_path, str): + json_path = Path(json_path) + + if overwrite_if_exists: + append_if_exists = False + unique_fn_if_exists = False + + if unique_fn_if_exists: + overwrite_if_exists = False + append_if_exists = False + if json_path.exists(): + time = dt.now().strftime("%Y-%m-%d-%H-%M-%S") + json_path = ( + json_path.parents[0] / f"{str(json_path.stem)}_{time}" + f"{str(json_path.suffix)}" + ) + + if overwrite_if_exists: + append_if_exists = False + with open(json_path, "w+") as fout: + json.dump(json_obj, fout, indent=2) + return + + if append_if_exists: + if json_path.exists(): + with open(json_path, "r") as fin: + read_file = json.load(fin) + read_file.update(json_obj) + with open(json_path, "w+") as fout: + json.dump(read_file, fout, indent=2) + return + + with open(json_path, "w+") as fout: + json.dump(json_obj, fout, indent=2) + + +def display_progress(text, current_step, last_step, enabled=True, fix_zero_start=True): + """Draws a progress indicator on the screen with the text preceeding the + progress + + Arguments: + test: str, text displayed to describe the task being executed + current_step: int, current step of the iteration + last_step: int, last possible step of the iteration + enabled: bool, if false this function will not execute. This is + for running silently without stdout output. + fix_zero_start: bool, if true adds 1 to each current step so that the + display starts at 1 instead of 0, which it would for most loops + otherwise. + """ + if not enabled: + return + + # Fix display for most loops which start with 0, otherwise looks weird + if fix_zero_start: + current_step = current_step + 1 + + term_line_len = 80 + final_chars = [":", ";", " ", ".", ","] + if text[-1:] not in final_chars: + text = text + " " + if len(text) < term_line_len: + bar_len = term_line_len - ( + len(text) + len(str(current_step)) + len(str(last_step)) + len(" / ") + ) + else: + bar_len = 30 + filled_len = int(round(bar_len * current_step / float(last_step))) + bar = "=" * filled_len + "." * (bar_len - filled_len) + + bar = f"{text}[{bar:s}] {current_step:d} / {last_step:d}" + if current_step < last_step - 1: + # Erase to end of line and print + sys.stdout.write("\033[K" + bar + "\r") + else: + sys.stdout.write(bar + "\n") + + sys.stdout.flush() + + +def init_logging(filename=None): + """Initialises log/stdout output + + Arguments: + filename: str, a filename can be set to output the log information to + a file instead of stdout""" + log_lvl = logging.INFO + log_format = "%(asctime)s: %(message)s" + if filename: + logging.basicConfig( + handlers=[logging.FileHandler(filename), logging.StreamHandler(sys.stdout)], + level=log_lvl, + format=log_format, + ) + else: + logging.basicConfig(stream=sys.stdout, level=log_lvl, format=log_format) + + +def get_default_config(): + """Returns a default config file""" + config = { + "outdir": "outdir", + "seed": 42, + "gpu": 0, + "dataset": "CIFAR10", + "num_classes": 10, + "test_sample_num": 1, + "test_start_index": 0, + "recursion_depth": 1, + "r_averaging": 1, + "scale": None, + "damp": None, + "calc_method": "img_wise", + "log_filename": None, + } + + return config + + +def conjugate_gradient(ax_fn, b, debug_callback=None, avextol=None, maxiter=None): + """Computes the solution to Ax - b = 0 by minimizing the conjugate objective + f(x) = x^T A x / 2 - b^T x. This does not require evaluating the matrix A + explicitly, only the matrix vector product Ax. + + From https://github.com/kohpangwei/group-influence-release/blob/master/influence/conjugate.py. + + Args: + ax_fn: A function that return Ax given x. + b: The vector b. + debug_callback: An optional debugging function that reports the current optimization function. Takes two + parameters: the current solution and a helper function that evaluates the quadratic and linear parts of the + conjugate objective separately. (Default value = None) + avextol: (Default value = None) + maxiter: (Default value = None) + + Returns: + The conjugate optimization solution. + + """ + + cg_callback = None + if debug_callback: + cg_callback = lambda x: debug_callback( + x, -np.dot(b, x), 0.5 * np.dot(x, ax_fn(x)) + ) + + result = fmin_ncg( + f=lambda x: 0.5 * np.dot(x, ax_fn(x)) - np.dot(b, x), + x0=np.zeros_like(b), + fprime=lambda x: ax_fn(x) - b, + fhess_p=lambda x, p: ax_fn(p), + callback=cg_callback, + avextol=avextol, + maxiter=maxiter, + ) + + return result + + +def del_attr(obj, names): + if len(names) == 1: + delattr(obj, names[0]) + else: + del_attr(getattr(obj, names[0]), names[1:]) + + +def set_attr(obj, names, val): + if len(names) == 1: + setattr(obj, names[0], val) + else: + set_attr(getattr(obj, names[0]), names[1:], val) + + +def make_functional(model): + orig_params = tuple(model.parameters()) + # Remove all the parameters in the model + names = [] + + for name, p in list(model.named_parameters()): + del_attr(model, name.split(".")) + names.append(name) + + return orig_params, names + + +def load_weights(model, names, params, as_params=False): + for name, p in zip(names, params): + if not as_params: + set_attr(model, name.split("."), p) + else: + set_attr(model, name.split("."), torch.nn.Parameter(p)) + + +def tensor_to_tuple(vec, parameters): + r"""Convert one vector to the parameters + + Adapted from + https://pytorch.org/docs/master/generated/torch.nn.utils.vector_to_parameters.html#torch.nn.utils.vector_to_parameters + + Arguments: + vec (Tensor): a single vector represents the parameters of a model. + parameters (Iterable[Tensor]): an iterator of Tensors that are the + parameters of a model. + """ + if not isinstance(vec, torch.Tensor): + raise TypeError('expected torch.Tensor, but got: {}' + .format(torch.typename(vec))) + + # Pointer for slicing the vector for each parameter + pointer = 0 + + split_tensors = [] + for param in parameters: + + # The length of the parameter + num_param = param.numel() + # Slice the vector, reshape it, and replace the old data of the parameter + split_tensors.append(vec[pointer:pointer + num_param].view_as(param)) + + # Increment the pointer + pointer += num_param + + return tuple(split_tensors) + + +def parameters_to_vector(parameters): + r"""Convert parameters to one vector + + Arguments: + parameters (Iterable[Tensor]): an iterator of Tensors that are the + parameters of a model. + + Returns: + The parameters represented by a single vector + """ + # Flag for the device where the parameter is located + + vec = [] + for param in parameters: + vec.append(param.view(-1)) + + return torch.cat(vec) diff --git a/fig2_linear_approx/leave_one_retraining.py b/fig2_linear_approx/leave_one_retraining.py new file mode 100644 index 0000000..e5fd25f --- /dev/null +++ b/fig2_linear_approx/leave_one_retraining.py @@ -0,0 +1,166 @@ +""" +reproduce the fig2 middle plot in the paper, remove one training sample and retrain the logistic regression model on MINIST 10 classes +2023-10-29 +""" + +import torch +from sklearn import linear_model +import numpy as np +from tqdm import tqdm +import pickle +from utils import get_mnist_data, visualize_result +from model import LogisticRegression as LR + +from influence_functions_toolkits.influence_functions import ( + calc_influence_single, +) + +# HYPARAMS +EPOCH = 10 +BATCH_SIZE = 100 +CLASS_A, CLASS_B = 1, 7 +TEST_INDEX = 5 +WEIGHT_DECAY = 0.01 # same as original paper +OUTPUT_DIR = '../results' +SAMPLE_NUM = 100 +RECURSION_DEPTH = 1000 +R = 10 +SEED = 17 + +# set seed +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.cuda.manual_seed_all(SEED) + + +class DataSet: + def __init__(self, data, targets): + self.data = data + self.targets = targets + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + out_data = self.data[idx] + out_label = self.targets[idx] + + return out_data, out_label + + +def get_accuracy(model, test_loader): + """ + test whether the weight transferred from sklearn model to pytorch model is correct + """ + correct = 0 + total = 0 + + with torch.no_grad(): + for data in tqdm(test_loader): + images, labels = data + + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + + total += labels.size(0) + correct += (predicted == labels).sum().item() + print('Accuracy of the model on the test images: %d %%' % (100 * correct / total)) + return correct / total + + +def leave_one_out(): + (x_train, y_train), (x_test, y_test) = get_mnist_data() + # print(x_train.shape, y_train.shape, x_test.shape, y_test.shape) + train_sample_num = len(x_train) + print("len(x_train):", len(x_train)) + + train_data = DataSet(x_train, y_train) + test_data = DataSet(x_test, y_test) + train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False) + + # prepare sklearn model to train w as used in original paper code + C = 1.0 / (train_sample_num * WEIGHT_DECAY) + sklearn_model = linear_model.LogisticRegression(C=C, solver='lbfgs', tol=1e-8, fit_intercept=False, + multi_class='multinomial', warm_start=True) + + # prepare pytorch model to compute influence function + torch_model = LR(weight_decay=WEIGHT_DECAY, is_multi=True) + + # train + sklearn_model.fit(x_train, y_train.ravel()) + print('LBFGS training took %s iter.' % sklearn_model.n_iter_) + + # assign W into pytorch model + w_opt = sklearn_model.coef_ + with torch.no_grad(): + torch_model.w = torch.nn.Parameter( + torch.tensor(w_opt, dtype=torch.float) # torch.Size([10, 784]) + ) + get_accuracy(torch_model, test_loader) + + # calculate original loss + x_test_input = torch.FloatTensor(x_test[TEST_INDEX: TEST_INDEX + 1]) + y_test_input = torch.LongTensor(y_test[TEST_INDEX: TEST_INDEX + 1]) + + test_data = DataSet(x_test[TEST_INDEX: TEST_INDEX + 1], y_test[TEST_INDEX: TEST_INDEX + 1]) + test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True) + + + test_loss_ori = torch_model.loss(torch_model(x_test_input), y_test_input, train=False).detach().cpu().numpy() + + print('Original loss :{}'.format(test_loss_ori)) + + loss_diff_approx, _, _, _, = calc_influence_single(torch_model, train_loader, test_loader, test_id_num=0, + recursion_depth=RECURSION_DEPTH, r=R, damp=0, scale=25) + loss_diff_approx = - torch.FloatTensor(loss_diff_approx).cpu().numpy() + + # get high and low loss diff indice, checking stability + sorted_indice = np.argsort(loss_diff_approx) + sample_indice = np.concatenate([sorted_indice[-int(SAMPLE_NUM / 2):], sorted_indice[:int(SAMPLE_NUM / 2)]]) + + # calculate true loss diff + loss_diff_true = np.zeros(SAMPLE_NUM) + for i, index in zip(range(SAMPLE_NUM), sample_indice): + print('[{}/{}]'.format(i + 1, SAMPLE_NUM)) + + # get minus one dataset + x_train_minus_one = np.delete(x_train, index, axis=0) + y_train_minus_one = np.delete(y_train, index, axis=0) + + # retrain + C = 1.0 / ((train_sample_num - 1) * WEIGHT_DECAY) + sklearn_model_minus_one = linear_model.LogisticRegression(C=C, fit_intercept=False, tol=1e-8, solver='lbfgs') + sklearn_model_minus_one.fit(x_train_minus_one, y_train_minus_one.ravel()) + print('LBFGS training took {} iter.'.format(sklearn_model_minus_one.n_iter_)) + + # assign w on tensorflow model + w_retrain = sklearn_model_minus_one.coef_ + with torch.no_grad(): + torch_model.w = torch.nn.Parameter( + torch.tensor(w_retrain, dtype=torch.float) + ) + + # get retrain loss + test_loss_retrain = torch_model.loss(torch_model(x_test_input), y_test_input, + train=False).detach().cpu().numpy() + + # get true loss diff + loss_diff_true[i] = test_loss_retrain - test_loss_ori + + print('Original loss :{}'.format(test_loss_ori)) + print('Retrain loss :{}'.format(test_loss_retrain)) + print('True loss diff :{}'.format(loss_diff_true[i])) + print('Estimated loss diff :{}'.format(loss_diff_approx[index])) + + pickle.dump(loss_diff_true, open('loss_diff_true.pkl', 'wb')) + pickle.dump(loss_diff_approx[sample_indice], open('loss_diff_approx.pkl', 'wb')) + r2_score = visualize_result(loss_diff_true, loss_diff_approx[sample_indice], OUTPUT_DIR) + + +if __name__ == "__main__": + leave_one_out() + loss_diff_true = pickle.load(open('loss_diff_true.pkl', 'rb')) + loss_diff_approx = pickle.load(open('loss_diff_approx.pkl', 'rb')) + visualize_result(loss_diff_true, loss_diff_approx, OUTPUT_DIR) diff --git a/fig2_linear_approx/model.py b/fig2_linear_approx/model.py new file mode 100644 index 0000000..62b3785 --- /dev/null +++ b/fig2_linear_approx/model.py @@ -0,0 +1,40 @@ +import torch +import numpy as np + +def log_clip(x): + return torch.log(torch.clamp(x, 1e-10, None)) + + +class LogisticRegression(torch.nn.Module): + def __init__(self, weight_decay, is_multi=False): + super(LogisticRegression, self).__init__() + self.is_multi = is_multi + # self.wd = torch.FloatTensor([weight_decay]).cuda() + if self.is_multi: + self.w = torch.nn.Parameter(torch.zeros([10, 784], requires_grad=True)) + else: + self.w = torch.nn.Parameter(torch.zeros([784], requires_grad=True)) + + def forward(self, x): + if self.is_multi: + logits = torch.matmul(x, self.w.T) + else: + logits = torch.matmul(x, torch.reshape(self.w, [-1, 1])) + return logits + + def loss(self, logits, y, train=True): + if self.is_multi: + criterion = torch.nn.CrossEntropyLoss() + # set dtype to float + y = y.type(torch.FloatTensor) + loss = criterion(logits, y.long()) + else: + preds = torch.sigmoid(logits) + + if train: + loss = -torch.mean( + y * log_clip(preds) + (1 - y) * log_clip(1 - preds)) # + torch.norm(self.w, 2) * self.wd + else: + loss = -torch.mean(y * log_clip(preds) + (1 - y) * log_clip(1 - preds)) + + return loss diff --git a/fig2_linear_approx/result.png b/fig2_linear_approx/result.png new file mode 100644 index 0000000..4e4cbec Binary files /dev/null and b/fig2_linear_approx/result.png differ diff --git a/fig2_linear_approx/utils.py b/fig2_linear_approx/utils.py new file mode 100644 index 0000000..2fc6a32 --- /dev/null +++ b/fig2_linear_approx/utils.py @@ -0,0 +1,68 @@ +from torchvision import datasets, transforms +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import mean_absolute_error, r2_score +import os + +def get_mnist_data(): + train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([ + transforms.ToTensor(), + ])) + x_train, y_train = train_data.data, train_data.targets + + test_data = datasets.MNIST('data', train=False, download=True, transform=transforms.Compose([ + transforms.ToTensor(), + ])) + + # set train data's last 5000 data as validation data, as in the paper + x_train, y_train = x_train[:-5000], y_train[:-5000] + + x_test, y_test = test_data.data, test_data.targets + x_train, x_test = x_train.reshape([-1, 784]) / 255.0, x_test.reshape([-1, 784]) / 255.0 # divided by 255, improve convergence + + return (x_train, y_train), (x_test, y_test) + + +def visualize_result(actual_loss_diff, estimated_loss_diff, save_path=None): + from matplotlib.ticker import MaxNLocator, FuncFormatter + + r2_s = r2_score(actual_loss_diff, estimated_loss_diff) + + max_abs = np.max([np.abs(actual_loss_diff), np.abs(estimated_loss_diff)]) + min_, max_ = -max_abs * 1.1, max_abs * 1.1 + plt.rcParams.update({'font.size': 15}) + tick_label_size = 8 + + fig, ax = plt.subplots() + + ax.scatter(actual_loss_diff, estimated_loss_diff, zorder=2, s=10) + ax.set_title('Linear(approx)') + ax.set_xlabel('Actual diff in loss') + ax.set_ylabel('Predicted diff in loss') + range_ = [min_, max_] + ax.plot(range_, range_, 'k-', alpha=0.2, zorder=1) + text = 'MAE = {:.03}\nR2 score = {:.03}'.format(mean_absolute_error(actual_loss_diff, estimated_loss_diff), + r2_s) + ax.text(max_abs, -max_abs, text, verticalalignment='bottom', horizontalalignment='right') + ax.set_xlim(min_, max_) + ax.set_ylim(min_, max_) + + # Using scientific notation for xticks and yticks + ax.ticklabel_format(style='sci', axis='both', scilimits=(0, 0)) + + # Adjusting the x and y ticks to be symmetric + ax.xaxis.set_major_locator(MaxNLocator(nbins=5, symmetric=True)) + ax.yaxis.set_major_locator(MaxNLocator(nbins=5, symmetric=True)) + ax.xaxis.set_major_formatter(FuncFormatter('{:.0e}'.format)) + ax.yaxis.set_major_formatter(FuncFormatter('{:.0e}'.format)) + # smaller the tick size + ax.xaxis.set_tick_params(labelsize=tick_label_size) + ax.yaxis.set_tick_params(labelsize=tick_label_size) + # make plt to be a square + plt.gca().set_aspect('equal', adjustable='box') + if save_path is not None: + plt.savefig(os.path.join(save_path, "result.png")) + else: + plt.show() + + return r2_s