diff --git a/jsonpath.lua b/jsonpath.lua index 07bf2c9..2953de4 100755 --- a/jsonpath.lua +++ b/jsonpath.lua @@ -100,6 +100,17 @@ ]]-- local M = {} +local codes = { + BAD_REQUEST = 400, + NOT_FOUND = 404, + INTERNAL_ERR = 500, +} + +local errors = require('errors') + +local JsonPathError = errors.new_class("JsonPathError") +local JsonPathNotFoundError = errors.new_class("JsonPathNotFoundError") + local ffi = require('ffi') -- Use Roberto Ierusalimschy's fabulous LulPeg pattern-matching library @@ -286,41 +297,297 @@ local jsonpath_grammer = (function() return jsonpath end)() +local function bad_request_error(err) + local err = JsonPathError:new(err) + err.rc = codes.BAD_REQUEST + return err +end --- Helper: evaluate abstract syntax tree. Called recursively. -local function eval_ast(ast, obj) +local function not_found_error(err) + local err = JsonPathNotFoundError:new(err) + err.rc = codes.NOT_FOUND + return err +end - -- Helper helper: match type of second operand to type of first operand - local function match_type(op1, op2) - if type(op1) == 'boolean' then - if is_null(op2) then - -- null must never be equal to other boolean, invert op1 - return not op1 - else - return (op2 and true or false) +local function internal_error(err) + local err = JsonPathNotFoundError:new(err) + err.rc = codes.INTERNAL_ERR + return err +end + +--- @alias Operator 1|2|3|4|5|6|7|8|9|10|11|12|13 +--- @alias OperatorType 1|2|3|4 + +--- @type Operator[] +local OPERATORS = { + --- Arithmetic operators + ADD = 1, + SUB = 2, + MUL = 3, + DIV = 4, + MOD = 5, + --- Logical operators + AND = 6, + OR = 7, + --- Equality + EQ = 8, + NEQ = 9, + --- Comparison + GT = 10, + GTE = 11, + LT = 12, + LTE = 13, +} + +local OPERATOR_TYPES = { + ARITHMETIC = 1, + LOGICAL = 2, + EQUALITY = 3, + COMPARISON = 4, +} + +--- @param op Operator|number +--- @return OperatorType|0 op_type +local function get_operator_type(op) + if op >= 1 and op <= 5 then + return OPERATOR_TYPES.ARITHMETIC + elseif op == 6 or op == 7 then + return OPERATOR_TYPES.LOGICAL + elseif op == 8 or op == 9 then + return OPERATOR_TYPES.EQUALITY + elseif op >= 10 and op <= 13 then + return OPERATOR_TYPES.COMPARISON + end + return 0 +end + +--- @type { [Operator]: fun(any,any): any } +local OPERATORS_FN = { + --- Arithmetic + function(l, r) + return l + r + end, + function(l, r) + return l - r + end, + function(l, r) + return l * r + end, + function(l, r) + return l / r + end, + function(l, r) + return l % r + end, + --- Logic (boolean operands only) + function(l, r) + return l and r + end, + function(l, r) + return l or r + end, + --- Eq + function (l, r) + return l == r + end, + function (l, r) + return l ~= r + end, + --- Cmp + function (l, r) + return l > r + end, + function (l, r) + return l >= r + end, + function (l, r) + return l < r + end, + function (l, r) + return l <= r + end, +} + +--- @return Operator | 0 +local function parse_operator(op) + if op == '+' then + return OPERATORS.ADD + elseif op == '-' then + return OPERATORS.SUB + elseif op == '*' then + return OPERATORS.MUL + elseif op == '/' then + return OPERATORS.DIV + elseif op == '%' then + return OPERATORS.MOD + elseif op:upper() == 'AND' or op == '&&' then + return OPERATORS.AND + elseif op:upper() == 'OR' or op == '||' then + return OPERATORS.OR + elseif op == '=' or op == '==' then + return OPERATORS.EQ + elseif op == '<>' or op == '!=' then + return OPERATORS.NEQ + elseif op == '>' then + return OPERATORS.GT + elseif op == '>=' then + return OPERATORS.GTE + elseif op == '<' then + return OPERATORS.LT + elseif op == '<=' then + return OPERATORS.LTE + else + return 0 + end +end + +--- Computes type casts and executes binary operator +--- +--- @param op Operator Operator to execute +--- @param lval any Left value of binary operator +--- @param rval any Right value of binary operator +--- @param op_str string String representation of operator, used in error reporting +--- @return any|nil val Result value +--- @return nil|string err Error, if cast has failed +local function exec_binary_op(op, lval, rval, op_str) + local l_type = type(lval) + local r_type = type(rval) + local op_type = get_operator_type(op) + + -- convert these long int numbers to normal numbers + if l_type == 'cdata' and lval ~= NULL and tostring(ffi.typeof(lval)) == 'ctype' then + l_type = "number" + lval = tonumber(lval) + end + if r_type == 'cdata' and lval ~= NULL and tostring(ffi.typeof(rval)) == 'ctype' then + r_type = "number" + rval = tonumber(rval) + end + + if op_type == OPERATOR_TYPES.ARITHMETIC then + -- arithmetic ops allowed only on numbers + if l_type == "string" then + lval = tonumber(lval) + if lval == nil then + return nil, bad_request_error(("can not parse string lvalue as number for operation %s"):format(op_str)) end - elseif type(op1) == 'number' then - return tonumber(op2) - elseif type(op1) == 'cdata' and tostring(ffi.typeof(op1)) == 'ctype' then - return tonumber(op2) - elseif is_null(op1) then - return op2 + elseif l_type ~= "number" then + return nil, bad_request_error(("lvalue is not a number for operation %s"):format(op_str)) + end + if r_type == "string" then + rval = tonumber(rval) + if rval == nil then + return nil, + bad_request_error(("can not parse string rvalue as number for operation %s"):format(op_str)) + end + elseif r_type ~= "number" then + return nil, bad_request_error(("rvalue is not a number for operation %s"):format(op_str)) + end + elseif op_type == OPERATOR_TYPES.LOGICAL then + -- everything which is not null is a true boolean + if l_type ~= "boolean" then + lval = not is_null(lval) + end + if r_type ~= "boolean" then + rval = not is_null(rval) + end + elseif op_type == OPERATOR_TYPES.EQUALITY then + -- cast numbers and booleans to string, if other operand is string + if l_type == "string" and r_type == "number" then + r_type = "string" + rval = tostring(rval) + elseif l_type == "string" and r_type == "boolean" then + r_type = "string" + rval = tostring(rval) + end + if r_type == "string" and l_type == "number" then + l_type = "string" + lval = tostring(lval) + elseif r_type == "string" and l_type == "boolean" then + l_type = "string" + lval = tostring(lval) + end + + -- cast booleans as numbers + if l_type == "number" and r_type == "boolean" then + r_type = "number" + rval = rval and 1 or 0 + end + if r_type == "number" and l_type == "boolean" then + l_type = "number" + lval = lval and 1 or 0 + end + + -- special comparisons when lvalue or rvalue is null + local lval_is_null, rval_is_null = is_null(lval), is_null(rval) + if lval_is_null and rval_is_null then + -- null == null -> true + return op == OPERATORS.EQ + end + if rval_is_null or lval_is_null then + -- something == null -> false + -- something != null -> true + -- null == something -> false + -- null != something -> true + return not (op == OPERATORS.EQ) + end + + -- bypass default operator functions for non-matching types + if l_type ~= r_type and op == OPERATORS.EQ then + -- values of different types are never equal + return false, nil + elseif l_type ~= r_type and op == OPERATORS.NEQ then + -- values of different types are always not equal + return true, nil + end + elseif op_type == OPERATOR_TYPES.COMPARISON then + -- allow to compare numbers with booleans + if l_type == "number" and r_type == "boolean" then + r_type = "number" + rval = rval and 1 or 0 + end + if r_type == "number" and l_type == "boolean" then + l_type = "number" + lval = lval and 1 or 0 end - return tostring(op2 or '') - end - -- Helper helper: convert operand to boolean - local function notempty(op1) - return op1 and true or false + -- try to parse string as number, if other operand is number + if l_type == "number" and r_type == "string" then + local num_rval = tonumber(rval) + if num_rval ~= nil then + r_type = "number" + rval = num_rval + end + end + if r_type == "number" and l_type == "string" then + local num_lval = tonumber(lval) + if num_lval ~= nil then + l_type = "number" + lval = num_lval + end + end + + -- must be the same type + if l_type ~= r_type then + return nil, bad_request_error(("can not apply %s on types %s and %s"):format(op_str, l_type, r_type)) + end + else + return nil, bad_request_error(("unknown operator %s"):format(op_str)) end + return OPERATORS_FN[op](lval, rval), nil +end + +-- Helper: evaluate abstract syntax tree. Called recursively. +local function eval_ast(ast, obj) + -- Helper helper: evaluate variable expression inside abstract syntax tree local function eval_var(expr, obj) if obj == nil then - return nil, 'object is not set' + return nil, bad_request_error('object is not set') end if type(obj) ~= "table" then - return nil, 'object is primitive' + return nil, not_found_error('object is primitive') end for i = 2, #expr do -- [1] is "var" @@ -331,7 +598,7 @@ local function eval_ast(ast, obj) member = type(member) == 'number' and member + 1 or member obj = obj[member] if is_nil(obj) then - return nil, 'object doesn\'t contain an object or attribute "' .. member .. '"' + return nil, not_found_error('object doesn\'t contain an object or attribute "'.. member ..'"') end end return obj @@ -348,7 +615,10 @@ local function eval_ast(ast, obj) local function eval_union(expr, obj) local matches = {} -- [1] is "union" for i = 2, #expr do - local result = eval_ast(expr[i], obj) + local result, err = eval_ast(expr[i], obj) + if err then + return nil, err + end if type(result) == 'table' then for _, j in ipairs(result) do table.insert(matches, j) @@ -362,16 +632,31 @@ local function eval_ast(ast, obj) -- Helper helper: evaluate 'filter' expression inside abstract syntax tree local function eval_filter(expr, obj) - return eval_ast(expr[2], obj) and true or false + local result, err = eval_ast(expr[2], obj) + if err then + if err.rc == codes.NOT_FOUND then + return false + end + return nil, err + end + return result and true or false end -- Helper helper: evaluate 'slice' expression inside abstract syntax tree local function eval_slice(expr, obj) local matches = {} -- [1] is "slice" if #expr == 4 then - local from = tonumber(eval_ast(expr[2], obj)) - local to = tonumber(eval_ast(expr[3], obj)) - local step = tonumber(eval_ast(expr[4], obj)) + local from_result, err = eval_ast(expr[2], obj) + if err then return nil, err end + local to_result, err = eval_ast(expr[3], obj) + if err then return nil, err end + local step_result, err = eval_ast(expr[4], obj) + if err then return nil, err end + + local from = tonumber(from_result) + local to = tonumber(to_result) + local step = tonumber(step_result) + if (from == nil) or (from < 0) or (to == nil) or (to < 0) then local len = eval_var_length(obj) if from == nil then @@ -399,55 +684,24 @@ local function eval_ast(ast, obj) return nil, err end for i = 3, #expr, 2 do - local operator = expr[i] - if operator == nil then - return nil, 'missing expression operator' + local op_str = expr[i] + if op_str == nil then + return nil, bad_request_error('missing expression operator') end - local op2, err = eval_ast(expr[i + 1], obj) + local op2, eval_err = eval_ast(expr[i + 1], obj) if is_nil(op2) then - return nil, err + return nil, eval_err end - if operator == '+' then - op1 = tonumber(op1) + tonumber(op2) - elseif operator == '-' then - op1 = tonumber(op1) - tonumber(op2) - elseif operator == '*' then - op1 = tonumber(op1) * tonumber(op2) - elseif operator == '/' then - op1 = tonumber(op1) / tonumber(op2) - elseif operator == '%' then - op1 = tonumber(op1) % tonumber(op2) - elseif operator:upper() == 'AND' or operator == '&&' then - op1 = notempty(op1) and notempty(op2) - elseif operator:upper() == 'OR' or operator == '||' then - op1 = notempty(op1) or notempty(op2) - elseif operator == '=' or operator == '==' then - op1 = op1 == match_type(op1, op2) - elseif operator == '<>' or operator == '!=' then - op1 = op1 ~= match_type(op1, op2) - elseif operator == '>' then - if is_null(op1) then - return false - end - op1 = op1 > match_type(op1, op2) - elseif operator == '>=' then - if is_null(op1) then - return false - end - op1 = op1 >= match_type(op1, op2) - elseif operator == '<' then - if is_null(op1) then - return false - end - op1 = op1 < match_type(op1, op2) - elseif operator == '<=' then - if is_null(op1) then - return false - end - op1 = op1 <= match_type(op1, op2) - else - return nil, 'unknown expression operator "' .. operator .. '"' + local op = parse_operator(op_str) + if op == 0 then + return nil, bad_request_error("unknown operator") end + --- @cast op Operator + local result, cast_err = exec_binary_op(op, op1, op2, op_str) + if cast_err ~= nil then + return nil, cast_err + end + op1 = result end return op1 end @@ -466,8 +720,7 @@ local function eval_ast(ast, obj) elseif ast[1] == 'filter' then return eval_filter(ast, obj) elseif ast[1] == 'slice' then - local result = eval_slice(ast, obj) - return result + return eval_slice(ast, obj) end return 0 @@ -499,7 +752,10 @@ local function match_path(ast, path, parent, obj) end elseif ast_spec[1] == 'union' or ast_spec[1] == 'slice' then -- match union or slice expression (on parent object) - local matches = eval_ast(ast_spec, parent) + local matches, err = eval_ast(ast_spec, parent) + if err then + return nil, err + end --- @cast matches table[] for _, i in pairs(matches) do match_component = tostring(i) == tostring(component) @@ -509,7 +765,16 @@ local function match_path(ast, path, parent, obj) end elseif ast_spec[1] == 'filter' then -- match filter expression - match_component = eval_ast(ast_spec, obj) and true or false + local filter_result, err = eval_ast(ast_spec, obj) + if err then + if err.rc == codes.NOT_FOUND then + match_component = false + else + return nil, err + end + else + match_component = filter_result and true or false + end end else if ast_spec == '*' then @@ -528,7 +793,16 @@ local function match_path(ast, path, parent, obj) if path_index == #path and ast_spec ~= "array" and match_component then local _, next_ast_spec = next(ast, ast_key) if next_ast_spec ~= nil and next_ast_spec[1] == 'filter' then - match_component = eval_ast(next_ast_spec, obj) and true or false + local filter_result, err = eval_ast(next_ast_spec, obj) + if err then + if err.rc == codes.NOT_FOUND then + match_component = false + else + return nil, err + end + else + match_component = filter_result and true or false + end ast_key, ast_spec = ast_iter(ast, ast_key) end end @@ -563,7 +837,10 @@ end local function match_tree(nodes, ast, path, parent, obj, count) -- Try to match every node against AST - local match = match_path(ast, path, parent, obj) + local match, err = match_path(ast, path, parent, obj) + if err then + return err + end if match == MATCH_ONE or match == MATCH_DESCENDANTS then -- This node matches. Add path and value to result -- (if max result count not yet reached) @@ -586,7 +863,10 @@ local function match_tree(nodes, ast, path, parent, obj, count) table.insert(path1, p) end table.insert(path1, type(key) == 'string' and key or (key - 1)) - match_tree(nodes, ast, path1, obj, child, count) + local err = match_tree(nodes, ast, path1, obj, child, count) + if err then + return err + end end end end @@ -612,15 +892,15 @@ end -- function M.parse(expr) if expr == nil or type(expr) ~= 'string' then - return nil, "missing or invalid 'expr' argument" + return nil, bad_request_error("missing or invalid 'expr' argument") end local ast = Ct(jsonpath_grammer * Cp()):match(expr) if ast == nil or #ast ~= 2 then - return nil, 'invalid expression "' .. expr .. '"' + return nil, bad_request_error('invalid expression "' .. expr .. '"') end if ast[2] ~= #expr + 1 then - return nil, 'invalid expression "' .. expr .. '" near "' .. expr:sub(ast[2]) .. '"' + return nil, bad_request_error('invalid expression "' .. expr .. '" near "' .. expr:sub(ast[2]) .. '"') end return ast[1] end @@ -644,13 +924,13 @@ end -- function M.nodes(obj, expr, count) if obj == nil or type(obj) ~= 'table' then - return nil, "missing or invalid 'obj' argument" + return nil, bad_request_error("missing or invalid 'obj' argument") end if expr == nil or (type(expr) ~= 'string' and type(expr) ~= 'table') then - return nil, "missing or invalid 'expr' argument" + return nil, bad_request_error("missing or invalid 'expr' argument") end if count ~= nil and type(count) ~= 'number' then - return nil, "invalid 'count' argument" + return nil, bad_request_error("invalid 'count' argument") end local ast, err @@ -662,7 +942,10 @@ function M.nodes(obj, expr, count) ast = expr end if ast == nil then - return nil, err or 'internal error' + if not err then + err = internal_error("internal error") + end + return nil, err end if count ~= nil and count == 0 then @@ -679,8 +962,10 @@ function M.nodes(obj, expr, count) end local matches = {} - match_tree(matches, ast, { '$' }, {}, obj, count) - + local err = match_tree(matches, ast, { '$' }, {}, obj, count) + if err then + return nil, err + end -- Sort results by path local sorted = {} for p, v in pairs(matches) do @@ -732,7 +1017,7 @@ function M.value(obj, expr, count) return nodes[1].value end - return nil, 'no element matching expression' + return nil, bad_request_error('no element matching expression') end diff --git a/test/test.lua b/test/test.lua index ff565de..c9f3016 100755 --- a/test/test.lua +++ b/test/test.lua @@ -826,8 +826,8 @@ testQuery = { } local result, err = jp.query(data, "$..photo[?(@.size>'400')]") - lu.assertItemsEquals(result, {}) lu.assertNil(err) + lu.assertItemsEquals(result, {}) end, testFilterNull = function() @@ -946,6 +946,131 @@ testQuery = { lu.assertNil(err) lu.assertItemsEquals(result, { array[2], array[3] }) end, + + testFilterIntBoolComparison = function () + local array = { + { id = 1, value = 0 }, + { id = 2, value = 1 }, + { id = 3, value = 2 }, + } + local result, err = jp.query(array, '$[?(@.value==true)]') + lu.assertNil(err) + lu.assertItemsEquals(result, { array[2] }) + + local result, err = jp.query(array, '$[?(@.value>true)]') + lu.assertNil(err) + lu.assertItemsEquals(result, { array[3] }) + + local result, err = jp.query(array, '$[?(@.value>=true)]') + lu.assertNil(err) + lu.assertItemsEquals(result, { array[2], array[3] }) + + local result, err = jp.query(array, '$[?(@.value1)]') + lu.assertError(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value>=1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, { array[1] }) + + local result, err = jp.query(array, '$[?(@.value<1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, { array[2] }) + + local result, err = jp.query(array, '$[?(@.value<=1)]') + lu.assertNil(err) + lu.assertItemsEquals(result, { array[1], array[2] }) + end, + + testFilterBoolStrComparison = function () + local array = { + { id = 1, value = true }, + { id = 2, value = false }, + } + local result, err = jp.query(array, '$[?(@.value=="1")]') + -- lu.assertError(err) + lu.assertNil(err) + lu.assertItemsEquals(result, {}) + + local result, err = jp.query(array, '$[?(@.value>"1")]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + + local result, err = jp.query(array, '$[?(@.value>="1")]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + + local result, err = jp.query(array, '$[?(@.value<"1")]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + + local result, err = jp.query(array, '$[?(@.value<="1")]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + end, + + testFilterArithmeticOpOnBool = function () + local array = { + { id = 1, value = 0 }, + { id = 1, value = 1 }, + { id = 2, value = 2 }, + } + local result, err = jp.query(array, '$[?(@.value==true+1)]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + + local result, err = jp.query(array, '$[?(@.value==true*1)]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + + local result, err = jp.query(array, '$[?(@.value==true/1)]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + + local result, err = jp.query(array, '$[?(@.value==true%1)]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + + local result, err = jp.query(array, '$[?(@.value<>false+1)]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + end, + + testFilterArithmeticOponStr = function () + local array = { + { id = 1, value = 0 }, + { id = 1, value = "a" }, + } + local result, err = jp.query(array, '$[?(@.value=="a"+"b")]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + + local result, err = jp.query(array, '$[?(@.value=="a"+null)]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + + local result, err = jp.query(array, '$[?(@.value=="a"+1)]') + lu.assertError(err) + lu.assertItemsEquals(result, nil) + end, }