-
Notifications
You must be signed in to change notification settings - Fork 5
/
model.py
358 lines (330 loc) · 21.8 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
import torch
import torch.nn as nn
class MultiConv(nn.Module):
def __init__(self, in_channels, out_channels, max_kernel_size=3, instance_norm=False, channel_attention=True):
super(MultiConv, self).__init__()
self.multiconv_layers = nn.ModuleList()
for kernel_size in range(3, max_kernel_size+1, 2):
if instance_norm:
doubleconv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=True, padding=kernel_size//2, padding_mode='reflect'),
nn.LeakyReLU(0.2),
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, bias=False, padding=kernel_size//2, padding_mode='reflect'),
nn.InstanceNorm2d(out_channels, affine=True),
nn.LeakyReLU(0.2))
else:
doubleconv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=True, padding=kernel_size//2, padding_mode='reflect'),
nn.LeakyReLU(0.2),
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, bias=True, padding=kernel_size//2, padding_mode='reflect'),
nn.LeakyReLU(0.2))
self.multiconv_layers.append(doubleconv)
if channel_attention:
self.channel_attention = ChannelAttention(out_channels*len(self.multiconv_layers), reduction_ratio=len(self.multiconv_layers))
else:
self.channel_attention = None
def forward(self, x):
if self.channel_attention is None:
return torch.cat([multiconv(x) for multiconv in self.multiconv_layers], 1)
else:
return self.channel_attention(torch.cat([multiconv(x) for multiconv in self.multiconv_layers], 1))
class ChannelAttention(nn.Module):
def __init__(self, channels, reduction_ratio=1):
super(ChannelAttention, self).__init__()
self.channel_attention = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels//reduction_ratio, kernel_size=1), nn.ReLU(), nn.Conv2d(channels//reduction_ratio, channels, kernel_size=1))
def forward(self, x):
return x * torch.sigmoid(self.channel_attention(x))
class PyNetCA(nn.Module):
def __init__(self, in_channels=4, out_channels=3, hidden_channels=16, instance_norm=True, channel_attention=True):
super(PyNetCA, self).__init__()
self.level0_conv1 = MultiConv(hidden_channels*2, hidden_channels*2, 9, False, channel_attention)
self.level0_conv2 = nn.Conv2d(hidden_channels*8, out_channels*2*2, 1)
self.level0_up1 = nn.PixelShuffle(2)
self.level1_conv1 = MultiConv(in_channels, hidden_channels*2, 3, False, channel_attention)
self.level1_conv2 = MultiConv(hidden_channels*4, hidden_channels*2, 5, False, channel_attention)
self.level1_conv3 = MultiConv(hidden_channels*6, hidden_channels*2, 7, instance_norm, channel_attention)
self.level1_conv4 = MultiConv(hidden_channels*6, hidden_channels*2, 9, instance_norm, channel_attention)
self.level1_conv5 = MultiConv(hidden_channels*8, hidden_channels*2, 9, instance_norm, channel_attention)
self.level1_conv6 = MultiConv(hidden_channels*8, hidden_channels*2, 9, instance_norm, channel_attention)
self.level1_conv7 = MultiConv(hidden_channels*8, hidden_channels*2, 9, instance_norm, channel_attention)
self.level1_conv8 = MultiConv(hidden_channels*8, hidden_channels*2, 7, instance_norm, channel_attention)
self.level1_conv9 = MultiConv(hidden_channels*8, hidden_channels*2, 5, instance_norm, channel_attention)
self.level1_conv10 = MultiConv(hidden_channels*8, hidden_channels*2, 3, False, channel_attention)
self.level1_conv11 = nn.Conv2d(hidden_channels*2, out_channels, 3, padding=3//2, padding_mode='reflect')
self.level1_up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*4, hidden_channels*2, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
self.level1_up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*4, hidden_channels*2, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
self.level2_conv1 = nn.Sequential(nn.MaxPool2d(2, 2),MultiConv(hidden_channels*2, hidden_channels*4, 3, instance_norm, channel_attention))
self.level2_conv2 = MultiConv(hidden_channels*8, hidden_channels*4, 5, instance_norm, channel_attention)
self.level2_conv3 = MultiConv(hidden_channels*12, hidden_channels*4, 7, instance_norm, channel_attention)
self.level2_conv4 = MultiConv(hidden_channels*12, hidden_channels*4, 7, instance_norm, channel_attention)
self.level2_conv5 = MultiConv(hidden_channels*12, hidden_channels*4, 7, instance_norm, channel_attention)
self.level2_conv6 = MultiConv(hidden_channels*12, hidden_channels*4, 7, instance_norm, channel_attention)
self.level2_conv7 = MultiConv(hidden_channels*16, hidden_channels*4, 5, instance_norm, channel_attention)
self.level2_conv8 = MultiConv(hidden_channels*12, hidden_channels*4, 3, instance_norm, channel_attention)
self.level2_conv9 = nn.Conv2d(hidden_channels*4, out_channels, 3, padding=3//2, padding_mode='reflect')
self.level2_up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*8, hidden_channels*4, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
self.level2_up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*8, hidden_channels*4, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
self.level3_conv1 = nn.Sequential(nn.MaxPool2d(2, 2), MultiConv(hidden_channels*4, hidden_channels*8, 3, instance_norm, channel_attention))
self.level3_conv2 = MultiConv(hidden_channels*16, hidden_channels*8, 5, instance_norm, channel_attention)
self.level3_conv3 = MultiConv(hidden_channels*16, hidden_channels*8, 5, instance_norm, channel_attention)
self.level3_conv4 = MultiConv(hidden_channels*16, hidden_channels*8, 5, instance_norm, channel_attention)
self.level3_conv5 = MultiConv(hidden_channels*16, hidden_channels*8, 5, instance_norm, channel_attention)
self.level3_conv6 = MultiConv(hidden_channels*32, hidden_channels*8, 3, instance_norm, channel_attention)
self.level3_conv7 = nn.Conv2d(hidden_channels*8, out_channels, 3, padding=3//2, padding_mode='reflect')
self.level3_up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*16, hidden_channels*8, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
self.level3_up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*16, hidden_channels*8, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
self.level4_conv1 = nn.Sequential(nn.MaxPool2d(2, 2), MultiConv(hidden_channels*8, hidden_channels*16, 3, instance_norm, channel_attention))
self.level4_conv2 = MultiConv(hidden_channels*32, hidden_channels*16, 3, instance_norm, channel_attention)
self.level4_conv3 = MultiConv(hidden_channels*16, hidden_channels*16, 3, instance_norm, channel_attention)
self.level4_conv4 = MultiConv(hidden_channels*16, hidden_channels*16, 3, instance_norm, channel_attention)
self.level4_conv5 = MultiConv(hidden_channels*16, hidden_channels*16, 3, instance_norm, channel_attention)
self.level4_conv6 = MultiConv(hidden_channels*32, hidden_channels*16, 3, instance_norm, channel_attention)
self.level4_conv7 = nn.Conv2d(hidden_channels*16, out_channels, 3, padding=3//2, padding_mode='reflect')
self.level4_up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*32, hidden_channels*16, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
self.level4_up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*32, hidden_channels*16, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
self.level5_conv1 = nn.Sequential(nn.MaxPool2d(2, 2), MultiConv(hidden_channels*16, hidden_channels*32, 3, instance_norm, channel_attention))
self.level5_conv2 = MultiConv(hidden_channels*32, hidden_channels*32, 3, instance_norm, channel_attention)
self.level5_conv3 = MultiConv(hidden_channels*32, hidden_channels*32, 3, instance_norm, channel_attention)
self.level5_conv4 = MultiConv(hidden_channels*32, hidden_channels*32, 3, instance_norm, channel_attention)
self.level5_conv5 = nn.Conv2d(hidden_channels*32, out_channels, 3, padding=3//2, padding_mode='reflect')
def get_parameters(self, level=None):
assert level in [None, 0, 1, 2, 3, 4, 5]
if level is None:
return self.parameters()
else:
params = []
for name, param in self.named_parameters():
if (f'level{level}' in name) or ('conv1' in name):
params.append(param)
return params
def forward(self, x, level):
if level<=5:
x = self.level1_conv1(x)
x1 = x
x = self.level2_conv1(x)
x2 = x
x = self.level3_conv1(x)
x3 = x
x = self.level4_conv1(x)
x4 = x
x = self.level5_conv1(x)
x = self.level5_conv2(x)
x = self.level5_conv3(x)
x = self.level5_conv4(x)
y5 = x
if level<=4:
x = torch.cat([x4, self.level4_up1(y5)], 1)
x = self.level4_conv2(x)
x = self.level4_conv3(x) + x
x = self.level4_conv4(x) + x
x = self.level4_conv5(x)
x = torch.cat([x, self.level4_up2(y5)], 1)
x = self.level4_conv6(x)
y4 = x
if level<=3:
x = torch.cat([x3, self.level3_up1(y4)], 1)
x = self.level3_conv2(x) + x
x = self.level3_conv3(x) + x
x = self.level3_conv4(x) + x
x = self.level3_conv5(x)
x = torch.cat([x, x3, self.level3_up2(y4)], 1)
x = self.level3_conv6(x)
y3 = x
if level<=2:
x = torch.cat([x2, self.level2_up1(y3)], 1)
x = self.level2_conv2(x)
x = torch.cat([x, x2], 1)
x = self.level2_conv3(x) + x
x = self.level2_conv4(x) + x
x = self.level2_conv5(x) + x
x = self.level2_conv6(x)
x = torch.cat([x, x2], 1)
x = self.level2_conv7(x)
x = torch.cat([x, self.level2_up2(y3)], 1)
x = self.level2_conv8(x)
y2 = x
if level<=1:
x = torch.cat([x1, self.level1_up1(y2)], 1)
x = self.level1_conv2(x)
x = torch.cat([x, x1], 1)
x = self.level1_conv3(x)
x = self.level1_conv4(x)
x = self.level1_conv5(x) + x
x = self.level1_conv6(x) + x
x = self.level1_conv7(x) + x
x = self.level1_conv8(x)
x = torch.cat([x, x1], 1)
x = self.level1_conv9(x)
x = torch.cat([x, self.level1_up2(y2), x1], 1)
x = self.level1_conv10(x)
if level==0:
x = self.level0_conv1(x)
x = self.level0_conv2(x)
x = self.level0_up1(x)
elif level==1:
x = self.level1_conv11(x)
elif level==2:
x = self.level2_conv9(x)
elif level==3:
x = self.level3_conv7(x)
elif level==4:
x = self.level4_conv7(x)
elif level==5:
x = self.level5_conv5(x)
return torch.tanh(x)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
### Baseline PyNet model is implemented but not used ###
# class PyNet(nn.Module):
# def __init__(self, in_channels=4, out_channels=3, hidden_channels=16, instance_norm=True, channel_attention=True):
# super(PyNet, self).__init__()
# self.level0_up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*2, hidden_channels, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
# self.level0_conv1 = nn.Conv2d(hidden_channels, out_channels, 3, padding=3//2, padding_mode='reflect')
#
# self.level1_conv1 = MultiConv(in_channels, hidden_channels*2, 3, False, channel_attention)
# self.level1_conv2 = MultiConv(hidden_channels*4, hidden_channels*2, 5, False, channel_attention)
# self.level1_conv3 = MultiConv(hidden_channels*6, hidden_channels*2, 7, instance_norm, channel_attention)
# self.level1_conv4 = MultiConv(hidden_channels*6, hidden_channels*2, 9, instance_norm, channel_attention)
# self.level1_conv5 = MultiConv(hidden_channels*8, hidden_channels*2, 9, instance_norm, channel_attention)
# self.level1_conv6 = MultiConv(hidden_channels*8, hidden_channels*2, 9, instance_norm, channel_attention)
# self.level1_conv7 = MultiConv(hidden_channels*8, hidden_channels*2, 9, instance_norm, channel_attention)
# self.level1_conv8 = MultiConv(hidden_channels*8, hidden_channels*2, 7, instance_norm, channel_attention)
# self.level1_conv9 = MultiConv(hidden_channels*8, hidden_channels*2, 5, instance_norm, channel_attention)
# self.level1_conv10 = MultiConv(hidden_channels*8, hidden_channels*2, 3, False, channel_attention)
# self.level1_conv11 = nn.Conv2d(hidden_channels*2, out_channels, 3, padding=3//2, padding_mode='reflect')
# self.level1_up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*4, hidden_channels*2, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
# self.level1_up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*4, hidden_channels*2, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
#
# self.level2_conv1 = nn.Sequential(nn.MaxPool2d(2, 2),MultiConv(hidden_channels*2, hidden_channels*4, 3, instance_norm, channel_attention))
# self.level2_conv2 = MultiConv(hidden_channels*8, hidden_channels*4, 5, instance_norm, channel_attention)
# self.level2_conv3 = MultiConv(hidden_channels*12, hidden_channels*4, 7, instance_norm, channel_attention)
# self.level2_conv4 = MultiConv(hidden_channels*12, hidden_channels*4, 7, instance_norm, channel_attention)
# self.level2_conv5 = MultiConv(hidden_channels*12, hidden_channels*4, 7, instance_norm, channel_attention)
# self.level2_conv6 = MultiConv(hidden_channels*12, hidden_channels*4, 7, instance_norm, channel_attention)
# self.level2_conv7 = MultiConv(hidden_channels*16, hidden_channels*4, 5, instance_norm, channel_attention)
# self.level2_conv8 = MultiConv(hidden_channels*12, hidden_channels*4, 3, instance_norm, channel_attention)
# self.level2_conv9 = nn.Conv2d(hidden_channels*4, out_channels, 3, padding=3//2, padding_mode='reflect')
# self.level2_up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*8, hidden_channels*4, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
# self.level2_up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*8, hidden_channels*4, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
#
# self.level3_conv1 = nn.Sequential(nn.MaxPool2d(2, 2), MultiConv(hidden_channels*4, hidden_channels*8, 3, instance_norm, channel_attention))
# self.level3_conv2 = MultiConv(hidden_channels*16, hidden_channels*8, 5, instance_norm, channel_attention)
# self.level3_conv3 = MultiConv(hidden_channels*16, hidden_channels*8, 5, instance_norm, channel_attention)
# self.level3_conv4 = MultiConv(hidden_channels*16, hidden_channels*8, 5, instance_norm, channel_attention)
# self.level3_conv5 = MultiConv(hidden_channels*16, hidden_channels*8, 5, instance_norm, channel_attention)
# self.level3_conv6 = MultiConv(hidden_channels*32, hidden_channels*8, 3, instance_norm, channel_attention)
# self.level3_conv7 = nn.Conv2d(hidden_channels*8, out_channels, 3, padding=3//2, padding_mode='reflect')
# self.level3_up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*16, hidden_channels*8, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
# self.level3_up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*16, hidden_channels*8, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
#
# self.level4_conv1 = nn.Sequential(nn.MaxPool2d(2, 2), MultiConv(hidden_channels*8, hidden_channels*16, 3, instance_norm, channel_attention))
# self.level4_conv2 = MultiConv(hidden_channels*32, hidden_channels*16, 3, instance_norm, channel_attention)
# self.level4_conv3 = MultiConv(hidden_channels*16, hidden_channels*16, 3, instance_norm, channel_attention)
# self.level4_conv4 = MultiConv(hidden_channels*16, hidden_channels*16, 3, instance_norm, channel_attention)
# self.level4_conv5 = MultiConv(hidden_channels*16, hidden_channels*16, 3, instance_norm, channel_attention)
# self.level4_conv6 = MultiConv(hidden_channels*32, hidden_channels*16, 3, instance_norm, channel_attention)
# self.level4_conv7 = nn.Conv2d(hidden_channels*16, out_channels, 3, padding=3//2, padding_mode='reflect')
# self.level4_up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*32, hidden_channels*16, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
# self.level4_up2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(hidden_channels*32, hidden_channels*16, 3, padding=3//2, padding_mode='reflect'), nn.LeakyReLU(0.2))
#
# self.level5_conv1 = nn.Sequential(nn.MaxPool2d(2, 2), MultiConv(hidden_channels*16, hidden_channels*32, 3, instance_norm, channel_attention))
# self.level5_conv2 = MultiConv(hidden_channels*32, hidden_channels*32, 3, instance_norm, channel_attention)
# self.level5_conv3 = MultiConv(hidden_channels*32, hidden_channels*32, 3, instance_norm, channel_attention)
# self.level5_conv4 = MultiConv(hidden_channels*32, hidden_channels*32, 3, instance_norm, channel_attention)
# self.level5_conv5 = nn.Conv2d(hidden_channels*32, out_channels, 3, padding=3//2, padding_mode='reflect')
#
#
# def get_parameters(self, level=None):
# assert level in [None, 0, 1, 2, 3, 4, 5]
# if level is None:
# return self.parameters()
# else:
# params = []
# for name, param in self.named_parameters():
# if (f'level{level}' in name) or ('conv1' in name):
# params.append(param)
# return params
#
# def count_parameters(self):
# return count_parameters(self)
#
#
# def forward(self, x, level):
# if level<=5:
# x = self.level1_conv1(x)
# x1 = x
# x = self.level2_conv1(x)
# x2 = x
# x = self.level3_conv1(x)
# x3 = x
# x = self.level4_conv1(x)
# x4 = x
# x = self.level5_conv1(x)
# x = self.level5_conv2(x)
# x = self.level5_conv3(x)
# x = self.level5_conv4(x)
# y5 = x
#
# if level<=4:
# x = torch.cat([x4, self.level4_up1(y5)], 1)
# x = self.level4_conv2(x)
# x = self.level4_conv3(x) + x
# x = self.level4_conv4(x) + x
# x = self.level4_conv5(x)
# x = torch.cat([x, self.level4_up2(y5)], 1)
# x = self.level4_conv6(x)
# y4 = x
#
# if level<=3:
# x = torch.cat([x3, self.level3_up1(y4)], 1)
# x = self.level3_conv2(x) + x
# x = self.level3_conv3(x) + x
# x = self.level3_conv4(x) + x
# x = self.level3_conv5(x)
# x = torch.cat([x, x3, self.level3_up2(y4)], 1)
# x = self.level3_conv6(x)
# y3 = x
#
# if level<=2:
# x = torch.cat([x2, self.level2_up1(y3)], 1)
# x = self.level2_conv2(x)
# x = torch.cat([x, x2], 1)
# x = self.level2_conv3(x) + x
# x = self.level2_conv4(x) + x
# x = self.level2_conv5(x) + x
# x = self.level2_conv6(x)
# x = torch.cat([x, x2], 1)
# x = self.level2_conv7(x)
# x = torch.cat([x, self.level2_up2(y3)], 1)
# x = self.level2_conv8(x)
# y2 = x
#
# if level<=1:
# x = torch.cat([x1, self.level1_up1(y2)], 1)
# x = self.level1_conv2(x)
# x = torch.cat([x, x1], 1)
# x = self.level1_conv3(x)
# x = self.level1_conv4(x)
# x = self.level1_conv5(x) + x
# x = self.level1_conv6(x) + x
# x = self.level1_conv7(x) + x
# x = self.level1_conv8(x)
# x = torch.cat([x, x1], 1)
# x = self.level1_conv9(x)
# x = torch.cat([x, self.level1_up2(y2), x1], 1)
# x = self.level1_conv10(x)
#
# if level==0:
# x = self.level0_up1(x)
# x = self.level0_conv1(x)
# elif level==1:
# x = self.level1_conv11(x)
# elif level==2:
# x = self.level2_conv9(x)
# elif level==3:
# x = self.level3_conv7(x)
# elif level==4:
# x = self.level4_conv7(x)
# elif level==5:
# x = self.level5_conv5(x)
# return torch.tanh(x)