Skip to content

Commit

Permalink
Merge pull request #166 from praekeltfoundation/put-embedding-model-b…
Browse files Browse the repository at this point in the history
…ehind-flag

Put embedding model behind flag
  • Loading branch information
KaitCrawford committed Aug 28, 2023
2 parents 984e19a + dba6d79 commit dc53f87
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 1 deletion.
2 changes: 2 additions & 0 deletions contentrepo/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,5 @@
EMAIL_SSL_CERTFILE = env.str("EMAIL_SSL_CERTFILE", None)
EMAIL_SSL_KEYFILE = env.str("EMAIL_SSL_KEYFILE", None)
EMAIL_TIMEOUT = env.int("EMAIL_TIMEOUT", None)

LOAD_TRANSFORMER_MODEL = env.bool("LOAD_TRANSFORMER_MODEL", False)
5 changes: 4 additions & 1 deletion home/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Define constants for use throughout the application
from django.conf import settings
from sentence_transformers import SentenceTransformer

GENDER_CHOICES = [
Expand All @@ -21,4 +22,6 @@
("empty", "Empty"),
]

model = SentenceTransformer("all-mpnet-base-v2")
model = None
if settings.LOAD_TRANSFORMER_MODEL:
model = SentenceTransformer("all-mpnet-base-v2")
3 changes: 3 additions & 0 deletions home/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ def save_revision(
def update_embedding(sender, instance, *args, **kwargs):
from .utils import preprocess_content_for_embedding

if not model:
return

embedding = {}
if instance.enable_web:
content = []
Expand Down
3 changes: 3 additions & 0 deletions home/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def cosine_similarity(A, B):
def retrieve_top_n_content_pieces(
user_input, queryset, n=5, content_type=None, platform="web"
):
if not model:
return []

# similar_embeddings = [{'faq_name':, 'faq_content':, 'embedding':}, ...] # We need to filter by content type and then retrieve their embeddings
# Generate embedding for user text
user_embedding = model.encode([user_input])
Expand Down

0 comments on commit dc53f87

Please sign in to comment.