diff --git a/src/query/compound.rs b/src/query/compound.rs index c29e425..11a3681 100644 --- a/src/query/compound.rs +++ b/src/query/compound.rs @@ -20,7 +20,9 @@ use serde::{Serialize, Serializer}; use crate::{json::ShouldSkip, units::OneOrMany}; -use super::{functions::Function, MinimumShouldMatch, Query, ScoreMode}; +use super::{ + functions::FilteredFunction, functions::Function, MinimumShouldMatch, Query, ScoreMode, +}; /// BoostMode #[derive(Debug, Copy, Clone)] @@ -153,7 +155,7 @@ pub struct FunctionScoreQuery { query: Option, #[serde(skip_serializing_if = "ShouldSkip::should_skip")] boost: Option, - functions: Vec, + functions: Vec, #[serde(skip_serializing_if = "ShouldSkip::should_skip")] max_boost: Option, #[serde(skip_serializing_if = "ShouldSkip::should_skip")] @@ -178,12 +180,12 @@ impl FunctionScoreQuery { add_field!(with_boost_mode, boost_mode, BoostMode); add_field!(with_min_score, min_score, f64); - pub fn with_functions>>(mut self, functions: A) -> Self { + pub fn with_functions>>(mut self, functions: A) -> Self { self.functions = functions.into(); self } - pub fn with_function>(mut self, function: A) -> Self { + pub fn with_function>(mut self, function: A) -> Self { self.functions = vec![function.into()]; self } diff --git a/src/query/functions.rs b/src/query/functions.rs index 11c5ea5..ec05885 100644 --- a/src/query/functions.rs +++ b/src/query/functions.rs @@ -25,6 +25,29 @@ use crate::{ units::{Distance, Duration, JsonVal, Location}, }; +use super::Query; + +/// FilteredFunction +#[derive(Debug, Serialize)] +pub struct FilteredFunction { + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + pub filter: Option, + #[serde(flatten)] + pub function: Function, +} + +impl FilteredFunction { + pub fn build_filtered_function>>( + filter: A, + function: Function, + ) -> FilteredFunction { + FilteredFunction { + filter: filter.into(), + function, + } + } +} + /// Function #[derive(Debug, Serialize)] pub enum Function { @@ -239,11 +262,7 @@ impl DecayOptions { } pub fn build>(self, field: A) -> Decay { - Decay(FieldBased::new( - field.into(), - self, - NoOuter, - )) + Decay(FieldBased::new(field.into(), self, NoOuter)) } } diff --git a/src/query/mod.rs b/src/query/mod.rs index 579df61..9104e26 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -409,7 +409,7 @@ mod tests { extern crate serde_json; use super::full_text::SimpleQueryStringFlags; - use super::functions::Function; + use super::functions::{FilteredFunction, Function}; use super::term::TermsQueryLookup; use super::{Flags, Query}; @@ -449,12 +449,13 @@ mod tests { #[test] fn test_function_score_query() { let function_score_query = Query::build_function_score() - .with_function( - Function::build_script_score("this_is_a_script") + .with_function(FilteredFunction { + filter: None, + function: Function::build_script_score("this_is_a_script") .with_lang("made_up") .add_param("A", 12) .build(), - ) + }) .build(); assert_eq!("{\"function_score\":{\"functions\":[{\"script_score\":{\"lang\":\"made_up\",\"params\":{\"A\":12},\"inline\":\"this_is_a_script\"}}]}}", serde_json::to_string(&function_score_query).unwrap());