From 6a5a78d7459dcea63bd346c53543690075465b20 Mon Sep 17 00:00:00 2001 From: James Tauber Date: Wed, 7 Feb 2024 05:37:53 -0500 Subject: [PATCH] method for calculating term frequencies fixes #28 --- README.md | 18 +++++++++++++++++- termdoc/htdm.py | 3 +++ tests.py | 13 +++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1ac2171..2cf3ddc 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,6 @@ Entire lists of tokens can be added for a particular address in one go using `ad ``` - You can **prune** a HTDM to just `n` levels with the method `prune(n)`. You can iterate over the document-term counts at the leaves of the HTDM with the method `leaf_entries()` (this returns a generator yielding `(document_address, term, count)` tuples). This is effectively a traditional TDM (the document IDs will still reflect the hierarchy but the aggregate counts aren't present). @@ -133,6 +132,23 @@ You can deep copy an HTDM with `copy()`. You can also pass a prefix to `copy()` You can save out an HTDM with `save()` which takes a `filename` and optional `field_sep` (defaulting to tab) and `prefix` (if you just want to save out a subtree). +### Calculations + +You can get a term frequency with `tf(term)` or `tf(term, address)`. + +```python +>>> c = termdoc.HTDM() +>>> c.increment_count("1", "foo") +>>> c.increment_count("1", "bar", 3) +>>> c.increment_count("2", "foo", 3) +>>> c.increment_count("2", "bar") +>>> c.tf("foo") +0.5 +>>> c.tf("foo", "2") +0.75 + +``` + ### Duplicates Policy You can optionally pass in a `duplicates` setting to the constructor indicating the policy you want to follow if a term-document count is updated more than once. diff --git a/termdoc/htdm.py b/termdoc/htdm.py index de6ef2e..5354b60 100644 --- a/termdoc/htdm.py +++ b/termdoc/htdm.py @@ -97,3 +97,6 @@ def copy(self, prefix=None): for document, term, count in self.leaf_entries(prefix): new.increment_count(document, term, count) return new + + def tf(self, term, address=""): + return self.get_counts(address)[term] / self.get_counts(address).total() diff --git a/tests.py b/tests.py index 9f0cbc5..04eb510 100755 --- a/tests.py +++ b/tests.py @@ -347,5 +347,18 @@ def test_add(self): self.assertEqual(c.get_counts("1")["bar"], 2) +class Test4(unittest.TestCase): + def test_term_frequency(self): + import termdoc + + c = termdoc.HTDM() + c.increment_count("1", "foo") + c.increment_count("1", "bar", 3) + c.increment_count("2", "foo", 3) + c.increment_count("2", "bar") + self.assertEqual(c.tf("foo"), 0.5) + self.assertEqual(c.tf("foo", "2"), 0.75) + + if __name__ == "__main__": unittest.main()