diff --git a/Cargo.lock b/Cargo.lock index b34975b..d6619ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1821,7 +1821,10 @@ dependencies = [ "itertools 0.10.5", "lazy_static", "libc", + "linfa", + "linfa-clustering", "lru", + "ndarray", "parallel-hnsw", "rand", "rand_pcg", diff --git a/vectorlink/Cargo.toml b/vectorlink/Cargo.toml index b8e9181..b4226e2 100644 --- a/vectorlink/Cargo.toml +++ b/vectorlink/Cargo.toml @@ -33,6 +33,9 @@ itertools = "0.10" chrono = "0.4.26" rayon = "1.8.0" libc = "0.2.153" +linfa = "0.7.0" +linfa-clustering = "0.7.0" +ndarray = "0.15.6" [dev-dependencies] assert_float_eq = "1.1.3" diff --git a/vectorlink/src/batch.rs b/vectorlink/src/batch.rs index b87be13..f878851 100644 --- a/vectorlink/src/batch.rs +++ b/vectorlink/src/batch.rs @@ -20,12 +20,15 @@ use urlencoding::encode; use crate::{ comparator::{ - Centroid16Comparator, DiskOpenAIComparator, OpenAIComparator, Quantized16Comparator, + Centroid16Comparator, DiskOpenAIComparator, DomainQuantizer, HnswQuantizer16, + OpenAIComparator, Quantized16Comparator, }, configuration::HnswConfiguration, + domain::{PqDerivedDomainInfo16, PqDerivedDomainInitializer16}, indexer::{create_index_name, index_serialization_path}, openai::{embeddings_for, EmbeddingError, Model}, server::Operation, + store::VectorFile, vecmath::{Embedding, CENTROID_16_LENGTH, EMBEDDING_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH}, vectors::VectorStore, }; @@ -56,36 +59,13 @@ pub enum VectorizationError { Io(#[from] io::Error), } -async fn save_embeddings( - vec_file: &mut File, - offset: usize, - embeddings: &[Embedding], -) -> Result<(), VectorizationError> { - let transmuted = unsafe { - std::slice::from_raw_parts( - embeddings.as_ptr() as *const u8, - std::mem::size_of_val(embeddings), - ) - }; - vec_file - .seek(SeekFrom::Start( - (offset * std::mem::size_of::()) as u64, - )) - .await?; - vec_file.write_all(transmuted).await?; - vec_file.flush().await?; - vec_file.sync_data().await?; - - Ok(()) -} - pub async fn vectorize_from_operations< S: Stream>, P: AsRef + Unpin, >( api_key: &str, model: Model, - vec_file: &mut File, + vec_file: &mut VectorFile, op_stream: S, progress_file_path: P, ) -> Result { @@ -122,7 +102,7 @@ pub async fn vectorize_from_operations< let (embeddings, chunk_failures) = embeds.unwrap()?; eprintln!("retrieved embeddings"); - save_embeddings(vec_file, offset as usize, &embeddings).await?; + vec_file.append_vector_range(&embeddings)?; eprintln!("saved embeddings"); failures += chunk_failures; offset += embeddings.len() as u64; @@ -190,7 +170,7 @@ pub async fn index_using_operations_and_vectors< op_file_path: P2, size: usize, id_offset: u64, - quantize_hnsw: bool, + quantize_hnsw: Option<&str>, model: Model, ) -> Result<(), IndexingError> { // Start at last hnsw offset @@ -257,20 +237,37 @@ pub async fn index_using_operations_and_vectors< .collect(); eprintln!("ready to generate hnsw"); - let hnsw = if quantize_hnsw { + let hnsw = if let Some(pq_name) = quantize_hnsw { let number_of_vectors = NUMBER_OF_CENTROIDS / 10; let c = DiskOpenAIComparator::new( domain_obj.name().to_owned(), Arc::new(domain_obj.immutable_file()), ); + + let derived_domain_info = domain_obj.get_derived_domain_info(pq_name); + if derived_domain_info.is_none() { + eprintln!("pq derived domain ({pq_name}) doesn't exist yet. constructing now"); + domain_obj + .create_derived(pq_name.to_string(), PqDerivedDomainInitializer16::default()) + .unwrap(); // TODO + } + // lazy - we just look it up again and now it should exist + let derived_domain_info: PqDerivedDomainInfo16 = + domain_obj.get_derived_domain_info(pq_name).unwrap(); + + let quantizer = derived_domain_info.quantizer.clone()); + + let quantized_comparator = + Quantized16Comparator::load(&vs, domain.to_string(), pq_name.to_string())?; + let hnsw: QuantizedHnsw< EMBEDDING_LENGTH, CENTROID_16_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, - Centroid16Comparator, Quantized16Comparator, DiskOpenAIComparator, - > = QuantizedHnsw::new(number_of_vectors, c); + Arc, + > = QuantizedHnsw::generate(quantizer, quantized_comparator, c, vecs); HnswConfiguration::SmallQuantizedOpenAi(model, hnsw) } else { let hnsw = Hnsw::generate(comparator, vecs, 24, 48, 12); @@ -303,11 +300,7 @@ pub async fn index_from_operations_file>( let mut vector_path = staging_path.clone(); vector_path.push("vectors"); - let mut vec_file = OpenOptions::new() - .create(true) - .write(true) - .open(&vector_path) - .await?; + let mut vec_file = VectorFile::open_create(&vector_path, true)?; let mut progress_file_path = staging_path.clone(); progress_file_path.push("progress"); diff --git a/vectorlink/src/comparator.rs b/vectorlink/src/comparator.rs index 4f04ac9..97e9a5b 100644 --- a/vectorlink/src/comparator.rs +++ b/vectorlink/src/comparator.rs @@ -1,23 +1,17 @@ -use parallel_hnsw::pq::{ - CentroidComparatorConstructor, PartialDistance, QuantizedComparatorConstructor, -}; -use rand::distributions::Uniform; -use rand::{thread_rng, Rng}; +use parallel_hnsw::pq::{HnswQuantizer, PartialDistance, Quantizer}; use serde::{Deserialize, Serialize}; -use std::collections::HashSet; use std::fs::OpenOptions; -use std::io::{Read, Write}; +use std::io::{self, BufReader, Read, Write}; use std::marker::PhantomData; -use std::path::PathBuf; use std::{path::Path, sync::Arc}; -use parallel_hnsw::{pq, Comparator, Serializable, SerializationError, VectorId}; +use parallel_hnsw::{Comparator, Serializable, SerializationError, VectorId}; +use crate::domain::PqDerivedDomainInfo; use crate::store::{ImmutableVectorFile, LoadedVectorRange, VectorFile}; use crate::vecmath::{ - self, EuclideanDistance16, EuclideanDistance32, Quantized16Embedding, Quantized32Embedding, - CENTROID_16_LENGTH, CENTROID_32_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, - QUANTIZED_32_EMBEDDING_LENGTH, + self, EuclideanDistance16, EuclideanDistance32, CENTROID_16_LENGTH, CENTROID_32_LENGTH, + EMBEDDING_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, }; use crate::{ vecmath::{normalized_cosine_distance, Embedding}, @@ -72,7 +66,7 @@ impl Serializable for DiskOpenAIComparator { fn deserialize>( path: P, - store: Arc, + store: &Arc, ) -> Result { let mut comparator_file = OpenOptions::new().read(true).open(path)?; let mut contents = String::new(); @@ -86,35 +80,6 @@ impl Serializable for DiskOpenAIComparator { } } -impl pq::VectorSelector for DiskOpenAIComparator { - type T = Embedding; - - fn selection(&self, size: usize) -> Vec { - // TODO do something else for sizes close to number of vecs - if size >= self.vectors.num_vecs() { - return self.vectors.all_vectors().unwrap().clone().into_vec(); - } - let mut rng = thread_rng(); - let mut set = HashSet::new(); - let range = Uniform::from(0_usize..self.vectors.num_vecs()); - while set.len() != size { - let candidate = rng.sample(&range); - set.insert(candidate); - } - - set.into_iter() - .map(|index| self.vectors.vec(index).unwrap()) - .collect() - } - - fn vector_chunks(&self) -> impl Iterator> { - self.vectors - .vector_chunks(1_000_000) - .unwrap() - .map(|x| x.unwrap()) - } -} - #[derive(Clone)] pub struct OpenAIComparator { domain_name: String, @@ -166,7 +131,7 @@ impl Serializable for OpenAIComparator { fn deserialize>( path: P, - store: Arc, + store: &Arc, ) -> Result { let mut comparator_file = OpenOptions::new().read(true).open(path)?; let mut contents = String::new(); @@ -230,6 +195,17 @@ pub struct ArrayCentroidComparator { calculator: PhantomData, } +impl + Default> ArrayCentroidComparator { + pub fn new(centroids: Vec<[f32; N]>) -> Self { + let len = centroids.len(); + Self { + distances: Arc::new(MemoizedPartialDistances::new(C::default(), ¢roids)), + centroids: Arc::new(LoadedVectorRange::new(centroids, 0..len)), + calculator: PhantomData, + } + } +} + impl Clone for ArrayCentroidComparator { fn clone(&self) -> Self { Self { @@ -244,19 +220,6 @@ unsafe impl Sync for ArrayCentroidComparator {} pub type Centroid16Comparator = ArrayCentroidComparator; pub type Centroid32Comparator = ArrayCentroidComparator; -impl + Default> - CentroidComparatorConstructor for ArrayCentroidComparator -{ - fn new(centroids: Vec) -> Self { - let len = centroids.len(); - Self { - distances: Arc::new(MemoizedPartialDistances::new(C::default(), ¢roids)), - centroids: Arc::new(LoadedVectorRange::new(centroids, 0..len)), - calculator: PhantomData, - } - } -} - impl + Default> Comparator for ArrayCentroidComparator { @@ -294,7 +257,7 @@ impl + Default> Serializable fn deserialize>( path: P, - _params: Self::Params, + _params: &Self::Params, ) -> Result { let vector_file: VectorFile<[f32; N]> = VectorFile::open(path, true)?; let centroids = Arc::new(vector_file.all_vectors()?); @@ -310,234 +273,307 @@ impl + Default> Serializable } } -#[derive(Clone)] -pub struct Quantized32Comparator { - pub cc: Centroid32Comparator, - pub data: Arc>, +pub struct QuantizedDomainComparator< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C, +> { + domain: String, + subdomain: String, + cc: ArrayCentroidComparator, + data: Arc>, } -impl QuantizedComparatorConstructor for Quantized32Comparator { - type CentroidComparator = Centroid32Comparator; - - fn new(cc: &Self::CentroidComparator) -> Self { +impl Clone + for QuantizedDomainComparator +{ + fn clone(&self) -> Self { Self { - cc: cc.clone(), - data: Default::default(), + domain: self.domain.clone(), + subdomain: self.subdomain.clone(), + cc: self.cc.clone(), + data: self.data.clone(), } } } -#[derive(Clone)] -pub struct Quantized16Comparator { - pub cc: Centroid16Comparator, - pub data: Arc>, -} - -impl QuantizedComparatorConstructor for Quantized16Comparator { - type CentroidComparator = Centroid16Comparator; +pub type Quantized16Comparator = QuantizedDomainComparator< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + EuclideanDistance16, +>; +pub type Quantized32Comparator = QuantizedDomainComparator< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + EuclideanDistance32, +>; - fn new(cc: &Self::CentroidComparator) -> Self { - Self { - cc: cc.clone(), - data: Default::default(), - } - } +#[derive(Serialize, Deserialize)] +struct QuantizedDomainComparatorMeta { + domain: String, + subdomain: String, } -impl PartialDistance for Quantized32Comparator { - fn partial_distance(&self, i: u16, j: u16) -> f32 { - self.cc.partial_distance(i, j) +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + DistanceCalculator, + > QuantizedDomainComparator +where + ArrayCentroidComparator: 'static + Comparator, +{ + pub fn load(store: &VectorStore, domain: String, subdomain: String) -> io::Result { + assert_eq!(SIZE, CENTROID_SIZE * QUANTIZED_SIZE); // TODO compile-time macro check this + let domain_info = store.get_domain(&domain)?; + let derived_domain_info: PqDerivedDomainInfo = + domain_info + .get_derived_domain_info(&subdomain) + .expect("pq subdomain not found"); + + Ok(Self { + domain, + subdomain, + cc: derived_domain_info.quantizer.quantizer.comparator().clone(), + data: Arc::new(derived_domain_info.file.all_vectors()?), + }) } } - -impl PartialDistance for Quantized16Comparator { +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + DistanceCalculator, + > PartialDistance for QuantizedDomainComparator +where + ArrayCentroidComparator: 'static + Comparator, +{ fn partial_distance(&self, i: u16, j: u16) -> f32 { self.cc.partial_distance(i, j) } } -impl Comparator for Quantized32Comparator +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + DistanceCalculator, + > Serializable for QuantizedDomainComparator where - Quantized32Comparator: PartialDistance, + ArrayCentroidComparator: 'static + Comparator, { - type T = Quantized32Embedding; - - type Borrowable<'a> = &'a Quantized32Embedding; - - fn lookup(&self, v: VectorId) -> Self::Borrowable<'_> { - &self.data[v.0] - } - - fn compare_raw(&self, v1: &Self::T, v2: &Self::T) -> f32 { - let mut partial_distances = [0.0_f32; QUANTIZED_32_EMBEDDING_LENGTH]; - for ix in 0..QUANTIZED_32_EMBEDDING_LENGTH { - let partial_1 = v1[ix]; - let partial_2 = v2[ix]; - let partial_distance = self.cc.partial_distance(partial_1, partial_2); - partial_distances[ix] = partial_distance; - } - - vecmath::sum_48(&partial_distances).sqrt() - } -} - -impl Serializable for Quantized32Comparator { - type Params = (); + type Params = Arc; fn serialize>(&self, path: P) -> Result<(), SerializationError> { - let path_buf: PathBuf = path.as_ref().into(); - std::fs::create_dir_all(&path_buf)?; - - let index_path = path_buf.join("index"); - self.cc.serialize(index_path)?; + let meta = QuantizedDomainComparatorMeta { + domain: self.domain.clone(), + subdomain: self.subdomain.clone(), + }; + let meta_string = serde_json::to_string(&meta)?; + std::fs::write(path, meta_string)?; - let vector_path = path_buf.join("vectors"); - let mut vector_file = VectorFile::open(vector_path, true)?; - vector_file.append_vector_range(self.data.vecs())?; Ok(()) } fn deserialize>( path: P, - _params: Self::Params, + params: &Self::Params, ) -> Result { - let path_buf: PathBuf = path.as_ref().into(); - let index_path = path_buf.join("index"); - let cc = Centroid32Comparator::deserialize(index_path, ())?; - - let vector_path = path_buf.join("vectors"); - let vector_file = VectorFile::open(vector_path, true)?; - let range = vector_file.all_vectors()?; + let comparator_file = OpenOptions::new().read(true).open(path)?; + let QuantizedDomainComparatorMeta { domain, subdomain } = + serde_json::from_reader(BufReader::new(comparator_file))?; - let data = Arc::new(range); - Ok(Self { cc, data }) + Ok(Self::load(¶ms, domain, subdomain)?) } } -impl pq::VectorStore for Quantized32Comparator { - type T = ::T; - - fn store(&mut self, i: Box>) -> Vec { - // this is p retty stupid, but then, these comparators should not be storing in the first place - let mut new_contents: Vec = Vec::with_capacity(self.data.len() + i.size_hint().0); - new_contents.extend(self.data.vecs().iter()); - let vid = self.data.len(); - let mut vectors: Vec = Vec::new(); - new_contents.extend(i.enumerate().map(|(i, v)| { - vectors.push(VectorId(vid + i)); - v - })); - let end = new_contents.len(); +pub type QuantizedDomainComparator16 = QuantizedDomainComparator< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + Centroid16Comparator, +>; +pub type QuantizedDomainComparator32 = QuantizedDomainComparator< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + Centroid32Comparator, +>; + +pub struct QuantizedEmbeddingSizeCombination< + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, +>; +pub trait ImplementedQuantizedEmbeddingSizeCombination< + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, +> +{ + fn compare_quantized( + comparator: &C, + v1: &[u16; QUANTIZED_SIZE], + v2: &[u16; QUANTIZED_SIZE], + ) -> f32; +} - let data = LoadedVectorRange::new(new_contents, 0..end); - self.data = Arc::new(data); +impl ImplementedQuantizedEmbeddingSizeCombination + for QuantizedEmbeddingSizeCombination +{ + fn compare_quantized( + comparator: &C, + v1: &[u16; QUANTIZED_16_EMBEDDING_LENGTH], + v2: &[u16; QUANTIZED_16_EMBEDDING_LENGTH], + ) -> f32 { + let mut partial_distances = [0.0_f32; QUANTIZED_16_EMBEDDING_LENGTH]; + for ix in 0..QUANTIZED_16_EMBEDDING_LENGTH { + let partial_1 = v1[ix]; + let partial_2 = v2[ix]; + let partial_distance = comparator.partial_distance(partial_1, partial_2); + partial_distances[ix] = partial_distance; + } - vectors + vecmath::sum_96(&partial_distances).sqrt() } } -impl pq::VectorSelector for OpenAIComparator { - type T = Embedding; - - fn selection(&self, size: usize) -> Vec { - // TODO do something else for sizes close to number of vecs - let mut rng = thread_rng(); - let mut set = HashSet::new(); - let range = Uniform::from(0_usize..size); - while set.len() != size { - let candidate = rng.sample(&range); - set.insert(candidate); +impl ImplementedQuantizedEmbeddingSizeCombination + for QuantizedEmbeddingSizeCombination +{ + fn compare_quantized( + comparator: &C, + v1: &[u16; QUANTIZED_32_EMBEDDING_LENGTH], + v2: &[u16; QUANTIZED_32_EMBEDDING_LENGTH], + ) -> f32 { + let mut partial_distances = [0.0_f32; QUANTIZED_32_EMBEDDING_LENGTH]; + for ix in 0..QUANTIZED_32_EMBEDDING_LENGTH { + let partial_1 = v1[ix]; + let partial_2 = v2[ix]; + let partial_distance = comparator.partial_distance(partial_1, partial_2); + partial_distances[ix] = partial_distance; } - set.into_iter() - .map(|index| *self.range.vec(index)) - .collect() - } - - fn vector_chunks(&self) -> impl Iterator> { - // low quality make better - self.range.vecs().chunks(1_000_000).map(|c| c.to_vec()) + vecmath::sum_48(&partial_distances).sqrt() } } -impl Comparator for Quantized16Comparator +impl + Comparator for QuantizedDomainComparator where - Quantized16Comparator: PartialDistance, + QuantizedEmbeddingSizeCombination: + ImplementedQuantizedEmbeddingSizeCombination, { - type T = Quantized16Embedding; + type T = [u16; QUANTIZED_SIZE]; - type Borrowable<'a> = &'a Self::T; + type Borrowable<'a> = &'a [u16; QUANTIZED_SIZE]; fn lookup(&self, v: VectorId) -> Self::Borrowable<'_> { - self.data.vec(v.0) + &self.data[v.0] } fn compare_raw(&self, v1: &Self::T, v2: &Self::T) -> f32 { - let mut partial_distances = [0.0_f32; QUANTIZED_16_EMBEDDING_LENGTH]; - for ix in 0..QUANTIZED_16_EMBEDDING_LENGTH { - let partial_1 = v1[ix]; - let partial_2 = v2[ix]; - let partial_distance = self.cc.partial_distance(partial_1, partial_2); - partial_distances[ix] = partial_distance; + QuantizedEmbeddingSizeCombination::::compare_quantized( + &self.cc, v1, v2, + ) + } +} + +pub type HnswQuantizer16 = HnswQuantizer< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + Centroid16Comparator, +>; +pub type HnswQuantizer32 = HnswQuantizer< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + Centroid32Comparator, +>; + +pub struct DomainQuantizer< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C, +> { + domain: String, + derived_domain: String, + quantizer: Arc< + HnswQuantizer< + SIZE, + CENTROID_SIZE, + QUANTIZED_SIZE, + ArrayCentroidComparator, + >, + >, +} + +impl Clone + for DomainQuantizer +{ + fn clone(&self) -> Self { + Self { + domain: self.domain.clone(), + derived_domain: self.derived_domain.clone(), + quantizer: self.quantizer.clone(), } + } +} - vecmath::sum_96(&partial_distances).sqrt() +#[derive(Serialize, Deserialize)] +pub struct DomainQuantizerMeta { + domain: String, + derived_domain: String, +} + +impl + Quantizer for DomainQuantizer +where + ArrayCentroidComparator: Comparator, +{ + fn quantize(&self, vec: &[f32; SIZE]) -> [u16; QUANTIZED_SIZE] { + self.quantizer.quantize(vec) + } + + fn reconstruct(&self, qvec: &[u16; QUANTIZED_SIZE]) -> [f32; SIZE] { + self.quantizer.reconstruct(qvec) } } -impl Serializable for Quantized16Comparator { - type Params = (); +impl + Serializable for DomainQuantizer +{ + type Params = Arc; fn serialize>(&self, path: P) -> Result<(), SerializationError> { - let path_buf: PathBuf = path.as_ref().into(); - std::fs::create_dir_all(&path_buf)?; - - let index_path = path_buf.join("index"); - self.cc.serialize(index_path)?; + let meta = DomainQuantizerMeta { + domain: self.domain.clone(), + derived_domain: self.derived_domain.clone(), + }; + let data = serde_json::to_string(&meta)?; + std::fs::write(path, data)?; - let vector_path = path_buf.join("vectors"); - let mut vector_file = VectorFile::create(vector_path, true)?; - vector_file.append_vector_range(self.data.vecs())?; Ok(()) } fn deserialize>( path: P, - _params: Self::Params, + params: &Self::Params, ) -> Result { - let path_buf: PathBuf = path.as_ref().into(); - let index_path = path_buf.join("index"); - let cc = Centroid16Comparator::deserialize(index_path, ())?; - - let vector_path = path_buf.join("vectors"); - let vector_file = VectorFile::open(vector_path, true)?; - let range = vector_file.all_vectors()?; - - let data = Arc::new(range); - Ok(Self { cc, data }) - } -} - -impl pq::VectorStore for Quantized16Comparator { - type T = ::T; - - fn store(&mut self, i: Box>) -> Vec { - // this is p retty stupid, but then, these comparators should not be storing in the first place - let mut new_contents: Vec = Vec::with_capacity(self.data.len() + i.size_hint().0); - new_contents.extend(self.data.vecs().iter()); - let vid = self.data.len(); - let mut vectors: Vec = Vec::new(); - new_contents.extend(i.enumerate().map(|(i, v)| { - vectors.push(VectorId(vid + i)); - v - })); - - let end = new_contents.len(); + let DomainQuantizerMeta { + domain, + derived_domain, + } = serde_json::from_reader(BufReader::new(std::fs::File::open(path)?))?; - let data = LoadedVectorRange::new(new_contents, 0..end); - self.data = Arc::new(data); + let d = params.get_domain(&domain).expect("domain not found"); + let dd: PqDerivedDomainInfo = d + .get_derived_domain_info(&derived_domain) + .expect("derived domain not found"); - vectors + Ok(dd.quantizer.clone()) } } diff --git a/vectorlink/src/configuration.rs b/vectorlink/src/configuration.rs index dd2509d..d3ca518 100644 --- a/vectorlink/src/configuration.rs +++ b/vectorlink/src/configuration.rs @@ -1,19 +1,23 @@ use std::{fs::OpenOptions, path::PathBuf, sync::Arc}; use itertools::Either; -use parallel_hnsw::{pq::QuantizedHnsw, AbstractVector, Hnsw, Serializable, VectorId}; +use parallel_hnsw::{ + pq::{HnswQuantizer, QuantizedHnsw}, + AbstractVector, Hnsw, Serializable, VectorId, +}; use rayon::iter::IndexedParallelIterator; use serde::{Deserialize, Serialize}; use crate::{ comparator::{ - Centroid16Comparator, Centroid32Comparator, DiskOpenAIComparator, OpenAIComparator, - Quantized16Comparator, Quantized32Comparator, + Centroid16Comparator, Centroid32Comparator, DiskOpenAIComparator, DomainQuantizer, + OpenAIComparator, Quantized16Comparator, Quantized32Comparator, }, openai::Model, vecmath::{ - Embedding, CENTROID_16_LENGTH, CENTROID_32_LENGTH, EMBEDDING_LENGTH, - QUANTIZED_16_EMBEDDING_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, + Embedding, EuclideanDistance16, EuclideanDistance32, CENTROID_16_LENGTH, + CENTROID_32_LENGTH, EMBEDDING_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, }, vectors::VectorStore, }; @@ -42,9 +46,14 @@ pub enum HnswConfiguration { EMBEDDING_LENGTH, CENTROID_32_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH, - Centroid32Comparator, Quantized32Comparator, DiskOpenAIComparator, + DomainQuantizer< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + EuclideanDistance32, + >, >, ), SmallQuantizedOpenAi( @@ -53,9 +62,14 @@ pub enum HnswConfiguration { EMBEDDING_LENGTH, CENTROID_16_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, - Centroid16Comparator, Quantized16Comparator, DiskOpenAIComparator, + DomainQuantizer< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + EuclideanDistance16, + >, >, ), UnquantizedOpenAi(Model, OpenAIHnsw), @@ -173,12 +187,12 @@ impl Serializable for HnswConfiguration { path: P, ) -> Result<(), parallel_hnsw::SerializationError> { match self { - HnswConfiguration::QuantizedOpenAi(_, hnsw) => { - hnsw.serialize(&path)?; - } - HnswConfiguration::UnquantizedOpenAi(_, qhnsw) => { + HnswConfiguration::QuantizedOpenAi(_, qhnsw) => { qhnsw.serialize(&path)?; } + HnswConfiguration::UnquantizedOpenAi(_, hnsw) => { + hnsw.serialize(&path)?; + } HnswConfiguration::SmallQuantizedOpenAi(_, qhnsw) => { qhnsw.serialize(&path)?; } @@ -196,7 +210,7 @@ impl Serializable for HnswConfiguration { fn deserialize>( path: P, - params: Self::Params, + params: &Self::Params, ) -> Result { let state_path: PathBuf = path.as_ref().join("state.json"); let mut state_file = OpenOptions::new() @@ -209,14 +223,14 @@ impl Serializable for HnswConfiguration { Ok(match state.typ { HnswConfigurationType::QuantizedOpenAi => HnswConfiguration::QuantizedOpenAi( state.model, - QuantizedHnsw::deserialize(path, params)?, + QuantizedHnsw::deserialize(path, &(params.clone(), params.clone()))?, ), HnswConfigurationType::UnquantizedOpenAi => { HnswConfiguration::UnquantizedOpenAi(state.model, Hnsw::deserialize(path, params)?) } HnswConfigurationType::SmallQuantizedOpenAi => HnswConfiguration::SmallQuantizedOpenAi( state.model, - QuantizedHnsw::deserialize(path, params)?, + QuantizedHnsw::deserialize(path, &(params.clone(), params.clone()))?, ), }) } diff --git a/vectorlink/src/domain.rs b/vectorlink/src/domain.rs index e625c60..0616c51 100644 --- a/vectorlink/src/domain.rs +++ b/vectorlink/src/domain.rs @@ -1,14 +1,38 @@ use std::{ - any::Any, + any::{Any, TypeId}, + collections::{HashMap, HashSet}, + error::Error, io, + marker::PhantomData, ops::{Deref, DerefMut, Range}, - path::Path, + path::{Path, PathBuf}, sync::{Arc, RwLock}, }; +use clap::ValueEnum; +use linfa::{traits::Fit, DatasetBase}; +use linfa_clustering::KMeans; +use ndarray::{Array, Array2}; +use parallel_hnsw::{ + pq::{HnswQuantizer, Quantizer}, + Comparator, Hnsw, Serializable, VectorId, +}; +use rand::{distributions::Uniform, rngs::StdRng, thread_rng, Rng, SeedableRng}; +use serde::{Deserialize, Serialize}; use urlencoding::encode; -use crate::store::{ImmutableVectorFile, LoadedVectorRange, SequentialVectorLoader, VectorFile}; +use crate::{ + comparator::{ + ArrayCentroidComparator, DistanceCalculator, DomainQuantizer, HnswQuantizer16, + HnswQuantizer32, + }, + store::{ImmutableVectorFile, LoadedVectorRange, SequentialVectorLoader, VectorFile}, + vecmath::{ + Embedding, EuclideanDistance16, EuclideanDistance32, CENTROID_16_LENGTH, + CENTROID_32_LENGTH, EMBEDDING_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + }, +}; pub trait GenericDomain: 'static + Any + Send + Sync { fn name(&self) -> &str; @@ -22,9 +46,389 @@ pub fn downcast_generic_domain( .expect("Could not downcast domain to expected embedding size") } +pub trait Deriver: Any { + type From: Copy; + + fn concatenate_derived(&self, loader: SequentialVectorLoader) -> io::Result<()>; + fn configuration(&self) -> DerivedDomainConfiguration; + fn get_derived_domain_info(&self) -> Box; + fn concatenate_file(&self, file: &VectorFile) -> io::Result<()> { + self.concatenate_derived(file.vector_chunks(self.chunk_size())?)?; + + Ok(()) + } + fn chunk_size(&self) -> usize { + 1_000 + } +} + +pub trait DerivedDomainInfo: Any {} + +pub trait DerivedDomainInitializer { + fn initialize( + &self, + path: PathBuf, + vectors: &VectorFile, + ) -> Result + Send + Sync>, Box>; +} + +// interestingly, we're required to provide our own trait object implementation. Rust is not able to derive it for us. +impl DerivedDomainInitializer + for Box + Send + Sync> +{ + fn initialize( + &self, + path: PathBuf, + vectors: &VectorFile, + ) -> Result + Send + Sync>, Box> { + (**self).initialize(path, vectors) + } +} + +pub struct PqDerivedDomain< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C, +> { + file: RwLock>, + quantizer: DomainQuantizer, +} + +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + DistanceCalculator, + > PqDerivedDomain +where + ArrayCentroidComparator: 'static + Comparator, +{ + fn as_arc( + self, + ) -> Option + Send + Sync + 'static>> { + let expected_type_id = TypeId::of::<[f32; SIZE]>(); + let actual_type_id = TypeId::of::(); + if expected_type_id == actual_type_id { + let result = Arc::new(self) as Arc>; + // this should be safe as we asserted at runtime that these types are the same + let transmuted: Arc + Send + Sync + 'static> = + unsafe { std::mem::transmute(result) }; + + Some(transmuted) + } else { + None + } + } +} + +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + DistanceCalculator, + > Deriver for PqDerivedDomain +where + ArrayCentroidComparator: 'static + Comparator, +{ + type From = [f32; SIZE]; + + fn concatenate_derived(&self, loader: SequentialVectorLoader) -> io::Result<()> { + for chunk in loader { + let chunk = chunk?; + let mut result = Vec::with_capacity(chunk.len()); + for vec in chunk.iter() { + let quantized = self.quantizer.quantize(vec); + result.push(quantized); + } + let mut file = self.file.write().unwrap(); + file.append_vector_range(&result)?; + } + + Ok(()) + } + + fn configuration(&self) -> DerivedDomainConfiguration { + match (SIZE, CENTROID_SIZE, QUANTIZED_SIZE) { + (EMBEDDING_LENGTH, CENTROID_16_LENGTH, QUANTIZED_16_EMBEDDING_LENGTH) => { + DerivedDomainConfiguration::SmallPq + } + (EMBEDDING_LENGTH, CENTROID_32_LENGTH, QUANTIZED_32_EMBEDDING_LENGTH) => { + DerivedDomainConfiguration::LargePq + } + _ => panic!("unserializable pq derived domain"), + } + } + + fn get_derived_domain_info(&self) -> Box { + let info = PqDerivedDomainInfo { + file: self.file.read().unwrap().as_immutable(), + quantizer: self.quantizer.clone(), + }; + Box::new(info) + } +} + +pub struct PqDerivedDomainInfo< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C, +> { + pub file: ImmutableVectorFile<[u16; QUANTIZED_SIZE]>, + pub quantizer: DomainQuantizer, +} + +impl + DerivedDomainInfo for PqDerivedDomainInfo +{ +} + +pub type PqDerivedDomainInfo16 = PqDerivedDomainInfo< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + EuclideanDistance16, +>; +pub type PqDerivedDomainInfo32 = PqDerivedDomainInfo< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + EuclideanDistance32, +>; +pub type PqDerivedDomain16 = PqDerivedDomain< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + EuclideanDistance16, +>; +pub type PqDerivedDomain32 = PqDerivedDomain< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + EuclideanDistance32, +>; +pub type PqDerivedDomainInitializer16 = PqDerivedDomainInitializer< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + EuclideanDistance16, +>; +pub type PqDerivedDomainInitializer32 = PqDerivedDomainInitializer< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + EuclideanDistance32, +>; + +#[derive(Serialize, Deserialize, ValueEnum, Debug, Clone, Copy)] +pub enum DerivedDomainConfiguration { + SmallPq, + LargePq, +} + +impl DerivedDomainConfiguration { + pub fn new>( + &self, + path: P, + ) -> Result + Send + Sync + 'static>, io::Error> { + let vecs_path = path.as_ref().join("quantized.vecs"); + let quantizer_path = path.as_ref().join("quantizer"); + match self { + Self::SmallPq => { + let file = RwLock::new(VectorFile::open(&vecs_path, true)?); + + let quantizer: HnswQuantizer16 = HnswQuantizer::deserialize(&quantizer_path, &()) + .expect("hnsw deserialization failed (small)"); + let domain: PqDerivedDomain16 = PqDerivedDomain { + file, + quantizer: quantizer, + }; + + Ok(domain.as_arc::().unwrap()) + } + Self::LargePq => { + let file = RwLock::new(VectorFile::open(&vecs_path, true)?); + let quantizer: HnswQuantizer32 = HnswQuantizer::deserialize(&quantizer_path, &()) + .expect("hnsw deserialization failed (large)"); + + let domain: PqDerivedDomain32 = PqDerivedDomain { + file, + quantizer: quantizer, + }; + + Ok(domain.as_arc::().unwrap()) + } + } + } + + pub fn initializer( + &self, + ) -> Box + 'static + Send + Sync> { + assert_eq!(TypeId::of::(), TypeId::of::()); + match self { + DerivedDomainConfiguration::SmallPq => { + let initializer = PqDerivedDomainInitializer::< + EMBEDDING_LENGTH, + CENTROID_16_LENGTH, + QUANTIZED_16_EMBEDDING_LENGTH, + EuclideanDistance16, + >::default(); + + let boxed: Box + 'static + Send + Sync> = + Box::new(initializer); + + unsafe { std::mem::transmute(boxed) } + } + DerivedDomainConfiguration::LargePq => { + let initializer = PqDerivedDomainInitializer::< + EMBEDDING_LENGTH, + CENTROID_32_LENGTH, + QUANTIZED_32_EMBEDDING_LENGTH, + EuclideanDistance32, + >::default(); + + let boxed: Box + 'static + Send + Sync> = + Box::new(initializer); + + unsafe { std::mem::transmute(boxed) } + } + } + } +} + +pub struct PqDerivedDomainInitializer< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C, +> { + _x: PhantomData, +} +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + DistanceCalculator, + > PqDerivedDomainInitializer +{ +} + +impl Default + for PqDerivedDomainInitializer +{ + fn default() -> Self { + Self { _x: PhantomData } + } +} + +impl< + const SIZE: usize, + const CENTROID_SIZE: usize, + const QUANTIZED_SIZE: usize, + C: 'static + DistanceCalculator + Default + Send, + > DerivedDomainInitializer<[f32; SIZE]> + for PqDerivedDomainInitializer +where + ArrayCentroidComparator: 'static + Comparator, +{ + fn initialize( + &self, + path: PathBuf, + vectors: &VectorFile<[f32; SIZE]>, + ) -> Result + Send + Sync>, Box> { + // TODO do something else for sizes close to number of vecs + const NUMBER_OF_CENTROIDS: usize = 10_000; + const SAMPLE_SIZE: usize = NUMBER_OF_CENTROIDS / 10; + let selection = if SAMPLE_SIZE >= vectors.num_vecs() { + vectors.all_vectors().unwrap().clone().into_vec() + } else { + let mut rng = thread_rng(); + let mut set = HashSet::new(); + let range = Uniform::from(0_usize..vectors.num_vecs()); + while set.len() != SAMPLE_SIZE { + let candidate = rng.sample(&range); + set.insert(candidate); + } + + set.into_iter() + .map(|index| vectors.vec(index).unwrap()) + .collect() + }; + + // Linfa + let data: Vec = selection.into_iter().flat_map(|v| v.into_iter()).collect(); + let sub_length = data.len() / CENTROID_SIZE; + let sub_arrays = Array::from_shape_vec((sub_length, CENTROID_SIZE), data).unwrap(); + eprintln!("sub_arrays: {sub_arrays:?}"); + let observations = DatasetBase::from(sub_arrays); + // TODO review this number + let number_of_clusters = usize::min(sub_length, 1_000); + let prng = StdRng::seed_from_u64(42); + eprintln!("Running kmeans"); + let model = KMeans::params_with_rng(number_of_clusters, prng.clone()) + .tolerance(1e-2) + .fit(&observations) + .expect("KMeans fitted"); + let centroid_array: Array2 = model.centroids().clone(); + centroid_array.len(); + let centroid_flat: Vec = centroid_array + .into_shape(number_of_clusters * CENTROID_SIZE) + .unwrap() + .to_vec(); + eprintln!("centroid flat len: {}", centroid_flat.len()); + let centroids: Vec<[f32; CENTROID_SIZE]> = centroid_flat + .chunks(CENTROID_SIZE) + .map(|v| { + let mut array = [0.0; CENTROID_SIZE]; + array.copy_from_slice(v); + array + }) + .collect(); + // + eprintln!("Number of centroids: {}", centroids.len()); + + let vector_ids = (0..centroids.len()).map(VectorId).collect(); + let centroid_comparator = ArrayCentroidComparator::new(centroids); + let centroid_m = 24; + let centroid_m0 = 48; + let centroid_order = 12; + let mut centroid_hnsw: Hnsw> = Hnsw::generate( + centroid_comparator, + vector_ids, + centroid_m, + centroid_m0, + centroid_order, + ); + //centroid_hnsw.improve_index(); + centroid_hnsw.improve_neighbors(0.01, 1.0); + + let centroid_quantizer: HnswQuantizer< + SIZE, + CENTROID_SIZE, + QUANTIZED_SIZE, + ArrayCentroidComparator, + > = HnswQuantizer::new(centroid_hnsw); + + let quantizer_path = path.join("quantizer"); + centroid_quantizer.serialize(quantizer_path)?; + + let quantized_path = path.join("quantized.vecs"); + let quantized_file: VectorFile<[u16; QUANTIZED_SIZE]> = + VectorFile::create(quantized_path, true)?; + + let deriver = PqDerivedDomain { + file: RwLock::new(quantized_file), + quantizer: Arc::new(centroid_quantizer), + }; + Ok(Arc::new(deriver)) + } +} + pub struct Domain { name: String, file: RwLock>, + derived_domains: RwLock + Send + Sync>>>, } impl GenericDomain for Domain { @@ -38,7 +442,7 @@ impl GenericDomain for Domain { } #[allow(unused)] -impl Domain { +impl Domain { pub fn name(&self) -> &str { &self.name } @@ -47,14 +451,36 @@ impl Domain { self.file().num_vecs() } - pub fn open>(dir: P, name: &str) -> io::Result { + pub fn open>(dir: P, name: &str) -> Result { let mut path = dir.as_ref().to_path_buf(); let encoded_name = encode(name); path.push(format!("{encoded_name}.vecs")); let file = RwLock::new(VectorFile::open_create(&path, true)?); + // load derived domains + let mut derived_path = path.clone(); + derived_path.set_extension("derived"); + let mut derived_domains = HashMap::new(); + if derived_path.exists() { + for file in std::fs::read_dir(derived_path)? { + let derived = file?; + // now we have to discover what kind of derived domain this is + // the options are hardcoded. + let name = derived.file_name().into_string().unwrap(); + let config_file = derived.path().join("config.json"); + if config_file.exists() { + let mut file = std::fs::File::open(config_file)?; + let config: DerivedDomainConfiguration = serde_json::from_reader(file)?; + let derived_domain = config.new::(derived.path()).expect("TODO"); + + derived_domains.insert(name, derived_domain); + } + } + } + Ok(Domain { name: name.to_string(), + derived_domains: RwLock::new(derived_domains), file, }) } @@ -71,20 +497,13 @@ impl Domain { self.file().as_immutable() } - fn add_vecs<'a, I: Iterator>(&self, vecs: I) -> io::Result<(usize, usize)> - where - T: 'a, - { - let mut vector_file = self.file_mut(); - let old_len = vector_file.num_vecs(); - let count = vector_file.append_vectors(vecs)?; - - Ok((old_len, count)) - } - pub fn concatenate_file>(&self, path: P) -> io::Result<(usize, usize)> { let read_vector_file = VectorFile::open(path, true)?; let old_size = self.num_vecs(); + let derived_domains = self.derived_domains.read().unwrap(); + for derived in derived_domains.values() { + derived.concatenate_file(&read_vector_file)?; + } Ok(( old_size, self.file_mut().append_vector_file(&read_vector_file)?, @@ -106,4 +525,53 @@ impl Domain { pub fn vector_chunks(&self, chunk_size: usize) -> io::Result> { self.file().vector_chunks(chunk_size) } + + pub fn create_derived>( + &self, + name: String, + derived_domain_initializer: N, + ) -> Result<(), Box> { + // first, let's take a read lock on the internal file to stop + // others from doing things to this domain. + // Makes deadlocks less likely as the only hold-and-wait + // pattern then remaining has to involve both file and derived + // domains. + let file = self.file(); + let mut derived_domains = self.derived_domains.write().unwrap(); + assert!( + !derived_domains.contains_key(&name), + "tried to create derived domain that already exists" + ); + + // create a directory for this derived domain + let mut path = file.path().clone(); + path.set_extension("derived"); + path.push(&name); + std::fs::create_dir_all(&path)?; + + // write a config so we can recognize later on what this domain is + let config_path = path.join("config.json"); + let deriver = derived_domain_initializer.initialize(path, &*file)?; + let config = deriver.configuration(); + let config_string = serde_json::to_string(&config).unwrap(); + std::fs::write(config_path, config_string)?; + + // convert all already-existing vectors to this domain + deriver.concatenate_file(&*file)?; + + derived_domains.insert(name, deriver); + + Ok(()) + } + + pub fn get_derived_domain_info(&self, name: &str) -> Option { + let domains = self.derived_domains.read().unwrap(); + let deriver = domains.get(name)?; + let info = deriver.get_derived_domain_info() as Box; + let downcast_info: Box = info + .downcast() + .expect("derived domain info not of expected type"); + + Some(*downcast_info) + } } diff --git a/vectorlink/src/main.rs b/vectorlink/src/main.rs index f2323e3..a759740 100644 --- a/vectorlink/src/main.rs +++ b/vectorlink/src/main.rs @@ -9,13 +9,13 @@ use std::sync::Arc; mod batch; mod comparator; mod configuration; +mod domain; mod indexer; mod openai; mod server; mod store; mod vecmath; mod vectors; -mod domain; mod search_server; @@ -23,15 +23,16 @@ use batch::index_from_operations_file; use clap::CommandFactory; use clap::{Parser, Subcommand, ValueEnum}; use configuration::HnswConfiguration; +use domain::DerivedDomainConfiguration; //use hnsw::Hnsw; use openai::Model; use parallel_hnsw::pq::Quantizer; -use parallel_hnsw::pq::VectorSelector; use parallel_hnsw::AbstractVector; use parallel_hnsw::Comparator; use parallel_hnsw::Serializable; use std::fs::File; use std::io; +use vecmath::Embedding; use vecmath::Quantized32Embedding; use vecmath::EMBEDDING_BYTE_LENGTH; use vecmath::EMBEDDING_LENGTH; @@ -212,6 +213,16 @@ enum Commands { #[arg(short, long)] key: Option, }, + Quantize { + #[arg(short, long)] + directory: String, + #[arg(short, long)] + domain: String, + #[arg(short, long)] + derived: String, + #[arg(short, long, value_enum, default_value_t = DerivedDomainConfiguration::SmallPq)] + method: DerivedDomainConfiguration, + }, } #[derive(Clone, Copy, Debug, ValueEnum)] @@ -636,6 +647,17 @@ async fn main() -> Result<(), Box> { .await .unwrap() } + Commands::Quantize { + directory, + domain, + derived, + method, + } => { + let store = VectorStore::new(directory, 10_000); // num bufs is actually obsolete now. + let domain = store.get_domain(&domain).unwrap(); + let initializer = method.initializer::(); + domain.create_derived(derived.clone(), initializer).unwrap(); + } } Ok(()) diff --git a/vectorlink/src/server.rs b/vectorlink/src/server.rs index 181cb71..7dac9d0 100644 --- a/vectorlink/src/server.rs +++ b/vectorlink/src/server.rs @@ -460,7 +460,7 @@ impl Service { let index_path = index_serialization_path(path, index_id); Ok(Arc::new(OpenAIHnsw::deserialize( index_path, - self.vector_store.clone(), + &self.vector_store, )?)) } } diff --git a/vectorlink/src/store.rs b/vectorlink/src/store.rs index a994b66..7c0997e 100644 --- a/vectorlink/src/store.rs +++ b/vectorlink/src/store.rs @@ -277,34 +277,10 @@ impl VectorFile { (self.num_vecs * std::mem::size_of::()) as u64, )?; self.num_vecs = self.num_vecs + vectors.len(); - self.file.sync_data()?; + self.file.sync_data()?; // TODO probably don't do it here cause we might want to append multiple ranges Ok(vectors.len()) } - pub fn append_vectors<'a, I: Iterator>(&mut self, vectors: I) -> io::Result - where - T: 'a, - { - // wouldn't it be more straightforward to just use the file as a cursor? - let mut offset = (self.num_vecs * std::mem::size_of::()) as u64; - let mut count = 0; - for vector in vectors { - let bytes = unsafe { - std::slice::from_raw_parts( - vector as *const T as *const u8, - std::mem::size_of::(), - ) - }; - self.file.write_all_at(bytes, offset)?; - self.num_vecs += 1; - offset += std::mem::size_of::() as u64; - count += 1; - } - - self.file.sync_data()?; - - Ok(count) - } pub fn append_vector_file(&mut self, file: &VectorFile) -> io::Result { let mut read_offset = 0; @@ -361,6 +337,10 @@ impl VectorFile { _x: PhantomData, }) } + + pub fn path(&self) -> &PathBuf { + &self.path + } } pub struct ImmutableVectorFile(VectorFile); diff --git a/vectorlink/src/vecmath.rs b/vectorlink/src/vecmath.rs index 4e2db8f..518458f 100644 --- a/vectorlink/src/vecmath.rs +++ b/vectorlink/src/vecmath.rs @@ -128,7 +128,7 @@ pub fn cosine_partial_distance_32(v1: &Centroid32, v2: &Centroid32) -> f32 { simd::cosine_partial_distance_32_simd(v1, v2) } -#[derive(Default)] +#[derive(Default, Clone)] pub struct EuclideanDistance32; impl DistanceCalculator for EuclideanDistance32 { type T = Centroid32; @@ -148,7 +148,7 @@ impl DistanceCalculator for EuclideanDistance32 { } } -#[derive(Default)] +#[derive(Default, Clone)] pub struct EuclideanDistance16; impl DistanceCalculator for EuclideanDistance16 { type T = Centroid16; diff --git a/vectorlink/src/vectors.rs b/vectorlink/src/vectors.rs index d798f2a..de93875 100644 --- a/vectorlink/src/vectors.rs +++ b/vectorlink/src/vectors.rs @@ -3,6 +3,7 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; +use std::error::Error; use std::fmt; use std::fs::{File, OpenOptions}; use std::io::{self, Seek, SeekFrom, Write}; @@ -45,7 +46,8 @@ impl VectorStore { } } - pub fn get_domain(&self, name: &str) -> io::Result>> { + // TODO better error + pub fn get_domain(&self, name: &str) -> Result>, io::Error> { let domains = self.domains.read().unwrap(); if let Some(domain) = domains.get(name) { Ok(downcast_generic_domain(domain.clone())) @@ -55,7 +57,10 @@ impl VectorStore { if let Some(domain) = domains.get(name) { Ok(downcast_generic_domain(domain.clone())) } else { - let domain = Arc::new(Domain::open(&self.dir, name)?); + let domain = Arc::new( + Domain::open(&self.dir, name) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?, + ); domains.insert(name.to_string(), domain.clone()); Ok(domain)