From d41e16a88d7106cde7fb633ff37bf9928a6bd1a3 Mon Sep 17 00:00:00 2001 From: Gutnar Leede Date: Tue, 12 Nov 2019 10:42:03 +0200 Subject: [PATCH 1/2] Add fix for using cpu device --- code/demo.py | 11 +++++++---- code/run_SiamRPN.py | 8 ++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/code/demo.py b/code/demo.py index 48985ec..bb93ff1 100644 --- a/code/demo.py +++ b/code/demo.py @@ -13,10 +13,13 @@ from run_SiamRPN import SiamRPN_init, SiamRPN_track from utils import get_axis_aligned_bbox, cxy_wh_2_rect +# get supported device +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # load net net = SiamRPNvot() -net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'SiamRPNVOT.model'))) -net.eval().cuda() +net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'SiamRPNVOT.model'), map_location=device)) +net.eval().to(device) # image and init box image_files = sorted(glob.glob('./bag/*.jpg')) @@ -26,14 +29,14 @@ # tracker init target_pos, target_sz = np.array([cx, cy]), np.array([w, h]) im = cv2.imread(image_files[0]) # HxWxC -state = SiamRPN_init(im, target_pos, target_sz, net) +state = SiamRPN_init(im, target_pos, target_sz, net, device) # tracking and visualization toc = 0 for f, image_file in enumerate(image_files): im = cv2.imread(image_file) tic = cv2.getTickCount() - state = SiamRPN_track(state, im) # track + state = SiamRPN_track(state, im, device) # track toc += cv2.getTickCount()-tic res = cxy_wh_2_rect(state['target_pos'], state['target_sz']) res = [int(l) for l in res] diff --git a/code/run_SiamRPN.py b/code/run_SiamRPN.py index fce0287..e3449de 100644 --- a/code/run_SiamRPN.py +++ b/code/run_SiamRPN.py @@ -114,7 +114,7 @@ def sz_wh(wh): return target_pos, target_sz, score[best_pscore_id] -def SiamRPN_init(im, target_pos, target_sz, net): +def SiamRPN_init(im, target_pos, target_sz, net, device="cuda"): state = dict() p = TrackerConfig() p.update(net.cfg) @@ -140,7 +140,7 @@ def SiamRPN_init(im, target_pos, target_sz, net): z_crop = get_subwindow_tracking(im, target_pos, p.exemplar_size, s_z, avg_chans) z = Variable(z_crop.unsqueeze(0)) - net.temple(z.cuda()) + net.temple(z.to(device)) if p.windowing == 'cosine': window = np.outer(np.hanning(p.score_size), np.hanning(p.score_size)) @@ -157,7 +157,7 @@ def SiamRPN_init(im, target_pos, target_sz, net): return state -def SiamRPN_track(state, im): +def SiamRPN_track(state, im, device="cuda"): p = state['p'] net = state['net'] avg_chans = state['avg_chans'] @@ -176,7 +176,7 @@ def SiamRPN_track(state, im): # extract scaled crops for search region x at previous target position x_crop = Variable(get_subwindow_tracking(im, target_pos, p.instance_size, round(s_x), avg_chans).unsqueeze(0)) - target_pos, target_sz, score = tracker_eval(net, x_crop.cuda(), target_pos, target_sz * scale_z, window, scale_z, p) + target_pos, target_sz, score = tracker_eval(net, x_crop.to(device), target_pos, target_sz * scale_z, window, scale_z, p) target_pos[0] = max(0, min(state['im_w'], target_pos[0])) target_pos[1] = max(0, min(state['im_h'], target_pos[1])) target_sz[0] = max(10, min(state['im_w'], target_sz[0])) From e65915aaf888eacee609975d08a32bbd4d15d99a Mon Sep 17 00:00:00 2001 From: Gutnar Leede Date: Tue, 12 Nov 2019 10:47:53 +0200 Subject: [PATCH 2/2] Specify 'cuda:0' as default device --- code/run_SiamRPN.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/code/run_SiamRPN.py b/code/run_SiamRPN.py index e3449de..7f3b54d 100644 --- a/code/run_SiamRPN.py +++ b/code/run_SiamRPN.py @@ -114,7 +114,7 @@ def sz_wh(wh): return target_pos, target_sz, score[best_pscore_id] -def SiamRPN_init(im, target_pos, target_sz, net, device="cuda"): +def SiamRPN_init(im, target_pos, target_sz, net, device="cuda:0"): state = dict() p = TrackerConfig() p.update(net.cfg) @@ -157,7 +157,7 @@ def SiamRPN_init(im, target_pos, target_sz, net, device="cuda"): return state -def SiamRPN_track(state, im, device="cuda"): +def SiamRPN_track(state, im, device="cuda:0"): p = state['p'] net = state['net'] avg_chans = state['avg_chans']