Skip to content
Draft
28 changes: 25 additions & 3 deletions paynt/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import paynt.quotient.mdp_family
from . import version

import paynt.utils.timer
Expand All @@ -10,12 +9,16 @@
import paynt.quotient.posmg
import paynt.quotient.storm_pomdp_control
import paynt.quotient.mdp
import paynt.quotient.mdp_family

import paynt.synthesizer.synthesizer
import paynt.synthesizer.synthesizer_cegis
import paynt.synthesizer.policy_tree
import paynt.synthesizer.decision_tree

import paynt.family.constraints.flexibletree
import paynt.family.constraints.costs

import click
import sys
import os
Expand Down Expand Up @@ -130,9 +133,21 @@ def setup_logger(log_path = None):
)

@click.option(
"--ce-generator", type=click.Choice(["dtmc", "mdp"]), default="dtmc", show_default=True,
"--ce-generator", type=click.Choice(["dtmc", "mdp", "none"]), default="dtmc", show_default=True,
help="counterexample generator",
)

@click.option("--constraint",
type=click.Choice(['prob1', 'prob0', 'tree', 'costs']),
default=None , show_default=True,
help="constraint type for CEGIS"
)
@click.option("--tree-nodes", default=None, type=int,
help="constraint tree: number of nodes in the decision tree (only for --constraint tree)")
@click.option("--costs-threshold", default=None, type=int,
help="costs constraint: threshold for costs (only for --constraint costs)")


@click.option("--profiling", is_flag=True, default=False,
help="run profiling")

Expand All @@ -148,7 +163,7 @@ def paynt_run(
mdp_discard_unreachable_choices,
tree_depth, tree_enumeration, tree_map_scheduler, add_dont_care_action,
constraint_bound,
ce_generator,
ce_generator, constraint, tree_nodes, costs_threshold,
profiling
):

Expand All @@ -164,6 +179,7 @@ def paynt_run(
paynt.quotient.quotient.Quotient.disable_expected_visits = disable_expected_visits
paynt.synthesizer.synthesizer.Synthesizer.export_synthesis_filename_base = export_synthesis
paynt.synthesizer.synthesizer_cegis.SynthesizerCEGIS.conflict_generator_type = ce_generator
paynt.synthesizer.synthesizer_cegis.SynthesizerCEGIS.constraint = constraint
paynt.quotient.pomdp.PomdpQuotient.initial_memory_size = fsc_memory_size
paynt.quotient.pomdp.PomdpQuotient.posterior_aware = posterior_aware
paynt.quotient.decpomdp.DecPomdpQuotient.initial_memory_size = fsc_memory_size
Expand All @@ -178,6 +194,12 @@ def paynt_run(
paynt.synthesizer.decision_tree.SynthesizerDecisionTree.scheduler_path = tree_map_scheduler
paynt.quotient.mdp.MdpQuotient.add_dont_care_action = add_dont_care_action

if constraint == "tree":
paynt.family.constraints.flexibletree.DecisionTreeConstraint.tree_nodes = tree_nodes
elif constraint == "costs":
paynt.family.constraints.costs.CostsConstraint.costs_threshold = costs_threshold
paynt.family.constraints.costs.CostsConstraint.model_folder = project

storm_control = None
if storm_pomdp:
storm_control = paynt.quotient.storm_pomdp_control.StormPOMDPControl()
Expand Down
10 changes: 10 additions & 0 deletions paynt/family/constraints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__version__ = "unknown"

try:
from .._version import __version__
except ImportError:
# We're running in a tree that doesn't have a _version.py, so we don't know what our version is.
pass

def version():
return __version__
19 changes: 19 additions & 0 deletions paynt/family/constraints/constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

from paynt.family.constraints.flexibletree import DecisionTreeConstraint
from paynt.family.constraints.prob_goal import ProbGoalConstraint
from paynt.family.constraints.costs import CostsConstraint

class Constraints:

@staticmethod
def create_constraint(constraint_type):
if constraint_type == "prob1":
return ProbGoalConstraint(prob=1)
elif constraint_type == "prob0":
return ProbGoalConstraint(prob=0)
elif constraint_type == "tree":
return DecisionTreeConstraint()
elif constraint_type == "costs":
return CostsConstraint()
else:
raise ValueError(f"Unknown constraint type: {constraint_type}")
56 changes: 56 additions & 0 deletions paynt/family/constraints/costs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import z3

import logging
import os
logger = logging.getLogger(__name__)


class CostsConstraint():

costs_threshold = 0
model_folder : str

COSTS_FILE = "sketch.costs"

def __init__(self):
pass

def build_constraint(
self,
variables,
quotient
):
# We build the quotient here

assertions = []

lines = None
costs_path = os.path.join(self.model_folder, self.COSTS_FILE)
with open(costs_path, "r") as f:
lines = f.readlines()

cost_vars = []
line_index = 0
for hole in range(quotient.family.num_holes):
for option in range(quotient.family.hole_num_options(hole)):
hole_name = quotient.family.hole_name(hole)
cost_var = z3.Int(f"cost_{hole_name}_{option}")
cost_vars.append(cost_var)
line = lines[line_index].strip()
line_index += 1
line_hole, line_option, cost_value = line.split()
assert line_hole == hole_name, f"Expected hole {hole_name}, got {line_hole}"
assert int(line_option) == option, f"Expected option {option}, got {line_option}"

assertions.append(
z3.If(
variables[hole] == option,
cost_var == int(cost_value),
cost_var == 0
)
)

# Add constraint: sum of cost_vars <= costs_threshold
assertions.append(z3.Sum(cost_vars) <= self.costs_threshold)
return assertions

Loading
Loading