Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

text-splitters: bug fix for CharacterTextSplitter replacing original separator with regex pattern when merging #23519

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
67 changes: 59 additions & 8 deletions libs/text-splitters/langchain_text_splitters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,68 @@ def split_documents(self, documents: Iterable[Document]) -> List[Document]:
metadatas.append(doc.metadata)
return self.create_documents(texts, metadatas=metadatas)

def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
text = separator.join(docs)
def _join_docs(
self, docs: List[str], separator: Union[str, Iterable[str]]
) -> Optional[str]:
if isinstance(separator, str):
# If separators is a single string, join docs using this single separator
text = separator.join(docs)
else:
# If separators is an iterable, use each separator for the
# respective positions
if len(docs) == 0:
return None
separator = list(separator)

if len(docs) - 1 != len(separator):
raise ValueError(
f"Number of separators ({len(separator)}) should be equal to "
f"number of docs minus 1 ({len(docs) - 1})."
)

text = docs[0]
for doc, sep in zip(docs[1:], separator):
text += sep + doc

if self._strip_whitespace:
text = text.strip()

if text == "":
return None
else:
return text

def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
def _merge_splits(
self, splits: Iterable[str], separator: Union[str, Iterable[str]]
) -> List[str]:
splits = list(splits)
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
if isinstance(separator, str):
separator = [separator] * (len(list(splits)) - 1)

separator = list(separator)
if len(splits) - 1 != len(separator):
raise ValueError(
f"Number of separators ({len(separator)}) should be equal to "
f"number of splits minus 1 ({len(splits) - 1})."
)

separator_lens = [self._length_function(sep) for sep in separator]

docs = []
current_doc: List[str] = []
# current_doc_start, current_doc_end keep tracks of the index of the splits
# in the current_doc
current_doc_start, current_doc_end = 0, 0
total = 0
for d in splits:
for i, d in enumerate(splits):
_len = self._length_function(d)
# separator_len is not applicable when we are adding the 1st elements
separator_len = (
separator_lens[i - 1] if i > 0 and i - 1 < len(separator_lens) else 0
)

if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
Expand All @@ -124,7 +167,9 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
f"which is longer than the specified {self._chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
doc = self._join_docs(
current_doc, separator[current_doc_start : current_doc_end - 1]
)
if doc is not None:
docs.append(doc)
# Keep on popping if:
Expand All @@ -136,12 +181,18 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
separator_lens[current_doc_start]
if len(current_doc) > 1
else 0
)
current_doc = current_doc[1:]
current_doc_start += 1
current_doc.append(d)
current_doc_end += 1
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
doc = self._join_docs(
current_doc, separator[current_doc_start : current_doc_end - 1]
)
if doc is not None:
docs.append(doc)
return docs
Expand Down
78 changes: 58 additions & 20 deletions libs/text-splitters/langchain_text_splitters/character.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import Any, List, Literal, Optional, Union
from typing import Any, List, Literal, Optional, Tuple, Union

from langchain_text_splitters.base import Language, TextSplitter

Expand All @@ -23,36 +23,60 @@ def split_text(self, text: str) -> List[str]:
separator = (
self._separator if self._is_separator_regex else re.escape(self._separator)
)
splits = _split_text_with_regex(text, separator, self._keep_separator)
_separator = "" if self._keep_separator else self._separator
splits, actual_separators = _split_text_with_regex(
text, separator, self._keep_separator
)
_separator = "" if self._keep_separator else actual_separators
return self._merge_splits(splits, _separator)


def _split_text_with_regex(
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
) -> List[str]:
) -> Tuple[List[str], List[str]]:
# Now that we have the separator, split the text
actual_separators = []

if separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
# take the elements at odd indexes as the actual separators
actual_separators = [_splits[i] for i in range(1, len(_splits), 2)]

