Skip to content

Commit

Permalink
Added QDTrack object tracker (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
mageofboy committed Jul 19, 2021
1 parent 1518c0a commit fdba024
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 4 deletions.
14 changes: 14 additions & 0 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ mkdir -p tracking/center_track ; cd tracking/center_track
# COCO model
~/.local/bin/gdown https://drive.google.com/uc?id=1tJCEJmdtYIh8VuN8CClGNws3YO7QGd40

###### Download QDTrack models ######
cd $PYLOT_HOME/dependencies/models
mkdir -p tracking/qd_track ; cd tracking/qd_track
~/.local/bin/gdown https://drive.google.com/uc?id=1YNAQgd8rMqqEG-fRj3VWlO4G5kdwJbxz

##### Download AnyNet depth estimation models #####
echo "[x] Downloading the depth estimation models..."
cd $PYLOT_HOME/dependencies/models
Expand Down Expand Up @@ -140,6 +145,15 @@ sudo apt-get install llvm-9
export LLVM_CONFIG=/usr/bin/llvm-config-9
python3 setup.py build develop --user

###### Install QDTrack ######
cd $PYLOT_HOME/dependencies/
git clone https://github.com/mageofboy/qdtrack.git
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
python3 setup.py develop #need to add mmcv
cd $PYLOT_HOME/dependencies/qdtrack
python3 setup.py develop

##### Download the Lanenet code #####
echo "[x] Cloning the lanenet lane detection code..."
cd $PYLOT_HOME/dependencies/
Expand Down
5 changes: 5 additions & 0 deletions pylot/component_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ def add_obstacle_tracking(center_camera_stream,
obstacles_wo_history_tracking_stream = \
pylot.operator_creator.add_center_track_tracking(
center_camera_stream, center_camera_setup)
elif FLAGS.tracker_type == 'qd_track':
logger.debug('Using QDTrack obstacle tracker...')
obstacles_wo_history_tracking_stream = \
pylot.operator_creator.add_qd_track_tracking(
center_camera_stream, center_camera_setup)
else:
logger.debug('Using obstacle tracker...')
obstacles_wo_history_tracking_stream = \
Expand Down
7 changes: 4 additions & 3 deletions pylot/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@
'True to enable obstacle tracking operator')
flags.DEFINE_bool('perfect_obstacle_tracking', False,
'True to enable perfect obstacle tracking')
flags.DEFINE_enum('tracker_type', 'sort',
['da_siam_rpn', 'deep_sort', 'sort', 'center_track'],
'Sets which obstacle tracker to use')
flags.DEFINE_enum(
'tracker_type', 'sort',
['da_siam_rpn', 'deep_sort', 'sort', 'center_track', 'qd_track'],
'Sets which obstacle tracker to use')
flags.DEFINE_bool('lane_detection', False, 'True to enable lane detection')
flags.DEFINE_bool('perfect_lane_detection', False,
'True to enable perfect lane detection')
Expand Down
13 changes: 13 additions & 0 deletions pylot/operator_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,19 @@ def add_center_track_tracking(bgr_camera_stream,
return obstacle_tracking_stream


def add_qd_track_tracking(bgr_camera_stream, camera_setup, name='qd_track'):
from pylot.perception.tracking.qd_track_operator import \
QdTrackOperator
op_config = erdos.OperatorConfig(name='qd_track_operator',
log_file_name=FLAGS.log_file_name,
csv_log_file_name=FLAGS.csv_log_file_name,
profile_file_name=FLAGS.profile_file_name)
[obstacle_tracking_stream] = erdos.connect(QdTrackOperator, op_config,
[bgr_camera_stream], FLAGS,
camera_setup)
return obstacle_tracking_stream


def add_tracking_evaluation(obstacle_tracking_stream,
ground_obstacles_stream,
evaluate_timely=False,
Expand Down
1 change: 0 additions & 1 deletion pylot/perception/detection/lane.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np

from pylot.utils import Location, Rotation, Transform, Vector3D

from shapely.geometry import Point
from shapely.geometry.polygon import Polygon

Expand Down
9 changes: 9 additions & 0 deletions pylot/perception/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@
['kitti_tracking', 'coco', 'mot', 'nuscenes'],
'CenterTrack available models')

# QDTrack tracking flags.
flags.DEFINE_string(
'qd_track_model_path', 'dependencies/models/tracking/qd_track/' +
'qdtrack-frcnn_r50_fpn_12e_bdd100k-13328aed.pth', 'Path to the model')
flags.DEFINE_string(
'qd_track_config_path',
'dependencies/qdtrack/configs/qdtrack-frcnn_r50_fpn_12e_bdd100k.py',
'Path to the model')

# Lane detection flags.
flags.DEFINE_float('lane_detection_gpu_memory_fraction', 0.3,
'GPU memory fraction allocated to Lanenet')
Expand Down
74 changes: 74 additions & 0 deletions pylot/perception/tracking/qd_track_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import time

import erdos

from pylot.perception.detection.obstacle import Obstacle
from pylot.perception.detection.utils import BoundingBox2D, \
OBSTACLE_LABELS
from pylot.perception.messages import ObstaclesMessage


class QdTrackOperator(erdos.Operator):
def __init__(self, camera_stream, obstacle_tracking_stream, flags,
camera_setup):
from qdtrack.apis import init_model

camera_stream.add_callback(self.on_frame_msg,
[obstacle_tracking_stream])
self._flags = flags
self._logger = erdos.utils.setup_logging(self.config.name,
self.config.log_file_name)
self._csv_logger = erdos.utils.setup_csv_logging(
self.config.name + '-csv', self.config.csv_log_file_name)
self._camera_setup = camera_setup
self.model = init_model(self._flags.qd_track_config_path,
checkpoint=self._flags.qd_track_model_path,
device='cuda:0',
cfg_options=None)
self.classes = ('pedestrian', 'rider', 'car', 'bus', 'truck',
'bicycle', 'motorcycle', 'train')
self.frame_id = 0

@staticmethod
def connect(camera_stream):
obstacle_tracking_stream = erdos.WriteStream()
return [obstacle_tracking_stream]

def destroy(self):
self._logger.warn('destroying {}'.format(self.config.name))

@erdos.profile_method()
def on_frame_msg(self, msg, obstacle_tracking_stream):
"""Invoked when a FrameMessage is received on the camera stream."""
from qdtrack.apis import inference_model

self._logger.debug('@{}: {} received frame'.format(
msg.timestamp, self.config.name))
assert msg.frame.encoding == 'BGR', 'Expects BGR frames'
start_time = time.time()
image_np = msg.frame.as_bgr_numpy_array()
results = inference_model(self.model, image_np, self.frame_id)
self.frame_id += 1

bbox_result, track_result = results.values()
obstacles = []
for k, v in track_result.items():
track_id = k
bbox = v['bbox'][None, :]
score = bbox[4]
label_id = v['label']
label = self.classes[label_id]
if label in ['pedestrian', 'rider']:
label = 'person'
if label in OBSTACLE_LABELS:
bounding_box_2D = BoundingBox2D(bbox[0], bbox[2], bbox[1],
bbox[3])
obstacles.append(
Obstacle(bounding_box_2D,
score,
label,
track_id,
bounding_box_2D=bounding_box_2D))
runtime = (time.time() - start_time) * 1000
obstacle_tracking_stream.send(
ObstaclesMessage(msg.timestamp, obstacles, runtime))
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@ nuscenes-devkit
progress
pyquaternion
scikit-learn==0.22.2
mmcv>=0.3.0
mmdet
##### CARLA dependencies #####
networkx==2.2

0 comments on commit fdba024

Please sign in to comment.