Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Clustering to Flair #2573

Merged
merged 38 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
aa0f52e
init
OatsProduction Dec 28, 2021
0d1ea74
formated the code
OatsProduction Dec 29, 2021
5cee85a
all filenames are now lowercase
OatsProduction Dec 29, 2021
2c213ea
added convergence to em
OatsProduction Dec 29, 2021
b6ade23
finished the kMeans algorithm
OatsProduction Dec 30, 2021
f56992f
WIP cf tree
OatsProduction Dec 30, 2021
b37555d
working global parameters
OatsProduction Jan 2, 2022
61b56e1
added changed from GitHub PR
OatsProduction Jan 2, 2022
f44dbc6
Merge branch 'flairNLP:master' into master
OatsProduction Jan 3, 2022
e4acd0b
working BIRCH
OatsProduction Jan 4, 2022
f443b7a
Merge remote-tracking branch 'origin/master'
OatsProduction Jan 4, 2022
3616b10
working EM clustering
OatsProduction Jan 8, 2022
8f1666e
Merge branch 'flairNLP:master' into master
OatsProduction Jan 8, 2022
bf80ad5
removed experiments files
OatsProduction Jan 9, 2022
f82698d
written a tutorial readMe file
OatsProduction Jan 9, 2022
a27bf94
remove test.py
whoisjones Jan 10, 2022
3331784
folder restructuring
whoisjones Jan 10, 2022
8f19db8
move clustering to docs
whoisjones Jan 10, 2022
4cf3719
change folder structure
whoisjones Jan 10, 2022
f831e50
further folder refactorings
whoisjones Jan 11, 2022
f4cafe1
kmeans refactoring.
whoisjones Jan 11, 2022
14095d9
kmeans save and load functions
whoisjones Jan 11, 2022
734a5a7
kmeans tutorial
whoisjones Jan 11, 2022
d1d1952
transform corpora for clustering function
whoisjones Jan 11, 2022
964dea5
PR for reusing sklearn clustering methods
whoisjones Jan 11, 2022
daed11c
Merge pull request #1 from whoisjones/clustering_refactorings
OatsProduction Jan 12, 2022
b57529a
added predict to clustering
Jan 12, 2022
68cef6e
added saving and loading of the cluster model
OatsProduction Jan 13, 2022
1cf2c9a
labels added to the sentences
OatsProduction Jan 13, 2022
d230535
added corpus for StackOverflow data
OatsProduction Jan 15, 2022
cd8edcb
working corpus for STACKOVERFLOW
OatsProduction Jan 18, 2022
bff7484
WIP loading and saving
Jan 19, 2022
28635a2
added evaluation data
OatsProduction Jan 20, 2022
f72d757
working loading/saving of the ClusteringModel
OatsProduction Jan 20, 2022
d2f4695
Merge branch 'master' into master
alanakbik Jan 26, 2022
4817e70
fixed the TUTORIAL_12_CLUSTERING.md
OatsProduction Feb 2, 2022
1df7834
Merge remote-tracking branch 'origin/master'
OatsProduction Feb 2, 2022
d062138
added fix for the memory mode
OatsProduction Feb 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions flair/models/clustering/Clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod


class Clustering(ABC):
@abstractmethod
def cluster(self, vectors: list) -> list:
pass

def getLabelList(self, listSenctence) -> list:
OatsProduction marked this conversation as resolved.
Show resolved Hide resolved
OatsProduction marked this conversation as resolved.
Show resolved Hide resolved
return list(map(lambda e: int(e.get_labels('cluster')[0].value), listSenctence))
41 changes: 41 additions & 0 deletions flair/models/clustering/Evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from sklearn.datasets import fetch_20newsgroups
from sklearn.metrics import accuracy_score, normalized_mutual_info_score


def getStackOverFlowLabels():
OatsProduction marked this conversation as resolved.
Show resolved Hide resolved
with open("evaluation/StackOverflow/title_StackOverflow.txt", "r", encoding="utf8") as myfile:
data = myfile.readlines()
return data


def getStackOverFlowData():
with open("evaluation/StackOverflow/title_StackOverflow.txt", "r", encoding="utf8") as myfile:
data = myfile.readlines()
return data


maxDocuments = 400
categories = [
'rec.motorcycles',
'rec.sport.baseball',
'comp.graphics',
'sci.space',
'talk.politics.mideast'
]


def get20NewsData():
ng5 = fetch_20newsgroups(categories=categories)
return ng5.data[1:maxDocuments]


def get20NewsLabel():
ng5 = fetch_20newsgroups(categories=categories)
return ng5.target[1:maxDocuments]


def evaluate(labels: list, predict_labels: list):
acc = accuracy_score(labels, predict_labels)
nmi = normalized_mutual_info_score(labels, predict_labels)
print("ACC: " + str(acc))
print("NMI: " + str(nmi))
51 changes: 51 additions & 0 deletions flair/models/clustering/birch/Birch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from flair.embeddings import DocumentEmbeddings

