Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,48 @@ func TestCostLimit(t *testing.T) {
}
}

func TestCostTrackingConsistentAcrossEvals(t *testing.T) {
env := testEnv(t,
Variable("val1", IntType),
Variable("val2", IntType),
)
ast, iss := env.Compile(`val1 + val2`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
checkedAst, iss := env.Check(ast)
if iss.Err() != nil {
t.Fatalf("env.Check() failed: %v", iss.Err())
}
program, err := env.Program(checkedAst, CostTracking(nil))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
input := map[string]any{"val1": 1, "val2": 2}

_, det1, err := program.Eval(input)
if err != nil {
t.Fatalf("first Eval() failed: %v", err)
}
cost1 := det1.ActualCost()
if cost1 == nil {
t.Fatal("first Eval() returned nil ActualCost")
}

_, det2, err := program.Eval(input)
if err != nil {
t.Fatalf("second Eval() failed: %v", err)
}
cost2 := det2.ActualCost()
if cost2 == nil {
t.Fatal("second Eval() returned nil ActualCost")
}

if *cost1 != *cost2 {
t.Errorf("ActualCost mismatch across evaluations: first=%d, second=%d", *cost1, *cost2)
}
}

func TestPartialVars(t *testing.T) {
env := testEnv(t,
Variable("x", StringType),
Expand Down
10 changes: 9 additions & 1 deletion cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,16 @@ func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) {
if p.costLimit != nil {
costOpts = append(costOpts, interpreter.CostTrackerLimit(*p.costLimit))
}
// Creating a new cost tracker for each evaluation causes significant work that
// needs to be repeated for each evaluation even though the cost tracker is
// mostly read-only once constructed. Therefore it gets constructed
// once now and later a cheap clone is used for each evaluation.
tracker, err := interpreter.NewCostTracker(p.callCostEstimator, costOpts...)
if err != nil {
return nil, fmt.Errorf("construct cost tracker: %w", err)
}
trackerFactory := func() (*interpreter.CostTracker, error) {
return interpreter.NewCostTracker(p.callCostEstimator, costOpts...)
return tracker.Clone()
}
var observers []interpreter.PlannerOption
if p.evalOpts&(OptExhaustiveEval|OptTrackState) != 0 {
Expand Down
13 changes: 13 additions & 0 deletions interpreter/runtimecost.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,19 @@ type CostTracker struct {
stack refValStack
}

// Clone makes a shallow copy of the tracker.
// The different clones can be used independently from
// each other.
func (c *CostTracker) Clone() (*CostTracker, error) {
tracker := &CostTracker{
Estimator: c.Estimator,
overloadTrackers: c.overloadTrackers,
Limit: c.Limit,
presenceTestHasCost: c.presenceTestHasCost,
}
return tracker, nil
}

// ActualCost returns the runtime cost
func (c *CostTracker) ActualCost() uint64 {
return c.cost
Expand Down
4 changes: 4 additions & 0 deletions interpreter/runtimecost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ func computeCost(t *testing.T, expr string, vars []*decls.VariableDecl, ctx Acti
if err != nil {
t.Fatalf("NewCostTracker() failed: %v", err)
}
costTracker, err = costTracker.Clone()
if err != nil {
t.Fatalf("checker.Clone() failed: %v", err)
}
checked, errs := checker.Check(parsed, s, env)
if len(errs.GetErrors()) != 0 {
t.Fatalf(`Failed to check expression "%s", error: %v`, expr, errs.GetErrors())
Expand Down