Skip to content

Commit

Permalink
Support proxy to webrtc media.
Browse files Browse the repository at this point in the history
  • Loading branch information
winlinvip committed Sep 4, 2024
1 parent 1009843 commit 2fc5b44
Show file tree
Hide file tree
Showing 4 changed files with 356 additions and 21 deletions.
4 changes: 2 additions & 2 deletions proxy/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,15 @@ func (v *HTTPStreaming) serveByBackend(ctx context.Context, w http.ResponseWrite
type HLSStreaming struct {
// The context for HLS streaming.
ctx context.Context
// The context ID for recovering the context.
ContextID string `json:"cid"`

// The spbhid, used to identify the backend server.
SRSProxyBackendHLSID string `json:"spbhid"`
// The stream URL in vhost/app/stream schema.
StreamURL string `json:"stream_url"`
// The full request URL for HLS streaming
FullURL string `json:"full_url"`
// The context ID for recovering the context.
ContextID string `json:"cid"`
}

func NewHLSStreaming(opts ...func(streaming *HLSStreaming)) *HLSStreaming {
Expand Down
300 changes: 281 additions & 19 deletions proxy/rtc.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,34 @@ package main

import (
"context"
"encoding/binary"
"fmt"
"io/ioutil"
"net"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
stdSync "sync"

"srs-proxy/errors"
"srs-proxy/logger"
"srs-proxy/sync"
)

type rtcServer struct {
// The UDP listener for WebRTC server.
listener *net.UDPConn

// Fast cache for the username to identify the connection.
// The key is username, the value is the UDP address.
usernames sync.Map[string, *RTCConnection]
// Fast cache for the udp address to identify the connection.
// The key is UDP address, the value is the username.
// TODO: Support fast earch by uint64 address.
addresses sync.Map[string, *RTCConnection]

// The wait group for server.
wg sync.WaitGroup
wg stdSync.WaitGroup
}

func newRTCServer(opts ...func(*rtcServer)) *rtcServer {
Expand Down Expand Up @@ -173,22 +183,26 @@ func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r
}

// Fetch the ice-ufrag and ice-pwd from local SDP answer.
var iceUfrag, icePwd string
if true {
ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`)
ufragMatch := ufragRe.FindStringSubmatch(localSDPAnswer)
if len(ufragMatch) <= 1 {
return errors.Errorf("no ice-ufrag in local sdp answer %v", localSDPAnswer)
}
iceUfrag = ufragMatch[1]
remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer)
if err != nil {
return errors.Wrapf(err, "parse remote sdp offer")
}
if true {
pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`)
pwdMatch := pwdRe.FindStringSubmatch(localSDPAnswer)
if len(pwdMatch) <= 1 {
return errors.Errorf("no ice-pwd in local sdp answer %v", localSDPAnswer)
}
icePwd = pwdMatch[1]

localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer)
if err != nil {
return errors.Wrapf(err, "parse local sdp answer")
}

// Save the new WebRTC connection to LB.
icePair := &RTCICEPair{
RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd,
LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd,
}
if _, err := srsLoadBalancer.LoadOrStoreWebRTC(ctx, streamURL, icePair.Ufrag(), NewRTCStreaming(func(s *RTCConnection) {
s.StreamURL, s.listenerUDP = streamURL, v.listener
s.BuildContext(ctx)
})); err != nil {
return errors.Wrapf(err, "load or store webrtc %v", streamURL)
}

// Response client with local answer.
Expand All @@ -197,7 +211,7 @@ func (v *rtcServer) serveByBackend(ctx context.Context, w http.ResponseWriter, r
}

logger.Df(ctx, "Response local answer %vB with ice-ufrag=%v, ice-pwd=%vB",
len(localSDPAnswer), iceUfrag, len(icePwd))
len(localSDPAnswer), localICEUfrag, len(localICEPwd))
return nil
}

Expand All @@ -220,5 +234,253 @@ func (v *rtcServer) Run(ctx context.Context) error {
v.listener = listener
logger.Df(ctx, "WebRTC server listen at %v", addr)

// Consume all messages from UDP media transport.
v.wg.Add(1)
go func() {
defer v.wg.Done()

for ctx.Err() == nil {
buf := make([]byte, 4096)
n, addr, err := listener.ReadFromUDP(buf)
if err != nil {
// TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit.
logger.Wf(ctx, "read from udp failed, err=%v", err)
continue
}

if err := v.handleClientUDP(ctx, addr, buf[:n]); err != nil {
logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%v", n, addr, err)
}
}
}()

return nil
}