from Clustering import Clustering
from birch.model.CfTree import CfTree
from birch.model.ClusteringFeature import ClusteringFeature
from flair.datasets import DataLoader

from kmeans.K_Means import KMeans

branchingFactorNonLeaf = 0
branchingFactorLeaf = 0
distanceMax = 1000000000
threshold = 0


class Birch(Clustering):
def __init__(self, thresholds: float, embeddings: DocumentEmbeddings, B: int, L: int):
global threshold
threshold = thresholds
global branchingFactorLeaf
branchingFactorLeaf = L
global branchingFactorNonLeaf
branchingFactorNonLeaf = B
global distanceMax

self.embeddings = embeddings
self.cfTree = CfTree()
self.predict = []

def cluster(self, vectors: list, batchSize: int = 64):
print("Starting BIRCH clustering with threshold: " + str(threshold))
self.predict = [0] * len(vectors)

for batch in DataLoader(vectors, batch_size=batchSize):
self.embeddings.embed(batch)

for idx, vector in enumerate(vectors):
self.cfTree.insertCf(ClusteringFeature(vector.embedding, idx=idx))
self.cfTree.validate()

cfs = self.cfTree.getLeafCfs()
cfVectors = self.cfTree.getVectorsFromCf(cfs)

kMeans = KMeans(3)
kMeans.clusterVectors(cfVectors)

for idx, cf in enumerate(cfs):
for cfIndex in cf.indices:
self.predict[cfIndex] = kMeans.predict[idx]

return self.cfTree
Empty file.
16 changes: 16 additions & 0 deletions flair/models/clustering/birch/model/CfNode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from birch.model.ClusteringFeature import ClusteringFeature


class CfNode:
def __init__(self):
self.cfs = []
self.isLeaf = False
self.parent = None

def sumAllCfs(self) -> ClusteringFeature:
cf = ClusteringFeature()

for help in self.cfs:
cf.absorbCf(help)

return cf
172 changes: 172 additions & 0 deletions flair/models/clustering/birch/model/CfTree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import numpy as np
from birch.model import LeafNode
from birch.model.CfNode import CfNode
from birch.model.ClusteringFeature import ClusteringFeature
from birch.model.NonLeafNode import NonLeafNode
from distance import Distance


class CfTree:
def __init__(self):
self.root = NonLeafNode()
self.firstChild = self.root.entries[0]

def insertCf(self, cf: ClusteringFeature):
leaf = self.getClosestLeaf(cf, self.root)
cf_node = leaf.getClosestCF(cf)

if cf_node.canAbsorbCf(cf):
cf_node.absorbCf(cf)
self.updatePathSimple(leaf)
return
if leaf.canAddNewCf():
leaf.addCF(cf)
self.updatePathSimple(leaf)
else:
newLeaf = self.splitLeaf(leaf, cf)
self.updatePathWithNewLeaf(newLeaf)

def splitLeaf(self, leaf: LeafNode, cf: ClusteringFeature) -> LeafNode:
leaf.cfs.append(cf)
indices = Distance.getFurthest2Points(leaf.cfs)
oldCf = [leaf.cfs[indices[0]]]
newCf = [leaf.cfs[indices[1]]]

for cf in leaf.cfs:
if not cf is oldCf[0] and not cf is newCf[0]:
if cf.calcualteDistance(oldCf[0]) < cf.calcualteDistance(newCf[0]):
oldCf.append(cf)
else:
newCf.append(cf)

index = leaf.parent.getChildIndex(leaf)
leaf.cfs = oldCf
leaf.parent.cfs[index] = leaf.sumAllCfs()

newLeaf = LeafNode.LeafNode(newCf, parent=leaf.parent)
leaf.next = newLeaf
newLeaf.prev = newLeaf

return newLeaf

def updatePathSimple(self, child: LeafNode):
parent = child.parent

while parent is not None:
idx = parent.getChildIndex(child)
parent.cfs[idx] = child.sumAllCfs()
child = parent
parent = parent.parent

def updatePathWithNewLeaf(self, newLeaf: LeafNode):
# TODO: update the whole path in a loop
if newLeaf.parent.canAddNode():
newLeaf.parent.addNode(newLeaf)
else:
self.splitNonLeafNode(newLeaf)

def splitNonLeafNode(self, node: CfNode):

if node.parent != None:
node.parent.addNode(node)
nonLeafNode = node.parent
else:
nonLeafNode = node

indices = Distance.getFurthest2Points(nonLeafNode.cfs)
oldCf = [indices[0]]
newCf = [indices[1]]
nodeCfs = nonLeafNode.cfs
nodeEntries = nonLeafNode.entries

for idx, cf in enumerate(nonLeafNode.cfs):
if not cf is nodeCfs[oldCf[0]] and not cf is nodeCfs[newCf[0]]:
if cf.calcualteDistance(nodeCfs[oldCf[0]]) < cf.calcualteDistance(nodeCfs[newCf[0]]):
oldCf.append(idx)
else:
newCf.append(idx)

