-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
24 lines (20 loc) · 853 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import deepneuralnet as net
import numpy as np
from tflearn.data_utils import image_preloader
model = net.model
X, Y = image_preloader(target_path='./train',
image_shape=(100, 100),
mode='folder',
grayscale=False,
categorical_labels=True,
normalize=True)
X = np.reshape(X, (-1, 100, 100, 3))
W, Z = image_preloader(target_path='./validate',
image_shape=(100, 100),
mode='folder',
grayscale=False,
categorical_labels=True,
normalize=True)
W = np.reshape(W, (-1, 100, 100, 3))
model.fit(X, Y, n_epoch=50, validation_set=(W,Z), show_metric=True)
model.save('./ZtrainedNet/final-model.tfl')