diff --git a/paynt/cli.py b/paynt/cli.py index 4b9acc850..5201233ac 100644 --- a/paynt/cli.py +++ b/paynt/cli.py @@ -1,4 +1,3 @@ -import paynt.quotient.mdp_family from . import version import paynt.utils.timer @@ -10,12 +9,16 @@ import paynt.quotient.posmg import paynt.quotient.storm_pomdp_control import paynt.quotient.mdp +import paynt.quotient.mdp_family import paynt.synthesizer.synthesizer import paynt.synthesizer.synthesizer_cegis import paynt.synthesizer.policy_tree import paynt.synthesizer.decision_tree +import paynt.family.constraints.flexibletree +import paynt.family.constraints.costs + import click import sys import os @@ -130,9 +133,21 @@ def setup_logger(log_path = None): ) @click.option( - "--ce-generator", type=click.Choice(["dtmc", "mdp"]), default="dtmc", show_default=True, + "--ce-generator", type=click.Choice(["dtmc", "mdp", "none"]), default="dtmc", show_default=True, help="counterexample generator", ) + +@click.option("--constraint", + type=click.Choice(['prob1', 'prob0', 'tree', 'costs']), + default=None , show_default=True, + help="constraint type for CEGIS" +) +@click.option("--tree-nodes", default=None, type=int, + help="constraint tree: number of nodes in the decision tree (only for --constraint tree)") +@click.option("--costs-threshold", default=None, type=int, + help="costs constraint: threshold for costs (only for --constraint costs)") + + @click.option("--profiling", is_flag=True, default=False, help="run profiling") @@ -148,7 +163,7 @@ def paynt_run( mdp_discard_unreachable_choices, tree_depth, tree_enumeration, tree_map_scheduler, add_dont_care_action, constraint_bound, - ce_generator, + ce_generator, constraint, tree_nodes, costs_threshold, profiling ): @@ -164,6 +179,7 @@ def paynt_run( paynt.quotient.quotient.Quotient.disable_expected_visits = disable_expected_visits paynt.synthesizer.synthesizer.Synthesizer.export_synthesis_filename_base = export_synthesis paynt.synthesizer.synthesizer_cegis.SynthesizerCEGIS.conflict_generator_type = ce_generator + paynt.synthesizer.synthesizer_cegis.SynthesizerCEGIS.constraint = constraint paynt.quotient.pomdp.PomdpQuotient.initial_memory_size = fsc_memory_size paynt.quotient.pomdp.PomdpQuotient.posterior_aware = posterior_aware paynt.quotient.decpomdp.DecPomdpQuotient.initial_memory_size = fsc_memory_size @@ -178,6 +194,12 @@ def paynt_run( paynt.synthesizer.decision_tree.SynthesizerDecisionTree.scheduler_path = tree_map_scheduler paynt.quotient.mdp.MdpQuotient.add_dont_care_action = add_dont_care_action + if constraint == "tree": + paynt.family.constraints.flexibletree.DecisionTreeConstraint.tree_nodes = tree_nodes + elif constraint == "costs": + paynt.family.constraints.costs.CostsConstraint.costs_threshold = costs_threshold + paynt.family.constraints.costs.CostsConstraint.model_folder = project + storm_control = None if storm_pomdp: storm_control = paynt.quotient.storm_pomdp_control.StormPOMDPControl() diff --git a/paynt/family/constraints/__init__.py b/paynt/family/constraints/__init__.py new file mode 100644 index 000000000..343a1ca8c --- /dev/null +++ b/paynt/family/constraints/__init__.py @@ -0,0 +1,10 @@ +__version__ = "unknown" + +try: + from .._version import __version__ +except ImportError: + # We're running in a tree that doesn't have a _version.py, so we don't know what our version is. + pass + +def version(): + return __version__ \ No newline at end of file diff --git a/paynt/family/constraints/constraints.py b/paynt/family/constraints/constraints.py new file mode 100644 index 000000000..4fca69f5b --- /dev/null +++ b/paynt/family/constraints/constraints.py @@ -0,0 +1,19 @@ + +from paynt.family.constraints.flexibletree import DecisionTreeConstraint +from paynt.family.constraints.prob_goal import ProbGoalConstraint +from paynt.family.constraints.costs import CostsConstraint + +class Constraints: + + @staticmethod + def create_constraint(constraint_type): + if constraint_type == "prob1": + return ProbGoalConstraint(prob=1) + elif constraint_type == "prob0": + return ProbGoalConstraint(prob=0) + elif constraint_type == "tree": + return DecisionTreeConstraint() + elif constraint_type == "costs": + return CostsConstraint() + else: + raise ValueError(f"Unknown constraint type: {constraint_type}") diff --git a/paynt/family/constraints/costs.py b/paynt/family/constraints/costs.py new file mode 100644 index 000000000..35577330e --- /dev/null +++ b/paynt/family/constraints/costs.py @@ -0,0 +1,56 @@ +import z3 + +import logging +import os +logger = logging.getLogger(__name__) + + +class CostsConstraint(): + + costs_threshold = 0 + model_folder : str + + COSTS_FILE = "sketch.costs" + + def __init__(self): + pass + + def build_constraint( + self, + variables, + quotient + ): + # We build the quotient here + + assertions = [] + + lines = None + costs_path = os.path.join(self.model_folder, self.COSTS_FILE) + with open(costs_path, "r") as f: + lines = f.readlines() + + cost_vars = [] + line_index = 0 + for hole in range(quotient.family.num_holes): + for option in range(quotient.family.hole_num_options(hole)): + hole_name = quotient.family.hole_name(hole) + cost_var = z3.Int(f"cost_{hole_name}_{option}") + cost_vars.append(cost_var) + line = lines[line_index].strip() + line_index += 1 + line_hole, line_option, cost_value = line.split() + assert line_hole == hole_name, f"Expected hole {hole_name}, got {line_hole}" + assert int(line_option) == option, f"Expected option {option}, got {line_option}" + + assertions.append( + z3.If( + variables[hole] == option, + cost_var == int(cost_value), + cost_var == 0 + ) + ) + + # Add constraint: sum of cost_vars <= costs_threshold + assertions.append(z3.Sum(cost_vars) <= self.costs_threshold) + return assertions + \ No newline at end of file diff --git a/paynt/family/constraints/flexibletree.py b/paynt/family/constraints/flexibletree.py new file mode 100644 index 000000000..d1008821d --- /dev/null +++ b/paynt/family/constraints/flexibletree.py @@ -0,0 +1,268 @@ +"""A classic decision tree.""" + +import z3 +from itertools import product, chain +import os + + +def piecewise_select(array, z3_int): + """Select an element of an array based on a z3 integer.""" + return z3.Sum([z3.If(z3_int == i, array[i], 0) for i in range(len(array))]) + + +def get_property_names(variable_name): + return [ + x.strip().split("=")[0].replace("!", "") + for x in variable_name[ + variable_name.find("[") + 1 : variable_name.find("]") + ].split("&") + ] + + +def get_property_values(variable_name): + return [ + int(x.strip().split("=")[1]) if "=" in x else (0 if x.strip()[0] == "!" else 1) + for x in variable_name[ + variable_name.find("[") + 1 : variable_name.find("]") + ].split("&") + ] + + +class DecisionTreeConstraint(): + + tree_nodes: int | None + + def __init__(self): + self.policy_vars = None + self.labels = None + self.label_to_index = None + self.left_child_ranges = None + self.right_child_ranges = None + + + def build_constraint( + self, + variables, + quotient + ): + self.variables = variables + num_nodes = self.tree_nodes + + policy_indices = list(range(len(variables))) + policy_vars = [variables[i] for i in policy_indices] + self.policy_vars = policy_vars + + # Collect all action labels and put them into an order + labels = list( + dict.fromkeys( + chain( + *[quotient.family.hole_to_option_labels[i] for i in policy_indices] + ) + ) + ) + print(labels) + label_to_index = {label: i for i, label in enumerate(labels)} + self.labels = labels + self.label_to_index = label_to_index + hole_to_label_indices = [] + # assert 2**num_bits > len(labels) + + # Check that the available action labels of policy vars are consistent + for i in policy_indices: + hole_to_label_indices.append( + [ + label_to_index[label] + for label in quotient.family.hole_to_option_labels[i] + ] + ) + + # variables have names of the form + # A([picked0=1 & picked1=0 & picked2=1 & picked3=1 & picked4=0 & picked5=1 & picked6=1 & x=3 & y=2],0 + first_variable_name = str(policy_vars[0]) + if "A([" not in first_variable_name: + raise ValueError( + "Variables must have properties (e.g., generated from POMDPs.)." + ) + property_names = get_property_names(first_variable_name) + num_properties = len(property_names) + + property_ranges = [(1e6, -1e6) for _ in range(num_properties)] + for variable in policy_vars: + property_values = get_property_values(str(variable)) + for i in range(num_properties): + property_ranges[i] = ( + min(property_ranges[i][0], property_values[i]), + max(property_ranges[i][1], property_values[i]), + ) + + constraints = [] + + # create a function for each node + decision_functions = [] + for i in range(num_nodes): + decision_functions.append( + z3.Function( + f"decision_{i}", + *[z3.IntSort()] * num_properties, + z3.IntSort(), + ) + ) + + # Left child is in range even([i+1, min(2i, num_nodes-1)]) + # Right child is in range odd([i+2, min(2i+1, num_nodes)]) + self.left_child_ranges = [ + [j for j in range(i + 1, min(2 * (i + 1), num_nodes)) if j % 2 == 1] + for i in range(num_nodes) + ] + self.right_child_ranges = [ + [j for j in range(i + 2, min(2 * (i + 1) + 1, num_nodes)) if j % 2 == 0] + for i in range(num_nodes) + ] + + # make weight nodes for constraints + node_constants = [] + property_indices = [] + node_is_leaf = [] + left_children = [] + right_children = [] + + for i in range(num_nodes): + # Is this node a leaf? + is_leaf = z3.Bool(f"leaf_{i}") + node_is_leaf.append(is_leaf) + + # The constant of a node + constant_var = z3.Int(f"const_{i}") + node_constants.append(constant_var) + + # The property index of a node + prop_index = z3.Int(f"prop_index_{i}") + # Must be in range + constraints.append(prop_index >= 0) + # print(num_properties) + constraints.append(prop_index < num_properties) + property_indices.append(prop_index) + + constraints.append(constant_var >= 0) + constraints.append( + z3.If( + is_leaf, + constant_var < len(labels), + constant_var <= piecewise_select( + [z3.IntVal(x[1]) for x in property_ranges], + prop_index, + ), + ) + ) + + left_child = z3.Int(f"left_{i}") + left_children.append(left_child) + right_child = z3.Int(f"right_{i}") + right_children.append(right_child) + # If this node is a leaf, the left and right children must be 0 + + constraints.append( + z3.If( + is_leaf, + left_child == 0, + left_child <= len(self.left_child_ranges[i]), + ) + ) + constraints.append( + z3.If( + is_leaf, + right_child == 0, + right_child <= len(self.right_child_ranges[i]), + ) + ) + constraints.append(z3.Implies(is_leaf, prop_index == 0)) + + all_property_values = [ + get_property_values(str(variable)) + for variable in enumerate(policy_vars) + ] + + for values in all_property_values: + prop_vals = [z3.IntVal(v) for v in values] + constraints.append( + z3.If( + is_leaf, + decision_functions[i](*prop_vals) == constant_var, + z3.If( + piecewise_select(prop_vals, prop_index) >= constant_var, + z3.Or( + *[ + z3.And( + left_child == j, + decision_functions[i](*prop_vals) + == decision_functions[ + self.left_child_ranges[i][j] + ](*prop_vals), + ) + for j in range(len(self.left_child_ranges[i])) + ] + ), + z3.Or( + *[ + z3.And( + right_child == j, + decision_functions[i](*prop_vals) + == decision_functions[ + self.right_child_ranges[i][j] + ](*prop_vals), + ) + for j in range(len(self.right_child_ranges[i])) + ] + ), + ), + ) + ) + + # each tree has (num_nodes+1) / 2 leaves + constraints.append(z3.Sum(node_is_leaf) == (num_nodes + 1) // 2) + # each node, except 0, must have a parent, that is before it + + for i in range(1, num_nodes): + # identify the nodes that have i in left_child_ranges or right_child_ranges + left_children_ranges = [ + j for j in range(num_nodes) if i in self.left_child_ranges[j] + ] + right_children_ranges = [ + j for j in range(num_nodes) if i in self.right_child_ranges[j] + ] + # i is left_child of one of the left_children or right_child of one of the right_children + parent_constraint = z3.Or( + *[ + z3.And( + left_children[x] == self.left_child_ranges[x].index(i), + z3.Not(node_is_leaf[x]), + ) + for x in left_children_ranges + if i in self.left_child_ranges[x] + ] + + [ + z3.And( + right_children[x] == self.right_child_ranges[x].index(i), + z3.Not(node_is_leaf[x]), + ) + for x in right_children_ranges + if i in self.right_child_ranges[x] + ] + ) + constraints.append(parent_constraint) + + for i, variable in enumerate(policy_vars): + label_range = quotient.family.hole_to_option_labels[policy_indices[i]] + if label_range == labels: + # The semantics of the variable is the same as the decision tree's + property_values = get_property_values(str(variable)) + constraints.append(variable == decision_functions[0](*property_values)) + else: + # We need to map the decision tree's value to the label index + label_indices = [label_to_index[label] for label in label_range] + property_values = get_property_values(str(variable)) + x = decision_functions[0](*property_values) + for index, label_index in enumerate(label_indices): + constraints.append((variable == index) == (x == label_index)) + + return constraints diff --git a/paynt/family/constraints/prob_goal.py b/paynt/family/constraints/prob_goal.py new file mode 100644 index 000000000..0853b2a41 --- /dev/null +++ b/paynt/family/constraints/prob_goal.py @@ -0,0 +1,137 @@ +"""Reach goal with prob>0 or prob=1.""" + +import z3 +import math +from stormpy import model_checking + +import logging +logger = logging.getLogger(__name__) + +class ProbGoalConstraint(): + def __init__(self, prob: int = 0): + assert prob in [0, 1], "ProbGoal requires prob to be either 0 or 1." + self.prob0 = (prob == 0) + + def build_constraint( + self, + variables, + quotient + ): + + # We build the quotient here + quotient.build(quotient.family) + + transition_matrix = quotient.family.mdp.model.transition_matrix + + choice_to_assignment = quotient.coloring.getChoiceToAssignment() + + target_states = model_checking(quotient.family.mdp.model, quotient.specification.all_properties()[0].formula.subformula.subformula).get_truth_values() + + assertions = [] + + initial_state = quotient.family.mdp.model.initial_states[0] + assert len(quotient.family.mdp.model.initial_states) == 1, "ProbGoal only supports single initial states." + + reachability_vars = [] + for state in range(transition_matrix.nr_columns): + reach_var = z3.Bool(f"reach_{state}") + reachability_vars.append(reach_var) + + if not self.prob0: + min_step_vars = [] + for state in range(transition_matrix.nr_columns): + min_step_var = z3.Int(f"min_step_{state}") + min_step_vars.append(min_step_var) + assertions.append(min_step_var >= 0) + + for state in range(transition_matrix.nr_columns): + if target_states.get(state): + assertions.append(reachability_vars[state]) + assertions.append(min_step_vars[state] == 0) + continue + + statement_for_state = [] + + rows = transition_matrix.get_rows_for_group(state) + for row in rows: + assignment = choice_to_assignment[row] + assignment_as_z3 = z3.And([ + variables[var] == z3.IntVal(x) + for var, x in assignment + ]) + + reachability_vars_of_row = [] + min_step_vars_of_row = [] + + for entry in transition_matrix.get_row(row): + value = entry.value() + if value == 0: + continue + assert value > 0, "Transition probabilities must be positive." + to_state = entry.column + if to_state == state: + continue + reachability_vars_of_row.append(reachability_vars[to_state]) + min_step_vars_of_row.append(min_step_vars[to_state]) + statement_for_state.append(z3.Implies(assignment_as_z3, z3.And(reachability_vars_of_row))) + assertions.append( + z3.Implies( + z3.And(reachability_vars[state], assignment_as_z3), + z3.And( + z3.Or([min_step_vars[state] == x + 1 for x in min_step_vars_of_row]), + z3.And([min_step_vars[state] <= x + 1 for x in min_step_vars_of_row]) + ) + ) + ) + assertions.append(z3.Implies(reachability_vars[state], z3.And(statement_for_state))) + else: + min_step_vars = [] + for state in range(transition_matrix.nr_columns): + min_step_var = z3.Int(f"min_step_{state}") + min_step_vars.append(min_step_var) + assertions.append(min_step_var >= 0) + + for state in range(transition_matrix.nr_columns): + if target_states.get(state): + assertions.append(reachability_vars[state]) + assertions.append(min_step_vars[state] == 0) + continue + + statement_for_state = [] + + rows = transition_matrix.get_rows_for_group(state) + for row in rows: + assignment = choice_to_assignment[row] + assignment_as_z3 = z3.And([ + variables[var] == z3.IntVal(x) + for var, x in assignment + ]) + + reachability_vars_of_row = [] + min_step_vars_of_row = [] + + for entry in transition_matrix.get_row(row): + value = entry.value() + if value == 0: + continue + assert value > 0, "Transition probabilities must be positive." + to_state = entry.column + if to_state == state: + continue + reachability_vars_of_row.append(reachability_vars[to_state]) + min_step_vars_of_row.append(min_step_vars[to_state]) + statement_for_state.append(z3.Implies(assignment_as_z3, z3.Or(reachability_vars_of_row))) + assertions.append( + z3.Implies( + z3.And(reachability_vars[state], assignment_as_z3), + z3.And( + z3.Or([min_step_vars[state] == x + 1 for x in min_step_vars_of_row]), + z3.And([min_step_vars[state] <= x + 1 for x in min_step_vars_of_row]) + ) + ) + ) + assertions.append(z3.Implies(reachability_vars[state], z3.Or(statement_for_state))) + + assertions.append(reachability_vars[initial_state]) + logger.info("Done building assertions for ProbGoal.") + return assertions diff --git a/paynt/family/constraints/tree.py b/paynt/family/constraints/tree.py new file mode 100644 index 000000000..e579facfd --- /dev/null +++ b/paynt/family/constraints/tree.py @@ -0,0 +1,152 @@ +"""A classic decision tree.""" + +import z3 + + +def piecewise_select(array, z3_int): + """Select an element of an array based on a z3 integer.""" + return z3.Sum([z3.If(z3_int == i, array[i], 0) for i in range(len(array))]) + + +def get_property_names(variable_name): + return [ + x.strip().split("=")[0].replace("!", "") + for x in variable_name[ + variable_name.find("[") + 1 : variable_name.find("]") + ].split("&") + ] + + +def get_property_values(variable_name): + return [ + int(x.strip().split("=")[1]) if "=" in x else (0 if x.strip()[0] == "!" else 1) + for x in variable_name[ + variable_name.find("[") + 1 : variable_name.find("]") + ].split("&") + ] + + +class DecisionTreeConstraintOld(): + + tree_depth: int + + tree_nodes: int | None + + def __init__(self): + pass + + def build_constraint(self, variables, quotient): + tree_depth = self.tree_depth + self.variables = variables + num_enabled_nodes = self.tree_nodes + + # variables have names of the form + # A([picked0=1 & picked1=0 & picked2=1 & picked3=1 & picked4=0 & picked5=1 & picked6=1 & x=3 & y=2],0 + first_variable_name = str(variables[0]) + if "A([" not in first_variable_name: + raise ValueError( + "Variables must have properties (e.g., generated from POMDPs.)." + ) + property_names = get_property_names(first_variable_name) + num_properties = len(property_names) + + property_ranges = [(1e6, -1e6) for _ in range(num_properties)] + for variable in variables: + property_values = get_property_values(str(variable)) + for i in range(num_properties): + property_ranges[i] = ( + min(property_ranges[i][0], property_values[i]), + max(property_ranges[i][1], property_values[i]), + ) + + # create a function + max_action_size = max([len(quotient.family.hole_options(hole)) for hole in range(len(variables))]) + decision_func = z3.Function( + "decision", *[z3.IntSort()] * num_properties, z3.IntSort() + ) + + # decision_func_int = z3.Function( + # "decision", *[z3.IntSort()] * num_properties, z3.IntSort() + # ) + + constraints = [] + + # tree is structured as follows + # 0 + # 1 2 + # 3 4 5 6 + # 7 8 9 10 11 12 13 14 + + num_nodes = 2**tree_depth - 1 + leaf_values = [ + z3.Int(f"leaf_{i}") for i in range(num_nodes + 1) + ] + + # make weight nodes for constraints + node_property = [] + node_constants = [] + for i in range(num_nodes): + # weight per variable + prop_index = z3.Int(f"node_{i}") + + # prop index must be in range + constraints.append(prop_index >= 0) + constraints.append(prop_index < num_properties) + + node_property.append(prop_index) + constant_var = z3.Int(f"const_{i}") + node_constants.append(constant_var) + constraints.append(constant_var >= 0) + + # if the constant of this node is > 0, this is also true for the parent + # this breaks symmetry for disabled nodes + if i > 0: + constraints.append( + z3.Implies(node_constants[i] > 0, node_constants[(i - 1) // 2] > 0) + ) + # if the constant is 0, the property must be 0 + constraints.append( + z3.Implies(node_constants[i] == 0, node_property[i] == 0) + ) + + # only num_enabled_nodes nodes can have constant > 0 + if num_enabled_nodes is not None: + constraints.append( + z3.Sum([z3.If(node_constants[i] > 0, 1, 0) for i in range(num_nodes)]) + == num_enabled_nodes + ) + + def decision_at_node(node: int, properties): + return z3.Or( + node_constants[node] == 0, + z3.Sum( + [ + z3.If(node_property[node] == i, properties[i], 0) + for i in range(num_properties) + ] + ) + >= node_constants[node], + ) + + def traverse_tree(node: int, properties: list[z3.Int]): + if node >= num_nodes: + return leaf_values[node - num_nodes] + else: + left = traverse_tree(2 * node + 1, properties) + right = traverse_tree(2 * node + 2, properties) + return z3.If(decision_at_node(node, properties), left, right) + + decision_variables = [z3.Int(f"decision_{i}") for i in range(num_properties)] + constraints.append( + z3.ForAll( + decision_variables, + traverse_tree(0, decision_variables) + == decision_func(*decision_variables), + ) + ) + + for variable in variables: + property_values = get_property_values(str(variable)) + constraints.append(variable == decision_func(*property_values)) + return constraints + diff --git a/paynt/family/smt.py b/paynt/family/smt.py index e1feacae5..e2df29678 100644 --- a/paynt/family/smt.py +++ b/paynt/family/smt.py @@ -1,6 +1,8 @@ import sys import z3 +from paynt.family.constraints.constraints import Constraints + # import pycvc5 if installed import importlib if importlib.util.find_spec('pycvc5') is not None: @@ -49,6 +51,18 @@ def __init__(self, smt_solver, family): else: pass + if smt_solver.constraint is not None: + + logger.info(f"Adding constraint {smt_solver.constraint} to the encoding.") + constraint = Constraints.create_constraint(smt_solver.constraint) + + constraint_smt_clauses = constraint.build_constraint( + self.smt_solver.solver_vars, + self.smt_solver.quotient + ) + + encoding = z3.And(encoding, *constraint_smt_clauses) + self.hole_clauses = hole_clauses self.encoding = encoding @@ -86,7 +100,7 @@ def pick_assignment(self): class SmtSolver(): - def __init__(self, family): + def __init__(self, quotient, constraint=None): # SMT solver containing description of the unexplored design space self.solver = None @@ -103,8 +117,14 @@ def __init__(self, family): # current depth of push/pop solving self.solver_depth = 0 + # initial constraint for the design space + self.constraint = constraint + + self.quotient = quotient + family = quotient.family + # choose solver - if "pycvc5" in sys.modules: + if "pycvc5" in sys.modules and self.constraint is None: logger.debug("using CVC5 for SMT solving.") self.use_cvc = True else: @@ -115,7 +135,7 @@ def __init__(self, family): self.solver_clauses = [] if self.use_python_z3: self.solver = z3.Solver() - self.solver_vars = [z3.Int(hole) for hole in range(family.num_holes)] + self.solver_vars = [z3.Int(family.hole_name(hole)) for hole in range(family.num_holes)] elif self.use_cvc: self.solver = pycvc5.Solver() self.solver.setOption("produce-models", "true") diff --git a/paynt/synthesizer/synthesizer_cegis.py b/paynt/synthesizer/synthesizer_cegis.py index 95ab4cbc4..0c51f92c0 100644 --- a/paynt/synthesizer/synthesizer_cegis.py +++ b/paynt/synthesizer/synthesizer_cegis.py @@ -12,6 +12,9 @@ class SynthesizerCEGIS(paynt.synthesizer.synthesizer.Synthesizer): # CLI argument selecting conflict generator conflict_generator_type = None + # CLI argument for setting initial constraint on the design space + constraint = None + def __init__(self, quotient): super().__init__(quotient) @@ -25,15 +28,19 @@ def __init__(self, quotient): def choose_conflict_generator(self, quotient): if SynthesizerCEGIS.conflict_generator_type == "mdp": conflict_generator = paynt.synthesizer.conflict_generator.mdp.ConflictGeneratorMdp(quotient) - else: + elif SynthesizerCEGIS.conflict_generator_type == "dtmc": # default conflict generator conflict_generator = paynt.synthesizer.conflict_generator.dtmc.ConflictGeneratorDtmc(quotient) + elif SynthesizerCEGIS.conflict_generator_type == "none": + conflict_generator = None + else: + raise ValueError(f"Unknown conflict generator type: {SynthesizerCEGIS.conflict_generator_type}") return conflict_generator @property def method_name(self): - return "CEGIS " + self.conflict_generator.name + return "CEGIS " + (self.conflict_generator.name if self.conflict_generator else "no CEs") def collect_conflict_requests(self, family, mc_result): @@ -79,8 +86,11 @@ def analyze_family_assignment_cegis(self, family, assignment): if accepting and not self.quotient.specification.can_be_improved(): return [], accepting_assignment - conflict_requests = self.collect_conflict_requests(family, result) - conflicts = self.conflict_generator.construct_conflicts(family, assignment, dtmc, conflict_requests) + if self.conflict_generator is not None: + conflict_requests = self.collect_conflict_requests(family, result) + conflicts = self.conflict_generator.construct_conflicts(family, assignment, dtmc, conflict_requests) + else: + conflicts = [[hole for hole in range(family.num_holes)]] return conflicts, accepting_assignment @@ -89,10 +99,11 @@ def synthesize_one(self, family): # build the quotient, map mdp states to hole indices self.quotient.build(family) - self.conflict_generator.initialize() + if self.conflict_generator is not None: + self.conflict_generator.initialize() # use sketch design space as a SAT baseline (TODO why?) - smt_solver = paynt.family.smt.SmtSolver(self.quotient.family) + smt_solver = paynt.family.smt.SmtSolver(self.quotient, self.constraint) # CEGIS loop assignment = smt_solver.pick_assignment(family) diff --git a/paynt/synthesizer/synthesizer_hybrid.py b/paynt/synthesizer/synthesizer_hybrid.py index da0485865..ee9c0cd83 100644 --- a/paynt/synthesizer/synthesizer_hybrid.py +++ b/paynt/synthesizer/synthesizer_hybrid.py @@ -94,7 +94,7 @@ def method_name(self): def synthesize_one(self, family): self.conflict_generator.initialize() - smt_solver = paynt.family.smt.SmtSolver(self.quotient.family) + smt_solver = paynt.family.smt.SmtSolver(self.quotient) # AR-CEGIS loop families = [family]