func (v *rtcServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error {
var stream *RTCConnection

// If STUN binding request, parse the ufrag and identify the connection.
if err := func() error {
if rtc_is_rtp_or_rtcp(data) || !rtc_is_stun(data) {
return nil
}

var pkt RTCStunPacket
if err := pkt.UnmarshalBinary(data); err != nil {
return errors.Wrapf(err, "unmarshal stun packet")
}

// Search the stream in fast cache.
if s, ok := v.usernames.Load(pkt.Username); ok {
stream = s
return nil
}

// Load stream by username.
if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil {
return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username)
} else {
stream = s
}

// Cache stream for fast search.
if stream != nil {
v.usernames.Store(pkt.Username, stream)
}
return nil
}(); err != nil {
return err
}

// Search the stream by addr.
if s, ok := v.addresses.Load(addr.String()); ok {
stream = s
} else if stream != nil {
// Cache the address for fast search.
v.addresses.Store(addr.String(), stream)
}

// If stream is not found, ignore the packet.
if stream == nil {
// TODO: Should logging the dropped packet, only logging the first one for each address.
return nil
}

// Proxy the packet to backend.
if err := stream.Proxy(addr, data); err != nil {
return errors.Wrapf(err, "proxy %vB for %v", len(data), stream.StreamURL)
}

return nil
}

type RTCConnection struct {
// The stream context for WebRTC streaming.
ctx context.Context
// The context ID for recovering the context.
ContextID string `json:"cid"`

// The stream URL in vhost/app/stream schema.
StreamURL string `json:"stream_url"`

// The UDP connection proxy to backend.
backendUDP *net.UDPConn
// The client UDP address. Note that it may change.
clientUDP *net.UDPAddr
// The listener UDP connection, used to send messages to client.
listenerUDP *net.UDPConn
}

func NewRTCStreaming(opts ...func(*RTCConnection)) *RTCConnection {
v := &RTCConnection{}
for _, opt := range opts {
opt(v)
}
return v
}

func (v *RTCConnection) Proxy(addr *net.UDPAddr, data []byte) error {
ctx := v.ctx

// Update the current UDP address.
v.clientUDP = addr

// Start the UDP proxy to backend.
if err := v.connectBackend(ctx); err != nil {
return errors.Wrapf(err, "connect backend for %v", v.StreamURL)
}

// Proxy client message to backend.
if v.backendUDP != nil {
if _, err := v.backendUDP.Write(data); err != nil {
return errors.Wrapf(err, "write to backend %v", v.StreamURL)
}
}

return nil
}

func (v *RTCConnection) connectBackend(ctx context.Context) error {
if v.backendUDP != nil {
return nil
}

// Pick a backend SRS server to proxy the RTC stream.
backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL)
if err != nil {
return errors.Wrapf(err, "pick backend")
}

// Parse UDP port from backend.
if len(backend.RTC) == 0 {
return errors.Errorf("no udp server")
}

var udpPort int
if iv, err := strconv.ParseInt(backend.RTC[0], 10, 64); err != nil {
return errors.Wrapf(err, "parse udp port %v", backend.RTC[0])
} else {
udpPort = int(iv)
}

// Connect to backend SRS server via UDP client.
backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort}
if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil {
return errors.Wrapf(err, "dial udp to %v", backendAddr)
} else {
v.backendUDP = backendUDP
}

// Proxy all messages from backend to client.
go func() {
for ctx.Err() == nil {
buf := make([]byte, 4096)
n, _, err := v.backendUDP.ReadFromUDP(buf)
if err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "read from backend failed, err=%v", err)
break
}

if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil {
// TODO: If backend server closed unexpectedly, we should notice the stream to quit.
logger.Wf(ctx, "write to client failed, err=%v", err)
break
}
}
}()

return nil
}

func (v *RTCConnection) BuildContext(ctx context.Context) {
if v.ContextID == "" {
v.ContextID = logger.GenerateContextID()
}
v.ctx = logger.WithContextID(ctx, v.ContextID)
}

type RTCICEPair struct {
// The remote ufrag, used for ICE username and session id.
RemoteICEUfrag string `json:"remote_ufrag"`
// The remote pwd, used for ICE password.
RemoteICEPwd string `json:"remote_pwd"`
// The local ufrag, used for ICE username and session id.
LocalICEUfrag string `json:"local_ufrag"`
// The local pwd, used for ICE password.
LocalICEPwd string `json:"local_pwd"`
}

// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag.
func (v *RTCICEPair) Ufrag() string {
return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag)
}

type RTCStunPacket struct {
// The stun message type.
MessageType uint16
// The stun username, or ufrag.
Username string
}

func (v *RTCStunPacket) UnmarshalBinary(data []byte) error {
if len(data) < 20 {
return errors.Errorf("stun packet too short %v", len(data))
}

p := data
v.MessageType = binary.BigEndian.Uint16(p)
messageLen := binary.BigEndian.Uint16(p[2:])
//magicCookie := p[:8]
//transactionID := p[:20]
p = p[20:]

if len(p) != int(messageLen) {
return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen)
}

for len(p) > 0 {
typ := binary.BigEndian.Uint16(p)
length := binary.BigEndian.Uint16(p[2:])
p = p[4:]

if len(p) < int(length) {
return errors.Errorf("stun attribute length invalid %v < %v", len(p), length)
}

value := p[:length]
p = p[length:]

if length%4 != 0 {
p = p[4-length%4:]
}

switch typ {
case 0x0006:
v.Username = string(value)
}
}

return nil
}
Loading

0 comments on commit 2fc5b44

Please sign in to comment.