Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions lib/text_splitter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ defmodule LangChain.TextSplitter do
end

defp merge_split_helper(d, acc, text_splitter, separator) do
separator_len = String.length(separator)
len = String.length(d)
separator_len = text_splitter.tokenizer.(separator)
len = text_splitter.tokenizer.(d)

test_separator_length =
if Enum.count(acc.current_doc) > 0, do: separator_len, else: 0
Expand All @@ -30,7 +30,7 @@ defmodule LangChain.TextSplitter do
acc.total -
(acc.current_doc
|> Enum.at(0, "")
|> String.length()) - separator_length
|> text_splitter.tokenizer.()) - separator_length

new_current_doc = acc.current_doc |> Enum.drop(1)

Expand Down Expand Up @@ -59,11 +59,11 @@ defmodule LangChain.TextSplitter do
|> Enum.reduce(
acc,
fn d, acc ->
len = String.length(d)
len = text_splitter.tokenizer.(d)

separator_length =
if Enum.count(acc.current_doc) > 0,
do: String.length(plain_separator),
do: text_splitter.tokenizer.(plain_separator),
else: 0

acc =
Expand All @@ -84,7 +84,7 @@ defmodule LangChain.TextSplitter do

separator_length =
if Enum.count(acc.current_doc) > 1,
do: String.length(plain_separator),
do: text_splitter.tokenizer.(plain_separator),
else: 0

%{acc | total: acc.total + separator_length + len}
Expand Down
13 changes: 8 additions & 5 deletions lib/text_splitter/character_text_splitter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ defmodule LangChain.TextSplitter.CharacterTextSplitter do
This splitter provides consistent chunk sizes.
It operates as follows:

- It splits the text at specified `separator` characters.
- It takes a `chunk_size` parameter that determines the maximum number of characters
- It splits the text at specified `separator` characters.
- It takes a `chunk_size` parameter that determines the maximum number of tokens
in each chunk.
- If no separator is found within the `chunk_size`,
it will create a chunk larger than the specified size.
Expand All @@ -17,10 +17,11 @@ defmodule LangChain.TextSplitter.CharacterTextSplitter do

A `CharacterTextSplitter` is defined using a schema.
* `separator` - String that splits a given text.
* `chunk_size` - Integer number of characters that a chunk should have.
* `chunk_overlap` - Integer number of characters that two consecutive chunks should share.
* `chunk_size` - Integer number of tokens that a chunk should have.
* `chunk_overlap` - Integer number of tokens that two consecutive chunks should share.
* `keep_separator` - Either `:discard_separator`, `:start` or `:end`. If `:discard_separator`, the separator is discarded from the output chunks. `:start` and `:end` keep the separator at the start or end of the output chunks. Defaults to `:discard_separator`.
* `is_separator_regex` - Boolean defaulting to `false`. If `true`, the `separator` string is not escaped. Defaults to `false`
* `tokenizer` - Function that takes a string and returns the number of tokens. Defaults to `&String.length/1`.
"""
use Ecto.Schema
import Ecto.Changeset
Expand All @@ -39,6 +40,7 @@ defmodule LangChain.TextSplitter.CharacterTextSplitter do
default: :discard_separator

field :is_separator_regex, :boolean, default: false
field :tokenizer, :any, virtual: true, default: &String.length/1
end

@type t :: %CharacterTextSplitter{}
Expand All @@ -48,7 +50,8 @@ defmodule LangChain.TextSplitter.CharacterTextSplitter do
:chunk_size,
:chunk_overlap,
:keep_separator,
:is_separator_regex
:is_separator_regex,
:tokenizer
]
@create_fields @update_fields

Expand Down
27 changes: 14 additions & 13 deletions lib/text_splitter/recursive_character_text_splitter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do
A `RecursiveCharacterTextSplitter` is defined using a schema.
* `separators` - List of string that split a given text.
The default list is `["\n\n", "\n", " ", ""]`.
* `chunk_size` - Integer number of characters that a chunk should have.
* `chunk_overlap` - Integer number of characters that two consecutive chunks should share.
* `chunk_size` - Integer number of tokens that a chunk should have.
* `chunk_overlap` - Integer number of tokens that two consecutive chunks should share.
* `keep_separator` - Either `:discard_separator`, `:start` or `:end`. If `nil`, the separator is discarded from the output chunks. `:start` and `:end` keep the separator at the start or end of the output chunks. Defaults to `start`.
* `is_separator_regex` - Boolean defaulting to `false`. If `true`, the `separator` string is not escaped. Defaults to `false`
* `tokenizer` - Function that takes a string and returns the number of tokens. Defaults to `&String.length/1`.
"""
use Ecto.Schema
import Ecto.Changeset
Expand All @@ -48,6 +49,7 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do
default: :start

field :is_separator_regex, :boolean, default: false
field :tokenizer, :any, virtual: true, default: &String.length/1
end

@type t :: %RecursiveCharacterTextSplitter{}
Expand All @@ -57,7 +59,8 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do
:chunk_size,
:chunk_overlap,
:keep_separator,
:is_separator_regex
:is_separator_regex,
:tokenizer
]
@create_fields @update_fields

