diff --git a/recipes/natural_language_processing/summarizer/app/summarizer.py b/recipes/natural_language_processing/summarizer/app/summarizer.py index 50d68d89..0a359ed1 100644 --- a/recipes/natural_language_processing/summarizer/app/summarizer.py +++ b/recipes/natural_language_processing/summarizer/app/summarizer.py @@ -3,6 +3,7 @@ from langchain.prompts import PromptTemplate from langchain_community.callbacks import StreamlitCallbackHandler from langchain_community.document_loaders import PyMuPDFLoader +from langchain_text_splitters import RecursiveCharacterTextSplitter from rouge_score import rouge_scorer import streamlit as st import tempfile @@ -34,22 +35,32 @@ def checking_model_service(): with st.spinner("Checking Model Service Availability..."): checking_model_service() +def split_append_chunk(chunk, list): + chunk_length = len(chunk) + chunk1 = " ".join(chunk.split()[:chunk_length]) + chunk2 = " ".join(chunk.split()[chunk_length:]) + list.extend([chunk1, chunk2]) def chunk_text(text): chunks = [] - chunk_size = 1024 - tokens = requests.post(f"{model_service[:-2]}extras/tokenize/", - json={"input":text}).content - tokens = json.loads(tokens)["tokens"] - num_tokens = len(tokens) - num_chunks = (num_tokens//chunk_size)+1 - for i in range(num_chunks): - chunk = tokens[:chunk_size] - chunk = requests.post(f"{model_service[:-2]}extras/detokenize/", - json={"tokens":chunk}).content - chunk = json.loads(chunk)["text"] - chunks.append(chunk) - tokens = tokens[chunk_size:] + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=3048, + chunk_overlap=0, + length_function=len, + is_separator_regex=False + ) + + text_chunks = text_splitter.create_documents([text]) + for chunk in text_chunks: + chunk = chunk.page_content + count = requests.post(f"{model_service[:-2]}extras/tokenize/count", + json={"input":chunk}).content + count = json.loads(count)["count"] + if count >= 2048: + split_append_chunk(chunk, chunks) + else: + chunks.append(chunk) + return chunks def read_file(file):