Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom labels to get topic tree #2125

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,11 +1847,12 @@ def get_representative_docs(self, topic: int = None) -> List[str]:
else:
return self.representative_docs_

@staticmethod
def get_topic_tree(
self,
hier_topics: pd.DataFrame,
max_distance: float = None,
tight_layout: bool = False,
custom_labels: Union[bool, str] = False,
) -> str:
"""Extract the topic tree such that it can be printed.

Expand All @@ -1862,6 +1863,11 @@ def get_topic_tree(
based on the Distance column in `hier_topics`.
tight_layout: Whether to use a tight layout (narrow width) for
easier readability if you have hundreds of topics.
custom_labels: If bool, whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
If `str`, it uses labels from other aspects, e.g., "Aspect1".
NOTE: Custom labels are only generated for the original
un-merged topics.

Returns:
A tree that has the following structure when printed:
Expand Down Expand Up @@ -1897,9 +1903,40 @@ def get_topic_tree(

max_original_topic = hier_topics.Parent_ID.astype(int).min() - 1

# Prepare tree labels to print
child_left_ids = hier_topics.Child_Left_ID.astype(int)
child_right_ids = hier_topics.Child_Right_ID.astype(int)

# Get the new parent labels generated from `hierarchical_topics`
new_left_labels = {int(row["Child_Left_ID"]): row["Child_Left_Name"] for idx, row in hier_topics.iterrows()}
new_right_labels = {int(row["Child_Right_ID"]): row["Child_Right_Name"] for idx, row in hier_topics.iterrows()}

if custom_labels:
left_labels = {}
if isinstance(custom_labels, str):
for topic, kws_info in self.topic_aspects_[custom_labels].items():
label = "_".join([kw[0] for kw in kws_info[:5]]) # displaying top 5 kws
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that some topic aspects might not have 5 keywords but actually a single label. In those cases, the kws_info[0] is likely to be filled with a label (or even a summary) and each instance in kws_info[1:] will be an empty string. As a result, you might get the following label: "artificial intelligence____".

left_labels[topic] = label
Comment on lines +1916 to +1919
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a check for when a user attempts to use a specific aspect but it is not found in self.topic_aspects_? I can imagine that a user (myself included) would mistype the name and the resulting error might not be entirely clear.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, I can add a check here. One thing I noticed when looking through how custom labels were handled in the plotting functions is that they all follow a pretty similar pattern:

if isinstance(custom_labels, str):

if isinstance(custom_labels, str):

if isinstance(custom_labels, str):

if isinstance(custom_labels, str):

if isinstance(custom_labels, str):

Maybe a more generic and reusable function can be created for this, which would check if the user specified aspect exists in self.topic_aspects_

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it makes no sense to do that check here but not at all of those other instances. For now, I'm alright either way. I can imagine wanting to stick with what I already use everywhere else for consistency as this might be a bit out of the scope of this particular PR. So it's up to you.

elif self.custom_labels_ is not None and custom_labels:
left_labels = {topic_id: label for topic_id, label in enumerate(self.custom_labels_, -self._outliers)}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to mention that I am going to steal that enumerate trick. That is so much nicer coding than how I have approached it thus far.


right_labels = left_labels.copy()

# We want to preserve the original labels from `topic_aspects_` or `custom_labels_`
# while adding in those generated from `hierarchical_topics`
new_left_labels.update(left_labels)
new_right_labels.update(right_labels)

child_left_names = [new_left_labels[topic] for topic in child_left_ids]
child_right_names = [new_right_labels[topic] for topic in child_right_ids]

else:
child_left_names = hier_topics.Child_Left_Name
child_right_names = hier_topics.Child_Right_Name

# Extract mapping from ID to name
topic_to_name = dict(zip(hier_topics.Child_Left_ID, hier_topics.Child_Left_Name))
topic_to_name.update(dict(zip(hier_topics.Child_Right_ID, hier_topics.Child_Right_Name)))
topic_to_name = dict(zip(child_left_ids.astype(str), child_left_names))
topic_to_name.update(dict(zip(child_right_ids.astype(str), child_right_names)))
topic_to_name = {topic: name[:100] for topic, name in topic_to_name.items()}

# Create tree
Expand Down