Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] updates & invalid operations should also trigger persisting of local HNSW #2499

Merged
merged 15 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions chromadb/segment/impl/vector/local_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class LocalHnswSegment(VectorReader):
_index: Optional[hnswlib.Index]
_dimensionality: Optional[int]
_total_elements_added: int
_total_elements_updated: int
_max_seq_id: SeqId

_lock: ReadWriteLock
Expand All @@ -66,6 +67,7 @@ def __init__(self, system: System, segment: Segment):
self._index = None
self._dimensionality = None
self._total_elements_added = 0
self._total_elements_updated = 0
self._max_seq_id = self._consumer.min_seqid()

self._id_to_seq_id = {}
Expand Down Expand Up @@ -275,6 +277,7 @@ def _apply_batch(self, batch: Batch) -> None:

# If that succeeds, update the total count
self._total_elements_added += batch.add_count
self._total_elements_updated += batch.update_count

# If that succeeds, finally the seq ID
self._max_seq_id = batch.max_seq_id
Expand Down
33 changes: 31 additions & 2 deletions chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class PersistentData:

dimensionality: Optional[int]
total_elements_added: int
total_elements_updated: int
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
max_seq_id: SeqId

id_to_label: Dict[str, int]
Expand All @@ -51,18 +52,25 @@ def __init__(
self,
dimensionality: Optional[int],
total_elements_added: int,
total_elements_updated: int,
max_seq_id: int,
id_to_label: Dict[str, int],
label_to_id: Dict[int, str],
id_to_seq_id: Dict[str, SeqId],
):
self.dimensionality = dimensionality
self.total_elements_added = total_elements_added
self.total_elements_updated = total_elements_updated
self.max_seq_id = max_seq_id
self.id_to_label = id_to_label
self.label_to_id = label_to_id
self.id_to_seq_id = id_to_seq_id

def __setstate__(self, state):
self.__dict__.update(state)
# Field was added after the initial implementation
self.total_elements_updated = 0

@staticmethod
def load_from_file(filename: str) -> "PersistentData":
"""Load persistent data from a file"""
Expand All @@ -86,6 +94,8 @@ class PersistentLocalHnswSegment(LocalHnswSegment):
_persist_directory: str
_allow_reset: bool

_invalid_operations_since_last_persist: int = 0
codetheweb marked this conversation as resolved.
Show resolved Hide resolved

_opentelemtry_client: OpenTelemetryClient

