Skip to content

Commit

Permalink
Merge 8a4d84b into c3725ae
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickwmcgee authored Oct 9, 2023
2 parents c3725ae + 8a4d84b commit 7902583
Show file tree
Hide file tree
Showing 10 changed files with 765 additions and 0 deletions.
3 changes: 3 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ val shapelessVersion = "2.3.10"
val sparkeyVersion = "3.2.5"
val tensorFlowVersion = "0.4.2"
val testContainersVersion = "0.41.0"
val voyagerVersion = "1.2.6"
val zoltarVersion = "0.6.0"
// dependent versions
val scalatestplusVersion = s"$scalatestVersion.0"
Expand Down Expand Up @@ -896,7 +897,9 @@ lazy val `scio-extra`: Project = project
"com.google.zetasketch" % "zetasketch" % zetasketchVersion,
"com.nrinaudo" %% "kantan.codecs" % kantanCodecsVersion,
"com.nrinaudo" %% "kantan.csv" % kantanCsvVersion,
"com.softwaremill.magnolia1_2" %% "magnolia" % magnoliaVersion,
"com.spotify" % "annoy" % annoyVersion,
"com.spotify" % "voyager" % voyagerVersion,
"com.spotify.sparkey" % "sparkey" % sparkeyVersion,
"com.twitter" %% "algebird-core" % algebirdVersion,
"io.circe" %% "circe-core" % circeVersion,
Expand Down
104 changes: 104 additions & 0 deletions scio-extra/src/it/scala/com/spotify/scio/extra/voyager/VoyagerIT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2023 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.spotify.scio.extra.voyager

import com.spotify.scio.testing.PipelineSpec
import com.spotify.scio.testing.util.ItUtils
import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType}
import org.apache.beam.sdk.io.FileSystems
import org.apache.beam.sdk.options.PipelineOptionsFactory
import org.apache.beam.sdk.util.MimeTypes

import java.nio.ByteBuffer
import scala.jdk.CollectionConverters._

class VoyagerIT extends PipelineSpec {
val dim: Int = 2
val storageType: StorageDataType = StorageDataType.E4M3
val distanceMeasure: SpaceType = SpaceType.Cosine
val ef = 200
val m = 16L

val sideData: Seq[(String, Array[Float])] = Seq(
"1" -> Array(2.5f, 7.2f),
"2" -> Array(1.2f, 2.2f),
"3" -> Array(5.6f, 3.4f)
)

FileSystems.setDefaultPipelineOptions(PipelineOptionsFactory.create())

it should "support .asVoyagerSideInput using GCS tempLocation" in {
val tempLocation = ItUtils.gcpTempLocation("voyager-it")
runWithContext { sc =>
sc.options.setTempLocation(tempLocation)
val (names, vectors) = sideData.unzip

val voyagerReader = sc
.parallelize(sideData)
.asVoyagerSideInput(distanceMeasure, storageType, dim)

val result = sc
.parallelize(vectors)
.withSideInputs(voyagerReader)
.flatMap { case (v, ctx) =>
ctx(voyagerReader).getNearest(v, 1, 100)
}
.toSCollection
.map(_.name)

result should containInAnyOrder(names)
}

// check files uploaded by voyager
val files = FileSystems
.`match`(s"$tempLocation/voyager-*")
.metadata()
.asScala
.map(_.resourceId())

FileSystems.delete(files.asJava)
}

it should "throw exception when Voyager file exists" in {
val uri = VoyagerUri(ItUtils.gcpTempLocation("voyager-it"))
val indexUri = uri.value.resolve(VoyagerUri.IndexFile)
val nameUri = uri.value.resolve(VoyagerUri.NamesFile)
val indexResourceId = FileSystems.matchNewResource(indexUri.toString, false)
val nameResourceId = FileSystems.matchNewResource(nameUri.toString, false)

// write some data in the
val f1 = FileSystems.create(nameResourceId, MimeTypes.BINARY)
val f2 = FileSystems.create(indexResourceId, MimeTypes.BINARY)
try {
f1.write(ByteBuffer.wrap("test-data".getBytes()))
f2.write(ByteBuffer.wrap("test-data".getBytes()))
} finally {
f1.close()
f2.close()
}

val e = the[IllegalArgumentException] thrownBy {
runWithContext { sc =>
sc.parallelize(sideData).asVoyager(uri, distanceMeasure, storageType, dim, 200L, 16)
}
}

e.getMessage shouldBe s"requirement failed: Voyager URI ${uri.value} already exists"

FileSystems.delete(Seq(nameResourceId, indexResourceId).asJava)
}
}
188 changes: 188 additions & 0 deletions scio-extra/src/main/scala/com/spotify/scio/extra/voyager/Voyager.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Copyright 2023 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.extra.voyager

import com.spotify.scio.util.{RemoteFileUtil, ScioUtil}
import com.spotify.scio.values.SideInput
import com.spotify.voyager.jni.Index.{SpaceType, StorageDataType}
import com.spotify.voyager.jni.{Index, StringIndex}
import org.apache.beam.sdk.transforms.DoFn
import org.apache.beam.sdk.values.PCollectionView

import java.net.URI
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Path, Paths}
import java.util.concurrent.ConcurrentHashMap

