-
Notifications
You must be signed in to change notification settings - Fork 4
/
run.py
64 lines (56 loc) · 1.98 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import keras_segmentation
from keras.models import load_model
import tensorflow as tf
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="2, 3"
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
DATA_NAME = "VOC"
EPOCH = [5, 10]
CHECKOUTPOINT_PATH = "./tmp/voc_5_10"
if DATA_NAME == "VOC":
train_images_path = "./Datasets/VOC/train/imgs/"
train_segs_path = "./Datasets/VOC/train/segs/"
test_images_path = "./Datasets/VOC/test/imgs/"
test_segs_path = "./Datasets/VOC/test/segs"
class_num = 21
if DATA_NAME == "CUB":
train_images_path = "./Datasets/CUB_200_2011/train/imgs/"
train_segs_path = "./Datasets/CUB_200_2011/train/segs/"
test_images_path = "./Datasets/CUB_200_2011/test/imgs/"
test_segs_path = "./Datasets/CUB_200_2011/test/segs"
class_num = 2
'''
Change model name
'''
model = keras_segmentation.models.unet.resnet50_unet(n_classes=class_num, input_height=416, input_width=608)
# model.load_weights("./tmp/cub_psspnet_vgg_pspnet.9")
for i in range(EPOCH[0]):
'''
Train
'''
model.train(
train_images = train_images_path,
train_annotations = train_segs_path,
checkpoints_path = CHECKOUTPOINT_PATH , epochs=EPOCH[1], verify_dataset = False
)
'''
Output
'''
# for (i, image_dir) in enumerate(os.listdir(test_images_path)):
# if i%50 == 0:
# out = model.predict_segmentation(
# inp= os.path.join(test_images_path, image_dir),
# out_fname= os.path.join("./Output/", image_dir)
# )
# # import matplotlib.pyplot as plt
# # plt.imshow(out)
'''
Test mIoU
'''
test_image_list = [os.path.join(test_images_path, i) for i in os.listdir(test_images_path)]
test_segs_list = [os.path.join(test_segs_path, i) for i in os.listdir(test_segs_path)]
model.evaluate_segmentation(
inp_images = test_image_list,
annotations = test_segs_list
)