Skip to content

Commit

Permalink
add creiterion, dataloader on pruner
Browse files Browse the repository at this point in the history
  • Loading branch information
hoonyyhoon committed May 10, 2024
1 parent 5b93d57 commit a3521b0
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 10 deletions.
81 changes: 81 additions & 0 deletions pruning_method/GraSP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

from pruning_method.pruner import Pruner


class Synflow(Pruner):
def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
super(Synflow, self).__init__(net, device, input_shape, dataloader, criterion)

self.params_to_prune = self.get_params(
(
(nn.Conv2d, "weight"),
(nn.Conv2d, "bias"),
(nn.Linear, "weight"),
(nn.Linear, "bias"),
)
)
prune.global_unstructured(
self.params_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.0,
)
# https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
# To get gradient of each weight(after prune at least one time)
self.params_to_prune_orig = self.get_params(
(
(nn.Conv2d, "weight_orig"),
(nn.Conv2d, "bias_orig"),
(nn.Linear, "weight_orig"),
(nn.Linear, "bias_orig"),
)
)

def prune(self, amount: int):
unit_amount = 1 - ((1 - amount) ** 0.01)
print(f"Start prune, target_sparsity: {amount*100:.2f}%")
for _ in range(100):
self.global_unstructured(
pruning_method=prune.L1Unstructured, amount=unit_amount
)
sparsity = self.mask_sparsity()
print(f"Pruning Done, sparsity: {sparsity:.2f}%")

def get_prune_score(self) -> List[float]:
"""Run prune algorithm and get score."""
# Synaptic flow
signs = self.linearize()
input_ones = torch.ones([1] + self.input_shape).to(self.device)
self.model.eval()
output = self.model(input_ones)
torch.sum(output).backward()

# get score function R
scores = []
for (p, n), (po, no) in zip(self.params_to_prune, self.params_to_prune_orig):
score = (getattr(p, n) * getattr(po, no).grad).to("cpu").detach().abs_()
scores.append(score)
getattr(po, no).grad.data.zero_()

self.nonlinearize(signs)
self.model.train()
return scores

@torch.no_grad()
def linearize(self):
signs = {}
for name, param in self.model.state_dict().items():
signs[name] = torch.sign(param)
param.abs_()
return signs

@torch.no_grad()
def nonlinearize(self, signs: Dict[str, torch.Tensor]):
for name, param in self.model.state_dict().items():
param.mul_(signs[name])
4 changes: 2 additions & 2 deletions pruning_method/Mag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

class Mag(Pruner):
def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int]
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
"""Initialize."""
super(Mag, self).__init__(net, device, input_shape)
super(Mag, self).__init__(net, device, input_shape, dataloader, criterion)

self.params_to_prune = self.get_params(
(
Expand Down
4 changes: 2 additions & 2 deletions pruning_method/Rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

class Rand(Pruner):
def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int]
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
"""Initialize."""
super(Rand, self).__init__(net, device, input_shape)
super(Rand, self).__init__(net, device, input_shape, dataloader, criterion)

self.params_to_prune = self.get_params(
(
Expand Down
72 changes: 72 additions & 0 deletions pruning_method/SNIP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

from pruning_method.pruner import Pruner


class SNIP(Pruner):
def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
super(SNIP, self).__init__(net, device, input_shape, dataloader, criterion)

self.params_to_prune = self.get_params(
(
(nn.Conv2d, "weight"),
(nn.Conv2d, "bias"),
(nn.Linear, "weight"),
(nn.Linear, "bias"),
)
)
self.params_to_prune_orig = self.get_params(
(
(nn.Conv2d, "weight_orig"),
(nn.Conv2d, "bias_orig"),
(nn.Linear, "weight_orig"),
(nn.Linear, "bias_orig"),
)
)
prune.global_unstructured(
self.params_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.0,
)
# https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
# To get gradient of each weight(after prune at least one time)
self.params_to_prune_orig = self.get_params(
(
(nn.Conv2d, "weight_orig"),
(nn.Conv2d, "bias_orig"),
(nn.Linear, "weight_orig"),
(nn.Linear, "bias_orig"),
)
)

def prune(self, amount: int):
unit_amount = 1 - ((1 - amount) ** 0.01)
print(f"Start prune, target_sparsity: {amount*100:.2f}%")
for _ in range(100):
self.global_unstructured(
pruning_method=prune.L1Unstructured, amount=unit_amount
)
sparsity = self.mask_sparsity()
print(f"Pruning Done, sparsity: {sparsity:.2f}%")

def get_prune_score(self) -> List[float]:
"""Run prune algorithm and get score."""

for data, target in self.dataloader:
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
self.criterion(output, target).backward()
break

scores = []
for (po, no) in self.params_to_prune_orig:
score = getattr(po, no).grad.to("cpu").detach().abs_()
scores.append(score)
getattr(po, no).grad.data.zero_()
return scores
4 changes: 2 additions & 2 deletions pruning_method/Synflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

class Synflow(Pruner):
def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int]
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
super(Synflow, self).__init__(net, device, input_shape)
super(Synflow, self).__init__(net, device, input_shape, dataloader, criterion)

self.params_to_prune = self.get_params(
(
Expand Down
4 changes: 3 additions & 1 deletion pruning_method/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Pruner(ABC):
"""Pruner abstract class."""

def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int]
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
"""Initialize."""
super(Pruner, self).__init__()
Expand All @@ -19,6 +19,8 @@ def __init__(
# need to be NCHW
self.input_shape = [input_shape[2], input_shape[0], input_shape[1]]
self.params_to_prune: Tuple[Tuple[nn.Module, str]] = None # type: ignore
self.dataloader = dataloader
self.criterion = criterion

@abstractmethod
def prune(self, amount):
Expand Down
7 changes: 4 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,18 @@
__import__("torchvision.models", fromlist=[""]), args.model
)().to(device)

# Train
trainer = Trainer(net, trainloader, testloader, device, args.epoch)

# Apply prune
input_shape = list(trainloader.dataset.data.shape[1:])
if len(input_shape) == 2:
input_shape = input_shape + [3]
pruner = getattr(
__import__("pruning_method." + method, fromlist=[""]), method
)(net, device, input_shape)
)(net, device, input_shape, trainloader, trainer.criterion)
pruner.prune(amount=prune_amount)

# Train
trainer = Trainer(net, trainloader, testloader, device, args.epoch)
test_acc = trainer.train(args.epoch)

# Remove
Expand Down

0 comments on commit a3521b0

Please sign in to comment.