Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/umem_ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ aws-sdk-bedrockruntime = "1.120.0"
aws-smithy-types = {version = "1.3.5", features = ["serde-deserialize", "serde-serialize", "rt-tokio"]}
serde-saphyr = "0.0.14"
aws-sdk-bedrockagentruntime = "1.119.0"
futures.workspace = true
27 changes: 27 additions & 0 deletions crates/umem_ai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use thiserror::Error;
use tokio::sync::OnceCell;
use umem_config::CONFIG;

use crate::response_generators::embed::{EmbeddingRequest, EmbeddingResponse};

pub type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;

lazy_static! {
Expand Down Expand Up @@ -101,6 +103,12 @@ impl RerankingModel {
}
}

#[derive(Debug, Clone)]
pub struct EmbeddingModel {
pub provider: Arc<AIProvider>,
pub model_name: String,
}

#[derive(Debug)]
pub enum AIProvider {
OpenAI(OpenAIProvider),
Expand Down Expand Up @@ -164,6 +172,17 @@ impl AIProvider {
_ => unimplemented!(),
}
}

pub(crate) async fn do_embed(
&self,
request: EmbeddingRequest,
) -> Result<EmbeddingResponse, ResponseGeneratorError> {
match self {
AIProvider::AmazonBedrock(provider) => provider.embed(request),
_ => unimplemented!(),
}
.await
}
}

#[async_trait]
Expand Down Expand Up @@ -200,6 +219,14 @@ pub trait ReranksStructuredData {
T: Serialize + Clone + Send + Sync;
}

#[async_trait]
pub trait Embeds {
async fn embed(
&self,
request: EmbeddingRequest,
) -> Result<EmbeddingResponse, ResponseGeneratorError>;
}

