Skip to content

Commit

Permalink
Refactor(core): replace net.IP with netip.Addr (#395)
Browse files Browse the repository at this point in the history
  • Loading branch information
xjasonlyu committed Aug 31, 2024
1 parent fd98f65 commit 1f09b4d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 24 deletions.
22 changes: 11 additions & 11 deletions core/nic.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package core

import (
"fmt"
"net"
"net/netip"

"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
Expand Down Expand Up @@ -61,7 +61,7 @@ func withSpoofing(nicID tcpip.NICID, v bool) option.Option {
}

// withMulticastGroups adds a NIC to the given multicast groups.
func withMulticastGroups(nicID tcpip.NICID, multicastGroups []net.IP) option.Option {
func withMulticastGroups(nicID tcpip.NICID, multicastGroups []netip.Addr) option.Option {
return func(s *stack.Stack) error {
if len(multicastGroups) == 0 {
return nil
Expand Down Expand Up @@ -103,15 +103,15 @@ func withMulticastGroups(nicID tcpip.NICID, multicastGroups []net.IP) option.Opt
stack.AddressProperties{PEB: stack.CanBePrimaryEndpoint},
)
for _, multicastGroup := range multicastGroups {
if ip := multicastGroup.To4(); ip != nil {
if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, tcpip.AddrFrom4Slice(ip)); err != nil {
return fmt.Errorf("join multicast group: %s", err)
}
} else {
ip := multicastGroup.To16()
if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, tcpip.AddrFrom16Slice(ip)); err != nil {
return fmt.Errorf("join multicast group: %s", err)
}
var err tcpip.Error
switch {
case multicastGroup.Is4():
err = s.JoinGroup(ipv4.ProtocolNumber, nicID, tcpip.AddrFrom4(multicastGroup.As4()))
case multicastGroup.Is6():
err = s.JoinGroup(ipv6.ProtocolNumber, nicID, tcpip.AddrFrom16(multicastGroup.As16()))
}
if err != nil {
return fmt.Errorf("join multicast group: %s", err)
}
}
return nil
Expand Down
4 changes: 2 additions & 2 deletions core/stack.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package core

import (
"net"
"net/netip"

"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
Expand All @@ -26,7 +26,7 @@ type Config struct {

// MulticastGroups is used by internal stack to add
// nic to given groups.
MulticastGroups []net.IP
MulticastGroups []netip.Addr

// Options are supplement options to apply settings
// for the internal stack.
Expand Down
3 changes: 2 additions & 1 deletion engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package engine
import (
"errors"
"net"
"net/netip"
"os/exec"
"sync"
"time"
Expand Down Expand Up @@ -197,7 +198,7 @@ func netstack(k *Key) (err error) {
return
}

var multicastGroups []net.IP
var multicastGroups []netip.Addr
if multicastGroups, err = parseMulticastGroups(k.MulticastGroups); err != nil {
return err
}
Expand Down
20 changes: 10 additions & 10 deletions engine/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/base64"
"fmt"
"net"
"net/netip"
"net/url"
"strings"

Expand Down Expand Up @@ -178,20 +179,19 @@ func parseRelay(u *url.URL) (proxy.Proxy, error) {
return proxy.NewRelay(address, username, password, opts.NoDelay)
}

func parseMulticastGroups(s string) (multicastGroups []net.IP, _ error) {
ipStrings := strings.Split(s, ",")
for _, ipString := range ipStrings {
if strings.TrimSpace(ipString) == "" {
func parseMulticastGroups(s string) (multicastGroups []netip.Addr, _ error) {
for _, ip := range strings.Split(s, ",") {
if ip = strings.TrimSpace(ip); ip == "" {
continue
}
ip := net.ParseIP(ipString)
if ip == nil {
return nil, fmt.Errorf("invalid IP format: %s", ipString)
addr, err := netip.ParseAddr(ip)
if err != nil {
return nil, err
}
if !ip.IsMulticast() {
return nil, fmt.Errorf("invalid multicast IP address: %s", ipString)
if !addr.IsMulticast() {
return nil, fmt.Errorf("invalid multicast IP: %s", addr)
}
multicastGroups = append(multicastGroups, ip)
multicastGroups = append(multicastGroups, addr)
}
return
}

0 comments on commit 1f09b4d

Please sign in to comment.