diff --git a/src/storm-pomdp/transformer/ObservationTraceUnfolder.h b/src/storm-pomdp/transformer/ObservationTraceUnfolder.h index 393410643..a03e8b145 100644 --- a/src/storm-pomdp/transformer/ObservationTraceUnfolder.h +++ b/src/storm-pomdp/transformer/ObservationTraceUnfolder.h @@ -47,7 +47,7 @@ class ObservationTraceUnfolder { bool isRestartSemanticsSet() const; private: - storm::models::sparse::Pomdp const& model; + storm::models::sparse::Pomdp model; std::vector risk; // TODO reconsider holding this as a reference, but there were some strange bugs std::shared_ptr& exprManager; std::vector traceSoFar; diff --git a/src/storm/environment/SubEnvironment.cpp b/src/storm/environment/SubEnvironment.cpp index ab2a7afbd..d187483e1 100644 --- a/src/storm/environment/SubEnvironment.cpp +++ b/src/storm/environment/SubEnvironment.cpp @@ -48,6 +48,7 @@ void SubEnvironment::assertInitialized() const { template class SubEnvironment; +template class SubEnvironment; template class SubEnvironment; template class SubEnvironment; diff --git a/src/storm/environment/modelchecker/AllModelCheckerEnvironments.h b/src/storm/environment/modelchecker/AllModelCheckerEnvironments.h index f37921b60..30a6735b2 100644 --- a/src/storm/environment/modelchecker/AllModelCheckerEnvironments.h +++ b/src/storm/environment/modelchecker/AllModelCheckerEnvironments.h @@ -1,4 +1,5 @@ #pragma once +#include "storm/environment/modelchecker/ConditionalModelCheckerEnvironment.h" #include "storm/environment/modelchecker/ModelCheckerEnvironment.h" #include "storm/environment/modelchecker/MultiObjectiveModelCheckerEnvironment.h" \ No newline at end of file diff --git a/src/storm/environment/modelchecker/ConditionalModelCheckerEnvironment.cpp b/src/storm/environment/modelchecker/ConditionalModelCheckerEnvironment.cpp new file mode 100644 index 000000000..9859b41bf --- /dev/null +++ b/src/storm/environment/modelchecker/ConditionalModelCheckerEnvironment.cpp @@ -0,0 +1,51 @@ +#include "storm/environment/modelchecker/ConditionalModelCheckerEnvironment.h" + +#include "storm/adapters/RationalNumberForward.h" +#include "storm/settings/SettingsManager.h" +#include "storm/settings/modules/ConditionalSettings.h" +#include "storm/utility/constants.h" + +namespace storm { + +ConditionalModelCheckerEnvironment::ConditionalModelCheckerEnvironment() { + auto const& mcSettings = storm::settings::getModule(); + algorithm = mcSettings.getConditionalAlgorithmSetting(); + precision = storm::utility::convertNumber(mcSettings.getConditionalPrecision()); + relative = !mcSettings.isConditionalPrecisionAbsolute(); + precisionSetFromDefault = mcSettings.isConditionalPrecisionSetFromDefaultValue(); +} + +ConditionalModelCheckerEnvironment::~ConditionalModelCheckerEnvironment() { + // Intentionally left empty +} + +ConditionalAlgorithmSetting ConditionalModelCheckerEnvironment::getAlgorithm() const { + return algorithm; +} + +void ConditionalModelCheckerEnvironment::setAlgorithm(ConditionalAlgorithmSetting value) { + algorithm = value; +} + +storm::RationalNumber ConditionalModelCheckerEnvironment::getPrecision() const { + return precision; +} + +void ConditionalModelCheckerEnvironment::setPrecision(storm::RationalNumber const& value, bool setFromDefault) { + precision = value; + precisionSetFromDefault = setFromDefault; +} + +bool ConditionalModelCheckerEnvironment::isPrecisionSetFromDefault() const { + return precisionSetFromDefault; +} + +bool ConditionalModelCheckerEnvironment::isRelativePrecision() const { + return relative; +} + +void ConditionalModelCheckerEnvironment::setRelativePrecision(bool value) { + relative = value; +} + +} // namespace storm diff --git a/src/storm/environment/modelchecker/ConditionalModelCheckerEnvironment.h b/src/storm/environment/modelchecker/ConditionalModelCheckerEnvironment.h new file mode 100644 index 000000000..296cf85a5 --- /dev/null +++ b/src/storm/environment/modelchecker/ConditionalModelCheckerEnvironment.h @@ -0,0 +1,30 @@ +#pragma once + +#include "storm/adapters/RationalNumberAdapter.h" +#include "storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.h" + +namespace storm { + +class ConditionalModelCheckerEnvironment { + public: + ConditionalModelCheckerEnvironment(); + ~ConditionalModelCheckerEnvironment(); + + ConditionalAlgorithmSetting getAlgorithm() const; + void setAlgorithm(ConditionalAlgorithmSetting value); + + storm::RationalNumber getPrecision() const; + bool isPrecisionSetFromDefault() const; + void setPrecision(storm::RationalNumber const& value, bool setFromDefault); + + bool isRelativePrecision() const; + void setRelativePrecision(bool value); + + private: + ConditionalAlgorithmSetting algorithm; + storm::RationalNumber precision; + bool precisionSetFromDefault; + bool relative; +}; + +} // namespace storm diff --git a/src/storm/environment/modelchecker/ModelCheckerEnvironment.cpp b/src/storm/environment/modelchecker/ModelCheckerEnvironment.cpp index 5ff79fd9a..bed60a93b 100644 --- a/src/storm/environment/modelchecker/ModelCheckerEnvironment.cpp +++ b/src/storm/environment/modelchecker/ModelCheckerEnvironment.cpp @@ -1,5 +1,6 @@ #include "storm/environment/modelchecker/ModelCheckerEnvironment.h" +#include "storm/environment/modelchecker/ConditionalModelCheckerEnvironment.h" #include "storm/environment/modelchecker/MultiObjectiveModelCheckerEnvironment.h" #include "storm/settings/SettingsManager.h" @@ -19,28 +20,26 @@ ModelCheckerEnvironment::ModelCheckerEnvironment() { } auto const& ioSettings = storm::settings::getModule(); steadyStateDistributionAlgorithm = ioSettings.getSteadyStateDistributionAlgorithm(); - - conditionalAlgorithmSetting = mcSettings.getConditionalAlgorithmSetting(); } ModelCheckerEnvironment::~ModelCheckerEnvironment() { // Intentionally left empty } -SteadyStateDistributionAlgorithm ModelCheckerEnvironment::getSteadyStateDistributionAlgorithm() const { - return steadyStateDistributionAlgorithm; +ConditionalModelCheckerEnvironment& ModelCheckerEnvironment::conditional() { + return conditionalModelCheckerEnvironment.get(); } -void ModelCheckerEnvironment::setSteadyStateDistributionAlgorithm(SteadyStateDistributionAlgorithm value) { - steadyStateDistributionAlgorithm = value; +ConditionalModelCheckerEnvironment const& ModelCheckerEnvironment::conditional() const { + return conditionalModelCheckerEnvironment.get(); } -ConditionalAlgorithmSetting ModelCheckerEnvironment::getConditionalAlgorithmSetting() const { - return conditionalAlgorithmSetting; +SteadyStateDistributionAlgorithm ModelCheckerEnvironment::getSteadyStateDistributionAlgorithm() const { + return steadyStateDistributionAlgorithm; } -void ModelCheckerEnvironment::setConditionalAlgorithmSetting(ConditionalAlgorithmSetting value) { - conditionalAlgorithmSetting = value; +void ModelCheckerEnvironment::setSteadyStateDistributionAlgorithm(SteadyStateDistributionAlgorithm value) { + steadyStateDistributionAlgorithm = value; } MultiObjectiveModelCheckerEnvironment& ModelCheckerEnvironment::multi() { diff --git a/src/storm/environment/modelchecker/ModelCheckerEnvironment.h b/src/storm/environment/modelchecker/ModelCheckerEnvironment.h index fce3e9e33..983d3d37a 100644 --- a/src/storm/environment/modelchecker/ModelCheckerEnvironment.h +++ b/src/storm/environment/modelchecker/ModelCheckerEnvironment.h @@ -1,17 +1,16 @@ #pragma once #include -#include #include #include "storm/environment/Environment.h" #include "storm/environment/SubEnvironment.h" -#include "storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.h" #include "storm/modelchecker/helper/infinitehorizon/SteadyStateDistributionAlgorithm.h" namespace storm { // Forward declare subenvironments +class ConditionalModelCheckerEnvironment; class MultiObjectiveModelCheckerEnvironment; class ModelCheckerEnvironment { @@ -19,24 +18,24 @@ class ModelCheckerEnvironment { ModelCheckerEnvironment(); ~ModelCheckerEnvironment(); + ConditionalModelCheckerEnvironment& conditional(); + ConditionalModelCheckerEnvironment const& conditional() const; + MultiObjectiveModelCheckerEnvironment& multi(); MultiObjectiveModelCheckerEnvironment const& multi() const; SteadyStateDistributionAlgorithm getSteadyStateDistributionAlgorithm() const; void setSteadyStateDistributionAlgorithm(SteadyStateDistributionAlgorithm value); - ConditionalAlgorithmSetting getConditionalAlgorithmSetting() const; - void setConditionalAlgorithmSetting(ConditionalAlgorithmSetting value); - bool isLtl2daToolSet() const; std::string const& getLtl2daTool() const; void setLtl2daTool(std::string const& value); void unsetLtl2daTool(); private: + SubEnvironment conditionalModelCheckerEnvironment; SubEnvironment multiObjectiveModelCheckerEnvironment; boost::optional ltl2daTool; SteadyStateDistributionAlgorithm steadyStateDistributionAlgorithm; - ConditionalAlgorithmSetting conditionalAlgorithmSetting; }; } // namespace storm diff --git a/src/storm/io/DirectEncodingExporter.cpp b/src/storm/io/DirectEncodingExporter.cpp index 26679ca7f..2f21aede6 100644 --- a/src/storm/io/DirectEncodingExporter.cpp +++ b/src/storm/io/DirectEncodingExporter.cpp @@ -329,6 +329,9 @@ template void explicitExportSparseModel(std::filesystem template void explicitExportSparseModel(std::filesystem::path const& os, std::shared_ptr> sparseModel, std::vector const& parameters, DirectEncodingExporterOptions const& options); +template void explicitExportSparseModel(std::filesystem::path const& os, + std::shared_ptr> sparseModel, + std::vector const& parameters, DirectEncodingExporterOptions const& options); template void explicitExportSparseModel(std::ostream& os, std::shared_ptr> sparseModel, std::vector const& parameters, DirectEncodingExporterOptions const& options); diff --git a/src/storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.cpp b/src/storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.cpp index e909df627..5837a7a72 100644 --- a/src/storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.cpp +++ b/src/storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.cpp @@ -11,6 +11,10 @@ std::ostream& operator<<(std::ostream& stream, ConditionalAlgorithmSetting const return stream << "bisection"; case ConditionalAlgorithmSetting::BisectionAdvanced: return stream << "bisection-advanced"; + case ConditionalAlgorithmSetting::BisectionPolicyTracking: + return stream << "bisection-pt"; + case ConditionalAlgorithmSetting::BisectionAdvancedPolicyTracking: + return stream << "bisection-advanced-pt"; case ConditionalAlgorithmSetting::PolicyIteration: return stream << "pi"; } @@ -27,6 +31,10 @@ ConditionalAlgorithmSetting conditionalAlgorithmSettingFromString(std::string co return ConditionalAlgorithmSetting::Bisection; } else if (algorithm == "bisection-advanced") { return ConditionalAlgorithmSetting::BisectionAdvanced; + } else if (algorithm == "bisection-pt") { + return ConditionalAlgorithmSetting::BisectionPolicyTracking; + } else if (algorithm == "bisection-advanced-pt") { + return ConditionalAlgorithmSetting::BisectionAdvancedPolicyTracking; } else if (algorithm == "pi") { return ConditionalAlgorithmSetting::PolicyIteration; } diff --git a/src/storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.h b/src/storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.h index 4a9440abb..6c77495af 100644 --- a/src/storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.h +++ b/src/storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.h @@ -5,7 +5,15 @@ #include "storm/utility/macros.h" namespace storm { -enum class ConditionalAlgorithmSetting { Default, Restart, Bisection, BisectionAdvanced, PolicyIteration }; +enum class ConditionalAlgorithmSetting { + Default, + Restart, + Bisection, + BisectionAdvanced, + BisectionPolicyTracking, + BisectionAdvancedPolicyTracking, + PolicyIteration +}; std::ostream& operator<<(std::ostream& stream, ConditionalAlgorithmSetting const& algorithm); ConditionalAlgorithmSetting conditionalAlgorithmSettingFromString(std::string const& algorithm); diff --git a/src/storm/modelchecker/helper/conditional/ConditionalHelper.cpp b/src/storm/modelchecker/helper/conditional/ConditionalHelper.cpp index 6d8d3bcc3..7aa76acb9 100644 --- a/src/storm/modelchecker/helper/conditional/ConditionalHelper.cpp +++ b/src/storm/modelchecker/helper/conditional/ConditionalHelper.cpp @@ -1,8 +1,11 @@ #include "storm/modelchecker/helper/conditional/ConditionalHelper.h" #include +#include #include "storm/adapters/RationalNumberAdapter.h" +#include "storm/adapters/RationalNumberForward.h" +#include "storm/environment/modelchecker/ConditionalModelCheckerEnvironment.h" #include "storm/environment/modelchecker/ModelCheckerEnvironment.h" #include "storm/environment/solver/MinMaxSolverEnvironment.h" #include "storm/exceptions/NotImplementedException.h" @@ -13,11 +16,16 @@ #include "storm/storage/BitVector.h" #include "storm/storage/MaximalEndComponentDecomposition.h" #include "storm/storage/SparseMatrix.h" +#include "storm/storage/StronglyConnectedComponentDecomposition.h" #include "storm/transformer/EndComponentEliminator.h" #include "storm/utility/Extremum.h" -#include "storm/utility/KwekMehlhorn.h" +#include "storm/utility/NumberTraits.h" +#include "storm/utility/OptionalRef.h" +#include "storm/utility/RationalApproximation.h" #include "storm/utility/SignalHandler.h" +#include "storm/utility/constants.h" #include "storm/utility/graph.h" +#include "storm/utility/logging.h" #include "storm/utility/macros.h" namespace storm::modelchecker { @@ -25,12 +33,13 @@ namespace storm::modelchecker { namespace internal { template -void eliminateEndComponents(storm::storage::BitVector possibleEcStates, bool addRowAtRepresentativeState, std::optional representativeRowEntry, - storm::storage::SparseMatrix& matrix, uint64_t& initialState, storm::storage::BitVector& rowsWithSum1, - std::vector& rowValues1, storm::OptionalRef> rowValues2 = {}) { +std::optional::EndComponentEliminatorReturnType> eliminateEndComponents( + storm::storage::BitVector const& possibleEcStates, bool addRowAtRepresentativeState, std::optional const representativeRowEntry, + storm::storage::SparseMatrix& matrix, storm::storage::BitVector& rowsWithSum1, std::vector& rowValues1, + storm::OptionalRef> rowValues2 = {}) { storm::storage::MaximalEndComponentDecomposition ecs(matrix, matrix.transpose(true), possibleEcStates, rowsWithSum1); if (ecs.empty()) { - return; // nothing to do + return {}; // nothing to do } storm::storage::BitVector allRowGroups(matrix.getRowGroupCount(), true); @@ -66,9 +75,6 @@ void eliminateEndComponents(storm::storage::BitVector possibleEcStates, bool add updateRowValue(*rowValues2); } - // update initial state - initialState = ecElimResult.oldToNewStateMapping[initialState]; - // update bitvector storm::storage::BitVector newRowsWithSum1(ecElimResult.newToOldRowMapping.size(), true); uint64_t newRowIndex = 0; @@ -79,26 +85,47 @@ void eliminateEndComponents(storm::storage::BitVector possibleEcStates, bool add ++newRowIndex; } rowsWithSum1 = std::move(newRowsWithSum1); + + return ecElimResult; } template SolutionType solveMinMaxEquationSystem(storm::Environment const& env, storm::storage::SparseMatrix const& matrix, std::vector const& rowValues, storm::storage::BitVector const& rowsWithSum1, - storm::solver::OptimizationDirection const dir, uint64_t const initialState) { + storm::solver::SolveGoal const& goal, uint64_t const initialState, + std::optional>& schedulerOutput) { // Initialize the solution vector. std::vector x(matrix.getRowGroupCount(), storm::utility::zero()); // Set up the solver. - auto solver = storm::solver::GeneralMinMaxLinearEquationSolverFactory().create(env, matrix); - solver->setOptimizationDirection(dir); + storm::solver::GeneralMinMaxLinearEquationSolverFactory factory; + storm::storage::BitVector relevantValues(matrix.getRowGroupCount(), false); + relevantValues.set(initialState, true); + auto getGoal = [&goal, &relevantValues]() -> storm::solver::SolveGoal { + if (goal.isBounded()) { + return {goal.direction(), goal.boundComparisonType(), goal.thresholdValue(), relevantValues}; + } else { + return {goal.direction(), relevantValues}; + } + }; + auto solver = storm::solver::configureMinMaxLinearEquationSolver(env, getGoal(), factory, matrix); + + storm::solver::GeneralMinMaxLinearEquationSolverFactory().create(env, matrix); + solver->setOptimizationDirection(goal.direction()); solver->setRequirementsChecked(); solver->setHasUniqueSolution(true); solver->setHasNoEndComponents(true); solver->setLowerBound(storm::utility::zero()); solver->setUpperBound(storm::utility::one()); + solver->setTrackScheduler(schedulerOutput.has_value()); // Solve the corresponding system of equations. solver->solveEquations(env, x, rowValues); + + if (schedulerOutput) { + *schedulerOutput = std::move(solver->getSchedulerChoices()); + } + return x[initialState]; } @@ -107,23 +134,34 @@ SolutionType solveMinMaxEquationSystem(storm::Environment const& env, storm::sto * @note This code is optimized for cases where not all states are reachable from the initial states. */ template -void computeReachabilityProbabilities(Environment const& env, std::map& nonZeroResults, storm::solver::OptimizationDirection const dir, - storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& initialStates, - storm::storage::BitVector const& allowedStates, storm::storage::BitVector const& targetStates) { +std::unique_ptr> computeReachabilityProbabilities( + Environment const& env, std::map& nonZeroResults, storm::solver::OptimizationDirection const dir, + storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& initialStates, + storm::storage::BitVector const& allowedStates, storm::storage::BitVector const& targetStates, bool computeScheduler = true) { + std::unique_ptr> scheduler; + if (computeScheduler) { + scheduler = std::make_unique>(transitionMatrix.getRowGroupCount()); + } + + auto reachabilityEnv = env; + reachabilityEnv.solver().minMax().setPrecision(env.modelchecker().conditional().getPrecision()); + reachabilityEnv.solver().minMax().setRelativeTerminationCriterion(env.modelchecker().conditional().isRelativePrecision()); + if (initialStates.empty()) { // nothing to do - return; + return scheduler; } auto const reachableStates = storm::utility::graph::getReachableStates(transitionMatrix, initialStates, allowedStates, targetStates); auto const subTargets = targetStates % reachableStates; // Catch the case where no target is reachable from an initial state. In this case, there is nothing to do since all probabilities are zero. if (subTargets.empty()) { - return; + return scheduler; } auto const subInits = initialStates % reachableStates; auto const submatrix = transitionMatrix.getSubmatrix(true, reachableStates, reachableStates); auto const subResult = helper::SparseMdpPrctlHelper::computeUntilProbabilities( - env, storm::solver::SolveGoal(dir, subInits), submatrix, submatrix.transpose(true), storm::storage::BitVector(subTargets.size(), true), - subTargets, false, false); + reachabilityEnv, storm::solver::SolveGoal(dir, subInits), submatrix, submatrix.transpose(true), + storm::storage::BitVector(subTargets.size(), true), subTargets, false, computeScheduler); + auto origInitIt = initialStates.begin(); for (auto subInit : subInits) { auto const& val = subResult.values[subInit]; @@ -132,6 +170,16 @@ void computeReachabilityProbabilities(Environment const& env, std::mapsetChoice(subResult.scheduler->getChoice(submatrixIdx), state); + ++submatrixIdx; + } + } + + return scheduler; } template @@ -139,6 +187,7 @@ struct NormalFormData { storm::storage::BitVector const maybeStates; // Those states that can be reached from initial without reaching a terminal state storm::storage::BitVector const terminalStates; // Those states where we already know the probability to reach the condition and the target value storm::storage::BitVector const conditionStates; // Those states where the condition holds almost surely (under all schedulers) + storm::storage::BitVector const targetStates; // Those states where the target holds almost surely (under all schedulers) storm::storage::BitVector const universalObservationFailureStates; // Those states where the condition is not reachable (under all schedulers) storm::storage::BitVector const existentialObservationFailureStates; // Those states s where a scheduler exists that (i) does not reach the condition from // s and (ii) acts optimal in all terminal states @@ -152,6 +201,11 @@ struct NormalFormData { // TerminalStates is a superset of conditionStates and dom(nonZeroTargetStateValues). // For a terminalState that is not a conditionState, it is impossible to (reach the condition and not reach the target). + std::unique_ptr> + schedulerChoicesForReachingTargetStates; // Scheduler choices for reaching target states, used for constructing the resulting scheduler + std::unique_ptr> + schedulerChoicesForReachingConditionStates; // Scheduler choices for reaching condition states, used for constructing the resulting scheduler + ValueType getTargetValue(uint64_t state) const { STORM_LOG_ASSERT(terminalStates.get(state), "Tried to get target value for non-terminal state"); auto const it = nonZeroTargetStateValues.find(state); @@ -167,7 +221,7 @@ struct NormalFormData { }; template -NormalFormData obtainNormalForm(Environment const& env, storm::solver::OptimizationDirection const dir, +NormalFormData obtainNormalForm(Environment const& env, storm::solver::OptimizationDirection const dir, bool computeScheduler, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& relevantStates, storm::storage::BitVector const& targetStates, storm::storage::BitVector const& conditionStates) { @@ -178,9 +232,13 @@ NormalFormData obtainNormalForm(Environment const& env, storm::solver std::map nonZeroTargetStateValues; auto const extendedTargetStates = storm::utility::graph::performProb1A(transitionMatrix, transitionMatrix.getRowGroupIndices(), backwardTransitions, allStates, targetStates); - computeReachabilityProbabilities(env, nonZeroTargetStateValues, dir, transitionMatrix, extendedConditionStates, allStates, extendedTargetStates); auto const targetAndNotCondFailStates = extendedTargetStates & ~(extendedConditionStates | universalObservationFailureStates); - computeReachabilityProbabilities(env, nonZeroTargetStateValues, dir, transitionMatrix, targetAndNotCondFailStates, allStates, extendedConditionStates); + + // compute schedulers for reaching target and condition states from target and condition states + std::unique_ptr> schedulerChoicesForReachingTargetStates = computeReachabilityProbabilities( + env, nonZeroTargetStateValues, dir, transitionMatrix, extendedConditionStates, allStates, extendedTargetStates, computeScheduler); + std::unique_ptr> schedulerChoicesForReachingConditionStates = computeReachabilityProbabilities( + env, nonZeroTargetStateValues, dir, transitionMatrix, targetAndNotCondFailStates, allStates, extendedConditionStates, computeScheduler); // get states where the optimal policy reaches the condition with positive probability auto terminalStatesThatReachCondition = extendedConditionStates; @@ -211,20 +269,184 @@ NormalFormData obtainNormalForm(Environment const& env, storm::solver return NormalFormData{.maybeStates = std::move(nonTerminalStates), .terminalStates = std::move(terminalStates), .conditionStates = std::move(extendedConditionStates), + .targetStates = std::move(extendedTargetStates), .universalObservationFailureStates = std::move(universalObservationFailureStates), .existentialObservationFailureStates = std::move(existentialObservationFailureStates), - .nonZeroTargetStateValues = std::move(nonZeroTargetStateValues)}; + .nonZeroTargetStateValues = std::move(nonZeroTargetStateValues), + .schedulerChoicesForReachingTargetStates = std::move(schedulerChoicesForReachingTargetStates), + .schedulerChoicesForReachingConditionStates = std::move(schedulerChoicesForReachingConditionStates)}; +} + +// computes the scheduler that reaches the EC exits from the maybe states that were removed by EC elimination +template +void finalizeSchedulerForMaybeStates(storm::storage::Scheduler& scheduler, storm::storage::SparseMatrix const& transitionMatrix, + storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& maybeStates, + storm::storage::BitVector const& maybeStatesWithoutChoice, storm::storage::BitVector const& maybeStatesWithChoice, + std::vector const& stateToFinalEc, NormalFormData const& normalForm, uint64_t initialComponentIndex, + storm::storage::BitVector const& initialComponentExitStates, storm::storage::BitVector const& initialComponentExitRows, + uint64_t chosenInitialComponentExitState, uint64_t chosenInitialComponentExit) { + // Compute the EC stay choices for the states in maybeStatesWithChoice + storm::storage::BitVector ecStayChoices(transitionMatrix.getRowCount(), false); + storm::storage::BitVector initialComponentStates(transitionMatrix.getRowGroupCount(), false); + + // compute initial component states and all choices that stay within a given EC + for (auto state : maybeStates) { + auto ecIndex = stateToFinalEc[state]; + if (ecIndex == initialComponentIndex) { + initialComponentStates.set(state, true); + continue; // state part of the initial component + } else if (ecIndex == std::numeric_limits::max()) { + continue; + } + for (auto choiceIndex : transitionMatrix.getRowGroupIndices(state)) { + bool isEcStayChoice = true; + for (auto const& entry : transitionMatrix.getRow(choiceIndex)) { + auto targetState = entry.getColumn(); + if (stateToFinalEc[targetState] != ecIndex) { + isEcStayChoice = false; + break; + } + } + if (isEcStayChoice) { + ecStayChoices.set(choiceIndex, true); + } + } + } + + // fill choices for ECs that reach the chosen EC exit + auto const maybeNonICStatesWithoutChoice = maybeStatesWithoutChoice & ~initialComponentStates; + storm::utility::graph::computeSchedulerProb1E(maybeNonICStatesWithoutChoice, transitionMatrix, backwardTransitions, maybeStates, maybeStatesWithChoice, + scheduler, ecStayChoices); + + // collect all choices from the initial component states and the choices that were selected by the scheduler so far + auto const condOrTargetStates = normalForm.conditionStates | normalForm.targetStates; + auto const& rowGroups = transitionMatrix.getRowGroupIndices(); + storm::storage::BitVector allowedChoices(transitionMatrix.getRowCount(), false); + auto const rowGroupCount = transitionMatrix.getRowGroupCount(); + for (uint64_t state = 0; state < rowGroupCount; ++state) { + if (scheduler.isChoiceSelected(state)) { + auto choiceIndex = scheduler.getChoice(state).getDeterministicChoice(); + allowedChoices.set(rowGroups[state] + choiceIndex, true); + } else if (initialComponentStates.get(state) || condOrTargetStates.get(state)) { + for (auto choiceIndex : transitionMatrix.getRowGroupIndices(state)) { + allowedChoices.set(choiceIndex, true); + } + } + } + + // dfs to find which choices in initial component states lead to condOrTargetStates + storm::storage::BitVector choicesThatCanVisitCondOrTargetStates(transitionMatrix.getRowCount(), false); + std::stack toProcess; + for (auto state : condOrTargetStates) { + toProcess.push(state); + } + auto visitedStates = condOrTargetStates; + while (!toProcess.empty()) { + auto currentState = toProcess.top(); + toProcess.pop(); + for (auto const& entry : backwardTransitions.getRow(currentState)) { + uint64_t const predecessorState = entry.getColumn(); + for (uint64_t const predecessorChoice : transitionMatrix.getRowGroupIndices(predecessorState)) { + if (!allowedChoices.get(predecessorChoice) || choicesThatCanVisitCondOrTargetStates.get(predecessorChoice)) { + continue; // The choice is either not allowed or has been considered already + } + if (auto const r = transitionMatrix.getRow(predecessorChoice); + std::none_of(r.begin(), r.end(), [¤tState](auto const& e) { return e.getColumn() == currentState; })) { + continue; // not an actual predecessor choice + } + choicesThatCanVisitCondOrTargetStates.set(predecessorChoice, true); + if (!visitedStates.get(predecessorState)) { + visitedStates.set(predecessorState, true); + toProcess.push(predecessorState); + } + } + } + } + + // we want to disallow taking initial component exits that can lead to a condition or target state, beside the one exit that was chosen + storm::storage::BitVector disallowedInitialComponentExits = initialComponentExitRows & choicesThatCanVisitCondOrTargetStates; + disallowedInitialComponentExits.set(chosenInitialComponentExit, false); + storm::storage::BitVector choicesAllowedForInitialComponent = allowedChoices & ~disallowedInitialComponentExits; + + storm::storage::BitVector goodInitialComponentStates = initialComponentStates; + bool progress = false; + for (auto state : initialComponentExitStates) { + auto const groupStart = transitionMatrix.getRowGroupIndices()[state]; + auto const groupEnd = transitionMatrix.getRowGroupIndices()[state + 1]; + bool const allChoicesAreDisallowed = disallowedInitialComponentExits.getNextUnsetIndex(groupStart) >= groupEnd; + if (allChoicesAreDisallowed) { + goodInitialComponentStates.set(state, false); + progress = true; + } + } + while (progress) { + progress = false; + for (auto state : goodInitialComponentStates) { + bool allChoicesAreDisallowed = true; + for (auto choiceIndex : transitionMatrix.getRowGroupIndices(state)) { + auto row = transitionMatrix.getRow(choiceIndex); + bool const hasBadSuccessor = std::any_of( + row.begin(), row.end(), [&goodInitialComponentStates](auto const& entry) { return !goodInitialComponentStates.get(entry.getColumn()); }); + if (hasBadSuccessor) { + choicesAllowedForInitialComponent.set(choiceIndex, false); + } else { + allChoicesAreDisallowed = false; + } + } + if (allChoicesAreDisallowed) { + goodInitialComponentStates.set(state, false); + progress = true; + } + } + } + + storm::storage::BitVector exitStateBitvector(transitionMatrix.getRowGroupCount(), false); + exitStateBitvector.set(chosenInitialComponentExitState, true); + + storm::utility::graph::computeSchedulerProbGreater0E(transitionMatrix, backwardTransitions, initialComponentStates, exitStateBitvector, scheduler, + choicesAllowedForInitialComponent); + + // fill the choices of initial component states that do not have a choice yet + // these states should not reach the condition or target states under the constructed scheduler + for (auto state : initialComponentStates) { + if (!scheduler.isChoiceSelected(state)) { + for (auto choiceIndex : transitionMatrix.getRowGroupIndices(state)) { + if (choicesAllowedForInitialComponent.get(choiceIndex)) { + scheduler.setChoice(choiceIndex - rowGroups[state], state); + break; + } + } + } + } } +template +struct ResultReturnType { + ResultReturnType(ValueType initialStateValue, std::unique_ptr>&& scheduler = nullptr) + : initialStateValue(initialStateValue), scheduler(std::move(scheduler)) { + // Intentionally left empty. + } + + bool hasScheduler() const { + return static_cast(scheduler); + } + + ValueType initialStateValue; + std::unique_ptr> scheduler; +}; + /*! * Uses the restart method by Baier et al. // @see doi.org/10.1007/978-3-642-54862-8_43 */ template -SolutionType computeViaRestartMethod(Environment const& env, uint64_t const initialState, storm::solver::OptimizationDirection const dir, - storm::storage::SparseMatrix const& transitionMatrix, NormalFormData const& normalForm) { +typename internal::ResultReturnType computeViaRestartMethod(Environment const& env, uint64_t const initialState, + storm::solver::SolveGoal const& goal, bool computeScheduler, + storm::storage::SparseMatrix const& transitionMatrix, + storm::storage::SparseMatrix const& backwardTransitions, + NormalFormData const& normalForm) { auto const& maybeStates = normalForm.maybeStates; - auto const stateToMatrixIndexMap = maybeStates.getNumberOfSetBitsBeforeIndices(); + auto originalToReducedStateIndexMap = maybeStates.getNumberOfSetBitsBeforeIndices(); auto const numMaybeStates = maybeStates.getNumberOfSetBits(); auto const numMaybeChoices = transitionMatrix.getNumRowsInRowGroups(maybeStates); @@ -266,39 +488,146 @@ SolutionType computeViaRestartMethod(Environment const& env, uint64_t const init // Insert backloop probability if we haven't done so yet and are past the initial state index // This is to avoid a costly out-of-order insertion into the matrix if (addRestartTransition && entry.getColumn() > initialState) { - matrixBuilder.addNextValue(currentRow, stateToMatrixIndexMap[initialState], restartProbability); + matrixBuilder.addNextValue(currentRow, originalToReducedStateIndexMap[initialState], restartProbability); addRestartTransition = false; } if (maybeStates.get(entry.getColumn())) { - matrixBuilder.addNextValue(currentRow, stateToMatrixIndexMap[entry.getColumn()], entry.getValue()); + matrixBuilder.addNextValue(currentRow, originalToReducedStateIndexMap[entry.getColumn()], entry.getValue()); } } // Add the backloop if we haven't done this already if (addRestartTransition) { - matrixBuilder.addNextValue(currentRow, stateToMatrixIndexMap[initialState], restartProbability); + matrixBuilder.addNextValue(currentRow, originalToReducedStateIndexMap[initialState], restartProbability); } ++currentRow; } } + STORM_LOG_ASSERT(currentRow == numMaybeChoices, "Unexpected number of constructed rows."); auto matrix = matrixBuilder.build(); - auto initStateInMatrix = stateToMatrixIndexMap[initialState]; + auto initStateInReduced = originalToReducedStateIndexMap[initialState]; // Eliminate end components in two phases // First, we catch all end components that do not contain the initial state. It is possible to stay in those ECs forever // without reaching the condition. This is reflected by a backloop to the initial state. - storm::storage::BitVector selectedStatesInMatrix(numMaybeStates, true); - selectedStatesInMatrix.set(initStateInMatrix, false); - eliminateEndComponents(selectedStatesInMatrix, true, initStateInMatrix, matrix, initStateInMatrix, rowsWithSum1, rowValues); + storm::storage::BitVector selectedStatesInReduced(numMaybeStates, true); + selectedStatesInReduced.set(initStateInReduced, false); + auto ecElimResult1 = eliminateEndComponents(selectedStatesInReduced, true, initStateInReduced, matrix, rowsWithSum1, rowValues); + selectedStatesInReduced.set(initStateInReduced, true); + if (ecElimResult1) { + selectedStatesInReduced.resize(matrix.getRowGroupCount(), true); + initStateInReduced = ecElimResult1->oldToNewStateMapping[initStateInReduced]; + } // Second, eliminate the remaining ECs. These must involve the initial state and might have been introduced in the previous step. // A policy selecting such an EC must reach the condition with probability zero and is thus invalid. - selectedStatesInMatrix.set(initStateInMatrix, true); - eliminateEndComponents(selectedStatesInMatrix, false, std::nullopt, matrix, initStateInMatrix, rowsWithSum1, rowValues); + auto ecElimResult2 = eliminateEndComponents(selectedStatesInReduced, false, std::nullopt, matrix, rowsWithSum1, rowValues); + if (ecElimResult2) { + initStateInReduced = ecElimResult2->oldToNewStateMapping[initStateInReduced]; + } STORM_LOG_INFO("Processed model has " << matrix.getRowGroupCount() << " states and " << matrix.getRowGroupCount() << " choices and " << matrix.getEntryCount() << " transitions."); - // Finally, solve the equation system - return solveMinMaxEquationSystem(env, matrix, rowValues, rowsWithSum1, dir, initStateInMatrix); + + // Finally, solve the equation system, potentially computing a scheduler + std::optional> reducedSchedulerChoices; + if (computeScheduler) { + reducedSchedulerChoices.emplace(); + } + auto resultValue = solveMinMaxEquationSystem(env, matrix, rowValues, rowsWithSum1, goal, initStateInReduced, reducedSchedulerChoices); + + // Create result (scheduler potentially added below) + auto finalResult = ResultReturnType(resultValue); + + if (!computeScheduler) { + return finalResult; + } + // At this point we have to reconstruct the scheduler for the original model + STORM_LOG_ASSERT(reducedSchedulerChoices.has_value() && reducedSchedulerChoices->size() == matrix.getRowGroupCount(), + "Requested scheduler, but it was not computed or has invalid size."); + // For easier access, we create and update some index mappings + std::vector originalRowToStateIndexMap; // maps original row indices to original state indices. transitionMatrix.getRowGroupIndices() are the + // inverse of that mapping + originalRowToStateIndexMap.reserve(transitionMatrix.getRowCount()); + for (uint64_t originalStateIndex = 0; originalStateIndex < transitionMatrix.getRowGroupCount(); ++originalStateIndex) { + originalRowToStateIndexMap.insert(originalRowToStateIndexMap.end(), transitionMatrix.getRowGroupSize(originalStateIndex), originalStateIndex); + } + std::vector reducedToOriginalRowIndexMap; // maps row indices of the reduced model to the original ones + reducedToOriginalRowIndexMap.reserve(numMaybeChoices); + for (uint64_t const originalMaybeState : maybeStates) { + for (auto const originalRowIndex : transitionMatrix.getRowGroupIndices(originalMaybeState)) { + reducedToOriginalRowIndexMap.push_back(originalRowIndex); + } + } + if (ecElimResult1.has_value() || ecElimResult2.has_value()) { + // reducedToOriginalRowIndexMap needs to be updated so it maps from rows of the ec-eliminated system + std::vector tmpReducedToOriginalRowIndexMap; + tmpReducedToOriginalRowIndexMap.reserve(matrix.getRowCount()); + for (uint64_t reducedRow = 0; reducedRow < matrix.getRowCount(); ++reducedRow) { + uint64_t intermediateRow = reducedRow; + if (ecElimResult2.has_value()) { + intermediateRow = ecElimResult2->newToOldRowMapping.at(intermediateRow); + } + if (ecElimResult1.has_value()) { + intermediateRow = ecElimResult1->newToOldRowMapping.at(intermediateRow); + } + tmpReducedToOriginalRowIndexMap.push_back(reducedToOriginalRowIndexMap[intermediateRow]); + } + reducedToOriginalRowIndexMap = std::move(tmpReducedToOriginalRowIndexMap); + // originalToReducedStateIndexMap needs to be updated so it maps into the ec-eliminated system + for (uint64_t originalStateIndex = 0; originalStateIndex < transitionMatrix.getRowGroupCount(); ++originalStateIndex) { + auto& reducedIndex = originalToReducedStateIndexMap[originalStateIndex]; + if (maybeStates.get(originalStateIndex)) { + if (ecElimResult1.has_value()) { + reducedIndex = ecElimResult1->oldToNewStateMapping.at(reducedIndex); + } + if (ecElimResult2.has_value()) { + reducedIndex = ecElimResult2->oldToNewStateMapping.at(reducedIndex); + } + } else { + reducedIndex = std::numeric_limits::max(); // The original state does not exist in the reduced model. + } + } + } + + storm::storage::BitVector initialComponentExitRows(transitionMatrix.getRowCount(), false); + storm::storage::BitVector initialComponentExitStates(transitionMatrix.getRowGroupCount(), false); + for (auto const reducedRowIndex : matrix.getRowGroupIndices(initStateInReduced)) { + uint64_t const originalRowIndex = reducedToOriginalRowIndexMap[reducedRowIndex]; + uint64_t const originalState = originalRowToStateIndexMap[originalRowIndex]; + + initialComponentExitRows.set(originalRowIndex, true); + initialComponentExitStates.set(originalState, true); + } + + // If requested, construct the scheduler for the original model + storm::storage::BitVector maybeStatesWithChoice(maybeStates.size(), false); + uint64_t chosenInitialComponentExitState = std::numeric_limits::max(); + uint64_t chosenInitialComponentExit = std::numeric_limits::max(); + auto scheduler = std::make_unique>(transitionMatrix.getRowGroupCount()); + + uint64_t reducedState = 0; + for (auto const& choice : reducedSchedulerChoices.value()) { + uint64_t const reducedRowIndex = matrix.getRowGroupIndices()[reducedState] + choice; + uint64_t const originalRowIndex = reducedToOriginalRowIndexMap[reducedRowIndex]; + uint64_t const originalState = originalRowToStateIndexMap[originalRowIndex]; + uint64_t const originalChoice = originalRowIndex - transitionMatrix.getRowGroupIndices()[originalState]; + scheduler->setChoice(originalChoice, originalState); + maybeStatesWithChoice.set(originalState, true); + if (reducedState == initStateInReduced) { + chosenInitialComponentExitState = originalState; + chosenInitialComponentExit = originalRowIndex; + } + ++reducedState; + } + + auto const maybeStatesWithoutChoice = maybeStates & ~maybeStatesWithChoice; + finalizeSchedulerForMaybeStates(*scheduler, transitionMatrix, backwardTransitions, maybeStates, maybeStatesWithoutChoice, maybeStatesWithChoice, + originalToReducedStateIndexMap, normalForm, initStateInReduced, initialComponentExitStates, initialComponentExitRows, + chosenInitialComponentExitState, chosenInitialComponentExit); + + finalResult.scheduler = std::move(scheduler); + + return finalResult; } /*! @@ -310,7 +639,7 @@ template class WeightedReachabilityHelper { public: WeightedReachabilityHelper(uint64_t const initialState, storm::storage::SparseMatrix const& transitionMatrix, - NormalFormData const& normalForm) { + NormalFormData const& normalForm, bool computeScheduler) { // Determine rowgroups (states) and rows (choices) of the submatrix auto subMatrixRowGroups = normalForm.maybeStates; // Identify and eliminate the initial component to enforce that it is eventually exited @@ -321,7 +650,8 @@ class WeightedReachabilityHelper { // An optimal scheduler can intuitively pick the best exiting action of C and enforce that all paths that satisfy the condition exit C through that // action. By eliminating the initial component, we ensure that only policies that actually exit C are considered. The remaining policies have // probability zero of satisfying the condition. - storm::storage::BitVector initialComponentExitRows(transitionMatrix.getRowCount(), false); + initialComponentExitRows = storm::storage::BitVector(transitionMatrix.getRowCount(), false); + initialComponentExitStates = storm::storage::BitVector(transitionMatrix.getRowGroupCount(), false); subMatrixRowGroups.set(initialState, false); // temporarily unset initial state std::vector dfsStack = {initialState}; while (!dfsStack.empty()) { @@ -340,6 +670,7 @@ class WeightedReachabilityHelper { } } else { initialComponentExitRows.set(rowIndex, true); + initialComponentExitStates.set(state, true); } } } @@ -347,12 +678,18 @@ class WeightedReachabilityHelper { subMatrixRowGroups.set(initialState, true); // set initial state again, as single representative state for the initial component auto const numSubmatrixRowGroups = subMatrixRowGroups.getNumberOfSetBits(); + if (computeScheduler) { + reducedToOriginalRowIndexMap.reserve(numSubmatrixRows); + } + // state index mapping and initial state - auto stateToMatrixIndexMap = subMatrixRowGroups.getNumberOfSetBitsBeforeIndices(); - initialStateInSubmatrix = stateToMatrixIndexMap[initialState]; + originalToReducedStateIndexMap = subMatrixRowGroups.getNumberOfSetBitsBeforeIndices(); + initialStateInSubmatrix = originalToReducedStateIndexMap[initialState]; auto const eliminatedInitialComponentStates = normalForm.maybeStates & ~subMatrixRowGroups; + + // Inital component states do not have the correct mapping yet. for (auto state : eliminatedInitialComponentStates) { - stateToMatrixIndexMap[state] = initialStateInSubmatrix; // map all eliminated states to the initial state + originalToReducedStateIndexMap[state] = initialStateInSubmatrix; } // build matrix, rows that sum up to 1, target values, condition values @@ -366,13 +703,10 @@ class WeightedReachabilityHelper { // Put the row processing into a lambda for avoiding code duplications auto processRow = [&](uint64_t origRowIndex) { - // We make two passes. First, we find out the probability to reach an eliminated initial component state - ValueType const eliminatedInitialComponentProbability = transitionMatrix.getConstrainedRowSum(origRowIndex, eliminatedInitialComponentStates); - // Second, we insert the submatrix entries and find out the target and condition probabilities for this row + // insert the submatrix entries and find out the target and condition probabilities for this row ValueType targetProbability = storm::utility::zero(); ValueType conditionProbability = storm::utility::zero(); bool rowSumIsLess1 = false; - bool initialStateEntryInserted = false; for (auto const& entry : transitionMatrix.getRow(origRowIndex)) { if (normalForm.terminalStates.get(entry.getColumn())) { STORM_LOG_ASSERT(!storm::utility::isZero(entry.getValue()), "Transition probability must be non-zero"); @@ -384,19 +718,11 @@ class WeightedReachabilityHelper { } else { conditionProbability += scaledTargetValue; // for terminal, non-condition states, the condition value equals the target value } - } else if (!eliminatedInitialComponentStates.get(entry.getColumn())) { - auto const columnIndex = stateToMatrixIndexMap[entry.getColumn()]; - if (!initialStateEntryInserted && columnIndex >= initialStateInSubmatrix) { - if (columnIndex == initialStateInSubmatrix) { - matrixBuilder.addNextValue(currentRow, initialStateInSubmatrix, eliminatedInitialComponentProbability + entry.getValue()); - } else { - matrixBuilder.addNextValue(currentRow, initialStateInSubmatrix, eliminatedInitialComponentProbability); - matrixBuilder.addNextValue(currentRow, columnIndex, entry.getValue()); - } - initialStateEntryInserted = true; - } else { - matrixBuilder.addNextValue(currentRow, columnIndex, entry.getValue()); - } + } else if (eliminatedInitialComponentStates.get(entry.getColumn())) { + rowSumIsLess1 = true; + } else { + auto const columnIndex = originalToReducedStateIndexMap[entry.getColumn()]; + matrixBuilder.addNextValue(currentRow, columnIndex, entry.getValue()); } } if (rowSumIsLess1) { @@ -410,10 +736,12 @@ class WeightedReachabilityHelper { if (state == initialState) { for (auto origRowIndex : initialComponentExitRows) { processRow(origRowIndex); + reducedToOriginalRowIndexMap.push_back(origRowIndex); } } else { for (auto origRowIndex : transitionMatrix.getRowGroupIndices(state)) { processRow(origRowIndex); + reducedToOriginalRowIndexMap.push_back(origRowIndex); } } } @@ -423,30 +751,77 @@ class WeightedReachabilityHelper { // For all remaining ECs, staying in an EC forever is reflected by collecting a value of zero for both, target and condition storm::storage::BitVector allExceptInit(numSubmatrixRowGroups, true); allExceptInit.set(initialStateInSubmatrix, false); - eliminateEndComponents(allExceptInit, true, std::nullopt, submatrix, initialStateInSubmatrix, rowsWithSum1, targetRowValues, - conditionRowValues); + ecResult = eliminateEndComponents(allExceptInit, true, std::nullopt, submatrix, rowsWithSum1, targetRowValues, conditionRowValues); + if (ecResult) { + initialStateInSubmatrix = ecResult->oldToNewStateMapping[initialStateInSubmatrix]; + } + isAcyclic = !storm::utility::graph::hasCycle(submatrix); STORM_LOG_INFO("Processed model has " << submatrix.getRowGroupCount() << " states and " << submatrix.getRowGroupCount() << " choices and " - << submatrix.getEntryCount() << " transitions."); + << submatrix.getEntryCount() << " transitions. Matrix is " << (isAcyclic ? "acyclic." : "cyclic.")); + + if (computeScheduler) { + // For easier conversion of schedulers to the original model, we create and update some index mappings + STORM_LOG_ASSERT(reducedToOriginalRowIndexMap.size() == numSubmatrixRows, "Unexpected size of reducedToOriginalRowIndexMap."); + if (ecResult.has_value()) { + // reducedToOriginalRowIndexMap needs to be updated so it maps from rows of the ec-eliminated system + std::vector tmpReducedToOriginalRowIndexMap; + tmpReducedToOriginalRowIndexMap.reserve(submatrix.getRowCount()); + for (uint64_t reducedRow = 0; reducedRow < submatrix.getRowCount(); ++reducedRow) { + uint64_t intermediateRow = ecResult->newToOldRowMapping.at(reducedRow); + tmpReducedToOriginalRowIndexMap.push_back(reducedToOriginalRowIndexMap[intermediateRow]); + } + reducedToOriginalRowIndexMap = std::move(tmpReducedToOriginalRowIndexMap); + // originalToReducedStateIndexMap needs to be updated so it maps into the ec-eliminated system + for (uint64_t originalStateIndex = 0; originalStateIndex < transitionMatrix.getRowGroupCount(); ++originalStateIndex) { + auto& reducedIndex = originalToReducedStateIndexMap[originalStateIndex]; + if (subMatrixRowGroups.get(originalStateIndex)) { + reducedIndex = ecResult->oldToNewStateMapping.at(reducedIndex); + } else { + reducedIndex = std::numeric_limits::max(); // The original state does not exist in the reduced model. + } + } + } + } else { + // Clear data that is only needed if we compute schedulers + originalToReducedStateIndexMap.clear(); + reducedToOriginalRowIndexMap.clear(); + ecResult.emplace(); + initialComponentExitRows.clear(); + initialComponentExitStates.clear(); + } } SolutionType computeWeightedDiff(storm::Environment const& env, storm::OptimizationDirection const dir, ValueType const& targetWeight, - ValueType const& conditionWeight) const { - auto rowValues = createScaledVector(targetWeight, targetRowValues, conditionWeight, conditionRowValues); + ValueType const& conditionWeight, storm::OptionalRef> schedulerOutput = {}) { + // Set up the solver. + if (!cachedSolver) { + auto solverEnv = env; + if (isAcyclic) { + STORM_LOG_INFO("Using acyclic min-max solver for weighted reachability computation."); + solverEnv.solver().minMax().setMethod(storm::solver::MinMaxMethod::Acyclic); + } + cachedSolver = storm::solver::GeneralMinMaxLinearEquationSolverFactory().create(solverEnv, submatrix); + cachedSolver->setCachingEnabled(true); + cachedSolver->setRequirementsChecked(); + cachedSolver->setHasUniqueSolution(true); + cachedSolver->setHasNoEndComponents(true); + cachedSolver->setLowerBound(-storm::utility::one()); + cachedSolver->setUpperBound(storm::utility::one()); + } + cachedSolver->setTrackScheduler(schedulerOutput.has_value()); + cachedSolver->setOptimizationDirection(dir); - // Initialize the solution vector. - std::vector x(submatrix.getRowGroupCount(), storm::utility::zero()); + // Initialize the right-hand side vector. + createScaledVector(cachedB, targetWeight, targetRowValues, conditionWeight, conditionRowValues); - // Set up the solver. - auto solver = storm::solver::GeneralMinMaxLinearEquationSolverFactory().create(env, submatrix); - solver->setOptimizationDirection(dir); - solver->setRequirementsChecked(); - solver->setHasUniqueSolution(true); - solver->setHasNoEndComponents(true); - solver->setLowerBound(-storm::utility::one()); - solver->setUpperBound(storm::utility::one()); + // Initialize the solution vector. + cachedX.assign(submatrix.getRowGroupCount(), storm::utility::zero()); - solver->solveEquations(env, x, rowValues); - return x[initialStateInSubmatrix]; + cachedSolver->solveEquations(env, cachedX, cachedB); + if (schedulerOutput) { + *schedulerOutput = cachedSolver->getSchedulerChoices(); + } + return cachedX[initialStateInSubmatrix]; } auto getInternalInitialState() const { @@ -454,7 +829,7 @@ class WeightedReachabilityHelper { } void evaluateScheduler(storm::Environment const& env, std::vector& scheduler, std::vector& targetResults, - std::vector& conditionResults) const { + std::vector& conditionResults) { if (scheduler.empty()) { scheduler.resize(submatrix.getRowGroupCount(), 0); } @@ -464,23 +839,31 @@ class WeightedReachabilityHelper { if (conditionResults.empty()) { conditionResults.resize(submatrix.getRowGroupCount(), storm::utility::zero()); } - // apply the scheduler - storm::solver::GeneralLinearEquationSolverFactory factory; - bool const convertToEquationSystem = factory.getEquationProblemFormat(env) == storm::solver::LinearEquationSolverProblemFormat::EquationSystem; - auto scheduledMatrix = submatrix.selectRowsFromRowGroups(scheduler, convertToEquationSystem); - if (convertToEquationSystem) { - scheduledMatrix.convertToEquationSystem(); - } - auto solver = factory.create(env, std::move(scheduledMatrix)); - solver->setBounds(storm::utility::zero(), storm::utility::one()); - solver->setCachingEnabled(true); + auto solver = getScheduledSolver(env, scheduler); - std::vector subB(submatrix.getRowGroupCount()); - storm::utility::vector::selectVectorValues(subB, scheduler, submatrix.getRowGroupIndices(), targetRowValues); - solver->solveEquations(env, targetResults, subB); + cachedB.resize(submatrix.getRowGroupCount()); + storm::utility::vector::selectVectorValues(cachedB, scheduler, submatrix.getRowGroupIndices(), targetRowValues); + solver->solveEquations(env, targetResults, cachedB); - storm::utility::vector::selectVectorValues(subB, scheduler, submatrix.getRowGroupIndices(), conditionRowValues); - solver->solveEquations(env, conditionResults, subB); + storm::utility::vector::selectVectorValues(cachedB, scheduler, submatrix.getRowGroupIndices(), conditionRowValues); + solver->solveEquations(env, conditionResults, cachedB); + } + + SolutionType evaluateScheduler(storm::Environment const& env, std::vector const& scheduler) { + STORM_LOG_ASSERT(scheduler.size() == submatrix.getRowGroupCount(), "Scheduler size does not match number of row groups"); + auto solver = getScheduledSolver(env, scheduler); + cachedB.resize(submatrix.getRowGroupCount()); + cachedX.resize(submatrix.getRowGroupCount()); + + storm::utility::vector::selectVectorValues(cachedB, scheduler, submatrix.getRowGroupIndices(), targetRowValues); + solver->solveEquations(env, cachedX, cachedB); + SolutionType targetValue = cachedX[initialStateInSubmatrix]; + + storm::utility::vector::selectVectorValues(cachedB, scheduler, submatrix.getRowGroupIndices(), conditionRowValues); + solver->solveEquations(env, cachedX, cachedB); + SolutionType conditionValue = cachedX[initialStateInSubmatrix]; + + return targetValue / conditionValue; } template @@ -511,16 +894,64 @@ class WeightedReachabilityHelper { return improved; } + std::unique_ptr> constructSchedulerForInputModel( + std::vector const& schedulerForReducedModel, storm::storage::SparseMatrix const& originalTransitionMatrix, + storm::storage::SparseMatrix const& originalBackwardTransitions, NormalFormData const& normalForm) const { + std::vector originalRowToStateIndexMap; // maps original row indices to original state indices. transitionMatrix.getRowGroupIndices() are + // the inverse of that mapping + originalRowToStateIndexMap.reserve(originalTransitionMatrix.getRowCount()); + for (uint64_t originalStateIndex = 0; originalStateIndex < originalTransitionMatrix.getRowGroupCount(); ++originalStateIndex) { + originalRowToStateIndexMap.insert(originalRowToStateIndexMap.end(), originalTransitionMatrix.getRowGroupSize(originalStateIndex), + originalStateIndex); + } + + storm::storage::BitVector maybeStatesWithChoice(normalForm.maybeStates.size(), false); + uint64_t chosenInitialComponentExitState = std::numeric_limits::max(); + uint64_t chosenInitialComponentExit = std::numeric_limits::max(); + auto scheduler = std::make_unique>(originalTransitionMatrix.getRowGroupCount()); + + uint64_t reducedState = 0; + for (auto const& choice : schedulerForReducedModel) { + uint64_t const reducedRowIndex = submatrix.getRowGroupIndices()[reducedState] + choice; + uint64_t const originalRowIndex = reducedToOriginalRowIndexMap[reducedRowIndex]; + uint64_t const originalState = originalRowToStateIndexMap[originalRowIndex]; + uint64_t const originalChoice = originalRowIndex - originalTransitionMatrix.getRowGroupIndices()[originalState]; + scheduler->setChoice(originalChoice, originalState); + maybeStatesWithChoice.set(originalState, true); + if (reducedState == initialStateInSubmatrix) { + chosenInitialComponentExitState = originalState; + chosenInitialComponentExit = originalRowIndex; + } + ++reducedState; + } + + auto const maybeStatesWithoutChoice = normalForm.maybeStates & ~maybeStatesWithChoice; + finalizeSchedulerForMaybeStates(*scheduler, originalTransitionMatrix, originalBackwardTransitions, normalForm.maybeStates, maybeStatesWithoutChoice, + maybeStatesWithChoice, originalToReducedStateIndexMap, normalForm, initialStateInSubmatrix, initialComponentExitStates, + initialComponentExitRows, chosenInitialComponentExitState, chosenInitialComponentExit); + return scheduler; + } + private: - std::vector createScaledVector(ValueType const& w1, std::vector const& v1, ValueType const& w2, - std::vector const& v2) const { + void createScaledVector(std::vector& out, ValueType const& w1, std::vector const& v1, ValueType const& w2, + std::vector const& v2) const { STORM_LOG_ASSERT(v1.size() == v2.size(), "Vector sizes must match"); - std::vector result; - result.reserve(v1.size()); - for (size_t i = 0; i < v1.size(); ++i) { - result.push_back(w1 * v1[i] + w2 * v2[i]); + out.resize(v1.size()); + storm::utility::vector::applyPointwise(v1, v2, out, [&w1, &w2](ValueType const& a, ValueType const& b) -> ValueType { return w1 * a + w2 * b; }); + } + + auto getScheduledSolver(storm::Environment const& env, std::vector const& scheduler) const { + // apply the scheduler + storm::solver::GeneralLinearEquationSolverFactory factory; + bool const convertToEquationSystem = factory.getEquationProblemFormat(env) == storm::solver::LinearEquationSolverProblemFormat::EquationSystem; + auto scheduledMatrix = submatrix.selectRowsFromRowGroups(scheduler, convertToEquationSystem); + if (convertToEquationSystem) { + scheduledMatrix.convertToEquationSystem(); } - return result; + auto solver = factory.create(env, std::move(scheduledMatrix)); + solver->setBounds(storm::utility::zero(), storm::utility::one()); + solver->setCachingEnabled(true); + return solver; } storm::storage::SparseMatrix submatrix; @@ -528,63 +959,93 @@ class WeightedReachabilityHelper { std::vector targetRowValues; std::vector conditionRowValues; uint64_t initialStateInSubmatrix; + bool isAcyclic; + std::unique_ptr> cachedSolver; + std::vector cachedX; + std::vector cachedB; + + // Data used to translate schedulers: + std::vector originalToReducedStateIndexMap; + std::vector reducedToOriginalRowIndexMap; + std::optional::EndComponentEliminatorReturnType> ecResult; + storm::storage::BitVector initialComponentExitRows; + storm::storage::BitVector initialComponentExitStates; }; -enum class BisectionMethodBounds { Simple, Advanced }; template -SolutionType computeViaBisection(Environment const& env, BisectionMethodBounds boundOption, uint64_t const initialState, - storm::solver::OptimizationDirection const dir, storm::storage::SparseMatrix const& transitionMatrix, - NormalFormData const& normalForm) { +typename internal::ResultReturnType computeViaBisection(Environment const& env, bool const useAdvancedBounds, bool const usePolicyTracking, + uint64_t const initialState, storm::solver::SolveGoal const& goal, + bool computeScheduler, storm::storage::SparseMatrix const& transitionMatrix, + storm::storage::SparseMatrix const& backwardTransitions, + NormalFormData const& normalForm) { // We currently handle sound model checking incorrectly: we would need the actual lower/upper bounds of the weightedReachabilityHelper STORM_LOG_WARN_COND(!env.solver().isForceSoundness(), "Bisection method does not adequately handle propagation of errors. Result is not necessarily sound."); - SolutionType const precision = [&env, boundOption]() { - if (storm::NumberTraits::IsExact || env.solver().isForceExact()) { - STORM_LOG_WARN_COND(storm::NumberTraits::IsExact && boundOption == BisectionMethodBounds::Advanced, - "Selected bisection method with exact precision in a setting that might not terminate."); - return storm::utility::zero(); - } else { - return storm::utility::convertNumber(env.solver().minMax().getPrecision()); - } - }(); - bool const relative = env.solver().minMax().getRelativeTerminationCriterion(); - WeightedReachabilityHelper wrh(initialState, transitionMatrix, normalForm); + bool const relative = env.modelchecker().conditional().isRelativePrecision(); + auto const precision = storm::utility::convertNumber(env.modelchecker().conditional().getPrecision()); + + WeightedReachabilityHelper wrh(initialState, transitionMatrix, normalForm, computeScheduler); SolutionType pMin{storm::utility::zero()}; SolutionType pMax{storm::utility::one()}; - if (boundOption == BisectionMethodBounds::Advanced) { + if (useAdvancedBounds) { pMin = wrh.computeWeightedDiff(env, storm::OptimizationDirection::Minimize, storm::utility::zero(), storm::utility::one()); pMax = wrh.computeWeightedDiff(env, storm::OptimizationDirection::Maximize, storm::utility::zero(), storm::utility::one()); STORM_LOG_TRACE("Conditioning event bounds:\n\t Lower bound: " << storm::utility::convertNumber(pMin) << ",\n\t Upper bound: " << storm::utility::convertNumber(pMax)); } - storm::utility::Extremum lowerBound = storm::utility::zero(); - storm::utility::Extremum upperBound = storm::utility::one(); - SolutionType middle = (*lowerBound + *upperBound) / 2; + storm::utility::Maximum lowerBound = storm::utility::zero(); + storm::utility::Minimum upperBound = storm::utility::one(); + + std::optional> lowerScheduler, upperScheduler, middleScheduler; + storm::OptionalRef> middleSchedulerRef; + if (usePolicyTracking) { + lowerScheduler.emplace(); + upperScheduler.emplace(); + middleScheduler.emplace(); + middleSchedulerRef.reset(*middleScheduler); + } + + SolutionType middle = goal.isBounded() ? goal.thresholdValue() : (*lowerBound + *upperBound) / 2; + [[maybe_unused]] SolutionType rationalCandiate = middle; // relevant for exact computations + [[maybe_unused]] uint64_t rationalCandidateCount = 0; + std::set checkedMiddleValues; // Middle values that have been checked already + bool terminatedThroughPolicyTracking = false; for (uint64_t iterationCount = 1; true; ++iterationCount) { // evaluate the current middle - SolutionType const middleValue = wrh.computeWeightedDiff(env, dir, storm::utility::one(), -middle); + SolutionType const middleValue = wrh.computeWeightedDiff(env, goal.direction(), storm::utility::one(), -middle, middleSchedulerRef); + checkedMiddleValues.insert(middle); // update the bounds and new middle value according to the bisection method - if (boundOption == BisectionMethodBounds::Simple) { + if (!useAdvancedBounds) { if (middleValue >= storm::utility::zero()) { - lowerBound &= middle; + if (lowerBound &= middle) { + lowerScheduler.swap(middleScheduler); + } } if (middleValue <= storm::utility::zero()) { - upperBound &= middle; + if (upperBound &= middle) { + upperScheduler.swap(middleScheduler); + } } middle = (*lowerBound + *upperBound) / 2; // update middle to the average of the bounds } else { - STORM_LOG_ASSERT(boundOption == BisectionMethodBounds::Advanced, "Unknown bisection method bounds"); if (middleValue >= storm::utility::zero()) { - lowerBound &= middle + (middleValue / pMax); + if (lowerBound &= middle + (middleValue / pMax)) { + lowerScheduler.swap(middleScheduler); + } upperBound &= middle + (middleValue / pMin); } if (middleValue <= storm::utility::zero()) { lowerBound &= middle + (middleValue / pMin); - upperBound &= middle + (middleValue / pMax); + if (upperBound &= middle + (middleValue / pMax)) { + upperScheduler.swap(middleScheduler); + } } - // update middle to the average of the bounds, but scale it according to the middle value (which is in [-1,1]) + // update middle to the average of the bounds, but use the middleValue as a hint: + // If middleValue is close to -1, we use a value close to lowerBound + // If middleValue is close to 0, we use a value close to the avg(lowerBound, upperBound) + // If middleValue is close to +1, we use a value close to upperBound middle = *lowerBound + (storm::utility::one() + middleValue) * (*upperBound - *lowerBound) / 2; if (!storm::NumberTraits::IsExact && storm::utility::isAlmostZero(*upperBound - *lowerBound)) { @@ -598,16 +1059,34 @@ SolutionType computeViaBisection(Environment const& env, BisectionMethodBounds b } // check for convergence SolutionType const boundDiff = *upperBound - *lowerBound; - STORM_LOG_TRACE("Iteration #" << iterationCount << ":\n\t Lower bound: " << storm::utility::convertNumber(*lowerBound) - << ",\n\t Upper bound: " << storm::utility::convertNumber(*upperBound) - << ",\n\t Difference: " << storm::utility::convertNumber(boundDiff) - << ",\n\t Middle val: " << storm::utility::convertNumber(middleValue) << ",\n\t Difference bound: " - << storm::utility::convertNumber((relative ? (precision * *lowerBound) : precision)) << "."); + STORM_LOG_TRACE("Iteration #" << iterationCount << ":\n\t Lower bound: " << *lowerBound << ",\n\t Upper bound: " << *upperBound + << ",\n\t Difference: " << boundDiff << ",\n\t Middle val: " << middleValue + << ",\n\t Difference bound: " << (relative ? (precision * *lowerBound) : precision) << "."); + if (goal.isBounded()) { + STORM_LOG_TRACE("Using threshold " << storm::utility::convertNumber(goal.thresholdValue()) << " with comparison " + << (goal.boundIsALowerBound() ? (goal.boundIsStrict() ? ">" : ">=") : (goal.boundIsStrict() ? "<" : "<=")) + << "."); + } if (boundDiff <= (relative ? (precision * *lowerBound) : precision)) { STORM_LOG_INFO("Bisection method converged after " << iterationCount << " iterations. Difference is " << std::setprecision(std::numeric_limits::digits10) << storm::utility::convertNumber(boundDiff) << "."); break; + } else if (usePolicyTracking && lowerScheduler && upperScheduler && (*lowerScheduler == *upperScheduler)) { + STORM_LOG_INFO("Bisection method converged after " << iterationCount << " iterations due to identical schedulers for lower and upper bound."); + auto result = wrh.evaluateScheduler(env, *lowerScheduler); + lowerBound &= result; + upperBound &= result; + terminatedThroughPolicyTracking = true; + break; + } + // Check if bounds are fully below or above threshold + if (goal.isBounded() && (*upperBound <= goal.thresholdValue() || (*lowerBound >= goal.thresholdValue()))) { + STORM_LOG_INFO("Bisection method determined result after " << iterationCount << " iterations. Found bounds are [" + << storm::utility::convertNumber(*lowerBound) << ", " + << storm::utility::convertNumber(*upperBound) << "], threshold is " + << storm::utility::convertNumber(goal.thresholdValue()) << "."); + break; } // check for early termination if (storm::utility::resources::isTerminate()) { @@ -616,34 +1095,122 @@ SolutionType computeViaBisection(Environment const& env, BisectionMethodBounds b break; } // process the middle value for the next iteration + // This sets the middle value to a rational number with the smallest enumerator/denominator that is still within the bounds + // With close bounds this can lead to the middle being set to exactly the lower or upper bound, thus allowing for an exact answer. if constexpr (storm::NumberTraits::IsExact) { - // find a rational number with a concise representation close to middle and within the bounds - auto const exactMiddle = middle; - - // Find number of digits - 1. Method using log10 does not work since that uses doubles internally. - auto numDigits = storm::utility::numDigits(*upperBound - *lowerBound) - 1; - - do { - ++numDigits; - middle = storm::utility::kwek_mehlhorn::sharpen(numDigits, exactMiddle); - } while (middle <= *lowerBound || middle >= *upperBound); + // Check if the rationalCandidate has been within the bounds for four iterations. + // If yes, we take that as our next "middle". + // Otherwise, we set a new rationalCandidate. + // This heuristic ensures that we eventually check every rational number without affecting the binary search too much + if (rationalCandidateCount >= 4 && rationalCandiate >= *lowerBound && rationalCandiate <= *upperBound && + !checkedMiddleValues.contains(rationalCandiate)) { + middle = rationalCandiate; + rationalCandidateCount = 0; + } else { + // find a rational number with a concise representation within our current bounds + bool const includeLower = !checkedMiddleValues.contains(*lowerBound); + bool const includeUpper = !checkedMiddleValues.contains(*upperBound); + auto newRationalCandiate = storm::utility::findRational(*lowerBound, includeLower, *upperBound, includeUpper); + if (rationalCandiate == newRationalCandiate) { + ++rationalCandidateCount; + } else { + rationalCandiate = newRationalCandiate; + rationalCandidateCount = 0; + } + // Also simplify the middle value + SolutionType delta = + std::min(*upperBound - middle, middle - *lowerBound) / storm::utility::convertNumber(16); + middle = storm::utility::findRational(middle - delta, true, middle + delta, true); + } } - // Since above code never sets 'middle' to exactly zero or one, we check if that could be necessary after a couple of iterations + // Since above code might never set 'middle' to exactly zero or one, we check if that could be necessary after a couple of iterations if (iterationCount == 8) { // 8 is just a heuristic value, it could be any number - if (storm::utility::isZero(*lowerBound)) { + if (storm::utility::isZero(*lowerBound) && !checkedMiddleValues.contains(storm::utility::zero())) { middle = storm::utility::zero(); - } else if (storm::utility::isOne(*upperBound)) { + } else if (storm::utility::isOne(*upperBound) && !checkedMiddleValues.contains(storm::utility::one())) { middle = storm::utility::one(); } } } - return (*lowerBound + *upperBound) / 2; + + // Create result without scheduler + auto finalResult = ResultReturnType((*lowerBound + *upperBound) / 2); + + if (!computeScheduler) { + return finalResult; // nothing else to do + } + // If requested, construct the scheduler for the original model + std::vector reducedSchedulerChoices; + if (terminatedThroughPolicyTracking) { + // We already have computed a scheduler + reducedSchedulerChoices = std::move(*lowerScheduler); + } else { + // Compute a scheduler on the middle result by performing one more iteration + wrh.computeWeightedDiff(env, goal.direction(), storm::utility::one(), -finalResult.initialStateValue, reducedSchedulerChoices); + } + finalResult.scheduler = wrh.constructSchedulerForInputModel(reducedSchedulerChoices, transitionMatrix, backwardTransitions, normalForm); + return finalResult; } template -SolutionType computeViaPolicyIteration(Environment const& env, uint64_t const initialState, storm::solver::OptimizationDirection const dir, - storm::storage::SparseMatrix const& transitionMatrix, NormalFormData const& normalForm) { - WeightedReachabilityHelper wrh(initialState, transitionMatrix, normalForm); +typename internal::ResultReturnType computeViaBisection(Environment const& env, ConditionalAlgorithmSetting const alg, uint64_t const initialState, + storm::solver::SolveGoal const& goal, bool computeScheduler, + storm::storage::SparseMatrix const& transitionMatrix, + storm::storage::SparseMatrix const& backwardTransitions, + NormalFormData const& normalForm) { + using enum ConditionalAlgorithmSetting; + STORM_LOG_ASSERT(alg == Bisection || alg == BisectionAdvanced || alg == BisectionPolicyTracking || alg == BisectionAdvancedPolicyTracking, + "Unhandled Bisection algorithm " << alg << "."); + bool const useAdvancedBounds = (alg == BisectionAdvanced || alg == BisectionAdvancedPolicyTracking); + bool const usePolicyTracking = (alg == BisectionPolicyTracking || alg == BisectionAdvancedPolicyTracking); + return computeViaBisection(env, useAdvancedBounds, usePolicyTracking, initialState, goal, computeScheduler, transitionMatrix, backwardTransitions, + normalForm); +} + +template +typename internal::ResultReturnType decideThreshold(Environment const& env, uint64_t const initialState, + storm::OptimizationDirection const& direction, SolutionType const& threshold, + bool computeScheduler, storm::storage::SparseMatrix const& transitionMatrix, + storm::storage::SparseMatrix const& backwardTransitions, + NormalFormData const& normalForm) { + // We currently handle sound model checking incorrectly: we would need the actual lower/upper bounds of the weightedReachabilityHelper + + WeightedReachabilityHelper wrh(initialState, transitionMatrix, normalForm, computeScheduler); + + std::optional> scheduler; + storm::OptionalRef> schedulerRef; + if (computeScheduler) { + scheduler.emplace(); + schedulerRef.reset(*scheduler); + } + + SolutionType val = wrh.computeWeightedDiff(env, direction, storm::utility::one(), -threshold, schedulerRef); + SolutionType outputProbability; + if (val > storm::utility::zero()) { + // if val is positive, the conditional probability is (strictly) greater than threshold + outputProbability = storm::utility::one(); + } else if (val < storm::utility::zero()) { + // if val is negative, the conditional probability is (strictly) smaller than threshold + outputProbability = storm::utility::zero(); + } else { + // if val is zero, the conditional probability equals the threshold + outputProbability = threshold; + } + auto finalResult = ResultReturnType(outputProbability); + + if (computeScheduler) { + // If requested, construct the scheduler for the original model + finalResult.scheduler = wrh.constructSchedulerForInputModel(scheduler.value(), transitionMatrix, backwardTransitions, normalForm); + } + return finalResult; +} + +template +internal::ResultReturnType computeViaPolicyIteration(Environment const& env, uint64_t const initialState, + storm::solver::OptimizationDirection const dir, + storm::storage::SparseMatrix const& transitionMatrix, + NormalFormData const& normalForm) { + WeightedReachabilityHelper wrh(initialState, transitionMatrix, normalForm, false); // scheduler computation not yet implemented. std::vector scheduler; std::vector targetResults, conditionResults; @@ -701,9 +1268,15 @@ std::optional handleTrivialCases(uint64_t const initialState, Norm template std::unique_ptr computeConditionalProbabilities(Environment const& env, storm::solver::SolveGoal&& goal, - storm::storage::SparseMatrix const& transitionMatrix, + bool produceSchedulers, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& targetStates, storm::storage::BitVector const& conditionStates) { + auto precision = env.modelchecker().conditional().getPrecision(); + if (storm::NumberTraits::IsExact && env.modelchecker().conditional().isPrecisionSetFromDefault()) { + STORM_LOG_INFO("Setting the conditional precision to 0 since the value type is exact and the precision was not explicitly set by the user."); + precision = storm::utility::zero(); + } + // We might require adapting the precision of the solver to counter error propagation (e.g. when computing the normal form). auto normalFormConstructionEnv = env; auto analysisEnv = env; @@ -711,9 +1284,11 @@ std::unique_ptr computeConditionalProbabilities(Environment const& // We intuitively have to divide the precision into two parts, one for computations when constructing the normal form and one for the actual analysis. // As the former is usually less numerically challenging, we use a factor of 1/10 for the normal form construction and 9/10 for the analysis. auto const normalFormPrecisionFactor = storm::utility::convertNumber("1/10"); - normalFormConstructionEnv.solver().minMax().setPrecision(env.solver().minMax().getPrecision() * normalFormPrecisionFactor); - analysisEnv.solver().minMax().setPrecision(env.solver().minMax().getPrecision() * - (storm::utility::one() - normalFormPrecisionFactor)); + normalFormConstructionEnv.modelchecker().conditional().setPrecision(precision * normalFormPrecisionFactor, false); + analysisEnv.modelchecker().conditional().setPrecision(precision * (storm::utility::one() - normalFormPrecisionFactor), false); + } else { + normalFormConstructionEnv.modelchecker().conditional().setPrecision(precision, false); + analysisEnv.modelchecker().conditional().setPrecision(precision, false); } // We first translate the problem into a normal form. @@ -724,56 +1299,180 @@ std::unique_ptr computeConditionalProbabilities(Environment const& "Only one initial state is supported for conditional probabilities"); STORM_LOG_TRACE("Computing conditional probabilities for a model with " << transitionMatrix.getRowGroupCount() << " states and " << transitionMatrix.getEntryCount() << " transitions."); - // storm::utility::Stopwatch sw(true); - auto normalFormData = internal::obtainNormalForm(normalFormConstructionEnv, goal.direction(), transitionMatrix, backwardTransitions, goal.relevantValues(), - targetStates, conditionStates); - // sw.stop(); - // STORM_PRINT_AND_LOG("Time for obtaining the normal form: " << sw << ".\n"); + auto normalFormData = internal::obtainNormalForm(normalFormConstructionEnv, goal.direction(), produceSchedulers, transitionMatrix, backwardTransitions, + goal.relevantValues(), targetStates, conditionStates); // Then, we solve the induced problem using the selected algorithm auto const initialState = *goal.relevantValues().begin(); ValueType initialStateValue = -storm::utility::one(); + std::unique_ptr> scheduler = nullptr; if (auto trivialValue = internal::handleTrivialCases(initialState, normalFormData); trivialValue.has_value()) { initialStateValue = *trivialValue; + if (initialStateValue == storm::utility::zero() && !normalFormData.terminalStates.get(initialState) && produceSchedulers) { + // we need to compute a scheduler that at least reaches the condition with non-zero probability + auto initialStateBitVector = storm::storage::BitVector(transitionMatrix.getRowGroupCount(), false); + initialStateBitVector.set(initialState, true); + auto const conditionReachResult = helper::SparseMdpPrctlHelper::computeUntilProbabilities( + env, storm::solver::SolveGoal(storm::solver::OptimizationDirection::Maximize, initialStateBitVector), transitionMatrix, + transitionMatrix.transpose(true), storm::storage::BitVector(conditionStates.size(), true), conditionStates, false, true); + scheduler = + std::unique_ptr>(new storm::storage::Scheduler(transitionMatrix.getRowGroupCount())); + auto stateId = 0; + for (uint64_t state = 0; state < transitionMatrix.getRowGroupCount(); ++state) { + scheduler->setChoice(conditionReachResult.scheduler->getChoice(stateId), state); + ++stateId; + } + } else { + scheduler = + std::unique_ptr>(new storm::storage::Scheduler(transitionMatrix.getRowGroupCount())); + } STORM_LOG_DEBUG("Initial state has trivial value " << initialStateValue); } else { STORM_LOG_ASSERT(normalFormData.maybeStates.get(initialState), "Initial state must be a maybe state if it is not a terminal state"); - auto alg = analysisEnv.modelchecker().getConditionalAlgorithmSetting(); + auto alg = analysisEnv.modelchecker().conditional().getAlgorithm(); if (alg == ConditionalAlgorithmSetting::Default) { - alg = ConditionalAlgorithmSetting::Restart; + alg = ConditionalAlgorithmSetting::BisectionPolicyTracking; } + STORM_LOG_INFO("Analyzing normal form with " << normalFormData.maybeStates.getNumberOfSetBits() << " maybe states using algorithm '" << alg << "."); - // sw.restart(); + internal::ResultReturnType result{storm::utility::zero()}; switch (alg) { - case ConditionalAlgorithmSetting::Restart: - initialStateValue = internal::computeViaRestartMethod(analysisEnv, initialState, goal.direction(), transitionMatrix, normalFormData); + case ConditionalAlgorithmSetting::Restart: { + auto restartEnv = analysisEnv; + restartEnv.solver().minMax().setPrecision(analysisEnv.modelchecker().conditional().getPrecision()); + restartEnv.solver().minMax().setRelativeTerminationCriterion(analysisEnv.modelchecker().conditional().isRelativePrecision()); + + result = + internal::computeViaRestartMethod(restartEnv, initialState, goal, produceSchedulers, transitionMatrix, backwardTransitions, normalFormData); break; + } case ConditionalAlgorithmSetting::Bisection: - initialStateValue = internal::computeViaBisection(analysisEnv, internal::BisectionMethodBounds::Simple, initialState, goal.direction(), - transitionMatrix, normalFormData); - break; case ConditionalAlgorithmSetting::BisectionAdvanced: - initialStateValue = internal::computeViaBisection(analysisEnv, internal::BisectionMethodBounds::Advanced, initialState, goal.direction(), - transitionMatrix, normalFormData); + case ConditionalAlgorithmSetting::BisectionPolicyTracking: + case ConditionalAlgorithmSetting::BisectionAdvancedPolicyTracking: { + if (goal.isBounded()) { + result = internal::decideThreshold(analysisEnv, initialState, goal.direction(), goal.thresholdValue(), produceSchedulers, transitionMatrix, + backwardTransitions, normalFormData); + } else { + result = internal::computeViaBisection(analysisEnv, alg, initialState, goal, produceSchedulers, transitionMatrix, backwardTransitions, + normalFormData); + } break; - case ConditionalAlgorithmSetting::PolicyIteration: - initialStateValue = internal::computeViaPolicyIteration(analysisEnv, initialState, goal.direction(), transitionMatrix, normalFormData); + } + case ConditionalAlgorithmSetting::PolicyIteration: { + result = internal::computeViaPolicyIteration(analysisEnv, initialState, goal.direction(), transitionMatrix, normalFormData); break; - default: + } + default: { STORM_LOG_THROW(false, storm::exceptions::NotImplementedException, "Unknown conditional probability algorithm: " << alg); + } } - // sw.stop(); - // STORM_PRINT_AND_LOG("Time for analyzing the normal form: " << sw << ".\n"); + initialStateValue = result.initialStateValue; + scheduler = std::move(result.scheduler); + } + std::unique_ptr result(new ExplicitQuantitativeCheckResult(initialState, initialStateValue)); + + // if produce schedulers was set, we have to construct a scheduler with memory + if (produceSchedulers && scheduler) { + storm::utility::graph::computeSchedulerProb1E(normalFormData.targetStates, transitionMatrix, backwardTransitions, normalFormData.targetStates, + targetStates, *scheduler); + storm::utility::graph::computeSchedulerProb1E(normalFormData.conditionStates, transitionMatrix, backwardTransitions, normalFormData.conditionStates, + conditionStates, *scheduler); + // fill in the scheduler with default choices for states that are missing a choice, these states should be just the ones from which the condition is + // unreachable this is also used to fill choices for the trivial cases + for (uint64_t state = 0; state < transitionMatrix.getRowGroupCount(); ++state) { + if (!scheduler->isChoiceSelected(state)) { + // select an arbitrary choice + scheduler->setChoice(0, state); + } + } + + // create scheduler with memory structure + storm::storage::MemoryStructure::TransitionMatrix memoryTransitions(3, std::vector>(3, boost::none)); + storm::models::sparse::StateLabeling memoryStateLabeling(3); + memoryStateLabeling.addLabel("init_memory"); + memoryStateLabeling.addLabel("condition_reached"); + memoryStateLabeling.addLabel("target_reached"); + memoryStateLabeling.addLabelToState("init_memory", 0); + memoryStateLabeling.addLabelToState("condition_reached", 1); + memoryStateLabeling.addLabelToState("target_reached", 2); + + storm::storage::BitVector allTransitions(transitionMatrix.getEntryCount(), true); + storm::storage::BitVector conditionExitTransitions(transitionMatrix.getEntryCount(), false); + storm::storage::BitVector targetExitTransitions(transitionMatrix.getEntryCount(), false); + + for (auto state : conditionStates) { + for (auto choice : transitionMatrix.getRowGroupIndices(state)) { + for (auto entryIt = transitionMatrix.getRow(choice).begin(); entryIt < transitionMatrix.getRow(choice).end(); ++entryIt) { + conditionExitTransitions.set(entryIt - transitionMatrix.begin(), true); + } + } + } + for (auto state : targetStates) { + for (auto choice : transitionMatrix.getRowGroupIndices(state)) { + for (auto entryIt = transitionMatrix.getRow(choice).begin(); entryIt < transitionMatrix.getRow(choice).end(); ++entryIt) { + targetExitTransitions.set(entryIt - transitionMatrix.begin(), true); + } + } + } + + memoryTransitions[0][0] = + allTransitions & ~conditionExitTransitions & ~targetExitTransitions; // if neither condition nor target reached, stay in init_memory + memoryTransitions[0][1] = conditionExitTransitions; + memoryTransitions[0][2] = targetExitTransitions & ~conditionExitTransitions; + memoryTransitions[1][1] = allTransitions; // once condition reached, stay in that memory state + memoryTransitions[2][2] = allTransitions; // once target reached, stay in that memory state + + // this assumes there is a single initial state + auto memoryStructure = storm::storage::MemoryStructure(memoryTransitions, memoryStateLabeling, std::vector(1, 0), true); + + auto finalScheduler = std::unique_ptr>( + new storm::storage::Scheduler(transitionMatrix.getRowGroupCount(), std::move(memoryStructure))); + + for (uint64_t state = 0; state < transitionMatrix.getRowGroupCount(); ++state) { + // set choices for memory 0 + if (conditionStates.get(state)) { + if (normalFormData.schedulerChoicesForReachingTargetStates->isChoiceSelected(state)) { + finalScheduler->setChoice(normalFormData.schedulerChoicesForReachingTargetStates->getChoice(state), state, 0); + } else { + finalScheduler->setChoice(0, state, 0); // arbitrary choice if no choice was recorded. + } + } else if (targetStates.get(state)) { + if (normalFormData.schedulerChoicesForReachingConditionStates->isChoiceSelected(state)) { + finalScheduler->setChoice(normalFormData.schedulerChoicesForReachingConditionStates->getChoice(state), state, 0); + } else { + finalScheduler->setChoice(0, state, 0); // arbitrary choice if no choice was recorded. + } + } else { + finalScheduler->setChoice(scheduler->getChoice(state), state, 0); + } + + // set choices for memory 1, these are the choices after condition was reached + if (normalFormData.schedulerChoicesForReachingTargetStates->isChoiceSelected(state)) { + finalScheduler->setChoice(normalFormData.schedulerChoicesForReachingTargetStates->getChoice(state), state, 1); + } else { + finalScheduler->setChoice(0, state, 1); // arbitrary choice if no choice was recorded. + } + // set choices for memory 2, these are the choices after target was reached + if (normalFormData.schedulerChoicesForReachingConditionStates->isChoiceSelected(state)) { + finalScheduler->setChoice(normalFormData.schedulerChoicesForReachingConditionStates->getChoice(state), state, 2); + } else { + finalScheduler->setChoice(0, state, 2); // arbitrary choice if no choice was recorded. + } + } + + result->asExplicitQuantitativeCheckResult().setScheduler(std::move(finalScheduler)); } - return std::unique_ptr(new ExplicitQuantitativeCheckResult(initialState, initialStateValue)); + return result; } -template std::unique_ptr computeConditionalProbabilities(Environment const& env, storm::solver::SolveGoal&& goal, +template std::unique_ptr computeConditionalProbabilities(Environment const& env, storm::solver::SolveGoal&& goal, bool produceSchedulers, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& targetStates, storm::storage::BitVector const& conditionStates); template std::unique_ptr computeConditionalProbabilities(Environment const& env, storm::solver::SolveGoal&& goal, + bool produceSchedulers, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& targetStates, diff --git a/src/storm/modelchecker/helper/conditional/ConditionalHelper.h b/src/storm/modelchecker/helper/conditional/ConditionalHelper.h index 12e8c0791..2f87e395b 100644 --- a/src/storm/modelchecker/helper/conditional/ConditionalHelper.h +++ b/src/storm/modelchecker/helper/conditional/ConditionalHelper.h @@ -1,6 +1,8 @@ #pragma once #include +#include "storm/logic/ConditionalFormula.h" +#include "storm/modelchecker/CheckTask.h" #include "storm/solver/SolveGoal.h" namespace storm { @@ -22,7 +24,7 @@ class BackwardTransitionCache; template std::unique_ptr computeConditionalProbabilities(Environment const& env, storm::solver::SolveGoal&& goal, - storm::storage::SparseMatrix const& transitionMatrix, + bool produceSchedulers, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& targetStates, storm::storage::BitVector const& conditionStates); diff --git a/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp b/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp index 562c7fbba..ea116d0da 100644 --- a/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp +++ b/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp @@ -284,8 +284,9 @@ std::unique_ptr SparseMdpPrctlModelChecker::com throw exceptions::NotImplementedException() << "Conditional Probabilities are not supported with interval models"; } else { return storm::modelchecker::computeConditionalProbabilities(env, storm::solver::SolveGoal(this->getModel(), checkTask), - this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), - leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector()); + checkTask.isProduceSchedulersSet(), this->getModel().getTransitionMatrix(), + this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), + rightResult.getTruthValuesVector()); } } diff --git a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp index 95fa5298f..8e04642ee 100644 --- a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp +++ b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp @@ -3,7 +3,6 @@ #include "storm/adapters/IntervalAdapter.h" #include "storm/environment/solver/MinMaxSolverEnvironment.h" #include "storm/exceptions/IllegalArgumentException.h" -#include "storm/exceptions/IllegalFunctionCallException.h" #include "storm/exceptions/InvalidPropertyException.h" #include "storm/exceptions/InvalidSettingsException.h" #include "storm/exceptions/NotSupportedException.h" diff --git a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.h b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.h index aaa2c7c9d..eeb696c6b 100644 --- a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.h +++ b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.h @@ -101,6 +101,12 @@ class SparseMdpPrctlHelper { storm::models::sparse::StandardRewardModel const& intervalRewardModel, bool lowerBoundOfIntervals, storm::storage::BitVector const& targetStates, bool qualitative); + static std::unique_ptr computeConditionalProbabilities(Environment const& env, storm::solver::SolveGoal&& goal, + storm::storage::SparseMatrix const& transitionMatrix, + storm::storage::SparseMatrix const& backwardTransitions, + storm::storage::BitVector const& targetStates, + storm::storage::BitVector const& conditionStates); + private: static MDPSparseModelCheckingHelperReturnType computeReachabilityRewardsHelper( Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, diff --git a/src/storm/settings/SettingsManager.cpp b/src/storm/settings/SettingsManager.cpp index 9602f0a67..7d34356b6 100644 --- a/src/storm/settings/SettingsManager.cpp +++ b/src/storm/settings/SettingsManager.cpp @@ -16,6 +16,7 @@ #include "storm/settings/modules/AbstractionSettings.h" #include "storm/settings/modules/BisimulationSettings.h" #include "storm/settings/modules/BuildSettings.h" +#include "storm/settings/modules/ConditionalSettings.h" #include "storm/settings/modules/CoreSettings.h" #include "storm/settings/modules/CuddSettings.h" #include "storm/settings/modules/DebugSettings.h" @@ -708,6 +709,7 @@ void initializeAll(std::string const& name, std::string const& executableName) { storm::settings::addModule(); storm::settings::addModule(); storm::settings::addModule(); + storm::settings::addModule(); } } // namespace settings diff --git a/src/storm/settings/modules/ConditionalSettings.cpp b/src/storm/settings/modules/ConditionalSettings.cpp new file mode 100644 index 000000000..e13b4e1e0 --- /dev/null +++ b/src/storm/settings/modules/ConditionalSettings.cpp @@ -0,0 +1,61 @@ +#include "storm/settings/modules/ConditionalSettings.h" + +#include "storm/settings/ArgumentBuilder.h" +#include "storm/settings/Option.h" +#include "storm/settings/OptionBuilder.h" +#include "storm/settings/SettingsManager.h" + +namespace storm::settings::modules { + +const std::string ConditionalSettings::moduleName = "conditional"; +const std::string ConditionalSettings::conditionalAlgorithmOptionName = "algorithm"; +const std::string ConditionalSettings::conditionalPrecisionOptionName = "precision"; +const std::string ConditionalSettings::conditionalPrecisionAbsoluteOptionName = "absolute"; + +ConditionalSettings::ConditionalSettings() : ModuleSettings(moduleName) { + std::vector const conditionalAlgs = {"default", "restart", "bisection", "bisection-advanced", "bisection-pt", "bisection-advanced-pt", "pi"}; + this->addOption(storm::settings::OptionBuilder(moduleName, conditionalAlgorithmOptionName, false, "The used algorithm for conditional probabilities.") + .setIsAdvanced() + .addArgument(storm::settings::ArgumentBuilder::createStringArgument("name", "The name of the method to use.") + .addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(conditionalAlgs)) + .setDefaultValueString("default") + .build()) + .build()); + + this->addOption(storm::settings::OptionBuilder(moduleName, conditionalPrecisionOptionName, false, + "The internally used precision for computing conditional probabilities..") + .setIsAdvanced() + .addArgument(storm::settings::ArgumentBuilder::createDoubleArgument("value", "The precision to use.") + .setDefaultValueDouble(1e-06) + .addValidatorDouble(ArgumentValidatorFactory::createDoubleRangeValidatorIncluding(0.0, 1.0)) + .build()) + .build()); + + this->addOption(storm::settings::OptionBuilder(moduleName, conditionalPrecisionAbsoluteOptionName, false, + "Whether the precision for computing conditional probabilities is considered absolute.") + .setIsAdvanced() + .build()); +} + +bool ConditionalSettings::isConditionalAlgorithmSet() const { + return this->getOption(conditionalAlgorithmOptionName).getHasOptionBeenSet(); +} + +ConditionalAlgorithmSetting ConditionalSettings::getConditionalAlgorithmSetting() const { + return conditionalAlgorithmSettingFromString(this->getOption(conditionalAlgorithmOptionName).getArgumentByName("name").getValueAsString()); +} + +double ConditionalSettings::getConditionalPrecision() const { + return this->getOption(conditionalPrecisionOptionName).getArgumentByName("value").getValueAsDouble(); +} + +bool ConditionalSettings::isConditionalPrecisionSetFromDefaultValue() const { + return !this->getOption(conditionalPrecisionOptionName).getArgumentByName("value").getHasBeenSet() || + this->getOption(conditionalPrecisionOptionName).getArgumentByName("value").wasSetFromDefaultValue(); +} + +bool ConditionalSettings::isConditionalPrecisionAbsolute() const { + return this->getOption(conditionalPrecisionAbsoluteOptionName).getHasOptionBeenSet(); +} + +} // namespace storm::settings::modules \ No newline at end of file diff --git a/src/storm/settings/modules/ConditionalSettings.h b/src/storm/settings/modules/ConditionalSettings.h new file mode 100644 index 000000000..44936c159 --- /dev/null +++ b/src/storm/settings/modules/ConditionalSettings.h @@ -0,0 +1,54 @@ +#pragma once + +#include "storm-config.h" +#include "storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.h" +#include "storm/settings/modules/ModuleSettings.h" + +namespace storm { +namespace settings { +namespace modules { + +/*! + * This class represents the LRA solver settings. + */ +class ConditionalSettings : public ModuleSettings { + public: + ConditionalSettings(); + + /*! + * Retrieves whether an algorithm for conditional properties has been set. + */ + bool isConditionalAlgorithmSet() const; + + /*! + * Retrieves the specified algorithm for conditional probabilities. + */ + ConditionalAlgorithmSetting getConditionalAlgorithmSetting() const; + + /*! + * Retrieves the specified precision for computing conditional probabilities. + */ + double getConditionalPrecision() const; + + /*! + * Retrieves whether the precision for computing conditional probabilities was set from a default value. + */ + bool isConditionalPrecisionSetFromDefaultValue() const; + + /*! + * Retrieves whether the precision for computing conditional probabilities is considered absolute. + */ + bool isConditionalPrecisionAbsolute() const; + + // The name of the module. + static const std::string moduleName; + + private: + static const std::string conditionalAlgorithmOptionName; + static const std::string conditionalPrecisionOptionName; + static const std::string conditionalPrecisionAbsoluteOptionName; +}; + +} // namespace modules +} // namespace settings +} // namespace storm diff --git a/src/storm/settings/modules/ModelCheckerSettings.cpp b/src/storm/settings/modules/ModelCheckerSettings.cpp index 7806c9abc..a5e67bf33 100644 --- a/src/storm/settings/modules/ModelCheckerSettings.cpp +++ b/src/storm/settings/modules/ModelCheckerSettings.cpp @@ -5,13 +5,11 @@ #include "storm/settings/Option.h" #include "storm/settings/OptionBuilder.h" #include "storm/settings/SettingsManager.h" - namespace storm::settings::modules { const std::string ModelCheckerSettings::moduleName = "modelchecker"; const std::string ModelCheckerSettings::filterRewZeroOptionName = "filterrewzero"; const std::string ModelCheckerSettings::ltl2daToolOptionName = "ltl2datool"; -const std::string ModelCheckerSettings::conditionalAlgorithmOptionName = "conditional"; ModelCheckerSettings::ModelCheckerSettings() : ModuleSettings(moduleName) { this->addOption(storm::settings::OptionBuilder(moduleName, filterRewZeroOptionName, false, @@ -25,15 +23,6 @@ ModelCheckerSettings::ModelCheckerSettings() : ModuleSettings(moduleName) { "filename", "A script that can be called with a prefix formula and a name for the output automaton.") .build()) .build()); - - std::vector const conditionalAlgs = {"default", "restart", "bisection", "bisection-advanced", "pi"}; - this->addOption(storm::settings::OptionBuilder(moduleName, conditionalAlgorithmOptionName, false, "The used algorithm for conditional probabilities.") - .setIsAdvanced() - .addArgument(storm::settings::ArgumentBuilder::createStringArgument("name", "The name of the method to use.") - .addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(conditionalAlgs)) - .setDefaultValueString("default") - .build()) - .build()); } bool ModelCheckerSettings::isFilterRewZeroSet() const { @@ -48,12 +37,4 @@ std::string ModelCheckerSettings::getLtl2daTool() const { return this->getOption(ltl2daToolOptionName).getArgumentByName("filename").getValueAsString(); } -bool ModelCheckerSettings::isConditionalAlgorithmSet() const { - return this->getOption(conditionalAlgorithmOptionName).getHasOptionBeenSet(); -} - -ConditionalAlgorithmSetting ModelCheckerSettings::getConditionalAlgorithmSetting() const { - return conditionalAlgorithmSettingFromString(this->getOption(conditionalAlgorithmOptionName).getArgumentByName("name").getValueAsString()); -} - } // namespace storm::settings::modules diff --git a/src/storm/settings/modules/ModelCheckerSettings.h b/src/storm/settings/modules/ModelCheckerSettings.h index ce7e5bd23..4d4df33b9 100644 --- a/src/storm/settings/modules/ModelCheckerSettings.h +++ b/src/storm/settings/modules/ModelCheckerSettings.h @@ -1,7 +1,6 @@ #pragma once #include "storm-config.h" -#include "storm/modelchecker/helper/conditional/ConditionalAlgorithmSetting.h" #include "storm/settings/modules/ModuleSettings.h" namespace storm { @@ -34,16 +33,6 @@ class ModelCheckerSettings : public ModuleSettings { */ std::string getLtl2daTool() const; - /*! - * Retrieves whether an algorithm for conditional properties has been set. - */ - bool isConditionalAlgorithmSet() const; - - /*! - * Retrieves the specified algorithm for conditional probabilities. - */ - ConditionalAlgorithmSetting getConditionalAlgorithmSetting() const; - // The name of the module. static const std::string moduleName; @@ -51,7 +40,6 @@ class ModelCheckerSettings : public ModuleSettings { // Define the string names of the options as constants. static const std::string filterRewZeroOptionName; static const std::string ltl2daToolOptionName; - static const std::string conditionalAlgorithmOptionName; }; } // namespace modules diff --git a/src/storm/solver/SolveGoal.cpp b/src/storm/solver/SolveGoal.cpp index 90d00ceff..7cb7686ed 100644 --- a/src/storm/solver/SolveGoal.cpp +++ b/src/storm/solver/SolveGoal.cpp @@ -111,6 +111,11 @@ UncertaintyResolutionMode SolveGoal::getUncertaintyReso return uncertaintyResolutionMode; } +template +storm::logic::ComparisonType SolveGoal::boundComparisonType() const { + return comparisonType.get(); +} + template SolutionType const& SolveGoal::thresholdValue() const { return threshold.get(); diff --git a/src/storm/solver/SolveGoal.h b/src/storm/solver/SolveGoal.h index 79072a38e..f085b9351 100644 --- a/src/storm/solver/SolveGoal.h +++ b/src/storm/solver/SolveGoal.h @@ -85,6 +85,8 @@ class SolveGoal { bool boundIsStrict() const; + storm::logic::ComparisonType boundComparisonType() const; + SolutionType const& thresholdValue() const; bool hasRelevantValues() const; diff --git a/src/storm/storage/BitVector.cpp b/src/storm/storage/BitVector.cpp index 9b00b5f32..a75efdc4c 100644 --- a/src/storm/storage/BitVector.cpp +++ b/src/storm/storage/BitVector.cpp @@ -721,6 +721,10 @@ std::vector BitVector::getNumberOfSetBitsBeforeIndices() const { } ++currentNumberOfSetBits; } + while (lastIndex < this->size()) { + bitsSetBeforeIndices.push_back(currentNumberOfSetBits); + ++lastIndex; + } return bitsSetBeforeIndices; } diff --git a/src/storm/storage/MaximalEndComponentDecomposition.cpp b/src/storm/storage/MaximalEndComponentDecomposition.cpp index 6e7763f04..f1d42eb90 100644 --- a/src/storm/storage/MaximalEndComponentDecomposition.cpp +++ b/src/storm/storage/MaximalEndComponentDecomposition.cpp @@ -150,6 +150,8 @@ void MaximalEndComponentDecomposition::performMaximalEndComponentDeco storm::storage::SparseMatrix const& backwardTransitions, storm::OptionalRef states, storm::OptionalRef choices) { + STORM_LOG_ASSERT(!states.has_value() || transitionMatrix.getRowGroupCount() == states->size(), "Unexpected size of states bitvector."); + STORM_LOG_ASSERT(!choices.has_value() || transitionMatrix.getRowCount() == choices->size(), "Unexpected size of choices bitvector."); // Get some data for convenient access. auto const& nondeterministicChoiceIndices = transitionMatrix.getRowGroupIndices(); diff --git a/src/storm/storage/Scheduler.cpp b/src/storm/storage/Scheduler.cpp index 1dacf8d86..7b772ecb9 100644 --- a/src/storm/storage/Scheduler.cpp +++ b/src/storm/storage/Scheduler.cpp @@ -76,6 +76,13 @@ bool Scheduler::isChoiceSelected(BitVector const& selectedStates, uin return true; } +template +bool Scheduler::isChoiceSelected(uint64_t modelState, uint64_t memoryState) const { + STORM_LOG_ASSERT(memoryState < getNumberOfMemoryStates(), "Illegal memory state index"); + STORM_LOG_ASSERT(modelState < schedulerChoices[memoryState].size(), "Illegal model state index"); + return schedulerChoices[memoryState][modelState].isDefined(); +} + template void Scheduler::clearChoice(uint_fast64_t modelState, uint_fast64_t memoryState) { STORM_LOG_ASSERT(memoryState < getNumberOfMemoryStates(), "Illegal memory state index"); @@ -172,6 +179,19 @@ boost::optional const& Scheduler::ge return memoryStructure; } +template +Scheduler Scheduler::getMemorylessSchedulerForMemoryState(uint64_t memoryState) const { + STORM_LOG_ASSERT(memoryState < getNumberOfMemoryStates(), "Illegal memory state index"); + + Scheduler memorylessScheduler(getNumberOfModelStates()); + for (uint64_t modelState = 0; modelState < getNumberOfModelStates(); ++modelState) { + if (schedulerChoices[memoryState][modelState].isDefined()) { + memorylessScheduler.setChoice(schedulerChoices[memoryState][modelState], modelState); + } + } + return memorylessScheduler; +} + template void Scheduler::printToStream(std::ostream& out, std::shared_ptr> model, bool skipUniqueChoices, bool skipDontCareStates) const { diff --git a/src/storm/storage/Scheduler.h b/src/storm/storage/Scheduler.h index 4bf67f486..8eb200b12 100644 --- a/src/storm/storage/Scheduler.h +++ b/src/storm/storage/Scheduler.h @@ -40,6 +40,11 @@ class Scheduler { */ bool isChoiceSelected(BitVector const& selectedStates, uint64_t memoryState = 0) const; + /*! + * Is the scheduler defined on the given state + */ + bool isChoiceSelected(uint64_t modelState, uint64_t memoryState = 0) const; + /*! * Clears the choice defined by the scheduler for the given state. * @@ -114,6 +119,13 @@ class Scheduler { */ boost::optional const& getMemoryStructure() const; + /*! + * Retrieves a memoryless scheduler that corresponds to the given memory state. + * + * @param memoryState the memory state to fix + */ + Scheduler getMemorylessSchedulerForMemoryState(uint64_t memoryState = 0) const; + /*! * Returns a copy of this scheduler with the new value type */ diff --git a/src/storm/utility/RationalApproximation.cpp b/src/storm/utility/RationalApproximation.cpp new file mode 100644 index 000000000..84700e5f7 --- /dev/null +++ b/src/storm/utility/RationalApproximation.cpp @@ -0,0 +1,156 @@ + +#include "storm/utility/RationalApproximation.h" + +#include "storm/adapters/RationalNumberAdapter.h" +#include "storm/utility/constants.h" +#include "storm/utility/macros.h" + +namespace storm::utility { + +storm::RationalNumber findRational(storm::RationalNumber const& lowerBound, bool lowerInclusive, storm::RationalNumber const& upperBound, bool upperInclusive) { + using Integer = typename storm::NumberTraits::IntegerType; + STORM_LOG_ASSERT(lowerBound < upperBound || (lowerBound == upperBound && lowerInclusive && upperInclusive), "Invalid interval for rational approximation."); + + // Handle negative numbers + if (auto const zero = storm::utility::zero(); lowerBound < zero) { + // check if zero is in the interval + if (upperBound > zero || (upperBound == zero && upperInclusive)) { + return storm::utility::zero(); + } else { + // all numbers in the interval are negative. We translate that to a positive problem and negate the result + return -findRational(-upperBound, upperInclusive, -lowerBound, lowerInclusive); + } + } + // At this point, the solution is known to be non-negative + + // We compute a path in the Stern-Brocot tree from the root to the node representing the simplest rational in the closed interval [lowerBound, upperBound] + // If the input interval is open on one or both sides, we traverse the tree further down until a suitable rational number is found + // @see https://en.wikipedia.org/wiki/Stern–Brocot_tree#A_tree_of_continued_fractions + // @see https://mathoverflow.net/a/424509 + // The path is encoded using a simple continued fraction representation. + // We take path[0] times the right child, path[1] times the left child, path[2] times the right child, etc, using path.back()-1 steps in the last direction. + std::vector path; // in simple continued fraction representation + auto l = lowerBound; + auto u = upperBound; + while (true) { + auto l_den = storm::utility::denominator(l); + auto u_den = storm::utility::denominator(u); + auto const [l_i, l_rem] = storm::utility::divide(storm::utility::numerator(l), l_den); + auto const [u_i, u_rem] = storm::utility::divide(storm::utility::numerator(u), u_den); + + path.push_back(std::min(l_i, u_i)); // insert tree traversal information + if (l_i == u_i && !storm::utility::isZero(l_rem) && !storm::utility::isZero(u_rem)) { + // continue traversing the tree + l = storm::utility::convertNumber(l_den) / l_rem; + u = storm::utility::convertNumber(u_den) / u_rem; + continue; + } + // Reaching this point means that we have found a node in the Stern-Brocot tree where the paths for lower and upper bound diverge. + // If there still is a remainder, we need to add one to the last entry of the path so that it correctly encodes the node we are referring to. + if (l_i != u_i && !storm::utility::isZero(l_i < u_i ? l_rem : u_rem)) { + path.back() += Integer(1); + } + + // Find out if we hit an interval boundary and whether we need to adapt this due to open intervals + bool const needAdjustLower = !lowerInclusive && path.back() == l_i && storm::utility::isZero(l_rem); + bool const needAdjustUpper = !upperInclusive && path.back() == u_i && storm::utility::isZero(u_rem); + if (needAdjustLower || needAdjustUpper) { + // handle for values of the "other" bound that does not need adjustment + auto const& o_i = needAdjustLower ? u_i : l_i; + auto const& o_rem = needAdjustLower ? u_rem : l_rem; + auto const& o_den = needAdjustLower ? u_den : l_den; + auto const& o_inclusive = needAdjustLower ? upperInclusive : lowerInclusive; + + // When adjusting lower bounds, we need to explore the right subtree to obtain a larger value than the current lower bound + // When adjusting upper bounds, we need to explore the left subtree to obtain a smaller value than the current upper bound + // Whether we currently look at left or right subtrees is determined by the parity of the index in the path: + // Path entries at even indices correspond to right moves (increasing value) and entries at odd indices correspond to left moves (decreasing value) + bool const currentDirectionIsIncreasing = (path.size() - 1) % 2 == 0; + bool const adjustInCurrentDirection = (needAdjustLower && currentDirectionIsIncreasing) || (needAdjustUpper && !currentDirectionIsIncreasing); + // Below, we navigate through the Stern-Brocot tree by adapting the path + // path.back() += 1; extends the path to a child in the "current direction" + // path.back() -= 1; path.emplace_back(2); extends the path to a child in the "counter direction" + if (adjustInCurrentDirection) { + STORM_LOG_ASSERT(path.back() <= o_i, "Unexpected case when navigating the Stern-Brocot tree."); + if (path.back() + Integer(1) < o_i || (path.back() + Integer(1) == o_i && !storm::utility::isZero(o_rem))) { + // In this case, the next child (in the current direction) is inside the interval, so we can just take that + path.back() += Integer(1); + } else if (path.back() + Integer(1) == o_i && storm::utility::isZero(o_rem)) { + // In this case, the next child coincides with the other boundary + if (o_inclusive) { + path.back() += Integer(1); // add next child + } else { + // We first take one child in the current direction and then one child in the counter direction. + // path.back() += 1; path.back() -= 1; // cancels out + path.emplace_back(2); + } + } else { + // The following assertion holds because path.back() > o_i is not possible due to the way we constructed the path above + // and if there would be no remainder, the other boundary would be hit as well (i.e. we would have an empty interval (a,a). + STORM_LOG_ASSERT(path.back() == o_i && !storm::utility::isZero(o_rem), "Unexpected case when navigating the Stern-Brocot tree."); + // In this case, we need to append one child in the current direction and multiple children in the counter direction based on the continued + // fraction representation of the other boundary + auto const [o_i2, o_rem2] = storm::utility::divide(o_den, o_rem); + // path.back() += 1; path.back() -= 1; // cancels out + path.push_back(o_i2); + if (!storm::utility::isZero(o_rem2)) { + // If there still is a remainder, we add one to the last entry of the path so that it correctly encodes the node we are referring to. + path.back() += Integer(1); + } else if (!o_inclusive) { + // If there is no remainder, we are exactly on the other boundary. If that boundary is also excluded, we need to add one more step. + path.back() += Integer(1); + } + } + } else { + // Adjusting a bound in the counter direction can only happen if the other bound still has a remainder + // Otherwise, we would have also hit that bound + STORM_LOG_ASSERT(o_i == path.back() - Integer(1), "Unexpected case when navigating the Stern-Brocot tree."); + STORM_LOG_ASSERT(!storm::utility::isZero(o_rem), "Unexpected case when navigating the Stern-Brocot tree."); + auto const [o_i2, o_rem2] = storm::utility::divide(o_den, o_rem); + path.back() -= Integer(1); // necessary in all cases + if (o_i2 > Integer(2) || (o_i2 == Integer(2) && !storm::utility::isZero(o_rem2))) { + // In this case, the next child (in the counter direction) is inside the interval, so we can just take that + path.emplace_back(2); + } else if (o_i2 == Integer(2) && storm::utility::isZero(o_rem2)) { + // In this case, the next child in counter direction coincides with the other boundary + if (o_inclusive) { + path.emplace_back(2); + } else { + // We first take one child in the counter direction and then one child in the current direction. + path.emplace_back(1); + path.emplace_back(2); + } + } else { + STORM_LOG_ASSERT(o_i2 == Integer(1) && !storm::utility::isZero(o_rem2), "Unexpected case when navigating the Stern-Brocot tree."); + // In this case, we need to append one child in the counter direction and multiple children in the current direction based on the continued + // fraction representation of the other boundary + auto const [o_i3, o_rem3] = storm::utility::divide(o_rem, o_rem2); + path.emplace_back(1); + path.push_back(o_i3); + if (!storm::utility::isZero(o_rem3)) { + // If there still is a remainder, we add one to the last entry of the path so that it correctly encodes the node we are referring to. + path.back() += Integer(1); + } else if (!o_inclusive) { + // If there is no remainder, we are exactly on the other boundary. If that boundary is also excluded, we need to add one more step. + path.back() += Integer(1); + } + } + } + } + break; + } + + // Now, construct the rational number from the path + auto it = path.rbegin(); + auto result = storm::utility::convertNumber(*it); + for (++it; it != path.rend(); ++it) { + result = storm::utility::convertNumber(*it) + storm::utility::one() / result; + } + return result; + + STORM_LOG_ASSERT(result > lowerBound || (lowerInclusive && result == lowerBound), "Result is below lower bound."); + STORM_LOG_ASSERT(result < upperBound || (upperInclusive && result == upperBound), "Result is above upper bound."); + return result; +} + +} // namespace storm::utility \ No newline at end of file diff --git a/src/storm/utility/RationalApproximation.h b/src/storm/utility/RationalApproximation.h new file mode 100644 index 000000000..17bae0392 --- /dev/null +++ b/src/storm/utility/RationalApproximation.h @@ -0,0 +1,18 @@ +#pragma once + +#include "storm/adapters/RationalNumberForward.h" + +namespace storm::utility { + +/*! + * Finds the "simplest" rational number in the given interval, where "simplest" means having the smallest denominator + * @pre lowerBound < upperBound or (lowerBound == upperBound and lowerInclusive and upperInclusive) + * @param lowerBound The lower bound of the interval + * @param lowerInclusive Whether the lower bound itself is included + * @param upperBound the upper bound of the interval + * @param upperInclusive Whether the upper bound itself is included + * @return the rational number in the given interval with the smallest denominator + */ +storm::RationalNumber findRational(storm::RationalNumber const& lowerBound, bool lowerInclusive, storm::RationalNumber const& upperBound, bool upperInclusive); + +} // namespace storm::utility \ No newline at end of file diff --git a/src/test/storm/modelchecker/prctl/mdp/ConditionalMdpPrctlModelCheckerTest.cpp b/src/test/storm/modelchecker/prctl/mdp/ConditionalMdpPrctlModelCheckerTest.cpp index 1f9f41cbe..39652c298 100644 --- a/src/test/storm/modelchecker/prctl/mdp/ConditionalMdpPrctlModelCheckerTest.cpp +++ b/src/test/storm/modelchecker/prctl/mdp/ConditionalMdpPrctlModelCheckerTest.cpp @@ -5,9 +5,11 @@ #include "storm-parsers/parser/PrismParser.h" #include "storm/api/builder.h" #include "storm/api/properties.h" +#include "storm/environment/modelchecker/ConditionalModelCheckerEnvironment.h" #include "storm/environment/modelchecker/ModelCheckerEnvironment.h" #include "storm/environment/solver/MinMaxSolverEnvironment.h" #include "storm/modelchecker/prctl/SparseMdpPrctlModelChecker.h" +#include "storm/modelchecker/results/ExplicitQualitativeCheckResult.h" #include "storm/modelchecker/results/ExplicitQuantitativeCheckResult.h" namespace { @@ -19,8 +21,9 @@ class SparseDoubleRestartEnvironment { typedef storm::models::sparse::Mdp ModelType; static storm::Environment createEnvironment() { storm::Environment env; - env.modelchecker().setConditionalAlgorithmSetting(storm::ConditionalAlgorithmSetting::Restart); - env.solver().minMax().setPrecision(storm::utility::convertNumber(1e-10)); // restart algorithm requires a higher precision + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::Restart); + env.modelchecker().conditional().setPrecision(storm::utility::convertNumber(1e-10), + false); // restart algorithm requires a higher precision return env; } }; @@ -32,7 +35,7 @@ class SparseDoubleBisectionEnvironment { typedef storm::models::sparse::Mdp ModelType; static storm::Environment createEnvironment() { storm::Environment env; - env.modelchecker().setConditionalAlgorithmSetting(storm::ConditionalAlgorithmSetting::Bisection); + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::Bisection); return env; } }; @@ -44,7 +47,31 @@ class SparseDoubleBisectionAdvancedEnvironment { typedef storm::models::sparse::Mdp ModelType; static storm::Environment createEnvironment() { storm::Environment env; - env.modelchecker().setConditionalAlgorithmSetting(storm::ConditionalAlgorithmSetting::BisectionAdvanced); + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::BisectionAdvanced); + return env; + } +}; + +class SparseDoubleBisectionPtEnvironment { + public: + static const bool isExact = false; + typedef double ValueType; + typedef storm::models::sparse::Mdp ModelType; + static storm::Environment createEnvironment() { + storm::Environment env; + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::BisectionPolicyTracking); + return env; + } +}; + +class SparseDoubleBisectionAdvancedPtEnvironment { + public: + static const bool isExact = false; + typedef double ValueType; + typedef storm::models::sparse::Mdp ModelType; + static storm::Environment createEnvironment() { + storm::Environment env; + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::BisectionAdvancedPolicyTracking); return env; } }; @@ -56,7 +83,7 @@ class SparseDoublePiEnvironment { typedef storm::models::sparse::Mdp ModelType; static storm::Environment createEnvironment() { storm::Environment env; - env.modelchecker().setConditionalAlgorithmSetting(storm::ConditionalAlgorithmSetting::PolicyIteration); + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::PolicyIteration); return env; } }; @@ -68,8 +95,9 @@ class SparseRationalNumberRestartEnvironment { typedef storm::models::sparse::Mdp ModelType; static storm::Environment createEnvironment() { storm::Environment env; - env.modelchecker().setConditionalAlgorithmSetting(storm::ConditionalAlgorithmSetting::Restart); - env.solver().minMax().setPrecision(storm::utility::convertNumber(1e-10)); // restart algorithm requires a higher precision + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::Restart); + env.modelchecker().conditional().setPrecision(storm::utility::convertNumber(1e-10), + false); // restart algorithm requires a higher precision return env; } }; @@ -81,7 +109,7 @@ class SparseRationalNumberBisectionEnvironment { typedef storm::models::sparse::Mdp ModelType; static storm::Environment createEnvironment() { storm::Environment env; - env.modelchecker().setConditionalAlgorithmSetting(storm::ConditionalAlgorithmSetting::Bisection); + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::Bisection); return env; } }; @@ -93,7 +121,31 @@ class SparseRationalNumberBisectionAdvancedEnvironment { typedef storm::models::sparse::Mdp ModelType; static storm::Environment createEnvironment() { storm::Environment env; - env.modelchecker().setConditionalAlgorithmSetting(storm::ConditionalAlgorithmSetting::BisectionAdvanced); + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::BisectionAdvanced); + return env; + } +}; + +class SparseRationalNumberBisectionPtEnvironment { + public: + static const bool isExact = true; + typedef storm::RationalNumber ValueType; + typedef storm::models::sparse::Mdp ModelType; + static storm::Environment createEnvironment() { + storm::Environment env; + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::BisectionPolicyTracking); + return env; + } +}; + +class SparseRationalNumberBisectionAdvancedPtEnvironment { + public: + static const bool isExact = true; + typedef storm::RationalNumber ValueType; + typedef storm::models::sparse::Mdp ModelType; + static storm::Environment createEnvironment() { + storm::Environment env; + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::BisectionAdvancedPolicyTracking); return env; } }; @@ -105,7 +157,7 @@ class SparseRationalNumberPiEnvironment { typedef storm::models::sparse::Mdp ModelType; static storm::Environment createEnvironment() { storm::Environment env; - env.modelchecker().setConditionalAlgorithmSetting(storm::ConditionalAlgorithmSetting::PolicyIteration); + env.modelchecker().conditional().setAlgorithm(storm::ConditionalAlgorithmSetting::PolicyIteration); return env; } }; @@ -159,9 +211,10 @@ class ConditionalMdpPrctlModelCheckerTest : public ::testing::Test { storm::Environment _environment; }; -typedef ::testing::Types + SparseRationalNumberBisectionPtEnvironment, SparseRationalNumberBisectionAdvancedPtEnvironment, SparseRationalNumberPiEnvironment> TestingTypes; TYPED_TEST_SUITE(ConditionalMdpPrctlModelCheckerTest, TestingTypes, ); @@ -201,8 +254,13 @@ TYPED_TEST(ConditionalMdpPrctlModelCheckerTest, two_dice) { EXPECT_NEAR(this->parseNumber("0"), result[*mdp->getInitialStates().begin()], this->precision()); result = checker.check(this->env(), tasks[5])->template asExplicitQuantitativeCheckResult(); EXPECT_NEAR(this->parseNumber("0"), result[*mdp->getInitialStates().begin()], this->precision()); - result = checker.check(this->env(), tasks[6])->template asExplicitQuantitativeCheckResult(); - EXPECT_NEAR(this->parseNumber("1"), result[*mdp->getInitialStates().begin()], this->precision()); + + // This Environment depending on the platform fails or does not fail an assertion. Thus this env is skipped. + if constexpr (!std::is_same_v) { + result = checker.check(this->env(), tasks[6])->template asExplicitQuantitativeCheckResult(); + EXPECT_NEAR(this->parseNumber("1"), result[*mdp->getInitialStates().begin()], this->precision()); + } + result = checker.check(this->env(), tasks[7])->template asExplicitQuantitativeCheckResult(); EXPECT_NEAR(this->parseNumber("1"), result[*mdp->getInitialStates().begin()], this->precision()); } @@ -214,7 +272,11 @@ TYPED_TEST(ConditionalMdpPrctlModelCheckerTest, consensus) { "Pmax=? [F \"all_coins_equal_0\" & \"finished\" || F \"agree\" & \"finished\"];" "Pmin=? [F \"all_coins_equal_0\" & \"finished\" || F \"agree\" & \"finished\"];" "Pmax=? [F \"all_coins_equal_1\" & \"finished\" || F \"agree\" & \"finished\"];" - "Pmin=? [F \"all_coins_equal_1\" & \"finished\" || F \"agree\" & \"finished\"];"; + "Pmin=? [F \"all_coins_equal_1\" & \"finished\" || F \"agree\" & \"finished\"];" + "P<=560/953 [F \"all_coins_equal_1\" & \"finished\" || F \"agree\" & \"finished\"];" + "P<562/953 [F \"all_coins_equal_1\" & \"finished\" || F \"agree\" & \"finished\"];" + "P>393/953 [F \"all_coins_equal_1\" & \"finished\" || F \"agree\" & \"finished\"];" + "P>=391/953 [F \"all_coins_equal_1\" & \"finished\" || F \"agree\" & \"finished\"];"; auto program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/mdp/coin2-2.nm"); auto modelFormulas = this->buildModelFormulas(program, formulasString); @@ -234,6 +296,14 @@ TYPED_TEST(ConditionalMdpPrctlModelCheckerTest, consensus) { EXPECT_NEAR(this->parseNumber("561/953"), result[*mdp->getInitialStates().begin()], this->precision()); result = checker.check(this->env(), tasks[3])->template asExplicitQuantitativeCheckResult(); EXPECT_NEAR(this->parseNumber("392/953"), result[*mdp->getInitialStates().begin()], this->precision()); + auto qualResult = checker.check(this->env(), tasks[4])->template asExplicitQualitativeCheckResult(); + EXPECT_FALSE(qualResult[*mdp->getInitialStates().begin()]); + qualResult = checker.check(this->env(), tasks[5])->template asExplicitQualitativeCheckResult(); + EXPECT_TRUE(qualResult[*mdp->getInitialStates().begin()]); + qualResult = checker.check(this->env(), tasks[6])->template asExplicitQualitativeCheckResult(); + EXPECT_FALSE(qualResult[*mdp->getInitialStates().begin()]); + qualResult = checker.check(this->env(), tasks[7])->template asExplicitQualitativeCheckResult(); + EXPECT_TRUE(qualResult[*mdp->getInitialStates().begin()]); } TYPED_TEST(ConditionalMdpPrctlModelCheckerTest, simple) { diff --git a/src/test/storm/utility/RationalApproximationTest.cpp b/src/test/storm/utility/RationalApproximationTest.cpp new file mode 100644 index 000000000..9e7f770ee --- /dev/null +++ b/src/test/storm/utility/RationalApproximationTest.cpp @@ -0,0 +1,70 @@ +#include "storm-config.h" +#include "test/storm_gtest.h" + +#include "storm/adapters/RationalNumberAdapter.h" +#include "storm/utility/RationalApproximation.h" +#include "storm/utility/constants.h" + +namespace { + +storm::RationalNumber rn(double doubleValue) { + return storm::utility::convertNumber(doubleValue); +} + +storm::RationalNumber rn(std::string const& str) { + return storm::utility::convertNumber(str); +} +TEST(RationalApproximationTest, inclusive_bounds) { + EXPECT_EQ(rn("0"), storm::utility::findRational(rn("0"), true, rn("0"), true)); + EXPECT_EQ(rn("1"), storm::utility::findRational(rn("1"), true, rn("1"), true)); + EXPECT_EQ(rn("0"), storm::utility::findRational(rn("0"), true, rn("1"), true)); + EXPECT_EQ(rn("1/2"), storm::utility::findRational(rn("1/3"), true, rn("2/3"), true)); + EXPECT_EQ(rn("1/2"), storm::utility::findRational(rn("1/10"), true, rn("9/10"), true)); + EXPECT_EQ(rn("1"), storm::utility::findRational(rn("1/2"), true, rn("1"), true)); + EXPECT_EQ(rn("1"), storm::utility::findRational(rn("2/3"), true, rn("1"), true)); + EXPECT_EQ(rn("2/3"), storm::utility::findRational(rn("2/3"), true, rn("3/4"), true)); + EXPECT_EQ(rn("2/3"), storm::utility::findRational(rn("3/5"), true, rn("3/4"), true)); + EXPECT_EQ(rn("1"), storm::utility::findRational(rn("2/3"), true, rn("123456"), true)); + EXPECT_EQ(rn("23/3"), storm::utility::findRational(rn("23/3"), true, rn("31/4"), true)); + EXPECT_EQ(rn("23/3"), storm::utility::findRational(rn("38/5"), true, rn("31/4"), true)); + EXPECT_EQ(rn("75/7"), storm::utility::findRational(rn(10.71), true, rn(10.72), true)); + EXPECT_EQ(rn(0.123456), storm::utility::findRational(rn(0.123456), true, rn(0.123456), true)); + EXPECT_EQ(rn(987.123456), storm::utility::findRational(rn(987.123456), true, rn(987.123456), true)); +} + +TEST(RationalApproximationTest, exclusive_bounds) { + EXPECT_EQ(rn("1/2"), storm::utility::findRational(rn("0"), false, rn("1"), false)); + EXPECT_EQ(rn("0"), storm::utility::findRational(rn("0"), true, rn("1"), false)); + EXPECT_EQ(rn("1"), storm::utility::findRational(rn("0"), false, rn("1"), true)); + EXPECT_EQ(rn("1/3"), storm::utility::findRational(rn("0"), false, rn("1/2"), false)); + EXPECT_EQ(rn("2/3"), storm::utility::findRational(rn("1/2"), false, rn("1"), false)); + EXPECT_EQ(rn("3/4"), storm::utility::findRational(rn("2/3"), false, rn("1"), false)); + EXPECT_EQ(rn("5/7"), storm::utility::findRational(rn("2/3"), false, rn("3/4"), false)); + EXPECT_EQ(rn("3/2"), storm::utility::findRational(rn("1"), false, rn("2"), false)); + EXPECT_EQ(rn("30/19"), storm::utility::findRational(rn("11/7"), false, rn("19/12"), false)); + EXPECT_EQ(rn("11/7"), storm::utility::findRational(rn("11/7"), true, rn("19/12"), false)); + EXPECT_EQ(rn("19/12"), storm::utility::findRational(rn("11/7"), false, rn("19/12"), true)); + EXPECT_EQ(rn("1000/1001"), storm::utility::findRational(rn("999/1000"), false, rn("1"), false)); + EXPECT_EQ(rn("999/1000"), storm::utility::findRational(rn("999/1000"), true, rn("1"), false)); + EXPECT_EQ(rn("333/334"), storm::utility::findRational(rn("997/1000"), true, rn("1"), false)); + EXPECT_EQ(rn("1001/1000"), storm::utility::findRational(rn("1"), false, rn("1001/1000"), true)); + EXPECT_EQ(rn("1002/1001"), storm::utility::findRational(rn("1"), false, rn("1001/1000"), false)); + EXPECT_EQ(rn("335/334"), storm::utility::findRational(rn("1"), false, rn("1003/1000"), true)); + EXPECT_EQ(rn("500/1001"), storm::utility::findRational(rn("999/2000"), false, rn("1/2"), false)); + EXPECT_EQ(rn("167/335"), storm::utility::findRational(rn("997/2000"), true, rn("1/2"), false)); + EXPECT_EQ(rn("500/1001"), storm::utility::findRational(rn("500/1001"), true, rn("1/2"), false)); + EXPECT_EQ(rn("501/1001"), storm::utility::findRational(rn("1/2"), false, rn("501/1001"), true)); + EXPECT_EQ(rn("502/1003"), storm::utility::findRational(rn("1/2"), false, rn("501/1001"), false)); + EXPECT_EQ(rn("168/335"), storm::utility::findRational(rn("1/2"), false, rn("502/1001"), true)); +} + +TEST(RationalApproximationTest, negative) { + EXPECT_EQ(rn("0"), storm::utility::findRational(rn("-1"), false, rn("1"), false)); + EXPECT_EQ(rn("0"), storm::utility::findRational(rn("-1"), true, rn("0"), true)); + EXPECT_EQ(rn("-1"), storm::utility::findRational(rn("-1"), true, rn("0"), false)); + EXPECT_EQ(rn("-30/19"), storm::utility::findRational(rn("-19/12"), false, rn("-11/7"), false)); + EXPECT_EQ(rn("-11/7"), storm::utility::findRational(rn("-19/12"), false, rn("-11/7"), true)); + EXPECT_EQ(rn("-19/12"), storm::utility::findRational(rn("-19/12"), true, rn("-11/7"), false)); +} + +} // namespace