forked from ChinaYi/ASFormer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_efficient.py
139 lines (123 loc) · 5.45 KB
/
train_efficient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from torchvision.models import efficientnet_v2_s, efficientnet_b0, EfficientNet_B0_Weights
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from data.dataset_maker import CustomImageDataset
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import wandb
import argparse
# Training function.
def train(model, trainloader, optimizer, criterion):
model.train()
print('Training')
train_running_loss = 0.0
train_running_correct = 0
counter = 0
for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
counter += 1
image, labels = data
image = image.to(device)
labels = labels.to(device)
optimizer.zero_grad()
# Forward pass.
outputs = model(image)
# Calculate the loss.
loss = criterion(outputs, labels)
train_running_loss += loss.item()
# Calculate the accuracy.
_, preds = torch.max(outputs.data, 1)
train_running_correct += (preds == labels).sum().item()
# Backpropagation
loss.backward()
# Update the weights.
optimizer.step()
wandb.log({"train_loss": loss.item(), "train_accuracy": float((preds == labels).sum().item()) / len(preds)})
# Loss and accuracy for the complete epoch.
epoch_loss = train_running_loss / counter
epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
return epoch_loss, epoch_acc
# Validation function.
def validate(model, testloader, criterion):
model.eval()
print('Validation')
valid_running_loss = 0.0
valid_running_correct = 0
counter = 0
with torch.no_grad():
for i, data in tqdm(enumerate(testloader), total=len(testloader)):
counter += 1
image, labels = data
image = image.to(device)
labels = labels.to(device)
# Forward pass.
outputs = model(image)
# Calculate the loss.
loss = criterion(outputs, labels)
valid_running_loss += loss.item()
# Calculate the accuracy.
_, preds = torch.max(outputs.data, 1)
valid_running_correct += (preds == labels).sum().item()
# Loss and accuracy for the complete epoch.
epoch_loss = valid_running_loss / counter
epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
return epoch_loss, epoch_acc
def run_exp(FOLD):
WANDB_START_METHOD = "thread"
wandb.init(project="CVSA_FINAL", entity="tandl", name=f"EFFICIENT_SPLIT_{FOLD}", save_code=True)
# Load the training and validation datasets.
# model = efficientnet_v2_s(weights='DEFAULT')
model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
model.classifier[1] = nn.Linear(1280, 18, bias=True)
dataset_train = CustomImageDataset(f'fold_indexes/fold{FOLD}_train.csv', n_samples=48000)
dataset_valid = CustomImageDataset(f'fold_indexes/fold{FOLD}_val.csv', n_samples=10000)
train_loader = DataLoader(dataset_train, batch_size=16, shuffle=True)
valid_loader = DataLoader(dataset_valid, batch_size=16, shuffle=False)
print(f"[INFO]: Number of training images: {len(dataset_train)}")
print(f"[INFO]: Number of validation images: {len(dataset_valid)}")
# Learning_parameters.
model = model.to(device)
# Optimizer.
optimizer = optim.Adam(model.parameters(), lr=lr)
# Loss function.
criterion = nn.CrossEntropyLoss()
# Lists to keep track of losses and accuracies.
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []
# Start the training.
best_val_accuracy = 0
for epoch in range(epochs):
print(f"[INFO]: Epoch {epoch+1} of {epochs}")
train_epoch_loss, train_epoch_acc = train(model, train_loader,
optimizer, criterion)
valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader,
criterion)
train_loss.append(train_epoch_loss)
valid_loss.append(valid_epoch_loss)
train_acc.append(train_epoch_acc)
valid_acc.append(valid_epoch_acc)
wandb.log({"Training Loss": train_epoch_loss,
"Training Accur.": train_epoch_acc,
"Valid. Loss": valid_epoch_loss,
"Valid. Accur.": valid_epoch_acc})
print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
print('-'*50)
if valid_epoch_acc > best_val_accuracy:
best_val_accuracy = valid_epoch_acc
# Save the trained model weights.
torch.save(model.state_dict(), f'efficient_models/fold{FOLD}_model.pkl')
torch.save(optimizer.state_dict(), f'efficient_models/fold{FOLD}_optimizer.pkl')
# Save the loss and accuracy plots.
print('TRAINING COMPLETE')
lr = 0.001
epochs = 1
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Computation device: {device}")
print(f"Learning rate: {lr}")
print(f"Epochs to train for: {epochs}\n")
parser = argparse.ArgumentParser()
parser.add_argument('--FOLD', default='0')
args = parser.parse_args()
run_exp(args.FOLD)