Expand Down Expand Up @@ -89,7 +92,7 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do
iex> split_tags = [",", "."]
iex> base_params = %{chunk_size: 10, chunk_overlap: 0, separators: split_tags}
iex> query = "Apple,banana,orange and tomato."
iex> splitter = RecursiveCharacterTextSplitter.new!(base_params)
iex> splitter = RecursiveCharacterTextSplitter.new!(base_params)
iex> splitter |> RecursiveCharacterTextSplitter.split_text(query)
["Apple", ",banana", ",orange and tomato", "."]

Expand All @@ -98,7 +101,7 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do
iex> split_tags = [",", "."]
iex> base_params = %{chunk_size: 10, chunk_overlap: 0, separators: split_tags, keep_separator: :end}
iex> query = "Apple,banana,orange and tomato."
iex> splitter = RecursiveCharacterTextSplitter.new!(base_params)
iex> splitter = RecursiveCharacterTextSplitter.new!(base_params)
iex> splitter |> RecursiveCharacterTextSplitter.split_text(query)
["Apple,", "banana,", "orange and tomato."]

Expand All @@ -112,7 +115,7 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do
...>def hello_world():
...> print('Hello, World')
...>
...>
...>
...># Call the function
...>hello_world()"
iex> splitter =
Expand Down Expand Up @@ -166,7 +169,7 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do
recursive_splits =
splits
|> Enum.reduce(acc, fn split, acc ->
if String.length(split) < text_splitter.chunk_size do
if text_splitter.tokenizer.(split) < text_splitter.chunk_size do
%{acc | good_splits: acc.good_splits ++ [split]}
else
acc =
Expand All @@ -187,12 +190,10 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do
%{acc | final_chunks: acc.final_chunks ++ [split]}
else
new_recursive_splitter =
%{
(text_splitter
|> Map.from_struct())
| separators: new_separators |> Enum.drop(1),
is_separator_regex: true
}
text_splitter
|> Map.from_struct()
|> Map.put(:separators, new_separators |> Enum.drop(1))
|> Map.put(:is_separator_regex, true)
|> RecursiveCharacterTextSplitter.new!()

other_info =
Expand Down
49 changes: 46 additions & 3 deletions test/text_splitter_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,28 @@ defmodule TextSplitterTest do
assert output == expected_output
end
end

test "Custom tokenizer function" do
text = "foo bar baz"
# Custom tokenizer that counts words instead of characters
word_tokenizer = fn text -> text |> String.split() |> length() end

expected_output = ["foo bar", "baz"]

character_splitter =
CharacterTextSplitter.new!(%{
separator: " ",
chunk_size: 2,
chunk_overlap: 0,
tokenizer: word_tokenizer
})

output =
character_splitter
|> CharacterTextSplitter.split_text(text)

assert output == expected_output
end
end

describe "RecursiveCharacterTextSplitter" do
Expand Down Expand Up @@ -306,6 +328,27 @@ Bye!\n\n-I."
end
end

describe "Custom tokenizer functionality" do
test "RecursiveCharacterTextSplitter with word tokenizer" do
text = "Hello world. This is a test. Another sentence here."
# Custom tokenizer that counts words instead of characters
word_tokenizer = fn text -> text |> String.split() |> length() end

expected_output = ["Hello world", ". This is", "a test", ". Another sentence", "here."]

splitter =
RecursiveCharacterTextSplitter.new!(%{
separators: [". ", " "],
chunk_size: 3,
chunk_overlap: 0,
tokenizer: word_tokenizer
})

output = splitter |> RecursiveCharacterTextSplitter.split_text(text)
assert output == expected_output
end
end

describe "Programming languages splitters" do
test "Python test splitter" do
fake_python_text = """
Expand Down Expand Up @@ -497,7 +540,7 @@ message Person {
string name = 1;
int32 age = 2;
repeated string hobbies = 3;
}
}
"

expected_splits = [
Expand Down Expand Up @@ -539,7 +582,7 @@ function helloWorld() {
}

// Call the function
helloWorld();
helloWorld();
"

expected_splits = [
Expand Down Expand Up @@ -1264,7 +1307,7 @@ end

-- Some sample functions
add :: Int -> Int -> Int
add x y = x + y
add x y = x + y
"

expected_splits = [
Expand Down
Loading