Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ScaNN to the list of supported ANN frameworks #553

Merged
merged 9 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ One important aspect of deploying recommender model is efficient retrieval via A

| Supported framework | Cornac wrapper | Examples |
| :---: | :---: | :---: |
| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb)
| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb), [ann_all.ipynb](examples/ann_all.ipynb)
| [google/scann](https://github.com/google-research/google-research/tree/master/scann) | [ScaNNANN](cornac/models/ann/recom_ann_scann.py) | [ann_all.ipynb](examples/ann_all.ipynb)


## Models
Expand Down
1 change: 1 addition & 0 deletions cornac/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .amr import AMR
from .ann import HNSWLibANN
from .ann import ScaNNANN
from .baseline_only import BaselineOnly
from .bivaecf import BiVAECF
from .bpr import BPR
Expand Down
1 change: 1 addition & 0 deletions cornac/models/ann/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .recom_ann_hnswlib import HNSWLibANN
from .recom_ann_scann import ScaNNANN
7 changes: 6 additions & 1 deletion cornac/models/ann/recom_ann_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================


import copy
import numpy as np

from ..recommender import Recommender
Expand Down Expand Up @@ -41,6 +41,11 @@ def __init__(self, model, name="BaseANN", verbose=False):
if not is_ann_supported(model):
raise ValueError(f"{model.name} doesn't support ANN search")

# ANN required attributes
self.measure = copy.deepcopy(model.get_vector_measure())
self.user_vectors = copy.deepcopy(model.get_user_vectors())
self.item_vectors = copy.deepcopy(model.get_item_vectors())

# get basic attributes to be a proper recommender
super().fit(train_set=model.train_set, val_set=model.val_set)

Expand Down
5 changes: 0 additions & 5 deletions cornac/models/ann/recom_ann_hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,6 @@ def __init__(
)
self.seed = seed

# ANN required attributes
self.measure = model.get_vector_measure()
self.user_vectors = model.get_user_vectors()
self.item_vectors = model.get_item_vectors()

self.index = None
self.ignored_attrs.extend(
[
Expand Down
164 changes: 164 additions & 0 deletions cornac/models/ann/recom_ann_scann.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2023 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================


import os
import multiprocessing
import numpy as np

from ..recommender import MEASURE_L2, MEASURE_DOT, MEASURE_COSINE
from .recom_ann_base import BaseANN


SUPPORTED_MEASURES = {MEASURE_L2: "squared_l2", MEASURE_DOT: "dot_product"}


class ScaNNANN(BaseANN):
"""Approximate Nearest Neighbor Search with ScaNN
(https://github.com/google-research/google-research/tree/master/scann).
ScaNN performs vector search in three phases: paritioning, scoring, and rescoring.
More on the algorithms and parameter description: https://github.com/google-research/google-research/blob/master/scann/docs/algorithms.md

Parameters
----------------
model: object: :obj:`cornac.models.Recommender`, required
Trained recommender model which to get user/item vectors from.

partition_params: dict, optional
Parameters for the partitioning phase, to send to the tree() call in ScaNN.

score_params: dict, optional
Parameters for the scoring phase, to send to the score_ah() call in ScaNN.
score_brute_force() will be called if score_brute_force is True.

score_brute_force: bool, optional, default: False
Whether to call score_brute_force() for the scoring phase.

rescore_params: dict, optional
Parameters for the rescoring phase, to send to the reorder() call in ScaNN.

num_threads: int, optional, default: -1
Default number of threads used for training. If num_threads = -1, all cores will be used.

seed: int, optional, default: None
Random seed for reproducibility.

name: str, required
Name of the recommender model.

verbose: boolean, optional, default: False
When True, running logs are displayed.
"""

def __init__(
self,
model,
partition_params=None,
score_params=None,
score_brute_force=False,
rescore_params=None,
num_threads=-1,
seed=None,
name="ScaNNANN",
verbose=False,
):
super().__init__(model=model, name=name, verbose=verbose)

if score_params is None:
score_params = {}

self.model = model
self.partition_params = partition_params
self.score_params = score_params
self.score_brute_force = score_brute_force
self.rescore_params = rescore_params
self.num_threads = (
num_threads if num_threads != -1 else multiprocessing.cpu_count()
)
self.seed = seed

self.index = None
self.ignored_attrs.extend(
[
"index", # will be saved separately
"item_vectors", # redundant after index is built
]
)

def build_index(self):
"""Building index from the base recommender model."""
import scann

assert self.measure in SUPPORTED_MEASURES

if self.measure == MEASURE_COSINE:
self.partition_params["spherical"] = True
self.item_vectors /= np.linalg.norm(self.item_vectors, axis=1)[
:, np.newaxis
]
self.measure = MEASURE_DOT
else:
self.partition_params["spherical"] = False

index_builder = scann.scann_ops_pybind.builder(
self.item_vectors, 10, SUPPORTED_MEASURES[self.measure]
)
index_builder.set_n_training_threads(self.num_threads)

# partitioning
if self.partition_params:
self.partition_params.setdefault(
"training_sample_size", self.item_vectors.shape[0]
)
index_builder = index_builder.tree(**self.partition_params)

# scoring
if self.score_brute_force:
index_builder = index_builder.score_brute_force(**self.score_params)
else:
index_builder = index_builder.score_ah(**self.score_params)

# rescoring
if self.rescore_params:
index_builder = index_builder.reorder(**self.rescore_params)

self.index = index_builder.build()

def knn_query(self, query, k):
"""Implementing ANN search for a given query.

Returns
-------
neighbors, distances: numpy.array and numpy.array
Array of k-nearest neighbors and corresponding distances for the given query.
"""
neighbors, distances = self.index.search_batched(query, final_num_neighbors=k)
return neighbors, distances

def save(self, save_dir=None):
saved_path = super().save(save_dir)
idx_path = saved_path + ".idx"
os.makedirs(idx_path, exist_ok=True)
self.index.searcher.serialize(idx_path)
return saved_path

@staticmethod
def load(model_path, trainable=False):
from scann.scann_ops.py import scann_ops_pybind

ann = BaseANN.load(model_path, trainable)
idx_path = ann.load_from + ".idx"
ann.index = scann_ops_pybind.load_searcher(idx_path)
return ann
Loading
Loading