-
Notifications
You must be signed in to change notification settings - Fork 6
/
refine.py
47 lines (39 loc) · 1.3 KB
/
refine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import os
from tqdm import tqdm
def extend(si, sj, instance_label, global_label, human_label, class_map):
"""
"""
directions = [[-1, 0], [0, 1], [1, 0], [0, -1],
[1, 1], [1, -1], [-1, 1], [-1, -1]]
inst_class = instance_label[si, sj]
human_class = human_label[si, sj]
global_class = class_map[inst_class]
queue = [[si, sj]]
while len(queue) != 0:
cur = queue[0]
queue.pop(0)
for direction in directions:
ni = cur[0] + direction[0]
nj = cur[1] + direction[1]
if ni >= 0 and nj >= 0 and \
ni < instance_label.shape[0] and \
nj < instance_label.shape[1] and \
instance_label[ni, nj] == 0 and \
global_label[ni, nj] == global_class:
instance_label[ni, nj] = inst_class
# Using refined instance label to refine human label
human_label[ni, nj] = human_class
queue.append([ni, nj])
def refine(instance_label, human_label, global_label, class_map):
"""
Inputs:
[ instance_label ]
np.array() with shape [h, w]
[ global_label ] with shape [h, w]
np.array()
"""
for i in range(instance_label.shape[0]):
for j in range(instance_label.shape[1]):
if instance_label[i, j] != 0:
extend(i, j, instance_label, global_label, human_label, class_map)