From 13fae6850dadb04b82d3792cbc64f065dc4117a7 Mon Sep 17 00:00:00 2001 From: Kenny Lov - MBP Date: Mon, 19 Aug 2024 01:55:03 -0700 Subject: [PATCH 1/2] Adding option for custom label in get_topic_tree --- bertopic/_bertopic.py | 47 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 171136df..70024b5e 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -1847,11 +1847,12 @@ def get_representative_docs(self, topic: int = None) -> List[str]: else: return self.representative_docs_ - @staticmethod def get_topic_tree( + self, hier_topics: pd.DataFrame, max_distance: float = None, tight_layout: bool = False, + custom_labels: Union[bool, str] = False ) -> str: """Extract the topic tree such that it can be printed. @@ -1862,6 +1863,11 @@ def get_topic_tree( based on the Distance column in `hier_topics`. tight_layout: Whether to use a tight layout (narrow width) for easier readability if you have hundreds of topics. + custom_labels: If bool, whether to use custom topic labels that were defined using + `topic_model.set_topic_labels`. + If `str`, it uses labels from other aspects, e.g., "Aspect1". + NOTE: Custom labels are only generated for the original + un-merged topics. Returns: A tree that has the following structure when printed: @@ -1897,9 +1903,44 @@ def get_topic_tree( max_original_topic = hier_topics.Parent_ID.astype(int).min() - 1 + # Prepare tree labels to print + child_left_ids = hier_topics.Child_Left_ID.astype(int) + child_right_ids = hier_topics.Child_Right_ID.astype(int) + + # Get the new parent labels generated from `hierarchical_topics` + new_left_labels = {int(row['Child_Left_ID']): row['Child_Left_Name'] for idx, row in hier_topics.iterrows()} + new_right_labels = {int(row['Child_Right_ID']): row['Child_Right_Name'] for idx, row in hier_topics.iterrows()} + + if custom_labels: + left_labels = {} + if isinstance(custom_labels, str): + for topic, kws_info in self.topic_aspects_[custom_labels].items(): + label = '_'.join([kw[0] for kw in kws_info[:5]]) # displaying top 5 kws + left_labels[topic] = label + elif self.custom_labels_ is not None and custom_labels: + left_labels = {topic_id: label for topic_id, label in enumerate(self.custom_labels_, -self._outliers)} + + right_labels = left_labels.copy() + + # We want to preserve the original labels from `topic_aspects_` or `custom_labels_` + # while adding in those generated from `hierarchical_topics` + new_left_labels.update(left_labels) + new_right_labels.update(right_labels) + + child_left_names = [ + new_left_labels[topic] for topic in child_left_ids + ] + child_right_names = [ + new_right_labels[topic] for topic in child_right_ids + ] + + else: + child_left_names = hier_topics.Child_Left_Name + child_right_names = hier_topics.Child_Right_Name + # Extract mapping from ID to name - topic_to_name = dict(zip(hier_topics.Child_Left_ID, hier_topics.Child_Left_Name)) - topic_to_name.update(dict(zip(hier_topics.Child_Right_ID, hier_topics.Child_Right_Name))) + topic_to_name = dict(zip(child_left_ids.astype(str), child_left_names)) + topic_to_name.update(dict(zip(child_right_ids.astype(str), child_right_names))) topic_to_name = {topic: name[:100] for topic, name in topic_to_name.items()} # Create tree From 0d03f5135ac86fb45b0bf035c799dc14cb13102c Mon Sep 17 00:00:00 2001 From: Kenny Lov Date: Mon, 19 Aug 2024 09:16:21 -0700 Subject: [PATCH 2/2] Applying ruff formatter --- bertopic/_bertopic.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 70024b5e..a2bd5119 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -1852,7 +1852,7 @@ def get_topic_tree( hier_topics: pd.DataFrame, max_distance: float = None, tight_layout: bool = False, - custom_labels: Union[bool, str] = False + custom_labels: Union[bool, str] = False, ) -> str: """Extract the topic tree such that it can be printed. @@ -1908,14 +1908,14 @@ def get_topic_tree( child_right_ids = hier_topics.Child_Right_ID.astype(int) # Get the new parent labels generated from `hierarchical_topics` - new_left_labels = {int(row['Child_Left_ID']): row['Child_Left_Name'] for idx, row in hier_topics.iterrows()} - new_right_labels = {int(row['Child_Right_ID']): row['Child_Right_Name'] for idx, row in hier_topics.iterrows()} + new_left_labels = {int(row["Child_Left_ID"]): row["Child_Left_Name"] for idx, row in hier_topics.iterrows()} + new_right_labels = {int(row["Child_Right_ID"]): row["Child_Right_Name"] for idx, row in hier_topics.iterrows()} if custom_labels: left_labels = {} if isinstance(custom_labels, str): for topic, kws_info in self.topic_aspects_[custom_labels].items(): - label = '_'.join([kw[0] for kw in kws_info[:5]]) # displaying top 5 kws + label = "_".join([kw[0] for kw in kws_info[:5]]) # displaying top 5 kws left_labels[topic] = label elif self.custom_labels_ is not None and custom_labels: left_labels = {topic_id: label for topic_id, label in enumerate(self.custom_labels_, -self._outliers)} @@ -1927,12 +1927,8 @@ def get_topic_tree( new_left_labels.update(left_labels) new_right_labels.update(right_labels) - child_left_names = [ - new_left_labels[topic] for topic in child_left_ids - ] - child_right_names = [ - new_right_labels[topic] for topic in child_right_ids - ] + child_left_names = [new_left_labels[topic] for topic in child_left_ids] + child_right_names = [new_right_labels[topic] for topic in child_right_ids] else: child_left_names = hier_topics.Child_Left_Name