if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = (
([_splits[i] + _splits[i + 1] for i in range(0, len(_splits) - 1, 2)])
if keep_separator == "end"
else ([_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)])
)
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = (
(splits + [_splits[-1]])
if keep_separator == "end"
else ([_splits[0]] + splits)
)
else:
splits = re.split(separator, text)
# skip the separators
splits = [_splits[i + 1] for i in range(1, len(_splits), 2)]

if len(_splits) % 2 == 0:
splits += _splits[-1:]

splits = (
(splits + [_splits[-1]])
if keep_separator == "end"
else ([_splits[0]] + splits)
)
else:
splits = list(text)
return [s for s in splits if s != ""]
# in this case, splits is a list of characters, we set the actual_separators
# to be a list of empty strings
actual_separators = [""] * (len(splits) - 1)

# remove empty string as well as corresponding separators
new_splits = []
new_actual_separators = []
for i, s in enumerate(splits):
if s != "":
new_splits.append(s)
# Only append to actual_separators if we are not at the last split
if i < len(actual_separators):
new_actual_separators.append(actual_separators[i])

return new_splits, new_actual_separators


class RecursiveCharacterTextSplitter(TextSplitter):
Expand Down Expand Up @@ -91,26 +115,40 @@ def _split_text(self, text: str, separators: List[str]) -> List[str]:
break

_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex(text, _separator, self._keep_separator)
splits, actual_separators = _split_text_with_regex(
text, _separator, self._keep_separator
)

# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
_good_separators = []
for s, sep in zip(splits, actual_separators + [""]):
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
_good_separators.append(sep)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
# merging using the actual separator instead of the regex
# if self._keep_separator is True, the last separator is already
# included in the splits, thus we pass in ""
merged_text = self._merge_splits(
_good_splits,
"" if self._keep_separator else _good_separators[:-1],
)
final_chunks.extend(merged_text)
_good_splits = []
_good_separators = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
# if self._keep_separator is True, the last separator is already included
# in the splits, thus we pass in ""
merged_text = self._merge_splits(
_good_splits, "" if self._keep_separator else _good_separators[:-1]
)
final_chunks.extend(merged_text)
return final_chunks

Expand Down
37 changes: 37 additions & 0 deletions libs/text-splitters/tests/unit_tests/test_text_splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,43 @@
assert output == expected_output


@pytest.mark.parametrize(
"separators",
[
r"\s",
[r"\s"],
[r" "],
[r"\s+"],
]
)
def test_recursive_text_splitter_regex_not_keep_separator(separators: str | list[str]) -> None:

Check failure on line 388 in libs/text-splitters/tests/unit_tests/test_text_splitters.py

View workflow job for this annotation

GitHub Actions / cd libs/text-splitters / make lint #3.12

Ruff (E501)

tests/unit_tests/test_text_splitters.py:388:89: E501 Line too long (95 > 88)

Check failure on line 388 in libs/text-splitters/tests/unit_tests/test_text_splitters.py

View workflow job for this annotation

GitHub Actions / cd libs/text-splitters / make lint #3.9

Ruff (E501)

tests/unit_tests/test_text_splitters.py:388:89: E501 Line too long (95 > 88)
"""Test Recursive Text Splitter using regex to split but not keeping separators."""

splitter = RecursiveCharacterTextSplitter(
separators=separators,
keep_separator=False,
is_separator_regex=True,
chunk_size=15,
chunk_overlap=0,
strip_whitespace=False,
)
output = splitter.split_text("Hello world")

# here we expect that the original space between "Hello" and "world" be retained,
# rather than replaced by the regular expression "\s"
assert output == [
"Hello world",
]

# more cases
assert splitter.split_text("Hello world!") == ["Hello world!"]
assert splitter.split_text("Hello world! ") == ["Hello world!"]
assert splitter.split_text("Hello world! ") == ["Hello world! "]
assert splitter.split_text(" Hello world!") == [" Hello world!"]
assert splitter.split_text(" Hello world!") == [" Hello world!"]



def test_split_documents() -> None:
"""Test split_documents."""
splitter = CharacterTextSplitter(separator="", chunk_size=1, chunk_overlap=0)
Expand Down
Loading