def __init__(self, system: System, segment: Segment):
Expand Down Expand Up @@ -121,6 +131,7 @@ def __init__(self, system: System, segment: Segment):
self._persist_data = PersistentData(
self._dimensionality,
self._total_elements_added,
self._total_elements_updated,
self._max_seq_id,
self._id_to_label,
self._label_to_id,
Expand Down Expand Up @@ -195,6 +206,7 @@ def _persist(self) -> None:
# Persist the metadata
self._persist_data.dimensionality = self._dimensionality
self._persist_data.total_elements_added = self._total_elements_added
self._persist_data.total_elements_updated = self._total_elements_updated
self._persist_data.max_seq_id = self._max_seq_id

# TODO: This should really be stored in sqlite, the index itself, or a better
Expand All @@ -206,14 +218,25 @@ def _persist(self) -> None:
with open(self._get_metadata_file(), "wb") as metadata_file:
pickle.dump(self._persist_data, metadata_file, pickle.HIGHEST_PROTOCOL)

self._invalid_operations_since_last_persist = 0

@trace_method(
"PersistentLocalHnswSegment._apply_batch", OpenTelemetryGranularity.ALL
)
@override
def _apply_batch(self, batch: Batch) -> None:
super()._apply_batch(batch)
if (
num_elements_added_since_last_persist = (
self._total_elements_added - self._persist_data.total_elements_added
)
num_elements_updated_since_last_persist = (
self._total_elements_updated - self._persist_data.total_elements_updated
)

if (
num_elements_added_since_last_persist
+ num_elements_updated_since_last_persist
+ self._invalid_operations_since_last_persist
>= self._sync_threshold
):
self._persist()
Expand Down Expand Up @@ -262,18 +285,24 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
logger.warning(
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
self._invalid_operations_since_last_persist += 1
elif op == Operation.ADD:
if record["record"]["embedding"] is not None:
if not exists_in_index:
self._curr_batch.apply(record, not exists_in_index)
self._brute_force_index.upsert([record])
else:
logger.warning(f"Add of existing embedding ID: {id}")
self._invalid_operations_since_last_persist += 1
elif op == Operation.UPSERT:
if record["record"]["embedding"] is not None:
self._curr_batch.apply(record, exists_in_index)
self._brute_force_index.upsert([record])
if len(self._curr_batch) >= self._batch_size:

if (
len(self._curr_batch) + self._invalid_operations_since_last_persist
>= self._batch_size
):
self._apply_batch(self._curr_batch)
self._curr_batch = Batch()
self._brute_force_index.clear()
Expand Down
8 changes: 6 additions & 2 deletions chromadb/test/property/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ def collections(
has_embeddings: Optional[bool] = None,
has_documents: Optional[bool] = None,
with_persistent_hnsw_params: bool = False,
max_hnsw_batch_size: int = 2000,
max_hnsw_sync_threshold: int = 2000,
) -> Collection:
"""Strategy to generate a Collection object. If add_filterable_data is True, then known_metadata_keys and known_document_keywords will be populated with consistent data."""

Expand All @@ -302,9 +304,11 @@ def collections(
metadata = {}
metadata.update(test_hnsw_config)
if with_persistent_hnsw_params:
metadata["hnsw:batch_size"] = draw(st.integers(min_value=3, max_value=2000))
metadata["hnsw:batch_size"] = draw(
st.integers(min_value=3, max_value=max_hnsw_batch_size)
)
metadata["hnsw:sync_threshold"] = draw(
st.integers(min_value=3, max_value=2000)
st.integers(min_value=3, max_value=max_hnsw_sync_threshold)
)
# Sometimes, select a space at random
if draw(st.booleans()):
Expand Down
14 changes: 13 additions & 1 deletion chromadb/test/property/test_cross_version_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,15 @@ def persist_generated_data_with_old_version(

# Since we can't pickle the embedding function, we always generate record sets with embeddings
collection_st: st.SearchStrategy[strategies.Collection] = st.shared(
strategies.collections(with_hnsw_params=True, has_embeddings=True), key="coll"
strategies.collections(
with_hnsw_params=True,
has_embeddings=True,
with_persistent_hnsw_params=True,
# By default, these are set to 2000, which makes it unlikely that index mutations will ever be fully flushed
max_hnsw_sync_threshold=10,
max_hnsw_batch_size=10,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without reducing the sync threshold, this test will almost never fully flush to disk, which makes the correctness guarantee weaker

without this change, the pickled metadata file was never loaded

),
key="coll",
)


Expand Down Expand Up @@ -336,6 +344,10 @@ def test_cycle_versions(
name=collection_strategy.name,
embedding_function=not_implemented_ef(), # type: ignore
)

# Should be able to add embeddings
coll.add(**embeddings_strategy) # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

causes the pickled metadata file to be loaded, a good thing to test since the class def changed and it needs special handling


invariants.count(coll, embeddings_strategy)
invariants.metadatas_match(coll, embeddings_strategy)
invariants.documents_match(coll, embeddings_strategy)
Expand Down
61 changes: 61 additions & 0 deletions chromadb/test/property/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import multiprocessing
from multiprocessing.connection import Connection
import multiprocessing.context
import time
from typing import Generator, Callable
from hypothesis import given
import hypothesis.strategies as st
import pytest
import chromadb
from chromadb.api import ClientAPI, ServerAPI
from chromadb.config import Settings, System
from chromadb.segment import SegmentManager, VectorReader
import chromadb.test.property.strategies as strategies
import chromadb.test.property.invariants as invariants
from chromadb.test.property.test_embeddings import (
Expand Down Expand Up @@ -122,6 +124,65 @@ def test_persist(
del system_2


def test_sync_threshold(settings: Settings) -> None:
system = System(settings)
system.start()
client = ClientCreator.from_system(system)

collection = client.create_collection(
name="test", metadata={"hnsw:batch_size": 3, "hnsw:sync_threshold": 3}
)

manager = system.instance(SegmentManager)
segment = manager.get_segment(collection.id, VectorReader)

def get_index_last_modified_at() -> float:
# Time resolution on Windows can be up to 10ms
time.sleep(0.1)
try:
return os.path.getmtime(segment._get_metadata_file()) # type: ignore[attr-defined]
codetheweb marked this conversation as resolved.
Show resolved Hide resolved
except FileNotFoundError:
return -1

last_modified_at = get_index_last_modified_at()

collection.add(ids=["1", "2"], embeddings=[[1.0], [2.0]])

# Should not have yet persisted
assert get_index_last_modified_at() == last_modified_at
last_modified_at = get_index_last_modified_at()

# Now there's 3 additions, and the sync threshold is 3...
collection.add(ids=["3"], embeddings=[[3.0]])

# ...so it should have persisted
assert get_index_last_modified_at() > last_modified_at
last_modified_at = get_index_last_modified_at()

# The same thing should happen with upserts
collection.upsert(ids=["1", "2", "3"], embeddings=[[1.0], [2.0], [3.0]])

# Should have persisted
assert get_index_last_modified_at() > last_modified_at
last_modified_at = get_index_last_modified_at()

# Mixed usage should also trigger persistence
collection.add(ids=["4"], embeddings=[[4.0]])
collection.upsert(ids=["1", "2"], embeddings=[[1.0], [2.0]])

# Should have persisted
assert get_index_last_modified_at() > last_modified_at
last_modified_at = get_index_last_modified_at()

# Invalid updates should also trigger persistence
collection.add(ids=["5"], embeddings=[[5.0]])
collection.add(ids=["1", "2"], embeddings=[[1.0], [2.0]])

# Should have persisted
assert get_index_last_modified_at() > last_modified_at
last_modified_at = get_index_last_modified_at()


def load_and_check(
settings: Settings,
collection_name: str,
Expand Down
Loading