Skip to content

Commit

Permalink
[ENH] updates & invalid operations should also trigger persisting of …
Browse files Browse the repository at this point in the history
…local HNSW (#2499)
  • Loading branch information
codetheweb authored Jul 15, 2024
1 parent 4d063ba commit 37c54df
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 2 deletions.
8 changes: 8 additions & 0 deletions chromadb/segment/impl/vector/local_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class LocalHnswSegment(VectorReader):
_index: Optional[hnswlib.Index]
_dimensionality: Optional[int]
_total_elements_added: int
_total_elements_updated: int
_total_invalid_operations: int
_max_seq_id: SeqId

_lock: ReadWriteLock
Expand All @@ -66,6 +68,8 @@ def __init__(self, system: System, segment: Segment):
self._index = None
self._dimensionality = None
self._total_elements_added = 0
self._total_elements_updated = 0
self._total_invalid_operations = 0
self._max_seq_id = self._consumer.min_seqid()

self._id_to_seq_id = {}
Expand Down Expand Up @@ -275,6 +279,7 @@ def _apply_batch(self, batch: Batch) -> None:

# If that succeeds, update the total count
self._total_elements_added += batch.add_count
self._total_elements_updated += batch.update_count

# If that succeeds, finally the seq ID
self._max_seq_id = batch.max_seq_id
Expand All @@ -300,6 +305,7 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
batch.apply(record)
else:
logger.warning(f"Delete of nonexisting embedding ID: {id}")
self._total_invalid_operations += 1

elif op == Operation.UPDATE:
if record["record"]["embedding"] is not None:
Expand All @@ -309,11 +315,13 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
logger.warning(
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
self._total_invalid_operations += 1
elif op == Operation.ADD:
if not label:
batch.apply(record, False)
else:
logger.warning(f"Add of existing embedding ID: {id}")
self._total_invalid_operations += 1
elif op == Operation.UPSERT:
batch.apply(record, label is not None)

Expand Down
37 changes: 35 additions & 2 deletions chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class PersistentData:

dimensionality: Optional[int]
total_elements_added: int
total_elements_updated: int
total_invalid_operations: int
max_seq_id: SeqId

id_to_label: Dict[str, int]
Expand All @@ -51,18 +53,28 @@ def __init__(
self,
dimensionality: Optional[int],
total_elements_added: int,
total_elements_updated: int,
total_invalid_operations: int,
max_seq_id: int,
id_to_label: Dict[str, int],
label_to_id: Dict[int, str],
id_to_seq_id: Dict[str, SeqId],
):
self.dimensionality = dimensionality
self.total_elements_added = total_elements_added
self.total_elements_updated = total_elements_updated
self.total_invalid_operations = total_invalid_operations
self.max_seq_id = max_seq_id
self.id_to_label = id_to_label
self.label_to_id = label_to_id
self.id_to_seq_id = id_to_seq_id

def __setstate__(self, state):
# Fields were added after the initial implementation
self.total_elements_updated = 0
self.total_invalid_operations = 0
self.__dict__.update(state)

@staticmethod
def load_from_file(filename: str) -> "PersistentData":
"""Load persistent data from a file"""
Expand Down Expand Up @@ -121,6 +133,8 @@ def __init__(self, system: System, segment: Segment):
self._persist_data = PersistentData(
self._dimensionality,
self._total_elements_added,
self._total_elements_updated,
self._total_invalid_operations,
self._max_seq_id,
self._id_to_label,
self._label_to_id,
Expand Down Expand Up @@ -195,6 +209,7 @@ def _persist(self) -> None:
# Persist the metadata
self._persist_data.dimensionality = self._dimensionality
self._persist_data.total_elements_added = self._total_elements_added
self._persist_data.total_elements_updated = self._total_elements_updated
self._persist_data.max_seq_id = self._max_seq_id

# TODO: This should really be stored in sqlite, the index itself, or a better
Expand All @@ -212,8 +227,20 @@ def _persist(self) -> None:
@override
def _apply_batch(self, batch: Batch) -> None:
super()._apply_batch(batch)
if (
num_elements_added_since_last_persist = (
self._total_elements_added - self._persist_data.total_elements_added
)
num_elements_updated_since_last_persist = (
self._total_elements_updated - self._persist_data.total_elements_updated
)
num_invalid_operations_since_last_persist = (
self._total_invalid_operations - self._persist_data.total_invalid_operations
)

if (
num_elements_added_since_last_persist
+ num_elements_updated_since_last_persist
+ num_invalid_operations_since_last_persist
>= self._sync_threshold
):
self._persist()
Expand Down Expand Up @@ -264,18 +291,24 @@ def _write_records(self, records: Sequence[LogRecord]) -> None:
logger.warning(
f"Update of nonexisting embedding ID: {record['record']['id']}"
)
self._total_invalid_operations += 1
elif op == Operation.ADD:
if record["record"]["embedding"] is not None:
if exists_in_index and not id_is_pending_delete:
logger.warning(f"Add of existing embedding ID: {id}")
self._total_invalid_operations += 1
else:
self._curr_batch.apply(record, not exists_in_index)
self._brute_force_index.upsert([record])
elif op == Operation.UPSERT:
if record["record"]["embedding"] is not None:
self._curr_batch.apply(record, exists_in_index)
self._brute_force_index.upsert([record])
if len(self._curr_batch) >= self._batch_size:

if (
len(self._curr_batch) + self._total_invalid_operations
>= self._batch_size
):
self._apply_batch(self._curr_batch)
self._curr_batch = Batch()
self._brute_force_index.clear()
Expand Down
7 changes: 7 additions & 0 deletions chromadb/test/property/test_cross_version_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def persist_generated_data_with_old_version(
strategies.collections(
with_hnsw_params=True,
has_embeddings=True,
# By default, these are set to 2000, which makes it unlikely that index mutations will ever be fully flushed
max_hnsw_sync_threshold=10,
max_hnsw_batch_size=10,
with_persistent_hnsw_params=st.booleans(),
),
key="coll",
Expand Down Expand Up @@ -341,6 +344,10 @@ def test_cycle_versions(
name=collection_strategy.name,
embedding_function=not_implemented_ef(), # type: ignore
)

# Should be able to add embeddings
coll.add(**embeddings_strategy) # type: ignore

invariants.count(coll, embeddings_strategy)
invariants.metadatas_match(coll, embeddings_strategy)
invariants.documents_match(coll, embeddings_strategy)
Expand Down
61 changes: 61 additions & 0 deletions chromadb/test/property/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import multiprocessing
from multiprocessing.connection import Connection
import multiprocessing.context
import time
from typing import Generator, Callable
from uuid import UUID
from hypothesis import given
Expand All @@ -10,6 +11,7 @@
import chromadb
from chromadb.api import ClientAPI, ServerAPI
from chromadb.config import Settings, System
from chromadb.segment import SegmentManager, VectorReader
import chromadb.test.property.strategies as strategies
import chromadb.test.property.invariants as invariants
from chromadb.test.property.test_embeddings import (
Expand Down Expand Up @@ -129,6 +131,65 @@ def test_persist(
del system_2


def test_sync_threshold(settings: Settings) -> None:
system = System(settings)
system.start()
client = ClientCreator.from_system(system)

collection = client.create_collection(
name="test", metadata={"hnsw:batch_size": 3, "hnsw:sync_threshold": 3}
)

manager = system.instance(SegmentManager)
segment = manager.get_segment(collection.id, VectorReader)

def get_index_last_modified_at() -> float:
# Time resolution on Windows can be up to 10ms
time.sleep(0.1)
try:
return os.path.getmtime(segment._get_metadata_file()) # type: ignore[attr-defined]
except FileNotFoundError:
return -1

last_modified_at = get_index_last_modified_at()

collection.add(ids=["1", "2"], embeddings=[[1.0], [2.0]])

# Should not have yet persisted
assert get_index_last_modified_at() == last_modified_at
last_modified_at = get_index_last_modified_at()

# Now there's 3 additions, and the sync threshold is 3...
collection.add(ids=["3"], embeddings=[[3.0]])

# ...so it should have persisted
assert get_index_last_modified_at() > last_modified_at
last_modified_at = get_index_last_modified_at()

# The same thing should happen with upserts
collection.upsert(ids=["1", "2", "3"], embeddings=[[1.0], [2.0], [3.0]])

# Should have persisted
assert get_index_last_modified_at() > last_modified_at
last_modified_at = get_index_last_modified_at()

# Mixed usage should also trigger persistence
collection.add(ids=["4"], embeddings=[[4.0]])
collection.upsert(ids=["1", "2"], embeddings=[[1.0], [2.0]])

# Should have persisted
assert get_index_last_modified_at() > last_modified_at
last_modified_at = get_index_last_modified_at()

# Invalid updates should also trigger persistence
collection.add(ids=["5"], embeddings=[[5.0]])
collection.add(ids=["1", "2"], embeddings=[[1.0], [2.0]])

# Should have persisted
assert get_index_last_modified_at() > last_modified_at
last_modified_at = get_index_last_modified_at()


def load_and_check(
settings: Settings,
collection_name: str,
Expand Down

0 comments on commit 37c54df

Please sign in to comment.