Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for cuml hdbscan membership_vector #1324

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions bertopic/cluster/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import hdbscan
import numpy as np


def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
""" Function used to select the HDBSCAN-like model for generating
""" Function used to select the HDBSCAN-like model for generating
predictions and probabilities.

Arguments:
Expand Down Expand Up @@ -42,7 +42,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
return cuml_hdbscan.all_points_membership_vectors(model)

return None

# membership_vector
if func == "membership_vector":
if isinstance(model, hdbscan.HDBSCAN):
Expand All @@ -51,8 +51,16 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):

str_type_model = str(type(model)).lower()
if "cuml" in str_type_model and "hdbscan" in str_type_model:
from cuml.cluster.hdbscan.prediction import approximate_predict
probabilities = approximate_predict(model, embeddings)
from cuml.cluster.hdbscan import prediction
try:
probabilities = prediction.membership_vector(
model, embeddings,
# bacth size cannot be larger than the number of docs
# this will be unnecessary in cuml 23.08
batch_size=min(embeddings.shape[0], 4096))
# membership_vector available in cuml 23.04 and up
except AttributeError:
_, probabilities = prediction.approximate_predict(model, embeddings)
return probabilities

return None
Expand Down
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,19 @@ def online_topic_model(documents, document_embeddings, embedding_model):
topics.extend(model.topics_)
model.topics_ = topics
return model


@pytest.fixture(scope="session")
def cuml_base_topic_model(documents, document_embeddings, embedding_model):
try:
from cuml import HDBSCAN as cuml_hdbscan, UMAP as cuml_umap
model = BERTopic(embedding_model=embedding_model,
calculate_probabilities=True,
umap_model=cuml_umap(random_state=42),
hdbscan_model=cuml_hdbscan(
min_cluster_size=3,
prediction_data=True))
model.fit(documents, document_embeddings)
return model
except ModuleNotFoundError:
return None
17 changes: 16 additions & 1 deletion tests/test_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
from bertopic import BERTopic


@pytest.mark.parametrize('model', [("base_topic_model"), ('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model'), ('supervised_topic_model'), ('representation_topic_model')])
@pytest.mark.parametrize('model', [("base_topic_model"), ('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model'), ('supervised_topic_model'), ('representation_topic_model'), ('cuml_base_topic_model')])
def test_full_model(model, documents, request):
""" Tests the entire pipeline in one go. This serves as a sanity check to see if the default
settings result in a good separation of topics.

NOTE: This does not cover all cases but merely combines it all together
"""
topic_model = copy.deepcopy(request.getfixturevalue(model))
if model == 'cuml_base_topic_model' and topic_model is None:
# cuml not installed, can't run test
return
if model == "base_topic_model":
topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2")
topic_model = BERTopic.load("model_dir")
Expand Down Expand Up @@ -110,3 +113,15 @@ def test_full_model(model, documents, request):
# if topic_model.topic_embeddings_ is not None:
# topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True)
# loaded_model = BERTopic.load("model_dir")

def test_cuml(cuml_base_topic_model, documents, request, monkeypatch):
"""Specific tests for cuml-based models."""

if cuml_base_topic_model is None:
# cuml not installed, can't run test
return
# make sure calculating probabilities does not fail if the cuml version
# does not yet support membership_vector (cuml 23.04 and higher)
with monkeypatch.context() as m:
m.delattr('cuml.cluster.hdbscan.prediction.membership_vector', raising=False)
predictions, probabilities = cuml_base_topic_model.transform(documents)