diff --git a/bindings/typescript/src/converters.ts b/bindings/typescript/src/converters.ts index b9086c93..5b7bfec3 100644 --- a/bindings/typescript/src/converters.ts +++ b/bindings/typescript/src/converters.ts @@ -52,35 +52,6 @@ export class ConversionError extends Error { // Generic converter factory // ============================================================================ -/** - * Convert Map objects to plain objects recursively. - * This is needed because serde-wasm-bindgen serializes serde_json::Map to JS Map - * instead of plain objects. - */ -function convertMapsToObjects(value: unknown): unknown { - if (value instanceof Map) { - const obj: Record = {}; - for (const [key, val] of value.entries()) { - obj[key] = convertMapsToObjects(val); - } - return obj; - } - - if (Array.isArray(value)) { - return value.map((item) => convertMapsToObjects(item)); - } - - if (value !== null && typeof value === "object") { - const obj: Record = {}; - for (const [key, val] of Object.entries(value)) { - obj[key] = convertMapsToObjects(val); - } - return obj; - } - - return value; -} - /** * Creates a converter function that transforms provider format to Lingua * @param wasmFn - The WASM function to call @@ -93,9 +64,7 @@ function createToLinguaConverter( ): (input: unknown) => TOutput { return (input: unknown): TOutput => { try { - const result = wasmFn()(input); - // Convert any Map objects to plain objects - return convertMapsToObjects(result) as TOutput; + return wasmFn()(input) as TOutput; } catch (error: unknown) { throw new ConversionError( `Failed to convert ${provider} message to Lingua`, @@ -119,9 +88,7 @@ function createFromLinguaConverter( ): (input: TInput) => T { return (input: TInput): T => { try { - const result = wasmFn()(input); - // Convert any Map objects to plain objects - return convertMapsToObjects(result) as T; + return wasmFn()(input) as T; } catch (error: unknown) { throw new ConversionError( `Failed to convert Lingua to ${provider} format`, @@ -338,9 +305,7 @@ export const linguaToGoogleContents = createFromLinguaConverter< */ export function deduplicateMessages(messages: Message[]): Message[] { try { - const result = getWasm().deduplicate_messages(messages); - // Convert any Map objects to plain objects - return convertMapsToObjects(result) as Message[]; + return getWasm().deduplicate_messages(messages) as Message[]; } catch (error: unknown) { throw new ConversionError( "Failed to deduplicate messages", @@ -369,9 +334,7 @@ export function importMessagesFromSpans( spans: ImportSpan[] ): Message[] { try { - const result = getWasm().import_messages_from_spans(spans); - // Convert any Map objects to plain objects - return convertMapsToObjects(result) as Message[]; + return getWasm().import_messages_from_spans(spans) as Message[]; } catch (error: unknown) { throw new ConversionError( "Failed to import messages from spans", @@ -396,9 +359,7 @@ export function importAndDeduplicateMessages( spans: ImportSpan[] ): Message[] { try { - const result = getWasm().import_and_deduplicate_messages(spans); - // Convert any Map objects to plain objects - return convertMapsToObjects(result) as Message[]; + return getWasm().import_and_deduplicate_messages(spans) as Message[]; } catch (error: unknown) { throw new ConversionError( "Failed to import and deduplicate messages from spans", diff --git a/bindings/typescript/tests/node-exports.test.ts b/bindings/typescript/tests/node-exports.test.ts index ae63879f..6c76c0f6 100644 --- a/bindings/typescript/tests/node-exports.test.ts +++ b/bindings/typescript/tests/node-exports.test.ts @@ -179,6 +179,27 @@ describe("Node.js exports", () => { expect(value).toBe(9007199254740993n); }); + test("should deduplicate messages across spans and return plain-object results", async () => { + const { importAndDeduplicateMessages } = await import("../src/index"); + + const sharedTurn = [ + { role: "user", content: "what is 2+2?" }, + { role: "assistant", content: "4" }, + ]; + + const messages = importAndDeduplicateMessages([ + { input: sharedTurn }, + { input: sharedTurn, output: { role: "assistant", content: "4" } }, + { input: [{ role: "user", content: "and 3+3?" }] }, + ]); + + expect(messages).toEqual([ + { role: "user", content: "what is 2+2?" }, + { role: "assistant", id: null, content: "4" }, + { role: "user", content: "and 3+3?" }, + ]); + }); + test("should NOT export browser-specific init function", async () => { const exports = await import("../src/index"); diff --git a/crates/lingua/src/processing/import.rs b/crates/lingua/src/processing/import.rs index 0921e85a..92330596 100644 --- a/crates/lingua/src/processing/import.rs +++ b/crates/lingua/src/processing/import.rs @@ -606,17 +606,19 @@ fn try_choices_array_parsing(data: &Value) -> Option> { pub fn import_messages_from_spans(spans: Vec) -> Vec { let mut messages = Vec::new(); - for span in spans { + for mut span in spans { let mut span_messages = Vec::new(); - // Try to extract messages from input - if let Some(Value::String(input_text)) = &span.input { - span_messages.push(Message::User { - content: UserContent::String(input_text.clone()), - }); - } else if let Some(input) = &span.input { - let input_messages = try_converting_to_messages(input); - span_messages.extend(input_messages); + match span.input.take() { + Some(Value::String(input_text)) => { + span_messages.push(Message::User { + content: UserContent::String(input_text), + }); + } + Some(input) => { + span_messages.extend(try_converting_to_messages(&input)); + } + None => {} } #[cfg(feature = "openai")] @@ -633,17 +635,17 @@ pub fn import_messages_from_spans(spans: Vec) -> Vec { messages.extend(span_messages); - // Try to extract messages from output - if let Some(Value::String(output_text)) = &span.output { - if !output_text.is_empty() { + match span.output.take() { + Some(Value::String(output_text)) if !output_text.is_empty() => { messages.push(Message::Assistant { - content: AssistantContent::String(output_text.clone()), + content: AssistantContent::String(output_text), id: None, }); } - } else if let Some(output) = &span.output { - let output_messages = try_converting_to_messages(output); - messages.extend(output_messages); + Some(output) => { + messages.extend(try_converting_to_messages(&output)); + } + None => {} } }