-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
63 lines (45 loc) · 1.75 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class AgeGenderModel(torch.nn.Module,):
def __init__(self):
super(AgeGenderModel, self).__init__()
self.resNet = models.resnet18(pretrained=True)
self.fc1 = nn.Linear(512, 512)
self.age_cls_pred = nn.Linear(512, 10)
self.age_reg_pred = nn.Linear(10,1)
self.fc2 = nn.Linear(512, 512)
self.gen_cls_pred = nn.Linear(512, 2)
self.dropout = nn.Dropout(0.5)
def get_resnet_convs_out(self, x):
x = self.resNet.conv1(x)
x = self.resNet.bn1(x)
x = self.resNet.relu(x)
x = self.resNet.maxpool(x)
x = self.resNet.layer1(x)
x = self.resNet.layer2(x)
x = self.resNet.layer3(x)
x = self.resNet.layer4(x)
return x
def get_age_gender(self, last_conv_out):
last_conv_out = self.resNet.avgpool(last_conv_out)
last_conv_out = last_conv_out.view(last_conv_out.size(0), -1)
last_conv_out = self.dropout(last_conv_out)
age_pred = F.relu(self.fc1(last_conv_out))
age_cls_pred = self.age_cls_pred(age_pred)
age_reg_pred = self.age_reg_pred(age_cls_pred)
gen_pred = F.relu(self.fc2(last_conv_out))
gen_pred = self.gen_cls_pred(gen_pred)
return gen_pred, age_cls_pred, age_reg_pred
def forward(self, x):
last1 = self.get_resnet_convs_out(x)
gen_pred, age_cls_pred, age_reg_pred = self.get_age_gender(last1)
return gen_pred, age_cls_pred, age_reg_pred
if __name__ == '__main__':
a = AgeGenderModel()
# import ipdb; ipdb.set_trace()
x = torch.zeros((1,3,112,112))
out = a(x)
print('All good')
pass