diff --git a/cel/cel_test.go b/cel/cel_test.go index 05adeb2f..09e54751 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -1969,6 +1969,48 @@ func TestCostLimit(t *testing.T) { } } +func TestCostTrackingConsistentAcrossEvals(t *testing.T) { + env := testEnv(t, + Variable("val1", IntType), + Variable("val2", IntType), + ) + ast, iss := env.Compile(`val1 + val2`) + if iss.Err() != nil { + t.Fatalf("env.Compile() failed: %v", iss.Err()) + } + checkedAst, iss := env.Check(ast) + if iss.Err() != nil { + t.Fatalf("env.Check() failed: %v", iss.Err()) + } + program, err := env.Program(checkedAst, CostTracking(nil)) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + input := map[string]any{"val1": 1, "val2": 2} + + _, det1, err := program.Eval(input) + if err != nil { + t.Fatalf("first Eval() failed: %v", err) + } + cost1 := det1.ActualCost() + if cost1 == nil { + t.Fatal("first Eval() returned nil ActualCost") + } + + _, det2, err := program.Eval(input) + if err != nil { + t.Fatalf("second Eval() failed: %v", err) + } + cost2 := det2.ActualCost() + if cost2 == nil { + t.Fatal("second Eval() returned nil ActualCost") + } + + if *cost1 != *cost2 { + t.Errorf("ActualCost mismatch across evaluations: first=%d, second=%d", *cost1, *cost2) + } +} + func TestPartialVars(t *testing.T) { env := testEnv(t, Variable("x", StringType), diff --git a/cel/program.go b/cel/program.go index 79df0374..1fb65877 100644 --- a/cel/program.go +++ b/cel/program.go @@ -261,8 +261,16 @@ func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) { if p.costLimit != nil { costOpts = append(costOpts, interpreter.CostTrackerLimit(*p.costLimit)) } + // Creating a new cost tracker for each evaluation causes significant work that + // needs to be repeated for each evaluation even though the cost tracker is + // mostly read-only once constructed. Therefore it gets constructed + // once now and later a cheap clone is used for each evaluation. + tracker, err := interpreter.NewCostTracker(p.callCostEstimator, costOpts...) + if err != nil { + return nil, fmt.Errorf("construct cost tracker: %w", err) + } trackerFactory := func() (*interpreter.CostTracker, error) { - return interpreter.NewCostTracker(p.callCostEstimator, costOpts...) + return tracker.Clone() } var observers []interpreter.PlannerOption if p.evalOpts&(OptExhaustiveEval|OptTrackState) != 0 { diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index 0233e9a2..68e43101 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -233,6 +233,19 @@ type CostTracker struct { stack refValStack } +// Clone makes a shallow copy of the tracker. +// The different clones can be used independently from +// each other. +func (c *CostTracker) Clone() (*CostTracker, error) { + tracker := &CostTracker{ + Estimator: c.Estimator, + overloadTrackers: c.overloadTrackers, + Limit: c.Limit, + presenceTestHasCost: c.presenceTestHasCost, + } + return tracker, nil +} + // ActualCost returns the runtime cost func (c *CostTracker) ActualCost() uint64 { return c.cost diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index a55db6f5..597d73ef 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -134,6 +134,10 @@ func computeCost(t *testing.T, expr string, vars []*decls.VariableDecl, ctx Acti if err != nil { t.Fatalf("NewCostTracker() failed: %v", err) } + costTracker, err = costTracker.Clone() + if err != nil { + t.Fatalf("checker.Clone() failed: %v", err) + } checked, errs := checker.Check(parsed, s, env) if len(errs.GetErrors()) != 0 { t.Fatalf(`Failed to check expression "%s", error: %v`, expr, errs.GetErrors())