From a385530eaf34ad6bfaf21558bf16aa12a4a9980d Mon Sep 17 00:00:00 2001 From: tqtg Date: Tue, 28 Nov 2023 18:40:50 +0000 Subject: [PATCH 1/9] add ScaNN --- cornac/models/ann/recom_ann_base.py | 7 +- cornac/models/ann/recom_ann_hnswlib.py | 5 - cornac/models/ann/recom_ann_scann.py | 164 +++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 6 deletions(-) create mode 100644 cornac/models/ann/recom_ann_scann.py diff --git a/cornac/models/ann/recom_ann_base.py b/cornac/models/ann/recom_ann_base.py index e29270b0a..c256a721e 100644 --- a/cornac/models/ann/recom_ann_base.py +++ b/cornac/models/ann/recom_ann_base.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ - +import copy import numpy as np from ..recommender import Recommender @@ -41,6 +41,11 @@ def __init__(self, model, name="BaseANN", verbose=False): if not is_ann_supported(model): raise ValueError(f"{model.name} doesn't support ANN search") + # ANN required attributes + self.measure = copy.deepcopy(model.get_vector_measure()) + self.user_vectors = copy.deepcopy(model.get_user_vectors()) + self.item_vectors = copy.deepcopy(model.get_item_vectors()) + # get basic attributes to be a proper recommender super().fit(train_set=model.train_set, val_set=model.val_set) diff --git a/cornac/models/ann/recom_ann_hnswlib.py b/cornac/models/ann/recom_ann_hnswlib.py index 5f960d628..ff4e68547 100644 --- a/cornac/models/ann/recom_ann_hnswlib.py +++ b/cornac/models/ann/recom_ann_hnswlib.py @@ -86,11 +86,6 @@ def __init__( ) self.seed = seed - # ANN required attributes - self.measure = model.get_vector_measure() - self.user_vectors = model.get_user_vectors() - self.item_vectors = model.get_item_vectors() - self.index = None self.ignored_attrs.extend( [ diff --git a/cornac/models/ann/recom_ann_scann.py b/cornac/models/ann/recom_ann_scann.py new file mode 100644 index 000000000..79b62cbc9 --- /dev/null +++ b/cornac/models/ann/recom_ann_scann.py @@ -0,0 +1,164 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import multiprocessing +import numpy as np + +from ..recommender import MEASURE_L2, MEASURE_DOT, MEASURE_COSINE +from .recom_ann_base import BaseANN + + +SUPPORTED_MEASURES = {MEASURE_L2: "squared_l2", MEASURE_DOT: "dot_product"} + + +class ScaNNANN(BaseANN): + """Approximate Nearest Neighbor Search with ScaNN + (https://github.com/google-research/google-research/tree/master/scann). + ScaNN performs vector search in three phases: paritioning, scoring, and rescoring. + More on the algorithms and parameter description: https://github.com/google-research/google-research/blob/master/scann/docs/algorithms.md + + Parameters + ---------------- + model: object: :obj:`cornac.models.Recommender`, required + Trained recommender model which to get user/item vectors from. + + num_neighbors: int, optional + The default number of neighbors/items to be returned. + + parition_params: dict, optional + Parameters for the partitioning phase, to send to the tree() call in ScaNN. + + score_params: dict, optional + Parameters for the scoring phase, to send to the score_ah() call in ScaNN. + score_brute_force() will be called if score_brute_force is True. + + score_brute_force: bool, optional, default: False + Whether to call score_brute_force() for the scoring phase. + + rescore_params: dict, optional + Parameters for the rescoring phase, to send to the reorder() call in ScaNN. + + num_threads: int, optional, default: -1 + Default number of threads to use when querying. If num_threads = -1, all cores will be used. + + seed: int, optional, default: None + Random seed for reproducibility. + + name: str, required + Name of the recommender model. + + verbose: boolean, optional, default: False + When True, running logs are displayed. + """ + + def __init__( + self, + model, + num_neighbors=10, + parition_params=None, + score_params=None, + score_brute_force=False, + rescore_params=None, + num_threads=-1, + seed=None, + name="ScaNNANN", + verbose=False, + ): + super().__init__(model=model, name=name, verbose=verbose) + + if score_params is None: + score_params = {} + + self.model = model + self.num_neighbors = num_neighbors + self.parition_params = parition_params + self.score_params = score_params + self.score_brute_force = score_brute_force + self.rescore_params = rescore_params + self.num_threads = ( + num_threads if num_threads != -1 else multiprocessing.cpu_count() + ) + self.seed = seed + + self.index = None + self.ignored_attrs.extend( + [ + "index", # will be saved separately + "item_vectors", # redundant after index is built + ] + ) + + def build_index(self): + """Building index from the base recommender model.""" + import scann + + assert self.measure in SUPPORTED_MEASURES + + if self.measure == MEASURE_COSINE: + self.item_vectors = ( + self.item_vectors + / np.linalg.norm(self.item_vectors, axis=1)[:, np.newaxis] + ) + self.measure = MEASURE_DOT + + index_builder = scann.scann_ops_pybind.builder( + self.item_vectors, self.num_neighbors, SUPPORTED_MEASURES[self.measure] + ) + + # partitioning + if self.parition_params: + index_builder = index_builder.tree(**self.parition_params) + + # scoring + if self.score_brute_force: + index_builder = index_builder.score_brute_force(**self.score_params) + else: + index_builder = index_builder.score_ah(**self.score_params) + + # rescoring + if self.rescore_params: + index_builder = index_builder.reorder(**self.rescore_params) + + self.index = index_builder.build() + + def knn_query(self, query, k): + """Implementing ANN search for a given query. + + Returns + ------- + neighbors, distances: numpy.array and numpy.array + Array of k-nearest neighbors and corresponding distances for the given query. + """ + neighbors, distances = self.index.search_batched(query, leaves_to_search=k) + return neighbors, distances + + # def save(self, save_dir=None): + # saved_path = super().save(save_dir) + # self.index.save_index(saved_path + ".idx") + # return saved_path + + # @staticmethod + # def load(model_path, trainable=False): + # import hnswlib + + # ann = BaseANN.load(model_path, trainable) + # ann.index = hnswlib.Index( + # space=SUPPORTED_MEASURES[ann.measure], dim=ann.user_vectors.shape[1] + # ) + # ann.index.load_index(ann.load_from + ".idx") + # ann.index.set_ef(ann.ef) + # ann.index.set_num_threads(ann.num_threads) + # return ann From 9d729cb255317edb983ce68dcd19284b15dd21c7 Mon Sep 17 00:00:00 2001 From: tqtg Date: Fri, 1 Dec 2023 00:18:21 +0000 Subject: [PATCH 2/9] add ScaNNANN --- cornac/models/__init__.py | 1 + cornac/models/ann/__init__.py | 1 + cornac/models/ann/recom_ann_scann.py | 46 +++++++++++++--------------- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index c43c3fbd3..05db7ea53 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -18,6 +18,7 @@ from .amr import AMR from .ann import HNSWLibANN +from .ann import ScaNNANN from .baseline_only import BaselineOnly from .bivaecf import BiVAECF from .bpr import BPR diff --git a/cornac/models/ann/__init__.py b/cornac/models/ann/__init__.py index c89556c69..77fd8630b 100644 --- a/cornac/models/ann/__init__.py +++ b/cornac/models/ann/__init__.py @@ -1 +1,2 @@ from .recom_ann_hnswlib import HNSWLibANN +from .recom_ann_scann import ScaNNANN diff --git a/cornac/models/ann/recom_ann_scann.py b/cornac/models/ann/recom_ann_scann.py index 79b62cbc9..bdb779dce 100644 --- a/cornac/models/ann/recom_ann_scann.py +++ b/cornac/models/ann/recom_ann_scann.py @@ -14,6 +14,7 @@ # ============================================================================ +import os import multiprocessing import numpy as np @@ -35,10 +36,10 @@ class ScaNNANN(BaseANN): model: object: :obj:`cornac.models.Recommender`, required Trained recommender model which to get user/item vectors from. - num_neighbors: int, optional + num_neighbors: int, optional, default: 100 The default number of neighbors/items to be returned. - parition_params: dict, optional + partition_params: dict, optional Parameters for the partitioning phase, to send to the tree() call in ScaNN. score_params: dict, optional @@ -67,8 +68,8 @@ class ScaNNANN(BaseANN): def __init__( self, model, - num_neighbors=10, - parition_params=None, + num_neighbors=100, + partition_params=None, score_params=None, score_brute_force=False, rescore_params=None, @@ -84,7 +85,7 @@ def __init__( self.model = model self.num_neighbors = num_neighbors - self.parition_params = parition_params + self.partition_params = partition_params self.score_params = score_params self.score_brute_force = score_brute_force self.rescore_params = rescore_params @@ -119,8 +120,8 @@ def build_index(self): ) # partitioning - if self.parition_params: - index_builder = index_builder.tree(**self.parition_params) + if self.partition_params: + index_builder = index_builder.tree(**self.partition_params) # scoring if self.score_brute_force: @@ -142,23 +143,18 @@ def knn_query(self, query, k): neighbors, distances: numpy.array and numpy.array Array of k-nearest neighbors and corresponding distances for the given query. """ - neighbors, distances = self.index.search_batched(query, leaves_to_search=k) + neighbors, distances = self.index.search_batched(query, final_num_neighbors=k) return neighbors, distances - # def save(self, save_dir=None): - # saved_path = super().save(save_dir) - # self.index.save_index(saved_path + ".idx") - # return saved_path - - # @staticmethod - # def load(model_path, trainable=False): - # import hnswlib - - # ann = BaseANN.load(model_path, trainable) - # ann.index = hnswlib.Index( - # space=SUPPORTED_MEASURES[ann.measure], dim=ann.user_vectors.shape[1] - # ) - # ann.index.load_index(ann.load_from + ".idx") - # ann.index.set_ef(ann.ef) - # ann.index.set_num_threads(ann.num_threads) - # return ann + def save(self, save_dir=None): + saved_path = super().save(save_dir) + self.index.searcher.serialize(os.path.dirname(saved_path)) + return saved_path + + @staticmethod + def load(model_path, trainable=False): + from scann.scann_ops.py import scann_ops_pybind + + ann = BaseANN.load(model_path, trainable) + ann.index = scann_ops_pybind.load_searcher(os.path.dirname(model_path)) + return ann From ed63068df77c64d3294c3fc9a32cad16b78ccd6c Mon Sep 17 00:00:00 2001 From: tqtg Date: Fri, 1 Dec 2023 00:19:17 +0000 Subject: [PATCH 3/9] add ann_all.ipynb --- cornac/models/ann/ann_all.ipynb | 328 ++++++++++++++++++++++++++++++++ 1 file changed, 328 insertions(+) create mode 100644 cornac/models/ann/ann_all.ipynb diff --git a/cornac/models/ann/ann_all.ipynb b/cornac/models/ann/ann_all.ipynb new file mode 100644 index 000000000..55507f580 --- /dev/null +++ b/cornac/models/ann/ann_all.ipynb @@ -0,0 +1,328 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "b9a4225b-1a05-4b58-9e1d-1511650ef225", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q hnswlib scann" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "74a9e78f-3e8a-4ee2-89fe-b3a3f4784b53", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import cornac\n", + "from cornac.data import Reader\n", + "from cornac.datasets import netflix\n", + "from cornac.eval_methods import RatioSplit\n", + "from cornac.models import MF" + ] + }, + { + "cell_type": "markdown", + "id": "cf6bb9a5-ffb5-4221-8122-9aa286af1d9c", + "metadata": {}, + "source": [ + "## Train a base recommender model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "76a0c130-7dd7-4004-a613-5b123dcc75d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rating_threshold = 1.0\n", + "exclude_unknowns = True\n", + "---\n", + "Training data:\n", + "Number of users = 9986\n", + "Number of items = 4921\n", + "Number of ratings = 547022\n", + "Max rating = 1.0\n", + "Min rating = 1.0\n", + "Global mean = 1.0\n", + "---\n", + "Test data:\n", + "Number of users = 9986\n", + "Number of items = 4921\n", + "Number of ratings = 60747\n", + "Number of unknown users = 0\n", + "Number of unknown items = 0\n", + "---\n", + "Total users = 9986\n", + "Total items = 4921\n", + "\n", + "[MF] Training started!\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e509a63955224732bcd98492e8ade2ea", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/25 [00:00 Date: Fri, 1 Dec 2023 00:22:26 +0000 Subject: [PATCH 4/9] move ann_all into examples --- {cornac/models/ann => examples}/ann_all.ipynb | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) rename {cornac/models/ann => examples}/ann_all.ipynb (89%) diff --git a/cornac/models/ann/ann_all.ipynb b/examples/ann_all.ipynb similarity index 89% rename from cornac/models/ann/ann_all.ipynb rename to examples/ann_all.ipynb index 55507f580..6c31b6f10 100644 --- a/cornac/models/ann/ann_all.ipynb +++ b/examples/ann_all.ipynb @@ -70,7 +70,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e509a63955224732bcd98492e8ade2ea", + "model_id": "86da24ba43bf48efb96f5cdc22b56af3", "version_major": 2, "version_minor": 0 }, @@ -93,7 +93,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fb774e86fed042c9939855b25d5387ae", + "model_id": "294b3c7a3bb04a4db74de6fcaaede852", "version_major": 2, "version_minor": 0 }, @@ -113,7 +113,7 @@ "...\n", " | AUC | Recall@20 | Train (s) | Test (s)\n", "-- + ------ + --------- + --------- + --------\n", - "MF | 0.8530 | 0.0669 | 1.9536 | 10.4698\n", + "MF | 0.8530 | 0.0669 | 1.9686 | 10.3095\n", "\n" ] } @@ -189,8 +189,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 3min 37s, sys: 11.8 ms, total: 3min 37s\n", - "Wall time: 4.54 s\n" + "CPU times: user 3min 35s, sys: 15.4 ms, total: 3min 35s\n", + "Wall time: 4.5 s\n" ] } ], @@ -225,7 +225,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 9, "id": "fa280b9d-ec04-41eb-9de2-acfb67fbeb80", "metadata": {}, "outputs": [ @@ -233,8 +233,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "HNSWLibANN\t\tIndexing=94ms\t\tRetrieval=340ms\t\tRecall=0.99875\n", - "ScaNNANN\t\tIndexing=100ms\t\tRetrieval=505ms\t\tRecall=0.99998\n" + "HNSWLibANN\t\tIndexing=92ms\t\tRetrieval=213ms\t\tRecall=0.99875\n", + "ScaNNANN\t\tIndexing=99ms\t\tRetrieval=506ms\t\tRecall=0.99998\n" ] } ], @@ -245,11 +245,13 @@ "\n", "anns = [\n", " HNSWLibANN(model=mf, M=16, ef_construction=100, ef=50, seed=123, num_threads=-1),\n", - " ScaNNANN(model=mf, seed=123, num_threads=-1, num_neighbors=K,\n", - " partition_params={\"num_leaves\": 100, \"num_leaves_to_search\": 50},\n", - " score_params={\"dimensions_per_block\": 2, \"anisotropic_quantization_threshold\": 0.2}, \n", - " rescore_params={\"reordering_num_neighbors\": 100})\n", - " \n", + " ScaNNANN(\n", + " model=mf, num_neighbors=K,\n", + " partition_params={\"num_leaves\": 100, \"num_leaves_to_search\": 50},\n", + " score_params={\"dimensions_per_block\": 2, \"anisotropic_quantization_threshold\": 0.2}, \n", + " rescore_params={\"reordering_num_neighbors\": 100},\n", + " seed=123, num_threads=-1,\n", + " ),\n", "]\n", "\n", "for ann in anns:\n", @@ -277,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 10, "id": "1fec7413-482b-417b-a12f-362a93aefb85", "metadata": {}, "outputs": [ From 4b46cf34d9837f0896f4cb1eca787f782bfcaede Mon Sep 17 00:00:00 2001 From: tqtg Date: Fri, 1 Dec 2023 00:25:20 +0000 Subject: [PATCH 5/9] update README --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4eb90b9b4..243c9c26e 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,8 @@ One important aspect of deploying recommender model is efficient retrieval via A | Supported framework | Cornac wrapper | Examples | | :---: | :---: | :---: | -| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb) +| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb), [ann_all.ipynb](examples/ann_all.ipynb) +| [google/scann](https://github.com/google-research/google-research/tree/master/scann) | [ScaNNANN](cornac/models/ann/recom_ann_scann.py) | [ann_all.ipynb](examples/ann_all.ipynb) ## Models From fa8a6f3bc898907d664e84c02f9102fcbb303cf9 Mon Sep 17 00:00:00 2001 From: tqtg Date: Fri, 1 Dec 2023 00:38:27 +0000 Subject: [PATCH 6/9] mkdir for index save() and load() --- cornac/models/ann/recom_ann_scann.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/cornac/models/ann/recom_ann_scann.py b/cornac/models/ann/recom_ann_scann.py index bdb779dce..c37b7b4b1 100644 --- a/cornac/models/ann/recom_ann_scann.py +++ b/cornac/models/ann/recom_ann_scann.py @@ -116,7 +116,9 @@ def build_index(self): self.measure = MEASURE_DOT index_builder = scann.scann_ops_pybind.builder( - self.item_vectors, self.num_neighbors, SUPPORTED_MEASURES[self.measure] + db=self.item_vectors, + num_neighbors=self.num_neighbors, + distance_measure=SUPPORTED_MEASURES[self.measure], ) # partitioning @@ -148,7 +150,9 @@ def knn_query(self, query, k): def save(self, save_dir=None): saved_path = super().save(save_dir) - self.index.searcher.serialize(os.path.dirname(saved_path)) + idx_path = saved_path + ".idx" + os.makedirs(idx_path, exist_ok=True) + self.index.searcher.serialize(idx_path) return saved_path @staticmethod @@ -156,5 +160,6 @@ def load(model_path, trainable=False): from scann.scann_ops.py import scann_ops_pybind ann = BaseANN.load(model_path, trainable) - ann.index = scann_ops_pybind.load_searcher(os.path.dirname(model_path)) + idx_path = ann.load_from + ".idx" + ann.index = scann_ops_pybind.load_searcher(idx_path) return ann From 96f3aed8b7de130b71ce626230cef2446fc374b8 Mon Sep 17 00:00:00 2001 From: tqtg Date: Fri, 1 Dec 2023 00:38:36 +0000 Subject: [PATCH 7/9] update example --- examples/ann_all.ipynb | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/ann_all.ipynb b/examples/ann_all.ipynb index 6c31b6f10..41b6097b5 100644 --- a/examples/ann_all.ipynb +++ b/examples/ann_all.ipynb @@ -70,7 +70,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "86da24ba43bf48efb96f5cdc22b56af3", + "model_id": "8eb492e01d5640749891b03f83fdc922", "version_major": 2, "version_minor": 0 }, @@ -93,7 +93,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "294b3c7a3bb04a4db74de6fcaaede852", + "model_id": "8be2c032170e4b54b50931f31c1eaefb", "version_major": 2, "version_minor": 0 }, @@ -113,7 +113,7 @@ "...\n", " | AUC | Recall@20 | Train (s) | Test (s)\n", "-- + ------ + --------- + --------- + --------\n", - "MF | 0.8530 | 0.0669 | 1.9686 | 10.3095\n", + "MF | 0.8530 | 0.0669 | 1.9428 | 10.6296\n", "\n" ] } @@ -189,8 +189,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 3min 35s, sys: 15.4 ms, total: 3min 35s\n", - "Wall time: 4.5 s\n" + "CPU times: user 3min 38s, sys: 19.5 ms, total: 3min 38s\n", + "Wall time: 4.57 s\n" ] } ], @@ -233,8 +233,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "HNSWLibANN\t\tIndexing=92ms\t\tRetrieval=213ms\t\tRecall=0.99875\n", - "ScaNNANN\t\tIndexing=99ms\t\tRetrieval=506ms\t\tRecall=0.99998\n" + "HNSWLibANN\t\tIndexing=91ms\t\tRetrieval=212ms\t\tRecall=0.99875\n", + "ScaNNANN\t\tIndexing=106ms\t\tRetrieval=513ms\t\tRecall=0.99998\n" ] } ], From 38140fbc1b1aba0b3221420d00dddc9a0223dd6f Mon Sep 17 00:00:00 2001 From: tqtg Date: Fri, 1 Dec 2023 01:54:27 +0000 Subject: [PATCH 8/9] set_n_training_threads() --- cornac/models/ann/recom_ann_scann.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cornac/models/ann/recom_ann_scann.py b/cornac/models/ann/recom_ann_scann.py index c37b7b4b1..0f66821fd 100644 --- a/cornac/models/ann/recom_ann_scann.py +++ b/cornac/models/ann/recom_ann_scann.py @@ -53,7 +53,7 @@ class ScaNNANN(BaseANN): Parameters for the rescoring phase, to send to the reorder() call in ScaNN. num_threads: int, optional, default: -1 - Default number of threads to use when querying. If num_threads = -1, all cores will be used. + Default number of threads used for training. If num_threads = -1, all cores will be used. seed: int, optional, default: None Random seed for reproducibility. @@ -120,6 +120,7 @@ def build_index(self): num_neighbors=self.num_neighbors, distance_measure=SUPPORTED_MEASURES[self.measure], ) + index_builder.set_n_training_threads(self.num_threads) # partitioning if self.partition_params: From f45b089f9e1f697486a580f3d5be10c1b4d89954 Mon Sep 17 00:00:00 2001 From: tqtg Date: Fri, 1 Dec 2023 02:42:29 +0000 Subject: [PATCH 9/9] remove default num_neighbors --- cornac/models/ann/recom_ann_scann.py | 22 ++++++++++------------ examples/ann_all.ipynb | 24 ++++++++++++++++-------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/cornac/models/ann/recom_ann_scann.py b/cornac/models/ann/recom_ann_scann.py index 0f66821fd..af116f1af 100644 --- a/cornac/models/ann/recom_ann_scann.py +++ b/cornac/models/ann/recom_ann_scann.py @@ -36,9 +36,6 @@ class ScaNNANN(BaseANN): model: object: :obj:`cornac.models.Recommender`, required Trained recommender model which to get user/item vectors from. - num_neighbors: int, optional, default: 100 - The default number of neighbors/items to be returned. - partition_params: dict, optional Parameters for the partitioning phase, to send to the tree() call in ScaNN. @@ -68,7 +65,6 @@ class ScaNNANN(BaseANN): def __init__( self, model, - num_neighbors=100, partition_params=None, score_params=None, score_brute_force=False, @@ -84,7 +80,6 @@ def __init__( score_params = {} self.model = model - self.num_neighbors = num_neighbors self.partition_params = partition_params self.score_params = score_params self.score_brute_force = score_brute_force @@ -109,21 +104,24 @@ def build_index(self): assert self.measure in SUPPORTED_MEASURES if self.measure == MEASURE_COSINE: - self.item_vectors = ( - self.item_vectors - / np.linalg.norm(self.item_vectors, axis=1)[:, np.newaxis] - ) + self.partition_params["spherical"] = True + self.item_vectors /= np.linalg.norm(self.item_vectors, axis=1)[ + :, np.newaxis + ] self.measure = MEASURE_DOT + else: + self.partition_params["spherical"] = False index_builder = scann.scann_ops_pybind.builder( - db=self.item_vectors, - num_neighbors=self.num_neighbors, - distance_measure=SUPPORTED_MEASURES[self.measure], + self.item_vectors, 10, SUPPORTED_MEASURES[self.measure] ) index_builder.set_n_training_threads(self.num_threads) # partitioning if self.partition_params: + self.partition_params.setdefault( + "training_sample_size", self.item_vectors.shape[0] + ) index_builder = index_builder.tree(**self.partition_params) # scoring diff --git a/examples/ann_all.ipynb b/examples/ann_all.ipynb index 41b6097b5..856adccdb 100644 --- a/examples/ann_all.ipynb +++ b/examples/ann_all.ipynb @@ -70,7 +70,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8eb492e01d5640749891b03f83fdc922", + "model_id": "e984dd7e18d74f0090247ab9e8247797", "version_major": 2, "version_minor": 0 }, @@ -93,7 +93,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8be2c032170e4b54b50931f31c1eaefb", + "model_id": "d563eecf55a045dda657a59890e7ec37", "version_major": 2, "version_minor": 0 }, @@ -113,7 +113,7 @@ "...\n", " | AUC | Recall@20 | Train (s) | Test (s)\n", "-- + ------ + --------- + --------- + --------\n", - "MF | 0.8530 | 0.0669 | 1.9428 | 10.6296\n", + "MF | 0.8530 | 0.0669 | 1.9484 | 10.3675\n", "\n" ] } @@ -189,8 +189,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 3min 38s, sys: 19.5 ms, total: 3min 38s\n", - "Wall time: 4.57 s\n" + "CPU times: user 3min 36s, sys: 19.4 ms, total: 3min 36s\n", + "Wall time: 4.51 s\n" ] } ], @@ -233,8 +233,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "HNSWLibANN\t\tIndexing=91ms\t\tRetrieval=212ms\t\tRecall=0.99875\n", - "ScaNNANN\t\tIndexing=106ms\t\tRetrieval=513ms\t\tRecall=0.99998\n" + "HNSWLibANN\t\tIndexing=93ms\t\tRetrieval=211ms\t\tRecall=0.99875\n", + "ScaNNANN\t\tIndexing=99ms\t\tRetrieval=511ms\t\tRecall=0.99998\n" ] } ], @@ -246,7 +246,7 @@ "anns = [\n", " HNSWLibANN(model=mf, M=16, ef_construction=100, ef=50, seed=123, num_threads=-1),\n", " ScaNNANN(\n", - " model=mf, num_neighbors=K,\n", + " model=mf,\n", " partition_params={\"num_leaves\": 100, \"num_leaves_to_search\": 50},\n", " score_params={\"dimensions_per_block\": 2, \"anisotropic_quantization_threshold\": 0.2}, \n", " rescore_params={\"reordering_num_neighbors\": 100},\n", @@ -304,6 +304,14 @@ " )\n", " ) " ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9791a743-5cbe-462c-b3f3-f7909e8476ad", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {