diff --git a/cmd/antrea-agent/agent.go b/cmd/antrea-agent/agent.go index 22b4c4af730..fd2ebad68c7 100644 --- a/cmd/antrea-agent/agent.go +++ b/cmd/antrea-agent/agent.go @@ -222,6 +222,9 @@ func run(o *Options) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Must start after registering all event handlers. + go serviceCIDRProvider.Run(stopCh) + // Get all available NodePort addresses. var nodePortAddressesIPv4, nodePortAddressesIPv6 []net.IP if o.config.AntreaProxy.ProxyAll { diff --git a/pkg/agent/route/route_linux.go b/pkg/agent/route/route_linux.go index fabee668733..1e18ba6dbdb 100644 --- a/pkg/agent/route/route_linux.go +++ b/pkg/agent/route/route_linux.go @@ -41,6 +41,7 @@ import ( binding "antrea.io/antrea/pkg/ovs/openflow" "antrea.io/antrea/pkg/ovs/ovsconfig" "antrea.io/antrea/pkg/util/env" + utilip "antrea.io/antrea/pkg/util/ip" ) const ( @@ -1363,7 +1364,22 @@ func (c *Client) addServiceCIDRRoute(serviceCIDR *net.IPNet) error { return fmt.Errorf("error listing ip routes: %w", err) } for i := 0; i < len(routes); i++ { - if routes[i].Gw.Equal(gw) && !routes[i].Dst.IP.Equal(serviceCIDR.IP) && routes[i].Dst.Contains(serviceCIDR.IP) { + // Not the routes we are interested in. + if !routes[i].Gw.Equal(gw) { + continue + } + // It's the latest route we just installed. + if utilip.IPNetEqual(routes[i].Dst, serviceCIDR) { + continue + } + // The route covers the desired route. It was installed when the calculated ServiceCIDR was larger than the + // current one, which could happen after some Services are deleted. + if utilip.IPNetContains(routes[i].Dst, serviceCIDR) { + staleRoutes = append(staleRoutes, &routes[i]) + } + // The desired route covers the route. It was installed when the calculated ServiceCIDR was smaller than the + // current one, which could happen after some Services are added. + if utilip.IPNetContains(serviceCIDR, routes[i].Dst) { staleRoutes = append(staleRoutes, &routes[i]) } } diff --git a/pkg/agent/route/route_linux_test.go b/pkg/agent/route/route_linux_test.go index 708e0a12264..9b675304025 100644 --- a/pkg/agent/route/route_linux_test.go +++ b/pkg/agent/route/route_linux_test.go @@ -1332,6 +1332,30 @@ func TestAddServiceCIDRRoute(t *testing.T) { }) }, }, + { + name: "Add route for Service IPv4 CIDR and clean up stale routes", + curServiceIPv4CIDR: nil, + newServiceIPv4CIDR: ip.MustParseCIDR("10.96.0.0/28"), + expectedCalls: func(mockNetlink *netlinktest.MockInterfaceMockRecorder) { + mockNetlink.RouteReplace(&netlink.Route{ + Dst: &net.IPNet{IP: net.ParseIP("10.96.0.0").To4(), Mask: net.CIDRMask(28, 32)}, + Gw: config.VirtualServiceIPv4, + Scope: netlink.SCOPE_UNIVERSE, + LinkIndex: 10, + }) + mockNetlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{LinkIndex: 10}, netlink.RT_FILTER_OIF).Return([]netlink.Route{ + {Dst: ip.MustParseCIDR("10.96.0.0/24"), Gw: config.VirtualServiceIPv4}, + {Dst: ip.MustParseCIDR("10.96.0.0/30"), Gw: config.VirtualServiceIPv4}, + }, nil) + mockNetlink.RouteListFiltered(netlink.FAMILY_V6, &netlink.Route{LinkIndex: 10}, netlink.RT_FILTER_OIF).Return([]netlink.Route{}, nil) + mockNetlink.RouteDel(&netlink.Route{ + Dst: ip.MustParseCIDR("10.96.0.0/24"), Gw: config.VirtualServiceIPv4, + }) + mockNetlink.RouteDel(&netlink.Route{ + Dst: ip.MustParseCIDR("10.96.0.0/30"), Gw: config.VirtualServiceIPv4, + }) + }, + }, { name: "Update route for Service IPv4 CIDR", curServiceIPv4CIDR: serviceIPv4CIDR1, diff --git a/pkg/agent/servicecidr/discoverer.go b/pkg/agent/servicecidr/discoverer.go index 5136081e0c7..a6a2dc74ba5 100644 --- a/pkg/agent/servicecidr/discoverer.go +++ b/pkg/agent/servicecidr/discoverer.go @@ -21,10 +21,15 @@ import ( "time" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/types" coreinformers "k8s.io/client-go/informers/core/v1" + corelisters "k8s.io/client-go/listers/core/v1" "k8s.io/client-go/tools/cache" + "k8s.io/client-go/util/workqueue" "k8s.io/klog/v2" utilnet "k8s.io/utils/net" + "k8s.io/utils/strings/slices" "antrea.io/antrea/pkg/agent/util" ) @@ -42,17 +47,22 @@ type Interface interface { AddEventHandler(handler EventHandler) } -type discoverer struct { +type Discoverer struct { serviceInformer cache.SharedIndexInformer + serviceLister corelisters.ServiceLister sync.RWMutex serviceIPv4CIDR *net.IPNet serviceIPv6CIDR *net.IPNet eventHandlers []EventHandler + // queue maintains the Service objects that need to be synced. + queue workqueue.Interface } -func NewServiceCIDRDiscoverer(serviceInformer coreinformers.ServiceInformer) Interface { - d := &discoverer{ +func NewServiceCIDRDiscoverer(serviceInformer coreinformers.ServiceInformer) *Discoverer { + d := &Discoverer{ serviceInformer: serviceInformer.Informer(), + serviceLister: serviceInformer.Lister(), + queue: workqueue.New(), } d.serviceInformer.AddEventHandlerWithResyncPeriod( cache.ResourceEventHandlerFuncs{ @@ -64,7 +74,37 @@ func NewServiceCIDRDiscoverer(serviceInformer coreinformers.ServiceInformer) Int return d } -func (d *discoverer) GetServiceCIDR(isIPv6 bool) (*net.IPNet, error) { +func (d *Discoverer) Run(stopCh <-chan struct{}) { + defer d.queue.ShutDown() + + klog.Info("Starting ServiceCIDRDiscoverer") + defer klog.Info("Stopping ServiceCIDRDiscoverer") + if !cache.WaitForCacheSync(stopCh, d.serviceInformer.HasSynced) { + return + } + svcs, _ := d.serviceLister.List(labels.Everything()) + d.updateServiceCIDR(svcs...) + + go func() { + for { + obj, quit := d.queue.Get() + if quit { + return + } + nn := obj.(types.NamespacedName) + + svc, _ := d.serviceLister.Services(nn.Namespace).Get(nn.Name) + // Ignore it if not found. + if svc != nil { + d.updateServiceCIDR(svc) + } + d.queue.Done(obj) + } + }() + <-stopCh +} + +func (d *Discoverer) GetServiceCIDR(isIPv6 bool) (*net.IPNet, error) { d.RLock() defer d.RUnlock() if isIPv6 { @@ -79,32 +119,37 @@ func (d *discoverer) GetServiceCIDR(isIPv6 bool) (*net.IPNet, error) { return d.serviceIPv4CIDR, nil } -func (d *discoverer) AddEventHandler(handler EventHandler) { +func (d *Discoverer) AddEventHandler(handler EventHandler) { d.eventHandlers = append(d.eventHandlers, handler) } -func (d *discoverer) addService(obj interface{}) { - svc := obj.(*corev1.Service) - d.updateServiceCIDR(svc) -} - -func (d *discoverer) updateService(_, obj interface{}) { +func (d *Discoverer) addService(obj interface{}) { svc := obj.(*corev1.Service) - d.updateServiceCIDR(svc) + klog.V(2).InfoS("Processing Service ADD event", "Service", klog.KObj(svc)) + d.queue.Add(types.NamespacedName{Namespace: svc.Namespace, Name: svc.Name}) } -func (d *discoverer) updateServiceCIDR(svc *corev1.Service) { - clusterIPs := svc.Spec.ClusterIPs - if len(clusterIPs) == 0 { - return +func (d *Discoverer) updateService(old, obj interface{}) { + oldSvc := old.(*corev1.Service) + curSvc := obj.(*corev1.Service) + klog.V(2).InfoS("Processing Service UPDATE event", "Service", klog.KObj(curSvc)) + if !slices.Equal(oldSvc.Spec.ClusterIPs, curSvc.Spec.ClusterIPs) { + d.queue.Add(types.NamespacedName{Namespace: curSvc.Namespace, Name: curSvc.Name}) } +} +func (d *Discoverer) updateServiceCIDR(svcs ...*corev1.Service) { var newServiceCIDRs []*net.IPNet - klog.V(2).InfoS("Processing Service ADD or UPDATE event", "Service", klog.KObj(svc)) - func() { - d.Lock() - defer d.Unlock() - for _, clusterIPStr := range clusterIPs { + + curServiceIPv4CIDR, curServiceIPv6CIDR := func() (*net.IPNet, *net.IPNet) { + d.RLock() + defer d.RUnlock() + return d.serviceIPv4CIDR, d.serviceIPv6CIDR + }() + + updated := false + for _, svc := range svcs { + for _, clusterIPStr := range svc.Spec.ClusterIPs { clusterIP := net.ParseIP(clusterIPStr) if clusterIP == nil { klog.V(2).InfoS("Skip invalid ClusterIP", "ClusterIP", clusterIPStr) @@ -112,10 +157,10 @@ func (d *discoverer) updateServiceCIDR(svc *corev1.Service) { } isIPv6 := utilnet.IsIPv6(clusterIP) - curServiceCIDR := d.serviceIPv4CIDR + curServiceCIDR := curServiceIPv4CIDR mask := net.IPv4len * 8 if isIPv6 { - curServiceCIDR = d.serviceIPv6CIDR + curServiceCIDR = curServiceIPv6CIDR mask = net.IPv6len * 8 } @@ -138,16 +183,31 @@ func (d *discoverer) updateServiceCIDR(svc *corev1.Service) { } if isIPv6 { - d.serviceIPv6CIDR = newServiceCIDR - klog.V(4).InfoS("Service IPv6 CIDR was updated", "ServiceCIDR", newServiceCIDR) + curServiceIPv6CIDR = newServiceCIDR } else { - d.serviceIPv4CIDR = newServiceCIDR - klog.V(4).InfoS("Service IPv4 CIDR was updated", "ServiceCIDR", newServiceCIDR) + curServiceIPv4CIDR = newServiceCIDR } - newServiceCIDRs = append(newServiceCIDRs, newServiceCIDR) + updated = true } - }() + } + if !updated { + return + } + func() { + d.Lock() + defer d.Unlock() + if d.serviceIPv4CIDR != curServiceIPv4CIDR { + d.serviceIPv4CIDR = curServiceIPv4CIDR + klog.InfoS("Service IPv4 CIDR was updated", "ServiceCIDR", curServiceIPv4CIDR) + newServiceCIDRs = append(newServiceCIDRs, curServiceIPv4CIDR) + } + if d.serviceIPv6CIDR != curServiceIPv6CIDR { + d.serviceIPv6CIDR = curServiceIPv6CIDR + klog.InfoS("Service IPv6 CIDR was updated", "ServiceCIDR", curServiceIPv6CIDR) + newServiceCIDRs = append(newServiceCIDRs, curServiceIPv6CIDR) + } + }() for _, handler := range d.eventHandlers { handler(newServiceCIDRs) } diff --git a/pkg/agent/servicecidr/discoverer_test.go b/pkg/agent/servicecidr/discoverer_test.go index 308c9801124..850db83de6a 100644 --- a/pkg/agent/servicecidr/discoverer_test.go +++ b/pkg/agent/servicecidr/discoverer_test.go @@ -67,6 +67,7 @@ func TestServiceCIDRProvider(t *testing.T) { defer close(stopCh) informerFactory.Start(stopCh) informerFactory.WaitForCacheSync(stopCh) + go serviceCIDRProvider.Run(stopCh) check := func(expectedServiceCIDR string, isServiceCIDRUpdated, isIPv6 bool) { if isServiceCIDRUpdated { @@ -84,15 +85,18 @@ func TestServiceCIDRProvider(t *testing.T) { } } serviceCIDR, err := serviceCIDRProvider.GetServiceCIDR(isIPv6) - assert.NoError(t, err) - assert.Equal(t, expectedServiceCIDR, serviceCIDR.String()) + if expectedServiceCIDR != "" { + assert.NoError(t, err) + assert.Equal(t, expectedServiceCIDR, serviceCIDR.String()) + } else { + assert.ErrorContains(t, err, "CIDR is not available yet") + } } svc := makeService("ns1", "svc0", "None", corev1.ProtocolTCP) _, err := client.CoreV1().Services("ns1").Create(context.TODO(), svc, metav1.CreateOptions{}) assert.NoError(t, err) - _, err = serviceCIDRProvider.GetServiceCIDR(false) - assert.ErrorContains(t, err, "Service IPv4 CIDR is not available yet") + check("", false, false) svc = makeService("ns1", "svc1", "10.10.0.1", corev1.ProtocolTCP) _, err = client.CoreV1().Services("ns1").Create(context.TODO(), svc, metav1.CreateOptions{}) @@ -121,8 +125,7 @@ func TestServiceCIDRProvider(t *testing.T) { svc = makeService("ns1", "svc60", "None", corev1.ProtocolTCP) _, err = client.CoreV1().Services("ns1").Create(context.TODO(), svc, metav1.CreateOptions{}) assert.NoError(t, err) - _, err = serviceCIDRProvider.GetServiceCIDR(true) - assert.ErrorContains(t, err, "Service IPv6 CIDR is not available yet") + check("", false, true) svc = makeService("ns1", "svc61", "10::1", corev1.ProtocolTCP) _, err = client.CoreV1().Services("ns1").Create(context.TODO(), svc, metav1.CreateOptions{}) diff --git a/pkg/util/ip/ip.go b/pkg/util/ip/ip.go index fc504508a43..bd82ea18989 100644 --- a/pkg/util/ip/ip.go +++ b/pkg/util/ip/ip.go @@ -195,6 +195,49 @@ func MustParseCIDR(cidr string) *net.IPNet { return ipNet } +// IPNetEqual returns if the provided IPNets are the same subnet. +func IPNetEqual(ipNet1, ipNet2 *net.IPNet) bool { + if ipNet1 == nil && ipNet2 == nil { + return true + } + if ipNet1 == nil || ipNet2 == nil { + return false + } + if !bytes.Equal(ipNet1.Mask, ipNet2.Mask) { + return false + } + if !ipNet1.IP.Equal(ipNet2.IP) { + return false + } + return true +} + +// IPNetContains returns if the first IPNet contains the second IPNet. +// For example: +// +// 10.0.0.0/24 contains 10.0.0.0/24. +// 10.0.0.0/24 contains 10.0.0.0/25. +// 10.0.0.0/24 contains 10.0.0.128/25. +// 10.0.0.0/24 does not contain 10.0.0.0/23. +// 10.0.0.0/24 does not contain 10.0.1.0/25. +func IPNetContains(ipNet1, ipNet2 *net.IPNet) bool { + if ipNet1 == nil || ipNet2 == nil { + return false + } + ones1, bits1 := ipNet1.Mask.Size() + ones2, bits2 := ipNet2.Mask.Size() + if bits1 != bits2 { + return false + } + if ones1 > ones2 { + return false + } + if !ipNet1.Contains(ipNet2.IP) { + return false + } + return true +} + func MustIPv6(s string) net.IP { ip := net.ParseIP(s) if !utilnet.IsIPv6(ip) { diff --git a/pkg/util/ip/ip_test.go b/pkg/util/ip/ip_test.go index ef812257fba..caf12b50f52 100644 --- a/pkg/util/ip/ip_test.go +++ b/pkg/util/ip/ip_test.go @@ -239,3 +239,93 @@ func TestAppendPortIfMissing(t *testing.T) { }) } } + +func TestIPNetEqual(t *testing.T) { + tests := []struct { + name string + ipNet1 *net.IPNet + ipNet2 *net.IPNet + want bool + }{ + { + name: "equal", + ipNet1: MustParseCIDR("1.1.1.0/30"), + ipNet2: MustParseCIDR("1.1.1.0/30"), + want: true, + }, + { + name: "different mask", + ipNet1: MustParseCIDR("1.1.1.0/30"), + ipNet2: MustParseCIDR("1.1.1.0/29"), + want: false, + }, + { + name: "different prefix", + ipNet1: MustParseCIDR("1.1.1.4/30"), + ipNet2: MustParseCIDR("1.1.1.0/30"), + want: false, + }, + { + name: "different family", + ipNet1: MustParseCIDR("1.1.1.4/30"), + ipNet2: MustParseCIDR("1:1:1:4::/30"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IPNetEqual(tt.ipNet1, tt.ipNet2)) + }) + } +} + +func TestIPNetContains(t *testing.T) { + tests := []struct { + name string + ipNet1 *net.IPNet + ipNet2 *net.IPNet + want bool + }{ + { + name: "equal", + ipNet1: MustParseCIDR("10.0.0.0/24"), + ipNet2: MustParseCIDR("10.0.0.0/24"), + want: true, + }, + { + name: "contain smaller subnet", + ipNet1: MustParseCIDR("10.0.0.0/24"), + ipNet2: MustParseCIDR("10.0.0.0/25"), + want: true, + }, + { + name: "contain smaller subnet with different prefix", + ipNet1: MustParseCIDR("10.0.0.0/24"), + ipNet2: MustParseCIDR("10.0.0.128/25"), + want: true, + }, + { + name: "not contain larger subnet", + ipNet1: MustParseCIDR("10.0.0.0/24"), + ipNet2: MustParseCIDR("10.0.0.0/23"), + want: false, + }, + { + name: "not contain smaller subnet with different prefix", + ipNet1: MustParseCIDR("10.0.0.0/24"), + ipNet2: MustParseCIDR("10.0.1.0/25"), + want: false, + }, + { + name: "not contain subnet of different family", + ipNet1: MustParseCIDR("1.1.1.4/30"), + ipNet2: MustParseCIDR("1:1:1:4::/30"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IPNetContains(tt.ipNet1, tt.ipNet2)) + }) + } +}