Skip to content

Commit

Permalink
feat: add uncompressed video images from webrtc
Browse files Browse the repository at this point in the history
  • Loading branch information
tfoldi committed Sep 21, 2024
1 parent 4a4a27a commit 32ca4b2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
27 changes: 27 additions & 0 deletions go2_robot_sdk/go2_robot_sdk/go2_driver_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
import threading
import asyncio

from aiortc import MediaStreamTrack
from cv_bridge import CvBridge


from scripts.go2_constants import ROBOT_CMD, RTC_TOPIC
from scripts.go2_func import gen_command, gen_mov_command
from scripts.go2_lidar_decoder import update_meshes_for_cloud2
Expand All @@ -46,6 +50,7 @@
from sensor_msgs_py import point_cloud2
from std_msgs.msg import Header
from nav_msgs.msg import Odometry
from sensor_msgs.msg import Image


logging.basicConfig(level=logging.WARN)
Expand Down Expand Up @@ -88,6 +93,7 @@ def __init__(self):
self.go2_lidar_pub = []
self.go2_odometry_pub = []
self.imu_pub = []
self.img_pub = []

if self.conn_mode == 'single':
self.joint_pub.append(self.create_publisher(
Expand All @@ -99,6 +105,7 @@ def __init__(self):
self.go2_odometry_pub.append(
self.create_publisher(Odometry, 'odom', qos_profile))
self.imu_pub.append(self.create_publisher(IMU, 'imu', qos_profile))
self.img_pub.append(self.create_publisher(Image, 'camera/image_raw', qos_profile))

else:
for i in range(len(self.robot_ip_lst)):
Expand All @@ -112,9 +119,13 @@ def __init__(self):
Odometry, f'robot{i}/odom', qos_profile))
self.imu_pub.append(self.create_publisher(
IMU, f'robot{i}/imu', qos_profile))
self.img_pub.append(self.create_publisher(
Image, f'robot{i}/camera/image_raw', qos_profile))

self.broadcaster = TransformBroadcaster(self, qos=qos_profile)

self.bridge = CvBridge()

self.robot_cmd_vel = {}
self.robot_odom = {}
self.robot_low_cmd = {}
Expand Down Expand Up @@ -254,6 +265,21 @@ def on_validated(self, robot_num):
self.conn[robot_num].data_channel.send(
json.dumps({"type": "subscribe", "topic": topic}))

async def on_video_frame(self, track: MediaStreamTrack, robot_num):
logger.info(f"Video frame received for robot {robot_num}")

while True:
frame = await track.recv()
img = frame.to_ndarray(format="bgr24")

logger.debug(f"Shape: {img.shape}, Dimensions: {img.ndim}, Type: {img.dtype}, Size: {img.size}")

# Convert the OpenCV image to ROS Image message
ros_image = self.bridge.cv2_to_imgmsg(img, encoding="bgr8")

# Publish the image
self.img_pub[robot_num].publish(ros_image)

def on_data_channel_message(self, _, msg, robot_num):

if msg.get('topic') == RTC_TOPIC["ULIDAR_ARRAY"]:
Expand Down Expand Up @@ -524,6 +550,7 @@ async def start_node():
token=base_node.token,
on_validated=base_node.on_validated,
on_message=base_node.on_data_channel_message,
on_video_frame=base_node.on_video_frame,
)

sleep_task_lst.append(asyncio.get_event_loop(
Expand Down
28 changes: 18 additions & 10 deletions go2_robot_sdk/scripts/webrtc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@


import binascii
import time
import uuid
import aiohttp
import base64
import hashlib
import json
Expand All @@ -39,7 +37,6 @@
from Crypto.Cipher import PKCS1_v1_5
import requests
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaBlackhole


from scripts.go2_lidar_decoder import LidarDecoder
Expand Down Expand Up @@ -193,7 +190,8 @@ def __init__(
token="",
on_validated=None,
on_message=None,
on_open=None
on_open=None,
on_video_frame=None,
):

self.pc = RTCPeerConnection()
Expand All @@ -205,8 +203,7 @@ def __init__(
self.on_message = on_message
self.on_open = on_open

self.audio_track = MediaBlackhole()
self.video_track = MediaBlackhole()
self.on_video_frame = on_video_frame

self.data_channel = self.pc.createDataChannel("data", id=0)
self.data_channel.on("open", self.on_data_channel_open)
Expand All @@ -215,19 +212,22 @@ def __init__(
self.pc.on("track", self.on_track)
self.pc.on("connectionstatechange", self.on_connection_state_change)

self.pc.addTransceiver("video", direction="recvonly")

def on_connection_state_change(self):
logger.info(f"Connection state is {self.pc.connectionState}")

def on_track(self, track):
async def on_track(self, track):
logger.info(f"Receiving {track.kind}")
if track.kind == "audio":
pass
elif track.kind == "video":
pass
frame = await track.recv()
logger.info(f"Received frame {frame}")
if self.on_video_frame:
await self.on_video_frame(track, int(self.robot_num))

async def generate_offer(self):
await self.audio_track.start()
await self.video_track.start()
offer = await self.pc.createOffer()
await self.pc.setLocalDescription(offer)
return offer.sdp
Expand Down Expand Up @@ -336,6 +336,14 @@ async def connect(self):

def validate_robot_conn(self, message):
if message.get("data") == "Validation Ok.":

# turn on video
self.publish(
"",
"on",
"vid",
)

self.validation_result = "SUCCESS"
if self.on_validated:
self.on_validated(self.robot_num)
Expand Down

0 comments on commit 32ca4b2

Please sign in to comment.