From 17bea8d817f042790c10f43a41be36637b1cbb3c Mon Sep 17 00:00:00 2001 From: Mathieu Ripert Date: Sun, 25 May 2025 16:16:21 +0200 Subject: [PATCH] feat: add configurable tokenizer to text splitters - Add tokenizer field to CharacterTextSplitter and RecursiveCharacterTextSplitter - Replace hardcoded String.length() calls with configurable tokenizer function - Default tokenizer remains String.length/1 for backward compatibility - Update documentation to reflect token-based chunk sizing - Add tests demonstrating custom tokenizer functionality --- lib/text_splitter.ex | 12 ++--- lib/text_splitter/character_text_splitter.ex | 13 +++-- .../recursive_character_text_splitter.ex | 27 +++++----- test/text_splitter_test.exs | 49 +++++++++++++++++-- 4 files changed, 74 insertions(+), 27 deletions(-) diff --git a/lib/text_splitter.ex b/lib/text_splitter.ex index ca2c4e71..06989f60 100644 --- a/lib/text_splitter.ex +++ b/lib/text_splitter.ex @@ -11,8 +11,8 @@ defmodule LangChain.TextSplitter do end defp merge_split_helper(d, acc, text_splitter, separator) do - separator_len = String.length(separator) - len = String.length(d) + separator_len = text_splitter.tokenizer.(separator) + len = text_splitter.tokenizer.(d) test_separator_length = if Enum.count(acc.current_doc) > 0, do: separator_len, else: 0 @@ -30,7 +30,7 @@ defmodule LangChain.TextSplitter do acc.total - (acc.current_doc |> Enum.at(0, "") - |> String.length()) - separator_length + |> text_splitter.tokenizer.()) - separator_length new_current_doc = acc.current_doc |> Enum.drop(1) @@ -59,11 +59,11 @@ defmodule LangChain.TextSplitter do |> Enum.reduce( acc, fn d, acc -> - len = String.length(d) + len = text_splitter.tokenizer.(d) separator_length = if Enum.count(acc.current_doc) > 0, - do: String.length(plain_separator), + do: text_splitter.tokenizer.(plain_separator), else: 0 acc = @@ -84,7 +84,7 @@ defmodule LangChain.TextSplitter do separator_length = if Enum.count(acc.current_doc) > 1, - do: String.length(plain_separator), + do: text_splitter.tokenizer.(plain_separator), else: 0 %{acc | total: acc.total + separator_length + len} diff --git a/lib/text_splitter/character_text_splitter.ex b/lib/text_splitter/character_text_splitter.ex index 0feb3f42..c9a8799c 100644 --- a/lib/text_splitter/character_text_splitter.ex +++ b/lib/text_splitter/character_text_splitter.ex @@ -5,8 +5,8 @@ defmodule LangChain.TextSplitter.CharacterTextSplitter do This splitter provides consistent chunk sizes. It operates as follows: - - It splits the text at specified `separator` characters. - - It takes a `chunk_size` parameter that determines the maximum number of characters + - It splits the text at specified `separator` characters. + - It takes a `chunk_size` parameter that determines the maximum number of tokens in each chunk. - If no separator is found within the `chunk_size`, it will create a chunk larger than the specified size. @@ -17,10 +17,11 @@ defmodule LangChain.TextSplitter.CharacterTextSplitter do A `CharacterTextSplitter` is defined using a schema. * `separator` - String that splits a given text. - * `chunk_size` - Integer number of characters that a chunk should have. - * `chunk_overlap` - Integer number of characters that two consecutive chunks should share. + * `chunk_size` - Integer number of tokens that a chunk should have. + * `chunk_overlap` - Integer number of tokens that two consecutive chunks should share. * `keep_separator` - Either `:discard_separator`, `:start` or `:end`. If `:discard_separator`, the separator is discarded from the output chunks. `:start` and `:end` keep the separator at the start or end of the output chunks. Defaults to `:discard_separator`. * `is_separator_regex` - Boolean defaulting to `false`. If `true`, the `separator` string is not escaped. Defaults to `false` + * `tokenizer` - Function that takes a string and returns the number of tokens. Defaults to `&String.length/1`. """ use Ecto.Schema import Ecto.Changeset @@ -39,6 +40,7 @@ defmodule LangChain.TextSplitter.CharacterTextSplitter do default: :discard_separator field :is_separator_regex, :boolean, default: false + field :tokenizer, :any, virtual: true, default: &String.length/1 end @type t :: %CharacterTextSplitter{} @@ -48,7 +50,8 @@ defmodule LangChain.TextSplitter.CharacterTextSplitter do :chunk_size, :chunk_overlap, :keep_separator, - :is_separator_regex + :is_separator_regex, + :tokenizer ] @create_fields @update_fields diff --git a/lib/text_splitter/recursive_character_text_splitter.ex b/lib/text_splitter/recursive_character_text_splitter.ex index 488dbf1c..27f18a85 100644 --- a/lib/text_splitter/recursive_character_text_splitter.ex +++ b/lib/text_splitter/recursive_character_text_splitter.ex @@ -25,10 +25,11 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do A `RecursiveCharacterTextSplitter` is defined using a schema. * `separators` - List of string that split a given text. The default list is `["\n\n", "\n", " ", ""]`. - * `chunk_size` - Integer number of characters that a chunk should have. - * `chunk_overlap` - Integer number of characters that two consecutive chunks should share. + * `chunk_size` - Integer number of tokens that a chunk should have. + * `chunk_overlap` - Integer number of tokens that two consecutive chunks should share. * `keep_separator` - Either `:discard_separator`, `:start` or `:end`. If `nil`, the separator is discarded from the output chunks. `:start` and `:end` keep the separator at the start or end of the output chunks. Defaults to `start`. * `is_separator_regex` - Boolean defaulting to `false`. If `true`, the `separator` string is not escaped. Defaults to `false` + * `tokenizer` - Function that takes a string and returns the number of tokens. Defaults to `&String.length/1`. """ use Ecto.Schema import Ecto.Changeset @@ -48,6 +49,7 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do default: :start field :is_separator_regex, :boolean, default: false + field :tokenizer, :any, virtual: true, default: &String.length/1 end @type t :: %RecursiveCharacterTextSplitter{} @@ -57,7 +59,8 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do :chunk_size, :chunk_overlap, :keep_separator, - :is_separator_regex + :is_separator_regex, + :tokenizer ] @create_fields @update_fields @@ -89,7 +92,7 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do iex> split_tags = [",", "."] iex> base_params = %{chunk_size: 10, chunk_overlap: 0, separators: split_tags} iex> query = "Apple,banana,orange and tomato." - iex> splitter = RecursiveCharacterTextSplitter.new!(base_params) + iex> splitter = RecursiveCharacterTextSplitter.new!(base_params) iex> splitter |> RecursiveCharacterTextSplitter.split_text(query) ["Apple", ",banana", ",orange and tomato", "."] @@ -98,7 +101,7 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do iex> split_tags = [",", "."] iex> base_params = %{chunk_size: 10, chunk_overlap: 0, separators: split_tags, keep_separator: :end} iex> query = "Apple,banana,orange and tomato." - iex> splitter = RecursiveCharacterTextSplitter.new!(base_params) + iex> splitter = RecursiveCharacterTextSplitter.new!(base_params) iex> splitter |> RecursiveCharacterTextSplitter.split_text(query) ["Apple,", "banana,", "orange and tomato."] @@ -112,7 +115,7 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do ...>def hello_world(): ...> print('Hello, World') ...> - ...> + ...> ...># Call the function ...>hello_world()" iex> splitter = @@ -166,7 +169,7 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do recursive_splits = splits |> Enum.reduce(acc, fn split, acc -> - if String.length(split) < text_splitter.chunk_size do + if text_splitter.tokenizer.(split) < text_splitter.chunk_size do %{acc | good_splits: acc.good_splits ++ [split]} else acc = @@ -187,12 +190,10 @@ defmodule LangChain.TextSplitter.RecursiveCharacterTextSplitter do %{acc | final_chunks: acc.final_chunks ++ [split]} else new_recursive_splitter = - %{ - (text_splitter - |> Map.from_struct()) - | separators: new_separators |> Enum.drop(1), - is_separator_regex: true - } + text_splitter + |> Map.from_struct() + |> Map.put(:separators, new_separators |> Enum.drop(1)) + |> Map.put(:is_separator_regex, true) |> RecursiveCharacterTextSplitter.new!() other_info = diff --git a/test/text_splitter_test.exs b/test/text_splitter_test.exs index 2b8577a2..6f9f5c5f 100644 --- a/test/text_splitter_test.exs +++ b/test/text_splitter_test.exs @@ -201,6 +201,28 @@ defmodule TextSplitterTest do assert output == expected_output end end + + test "Custom tokenizer function" do + text = "foo bar baz" + # Custom tokenizer that counts words instead of characters + word_tokenizer = fn text -> text |> String.split() |> length() end + + expected_output = ["foo bar", "baz"] + + character_splitter = + CharacterTextSplitter.new!(%{ + separator: " ", + chunk_size: 2, + chunk_overlap: 0, + tokenizer: word_tokenizer + }) + + output = + character_splitter + |> CharacterTextSplitter.split_text(text) + + assert output == expected_output + end end describe "RecursiveCharacterTextSplitter" do @@ -306,6 +328,27 @@ Bye!\n\n-I." end end + describe "Custom tokenizer functionality" do + test "RecursiveCharacterTextSplitter with word tokenizer" do + text = "Hello world. This is a test. Another sentence here." + # Custom tokenizer that counts words instead of characters + word_tokenizer = fn text -> text |> String.split() |> length() end + + expected_output = ["Hello world", ". This is", "a test", ". Another sentence", "here."] + + splitter = + RecursiveCharacterTextSplitter.new!(%{ + separators: [". ", " "], + chunk_size: 3, + chunk_overlap: 0, + tokenizer: word_tokenizer + }) + + output = splitter |> RecursiveCharacterTextSplitter.split_text(text) + assert output == expected_output + end + end + describe "Programming languages splitters" do test "Python test splitter" do fake_python_text = """ @@ -497,7 +540,7 @@ message Person { string name = 1; int32 age = 2; repeated string hobbies = 3; -} +} " expected_splits = [ @@ -539,7 +582,7 @@ function helloWorld() { } // Call the function -helloWorld(); +helloWorld(); " expected_splits = [ @@ -1264,7 +1307,7 @@ end -- Some sample functions add :: Int -> Int -> Int - add x y = x + y + add x y = x + y " expected_splits = [