Skip to content

Commit 6511a75

Browse files
committed
Added decision preferences of unseen terms
1 parent 3041611 commit 6511a75

File tree

3 files changed

+37
-21
lines changed

3 files changed

+37
-21
lines changed

src/api/MainSolver.cc

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -162,23 +162,35 @@ void MainSolver::resetDecisionPreferences() {
162162
smt_solver->clearUserBranchLits();
163163
}
164164

165-
void MainSolver::propagateDecisionPreferenceToSMTSolver(PTRef pref) {
165+
void MainSolver::propagateDecisionPreferenceToSMTSolver(PTRef pref, FrameId frameId) {
166166
assert(logic.getSortRef(pref) == logic.getSort_bool());
167167
assert(not logic.isConstant(pref));
168168

169-
Lit l;
170-
if (logic.isBoolVarLit(pref)) {
171-
l = term_mapper->getOrCreateLit(pref);
172-
Var v = var(l);
173-
smt_solver->addVar(v);
174-
} else {
175-
// Ignores substitutions ..
176-
assert(term_mapper->hasLit(pref));
177-
l = term_mapper->getLit(pref);
178-
}
179-
assert(term_mapper->getLit(pref) == l);
180-
assert(term_mapper->getVar(pref) == var(l));
181-
smt_solver->pushUserBranchLit(l);
169+
// Ignores substitutions ..
170+
Lit lit = [&] {
171+
if (term_mapper->hasLit(pref)) { return term_mapper->getLit(pref); }
172+
173+
if (logic.isBoolVarLit(pref)) {
174+
Lit l = term_mapper->getOrCreateLit(pref);
175+
Var v = var(l);
176+
smt_solver->addVar(v);
177+
assert(term_mapper->getLit(pref) == l);
178+
assert(term_mapper->getVar(pref) == var(l));
179+
return l;
180+
}
181+
182+
auto name = std::string{"pref"} + std::to_string(smt_solver->userBranchLitsSize());
183+
PTRef decisionVarTerm = logic.mkBoolVar(name.c_str());
184+
Lit l = term_mapper->getOrCreateLit(decisionVarTerm);
185+
PTRef condTerm = logic.mkImpl(decisionVarTerm, pref);
186+
sstat status = giveToSolver(condTerm, frameId);
187+
assert(status == s_Undef);
188+
assert(term_mapper->getLit(decisionVarTerm) == l);
189+
assert(term_mapper->getVar(decisionVarTerm) == var(l));
190+
return l;
191+
}();
192+
193+
smt_solver->pushUserBranchLit(lit);
182194
}
183195

184196
sstat MainSolver::simplifyFormulas() {
@@ -235,7 +247,7 @@ sstat MainSolver::simplifyFormulas() {
235247
}
236248

237249
for (PTRef pref : decisionPreferences.scope(i)) {
238-
propagateDecisionPreferenceToSMTSolver(pref);
250+
propagateDecisionPreferenceToSMTSolver(pref, frames[i].getId());
239251
}
240252
}
241253
if (status == s_False) {
@@ -355,16 +367,17 @@ std::unique_ptr<InterpolationContext> MainSolver::getInterpolationContext() {
355367
}
356368

357369
sstat MainSolver::giveToSolver(PTRef root, FrameId push_id) {
358-
359370
struct ClauseCallBack : public Cnfizer::ClauseCallBack {
360371
std::vector<vec<Lit>> clauses;
361372
void operator()(vec<Lit> && c) override { clauses.push_back(std::move(c)); }
362373
};
374+
363375
ClauseCallBack callBack;
364376
ts.setClauseCallBack(&callBack);
365377
ts.Cnfizer::cnfize(root, push_id);
366378
bool const keepPartitionsSeparate = trackPartitions();
367-
Lit frameLit = push_id == 0 ? Lit{} : term_mapper->getOrCreateLit(frameTerms[push_id]);
379+
Lit frameLit;
380+
if (push_id != 0) { frameLit = term_mapper->getOrCreateLit(frameTerms[push_id]); }
368381
int partitionIndex = keepPartitionsSeparate ? pmanager.getPartitionIndex(root) : -1;
369382
for (auto & clause : callBack.clauses) {
370383
if (push_id != 0) { clause.push(frameLit); }

src/api/MainSolver.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ class MainSolver {
295295

296296
inline bool trackPartitions() const;
297297

298-
void propagateDecisionPreferenceToSMTSolver(PTRef);
298+
void propagateDecisionPreferenceToSMTSolver(PTRef, FrameId);
299299

300300
PTRef rewriteMaxArity(PTRef root);
301301

test/unit/test_DecisionPreference.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,21 @@ class DecisionPreferenceTest : public ::testing::Test {
4747
auto model = mainSolver.getModel();
4848
PTRef val_a = model->evaluate(a);
4949
PTRef val_b = model->evaluate(b);
50+
assert(val_a != PTRef_Undef);
51+
assert(val_b != PTRef_Undef);
5052
PTRef exp_val_a = lboolValToPTRef(expValA);
5153
PTRef exp_val_b = lboolValToPTRef(expValB);
52-
EXPECT_EQ(val_a, exp_val_a);
53-
EXPECT_EQ(val_b, exp_val_b);
54+
if (exp_val_a != PTRef_Undef) { EXPECT_EQ(val_a, exp_val_a); }
55+
if (exp_val_b != PTRef_Undef) { EXPECT_EQ(val_b, exp_val_b); }
5456
}
5557

5658
protected:
5759
PTRef lboolValToPTRef(lbool val) const {
5860
auto & logic = osmt->getLogic();
5961
if (val == l_False) { return logic.getTerm_false(); }
6062
if (val == l_True) { return logic.getTerm_true(); }
61-
return logic.getDefaultValuePTRef(logic.getSort_bool());
63+
// return logic.getDefaultValuePTRef(logic.getSort_bool());
64+
return PTRef_Undef;
6265
}
6366

6467
std::shared_ptr<Opensmt> osmt{};

0 commit comments

Comments
 (0)