-
Notifications
You must be signed in to change notification settings - Fork 4
/
ImageFeature.py
34 lines (27 loc) · 926 Bytes
/
ImageFeature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import numpy as np
import matplotlib.pyplot as plt
import LoadData
class ExtractImageFeature(torch.nn.Module):
def __init__(self):
super(ExtractImageFeature, self).__init__()
# 2048->1024
self.Linear = torch.nn.Linear(2048, 1024)
def forward(self, input):
input=input.permute(1,0,2)
output=list()
for i in range(196):
sub_output=torch.nn.functional.relu(self.Linear(input[i]))
output.append(sub_output)
output=torch.stack(output)
mean=torch.mean(output,0)
return mean,output
if __name__ == "__main__":
test=ExtractImageFeature()
for text_index,image_feature,attribute_index,group,id in LoadData.train_loader:
result,seq=test(image_feature)
# [2, 1024]
print(result.shape)
# [196, 2, 1024]
print(seq.shape)
break