Skip to content

Commit

Permalink
add creiterion, dataloader on pruner
Browse files Browse the repository at this point in the history
  • Loading branch information
hoonyyhoon committed May 11, 2024
1 parent 5b93d57 commit 456ab05
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 29 deletions.
4 changes: 2 additions & 2 deletions pruning_method/Mag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

class Mag(Pruner):
def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int]
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
"""Initialize."""
super(Mag, self).__init__(net, device, input_shape)
super(Mag, self).__init__(net, device, input_shape, dataloader, criterion)

self.params_to_prune = self.get_params(
(
Expand Down
4 changes: 2 additions & 2 deletions pruning_method/Rand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

class Rand(Pruner):
def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int]
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
"""Initialize."""
super(Rand, self).__init__(net, device, input_shape)
super(Rand, self).__init__(net, device, input_shape, dataloader, criterion)

self.params_to_prune = self.get_params(
(
Expand Down
73 changes: 73 additions & 0 deletions pruning_method/SNIP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from tqdm import tqdm

from pruning_method.pruner import Pruner


class SNIP(Pruner):
def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
super(SNIP, self).__init__(net, device, input_shape, dataloader, criterion)

self.params_to_prune = self.get_params(
(
(nn.Conv2d, "weight"),
(nn.Conv2d, "bias"),
(nn.Linear, "weight"),
(nn.Linear, "bias"),
)
)
self.params_to_prune_orig = self.get_params(
(
(nn.Conv2d, "weight_orig"),
(nn.Conv2d, "bias_orig"),
(nn.Linear, "weight_orig"),
(nn.Linear, "bias_orig"),
)
)
prune.global_unstructured(
self.params_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.0,
)
# https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
# To get gradient of each weight(after prune at least one time)
self.params_to_prune_orig = self.get_params(
(
(nn.Conv2d, "weight_orig"),
(nn.Conv2d, "bias_orig"),
(nn.Linear, "weight_orig"),
(nn.Linear, "bias_orig"),
)
)

def prune(self, amount: int):
unit_amount = 1 - ((1 - amount) ** 0.01)
print(f"Start prune, target_sparsity: {amount*100:.2f}%")
self.global_unstructured(
pruning_method=prune.L1Unstructured, amount=unit_amount
)
sparsity = self.mask_sparsity()
print(f"Pruning Done, sparsity: {sparsity:.2f}%")

def get_prune_score(self) -> List[float]:
"""Run prune algorithm and get score."""
self.model.train()
with tqdm(self.dataloader, unit="batch") as iepoch:
for inputs, labels in iepoch:
data, target = inputs.to(self.device), labels.to(self.device)
output = self.model(data)
self.criterion(output, target).backward()

scores = []
for (p, n), (po, no) in zip(self.params_to_prune, self.params_to_prune_orig):
score = (getattr(p, n)* getattr(po, no).grad).to("cpu").detach().abs_()
scores.append(score)
getattr(po, no).grad.data.zero_()
self.model.zero_grad()
return scores
4 changes: 2 additions & 2 deletions pruning_method/Synflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

class Synflow(Pruner):
def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int]
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
super(Synflow, self).__init__(net, device, input_shape)
super(Synflow, self).__init__(net, device, input_shape, dataloader, criterion)

self.params_to_prune = self.get_params(
(
Expand Down
4 changes: 3 additions & 1 deletion pruning_method/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Pruner(ABC):
"""Pruner abstract class."""

def __init__(
self, net: nn.Module, device: torch.device, input_shape: List[int]
self, net: nn.Module, device: torch.device, input_shape: List[int], dataloader: torch.utils.data.DataLoader, criterion
) -> None:
"""Initialize."""
super(Pruner, self).__init__()
Expand All @@ -19,6 +19,8 @@ def __init__(
# need to be NCHW
self.input_shape = [input_shape[2], input_shape[0], input_shape[1]]
self.params_to_prune: Tuple[Tuple[nn.Module, str]] = None # type: ignore
self.dataloader = dataloader
self.criterion = criterion

@abstractmethod
def prune(self, amount):
Expand Down
16 changes: 9 additions & 7 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
help="Dataset in torchvision.datasets ex) CIFAR10, CIFAR100, MNIST",
)
parser.add_argument(
"--batch_size", default=64, type=int, help="Batch size, default: 64"
"--batch_size", default=128, type=int, help="Batch size, default: 128"
)
parser.add_argument(
"--method_list",
Expand All @@ -37,15 +37,15 @@
parser.add_argument(
"--ratio_list",
nargs="+",
default=[0.9, 0.95],
default=[0.9, 0.95, 0.98],
type=float, # type: ignore
help="List of pruning ratio. ex) --ratio_list 0 0.5 0.9 0.95 0.99", # type: ignore
)
parser.add_argument(
"--epoch",
default=50,
default=160,
type=int,
help="Number of epochs to train, default: 50",
help="Number of epochs to train, default: 160",
)

args = parser.parse_args()
Expand All @@ -67,18 +67,20 @@
net = getattr(
__import__("torchvision.models", fromlist=[""]), args.model
)().to(device)
print(net.__class__)

# Train
trainer = Trainer(net, trainloader, testloader, device, args.epoch)

# Apply prune
input_shape = list(trainloader.dataset.data.shape[1:])
if len(input_shape) == 2:
input_shape = input_shape + [3]
pruner = getattr(
__import__("pruning_method." + method, fromlist=[""]), method
)(net, device, input_shape)
)(net, device, input_shape, trainloader, trainer.criterion)
pruner.prune(amount=prune_amount)

# Train
trainer = Trainer(net, trainloader, testloader, device, args.epoch)
test_acc = trainer.train(args.epoch)

# Remove
Expand Down
12 changes: 2 additions & 10 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,9 @@ def __init__(
# Train
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.SGD(
self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5
)
self.scheduler = CosineAnnealingWarmupRestarts(
self.optimizer,
first_cycle_steps=epoch//4,
cycle_mult=1.0,
max_lr=0.1,
min_lr=0.0001,
warmup_steps=max(10, epoch//(4*4)),
gamma=0.5,
self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5*1e-4
)
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[epoch*3//8, epoch*6//8], gamma=0.2)

def train(self, epochs: int) -> float:
"""Train model, return best acc."""
Expand Down
14 changes: 9 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
def expand_to_rgb(x):
return x.repeat(3, 1, 1)

NORMALIZE_MAP = {
"mnist": [(0.1307,0.1307,0.1307), (0.3081,0.3081,0.3081)],
"cifar10": [(0.491, 0.482, 0.447), (0.247, 0.243, 0.262)],
"cifar100": [(0.507, 0.487, 0.441), (0.267, 0.256, 0.276)]
}

def get_dataloader(
dataset: str, batch_size: int
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
"""Sloppy dataloader."""
# hard-coded normalizing params
normalize = transforms.Normalize(
mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]
)
mean, std = NORMALIZE_MAP.get(dataset.lower(), [(0.507, 0.487, 0.441), (0.267, 0.256, 0.276)])
normalize = transforms.Normalize(mean=mean,std=std)

transform_train = transforms.Compose(
[
Expand Down Expand Up @@ -48,14 +52,14 @@ def get_dataloader(
root="datasets/", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=4
trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True,
)

testset = getattr(__import__("torchvision.datasets", fromlist=[""]), dataset)(
root="datasets/", train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=4
testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True,
)

return trainloader, testloader
Expand Down

0 comments on commit 456ab05

Please sign in to comment.