newNode = NonLeafNode()
newNode.cfs = list(np.array(nodeCfs)[np.array(newCf)])
newNode.entries = list(np.array(nodeEntries)[np.array(newCf)])

for item in newNode.entries:
item.parent = newNode

nonLeafNode.cfs = list(np.array(nodeCfs)[np.array(oldCf)])
nonLeafNode.entries = list(np.array(nodeEntries)[np.array(oldCf)])

for item in nonLeafNode.entries:
item.parent = nonLeafNode

if nonLeafNode.parent is None:
self.root = NonLeafNode()
self.root.entries = []
self.root.cfs = []
self.root.addNode(nonLeafNode)
self.root.addNode(newNode)
print("new Height -> new root")
else:
if nonLeafNode.parent.canAddNode():
print("add Node")
nonLeafNode.parent.addNode(newNode)
else:
print("split again ")
self.splitNonLeafNode(nonLeafNode.parent)

def getClosestLeaf(self, cf: ClusteringFeature, nonLeafNode: NonLeafNode) -> LeafNode:
cfNode = nonLeafNode.getClosestChild(cf)
if cfNode is None:
return None

if cfNode.isLeaf:
return cfNode
else:
return self.getClosestLeaf(cf, cfNode)

def validate(self):
self.validateNode(self.root)

def validateNode(self, nonLeafNode: NonLeafNode) -> bool:
n = 0
# TODO: fix
# for idx, node in enumerate(nonLeafNode.entries):
# n = self.calculateCfs(node)
# nNonLeaf = nonLeafNode.cfs[idx].N
# if n != nNonLeaf:
# print(False, idx)
# return False

return True

def calculateCfs(self, nonLeafNode: NonLeafNode) -> int:
if nonLeafNode.isLeaf:
return nonLeafNode.sumAllCfs().N
else:
n = 0
for idx, node in enumerate(nonLeafNode.entries):
n = self.validateNode(node)
nNonLeaf = nonLeafNode.cfs[idx].N
if n != nNonLeaf:
print(False, n, nNonLeaf)

def getLeafList(self) -> list:
next = self.firstChild
leafs = [next]
while next.next is not None:
print("next")
next = next.next
leafs.append(next)

return leafs

def getLeafCfs(self) -> list:
leafs = self.getLeafList()
cfVectors = []
for leaf in leafs:
for cf in leaf.cfs:
cfVectors.append(cf)
return cfVectors

def getVectorsFromCf(self, cfs: list) -> list:
return [cf.getCenter() for cf in cfs]
49 changes: 49 additions & 0 deletions flair/models/clustering/birch/model/ClusteringFeature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch import Tensor

from birch import Birch
from distance import Distance


class ClusteringFeature:
def __init__(self, tensor: Tensor = None, idx: int = None):
if tensor is None:
self.N = 0
self.SS = None
self.LS = None
else:
self.N = 1
self.SS = tensor
self.LS = tensor * tensor
if idx is None:
self.indices = []
else:
self.indices = [idx]

def absorbCf(self, cf):
self.N += cf.N
self.indices.extend(cf.indices)
if self.LS is None:
self.LS = cf.LS
else:
self.LS += cf.LS
if self.SS is None:
self.SS = cf.SS
else:
self.SS *= cf.SS

def getCenter(self) -> Tensor:
return self.LS / self.N

def calcualteDistance(self, vector) -> Tensor:
if self.LS is None:
return Tensor([Birch.distanceMax - 100])
else:
return Distance.getCosineDistance(self.getCenter(), vector.getCenter())

def canAbsorbCf(self, cf) -> bool:
if self.LS is None:
return True

distance = Distance.getCosineDistance(self.getCenter(), cf.getCenter())
return distance <= Birch.threshold
39 changes: 39 additions & 0 deletions flair/models/clustering/birch/model/LeafNode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from torch import Tensor

import birch.Birch
from birch.model.ClusteringFeature import ClusteringFeature
from birch.model.CfNode import CfNode


class LeafNode(CfNode):
def __init__(self, initCfs: list = None, parent=None):
super().__init__()
if initCfs is None:
self.cfs = [ClusteringFeature()]
else:
self.cfs = initCfs
self.parent = parent
self.isLeaf = True
self.prev = None
self.next = None

def addCF(self, cf: ClusteringFeature):
self.cfs.append(cf)

def canAddNewCf(self):
return self.cfs.__len__() < birch.Birch.branchingFactorLeaf

def getClosestCF(self, vector: Tensor) -> ClusteringFeature:
minDistance = birch.Birch.distanceMax
cfResult = None

for cf in self.cfs:
distance = cf.calcualteDistance(vector)

if distance < minDistance:
minDistance = distance
cfResult = cf

return cfResult


Loading