Skip to content

Commit

Permalink
add support for cuml hdbscan membership_vector
Browse files Browse the repository at this point in the history
cuml version 23.04 now has membership_vector, so this allows
topic_model.transform() to calculate the probability matrix if
using a cuml-based hdbscan model
  • Loading branch information
stevetracvc committed Jun 8, 2023
1 parent fca5a4f commit 2fce29b
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions bertopic/cluster/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import hdbscan
import numpy as np


def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
""" Function used to select the HDBSCAN-like model for generating

def hdbscan_delegator(model, func: str,
embeddings: np.ndarray = None,
batch_size: int = 4096 ):
""" Function used to select the HDBSCAN-like model for generating
predictions and probabilities.
Arguments:
Expand All @@ -14,6 +16,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
- "membership_vector"
embeddings: Input embeddings for "approximate_predict"
and "membership_vector"
batch_size: batch_size for cuml hdbscan
"""

# Approximate predict
Expand Down Expand Up @@ -42,7 +45,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
return cuml_hdbscan.all_points_membership_vectors(model)

return None

# membership_vector
if func == "membership_vector":
if isinstance(model, hdbscan.HDBSCAN):
Expand All @@ -51,8 +54,15 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):

str_type_model = str(type(model)).lower()
if "cuml" in str_type_model and "hdbscan" in str_type_model:
from cuml.cluster.hdbscan.prediction import approximate_predict
probabilities = approximate_predict(model, embeddings)
from cuml.cluster.hdbscan import prediction
try:
probabilities = prediction.membership_vector(
model, embeddings,
# bacth size cannot be larger than the number of docs
batch_size=min(embeddings.shape[0], batch_size))
# membership_vector available in cuml 23.04 and up
except ImportError:
probabilities = prediction.approximate_predict(model, embeddings)
return probabilities

return None
Expand Down

0 comments on commit 2fce29b

Please sign in to comment.