diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 355a53f6..656a9c42 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -1,9 +1,11 @@ import hdbscan import numpy as np - -def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): - """ Function used to select the HDBSCAN-like model for generating + +def hdbscan_delegator(model, func: str, + embeddings: np.ndarray = None, + batch_size: int = 4096 ): + """ Function used to select the HDBSCAN-like model for generating predictions and probabilities. Arguments: @@ -14,6 +16,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): - "membership_vector" embeddings: Input embeddings for "approximate_predict" and "membership_vector" + batch_size: batch_size for cuml hdbscan """ # Approximate predict @@ -42,7 +45,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): return cuml_hdbscan.all_points_membership_vectors(model) return None - + # membership_vector if func == "membership_vector": if isinstance(model, hdbscan.HDBSCAN): @@ -51,8 +54,15 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: - from cuml.cluster.hdbscan.prediction import approximate_predict - probabilities = approximate_predict(model, embeddings) + from cuml.cluster.hdbscan import prediction + try: + probabilities = prediction.membership_vector( + model, embeddings, + # bacth size cannot be larger than the number of docs + batch_size=min(embeddings.shape[0], batch_size)) + # membership_vector available in cuml 23.04 and up + except ImportError: + probabilities = prediction.approximate_predict(model, embeddings) return probabilities return None