From 4c5fd3482fbacc7d5da99adf0e81715c6eca9e74 Mon Sep 17 00:00:00 2001 From: "huangsai.hs" Date: Tue, 18 Jul 2023 19:28:35 +0800 Subject: [PATCH] add arguments MAX_DIST & FILTER; add unit test for getdistance --- tair/tairvector.py | 9 +++++ tests/test_tairvector.py | 71 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/tair/tairvector.py b/tair/tairvector.py index a2652cc..85c45fb 100644 --- a/tair/tairvector.py +++ b/tair/tairvector.py @@ -506,6 +506,11 @@ def _tvs_getdistance( args = list(keys) if top_n is not None: args += ("TOPN", top_n) + if max_dist is not None: + args += ("MAX_DIST", max_dist) + if filter_str is not None: + args += ("FILTER", filter_str) + if (not isinstance(vector, str)) and (not isinstance(vector, bytes)): vector_str = self.encode_vector(vector) else: @@ -541,6 +546,10 @@ def tvs_getdistance( k = min(k, top_n) args = ["TOPN", k] + if max_dist is not None: + args += ("MAX_DIST", max_dist) + if filter_str is not None: + args += ("FILTER", filter_str) def process_batch(batch): return self.execute_command( diff --git a/tests/test_tairvector.py b/tests/test_tairvector.py index 82d68b2..89bbf57 100644 --- a/tests/test_tairvector.py +++ b/tests/test_tairvector.py @@ -628,6 +628,77 @@ def test_9_cleanup(self): self.assertEqual(ret, 1) +class GetDistanceTest(unittest.TestCase): + index_name = "getdistance_test" + + def test_0_create(self): + vectors = [[random() for _ in range(dim)] for _ in range(1000)] + ret = client.tvs_create_index( + self.index_name, + dim, + distance_type=DistanceMetric.L2, + ) + self.assertTrue(ret) + for i, v in enumerate(vectors): + ret = client.tvs_hset(self.index_name, str(i), v, attr=i) + self.assertEqual(ret, 2) + + def test_1_getdistance(self): + query = [random() for _ in range(dim)] + keys = [str(i) for i in range(1000)] + + # test low level interface + results = client._tvs_getdistance(self.index_name, query, keys, top_n=10) + self.assertEqual(20, len(results)) + for i in range(0, len(results) - 2, 2): + self.assertTrue(float(results[i + 1]) <= float(results[i + 3])) + + # test wrapped interface + results = client.tvs_getdistance( + self.index_name, query, keys, batch_size=100, top_n=10 + ) + self.assertEqual(10, len(results)) + for i in range(len(results) - 1): + self.assertTrue(results[i][1] <= results[i + 1][1]) + + def test_2_getdistance_with_max_dist(self): + query = [random() for _ in range(dim)] + keys = [str(i) for i in range(1000)] + + results = client.tvs_getdistance( + self.index_name, query, keys, batch_size=100, top_n=10, max_dist=3.0 + ) + self.assertGreater(len(results), 0) + self.assertGreaterEqual(10, len(results)) + for i in range(len(results) - 1): + self.assertTrue(results[i][1] <= results[i + 1][1]) + for _, score in results: + self.assertGreater(3.0, score) + + def test_3_getdistance_with_filter(self): + query = [random() for _ in range(dim)] + keys = [str(i) for i in range(1000)] + + results = client.tvs_getdistance( + self.index_name, + query, + keys, + batch_size=100, + top_n=10, + filter_str="attr<500", + ) + self.assertGreater(len(results), 0) + self.assertGreaterEqual(10, len(results)) + for i in range(len(results) - 1): + self.assertTrue(results[i][1] <= results[i + 1][1]) + for key, _ in results: + self.assertLess(int(key), 500) + + def test_9_cleanup(self): + ret = client.tvs_del_index(self.index_name) + self.assertEqual(ret, 1) + + if __name__ == "__main__": unittest.main() client.close()