diff --git a/build.sbt b/build.sbt index b27a2e7ff0..b4022d2966 100644 --- a/build.sbt +++ b/build.sbt @@ -138,6 +138,7 @@ val shapelessVersion = "2.3.10" val sparkeyVersion = "3.2.5" val tensorFlowVersion = "0.4.2" val testContainersVersion = "0.41.0" +val voyagerVersion = "1.2.6" val zoltarVersion = "0.6.0" // dependent versions val scalatestplusVersion = s"$scalatestVersion.0" @@ -895,7 +896,9 @@ lazy val `scio-extra`: Project = project "com.google.zetasketch" % "zetasketch" % zetasketchVersion, "com.nrinaudo" %% "kantan.codecs" % kantanCodecsVersion, "com.nrinaudo" %% "kantan.csv" % kantanCsvVersion, + "com.softwaremill.magnolia1_2" %% "magnolia" % magnoliaVersion, "com.spotify" % "annoy" % annoyVersion, + "com.spotify" % "voyager" % voyagerVersion, "com.spotify.sparkey" % "sparkey" % sparkeyVersion, "com.twitter" %% "algebird-core" % algebirdVersion, "io.circe" %% "circe-core" % circeVersion, diff --git a/scio-extra/src/it/scala/com/spotify/scio/extra/voyager/VoyagerIT.scala b/scio-extra/src/it/scala/com/spotify/scio/extra/voyager/VoyagerIT.scala new file mode 100644 index 0000000000..7095b9c631 --- /dev/null +++ b/scio-extra/src/it/scala/com/spotify/scio/extra/voyager/VoyagerIT.scala @@ -0,0 +1,104 @@ +/* + * Copyright 2023 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package com.spotify.scio.extra.voyager + +import com.spotify.scio.testing.PipelineSpec +import com.spotify.scio.testing.util.ItUtils +import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType} +import org.apache.beam.sdk.io.FileSystems +import org.apache.beam.sdk.options.PipelineOptionsFactory +import org.apache.beam.sdk.util.MimeTypes + +import java.nio.ByteBuffer +import scala.jdk.CollectionConverters._ + +class VoyagerIT extends PipelineSpec { + val dim: Int = 2 + val storageType: StorageDataType = StorageDataType.E4M3 + val distanceMeasure: SpaceType = SpaceType.Cosine + val ef = 200 + val m = 16L + + val sideData: Seq[(String, Array[Float])] = Seq( + "1" -> Array(2.5f, 7.2f), + "2" -> Array(1.2f, 2.2f), + "3" -> Array(5.6f, 3.4f) + ) + + FileSystems.setDefaultPipelineOptions(PipelineOptionsFactory.create()) + + it should "support .asVoyagerSideInput using GCS tempLocation" in { + val tempLocation = ItUtils.gcpTempLocation("voyager-it") + runWithContext { sc => + sc.options.setTempLocation(tempLocation) + val (names, vectors) = sideData.unzip + + val voyagerReader = sc + .parallelize(sideData) + .asVoyagerSideInput(distanceMeasure, storageType, dim) + + val result = sc + .parallelize(vectors) + .withSideInputs(voyagerReader) + .flatMap { case (v, ctx) => + ctx(voyagerReader).getNearest(v, 1, 100) + } + .toSCollection + .map(_.name) + + result should containInAnyOrder(names) + } + + // check files uploaded by voyager + val files = FileSystems + .`match`(s"$tempLocation/voyager-*") + .metadata() + .asScala + .map(_.resourceId()) + + FileSystems.delete(files.asJava) + } + + it should "throw exception when Voyager file exists" in { + val uri = VoyagerUri(ItUtils.gcpTempLocation("voyager-it")) + val indexUri = uri.value.resolve(VoyagerUri.IndexFile) + val nameUri = uri.value.resolve(VoyagerUri.NamesFile) + val indexResourceId = FileSystems.matchNewResource(indexUri.toString, false) + val nameResourceId = FileSystems.matchNewResource(nameUri.toString, false) + + // write some data in the + val f1 = FileSystems.create(nameResourceId, MimeTypes.BINARY) + val f2 = FileSystems.create(indexResourceId, MimeTypes.BINARY) + try { + f1.write(ByteBuffer.wrap("test-data".getBytes())) + f2.write(ByteBuffer.wrap("test-data".getBytes())) + } finally { + f1.close() + f2.close() + } + + val e = the[IllegalArgumentException] thrownBy { + runWithContext { sc => + sc.parallelize(sideData).asVoyager(uri, distanceMeasure, storageType, dim, 200L, 16) + } + } + + e.getMessage shouldBe s"requirement failed: Voyager URI ${uri.value} already exists" + + FileSystems.delete(Seq(nameResourceId, indexResourceId).asJava) + } +} diff --git a/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/Voyager.scala b/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/Voyager.scala new file mode 100644 index 0000000000..cf8b101599 --- /dev/null +++ b/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/Voyager.scala @@ -0,0 +1,188 @@ +/* + * Copyright 2023 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.extra.voyager + +import com.spotify.scio.util.{RemoteFileUtil, ScioUtil} +import com.spotify.scio.values.SideInput +import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType} +import com.spotify.voyager.jni.{Index, StringIndex} +import org.apache.beam.sdk.transforms.DoFn +import org.apache.beam.sdk.values.PCollectionView + +import java.net.URI +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Path, Paths} +import java.util.concurrent.ConcurrentHashMap + +/** + * Represents the base URI for a voyager index, either on a local or a remote file system. For + * remote file systems, the `path` should be in the form 'scheme:////'. For local + * files, it should be in the form '//'. The `uri` specified represents the directory where + * the `index.hnsw` and `names.json` are. + */ +final case class VoyagerUri(value: URI) extends AnyVal { + + import VoyagerUri._ + + def exists(implicit remoteFileUtil: RemoteFileUtil): Boolean = { + if (ScioUtil.isLocalUri(value)) { + VoyagerFiles.exists(f => Paths.get(value.resolve(f)).toFile.exists()) + } else { + VoyagerFiles.exists(f => remoteFileUtil.remoteExists(value.resolve(f))) + } + } +} + +object VoyagerUri { + def apply(value: String): VoyagerUri = new VoyagerUri(new URI(value)) + + private[voyager] val IndexFile = "index.hnsw" + private[voyager] val NamesFile = "names.json" + + private[voyager] val VoyagerFiles: Seq[String] = Seq(IndexFile, NamesFile) + +} + +/** Result of a voyager query */ +final case class VoyagerResult(name: String, distance: Float) + +class VoyagerWriter private[voyager] ( + indexFile: Path, + namesFile: Path, + spaceType: SpaceType, + storageDataType: StorageDataType, + dim: Int, + ef: Long = 200L, + m: Long = 16L +) { + import VoyagerWriter._ + + def write(vectors: Iterable[(String, Array[Float])]): Unit = { + val indexOutputStream = Files.newOutputStream(indexFile) + val namesOutputStream = Files.newOutputStream(namesFile) + + val names = List.newBuilder[String] + val index = new Index(spaceType, dim, m, ef, RandomSeed, ChunkSize.toLong, storageDataType) + + vectors.zipWithIndex + .map { case ((name, vector), idx) => (name, vector, idx.toLong) } + .grouped(ChunkSize) + .map(_.unzip3) + .foreach { case (ns, vs, is) => + names ++= ns + index.addItems(vs.toArray, is.toArray, -1) + } + + // save index + index.saveIndex(indexOutputStream) + index.close() + // save names + val json = names.result().mkString("[\"", "\",\"", "\"]") + namesOutputStream.write(json.getBytes(StandardCharsets.UTF_8)) + // close + indexOutputStream.close() + namesOutputStream.close() + } +} + +private object VoyagerWriter { + private val RandomSeed: Long = 1L + private val ChunkSize: Int = 32786 // 2^15 +} + +/** + * Voyager reader class for nearest neighbor lookups. Supports looking up neighbors for a vector and + * returning the string labels and distances associated. + * + * @param indexFile + * The `index.hnsw` file. + * @param namesFile + * The `names.json` file. + * @param spaceType + * The measurement for computing distance between entities. One of Euclidean, Cosine or Dot (inner + * product). + * @param storageDataType + * The Storage type of the vectors at rest. One of Float8, Float32 or E4M3. + * @param dim + * Number of dimensions in vectors. + */ +class VoyagerReader private[voyager] ( + indexFile: Path, + namesFile: Path, + spaceType: SpaceType, + storageDataType: StorageDataType, + dim: Int +) { + require(dim > 0, "Vector dimension should be > 0") + + @transient private lazy val index: StringIndex = + StringIndex.load(indexFile.toString, namesFile.toString, spaceType, dim, storageDataType) + + /** + * Gets maxNumResults nearest neighbors for vector v using ef (where ef is the size of the dynamic + * list for the nearest neighbors during search). + */ + def getNearest(v: Array[Float], maxNumResults: Int, ef: Int): Array[VoyagerResult] = { + val queryResults = index.query(v, maxNumResults, ef) + queryResults.getNames + .zip(queryResults.getDistances) + .map { case (name, distance) => VoyagerResult(name, distance) } + } +} + +/** + * Construction for a VoyagerSide input that leverages a synchronized map to ensure that the reader + * is only loaded once per [[VoyagerUri]]. + */ +private[voyager] class VoyagerSideInput( + val view: PCollectionView[VoyagerUri], + remoteFileUtil: RemoteFileUtil, + distanceMeasure: SpaceType, + storageType: StorageDataType, + dim: Int +) extends SideInput[VoyagerReader] { + + import VoyagerSideInput._ + + private def createReader(uri: VoyagerUri): VoyagerReader = { + val indexUri = uri.value.resolve(VoyagerUri.IndexFile) + val namesUri = uri.value.resolve(VoyagerUri.NamesFile) + + val (localIndex, localNames) = if (ScioUtil.isLocalUri(uri.value)) { + (Paths.get(indexUri), Paths.get(namesUri)) + } else { + val downloadedIndex = remoteFileUtil.download(indexUri) + val downloadedNames = remoteFileUtil.download(namesUri) + (downloadedIndex, downloadedNames) + } + new VoyagerReader(localIndex, localNames, distanceMeasure, storageType, dim) + } + + override def get[I, O](context: DoFn[I, O]#ProcessContext): VoyagerReader = { + val uri = context.sideInput(view) + VoyagerReaderSharedCache.computeIfAbsent(uri, createReader) + } +} + +private object VoyagerSideInput { + // cache the VoyagerUri to VoyagerReader per JVM so workers with multiple + // voyager side-input steps load the index only once + @transient private lazy val VoyagerReaderSharedCache + : ConcurrentHashMap[VoyagerUri, VoyagerReader] = + new ConcurrentHashMap() +} diff --git a/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/package.scala b/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/package.scala new file mode 100644 index 0000000000..6d7c426cb8 --- /dev/null +++ b/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/package.scala @@ -0,0 +1,23 @@ +/* + * Copyright 2023 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.extra + +import com.spotify.scio.extra.voyager.syntax.AllSyntax + +/** Main package for Voyager side input APIs. */ +package object voyager extends AllSyntax diff --git a/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/syntax/AllSyntax.scala b/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/syntax/AllSyntax.scala new file mode 100644 index 0000000000..b9a3ef5852 --- /dev/null +++ b/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/syntax/AllSyntax.scala @@ -0,0 +1,20 @@ +/* + * Copyright 2023 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.extra.voyager.syntax + +trait AllSyntax extends ScioContextSyntax with SCollectionSyntax diff --git a/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/syntax/SCollectionSyntax.scala b/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/syntax/SCollectionSyntax.scala new file mode 100644 index 0000000000..804fe99422 --- /dev/null +++ b/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/syntax/SCollectionSyntax.scala @@ -0,0 +1,209 @@ +/* + * Copyright 2023 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.extra.voyager.syntax + +import com.spotify.scio.annotations.experimental +import com.spotify.scio.extra.voyager.{VoyagerReader, VoyagerSideInput, VoyagerUri, VoyagerWriter} +import com.spotify.scio.util.{RemoteFileUtil, ScioUtil} +import com.spotify.scio.values.{SCollection, SideInput} +import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType} +import org.apache.beam.sdk.transforms.View + +import java.nio.file.{Files, Paths} +import java.util.UUID + +class VoyagerSCollectionOps(@transient private val self: SCollection[VoyagerUri]) extends AnyVal { + + /** + * Load the Voyager index stored at [[VoyagerUri]] in this + * [[com.spotify.scio.values.SCollection SCollection]]. + * + * @param spaceType + * The measurement for computing distance between entities. One of Euclidean, Cosine or Dot + * (inner product). + * @param storageType + * The Storage type of the vectors at rest. One of Float8, Float32 or E4M3. + * @param dim + * Number of dimensions in vectors. + * @return + * SideInput[VoyagerReader] + */ + @experimental + def asVoyagerSideInput( + spaceType: SpaceType, + storageType: StorageDataType, + dim: Int + ): SideInput[VoyagerReader] = { + val view = self.applyInternal(View.asSingleton()) + new VoyagerSideInput( + view, + RemoteFileUtil.create(self.context.options), + spaceType, + storageType, + dim + ) + } + +} + +class VoyagerPairSCollectionOps( + @transient private val self: SCollection[(String, Array[Float])] +) extends AnyVal { + + /** + * Write the key-value pairs of this SCollection as a Voyager index to a specified location using + * the parameters specified. + * + * @param uri + * The [[VoyagerUri]]. + * @param spaceType + * The measurement for computing distance between entities. One of Euclidean, Cosine or Dot + * (inner product). + * @param storageDataType + * The storage data type of the vectors at rest. One of Float8, Float32 or E4M3. + * @param dim + * Number of dimensions in vectors. + * @param ef + * The size of the dynamic list of neighbors used during construction time. This parameter + * controls query time/accuracy tradeoff. More information can be found in the hnswlib + * documentation https://github.com/nmslib/hnswlib. + * @param m + * The number of outgoing connections in the graph. + * @return + * A [[VoyagerUri]] representing where the index was written to. + */ + @experimental + def asVoyager( + uri: VoyagerUri, + spaceType: SpaceType, + storageDataType: StorageDataType, + dim: Int, + ef: Long, + m: Long + ): SCollection[VoyagerUri] = { + implicit val remoteFileUtil: RemoteFileUtil = RemoteFileUtil.create(self.context.options) + require(!uri.exists, s"Voyager URI ${uri.value} already exists") + + self.transform { in => + in.reifyAsIterableInGlobalWindow + .map { xs => + val indexUri = uri.value.resolve(VoyagerUri.IndexFile) + val namesUri = uri.value.resolve(VoyagerUri.NamesFile) + val isLocal = ScioUtil.isLocalUri(uri.value) + + val (localIndex, localNames) = if (isLocal) { + (Paths.get(indexUri), Paths.get(namesUri)) + } else { + val tmpDir = Files.createTempDirectory("voyager-") + val tmpIndex = tmpDir.resolve(VoyagerUri.IndexFile) + val tmpNames = tmpDir.resolve(VoyagerUri.NamesFile) + (tmpIndex, tmpNames) + } + + val writer = + new VoyagerWriter(localIndex, localNames, spaceType, storageDataType, dim, ef, m) + writer.write(xs) + + if (!isLocal) { + remoteFileUtil.upload(localIndex, indexUri) + remoteFileUtil.upload(localNames, namesUri) + } + + uri + } + } + } + + /** + * Write the key-value pairs of this SCollection as a Voyager index to a temporary location and + * building the index using the parameters specified. + * + * @param distanceMeasure + * The measurement for computing distance between entities. One of Euclidean, Cosine or Dot + * (inner product). + * @param storageDataType + * The Storage type of the vectors at rest. One of Float8, Float32 or E4M3. + * @param dim + * Number of dimensions in vectors. + * @param ef + * The size of the dynamic list of neighbors used during construction time. This parameter + * controls query time/accuracy tradeoff. More information can be found in the hnswlib + * documentation https://github.com/nmslib/hnswlib. + * @param m + * The number of outgoing connections in the graph. + * @return + * A [[VoyagerUri]] representing where the index was written to. + */ + @experimental + def asVoyager( + distanceMeasure: SpaceType, + storageDataType: StorageDataType, + dim: Int, + ef: Long = 200L, + m: Long = 16L + ): SCollection[VoyagerUri] = { + val uuid = UUID.randomUUID() + val tempLocation: String = self.context.options.getTempLocation + require(tempLocation != null, s"Voyager writes require --tempLocation to be set.") + val uri = VoyagerUri(s"${tempLocation.stripSuffix("/")}/voyager-build-$uuid") + asVoyager(uri, distanceMeasure, storageDataType, dim, ef, m) + } + + /** + * Write the key-value pairs of this SCollection as a Voyager index to a temporary location, + * building the index using the parameters specified and then loading the reader into a side + * input. + * + * @param spaceType + * The measurement for computing distance between entities. One of Euclidean, Cosine or Dot + * (inner product). + * @param storageType + * The Storage type of the vectors at rest. One of Float8, Float32 or E4M3. + * @param dim + * Number of dimensions in vectors. + * @param ef + * The size of the dynamic list of neighbors used during construction time. This parameter + * controls query time/accuracy tradeoff. More information can be found in the hnswlib + * documentation https://github.com/nmslib/hnswlib. + * @param m + * The number of outgoing connections in the graph. + * @return + * A SideInput with a [[VoyagerReader]] + */ + @experimental + def asVoyagerSideInput( + spaceType: SpaceType, + storageType: StorageDataType, + dim: Int, + ef: Long = 200L, + m: Long = 16L + ): SideInput[VoyagerReader] = + new VoyagerSCollectionOps(asVoyager(spaceType, storageType, dim, ef, m)) + .asVoyagerSideInput(spaceType, storageType, dim) +} + +trait SCollectionSyntax { + implicit def voyagerSCollectionOps(coll: SCollection[VoyagerUri]): VoyagerSCollectionOps = + new VoyagerSCollectionOps(coll) + + implicit def VoyagerPairSCollectionOps( + coll: SCollection[(String, Array[Float])] + ): VoyagerPairSCollectionOps = + new VoyagerPairSCollectionOps(coll) + +} diff --git a/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/syntax/ScioContextSyntax.scala b/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/syntax/ScioContextSyntax.scala new file mode 100644 index 0000000000..cff6f37512 --- /dev/null +++ b/scio-extra/src/main/scala/com/spotify/scio/extra/voyager/syntax/ScioContextSyntax.scala @@ -0,0 +1,62 @@ +/* + * Copyright 2023 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.extra.voyager.syntax + +import com.spotify.scio.ScioContext +import com.spotify.scio.annotations.experimental +import com.spotify.scio.extra.voyager.{VoyagerReader, VoyagerSideInput, VoyagerUri} +import com.spotify.scio.util.RemoteFileUtil +import com.spotify.scio.values.SideInput +import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType} +import org.apache.beam.sdk.transforms.View + +/** Enhanced version of [[ScioContext]] with Voyager methods */ +class VoyagerScioContextOps(private val self: ScioContext) extends AnyVal { + + /** + * Creates a SideInput of [[VoyagerReader]] from an [[VoyagerUri]] base path. To be used with + * [[com.spotify.scio.values.SCollection.withSideInputs SCollection.withSideInputs]] + * + * @param uri + * The [[VoyagerUri]]. + * @param spaceType + * The measurement for computing distance between entities. One of Euclidean, Cosine or Dot + * (inner product). + * @param storageDataType + * The Storage type of the vectors at rest. One of Float8, Float32 or E4M3. + * @param dim + * Number of dimensions in vectors. + * @return + * A [[SideInput]] of the [[VoyagerReader]] to be used for querying. + */ + @experimental + def voyagerSideInput( + uri: VoyagerUri, + spaceType: SpaceType, + storageDataType: StorageDataType, + dim: Int + ): SideInput[VoyagerReader] = { + val view = self.parallelize(Seq(uri)).applyInternal(View.asSingleton()) + new VoyagerSideInput(view, RemoteFileUtil.create(self.options), spaceType, storageDataType, dim) + } +} + +trait ScioContextSyntax { + implicit def voyagerScioContextOps(sc: ScioContext): VoyagerScioContextOps = + new VoyagerScioContextOps(sc) +} diff --git a/scio-extra/src/test/scala/com/spotify/scio/extra/voyager/VoyagerTest.scala b/scio-extra/src/test/scala/com/spotify/scio/extra/voyager/VoyagerTest.scala new file mode 100644 index 0000000000..09f615dd94 --- /dev/null +++ b/scio-extra/src/test/scala/com/spotify/scio/extra/voyager/VoyagerTest.scala @@ -0,0 +1,79 @@ +/* + * Copyright 2023 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.extra.voyager + +import com.spotify.scio.testing.CoderAssertions.{notFallback, ValueShouldSyntax} +import com.spotify.scio.testing.PipelineSpec +import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType} +import com.spotify.voyager.jni.StringIndex + +import java.nio.file.Files + +class VoyagerTest extends PipelineSpec { + val spaceType: SpaceType = SpaceType.Cosine + val storageDataType: StorageDataType = StorageDataType.E4M3 + val dim: Int = 2 + + val sideData: Seq[(String, Array[Float])] = + Seq(("1", Array(2.5f, 7.2f)), ("2", Array(1.2f, 2.2f)), ("3", Array(5.6f, 3.4f))) + + "SCollection" should "support .asVoyager with specified local file" in { + val tmpDir = Files.createTempDirectory("voyager-test") + val uri = VoyagerUri(tmpDir.toUri) + + runWithContext { sc => + sc.parallelize(sideData).asVoyager(uri, spaceType, storageDataType, dim, 200L, 16L) + } + + val index = StringIndex.load( + tmpDir.resolve(VoyagerUri.IndexFile).toString, + tmpDir.resolve(VoyagerUri.NamesFile).toString, + SpaceType.Cosine, + dim, + StorageDataType.E4M3 + ) + + sideData.foreach { data => + val result = index.query(data._2, 2, 100) + result.getNames.length shouldEqual 2 + result.getDistances.length shouldEqual 2 + result.getNames should contain(data._1) + } + } + + it should "throw exception when the Voyager files already exists" in { + val tmpDir = Files.createTempDirectory("voyager-test") + val uri = VoyagerUri(tmpDir.toUri) + + val index = tmpDir.resolve("index.hnsw") + val names = tmpDir.resolve("names.json") + Files.createFile(index) + Files.createFile(names) + + the[IllegalArgumentException] thrownBy { + runWithContext { sc => + sc.parallelize(sideData).asVoyager(uri, spaceType, storageDataType, dim, 200L, 16L) + } + } should have message s"requirement failed: Voyager URI ${uri.value} already exists" + } + + "VoyagerUri" should "not use Kryo" in { + val uri = VoyagerUri("gs://this-that") + uri coderShould notFallback() + } +} diff --git a/site/src/main/paradox/extras/Voyager.md b/site/src/main/paradox/extras/Voyager.md new file mode 100644 index 0000000000..485e89ffa9 --- /dev/null +++ b/site/src/main/paradox/extras/Voyager.md @@ -0,0 +1,76 @@ +# Voyager + +Scio supports Spotify's [Voyager](https://github.com/spotify/voyager), which provides an easy to use API on top of `hnswlib` that that +works in python and java. + +## Write + +A keyed `SCollection` with `String` keys and `Array[Float]` vector values can be saved with @scaladoc[asVoyager](com.spotify.scio.extra.voyager.syntax.VoyagerPairSCollectionOps#asVoyager(uri:com.spotify.scio.extra.voyager.VoyagerUri,spaceType:com.spotify.voyager.jni.Index.SpaceType,storageDataType:com.spotify.voyager.jni.Index.StorageDataType,dim:Int,ef:Long,m:Long):com.spotify.scio.values.SCollection[com.spotify.scio.extra.voyager.VoyagerUri]): + +```scala +import com.spotify.scio.values.SCollection +import com.spotify.scio.extra.voyager._ +import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType} + +val voyagerUri = VoyagerUri("gs://output-path") +val dim: Int = ??? +val ef: Long = ??? +val m: Int = ??? +val storageType: StorageDataType = ??? +val spaceType: SpaceType = ??? +val itemVectors: SCollection[(String, Array[Float])] = ??? +itemVectors.asVoyager(voyagerUri, spaceType, storageType, dim, ef, m) +``` + +## Side Input + +A Voyager index can be read directly as a `SideInput` with @scaladoc[asVoyagerSideInput](com.spotify.scio.extra.voyager.syntax.VoyagerScioContextOps#voyagerSideInput(uri:com.spotify.scio.extra.voyager.VoyagerUri,spaceType:com.spotify.voyager.jni.Index.SpaceType,storageDataType:com.spotify.voyager.jni.Index.StorageDataType,dim:Int):com.spotify.scio.values.SideInput[com.spotify.scio.extra.voyager.VoyagerReader]): + +```scala +import com.spotify.scio._ +import com.spotify.scio.values.SCollection +import com.spotify.scio.extra.voyager._ +import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType} + +val sc: ScioContext = ??? + +val voyagerUri = VoyagerUri("gs://output-path") +val dim: Int = ??? +val storageType: StorageDataType = ??? +val spaceType: SpaceType = ??? +val itemVectors: SCollection[(String, Array[Float])] = ??? +val voyagerSI: SideInput[VoyagerReader] = sc.voyagerSideInput(voyagerUri, spaceType, storageType, dim) +``` + +Alternatively, an `SCollection` can be converted directly to a `SideInput` with @scaladoc[asVoyagerSideInput](com.spotify.scio.extra.voyager.syntax.VoyagerSCollectionOps#asVoyagerSideInput(spaceType:com.spotify.voyager.jni.Index.SpaceType,storageType:com.spotify.voyager.jni.Index.StorageDataType,dim:Int):com.spotify.scio.values.SideInput[com.spotify.scio.extra.voyager.VoyagerReader]): +```scala +import com.spotify.scio.values.SCollection +import com.spotify.scio.extra.voyager._ +import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType} + +val dim: Int = ??? +val ef: Long = ??? +val m: Int = ??? +val storageType: StorageDataType = ??? +val spaceType: SpaceType = ??? +val itemVectors: SCollection[(String, Array[Float])] = ??? +val voyagerSI: SideInput[VoyagerReader] = itemVectors.asVoyagerSideInput(spaceType, storageType, dim, ef, m) +``` + +An @scaladoc[VoyagerReader](com.spotify.scio.extra.voyager.VoyagerReader) provides access to querying the Voyager index to get their nearest neighbors. +```scala +import com.spotify.scio.values.{SCollection, SideInput} +import com.spotify.scio.extra.voyager._ + +val voyagerSI: SideInput[VoyagerReader] = ??? +val elements: SCollection[(String, Array[Float])] = ??? +val maxNumResults: Int = ??? +val ef: Int = ??? + +val queryResults: SCollection[(String, Array[VoyagerResult])] = elements + .withSideInputs(voyagerSI) + .map { case ((label, vector), ctx) => + val voyagerReader: VoyagerReader = ctx(voyagerSI) + (label, voyagerReader.getNearest(vector, maxNumResults, ef)) + } +``` diff --git a/site/src/main/paradox/extras/index.md b/site/src/main/paradox/extras/index.md index dba8a6656d..85a597c1d3 100644 --- a/site/src/main/paradox/extras/index.md +++ b/site/src/main/paradox/extras/index.md @@ -17,5 +17,6 @@ * @ref:[Sparkey](Sparkey.md) * @ref:[REPL](Scio-REPL.md) * @ref:[Transforms](Transforms.md) +* @ref:[Voyager](Voyager.md) @@@