Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions src/blocking/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,40 +360,38 @@ impl std::fmt::Debug for BlockingUploadFileBuilder<'_> {
}

impl BlockingUploadFileBuilder<'_> {
/// Set the peer that owns the uploaded file (required).
#[must_use]
pub fn peer(self, id: impl Into<String>) -> Self {
fn with_inner(
self,
f: impl FnOnce(crate::UploadFileBuilder<'_>) -> crate::UploadFileBuilder<'_>,
) -> Self {
Self {
inner: self.inner.peer(id),
inner: f(self.inner),
reader_handle: self.reader_handle,
}
}

/// Set the peer that owns the uploaded file (required).
#[must_use]
pub fn peer(self, id: impl Into<String>) -> Self {
self.with_inner(|b| b.peer(id))
}

/// Attach arbitrary JSON metadata to the created message(s).
#[must_use]
pub fn metadata(self, value: Value) -> Self {
Self {
inner: self.inner.metadata(value),
reader_handle: self.reader_handle,
}
self.with_inner(|b| b.metadata(value))
}

/// Attach configuration to the created message(s).
#[must_use]
pub fn configuration(self, value: Value) -> Self {
Self {
inner: self.inner.configuration(value),
reader_handle: self.reader_handle,
}
self.with_inner(|b| b.configuration(value))
}

/// Override the creation timestamp (ISO 3339).
#[must_use]
pub fn created_at(self, dt: DateTime<Utc>) -> Self {
Self {
inner: self.inner.created_at(dt),
reader_handle: self.reader_handle,
}
self.with_inner(|b| b.created_at(dt))
}

/// Send the upload request and return the created messages.
Expand Down
51 changes: 29 additions & 22 deletions src/conclusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::http::routes;
use crate::types::conclusion::Conclusion as ConclusionData;
use crate::types::conclusion::ConclusionPage;
use crate::types::conclusion::{ConclusionBatchCreate, ConclusionCreate};
use crate::types::conclusion::{ConclusionFilters, ConclusionGet, ConclusionQuery};
use crate::types::dialectic::RepresentationResponse;
use crate::types::pagination::paginate_post;

Expand Down Expand Up @@ -690,14 +691,19 @@ impl ListConclusionsBuilder {
/// # }
/// ```
pub async fn send(self) -> Result<ConclusionPage> {
let mut filters = serde_json::json!({
"observer_id": self.scope.inner.observer,
"observed_id": self.scope.inner.observed,
});
if let Some(ref sid) = self.session_id {
filters["session_id"] = serde_json::Value::String(sid.clone());
}
let body = serde_json::json!({"filters": filters});
let body = ConclusionGet::builder()
.filters(
ConclusionFilters::builder()
.observer_id(self.scope.inner.observer.clone())
.observed_id(self.scope.inner.observed.clone())
.maybe_session_id(self.session_id)
.build(),
)
.build();
let body = serde_json::to_value(&body).map_err(|e| HonchoError::Decode {
path: "ConclusionGet".to_owned(),
source: e,
})?;
let route = routes::conclusions_list(&self.scope.inner.workspace_id)?;
paginate_post(
&self.scope.inner.http,
Expand Down Expand Up @@ -781,18 +787,21 @@ impl QueryConclusionsBuilder {
"distance must be between 0.0 and 1.0, got {d}"
)));
}
let filters = serde_json::json!({
"observer_id": self.scope.inner.observer,
"observed_id": self.scope.inner.observed,
});
let mut body = serde_json::json!({
"query": self.query,
"top_k": self.top_k,
"filters": filters,
});
if let Some(d) = self.distance {
body["distance"] = serde_json::Value::from(d);
}
let body = ConclusionQuery::builder()
.query(self.query)
.top_k(self.top_k)
.maybe_distance(self.distance)
.filters(
ConclusionFilters::builder()
.observer_id(self.scope.inner.observer.clone())
.observed_id(self.scope.inner.observed.clone())
.build(),
)
.build();
let body = serde_json::to_value(&body).map_err(|e| HonchoError::Decode {
path: "ConclusionQuery".to_owned(),
source: e,
})?;
let route = routes::conclusions_query(&self.scope.inner.workspace_id)?;
let data: Vec<ConclusionData> =
self.scope.inner.http.post(&route, Some(&body), &[]).await?;
Expand Down Expand Up @@ -1249,7 +1258,6 @@ mod tests {

let expected_body = serde_json::json!({
"query": "preferences",
"top_k": 10,
"filters": {
"observer_id": "alice",
"observed_id": "bob",
Expand Down Expand Up @@ -1403,7 +1411,6 @@ mod tests {
// Step 3: Query
let query_body = serde_json::json!({
"query": "preferences",
"top_k": 10,
"filters": {
"observer_id": "alice",
"observed_id": "bob",
Expand Down
186 changes: 78 additions & 108 deletions src/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,25 +150,7 @@ impl Peer {
/// # }
/// ```
pub async fn refresh(&self) -> Result<()> {
let mut body_map = serde_json::Map::new();
body_map.insert("id".into(), Value::String(self.inner.id.clone()));
let body = Value::Object(body_map);
let resp: PeerResponse = self
.inner
.http
.post(&routes::peers(&self.inner.workspace_id)?, Some(&body), &[])
.await?;
*self
.inner
.metadata
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(resp.metadata);
*self
.inner
.configuration
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner) =
map_to_peer_config(&resp.configuration)?;
self.fetch_and_update_cache().await?;
Ok(())
}

Expand Down Expand Up @@ -295,6 +277,11 @@ impl Peer {
/// [`PeerConfig`].
#[cfg_attr(feature = "tracing", tracing::instrument(skip(self), fields(peer_id = self.inner.id.as_str())))]
pub async fn get_configuration_raw(&self) -> Result<HashMap<String, Value>> {
let resp = self.fetch_and_update_cache().await?;
Ok(resp.configuration)
}

async fn fetch_and_update_cache(&self) -> Result<PeerResponse> {
let mut body_map = serde_json::Map::new();
body_map.insert("id".into(), Value::String(self.inner.id.clone()));
let body = Value::Object(body_map);
Expand All @@ -314,7 +301,7 @@ impl Peer {
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner) =
map_to_peer_config(&resp.configuration)?;
Ok(resp.configuration)
Ok(resp)
}

/// Set the peer's configuration on the server from a raw JSON map.
Expand Down Expand Up @@ -607,31 +594,26 @@ impl Peer {
&self,
options: &crate::types::peer::PeerContextOptions,
) -> Result<PeerContext> {
let route = routes::peer_context(&self.inner.workspace_id, &self.inner.id)?;
let mut params: Vec<(&str, String)> = Vec::new();
let mut builder = self.context_builder();
if let Some(ref v) = options.target {
params.push(("target", v.clone()));
builder = builder.target(v.clone());
}
if let Some(ref v) = options.search_query {
params.push(("search_query", v.clone()));
builder = builder.search_query(v.clone());
}
if let Some(ref v) = options.search_top_k {
params.push(("search_top_k", v.to_string()));
if let Some(v) = options.search_top_k {
builder = builder.search_top_k(v);
}
if let Some(ref v) = options.search_max_distance {
params.push(("search_max_distance", v.to_string()));
if let Some(v) = options.search_max_distance {
builder = builder.search_max_distance(v);
}
if let Some(ref v) = options.include_most_frequent {
params.push((
"include_most_frequent",
if *v { "true" } else { "false" }.to_string(),
));
if let Some(v) = options.include_most_frequent {
builder = builder.include_most_frequent(v);
}
if let Some(ref v) = options.max_conclusions {
params.push(("max_conclusions", v.to_string()));
if let Some(v) = options.max_conclusions {
builder = builder.max_conclusions(v);
}
let refs: Vec<(&str, &str)> = params.iter().map(|(k, v)| (*k, v.as_str())).collect();
self.inner.http.get(&route, &refs).await
builder.send().await
}

// ── Sessions ───────────────────────────────────────────────────────
Expand Down Expand Up @@ -1195,27 +1177,11 @@ impl RepresentationBuilder {
/// or `max_conclusions` are out of range.
#[cfg_attr(feature = "tracing", tracing::instrument(skip(self), fields(peer_id = self.peer_id.as_str())))]
pub async fn send(self) -> Result<String> {
if let Some(k) = self.search_top_k
&& !(1..=100).contains(&k)
{
return Err(HonchoError::Validation(format!(
"search_top_k must be between 1 and 100, got {k}"
)));
}
if let Some(d) = self.search_max_distance
&& !(0.0..=1.0).contains(&d)
{
return Err(HonchoError::Validation(format!(
"search_max_distance must be between 0.0 and 1.0, got {d}"
)));
}
if let Some(c) = self.max_conclusions
&& !(1..=100).contains(&c)
{
return Err(HonchoError::Validation(format!(
"max_conclusions must be between 1 and 100, got {c}"
)));
}
validate_search_params(
self.search_top_k,
self.search_max_distance,
self.max_conclusions,
)?;

let params = crate::types::peer::PeerRepresentationGet {
session_id: self.session_id,
Expand Down Expand Up @@ -1332,60 +1298,34 @@ impl ContextBuilder {
/// or `max_conclusions` are out of range.
#[cfg_attr(feature = "tracing", tracing::instrument(skip(self), fields(peer_id = self.peer_id.as_str())))]
pub async fn send(self) -> Result<PeerContext> {
if let Some(k) = self.search_top_k
&& !(1..=100).contains(&k)
{
return Err(HonchoError::Validation(format!(
"search_top_k must be between 1 and 100, got {k}"
)));
}
if let Some(d) = self.search_max_distance
&& !(0.0..=1.0).contains(&d)
{
return Err(HonchoError::Validation(format!(
"search_max_distance must be between 0.0 and 1.0, got {d}"
)));
}
if let Some(c) = self.max_conclusions
&& !(1..=100).contains(&c)
{
return Err(HonchoError::Validation(format!(
"max_conclusions must be between 1 and 100, got {c}"
)));
validate_search_params(
self.search_top_k,
self.search_max_distance,
self.max_conclusions,
)?;

macro_rules! push_param {
($params:expr, $key:expr, $val:expr) => {
if let Some(v) = $val {
$params.push(($key, v.to_string()));
}
};
}

let route = routes::peer_context(&self.workspace_id, &self.peer_id)?;
let mut params: Vec<(&str, String)> = Vec::new();
if let Some(ref v) = self.target {
params.push(("target", v.clone()));
}
if let Some(v) = self.summary {
params.push(("summary", if v { "true" } else { "false" }.to_string()));
if let Some(v) = self.target {
params.push(("target", v));
}
if let Some(v) = self.limit_to_session {
params.push((
"limit_to_session",
if v { "true" } else { "false" }.to_string(),
));
}
if let Some(ref v) = self.search_query {
params.push(("search_query", v.clone()));
}
if let Some(v) = self.search_top_k {
params.push(("search_top_k", v.to_string()));
}
if let Some(v) = self.search_max_distance {
params.push(("search_max_distance", v.to_string()));
}
if let Some(v) = self.include_most_frequent {
params.push((
"include_most_frequent",
if v { "true" } else { "false" }.to_string(),
));
}
if let Some(v) = self.max_conclusions {
params.push(("max_conclusions", v.to_string()));
push_param!(params, "summary", self.summary);
push_param!(params, "limit_to_session", self.limit_to_session);
if let Some(v) = self.search_query {
params.push(("search_query", v));
}
push_param!(params, "search_top_k", self.search_top_k);
push_param!(params, "search_max_distance", self.search_max_distance);
push_param!(params, "include_most_frequent", self.include_most_frequent);
push_param!(params, "max_conclusions", self.max_conclusions);
let refs: Vec<(&str, &str)> = params.iter().map(|(k, v)| (*k, v.as_str())).collect();
self.http.get(&route, &refs).await
}
Expand Down Expand Up @@ -1489,9 +1429,39 @@ impl MessageBuilder {
}
}

fn validate_search_params(
search_top_k: Option<u32>,
search_max_distance: Option<f64>,
max_conclusions: Option<u32>,
) -> Result<()> {
if let Some(k) = search_top_k
&& !(1..=100).contains(&k)
{
return Err(HonchoError::Validation(format!(
"search_top_k must be between 1 and 100, got {k}"
)));
}
if let Some(d) = search_max_distance
&& !(0.0..=1.0).contains(&d)
{
return Err(HonchoError::Validation(format!(
"search_max_distance must be between 0.0 and 1.0, got {d}"
)));
}
if let Some(c) = max_conclusions
&& !(1..=100).contains(&c)
{
return Err(HonchoError::Validation(format!(
"max_conclusions must be between 1 and 100, got {c}"
)));
}
Ok(())
}

fn map_to_peer_config(map: &HashMap<String, Value>) -> Result<Option<PeerConfig>> {
let val = serde_json::to_value(map).map_err(|e| HonchoError::Configuration(e.to_string()))?;
serde_json::from_value(val)
let obj: serde_json::Map<String, Value> =
map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
serde_json::from_value(Value::Object(obj))
.map(Some)
.map_err(|e| HonchoError::Configuration(e.to_string()))
}
Expand Down
Loading
Loading