Skip to content

Commit

Permalink
skip certain columns of dubious utility
Browse files Browse the repository at this point in the history
  • Loading branch information
mikemccand committed Sep 7, 2024
1 parent 339cc97 commit ebefd8c
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 81 deletions.
157 changes: 81 additions & 76 deletions src/main/knn/KnnGraphTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package knn;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.io.OutputStream;
import java.lang.management.ManagementFactory;
Expand All @@ -31,9 +29,9 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileTime;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashSet;
import java.util.List;
Expand All @@ -42,16 +40,16 @@
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.Executors;

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.lucene912.Lucene912Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
Expand All @@ -68,6 +66,7 @@
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntIntHashMap;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.IndexSearcher;
Expand All @@ -84,15 +83,17 @@
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.PrintStreamInfoStream;
import org.apache.lucene.util.SuppressForbidden;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
//TODO Lucene may make these unavailable, we should pull in this from hppc directly
import org.apache.lucene.internal.hppc.IntIntHashMap;

// e.g. to compile with zero build tooling!:
//
Expand Down Expand Up @@ -531,87 +532,90 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][]
long start;
ThreadMXBean bean = ManagementFactory.getThreadMXBean();
long cpuTimeStartNs;
try (Directory dir = FSDirectory.open(indexPath);
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
dir.setPreload((x, ctx) -> x.endsWith(".vec") || x.endsWith(".veq"));
try (
DirectoryReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = new IndexSearcher(reader);
numDocs = reader.maxDoc();
Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null;
for (int i = 0; i < numIters; i++) {
// warm up
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
byte[] target = targetReaderByte.nextBytes();
if (prefilter) {
doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else {
doKnnByteVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
}
} else {
float[] target = targetReader.next();
if (prefilter) {
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
IndexSearcher searcher = new IndexSearcher(reader);
numDocs = reader.maxDoc();
Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null;
for (int i = 0; i < numIters; i++) {
// warm up
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
byte[] target = targetReaderByte.nextBytes();
if (prefilter) {
doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else {
doKnnByteVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
}
} else {
doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
float[] target = targetReader.next();
if (prefilter) {
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else {
doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
}
}
}
}
targetReader.reset();
start = System.nanoTime();
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
for (int i = 0; i < numIters; i++) {
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
byte[] target = targetReaderByte.nextBytes();
if (prefilter) {
results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else {
results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
}
} else {
float[] target = targetReader.next();
if (prefilter) {
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
targetReader.reset();
start = System.nanoTime();
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
for (int i = 0; i < numIters; i++) {
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
byte[] target = targetReaderByte.nextBytes();
if (prefilter) {
results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else {
results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
}
} else {
results[i] =
doKnnVectorQuery(
searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
float[] target = targetReader.next();
if (prefilter) {
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
} else {
results[i] =
doKnnVectorQuery(
searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
}
}
}
if (prefilter == false && matchDocs != null) {
results[i].scoreDocs =
if (prefilter == false && matchDocs != null) {
results[i].scoreDocs =
Arrays.stream(results[i].scoreDocs)
.filter(scoreDoc -> matchDocs.get(scoreDoc.doc))
.toArray(ScoreDoc[]::new);
.filter(scoreDoc -> matchDocs.get(scoreDoc.doc))
.toArray(ScoreDoc[]::new);
}
}
}
totalCpuTimeMS =
totalCpuTimeMS =
TimeUnit.NANOSECONDS.toMillis(bean.getCurrentThreadCpuTime() - cpuTimeStartNs);
elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); // ns -> ms
StoredFields storedFields = reader.storedFields();
for (int i = 0; i < numIters; i++) {
totalVisited += results[i].totalHits.value;
for (ScoreDoc doc : results[i].scoreDocs) {
if (doc.doc != NO_MORE_DOCS) {
// there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens
// in some degenerate case (like input query has NaN in it?) that causes no results to
// be returned from HNSW search?
doc.doc = Integer.parseInt(storedFields.document(doc.doc).get("id"));
} else {
System.out.println("NO_MORE_DOCS!");
elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); // ns -> ms
StoredFields storedFields = reader.storedFields();
for (int i = 0; i < numIters; i++) {
totalVisited += results[i].totalHits.value;
for (ScoreDoc doc : results[i].scoreDocs) {
if (doc.doc != NO_MORE_DOCS) {
// there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens
// in some degenerate case (like input query has NaN in it?) that causes no results to
// be returned from HNSW search?
doc.doc = Integer.parseInt(storedFields.document(doc.doc).get("id"));
} else {
System.out.println("NO_MORE_DOCS!");
}
}
}
}
}
if (quiet == false) {
System.out.println(
"completed "
+ numIters
+ " searches in "
+ elapsed
+ " ms: "
+ ((1000 * numIters) / elapsed)
+ " QPS "
+ "CPU time="
+ totalCpuTimeMS
+ "ms");
if (quiet == false) {
System.out.println(
"completed "
+ numIters
+ " searches in "
+ elapsed
+ " ms: "
+ ((1000 * numIters) / elapsed)
+ " QPS "
+ "CPU time="
+ totalCpuTimeMS
+ "ms");
}
}
} finally {
executorService.shutdown();
Expand Down Expand Up @@ -642,10 +646,11 @@ private void testSearch(Path indexPath, Path queryPath, Path outputPath, int[][]
}
System.out.printf(
Locale.ROOT,
"SUMMARY: %5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%s\t%d\t%d\t%.2f\t%s\n",
"SUMMARY: %5.3f\t%5.2f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%d\t%.2f\t%s\n",
recall,
totalCpuTimeMS / (float) numIters,
numDocs,
topK,
fanout,
maxConn,
beamWidth,
Expand Down
19 changes: 14 additions & 5 deletions src/python/knnPerfTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,36 @@ def run_knn_benchmark(checkout, values):
all_results.append(summary)
print('\nResults:')

header = 'recall\tlatency (ms)\tnDoc\tfanout\tmaxConn\tbeamWidth\tquantized\tvisited\tindex ms\tselectivity\tfilterType'
header = 'recall\tlatency (ms)\tnDoc\ttopK\tfanout\tmaxConn\tbeamWidth\tquantized\tvisited\tindex ms\tselectivity\tfilterType'

# crazy logic to make everything fixed width so rendering in fixed width font "aligns":
num_columns = len(header.split('\t'))
headers = header.split('\t')
num_columns = len(headers)
# print(f'{num_columns} columns')
max_by_col = [0] * num_columns

rows_to_print = [header] + all_results

# TODO: be more careful when we skip/show headers e.g. if some of the runs involve filtering,
# turn filterType/selectivity back on for all runs
skip_headers = {'selectivity', 'filterType', 'visited'}

skip_column_index = {headers.index(h) for h in skip_headers}

for row in rows_to_print:
by_column = row.split('\t')
if len(by_column) != num_columns:
raise RuntimeError(f'wrong number of columns: expected {num_columns} but got {len(by_column)}')
raise RuntimeError(f'wrong number of columns: expected {num_columns} but got {len(by_column)} in "{row}"')
for i, s in enumerate(by_column):
max_by_col[i] = max(max_by_col[i], len(s))

row_fmt = ' '.join([f'%{max_by_col[i]}s' for i in range(num_columns)])
row_fmt = ' '.join([f'%{max_by_col[i]}s' for i in range(num_columns) if i not in skip_column_index])
# print(f'using row format {row_fmt}')

for row in rows_to_print:
print(row_fmt % tuple(row.split('\t')))
cols = row.split('\t')
cols = tuple(cols[x] for x in range(len(cols)) if x not in skip_column_index)
print(row_fmt % cols)


run_knn_benchmark(LUCENE_CHECKOUT, PARAMS)

0 comments on commit ebefd8c

Please sign in to comment.