Skip to content

Commit

Permalink
add arguments MAX_DIST & FILTER; add unit test for getdistance
Browse files Browse the repository at this point in the history
  • Loading branch information
seth-hg authored and yangbodong22011 committed Jul 19, 2023
1 parent 53c095a commit 4c5fd34
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tair/tairvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,11 @@ def _tvs_getdistance(
args = list(keys)
if top_n is not None:
args += ("TOPN", top_n)
if max_dist is not None:
args += ("MAX_DIST", max_dist)
if filter_str is not None:
args += ("FILTER", filter_str)

if (not isinstance(vector, str)) and (not isinstance(vector, bytes)):
vector_str = self.encode_vector(vector)
else:
Expand Down Expand Up @@ -541,6 +546,10 @@ def tvs_getdistance(
k = min(k, top_n)

args = ["TOPN", k]
if max_dist is not None:
args += ("MAX_DIST", max_dist)
if filter_str is not None:
args += ("FILTER", filter_str)

def process_batch(batch):
return self.execute_command(
Expand Down
71 changes: 71 additions & 0 deletions tests/test_tairvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,77 @@ def test_9_cleanup(self):
self.assertEqual(ret, 1)


class GetDistanceTest(unittest.TestCase):
index_name = "getdistance_test"

def test_0_create(self):
vectors = [[random() for _ in range(dim)] for _ in range(1000)]
ret = client.tvs_create_index(
self.index_name,
dim,
distance_type=DistanceMetric.L2,
)
self.assertTrue(ret)
for i, v in enumerate(vectors):
ret = client.tvs_hset(self.index_name, str(i), v, attr=i)
self.assertEqual(ret, 2)

def test_1_getdistance(self):
query = [random() for _ in range(dim)]
keys = [str(i) for i in range(1000)]

# test low level interface
results = client._tvs_getdistance(self.index_name, query, keys, top_n=10)
self.assertEqual(20, len(results))
for i in range(0, len(results) - 2, 2):
self.assertTrue(float(results[i + 1]) <= float(results[i + 3]))

# test wrapped interface
results = client.tvs_getdistance(
self.index_name, query, keys, batch_size=100, top_n=10
)
self.assertEqual(10, len(results))
for i in range(len(results) - 1):
self.assertTrue(results[i][1] <= results[i + 1][1])

def test_2_getdistance_with_max_dist(self):
query = [random() for _ in range(dim)]
keys = [str(i) for i in range(1000)]

results = client.tvs_getdistance(
self.index_name, query, keys, batch_size=100, top_n=10, max_dist=3.0
)
self.assertGreater(len(results), 0)
self.assertGreaterEqual(10, len(results))
for i in range(len(results) - 1):
self.assertTrue(results[i][1] <= results[i + 1][1])
for _, score in results:
self.assertGreater(3.0, score)

def test_3_getdistance_with_filter(self):
query = [random() for _ in range(dim)]
keys = [str(i) for i in range(1000)]

results = client.tvs_getdistance(
self.index_name,
query,
keys,
batch_size=100,
top_n=10,
filter_str="attr<500",
)
self.assertGreater(len(results), 0)
self.assertGreaterEqual(10, len(results))
for i in range(len(results) - 1):
self.assertTrue(results[i][1] <= results[i + 1][1])
for key, _ in results:
self.assertLess(int(key), 500)

def test_9_cleanup(self):
ret = client.tvs_del_index(self.index_name)
self.assertEqual(ret, 1)


if __name__ == "__main__":
unittest.main()
client.close()

0 comments on commit 4c5fd34

Please sign in to comment.