-
Notifications
You must be signed in to change notification settings - Fork 15
/
main.py
61 lines (48 loc) · 1.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import matplotlib as mpl
mpl.use('Agg')
import wfdb as wf
import numpy as np
from utils import storage as us
from utils import plotters as up
from datasets import mitdb as dm
from models import abstract as ma
from matplotlib import pyplot as plt
def main():
dm.create_datasets()
dataset = us.Dataset('data')
params = [200]
netname = 'foo'
# Create model
model = ma.ConvEncoder(netname, params)
# How much data will it see
model.set_epochs(500)
model.set_batch_size(32)
# Learning tactic
lrs = [0.003, 0.0007, 0.0001]
model.set_learning_rates(lrs)
model.train(dataset)
# Extract the validation dataset
howmany = 1000
si, la = dataset.test_batch(howmany)
# Calculate loss
loss = model.consume(si, la)
print 'Validation loss', loss
# Show some results
score = model.process(si)
its = range(500)
its = its[::17]
for it in its:
plt.plot(score[it])
plt.plot(la[it])
savepath = 'tmp/result-{}.png'.format(it)
plt.savefig(savepath)
plt.clf()
# Also extract the final filter
filtr = model.get_tensor('conv_out/weights:0')
filtr = filtr.reshape([len(filtr)])
# To png
plt.plot(filtr)
plt.savefig('tmp/filter.png')
plt.clf()
if __name__ == '__main__':
main()