From 0526972d709a27c6e76e5e775c9928b71005324c Mon Sep 17 00:00:00 2001 From: Ashton Kinslow Date: Mon, 7 Aug 2023 15:52:33 -0500 Subject: [PATCH 1/7] support custom ops --- config.go | 57 +++++++++++++--- rql.go | 183 ++++++++++++++++++++++++++++++++++++---------------- rql_test.go | 41 ++++++++++++ 3 files changed, 216 insertions(+), 65 deletions(-) diff --git a/config.go b/config.go index ae4b229..803a7ae 100644 --- a/config.go +++ b/config.go @@ -8,14 +8,12 @@ import ( // Op is a filter operator used by rql. type Op string - -// SQL returns the SQL representation of the operator. -func (o Op) SQL() string { - return opFormat[o] -} +type Direction byte // Operators that support by rql. const ( + ASC = Direction('+') + DESC = Direction('-') EQ = Op("eq") // = NEQ = Op("neq") // <> LT = Op("lt") // < @@ -43,9 +41,9 @@ var ( // A sorting expression can be optionally prefixed with + or - to control the // sorting direction, ascending or descending. For example, '+field' or '-field'. // If the predicate is missing or empty then it defaults to '+' - sortDirection = map[byte]string{ - '+': "asc", - '-': "desc", + sortDirection = map[Direction]string{ + ASC: "asc", + DESC: "desc", } opFormat = map[Op]string{ EQ: "=", @@ -60,6 +58,20 @@ var ( } ) +func GetAllOps() []Op { + return []Op{ + EQ, + NEQ, + LT, + GT, + LTE, + GTE, + LIKE, + OR, + AND, + } +} + // Config is the configuration for the parser. type Config struct { // TagName is an optional tag name for configuration. t defaults to "rql". @@ -136,6 +148,16 @@ type Config struct { // DefaultSort is the default value for the 'Sort' field that returns when no sort expression is supplied by the caller. // It defaults to an empty string slice. DefaultSort []string + // Lets the user define how a rql op is translated to a db op. + GetDBOp func(Op) string + // Lets the user define how a rql dir ('+','-') is translated to a db direction. + GetDBDir func(Direction) string + // Sets the validator function based on the type + GetValidateFn func(reflect.Type) func(interface{}) error + // Sets the convertor function based on the type + GetConverter func(reflect.Type) func(interface{}) interface{} + // Sets the supported operations for that type + GetSupportedOps func(reflect.Type) []Op } // defaults sets the default configuration of Config. @@ -152,6 +174,25 @@ func (c *Config) defaults() error { if c.ColumnFn == nil { c.ColumnFn = Column } + if c.GetDBOp == nil { + c.GetDBOp = func(o Op) string { + return opFormat[o] + } + } + if c.GetDBDir == nil { + c.GetDBDir = func(d Direction) string { + return sortDirection[d] + } + } + if c.GetConverter == nil { + c.GetConverter = GetConverterFn + } + if c.GetValidateFn == nil { + c.GetValidateFn = GetValidateFn + } + if c.GetSupportedOps == nil { + c.GetSupportedOps = GetSupportedOps + } defaultString(&c.TagName, DefaultTagName) defaultString(&c.OpPrefix, DefaultOpPrefix) defaultString(&c.FieldSep, DefaultFieldSep) diff --git a/rql.go b/rql.go index 7775487..cc0a1fc 100644 --- a/rql.go +++ b/rql.go @@ -17,6 +17,7 @@ import ( //go:generate easyjson -omit_empty -disallow_unknown_fields -snake_case rql.go // Query is the decoded result of the user input. +// //easyjson:json type Query struct { // Limit must be > 0 and <= to `LimitMaxValue`. @@ -73,7 +74,6 @@ type Query struct { // return nil, err // } // return users, nil -// type Params struct { // Limit represents the number of rows returned by the SELECT statement. Limit int @@ -104,6 +104,9 @@ func (p ParseError) Error() string { return p.msg } +type Validator func(interface{}) error +type Converter func(interface{}) interface{} + // field is a configuration of a struct field. type field struct { // Name of the column. @@ -115,9 +118,9 @@ type field struct { // All supported operators for this field. FilterOps map[string]bool // Validation for the type. for example, unit8 greater than or equal to 0. - ValidateFn func(interface{}) error + ValidateFn Validator // ConvertFn converts the given value to the type value. - CovertFn func(interface{}) interface{} + CovertFn Converter } // A Parser parses various types. The result from the Parse method is a Param object. @@ -203,7 +206,6 @@ func (p *Parser) ParseQuery(q *Query) (pr *Params, err error) { // Username => username // FullName => full_name // HTTPCode => http_code -// func Column(s string) string { var b strings.Builder for i := 0; i < len(s); i++ { @@ -221,6 +223,111 @@ func Column(s string) string { return b.String() } +func GetSupportedOps(t reflect.Type) []Op { + switch t.Kind() { + case reflect.Bool: + return []Op{EQ, NEQ} + case reflect.String: + return []Op{EQ, NEQ, LT, LTE, GT, GTE, LIKE} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + case reflect.Float32, reflect.Float64: + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + case reflect.Struct: + switch v := reflect.Zero(t); v.Interface().(type) { + case sql.NullBool: + return []Op{EQ, NEQ} + case sql.NullString: + return []Op{EQ, NEQ} + case sql.NullInt64: + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + case sql.NullFloat64: + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + case time.Time: + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + default: + if v.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + } + return []Op{} + } + default: + return []Op{} + } +} + +func GetConverterFn(t reflect.Type) func(interface{}) interface{} { + layout := "" + switch t.Kind() { + case reflect.Bool: + return valueFn + case reflect.String: + return valueFn + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return convertInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return convertInt + case reflect.Float32, reflect.Float64: + return valueFn + case reflect.Struct: + switch v := reflect.Zero(t); v.Interface().(type) { + case sql.NullBool: + return valueFn + case sql.NullString: + return valueFn + case sql.NullInt64: + return convertInt + case sql.NullFloat64: + return valueFn + case time.Time: + return convertTime(layout) + default: + if v.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + return convertTime(layout) + } + } + } + return valueFn +} + +func GetValidateFn(t reflect.Type) Validator { + layout := "" + switch t.Kind() { + case reflect.Bool: + return validateBool + case reflect.String: + return validateString + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return validateInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return validateUInt + case reflect.Float32, reflect.Float64: + return validateFloat + case reflect.Struct: + switch v := reflect.Zero(t); v.Interface().(type) { + case sql.NullBool: + return validateBool + case sql.NullString: + return validateString + case sql.NullInt64: + return validateInt + case sql.NullFloat64: + return validateFloat + case time.Time: + return validateTime(layout) + default: + if !v.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + return nil + } + return validateTime(layout) + } + default: + return nil + } +} + // init initializes the parser parsing state. it scans the fields // in a breath-first-search order and for each one of the field calls parseField. func (p *Parser) init() error { @@ -289,55 +396,15 @@ func (p *Parser) parseField(sf reflect.StructField) error { p.Log("Ignoring unknown option %q in struct tag", opt) } } - var filterOps []Op - switch typ := indirect(sf.Type); typ.Kind() { - case reflect.Bool: - f.ValidateFn = validateBool - filterOps = append(filterOps, EQ, NEQ) - case reflect.String: - f.ValidateFn = validateString - filterOps = append(filterOps, EQ, NEQ, LT, LTE, GT, GTE, LIKE) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - f.ValidateFn = validateInt - f.CovertFn = convertInt - filterOps = append(filterOps, EQ, NEQ, LT, LTE, GT, GTE) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - f.ValidateFn = validateUInt - f.CovertFn = convertInt - filterOps = append(filterOps, EQ, NEQ, LT, LTE, GT, GTE) - case reflect.Float32, reflect.Float64: - f.ValidateFn = validateFloat - filterOps = append(filterOps, EQ, NEQ, LT, LTE, GT, GTE) - case reflect.Struct: - switch v := reflect.Zero(typ); v.Interface().(type) { - case sql.NullBool: - f.ValidateFn = validateBool - filterOps = append(filterOps, EQ, NEQ) - case sql.NullString: - f.ValidateFn = validateString - filterOps = append(filterOps, EQ, NEQ) - case sql.NullInt64: - f.ValidateFn = validateInt - f.CovertFn = convertInt - filterOps = append(filterOps, EQ, NEQ, LT, LTE, GT, GTE) - case sql.NullFloat64: - f.ValidateFn = validateFloat - filterOps = append(filterOps, EQ, NEQ, LT, LTE, GT, GTE) - case time.Time: - f.ValidateFn = validateTime(layout) - f.CovertFn = convertTime(layout) - filterOps = append(filterOps, EQ, NEQ, LT, LTE, GT, GTE) - default: - if !v.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { - return fmt.Errorf("rql: field type for %q is not supported", sf.Name) - } - f.ValidateFn = validateTime(layout) - f.CovertFn = convertTime(layout) - filterOps = append(filterOps, EQ, NEQ, LT, LTE, GT, GTE) - } - default: + t := indirect(sf.Type) + + filterOps := GetSupportedOps(t) + if len(filterOps) == 0 { return fmt.Errorf("rql: field type for %q is not supported", sf.Name) } + f.CovertFn = GetConverterFn(t) + f.ValidateFn = GetValidateFn(t) + for _, op := range filterOps { f.FilterOps[p.op(op)] = true } @@ -375,12 +442,14 @@ func (p *Parser) sort(fields []string) string { sortParams := make([]string, len(fields)) for i, field := range fields { expect(field != "", "sort field can not be empty") + var orderBy string - // if the sort field prefixed by an order indicator. - if order, ok := sortDirection[field[0]]; ok { - orderBy = order + f0 := field[0] + if f0 == byte(ASC) || f0 == byte(DESC) { + orderBy = p.GetDBDir(Direction(f0)) field = field[1:] } + expect(p.fields[field] != nil, "unrecognized key %q for sorting", field) expect(p.fields[field].Sortable, "field %q is not sortable", field) colName := p.colName(field) @@ -425,7 +494,7 @@ func (p *parseState) relOp(op Op, terms []interface{}) { for _, t := range terms { if i > 0 { p.WriteByte(' ') - p.WriteString(op.SQL()) + p.WriteString(p.GetDBOp(op)) p.WriteByte(' ') } mt, ok := t.(map[string]interface{}) @@ -469,7 +538,7 @@ func (p *parseState) field(f *field, v interface{}) { // for example: "name = ?", or "age >= ?". func (p *Parser) fmtOp(field string, op Op) string { colName := p.colName(field) - return colName + " " + op.SQL() + " ?" + return colName + " " + p.GetDBOp(op) + " ?" } // colName formats the query field to database column name in cases the user configured a custom @@ -567,7 +636,7 @@ func validateUInt(v interface{}) error { } // validate that the underlined element of this interface is a "datetime" string. -func validateTime(layout string) func(interface{}) error { +func validateTime(layout string) Validator { return func(v interface{}) error { s, ok := v.(string) if !ok { diff --git a/rql_test.go b/rql_test.go index f691219..6723f29 100644 --- a/rql_test.go +++ b/rql_test.go @@ -912,6 +912,47 @@ func TestParse(t *testing.T) { }`), wantErr: true, }, + { + name: "custom db symbols", + conf: Config{ + Model: struct { + ID string `rql:"filter"` + FullName string `rql:"filter"` + HTTPUrl string `rql:"filter"` + NestedStruct struct { + UUID string `rql:"filter"` + } + }{}, + FieldSep: ".", + GetDBOp: func(o Op) string { + if o == EQ { + return "eq" + } + return opFormat[o] + + }, + GetDBDir: func(d Direction) string { + if d == ASC { + return "ASC" + } + return "DESC" + }, + }, + input: []byte(`{ + "filter": { + "id": "id", + "full_name": "full_name", + "http_url": "http_url", + "nested_struct.uuid": "uuid" + } + }`), + wantOut: &Params{ + Limit: 25, + FilterExp: "id eq ? AND full_name eq ? AND http_url eq ? AND nested_struct_uuid eq ?", + FilterArgs: []interface{}{"id", "full_name", "http_url", "uuid"}, + Sort: "", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From 45c3e33fce893da71fa52a7cb7f3ef937707a156 Mon Sep 17 00:00:00 2001 From: Ashton Kinslow Date: Fri, 11 Aug 2023 16:48:31 -0500 Subject: [PATCH 2/7] add additional context to func signatures to support reusing operator with different db op --- config.go | 12 +- rql.go | 83 +++++---- rql_custom_ops_test.go | 376 +++++++++++++++++++++++++++++++++++++++++ rql_test.go | 38 ++++- 4 files changed, 464 insertions(+), 45 deletions(-) create mode 100644 rql_custom_ops_test.go diff --git a/config.go b/config.go index 803a7ae..e7d9864 100644 --- a/config.go +++ b/config.go @@ -149,13 +149,13 @@ type Config struct { // It defaults to an empty string slice. DefaultSort []string // Lets the user define how a rql op is translated to a db op. - GetDBOp func(Op) string + GetDBOp func(Op, *Field) string // Lets the user define how a rql dir ('+','-') is translated to a db direction. GetDBDir func(Direction) string // Sets the validator function based on the type - GetValidateFn func(reflect.Type) func(interface{}) error + GetValidator func(reflect.Type) Validator // Sets the convertor function based on the type - GetConverter func(reflect.Type) func(interface{}) interface{} + GetConverter func(reflect.Type) Converter // Sets the supported operations for that type GetSupportedOps func(reflect.Type) []Op } @@ -175,7 +175,7 @@ func (c *Config) defaults() error { c.ColumnFn = Column } if c.GetDBOp == nil { - c.GetDBOp = func(o Op) string { + c.GetDBOp = func(o Op, _ *Field) string { return opFormat[o] } } @@ -187,8 +187,8 @@ func (c *Config) defaults() error { if c.GetConverter == nil { c.GetConverter = GetConverterFn } - if c.GetValidateFn == nil { - c.GetValidateFn = GetValidateFn + if c.GetValidator == nil { + c.GetValidator = GetValidateFn } if c.GetSupportedOps == nil { c.GetSupportedOps = GetSupportedOps diff --git a/rql.go b/rql.go index cc0a1fc..621eb7f 100644 --- a/rql.go +++ b/rql.go @@ -104,11 +104,11 @@ func (p ParseError) Error() string { return p.msg } -type Validator func(interface{}) error -type Converter func(interface{}) interface{} +type Validator func(Op, reflect.Type, interface{}) error +type Converter func(Op, reflect.Type, interface{}) interface{} // field is a configuration of a struct field. -type field struct { +type Field struct { // Name of the column. Name string // Has a "sort" option in the tag. @@ -121,13 +121,15 @@ type field struct { ValidateFn Validator // ConvertFn converts the given value to the type value. CovertFn Converter + // Type of the field + Type reflect.Type } // A Parser parses various types. The result from the Parse method is a Param object. // It is safe for concurrent use by multiple goroutines except for configuration changes. type Parser struct { Config - fields map[string]*field + fields map[string]*Field } // NewParser creates a new Parser. it fails if the configuration is invalid. @@ -137,7 +139,7 @@ func NewParser(c Config) (*Parser, error) { } p := &Parser{ Config: c, - fields: make(map[string]*field), + fields: make(map[string]*Field), } if err := p.init(); err != nil { return nil, err @@ -224,6 +226,7 @@ func Column(s string) string { } func GetSupportedOps(t reflect.Type) []Op { + println(t.Kind().String()) switch t.Kind() { case reflect.Bool: return []Op{EQ, NEQ} @@ -258,7 +261,7 @@ func GetSupportedOps(t reflect.Type) []Op { } } -func GetConverterFn(t reflect.Type) func(interface{}) interface{} { +func GetConverterFn(t reflect.Type) Converter { layout := "" switch t.Kind() { case reflect.Bool: @@ -331,7 +334,8 @@ func GetValidateFn(t reflect.Type) Validator { // init initializes the parser parsing state. it scans the fields // in a breath-first-search order and for each one of the field calls parseField. func (p *Parser) init() error { - t := indirect(reflect.TypeOf(p.Model)) + t := reflect.TypeOf(p.Model) + t = indirect(t) l := list.New() for i := 0; i < t.NumField(); i++ { l.PushFront(t.Field(i)) @@ -364,7 +368,7 @@ func (p *Parser) init() error { // parseField parses the given struct field tag, and add a rule // in the parser according to its type and the options that were set on the tag. func (p *Parser) parseField(sf reflect.StructField) error { - f := &field{ + f := &Field{ Name: p.ColumnFn(sf.Name), CovertFn: valueFn, FilterOps: make(map[string]bool), @@ -396,14 +400,16 @@ func (p *Parser) parseField(sf reflect.StructField) error { p.Log("Ignoring unknown option %q in struct tag", opt) } } - t := indirect(sf.Type) - filterOps := GetSupportedOps(t) + // t := indirect(sf.Type) + t := sf.Type + f.Type = t + filterOps := p.Config.GetSupportedOps(f.Type) if len(filterOps) == 0 { return fmt.Errorf("rql: field type for %q is not supported", sf.Name) } - f.CovertFn = GetConverterFn(t) - f.ValidateFn = GetValidateFn(t) + f.CovertFn = p.Config.GetConverter(f.Type) + f.ValidateFn = p.Config.GetValidator(f.Type) for _, op := range filterOps { f.FilterOps[p.op(op)] = true @@ -477,8 +483,9 @@ func (p *parseState) and(f map[string]interface{}) { expect(ok, "$and must be type array") p.relOp(AND, terms) case p.fields[k] != nil: - expect(p.fields[k].Filterable, "field %q is not filterable", k) - p.field(p.fields[k], v) + f := p.fields[k] + expect(f.Filterable, "field %q is not filterable", k) + p.field(f, v) default: expect(false, "unrecognized key %q for filtering", k) } @@ -494,7 +501,7 @@ func (p *parseState) relOp(op Op, terms []interface{}) { for _, t := range terms { if i > 0 { p.WriteByte(' ') - p.WriteString(p.GetDBOp(op)) + p.WriteString(p.GetDBOp(op, nil)) p.WriteByte(' ') } mt, ok := t.(map[string]interface{}) @@ -507,13 +514,16 @@ func (p *parseState) relOp(op Op, terms []interface{}) { } } -func (p *parseState) field(f *field, v interface{}) { +func (p *parseState) field(f *Field, v interface{}) { terms, ok := v.(map[string]interface{}) // default equality check. if !ok { - must(f.ValidateFn(v), "invalid datatype for field %q", f.Name) - p.WriteString(p.fmtOp(f.Name, EQ)) - p.values = append(p.values, f.CovertFn(v)) + op := EQ + err := f.ValidateFn(op, f.Type, v) + must(err, "invalid datatype for field %q", f.Name) + p.WriteString(p.fmtOp(f, op)) + arg := f.CovertFn(op, f.Type, v) + p.values = append(p.values, arg) } var i int if len(terms) > 1 { @@ -523,10 +533,12 @@ func (p *parseState) field(f *field, v interface{}) { if i > 0 { p.WriteString(" AND ") } + op := Op(opName[1:]) expect(f.FilterOps[opName], "can not apply op %q on field %q", opName, f.Name) - must(f.ValidateFn(opVal), "invalid datatype or format for field %q", f.Name) - p.WriteString(p.fmtOp(f.Name, Op(opName[1:]))) - p.values = append(p.values, f.CovertFn(opVal)) + must(f.ValidateFn(op, f.Type, opVal), "invalid datatype or format for field %q", f.Name) + p.WriteString(p.fmtOp(f, op)) + arg := f.CovertFn(op, f.Type, opVal) + p.values = append(p.values, arg) i++ } if len(terms) > 1 { @@ -536,9 +548,8 @@ func (p *parseState) field(f *field, v interface{}) { // fmtOp create a string for the operation with a placeholder. // for example: "name = ?", or "age >= ?". -func (p *Parser) fmtOp(field string, op Op) string { - colName := p.colName(field) - return colName + " " + p.GetDBOp(op) + " ?" +func (p *Parser) fmtOp(f *Field, op Op) string { + return f.Name + " " + p.GetDBOp(op, f) + " ?" } // colName formats the query field to database column name in cases the user configured a custom @@ -589,7 +600,7 @@ func errorType(v interface{}, expected string) error { } // validate that the underlined element of given interface is a boolean. -func validateBool(v interface{}) error { +func validateBool(op Op, t reflect.Type, v interface{}) error { if _, ok := v.(bool); !ok { return errorType(v, "bool") } @@ -597,7 +608,7 @@ func validateBool(v interface{}) error { } // validate that the underlined element of given interface is a string. -func validateString(v interface{}) error { +func validateString(op Op, t reflect.Type, v interface{}) error { if _, ok := v.(string); !ok { return errorType(v, "string") } @@ -605,7 +616,7 @@ func validateString(v interface{}) error { } // validate that the underlined element of given interface is a float. -func validateFloat(v interface{}) error { +func validateFloat(op Op, t reflect.Type, v interface{}) error { if _, ok := v.(float64); !ok { return errorType(v, "float64") } @@ -613,7 +624,7 @@ func validateFloat(v interface{}) error { } // validate that the underlined element of given interface is an int. -func validateInt(v interface{}) error { +func validateInt(op Op, t reflect.Type, v interface{}) error { n, ok := v.(float64) if !ok { return errorType(v, "int") @@ -625,8 +636,8 @@ func validateInt(v interface{}) error { } // validate that the underlined element of given interface is an int and greater than 0. -func validateUInt(v interface{}) error { - if err := validateInt(v); err != nil { +func validateUInt(op Op, t reflect.Type, v interface{}) error { + if err := validateInt(op, t, v); err != nil { return err } if v.(float64) < 0 { @@ -637,7 +648,7 @@ func validateUInt(v interface{}) error { // validate that the underlined element of this interface is a "datetime" string. func validateTime(layout string) Validator { - return func(v interface{}) error { + return func(_ Op, _ reflect.Type, v interface{}) error { s, ok := v.(string) if !ok { return errorType(v, "string") @@ -648,20 +659,20 @@ func validateTime(layout string) Validator { } // convert float to int. -func convertInt(v interface{}) interface{} { +func convertInt(op Op, t reflect.Type, v interface{}) interface{} { return int(v.(float64)) } // convert string to time object. -func convertTime(layout string) func(interface{}) interface{} { - return func(v interface{}) interface{} { +func convertTime(layout string) func(Op, reflect.Type, interface{}) interface{} { + return func(_ Op, _ reflect.Type, v interface{}) interface{} { t, _ := time.Parse(layout, v.(string)) return t } } // nop converter. -func valueFn(v interface{}) interface{} { +func valueFn(op Op, t reflect.Type, v interface{}) interface{} { return v } diff --git a/rql_custom_ops_test.go b/rql_custom_ops_test.go new file mode 100644 index 0000000..6a4cb93 --- /dev/null +++ b/rql_custom_ops_test.go @@ -0,0 +1,376 @@ +package rql + +import ( + "database/sql" + "errors" + "fmt" + "math" + "reflect" + "testing" + "time" +) + +var customOpFormat = map[Op]string{ + EQ: "=", + NEQ: "<>", + LT: "<", + GT: ">", + LTE: "<=", + GTE: ">=", + LIKE: "ILIKE", + OR: "OR", + AND: "AND", + IN: "IN", + NIN: "NOT IN", + ALL: "@>", + OVERLAP: "&&", + CONTAINS: "@>", + EXISTS: "?|", +} + +type StructAlias map[string]interface{} + +func TestParse2(t *testing.T) { + tests := []struct { + name string + conf Config + input []byte + wantErr bool + wantOut *Params + }{ + { + + name: "custom conv/val func", + conf: Config{ + Model: struct { + IDs []string `rql:"filter,column=ids"` + StrSl []string `rql:"filter,column=str_sl"` + Inty int `rql:"filter,column=inty"` + IntSl []int `rql:"filter,column=int_sl"` + Floats []float64 `rql:"filter,column=floats"` + Map map[string]interface{} `rql:"filter,column=map"` + AliasMap StructAlias `rql:"filter,column=alias_map"` + }{}, + FieldSep: ".", + GetDBOp: func(o Op, f *Field) string { + return customOpFormat[o] + }, + GetSupportedOps: CustomGetSupportedOps, + GetValidator: CustomGetValidateFn, + GetConverter: CustomGetConverterFn, + }, + input: []byte(`{ + "filter": { + "floats" :{"$overlap":[1.2,3.2,1]}, + "map" : {"$contains": {"key":{"someobject":"fdf"}}}, + "alias_map" : {"$exists": "str"}, + "ids": ["1"], + "inty": {"$in":[2]}, + "int_sl": {"$all":[1,2]} + } + }`), + wantOut: &Params{ + Limit: 25, + FilterExp: "map @> ? AND alias_map ?| ? AND ids = ? AND inty IN ? AND int_sl @> ? AND floats && ?", + FilterArgs: []interface{}{ + []interface{}{1.2, 3.2, float64(1)}, + map[string]interface{}{"key": map[string]interface{}{"someobject": "fdf"}}, + "str", + []interface{}{"1"}, + []interface{}{2}, + []interface{}{1, 2}}, + Sort: "", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.conf.Log = t.Logf + p, err := NewParser(tt.conf) + if err != nil { + t.Fatalf("failed to build parser: %v", err) + } + out, err := p.Parse(tt.input) + if tt.wantErr != (err != nil) { + t.Fatalf("want: %v\ngot:%v\nerr: %v", tt.wantErr, err != nil, err) + } + assertParams(t, out, tt.wantOut) + }) + } +} + +var ( + IN = Op("in") + NIN = Op("nin") + OVERLAP = Op("overlap") + ALL = Op("all") + CONTAINS = Op("contains") + EXISTS = Op("exists") +) + +func CustomGetSupportedOps(t reflect.Type) []Op { + t = indirect(t) + switch t.Kind() { + case reflect.Bool: + return []Op{EQ, NEQ} + case reflect.String: + return []Op{EQ, NEQ, LT, LTE, GT, GTE, LIKE, IN, NIN} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return []Op{EQ, NEQ, LT, LTE, GT, GTE, IN, NIN} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return []Op{EQ, NEQ, LT, LTE, GT, GTE, IN, NIN} + case reflect.Float32, reflect.Float64: + return []Op{EQ, NEQ, LT, LTE, GT, GTE, IN, NIN} + case reflect.Slice: + return []Op{EQ, NEQ, OVERLAP, ALL} + case reflect.Map: + return []Op{CONTAINS, EXISTS} + case reflect.Struct: + switch v := reflect.Zero(t); v.Interface().(type) { + case sql.NullBool: + return []Op{EQ, NEQ} + case sql.NullString: + return []Op{EQ, NEQ} + case sql.NullInt64: + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + case sql.NullFloat64: + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + case time.Time: + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + default: + if v.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + return []Op{EQ, NEQ, LT, LTE, GT, GTE} + } + return []Op{} + } + default: + return []Op{} + } +} + +func CustomGetConverterFn(t reflect.Type) Converter { + layout := "" + switch t.Kind() { + case reflect.Bool: + return valueFn + case reflect.String: + return valueFn + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return customConvertInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return customConvertInt + case reflect.Float32, reflect.Float64: + return valueFn + case reflect.Slice: + return convertSlice + case reflect.Map: + return valueFn + case reflect.Struct: + switch v := reflect.Zero(t); v.Interface().(type) { + case sql.NullBool: + return valueFn + case sql.NullString: + return valueFn + case sql.NullInt64: + return customConvertInt + case sql.NullFloat64: + return valueFn + case time.Time: + return convertTime(layout) + default: + if v.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + return convertTime(layout) + } + } + } + return valueFn +} + +func CustomGetValidateFn(t reflect.Type) Validator { + layout := "" + switch t.Kind() { + case reflect.Bool: + return validateBool + case reflect.String: + return customValidateString + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return customValidateInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return customValidateUInt + case reflect.Float32, reflect.Float64: + return customValidateFloat + case reflect.Slice: + return validateSliceOp + case reflect.Map: + return validateMapOp + case reflect.Struct: + switch v := reflect.Zero(t); v.Interface().(type) { + case sql.NullBool: + return validateBool + case sql.NullString: + return customValidateString + case sql.NullInt64: + return customValidateInt + case sql.NullFloat64: + return customValidateFloat + case time.Time: + return validateTime(layout) + default: + if !v.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + return nil + } + return validateTime(layout) + } + default: + return nil + } +} + +func validateSliceElem(v interface{}, expectedElemType reflect.Type) error { + slice, ok := v.([]interface{}) + if !ok { + return errorType(v, "array") + } + for _, item := range slice { + it := reflect.TypeOf(item) + c := isComparable(expectedElemType, it) + + if !c { + t := expectedElemType.Kind().String() + return errorType(item, t) + } + } + return nil +} + +// validate that the underlined element of given interface is a string. +func customValidateString(op Op, t reflect.Type, v interface{}) error { + if op == IN || op == NIN { + validateSliceElem(v, reflect.TypeOf("")) + } + if _, ok := v.(string); !ok { + return errorType(v, "string") + } + return nil +} + +// validate that the underlined element of given interface is a float. +func customValidateFloat(op Op, t reflect.Type, v interface{}) error { + if op == IN || op == NIN { + return validateSliceElem(v, reflect.TypeOf(1.1)) + } + if _, ok := v.(float64); !ok { + return errorType(v, "float64") + } + return nil +} + +// validate that the underlined element of given interface is an int. +func customValidateInt(op Op, t reflect.Type, v interface{}) error { + if op == IN || op == NIN { + return validateSliceElem(v, reflect.TypeOf(1.1)) + } + n, ok := v.(float64) + if !ok { + return errorType(v, "int") + } + if math.Trunc(n) != n { + return errors.New("not an integer") + } + return nil +} + +// validate that the underlined element of given interface is an int and greater than 0. +func customValidateUInt(op Op, t reflect.Type, v interface{}) error { + if op == IN || op == NIN { + return validateSliceElem(v, reflect.TypeOf(1.1)) + } + if err := validateInt(op, t, v); err != nil { + return err + } + if v.(float64) < 0 { + return errors.New("not an unsigned integer") + } + return nil +} + +// convert float to int. +func customConvertInt(op Op, t reflect.Type, v interface{}) interface{} { + if op == IN || op == NIN { + sl, ok := v.([]interface{}) + if !ok { + return v + } + for i, f := range sl { + newInt := int(f.(float64)) + sl[i] = newInt + } + return sl + } + return int(v.(float64)) +} + +func convertSlice(op Op, t reflect.Type, v interface{}) interface{} { + if isNumeric(t.Elem()) { + switch t.Elem().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + + sl, ok := v.([]interface{}) + if !ok { + return v + } + for i, f := range sl { + newInt := int(f.(float64)) + sl[i] = newInt + } + return sl + } + } + return v +} + +func isNumeric(t reflect.Type) bool { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Float32, reflect.Float64, + reflect.Complex64, reflect.Complex128: + return true + default: + return false + } +} + +// does not handle int IN []int +func isComparable(t reflect.Type, t2 reflect.Type) bool { + return t == t2 || (isNumeric(t) && isNumeric(t2)) +} + +func validateSliceOp(op Op, t reflect.Type, v interface{}) error { + if t.Kind() != reflect.Slice { + return fmt.Errorf("t is not a slice, wrong validate func") + } + + vType := reflect.TypeOf(v) + if vType.Kind() != reflect.Slice { + return fmt.Errorf("not a slice") + } + return validateSliceElem(v, t.Elem()) +} + +func validateMapOp(op Op, t reflect.Type, v interface{}) error { + if op == EXISTS { + _, ok := v.(string) + if !ok { + return fmt.Errorf("exists expects a string arg") + } + } + if op == ALL { + _, ok := v.(map[string]interface{}) + if !ok { + return fmt.Errorf("exists expects a string arg") + } + } + return nil +} diff --git a/rql_test.go b/rql_test.go index 6723f29..c5b7d59 100644 --- a/rql_test.go +++ b/rql_test.go @@ -2,7 +2,9 @@ package rql import ( "database/sql" + "fmt" "reflect" + "sort" "strings" "testing" "time" @@ -924,7 +926,7 @@ func TestParse(t *testing.T) { } }{}, FieldSep: ".", - GetDBOp: func(o Op) string { + GetDBOp: func(o Op, _ *Field) string { if o == EQ { return "eq" } @@ -991,8 +993,9 @@ func assertParams(t *testing.T, got *Params, want *Params) { if !equalExp(got.FilterExp, want.FilterExp) || !equalExp(want.FilterExp, got.FilterExp) { t.Fatalf("filter expr:\n\tgot: %q\n\twant %q", got.FilterExp, want.FilterExp) } - if !equalArgs(got.FilterArgs, got.FilterArgs) || !equalArgs(want.FilterArgs, got.FilterArgs) { - t.Fatalf("filter args:\n\tgot: %v\n\twant %v", got.FilterArgs, want.FilterArgs) + err := deepEqualIgnoreOrder(got.FilterArgs, want.FilterArgs) + if err != nil { + t.Fatalf("filter args:\n\tgot: %v\n\twant %v %v", got.FilterArgs, want.FilterArgs, err.Error()) } } @@ -1062,3 +1065,32 @@ func mustParseTime(layout, s string) time.Time { t, _ := time.Parse(layout, s) return t } + +func deepSort(i interface{}) interface{} { + switch reflect.TypeOf(i).Kind() { + case reflect.Slice: + s := reflect.ValueOf(i) + if s.Len() == 0 { + return i + } + newSlice := make([]interface{}, s.Len()) + for j := 0; j < s.Len(); j++ { + newSlice[j] = deepSort(s.Index(j).Interface()) + } + sort.SliceStable(newSlice, func(i, j int) bool { + return fmt.Sprint(newSlice[i]) < fmt.Sprint(newSlice[j]) + }) + return newSlice + default: + return i + } +} + +func deepEqualIgnoreOrder(a, b interface{}) error { + sortedA := deepSort(a) + sortedB := deepSort(b) + if !reflect.DeepEqual(sortedA, sortedB) { + return fmt.Errorf("differences found: A=%v, B=%v", sortedA, sortedB) + } + return nil +} From 5f6c8743cd12535cfca285564e81329e97e3322e Mon Sep 17 00:00:00 2001 From: Ashton Kinslow Date: Fri, 11 Aug 2023 17:51:16 -0500 Subject: [PATCH 3/7] fix nested obj --- rql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rql.go b/rql.go index 621eb7f..00044ce 100644 --- a/rql.go +++ b/rql.go @@ -549,7 +549,7 @@ func (p *parseState) field(f *Field, v interface{}) { // fmtOp create a string for the operation with a placeholder. // for example: "name = ?", or "age >= ?". func (p *Parser) fmtOp(f *Field, op Op) string { - return f.Name + " " + p.GetDBOp(op, f) + " ?" + return p.colName(f.Name) + " " + p.GetDBOp(op, f) + " ?" } // colName formats the query field to database column name in cases the user configured a custom From 6d10a8918e13d4f72d26ddc4cb66b57c8fd89332 Mon Sep 17 00:00:00 2001 From: Ashton Kinslow Date: Fri, 11 Aug 2023 18:18:48 -0500 Subject: [PATCH 4/7] update func signatures to support time ops --- config.go | 6 +++--- rql.go | 27 ++++++++++++++------------ rql_custom_ops_test.go | 10 ++++++---- rql_test.go | 43 ++++++++++++++++++++++++++++++++---------- 4 files changed, 57 insertions(+), 29 deletions(-) diff --git a/config.go b/config.go index e7d9864..7eaf6d5 100644 --- a/config.go +++ b/config.go @@ -153,11 +153,11 @@ type Config struct { // Lets the user define how a rql dir ('+','-') is translated to a db direction. GetDBDir func(Direction) string // Sets the validator function based on the type - GetValidator func(reflect.Type) Validator + GetValidator func(f *Field) Validator // Sets the convertor function based on the type - GetConverter func(reflect.Type) Converter + GetConverter func(f *Field) Converter // Sets the supported operations for that type - GetSupportedOps func(reflect.Type) []Op + GetSupportedOps func(f *Field) []Op } // defaults sets the default configuration of Config. diff --git a/rql.go b/rql.go index 00044ce..8d3583a 100644 --- a/rql.go +++ b/rql.go @@ -123,6 +123,8 @@ type Field struct { CovertFn Converter // Type of the field Type reflect.Type + // Time Layout + Layout string } // A Parser parses various types. The result from the Parse method is a Param object. @@ -225,8 +227,8 @@ func Column(s string) string { return b.String() } -func GetSupportedOps(t reflect.Type) []Op { - println(t.Kind().String()) +func GetSupportedOps(f *Field) []Op { + t := f.Type switch t.Kind() { case reflect.Bool: return []Op{EQ, NEQ} @@ -261,8 +263,9 @@ func GetSupportedOps(t reflect.Type) []Op { } } -func GetConverterFn(t reflect.Type) Converter { - layout := "" +func GetConverterFn(f *Field) Converter { + layout := f.Layout + t := f.Type switch t.Kind() { case reflect.Bool: return valueFn @@ -295,8 +298,9 @@ func GetConverterFn(t reflect.Type) Converter { return valueFn } -func GetValidateFn(t reflect.Type) Validator { - layout := "" +func GetValidateFn(f *Field) Validator { + t := f.Type + layout := f.Layout switch t.Kind() { case reflect.Bool: return validateBool @@ -400,16 +404,15 @@ func (p *Parser) parseField(sf reflect.StructField) error { p.Log("Ignoring unknown option %q in struct tag", opt) } } + f.Layout = layout - // t := indirect(sf.Type) - t := sf.Type - f.Type = t - filterOps := p.Config.GetSupportedOps(f.Type) + f.Type = indirect(sf.Type) + filterOps := p.Config.GetSupportedOps(f) if len(filterOps) == 0 { return fmt.Errorf("rql: field type for %q is not supported", sf.Name) } - f.CovertFn = p.Config.GetConverter(f.Type) - f.ValidateFn = p.Config.GetValidator(f.Type) + f.CovertFn = p.Config.GetConverter(f) + f.ValidateFn = p.Config.GetValidator(f) for _, op := range filterOps { f.FilterOps[p.op(op)] = true diff --git a/rql_custom_ops_test.go b/rql_custom_ops_test.go index 6a4cb93..7144533 100644 --- a/rql_custom_ops_test.go +++ b/rql_custom_ops_test.go @@ -108,8 +108,8 @@ var ( EXISTS = Op("exists") ) -func CustomGetSupportedOps(t reflect.Type) []Op { - t = indirect(t) +func CustomGetSupportedOps(f *Field) []Op { + t := f.Type switch t.Kind() { case reflect.Bool: return []Op{EQ, NEQ} @@ -148,8 +148,9 @@ func CustomGetSupportedOps(t reflect.Type) []Op { } } -func CustomGetConverterFn(t reflect.Type) Converter { +func CustomGetConverterFn(f *Field) Converter { layout := "" + t := f.Type switch t.Kind() { case reflect.Bool: return valueFn @@ -186,7 +187,8 @@ func CustomGetConverterFn(t reflect.Type) Converter { return valueFn } -func CustomGetValidateFn(t reflect.Type) Validator { +func CustomGetValidateFn(f *Field) Validator { + t := f.Type layout := "" switch t.Kind() { case reflect.Bool: diff --git a/rql_test.go b/rql_test.go index c5b7d59..4a6894f 100644 --- a/rql_test.go +++ b/rql_test.go @@ -1063,22 +1063,40 @@ func split(e string) []string { func mustParseTime(layout, s string) time.Time { t, _ := time.Parse(layout, s) + return t } +func compareInterface(a, b interface{}) bool { + // If either of the values is nil, handle them first. + if a == nil && b != nil { + return true // consider nil as the smallest value + } + if a != nil && b == nil { + return false + } + if a == nil && b == nil { + return false // doesn't matter which one comes first if both are nil + } + + // If they are slices, compare their sorted string representations. + if reflect.TypeOf(a).Kind() == reflect.Slice && reflect.TypeOf(b).Kind() == reflect.Slice { + return fmt.Sprint(a) < fmt.Sprint(b) + } + + // Otherwise, use the regular string representation for comparison. + return fmt.Sprint(a) < fmt.Sprint(b) +} + func deepSort(i interface{}) interface{} { - switch reflect.TypeOf(i).Kind() { - case reflect.Slice: - s := reflect.ValueOf(i) - if s.Len() == 0 { - return i - } - newSlice := make([]interface{}, s.Len()) - for j := 0; j < s.Len(); j++ { - newSlice[j] = deepSort(s.Index(j).Interface()) + switch v := i.(type) { + case []interface{}: + newSlice := make([]interface{}, len(v)) + for j, item := range v { + newSlice[j] = deepSort(item) } sort.SliceStable(newSlice, func(i, j int) bool { - return fmt.Sprint(newSlice[i]) < fmt.Sprint(newSlice[j]) + return compareInterface(newSlice[i], newSlice[j]) }) return newSlice default: @@ -1087,6 +1105,11 @@ func deepSort(i interface{}) interface{} { } func deepEqualIgnoreOrder(a, b interface{}) error { + // Explicitly handle nil cases + if (a == nil && b != nil) || (a != nil && b == nil) { + return fmt.Errorf("differences found: A=%v, B=%v", a, b) + } + sortedA := deepSort(a) sortedB := deepSort(b) if !reflect.DeepEqual(sortedA, sortedB) { From e52d0d9eef55bf85ae17fa5f4ffbc464d6d9df95 Mon Sep 17 00:00:00 2001 From: Ashton Kinslow Date: Fri, 11 Aug 2023 18:34:55 -0500 Subject: [PATCH 5/7] seperate out FieldMeta to provide more context and fix tz --- config.go | 6 ++-- rql.go | 66 +++++++++++++++++++++++------------------- rql_custom_ops_test.go | 30 ++++++++++--------- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/config.go b/config.go index 7eaf6d5..ff21979 100644 --- a/config.go +++ b/config.go @@ -153,11 +153,11 @@ type Config struct { // Lets the user define how a rql dir ('+','-') is translated to a db direction. GetDBDir func(Direction) string // Sets the validator function based on the type - GetValidator func(f *Field) Validator + GetValidator func(f *FieldMeta) Validator // Sets the convertor function based on the type - GetConverter func(f *Field) Converter + GetConverter func(f *FieldMeta) Converter // Sets the supported operations for that type - GetSupportedOps func(f *Field) []Op + GetSupportedOps func(f *FieldMeta) []Op } // defaults sets the default configuration of Config. diff --git a/rql.go b/rql.go index 8d3583a..5b7b44f 100644 --- a/rql.go +++ b/rql.go @@ -104,11 +104,19 @@ func (p ParseError) Error() string { return p.msg } -type Validator func(Op, reflect.Type, interface{}) error -type Converter func(Op, reflect.Type, interface{}) interface{} +type Validator func(Op, FieldMeta, interface{}) error +type Converter func(Op, FieldMeta, interface{}) interface{} // field is a configuration of a struct field. type Field struct { + *FieldMeta + // Validation for the type. for example, unit8 greater than or equal to 0. + ValidateFn Validator + // ConvertFn converts the given value to the type value. + CovertFn Converter +} + +type FieldMeta struct { // Name of the column. Name string // Has a "sort" option in the tag. @@ -117,10 +125,6 @@ type Field struct { Filterable bool // All supported operators for this field. FilterOps map[string]bool - // Validation for the type. for example, unit8 greater than or equal to 0. - ValidateFn Validator - // ConvertFn converts the given value to the type value. - CovertFn Converter // Type of the field Type reflect.Type // Time Layout @@ -227,7 +231,7 @@ func Column(s string) string { return b.String() } -func GetSupportedOps(f *Field) []Op { +func GetSupportedOps(f *FieldMeta) []Op { t := f.Type switch t.Kind() { case reflect.Bool: @@ -263,7 +267,7 @@ func GetSupportedOps(f *Field) []Op { } } -func GetConverterFn(f *Field) Converter { +func GetConverterFn(f *FieldMeta) Converter { layout := f.Layout t := f.Type switch t.Kind() { @@ -298,7 +302,7 @@ func GetConverterFn(f *Field) Converter { return valueFn } -func GetValidateFn(f *Field) Validator { +func GetValidateFn(f *FieldMeta) Validator { t := f.Type layout := f.Layout switch t.Kind() { @@ -373,9 +377,11 @@ func (p *Parser) init() error { // in the parser according to its type and the options that were set on the tag. func (p *Parser) parseField(sf reflect.StructField) error { f := &Field{ - Name: p.ColumnFn(sf.Name), - CovertFn: valueFn, - FilterOps: make(map[string]bool), + FieldMeta: &FieldMeta{ + Name: p.ColumnFn(sf.Name), + FilterOps: make(map[string]bool), + }, + CovertFn: valueFn, } layout := time.RFC3339 opts := strings.Split(sf.Tag.Get(p.TagName), ",") @@ -407,12 +413,12 @@ func (p *Parser) parseField(sf reflect.StructField) error { f.Layout = layout f.Type = indirect(sf.Type) - filterOps := p.Config.GetSupportedOps(f) + filterOps := p.Config.GetSupportedOps(f.FieldMeta) if len(filterOps) == 0 { return fmt.Errorf("rql: field type for %q is not supported", sf.Name) } - f.CovertFn = p.Config.GetConverter(f) - f.ValidateFn = p.Config.GetValidator(f) + f.CovertFn = p.Config.GetConverter(f.FieldMeta) + f.ValidateFn = p.Config.GetValidator(f.FieldMeta) for _, op := range filterOps { f.FilterOps[p.op(op)] = true @@ -522,10 +528,10 @@ func (p *parseState) field(f *Field, v interface{}) { // default equality check. if !ok { op := EQ - err := f.ValidateFn(op, f.Type, v) + err := f.ValidateFn(op, *f.FieldMeta, v) must(err, "invalid datatype for field %q", f.Name) p.WriteString(p.fmtOp(f, op)) - arg := f.CovertFn(op, f.Type, v) + arg := f.CovertFn(op, *f.FieldMeta, v) p.values = append(p.values, arg) } var i int @@ -538,9 +544,9 @@ func (p *parseState) field(f *Field, v interface{}) { } op := Op(opName[1:]) expect(f.FilterOps[opName], "can not apply op %q on field %q", opName, f.Name) - must(f.ValidateFn(op, f.Type, opVal), "invalid datatype or format for field %q", f.Name) + must(f.ValidateFn(op, *f.FieldMeta, opVal), "invalid datatype or format for field %q", f.Name) p.WriteString(p.fmtOp(f, op)) - arg := f.CovertFn(op, f.Type, opVal) + arg := f.CovertFn(op, *f.FieldMeta, opVal) p.values = append(p.values, arg) i++ } @@ -603,7 +609,7 @@ func errorType(v interface{}, expected string) error { } // validate that the underlined element of given interface is a boolean. -func validateBool(op Op, t reflect.Type, v interface{}) error { +func validateBool(op Op, f FieldMeta, v interface{}) error { if _, ok := v.(bool); !ok { return errorType(v, "bool") } @@ -611,7 +617,7 @@ func validateBool(op Op, t reflect.Type, v interface{}) error { } // validate that the underlined element of given interface is a string. -func validateString(op Op, t reflect.Type, v interface{}) error { +func validateString(op Op, f FieldMeta, v interface{}) error { if _, ok := v.(string); !ok { return errorType(v, "string") } @@ -619,7 +625,7 @@ func validateString(op Op, t reflect.Type, v interface{}) error { } // validate that the underlined element of given interface is a float. -func validateFloat(op Op, t reflect.Type, v interface{}) error { +func validateFloat(op Op, f FieldMeta, v interface{}) error { if _, ok := v.(float64); !ok { return errorType(v, "float64") } @@ -627,7 +633,7 @@ func validateFloat(op Op, t reflect.Type, v interface{}) error { } // validate that the underlined element of given interface is an int. -func validateInt(op Op, t reflect.Type, v interface{}) error { +func validateInt(op Op, f FieldMeta, v interface{}) error { n, ok := v.(float64) if !ok { return errorType(v, "int") @@ -639,8 +645,8 @@ func validateInt(op Op, t reflect.Type, v interface{}) error { } // validate that the underlined element of given interface is an int and greater than 0. -func validateUInt(op Op, t reflect.Type, v interface{}) error { - if err := validateInt(op, t, v); err != nil { +func validateUInt(op Op, f FieldMeta, v interface{}) error { + if err := validateInt(op, f, v); err != nil { return err } if v.(float64) < 0 { @@ -651,7 +657,7 @@ func validateUInt(op Op, t reflect.Type, v interface{}) error { // validate that the underlined element of this interface is a "datetime" string. func validateTime(layout string) Validator { - return func(_ Op, _ reflect.Type, v interface{}) error { + return func(_ Op, _ FieldMeta, v interface{}) error { s, ok := v.(string) if !ok { return errorType(v, "string") @@ -662,20 +668,20 @@ func validateTime(layout string) Validator { } // convert float to int. -func convertInt(op Op, t reflect.Type, v interface{}) interface{} { +func convertInt(op Op, f FieldMeta, v interface{}) interface{} { return int(v.(float64)) } // convert string to time object. -func convertTime(layout string) func(Op, reflect.Type, interface{}) interface{} { - return func(_ Op, _ reflect.Type, v interface{}) interface{} { +func convertTime(layout string) func(Op, FieldMeta, interface{}) interface{} { + return func(_ Op, _ FieldMeta, v interface{}) interface{} { t, _ := time.Parse(layout, v.(string)) return t } } // nop converter. -func valueFn(op Op, t reflect.Type, v interface{}) interface{} { +func valueFn(op Op, f FieldMeta, v interface{}) interface{} { return v } diff --git a/rql_custom_ops_test.go b/rql_custom_ops_test.go index 7144533..e657d15 100644 --- a/rql_custom_ops_test.go +++ b/rql_custom_ops_test.go @@ -108,7 +108,7 @@ var ( EXISTS = Op("exists") ) -func CustomGetSupportedOps(f *Field) []Op { +func CustomGetSupportedOps(f *FieldMeta) []Op { t := f.Type switch t.Kind() { case reflect.Bool: @@ -148,8 +148,8 @@ func CustomGetSupportedOps(f *Field) []Op { } } -func CustomGetConverterFn(f *Field) Converter { - layout := "" +func CustomGetConverterFn(f *FieldMeta) Converter { + layout := f.Layout t := f.Type switch t.Kind() { case reflect.Bool: @@ -187,9 +187,9 @@ func CustomGetConverterFn(f *Field) Converter { return valueFn } -func CustomGetValidateFn(f *Field) Validator { +func CustomGetValidateFn(f *FieldMeta) Validator { t := f.Type - layout := "" + layout := f.Layout switch t.Kind() { case reflect.Bool: return validateBool @@ -246,7 +246,7 @@ func validateSliceElem(v interface{}, expectedElemType reflect.Type) error { } // validate that the underlined element of given interface is a string. -func customValidateString(op Op, t reflect.Type, v interface{}) error { +func customValidateString(op Op, f FieldMeta, v interface{}) error { if op == IN || op == NIN { validateSliceElem(v, reflect.TypeOf("")) } @@ -257,7 +257,7 @@ func customValidateString(op Op, t reflect.Type, v interface{}) error { } // validate that the underlined element of given interface is a float. -func customValidateFloat(op Op, t reflect.Type, v interface{}) error { +func customValidateFloat(op Op, f FieldMeta, v interface{}) error { if op == IN || op == NIN { return validateSliceElem(v, reflect.TypeOf(1.1)) } @@ -268,7 +268,7 @@ func customValidateFloat(op Op, t reflect.Type, v interface{}) error { } // validate that the underlined element of given interface is an int. -func customValidateInt(op Op, t reflect.Type, v interface{}) error { +func customValidateInt(op Op, f FieldMeta, v interface{}) error { if op == IN || op == NIN { return validateSliceElem(v, reflect.TypeOf(1.1)) } @@ -283,11 +283,11 @@ func customValidateInt(op Op, t reflect.Type, v interface{}) error { } // validate that the underlined element of given interface is an int and greater than 0. -func customValidateUInt(op Op, t reflect.Type, v interface{}) error { +func customValidateUInt(op Op, f FieldMeta, v interface{}) error { if op == IN || op == NIN { return validateSliceElem(v, reflect.TypeOf(1.1)) } - if err := validateInt(op, t, v); err != nil { + if err := validateInt(op, f, v); err != nil { return err } if v.(float64) < 0 { @@ -297,7 +297,7 @@ func customValidateUInt(op Op, t reflect.Type, v interface{}) error { } // convert float to int. -func customConvertInt(op Op, t reflect.Type, v interface{}) interface{} { +func customConvertInt(op Op, f FieldMeta, v interface{}) interface{} { if op == IN || op == NIN { sl, ok := v.([]interface{}) if !ok { @@ -312,7 +312,8 @@ func customConvertInt(op Op, t reflect.Type, v interface{}) interface{} { return int(v.(float64)) } -func convertSlice(op Op, t reflect.Type, v interface{}) interface{} { +func convertSlice(op Op, f FieldMeta, v interface{}) interface{} { + t := f.Type if isNumeric(t.Elem()) { switch t.Elem().Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, @@ -349,7 +350,8 @@ func isComparable(t reflect.Type, t2 reflect.Type) bool { return t == t2 || (isNumeric(t) && isNumeric(t2)) } -func validateSliceOp(op Op, t reflect.Type, v interface{}) error { +func validateSliceOp(op Op, f FieldMeta, v interface{}) error { + t := f.Type if t.Kind() != reflect.Slice { return fmt.Errorf("t is not a slice, wrong validate func") } @@ -361,7 +363,7 @@ func validateSliceOp(op Op, t reflect.Type, v interface{}) error { return validateSliceElem(v, t.Elem()) } -func validateMapOp(op Op, t reflect.Type, v interface{}) error { +func validateMapOp(op Op, f FieldMeta, v interface{}) error { if op == EXISTS { _, ok := v.(string) if !ok { From fb792500be4aae203186eaf23a9be7c284a293cd Mon Sep 17 00:00:00 2001 From: Ashton Kinslow Date: Fri, 11 Aug 2023 18:48:13 -0500 Subject: [PATCH 6/7] update to use field metadata --- config.go | 4 ++-- rql.go | 2 +- rql_custom_ops_test.go | 2 +- rql_test.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/config.go b/config.go index ff21979..450b096 100644 --- a/config.go +++ b/config.go @@ -149,7 +149,7 @@ type Config struct { // It defaults to an empty string slice. DefaultSort []string // Lets the user define how a rql op is translated to a db op. - GetDBOp func(Op, *Field) string + GetDBOp func(Op, *FieldMeta) string // Lets the user define how a rql dir ('+','-') is translated to a db direction. GetDBDir func(Direction) string // Sets the validator function based on the type @@ -175,7 +175,7 @@ func (c *Config) defaults() error { c.ColumnFn = Column } if c.GetDBOp == nil { - c.GetDBOp = func(o Op, _ *Field) string { + c.GetDBOp = func(o Op, _ *FieldMeta) string { return opFormat[o] } } diff --git a/rql.go b/rql.go index 5b7b44f..3148d16 100644 --- a/rql.go +++ b/rql.go @@ -558,7 +558,7 @@ func (p *parseState) field(f *Field, v interface{}) { // fmtOp create a string for the operation with a placeholder. // for example: "name = ?", or "age >= ?". func (p *Parser) fmtOp(f *Field, op Op) string { - return p.colName(f.Name) + " " + p.GetDBOp(op, f) + " ?" + return p.colName(f.Name) + " " + p.GetDBOp(op, f.FieldMeta) + " ?" } // colName formats the query field to database column name in cases the user configured a custom diff --git a/rql_custom_ops_test.go b/rql_custom_ops_test.go index e657d15..014ecd6 100644 --- a/rql_custom_ops_test.go +++ b/rql_custom_ops_test.go @@ -52,7 +52,7 @@ func TestParse2(t *testing.T) { AliasMap StructAlias `rql:"filter,column=alias_map"` }{}, FieldSep: ".", - GetDBOp: func(o Op, f *Field) string { + GetDBOp: func(o Op, f *FieldMeta) string { return customOpFormat[o] }, GetSupportedOps: CustomGetSupportedOps, diff --git a/rql_test.go b/rql_test.go index 4a6894f..477a69f 100644 --- a/rql_test.go +++ b/rql_test.go @@ -926,7 +926,7 @@ func TestParse(t *testing.T) { } }{}, FieldSep: ".", - GetDBOp: func(o Op, _ *Field) string { + GetDBOp: func(o Op, f *FieldMeta) string { if o == EQ { return "eq" } From c48b84b57580a64cac26cc0e55a3727526ce30b3 Mon Sep 17 00:00:00 2001 From: Ashton Kinslow Date: Wed, 16 Aug 2023 15:08:52 -0500 Subject: [PATCH 7/7] update op to handle any better --- config.go | 16 +++++++++++----- rql.go | 6 ++++-- rql_custom_ops_test.go | 7 +++++-- rql_test.go | 7 +++---- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/config.go b/config.go index 450b096..241cd31 100644 --- a/config.go +++ b/config.go @@ -148,8 +148,11 @@ type Config struct { // DefaultSort is the default value for the 'Sort' field that returns when no sort expression is supplied by the caller. // It defaults to an empty string slice. DefaultSort []string - // Lets the user define how a rql op is translated to a db op. - GetDBOp func(Op, *FieldMeta) string + // Lets the user define how a rql op is translated to a db op. // Returns db operator and statement format string. + // TODO: I think this interface can be improved, I'm not sure exactly yet, need more use cases. + // Current edge case requiring format string is the `= any (?)` op. Any expects `()` around ? for casting over. + // Providing a format string fixes that, but is not very flexible, a template would be better. + GetDBStatement func(Op, *FieldMeta) (string, string) // Lets the user define how a rql dir ('+','-') is translated to a db direction. GetDBDir func(Direction) string // Sets the validator function based on the type @@ -174,9 +177,12 @@ func (c *Config) defaults() error { if c.ColumnFn == nil { c.ColumnFn = Column } - if c.GetDBOp == nil { - c.GetDBOp = func(o Op, _ *FieldMeta) string { - return opFormat[o] + if c.GetDBStatement == nil { + c.GetDBStatement = func(o Op, _ *FieldMeta) (string, string) { + if o == Op("any") { + return opFormat[o], "%v %v (%v)" + } + return opFormat[o], "%v %v %v" } } if c.GetDBDir == nil { diff --git a/rql.go b/rql.go index 3148d16..4cc8828 100644 --- a/rql.go +++ b/rql.go @@ -510,7 +510,8 @@ func (p *parseState) relOp(op Op, terms []interface{}) { for _, t := range terms { if i > 0 { p.WriteByte(' ') - p.WriteString(p.GetDBOp(op, nil)) + op, _ := p.GetDBStatement(op, nil) // AND + p.WriteString(op) p.WriteByte(' ') } mt, ok := t.(map[string]interface{}) @@ -558,7 +559,8 @@ func (p *parseState) field(f *Field, v interface{}) { // fmtOp create a string for the operation with a placeholder. // for example: "name = ?", or "age >= ?". func (p *Parser) fmtOp(f *Field, op Op) string { - return p.colName(f.Name) + " " + p.GetDBOp(op, f.FieldMeta) + " ?" + dbOp, fmtStr := p.Config.GetDBStatement(op, f.FieldMeta) + return fmt.Sprintf(fmtStr, p.colName(f.Name), dbOp, "?") } // colName formats the query field to database column name in cases the user configured a custom diff --git a/rql_custom_ops_test.go b/rql_custom_ops_test.go index 014ecd6..d70a302 100644 --- a/rql_custom_ops_test.go +++ b/rql_custom_ops_test.go @@ -52,8 +52,11 @@ func TestParse2(t *testing.T) { AliasMap StructAlias `rql:"filter,column=alias_map"` }{}, FieldSep: ".", - GetDBOp: func(o Op, f *FieldMeta) string { - return customOpFormat[o] + GetDBStatement: func(o Op, f *FieldMeta) (string, string) { + if o == Op("any") { + return customOpFormat[o], "%v %v (%v)" + } + return customOpFormat[o], "%v %v %v" }, GetSupportedOps: CustomGetSupportedOps, GetValidator: CustomGetValidateFn, diff --git a/rql_test.go b/rql_test.go index 477a69f..f760618 100644 --- a/rql_test.go +++ b/rql_test.go @@ -926,12 +926,11 @@ func TestParse(t *testing.T) { } }{}, FieldSep: ".", - GetDBOp: func(o Op, f *FieldMeta) string { + GetDBStatement: func(o Op, f *FieldMeta) (string, string) { if o == EQ { - return "eq" + return "eq", "%v %v %v" } - return opFormat[o] - + return opFormat[o], "%v %v %v" }, GetDBDir: func(d Direction) string { if d == ASC {