diff --git a/Cargo.lock b/Cargo.lock index ab4b2d3..6bf145c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5168,6 +5168,7 @@ dependencies = [ "aws-smithy-types", "backon", "base64", + "futures", "lazy_static", "mime", "reqwest", diff --git a/crates/umem_ai/Cargo.toml b/crates/umem_ai/Cargo.toml index 1392093..b877ef1 100644 --- a/crates/umem_ai/Cargo.toml +++ b/crates/umem_ai/Cargo.toml @@ -25,3 +25,4 @@ aws-sdk-bedrockruntime = "1.120.0" aws-smithy-types = {version = "1.3.5", features = ["serde-deserialize", "serde-serialize", "rt-tokio"]} serde-saphyr = "0.0.14" aws-sdk-bedrockagentruntime = "1.119.0" +futures.workspace = true diff --git a/crates/umem_ai/src/lib.rs b/crates/umem_ai/src/lib.rs index 1393ead..2e9c6b3 100644 --- a/crates/umem_ai/src/lib.rs +++ b/crates/umem_ai/src/lib.rs @@ -16,6 +16,8 @@ use thiserror::Error; use tokio::sync::OnceCell; use umem_config::CONFIG; +use crate::response_generators::embed::{EmbeddingRequest, EmbeddingResponse}; + pub type HashMap = rustc_hash::FxHashMap; lazy_static! { @@ -101,6 +103,12 @@ impl RerankingModel { } } +#[derive(Debug, Clone)] +pub struct EmbeddingModel { + pub provider: Arc, + pub model_name: String, +} + #[derive(Debug)] pub enum AIProvider { OpenAI(OpenAIProvider), @@ -164,6 +172,17 @@ impl AIProvider { _ => unimplemented!(), } } + + pub(crate) async fn do_embed( + &self, + request: EmbeddingRequest, + ) -> Result { + match self { + AIProvider::AmazonBedrock(provider) => provider.embed(request), + _ => unimplemented!(), + } + .await + } } #[async_trait] @@ -200,6 +219,14 @@ pub trait ReranksStructuredData { T: Serialize + Clone + Send + Sync; } +#[async_trait] +pub trait Embeds { + async fn embed( + &self, + request: EmbeddingRequest, + ) -> Result; +} + impl From for AIProvider { fn from(config: OpenAIProvider) -> Self { AIProvider::OpenAI(config) diff --git a/crates/umem_ai/src/providers/amazon_bedrock.rs b/crates/umem_ai/src/providers/amazon_bedrock.rs index 6918c8a..ea8d46a 100644 --- a/crates/umem_ai/src/providers/amazon_bedrock.rs +++ b/crates/umem_ai/src/providers/amazon_bedrock.rs @@ -1,7 +1,8 @@ use crate::{ - GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText, OpenAIProvider, - Ranking, RerankRequest, RerankResponse, Reranks, ReranksStructuredData, SerializationMode, - StructuredRanking, StructuredRerankRequest, StructuredRerankResponse, + Embeds, GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText, + OpenAIProvider, Ranking, RerankRequest, RerankResponse, Reranks, ReranksStructuredData, + SerializationMode, StructuredRanking, StructuredRerankRequest, StructuredRerankResponse, + embed::{EmbeddingRequest, EmbeddingResponse, embed as embedFn}, messages::{FilePart, UserModelMessage}, response_generators::{ self, GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError, @@ -17,19 +18,21 @@ use aws_sdk_bedrockagentruntime::types::{ RerankTextDocument, RerankingConfiguration, RerankingConfigurationType, }; use aws_sdk_bedrockruntime::{ - error::BuildError, - operation::converse::builders::ConverseFluentBuilder, + error::{BuildError, ProvideErrorMetadata}, + operation::{converse::builders::ConverseFluentBuilder, invoke_model::InvokeModelOutput}, types::{ AnyToolChoice, ContentBlock, ConverseOutput, ImageBlock, InferenceConfiguration, Message, Tool, ToolChoice, ToolConfiguration, ToolInputSchema, ToolSpecification, }, }; use base64::Engine; +use futures::future::join_all; use schemars::JsonSchema; -use serde::{Serialize, de::DeserializeOwned}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::Map; -use std::sync::Arc; +use std::{error::Error, sync::Arc}; use thiserror::Error; +use tokio::sync::Semaphore; #[derive(Clone, Debug)] pub struct AmazonBedrockProvider { @@ -38,6 +41,20 @@ pub struct AmazonBedrockProvider { bedrockagentruntime_client: Arc, } +impl AmazonBedrockProvider { + async fn default() -> Self { + Self::builder() + .region(std::env::var("AWS_REGION").expect("AWS_REGION not set")) + .access_key_id(std::env::var("AWS_ACCESS_KEY_ID").expect("AWS_ACCESS_KEY_ID not set")) + .secret_access_key( + std::env::var("AWS_SECRET_ACCESS_KEY").expect("AWS_SECRET_ACCESS_KEY not set"), + ) + .build() + .await + .expect("Failed to build AmazonBedrockProvider based on default environment variables") + } +} + #[async_trait] impl GeneratesText for AmazonBedrockProvider { async fn generate_text( @@ -72,7 +89,7 @@ impl GeneratesText for AmazonBedrockProvider { .await .map_err(|e| { tracing::error!("{}", e); - ResponseGeneratorError::BedrockConverseError(format!("{:?}", e)) + ResponseGeneratorError::BedrockConverseError(format!("{:?}", e.meta())) })?; let converse_output = match converse_response.output { @@ -137,7 +154,7 @@ impl GeneratesObject for AmazonBedrockProvider { )) .send() .await - .map_err(|e| ResponseGeneratorError::BedrockConverseError(e.to_string()))?; + .map_err(|e| ResponseGeneratorError::BedrockConverseError(e.meta().to_string()))?; let converse_output = match converse_response.output { Some(output) => output, @@ -247,7 +264,7 @@ impl Reranks for AmazonBedrockProvider { ) .send() .await - .map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.to_string()))?; + .map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.meta().to_string()))?; let results = response.results(); @@ -406,7 +423,7 @@ impl ReranksStructuredData for AmazonBedrockProvider { ) .send() .await - .map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.to_string()))?; + .map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.meta().to_string()))?; let results = response.results(); @@ -439,6 +456,95 @@ impl ReranksStructuredData for AmazonBedrockProvider { } } +#[async_trait] +impl Embeds for AmazonBedrockProvider { + async fn embed( + &self, + request: EmbeddingRequest, + ) -> Result { + if request.input.is_empty() { + return Err(ResponseGeneratorError::InvalidArgumentsProvided( + "Embedding input cannot be empty".to_string(), + )); + } + + let semaphore = Arc::new(Semaphore::new(request.max_parallels)); + let mut embedding_invoke_handles = Vec::with_capacity(request.input.len()); + + for data in request.input.into_iter() { + let permit = semaphore.clone().acquire_owned().await.map_err(|e| { + ResponseGeneratorError::InternalServerError(format!( + "Failed to acquire thread lock while making multiple requests. Details: {e}" + )) + })?; + + let bedrockruntime_client = Arc::clone(&self.bedrockruntime_client); + let model_name = request.model.model_name.clone(); + + let handle = tokio::spawn(async move { + let invoke_res = bedrockruntime_client + .invoke_model() + .model_id(&model_name) + .body(aws_smithy_types::Blob::new( + serde_json::json!({"inputText":data, "dimensions": request.dimensions, "normalize": request.normalize}).to_string(), + )) + .send() + .await + .map_err(|e| { + ResponseGeneratorError::BedrockInvokeError(format!( + "Failed to invoke Bedrock embedding model: {}", + e.meta() + )) + }); + drop(permit); + invoke_res + }); + + embedding_invoke_handles.push(handle) + } + + let results: Result, ResponseGeneratorError> = + join_all(embedding_invoke_handles) + .await + .into_iter() + .map(|r| { + r.map_err(|e| { + ResponseGeneratorError::InternalServerError(format!( + "Failed to acquire result from the relevant thread. Details: {e}" + )) + }) + }) + .map(|r| r.and_then(|inner| inner)) + .collect(); + + let embeddings = results? + .into_iter() + .map(|r| { + serde_json::from_slice::( + &r.body.into_inner(), + ) + }) + .collect::, _>>() + .map_err(|e| { + ResponseGeneratorError::Deserialization( + e, + "Failed to deserialize Bedrock embedding response".to_string(), + ) + })?; + + Ok(EmbeddingResponse { + embeddings: embeddings.into_iter().map(|e| e.embedding).collect(), + }) + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct AmazonBedrockEmbeddingInvokeModelResponse { + embedding: Vec, + input_text_token_count: usize, +} + impl AmazonBedrockProvider { fn normalize_generate_object_request< T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned, @@ -663,9 +769,9 @@ impl AmazonBedrockProviderBuilder { mod tests { use super::*; use crate::{ - AIProvider, GenerateObjectRequestBuilder, GenerateTextRequestBuilder, LanguageModel, - RerankingModel, SerializationFormat, generate_object, generate_text, rerank, - structured_rerank, + AIProvider, EmbeddingModel, GenerateObjectRequestBuilder, GenerateTextRequestBuilder, + LanguageModel, RerankingModel, SerializationFormat, embed, generate_object, generate_text, + rerank, structured_rerank, }; use serde::Deserialize; use std::sync::Arc; @@ -889,4 +995,34 @@ mod tests { let rerank_response = structured_rerank(request).await.unwrap(); dbg!(&rerank_response); } + + #[tokio::test] + async fn embedding_test() { + let provider = Arc::new(AIProvider::from( + AmazonBedrockProviderBuilder::default() + .region("REGION") + .access_key_id("ACESS_KEY_ID") + .secret_access_key("SECRET_ACCESS_KEY") + .build() + .await + .unwrap(), + )); + + let model = Arc::new(EmbeddingModel { + provider, + model_name: "amazon.titan-embed-text-v2:0".to_string(), + }); + + let request = EmbeddingRequest::builder() + .model(model) + .input(vec![ + "The quick brown fox jumps over the lazy dog.".to_string(), + "To be or not to be, that is the question.".to_string(), + "All that glitters is not gold.".to_string(), + ]) + .build(); + + let embedding_response = embedFn(request).await.unwrap(); + dbg!(&embedding_response); + } } diff --git a/crates/umem_ai/src/providers/cohere.rs b/crates/umem_ai/src/providers/cohere.rs index bc96b80..5539362 100644 --- a/crates/umem_ai/src/providers/cohere.rs +++ b/crates/umem_ai/src/providers/cohere.rs @@ -36,6 +36,18 @@ pub struct CohereProvider { headers: HeaderMap, } +impl Default for CohereProvider { + fn default() -> Self { + Self { + base_url: "https://api.cohere.com/v2".into(), + api_key: std::env::var("COHERE_API_KEY").expect( + "COHERE_API_KEY must be set to get a default implementation of cohere provider", + ), + headers: HeaderMap::new(), + } + } +} + #[async_trait] impl Reranks for CohereProvider { async fn rerank( diff --git a/crates/umem_ai/src/providers/openai.rs b/crates/umem_ai/src/providers/openai.rs index b23497d..dcd7d9a 100644 --- a/crates/umem_ai/src/providers/openai.rs +++ b/crates/umem_ai/src/providers/openai.rs @@ -35,6 +35,18 @@ pub struct OpenAIProvider { pub project: Option, } +impl Default for OpenAIProvider { + fn default() -> Self { + Self { + api_key: std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set"), + base_url: "https://api.openai.com/v1".into(), + default_headers: HeaderMap::new(), + organization: None, + project: None, + } + } +} + impl OpenAIProvider { pub fn normalize_generate_object_request< T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned, diff --git a/crates/umem_ai/src/response_generators/embed.rs b/crates/umem_ai/src/response_generators/embed.rs new file mode 100644 index 0000000..872e7e8 --- /dev/null +++ b/crates/umem_ai/src/response_generators/embed.rs @@ -0,0 +1,76 @@ +use crate::{ + EmbeddingModel, ResponseGeneratorError, + utils::{self, is_retryable_error}, +}; +use backon::{ExponentialBuilder, Retryable}; +use reqwest::header::HeaderMap; +use std::{sync::Arc, time::Duration}; + +pub async fn embed( + mut request: EmbeddingRequest, +) -> Result { + let per_request_timeout = request.timeout; + let max_retries = request.max_retries; + let total_delay = per_request_timeout.mul_f32(max_retries as f32 / 2.0); + + let generation = || { + let model = Arc::clone(&request.model); + let provider = Arc::clone(&model.provider); + let request = request.clone(); + + async move { + tokio::time::timeout(per_request_timeout, provider.do_embed(request)) + .await + .map_err(ResponseGeneratorError::TimeoutError) + .flatten() + } + }; + + generation + .retry( + ExponentialBuilder::default() + .with_max_times(max_retries) + .with_total_delay(Some(total_delay)), + ) + .sleep(tokio::time::sleep) + .when(is_retryable_error) + .notify(|err, dur| { + tracing::debug!("retrying {:?} after {:?}", err, dur); + }) + .await +} + +#[derive(Clone, Debug, typed_builder::TypedBuilder)] +pub struct EmbeddingRequest { + pub model: Arc, + + #[builder(setter(transform = |value: impl IntoIterator| + value.into_iter().collect()) + )] + pub input: Vec, + + #[builder(default = 3_usize)] + pub max_retries: usize, + + #[builder(default = 1000_usize)] + pub max_parallels: usize, + + #[builder(default, setter(transform = |value: Vec<(String, String)>| + utils::build_header_map(value.as_slice()).unwrap_or_default() + ))] + pub custom_headers: HeaderMap, + + #[builder(default = Duration::from_secs(60))] + pub timeout: Duration, + + #[builder(default = 1024)] + pub dimensions: usize, + + #[builder(default = true)] + pub normalize: bool, +} + +#[derive(Debug, Clone)] +pub struct EmbeddingResponse { + pub embeddings: Vec>, +} diff --git a/crates/umem_ai/src/response_generators/mod.rs b/crates/umem_ai/src/response_generators/mod.rs index b967ef0..a7c4b9a 100644 --- a/crates/umem_ai/src/response_generators/mod.rs +++ b/crates/umem_ai/src/response_generators/mod.rs @@ -1,3 +1,4 @@ +pub mod embed; pub mod generate_object; pub mod generate_text; pub mod messages; @@ -9,8 +10,8 @@ pub use generate_text::*; pub use messages::*; pub use rerank::*; pub use structured_rerank::*; - use thiserror::Error; + #[derive(Error, Debug)] pub enum ResponseGeneratorError { #[error(transparent)] @@ -23,12 +24,16 @@ pub enum ResponseGeneratorError { BedrockConverseError(String), #[error("BedrockAgentRuntime Rerank Command error, Details: {0}")] BedrockAgentRerankCommandSendError(String), + #[error("BedrockRuntime Invoke API error, Details: {0}")] + BedrockInvokeError(String), #[error("empty response from AI provider")] EmptyProviderResponse, #[error("invalid response from AI provider, Details: {0}")] InvalidProviderResponse(String), #[error("invalid arguments provided: {0}")] InvalidArgumentsProvided(String), + #[error("Internal Server Error: {0}")] + InternalServerError(String), #[error(transparent)] Transient(#[from] anyhow::Error), #[error("yaml serialization error: {0}")] diff --git a/crates/umem_ai/src/utils.rs b/crates/umem_ai/src/utils.rs index 8edd6d1..a71ad91 100644 --- a/crates/umem_ai/src/utils.rs +++ b/crates/umem_ai/src/utils.rs @@ -64,6 +64,14 @@ pub fn is_retryable_error(e: &ResponseGeneratorError) -> bool { } ResponseGeneratorError::BedrockAgentRerankCommandSendError(e) => { tracing::error!("Bedrock agent rerank command send error: {}", e); + false + } + ResponseGeneratorError::BedrockInvokeError(e) => { + tracing::error!("Bedrock agent embed invoke command error: {}", e); + false + } + ResponseGeneratorError::InternalServerError(e) => { + tracing::error!("Internal Server Error: {}", e); true } }