-
Notifications
You must be signed in to change notification settings - Fork 0
/
embedding.py
124 lines (89 loc) · 3.25 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
__author__ = "Md. Ahsan Ayub"
__license__ = "GPL"
__credits__ = ["Ayub, Md. Ahsan", "Majumdar, Subho"]
__email__ = "md.ahsanayub@outlook.com"
__status__ = "Prototype"
import pandas as pd
from openai import OpenAI
from sentence_transformers import SentenceTransformer
class OpenAi:
def __init__(self, api_key):
self.api_key = api_key
self.client = OpenAI(
api_key = self.api_key,
)
def get_embedding(self, text, model="text-embedding-3-small"):
text = text.replace("\n", " ")
# 8196 tokens are equal to approx 6000 words
word_count = len(text.split())
if(word_count > 6000):
print(word_count)
return None
try:
return self.client.embeddings.create(input = [text], model=model).data[0].embedding
except:
print(word_count)
return None
class OctoAi:
def __init__(self, api_key):
self.api_key = api_key
self.base_url = "https://text.octoai.run/v1"
self.client = OpenAI(
api_key = self.api_key,
base_url = self.base_url,
)
def get_embedding(self, text, model="thenlper/gte-large"):
text = text.replace("\n", " ")
try:
return self.client.embeddings.create(input = [text], model=model).data[0].embedding
except:
print(len(text))
return None
class MiniLm:
""" https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 """
def __init__(self):
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
def get_embedding(self, text):
text = text.replace("\n", " ")
try:
return self.model.encode(text).tolist()
except:
print(len(text))
return None
def data_processing(obj, dataset):
# split dataset into 1k rows for embeddings
start = 1
row_count = 1000
max_count = len(dataset)
while start < max_count:
finish = start + row_count
if finish > max_count:
finish = max_count + 1
# Split the dataset into row_count
df = dataset[dataset['id'] >= start]
df = df[df['id'] < finish]
prompts = df[["text"]]
prompts["text_embedding"] = prompts["text"].astype(str).apply(obj.get_embedding)
prompts["label"] = df["label"].values
prompts["id"] = df["id"].values
prompts = prompts.drop(columns=['text'])
prompts = prompts.reindex(columns=['id','text_embedding','label'])
if isinstance(obj,OpenAi):
prompts.to_csv('embeddings/openai/'+str(start)+'-'+str(finish-1)+'.csv', index=False)
if isinstance(obj,OctoAi):
prompts.to_csv('embeddings/octoai/'+str(start)+'-'+str(finish-1)+'.csv', index=False)
if isinstance(obj,MiniLm):
prompts.to_csv('embeddings/minilm/'+str(start)+'-'+str(finish-1)+'.csv', index=False)
print(str(finish - 1) + ' done out of ' + str(max_count))
start = finish
# Driver program
if __name__ == '__main__':
dataset = pd.read_pickle("dataset/hf_dataset.pkl")
octoai = OctoAi("<your_api_key>")
data_processing(octoai, dataset)
openai = OpenAi("<your_api_key>")
data_processing(openai, dataset)
minilm = MiniLm()
data_processing(minilm, dataset)
del dataset