/**
* Represents the base URI for a voyager index, either on a local or a remote file system. For
* remote file systems, the `path` should be in the form 'scheme://<bucket>/<path>/'. For local
* files, it should be in the form '/<path>/'. The `uri` specified represents the directory where
* the `index.hnsw` and `names.json` are.
*/
final case class VoyagerUri(value: URI) extends AnyVal {

import VoyagerUri._

def exists(implicit remoteFileUtil: RemoteFileUtil): Boolean = {
if (ScioUtil.isLocalUri(value)) {
VoyagerFiles.exists(f => Paths.get(value.resolve(f)).toFile.exists())
} else {
VoyagerFiles.exists(f => remoteFileUtil.remoteExists(value.resolve(f)))
}
}
}

object VoyagerUri {
def apply(value: String): VoyagerUri = new VoyagerUri(new URI(value))

private[voyager] val IndexFile = "index.hnsw"
private[voyager] val NamesFile = "names.json"

private[voyager] val VoyagerFiles: Seq[String] = Seq(IndexFile, NamesFile)

}

/** Result of a voyager query */
final case class VoyagerResult(name: String, distance: Float)

class VoyagerWriter private[voyager] (
indexFile: Path,
namesFile: Path,
spaceType: SpaceType,
storageDataType: StorageDataType,
dim: Int,
ef: Long = 200L,
m: Long = 16L
) {
import VoyagerWriter._

def write(vectors: Iterable[(String, Array[Float])]): Unit = {
val indexOutputStream = Files.newOutputStream(indexFile)
val namesOutputStream = Files.newOutputStream(namesFile)

val names = List.newBuilder[String]
val index = new Index(spaceType, dim, m, ef, RandomSeed, ChunkSize.toLong, storageDataType)

vectors.zipWithIndex
.map { case ((name, vector), idx) => (name, vector, idx.toLong) }
.grouped(ChunkSize)
.map(_.unzip3)
.foreach { case (ns, vs, is) =>
names ++= ns
index.addItems(vs.toArray, is.toArray, -1)
}

// save index
index.saveIndex(indexOutputStream)
index.close()
// save names
val json = names.result().mkString("[\"", "\",\"", "\"]")
namesOutputStream.write(json.getBytes(StandardCharsets.UTF_8))
// close
indexOutputStream.close()
namesOutputStream.close()
}
}

private object VoyagerWriter {
private val RandomSeed: Long = 1L
private val ChunkSize: Int = 32786 // 2^15
}

/**
* Voyager reader class for nearest neighbor lookups. Supports looking up neighbors for a vector and
* returning the string labels and distances associated.
*
* @param indexFile
* The `index.hnsw` file.
* @param namesFile
* The `names.json` file.
* @param spaceType
* The measurement for computing distance between entities. One of Euclidean, Cosine or Dot (inner
* product).
* @param storageDataType
* The Storage type of the vectors at rest. One of Float8, Float32 or E4M3.
* @param dim
* Number of dimensions in vectors.
*/
class VoyagerReader private[voyager] (
indexFile: Path,
namesFile: Path,
spaceType: SpaceType,
storageDataType: StorageDataType,
dim: Int
) {
require(dim > 0, "Vector dimension should be > 0")

@transient private lazy val index: StringIndex =
StringIndex.load(indexFile.toString, namesFile.toString, spaceType, dim, storageDataType)

/**
* Gets maxNumResults nearest neighbors for vector v using ef (where ef is the size of the dynamic
* list for the nearest neighbors during search).
*/
def getNearest(v: Array[Float], maxNumResults: Int, ef: Int): Array[VoyagerResult] = {
val queryResults = index.query(v, maxNumResults, ef)
queryResults.getNames
.zip(queryResults.getDistances)
.map { case (name, distance) => VoyagerResult(name, distance) }
}
}

/**
* Construction for a VoyagerSide input that leverages a synchronized map to ensure that the reader
* is only loaded once per [[VoyagerUri]].
*/
private[voyager] class VoyagerSideInput(
val view: PCollectionView[VoyagerUri],
remoteFileUtil: RemoteFileUtil,
distanceMeasure: SpaceType,
storageType: StorageDataType,
dim: Int
) extends SideInput[VoyagerReader] {

import VoyagerSideInput._

private def createReader(uri: VoyagerUri): VoyagerReader = {
val indexUri = uri.value.resolve(VoyagerUri.IndexFile)
val namesUri = uri.value.resolve(VoyagerUri.NamesFile)

val (localIndex, localNames) = if (ScioUtil.isLocalUri(uri.value)) {
(Paths.get(indexUri), Paths.get(namesUri))
} else {
val downloadedIndex = remoteFileUtil.download(indexUri)
val downloadedNames = remoteFileUtil.download(namesUri)
(downloadedIndex, downloadedNames)
}
new VoyagerReader(localIndex, localNames, distanceMeasure, storageType, dim)
}

override def get[I, O](context: DoFn[I, O]#ProcessContext): VoyagerReader = {
val uri = context.sideInput(view)
VoyagerReaderSharedCache.computeIfAbsent(uri, createReader)
}
}

private object VoyagerSideInput {
// cache the VoyagerUri to VoyagerReader per JVM so workers with multiple
// voyager side-input steps load the index only once
@transient private lazy val VoyagerReaderSharedCache
: ConcurrentHashMap[VoyagerUri, VoyagerReader] =
new ConcurrentHashMap()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2023 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.extra

import com.spotify.scio.extra.voyager.syntax.AllSyntax

/** Main package for Voyager side input APIs. */
package object voyager extends AllSyntax
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2023 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.extra.voyager.syntax

trait AllSyntax extends ScioContextSyntax with SCollectionSyntax
Loading

0 comments on commit 7902583

Please sign in to comment.