impl From<OpenAIProvider> for AIProvider {
fn from(config: OpenAIProvider) -> Self {
AIProvider::OpenAI(config)
Expand Down
164 changes: 150 additions & 14 deletions crates/umem_ai/src/providers/amazon_bedrock.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{
GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText, OpenAIProvider,
Ranking, RerankRequest, RerankResponse, Reranks, ReranksStructuredData, SerializationMode,
StructuredRanking, StructuredRerankRequest, StructuredRerankResponse,
Embeds, GenerateObjectRequest, GenerateObjectResponse, GeneratesObject, GeneratesText,
OpenAIProvider, Ranking, RerankRequest, RerankResponse, Reranks, ReranksStructuredData,
SerializationMode, StructuredRanking, StructuredRerankRequest, StructuredRerankResponse,
embed::{EmbeddingRequest, EmbeddingResponse, embed as embedFn},
messages::{FilePart, UserModelMessage},
response_generators::{
self, GenerateTextRequest, GenerateTextResponse, ResponseGeneratorError,
Expand All @@ -17,19 +18,21 @@ use aws_sdk_bedrockagentruntime::types::{
RerankTextDocument, RerankingConfiguration, RerankingConfigurationType,
};
use aws_sdk_bedrockruntime::{
error::BuildError,
operation::converse::builders::ConverseFluentBuilder,
error::{BuildError, ProvideErrorMetadata},
operation::{converse::builders::ConverseFluentBuilder, invoke_model::InvokeModelOutput},
types::{
AnyToolChoice, ContentBlock, ConverseOutput, ImageBlock, InferenceConfiguration, Message,
Tool, ToolChoice, ToolConfiguration, ToolInputSchema, ToolSpecification,
},
};
use base64::Engine;
use futures::future::join_all;
use schemars::JsonSchema;
use serde::{Serialize, de::DeserializeOwned};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::Map;
use std::sync::Arc;
use std::{error::Error, sync::Arc};
use thiserror::Error;
use tokio::sync::Semaphore;

#[derive(Clone, Debug)]
pub struct AmazonBedrockProvider {
Expand All @@ -38,6 +41,20 @@ pub struct AmazonBedrockProvider {
bedrockagentruntime_client: Arc<aws_sdk_bedrockagentruntime::Client>,
}

impl AmazonBedrockProvider {
async fn default() -> Self {
Self::builder()
.region(std::env::var("AWS_REGION").expect("AWS_REGION not set"))
.access_key_id(std::env::var("AWS_ACCESS_KEY_ID").expect("AWS_ACCESS_KEY_ID not set"))
.secret_access_key(
std::env::var("AWS_SECRET_ACCESS_KEY").expect("AWS_SECRET_ACCESS_KEY not set"),
)
.build()
.await
.expect("Failed to build AmazonBedrockProvider based on default environment variables")
}
}

#[async_trait]
impl GeneratesText for AmazonBedrockProvider {
async fn generate_text(
Expand Down Expand Up @@ -72,7 +89,7 @@ impl GeneratesText for AmazonBedrockProvider {
.await
.map_err(|e| {
tracing::error!("{}", e);
ResponseGeneratorError::BedrockConverseError(format!("{:?}", e))
ResponseGeneratorError::BedrockConverseError(format!("{:?}", e.meta()))
})?;

let converse_output = match converse_response.output {
Expand Down Expand Up @@ -137,7 +154,7 @@ impl GeneratesObject for AmazonBedrockProvider {
))
.send()
.await
.map_err(|e| ResponseGeneratorError::BedrockConverseError(e.to_string()))?;
.map_err(|e| ResponseGeneratorError::BedrockConverseError(e.meta().to_string()))?;

let converse_output = match converse_response.output {
Some(output) => output,
Expand Down Expand Up @@ -247,7 +264,7 @@ impl Reranks for AmazonBedrockProvider {
)
.send()
.await
.map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.to_string()))?;
.map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.meta().to_string()))?;

let results = response.results();

Expand Down Expand Up @@ -406,7 +423,7 @@ impl ReranksStructuredData for AmazonBedrockProvider {
)
.send()
.await
.map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.to_string()))?;
.map_err(|e| ResponseGeneratorError::BedrockAgentRerankCommandSendError(e.meta().to_string()))?;

let results = response.results();

Expand Down Expand Up @@ -439,6 +456,95 @@ impl ReranksStructuredData for AmazonBedrockProvider {
}
}

#[async_trait]
impl Embeds for AmazonBedrockProvider {
async fn embed(
&self,
request: EmbeddingRequest,
) -> Result<EmbeddingResponse, ResponseGeneratorError> {
if request.input.is_empty() {
return Err(ResponseGeneratorError::InvalidArgumentsProvided(
"Embedding input cannot be empty".to_string(),
));
}

let semaphore = Arc::new(Semaphore::new(request.max_parallels));
let mut embedding_invoke_handles = Vec::with_capacity(request.input.len());

for data in request.input.into_iter() {
let permit = semaphore.clone().acquire_owned().await.map_err(|e| {
ResponseGeneratorError::InternalServerError(format!(
"Failed to acquire thread lock while making multiple requests. Details: {e}"
))
})?;

let bedrockruntime_client = Arc::clone(&self.bedrockruntime_client);
let model_name = request.model.model_name.clone();

let handle = tokio::spawn(async move {
let invoke_res = bedrockruntime_client
.invoke_model()
.model_id(&model_name)
.body(aws_smithy_types::Blob::new(
serde_json::json!({"inputText":data, "dimensions": request.dimensions, "normalize": request.normalize}).to_string(),
))
.send()
.await
.map_err(|e| {
ResponseGeneratorError::BedrockInvokeError(format!(
"Failed to invoke Bedrock embedding model: {}",
e.meta()
))
});
drop(permit);
invoke_res
});

embedding_invoke_handles.push(handle)
}

let results: Result<Vec<InvokeModelOutput>, ResponseGeneratorError> =
join_all(embedding_invoke_handles)
.await
.into_iter()
.map(|r| {
r.map_err(|e| {
ResponseGeneratorError::InternalServerError(format!(
"Failed to acquire result from the relevant thread. Details: {e}"
))
})
})
.map(|r| r.and_then(|inner| inner))
.collect();

let embeddings = results?
.into_iter()
.map(|r| {
serde_json::from_slice::<AmazonBedrockEmbeddingInvokeModelResponse>(
&r.body.into_inner(),
)
})
.collect::<Result<Vec<_>, _>>()
.map_err(|e| {
ResponseGeneratorError::Deserialization(
e,
"Failed to deserialize Bedrock embedding response".to_string(),
)
})?;

Ok(EmbeddingResponse {
embeddings: embeddings.into_iter().map(|e| e.embedding).collect(),
})
}
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct AmazonBedrockEmbeddingInvokeModelResponse {
embedding: Vec<f32>,
input_text_token_count: usize,
}

impl AmazonBedrockProvider {
fn normalize_generate_object_request<
T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned,
Expand Down Expand Up @@ -663,9 +769,9 @@ impl AmazonBedrockProviderBuilder {
mod tests {
use super::*;
use crate::{
AIProvider, GenerateObjectRequestBuilder, GenerateTextRequestBuilder, LanguageModel,
RerankingModel, SerializationFormat, generate_object, generate_text, rerank,
structured_rerank,
AIProvider, EmbeddingModel, GenerateObjectRequestBuilder, GenerateTextRequestBuilder,
LanguageModel, RerankingModel, SerializationFormat, embed, generate_object, generate_text,
rerank, structured_rerank,
};
use serde::Deserialize;
use std::sync::Arc;
Expand Down Expand Up @@ -889,4 +995,34 @@ mod tests {
let rerank_response = structured_rerank(request).await.unwrap();
dbg!(&rerank_response);
}

#[tokio::test]
async fn embedding_test() {
let provider = Arc::new(AIProvider::from(
AmazonBedrockProviderBuilder::default()
.region("REGION")
.access_key_id("ACESS_KEY_ID")
.secret_access_key("SECRET_ACCESS_KEY")
.build()
.await
.unwrap(),
));

let model = Arc::new(EmbeddingModel {
provider,
model_name: "amazon.titan-embed-text-v2:0".to_string(),
});

let request = EmbeddingRequest::builder()
.model(model)
.input(vec![
"The quick brown fox jumps over the lazy dog.".to_string(),
"To be or not to be, that is the question.".to_string(),
"All that glitters is not gold.".to_string(),
])
.build();

let embedding_response = embedFn(request).await.unwrap();
dbg!(&embedding_response);
}
}
12 changes: 12 additions & 0 deletions crates/umem_ai/src/providers/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ pub struct CohereProvider {
headers: HeaderMap,
}

impl Default for CohereProvider {
fn default() -> Self {
Self {
base_url: "https://api.cohere.com/v2".into(),
api_key: std::env::var("COHERE_API_KEY").expect(
"COHERE_API_KEY must be set to get a default implementation of cohere provider",
),
headers: HeaderMap::new(),
}
}
}

#[async_trait]
impl Reranks for CohereProvider {
async fn rerank(
Expand Down
12 changes: 12 additions & 0 deletions crates/umem_ai/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ pub struct OpenAIProvider {
pub project: Option<String>,
}

impl Default for OpenAIProvider {
fn default() -> Self {
Self {
api_key: std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set"),
base_url: "https://api.openai.com/v1".into(),
default_headers: HeaderMap::new(),
organization: None,
project: None,
}
}
}

impl OpenAIProvider {
pub fn normalize_generate_object_request<
T: Clone + JsonSchema + Serialize + Send + Sync + DeserializeOwned,
Expand Down
Loading