-
Notifications
You must be signed in to change notification settings - Fork 513
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
765 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
scio-extra/src/it/scala/com/spotify/scio/extra/voyager/VoyagerIT.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/* | ||
* Copyright 2023 Spotify AB. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package com.spotify.scio.extra.voyager | ||
|
||
import com.spotify.scio.testing.PipelineSpec | ||
import com.spotify.scio.testing.util.ItUtils | ||
import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType} | ||
import org.apache.beam.sdk.io.FileSystems | ||
import org.apache.beam.sdk.options.PipelineOptionsFactory | ||
import org.apache.beam.sdk.util.MimeTypes | ||
|
||
import java.nio.ByteBuffer | ||
import scala.jdk.CollectionConverters._ | ||
|
||
class VoyagerIT extends PipelineSpec { | ||
val dim: Int = 2 | ||
val storageType: StorageDataType = StorageDataType.E4M3 | ||
val distanceMeasure: SpaceType = SpaceType.Cosine | ||
val ef = 200 | ||
val m = 16L | ||
|
||
val sideData: Seq[(String, Array[Float])] = Seq( | ||
"1" -> Array(2.5f, 7.2f), | ||
"2" -> Array(1.2f, 2.2f), | ||
"3" -> Array(5.6f, 3.4f) | ||
) | ||
|
||
FileSystems.setDefaultPipelineOptions(PipelineOptionsFactory.create()) | ||
|
||
it should "support .asVoyagerSideInput using GCS tempLocation" in { | ||
val tempLocation = ItUtils.gcpTempLocation("voyager-it") | ||
runWithContext { sc => | ||
sc.options.setTempLocation(tempLocation) | ||
val (names, vectors) = sideData.unzip | ||
|
||
val voyagerReader = sc | ||
.parallelize(sideData) | ||
.asVoyagerSideInput(distanceMeasure, storageType, dim) | ||
|
||
val result = sc | ||
.parallelize(vectors) | ||
.withSideInputs(voyagerReader) | ||
.flatMap { case (v, ctx) => | ||
ctx(voyagerReader).getNearest(v, 1, 100) | ||
} | ||
.toSCollection | ||
.map(_.name) | ||
|
||
result should containInAnyOrder(names) | ||
} | ||
|
||
// check files uploaded by voyager | ||
val files = FileSystems | ||
.`match`(s"$tempLocation/voyager-*") | ||
.metadata() | ||
.asScala | ||
.map(_.resourceId()) | ||
|
||
FileSystems.delete(files.asJava) | ||
} | ||
|
||
it should "throw exception when Voyager file exists" in { | ||
val uri = VoyagerUri(ItUtils.gcpTempLocation("voyager-it")) | ||
val indexUri = uri.value.resolve(VoyagerUri.IndexFile) | ||
val nameUri = uri.value.resolve(VoyagerUri.NamesFile) | ||
val indexResourceId = FileSystems.matchNewResource(indexUri.toString, false) | ||
val nameResourceId = FileSystems.matchNewResource(nameUri.toString, false) | ||
|
||
// write some data in the | ||
val f1 = FileSystems.create(nameResourceId, MimeTypes.BINARY) | ||
val f2 = FileSystems.create(indexResourceId, MimeTypes.BINARY) | ||
try { | ||
f1.write(ByteBuffer.wrap("test-data".getBytes())) | ||
f2.write(ByteBuffer.wrap("test-data".getBytes())) | ||
} finally { | ||
f1.close() | ||
f2.close() | ||
} | ||
|
||
val e = the[IllegalArgumentException] thrownBy { | ||
runWithContext { sc => | ||
sc.parallelize(sideData).asVoyager(uri, distanceMeasure, storageType, dim, 200L, 16) | ||
} | ||
} | ||
|
||
e.getMessage shouldBe s"requirement failed: Voyager URI ${uri.value} already exists" | ||
|
||
FileSystems.delete(Seq(nameResourceId, indexResourceId).asJava) | ||
} | ||
} |
188 changes: 188 additions & 0 deletions
188
scio-extra/src/main/scala/com/spotify/scio/extra/voyager/Voyager.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
/* | ||
* Copyright 2023 Spotify AB. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package com.spotify.scio.extra.voyager | ||
|
||
import com.spotify.scio.util.{RemoteFileUtil, ScioUtil} | ||
import com.spotify.scio.values.SideInput | ||
import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType} | ||
import com.spotify.voyager.jni.{Index, StringIndex} | ||
import org.apache.beam.sdk.transforms.DoFn | ||
import org.apache.beam.sdk.values.PCollectionView | ||
|
||
import java.net.URI | ||
import java.nio.charset.StandardCharsets | ||
import java.nio.file.{Files, Path, Paths} | ||
import java.util.concurrent.ConcurrentHashMap | ||
|
||
/** | ||
* Represents the base URI for a voyager index, either on a local or a remote file system. For | ||
* remote file systems, the `path` should be in the form 'scheme://<bucket>/<path>/'. For local | ||
* files, it should be in the form '/<path>/'. The `uri` specified represents the directory where | ||
* the `index.hnsw` and `names.json` are. | ||
*/ | ||
final case class VoyagerUri(value: URI) extends AnyVal { | ||
|
||
import VoyagerUri._ | ||
|
||
def exists(implicit remoteFileUtil: RemoteFileUtil): Boolean = { | ||
if (ScioUtil.isLocalUri(value)) { | ||
VoyagerFiles.exists(f => Paths.get(value.resolve(f)).toFile.exists()) | ||
} else { | ||
VoyagerFiles.exists(f => remoteFileUtil.remoteExists(value.resolve(f))) | ||
} | ||
} | ||
} | ||
|
||
object VoyagerUri { | ||
def apply(value: String): VoyagerUri = new VoyagerUri(new URI(value)) | ||
|
||
private[voyager] val IndexFile = "index.hnsw" | ||
private[voyager] val NamesFile = "names.json" | ||
|
||
private[voyager] val VoyagerFiles: Seq[String] = Seq(IndexFile, NamesFile) | ||
|
||
} | ||
|
||
/** Result of a voyager query */ | ||
final case class VoyagerResult(name: String, distance: Float) | ||
|
||
class VoyagerWriter private[voyager] ( | ||
indexFile: Path, | ||
namesFile: Path, | ||
spaceType: SpaceType, | ||
storageDataType: StorageDataType, | ||
dim: Int, | ||
ef: Long = 200L, | ||
m: Long = 16L | ||
) { | ||
import VoyagerWriter._ | ||
|
||
def write(vectors: Iterable[(String, Array[Float])]): Unit = { | ||
val indexOutputStream = Files.newOutputStream(indexFile) | ||
val namesOutputStream = Files.newOutputStream(namesFile) | ||
|
||
val names = List.newBuilder[String] | ||
val index = new Index(spaceType, dim, m, ef, RandomSeed, ChunkSize.toLong, storageDataType) | ||
|
||
vectors.zipWithIndex | ||
.map { case ((name, vector), idx) => (name, vector, idx.toLong) } | ||
.grouped(ChunkSize) | ||
.map(_.unzip3) | ||
.foreach { case (ns, vs, is) => | ||
names ++= ns | ||
index.addItems(vs.toArray, is.toArray, -1) | ||
} | ||
|
||
// save index | ||
index.saveIndex(indexOutputStream) | ||
index.close() | ||
// save names | ||
val json = names.result().mkString("[\"", "\",\"", "\"]") | ||
namesOutputStream.write(json.getBytes(StandardCharsets.UTF_8)) | ||
// close | ||
indexOutputStream.close() | ||
namesOutputStream.close() | ||
} | ||
} | ||
|
||
private object VoyagerWriter { | ||
private val RandomSeed: Long = 1L | ||
private val ChunkSize: Int = 32786 // 2^15 | ||
} | ||
|
||
/** | ||
* Voyager reader class for nearest neighbor lookups. Supports looking up neighbors for a vector and | ||
* returning the string labels and distances associated. | ||
* | ||
* @param indexFile | ||
* The `index.hnsw` file. | ||
* @param namesFile | ||
* The `names.json` file. | ||
* @param spaceType | ||
* The measurement for computing distance between entities. One of Euclidean, Cosine or Dot (inner | ||
* product). | ||
* @param storageDataType | ||
* The Storage type of the vectors at rest. One of Float8, Float32 or E4M3. | ||
* @param dim | ||
* Number of dimensions in vectors. | ||
*/ | ||
class VoyagerReader private[voyager] ( | ||
indexFile: Path, | ||
namesFile: Path, | ||
spaceType: SpaceType, | ||
storageDataType: StorageDataType, | ||
dim: Int | ||
) { | ||
require(dim > 0, "Vector dimension should be > 0") | ||
|
||
@transient private lazy val index: StringIndex = | ||
StringIndex.load(indexFile.toString, namesFile.toString, spaceType, dim, storageDataType) | ||
|
||
/** | ||
* Gets maxNumResults nearest neighbors for vector v using ef (where ef is the size of the dynamic | ||
* list for the nearest neighbors during search). | ||
*/ | ||
def getNearest(v: Array[Float], maxNumResults: Int, ef: Int): Array[VoyagerResult] = { | ||
val queryResults = index.query(v, maxNumResults, ef) | ||
queryResults.getNames | ||
.zip(queryResults.getDistances) | ||
.map { case (name, distance) => VoyagerResult(name, distance) } | ||
} | ||
} | ||
|
||
/** | ||
* Construction for a VoyagerSide input that leverages a synchronized map to ensure that the reader | ||
* is only loaded once per [[VoyagerUri]]. | ||
*/ | ||
private[voyager] class VoyagerSideInput( | ||
val view: PCollectionView[VoyagerUri], | ||
remoteFileUtil: RemoteFileUtil, | ||
distanceMeasure: SpaceType, | ||
storageType: StorageDataType, | ||
dim: Int | ||
) extends SideInput[VoyagerReader] { | ||
|
||
import VoyagerSideInput._ | ||
|
||
private def createReader(uri: VoyagerUri): VoyagerReader = { | ||
val indexUri = uri.value.resolve(VoyagerUri.IndexFile) | ||
val namesUri = uri.value.resolve(VoyagerUri.NamesFile) | ||
|
||
val (localIndex, localNames) = if (ScioUtil.isLocalUri(uri.value)) { | ||
(Paths.get(indexUri), Paths.get(namesUri)) | ||
} else { | ||
val downloadedIndex = remoteFileUtil.download(indexUri) | ||
val downloadedNames = remoteFileUtil.download(namesUri) | ||
(downloadedIndex, downloadedNames) | ||
} | ||
new VoyagerReader(localIndex, localNames, distanceMeasure, storageType, dim) | ||
} | ||
|
||
override def get[I, O](context: DoFn[I, O]#ProcessContext): VoyagerReader = { | ||
val uri = context.sideInput(view) | ||
VoyagerReaderSharedCache.computeIfAbsent(uri, createReader) | ||
} | ||
} | ||
|
||
private object VoyagerSideInput { | ||
// cache the VoyagerUri to VoyagerReader per JVM so workers with multiple | ||
// voyager side-input steps load the index only once | ||
@transient private lazy val VoyagerReaderSharedCache | ||
: ConcurrentHashMap[VoyagerUri, VoyagerReader] = | ||
new ConcurrentHashMap() | ||
} |
23 changes: 23 additions & 0 deletions
23
scio-extra/src/main/scala/com/spotify/scio/extra/voyager/package.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* | ||
* Copyright 2023 Spotify AB. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package com.spotify.scio.extra | ||
|
||
import com.spotify.scio.extra.voyager.syntax.AllSyntax | ||
|
||
/** Main package for Voyager side input APIs. */ | ||
package object voyager extends AllSyntax |
20 changes: 20 additions & 0 deletions
20
scio-extra/src/main/scala/com/spotify/scio/extra/voyager/syntax/AllSyntax.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/* | ||
* Copyright 2023 Spotify AB. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package com.spotify.scio.extra.voyager.syntax | ||
|
||
trait AllSyntax extends ScioContextSyntax with SCollectionSyntax |
Oops, something went wrong.