From f866c805c2f78271de9f2b61254363d009cee8c6 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Thu, 25 Jul 2024 22:42:34 -0700 Subject: [PATCH] util: SocketAddress.toString() cannot be used for equality Some addresses are equal even though their toString is different (InetSocketAddress ignores the hostname when it has an address). And some addresses are not equal even though their toString might be the same (AnonymousInProcessSocketAddress doesn't override toString()). InetSocketAddress/InetAddress do not cache the toString() result. Thus, even in the worst case that uses a HashSet, this should use less memory than the earlier approach, as no strings are formatted. It probably also significantly improves performance in the reasonably common case when an Endpoint is created just for looking up a key, because the string creation in the constructor isn't then amorized. updateChildrenWithResolvedAddresses(), for example, creates n^2 Endpoint objects for lookups. --- .../io/grpc/util/MultiChildLoadBalancer.java | 33 ++++++----- .../grpc/util/MultiChildLoadBalancerTest.java | 55 ++++++++----------- .../java/io/grpc/util/AbstractTestHelper.java | 16 +++++- 3 files changed, 54 insertions(+), 50 deletions(-) diff --git a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java index c5f774984fe..893dd1e1598 100644 --- a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java @@ -37,10 +37,10 @@ import io.grpc.internal.PickFirstLoadBalancerProvider; import java.net.SocketAddress; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -494,25 +494,27 @@ protected Helper delegate() { /** * Endpoint is an optimization to quickly lookup and compare EquivalentAddressGroup address sets. - * Ignores the attributes, orders the addresses in a deterministic manner and converts each - * address into a string for easy comparison. Also caches the hashcode. - * Is used as a key for ChildLbState for most load balancers (ClusterManagerLB uses a String). + * It ignores the attributes. Is used as a key for ChildLbState for most load balancers + * (ClusterManagerLB uses a String). */ protected static class Endpoint { - final String[] addrs; + final Collection addrs; final int hashCode; public Endpoint(EquivalentAddressGroup eag) { checkNotNull(eag, "eag"); - addrs = new String[eag.getAddresses().size()]; - int i = 0; + if (eag.getAddresses().size() < 10) { + addrs = eag.getAddresses(); + } else { + // This is expected to be very unlikely in practice + addrs = new HashSet<>(eag.getAddresses()); + } + int sum = 0; for (SocketAddress address : eag.getAddresses()) { - addrs[i++] = address.toString(); + sum += address.hashCode(); } - Arrays.sort(addrs); - - hashCode = Arrays.hashCode(addrs); + hashCode = sum; } @Override @@ -525,24 +527,21 @@ public boolean equals(Object other) { if (this == other) { return true; } - if (other == null) { - return false; - } if (!(other instanceof Endpoint)) { return false; } Endpoint o = (Endpoint) other; - if (o.hashCode != hashCode || o.addrs.length != addrs.length) { + if (o.hashCode != hashCode || o.addrs.size() != addrs.size()) { return false; } - return Arrays.equals(o.addrs, this.addrs); + return o.addrs.containsAll(addrs); } @Override public String toString() { - return Arrays.toString(addrs); + return addrs.toString(); } } diff --git a/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java b/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java index df226d5aee8..6bfd6d7a659 100644 --- a/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java @@ -21,7 +21,6 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -34,6 +33,7 @@ import static org.mockito.Mockito.verify; import com.google.common.collect.Lists; +import com.google.common.testing.EqualsTester; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -244,37 +244,28 @@ public void testEndpoint_toString() { @Test public void testEndpoint_equals() { - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1"), - createEndpoint(Attributes.EMPTY, "addr1")); - - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(Attributes.EMPTY, "addr2", "addr1")); - - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(affinity, "addr2", "addr1")); - - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2").hashCode(), - createEndpoint(affinity, "addr2", "addr1").hashCode()); - - } - - @Test - public void testEndpoint_notEquals() { - assertNotEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(Attributes.EMPTY, "addr1", "addr3")); - - assertNotEquals( - createEndpoint(Attributes.EMPTY, "addr1"), - createEndpoint(Attributes.EMPTY, "addr1", "addr2")); - - assertNotEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(Attributes.EMPTY, "addr1")); + new EqualsTester() + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1"), + createEndpoint(Attributes.EMPTY, "addr1")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2"), + createEndpoint(Attributes.EMPTY, "addr2", "addr1"), + createEndpoint(affinity, "addr1", "addr2")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr3")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr10"), + createEndpoint(Attributes.EMPTY, "addr2", "addr1", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr10")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr11")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr10", "addr11")) + .testEquals(); } private String addressesOnlyString(EquivalentAddressGroup eag) { diff --git a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java index b0239c56703..bdeff9d17c5 100644 --- a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java +++ b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java @@ -276,7 +276,7 @@ public String toString() { } } - public static class FakeSocketAddress extends SocketAddress { + public static final class FakeSocketAddress extends SocketAddress { private static final long serialVersionUID = 0L; final String name; @@ -288,6 +288,20 @@ public static class FakeSocketAddress extends SocketAddress { public String toString() { return "FakeSocketAddress-" + name; } + + @Override + public boolean equals(Object o) { + if (!(o instanceof FakeSocketAddress)) { + return false; + } + FakeSocketAddress that = (FakeSocketAddress) o; + return this.name.equals(that.name); + } + + @Override + public int hashCode() { + return name.hashCode(); + } } }