From 2fce29b0e609a52d6668163350149ac574ff67fb Mon Sep 17 00:00:00 2001 From: Steve Date: Thu, 8 Jun 2023 11:10:32 -0600 Subject: [PATCH 1/6] add support for cuml hdbscan membership_vector cuml version 23.04 now has membership_vector, so this allows topic_model.transform() to calculate the probability matrix if using a cuml-based hdbscan model --- bertopic/cluster/_utils.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 355a53f6..656a9c42 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -1,9 +1,11 @@ import hdbscan import numpy as np - -def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): - """ Function used to select the HDBSCAN-like model for generating + +def hdbscan_delegator(model, func: str, + embeddings: np.ndarray = None, + batch_size: int = 4096 ): + """ Function used to select the HDBSCAN-like model for generating predictions and probabilities. Arguments: @@ -14,6 +16,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): - "membership_vector" embeddings: Input embeddings for "approximate_predict" and "membership_vector" + batch_size: batch_size for cuml hdbscan """ # Approximate predict @@ -42,7 +45,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): return cuml_hdbscan.all_points_membership_vectors(model) return None - + # membership_vector if func == "membership_vector": if isinstance(model, hdbscan.HDBSCAN): @@ -51,8 +54,15 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: - from cuml.cluster.hdbscan.prediction import approximate_predict - probabilities = approximate_predict(model, embeddings) + from cuml.cluster.hdbscan import prediction + try: + probabilities = prediction.membership_vector( + model, embeddings, + # bacth size cannot be larger than the number of docs + batch_size=min(embeddings.shape[0], batch_size)) + # membership_vector available in cuml 23.04 and up + except ImportError: + probabilities = prediction.approximate_predict(model, embeddings) return probabilities return None From 0c7058ca18a36b80b46facb5cb09bd3c3caf64fd Mon Sep 17 00:00:00 2001 From: Steve Date: Fri, 9 Jun 2023 15:50:42 -0600 Subject: [PATCH 2/6] removed batch_size param --- bertopic/cluster/_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 656a9c42..72a73719 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -2,9 +2,7 @@ import numpy as np -def hdbscan_delegator(model, func: str, - embeddings: np.ndarray = None, - batch_size: int = 4096 ): +def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): """ Function used to select the HDBSCAN-like model for generating predictions and probabilities. @@ -16,7 +14,6 @@ def hdbscan_delegator(model, func: str, - "membership_vector" embeddings: Input embeddings for "approximate_predict" and "membership_vector" - batch_size: batch_size for cuml hdbscan """ # Approximate predict @@ -59,7 +56,8 @@ def hdbscan_delegator(model, func: str, probabilities = prediction.membership_vector( model, embeddings, # bacth size cannot be larger than the number of docs - batch_size=min(embeddings.shape[0], batch_size)) + # this will be unnecessary in cuml 23.08 + batch_size=min(embeddings.shape[0], 4096)) # membership_vector available in cuml 23.04 and up except ImportError: probabilities = prediction.approximate_predict(model, embeddings) From 4e5f0280b5de2395561c229d0ad442d2e7b38e5a Mon Sep 17 00:00:00 2001 From: Steve Date: Fri, 9 Jun 2023 15:51:03 -0600 Subject: [PATCH 3/6] fixed exception handler --- bertopic/cluster/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 72a73719..7054bed3 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -59,7 +59,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): # this will be unnecessary in cuml 23.08 batch_size=min(embeddings.shape[0], 4096)) # membership_vector available in cuml 23.04 and up - except ImportError: + except AttributeError: probabilities = prediction.approximate_predict(model, embeddings) return probabilities From b971bf602ee739499b4ceb54128f17bba0122c6b Mon Sep 17 00:00:00 2001 From: Steve Date: Fri, 9 Jun 2023 15:52:01 -0600 Subject: [PATCH 4/6] fixed return value for earlier versions of cuml --- bertopic/cluster/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 7054bed3..0f174a18 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -60,7 +60,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): batch_size=min(embeddings.shape[0], 4096)) # membership_vector available in cuml 23.04 and up except AttributeError: - probabilities = prediction.approximate_predict(model, embeddings) + _, probabilities = prediction.approximate_predict(model, embeddings) return probabilities return None From 9ee5bc59e33879b88eb546c6d6c3b2431574f15d Mon Sep 17 00:00:00 2001 From: Steve Date: Fri, 9 Jun 2023 17:53:32 -0600 Subject: [PATCH 5/6] new tests for cuml models --- tests/conftest.py | 13 +++++++++++++ tests/test_bertopic.py | 11 ++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0418ac5e..2f013b29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,3 +126,16 @@ def online_topic_model(documents, document_embeddings, embedding_model): topics.extend(model.topics_) model.topics_ = topics return model + + +@pytest.fixture(scope="session") +def cuml_base_topic_model(documents, document_embeddings, embedding_model): + from cuml import HDBSCAN as cuml_hdbscan, UMAP as cuml_umap + model = BERTopic(embedding_model=embedding_model, + calculate_probabilities=True, + umap_model=cuml_umap(random_state=42), + hdbscan_model=cuml_hdbscan( + min_cluster_size=3, + prediction_data=True)) + model.fit(documents, document_embeddings) + return model diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index be5904e7..64559bdc 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -3,7 +3,7 @@ from bertopic import BERTopic -@pytest.mark.parametrize('model', [("base_topic_model"), ('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model'), ('supervised_topic_model'), ('representation_topic_model')]) +@pytest.mark.parametrize('model', [("base_topic_model"), ('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model'), ('supervised_topic_model'), ('representation_topic_model'), ('cuml_base_topic_model')]) def test_full_model(model, documents, request): """ Tests the entire pipeline in one go. This serves as a sanity check to see if the default settings result in a good separation of topics. @@ -110,3 +110,12 @@ def test_full_model(model, documents, request): # if topic_model.topic_embeddings_ is not None: # topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True) # loaded_model = BERTopic.load("model_dir") + +def test_cuml(cuml_base_topic_model, documents, request, monkeypatch): + """Specific tests for cuml-based models.""" + + # make sure calculating probabilities does not fail if the cuml version + # does not yet support membership_vector (cuml 23.04 and higher) + with monkeypatch.context() as m: + m.delattr('cuml.cluster.hdbscan.prediction.membership_vector', raising=False) + predictions, probabilities = cuml_base_topic_model.transform(documents) From 74335fa690944976dd595c062a19cbdf0d777c84 Mon Sep 17 00:00:00 2001 From: Steve Date: Sat, 10 Jun 2023 07:07:03 -0600 Subject: [PATCH 6/6] fix tests if cuml not installed --- tests/conftest.py | 21 ++++++++++++--------- tests/test_bertopic.py | 6 ++++++ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2f013b29..57ab7223 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -130,12 +130,15 @@ def online_topic_model(documents, document_embeddings, embedding_model): @pytest.fixture(scope="session") def cuml_base_topic_model(documents, document_embeddings, embedding_model): - from cuml import HDBSCAN as cuml_hdbscan, UMAP as cuml_umap - model = BERTopic(embedding_model=embedding_model, - calculate_probabilities=True, - umap_model=cuml_umap(random_state=42), - hdbscan_model=cuml_hdbscan( - min_cluster_size=3, - prediction_data=True)) - model.fit(documents, document_embeddings) - return model + try: + from cuml import HDBSCAN as cuml_hdbscan, UMAP as cuml_umap + model = BERTopic(embedding_model=embedding_model, + calculate_probabilities=True, + umap_model=cuml_umap(random_state=42), + hdbscan_model=cuml_hdbscan( + min_cluster_size=3, + prediction_data=True)) + model.fit(documents, document_embeddings) + return model + except ModuleNotFoundError: + return None diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index 64559bdc..8b5952eb 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -11,6 +11,9 @@ def test_full_model(model, documents, request): NOTE: This does not cover all cases but merely combines it all together """ topic_model = copy.deepcopy(request.getfixturevalue(model)) + if model == 'cuml_base_topic_model' and topic_model is None: + # cuml not installed, can't run test + return if model == "base_topic_model": topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2") topic_model = BERTopic.load("model_dir") @@ -114,6 +117,9 @@ def test_full_model(model, documents, request): def test_cuml(cuml_base_topic_model, documents, request, monkeypatch): """Specific tests for cuml-based models.""" + if cuml_base_topic_model is None: + # cuml not installed, can't run test + return # make sure calculating probabilities does not fail if the cuml version # does not yet support membership_vector (cuml 23.04 and higher) with monkeypatch.context() as m: