Skip to content

Commit

Permalink
Merge pull request #1045 from borglab/feature/discrete_wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Jan 20, 2022
2 parents ee7d32d + 640a3b8 commit d8abdc2
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 108 deletions.
72 changes: 33 additions & 39 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,67 +143,64 @@ void DiscreteConditional::print(const string& s,
}
}
cout << "):\n";
ADT::print("");
ADT::print("", formatter);
cout << endl;
}

/* ******************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false;
else {
const DecisionTreeFactor& f(
static_cast<const DecisionTreeFactor&>(other));
} else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return DecisionTreeFactor::equals(f, tol);
}
}

/* ******************************************************************************** */
/* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
const DiscreteValues& parentsValues) {
const DiscreteValues& given,
bool forceComplete = true) {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional);
size_t value;
for (Key j : conditional.parents()) {
try {
value = parentsValues.at(j);
value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
if (forceComplete) {
given.print("parentsValues: ");
throw runtime_error(
"DiscreteConditional::Choose: parent value missing");
}
}
}
return adt;
}

/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
const DiscreteValues& parentsValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
ADT adt(*this);
size_t value;
for (Key j : parents()) {
try {
value = parentsValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
/* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const {
ADT adt = Choose(*this, given, false); // P(F|S=given)

// Convert ADT to factor.
DiscreteKeys discreteKeys;
// Collect all keys not in given.
DiscreteKeys dKeys;
for (Key j : frontals()) {
discreteKeys.emplace_back(j, this->cardinality(j));
dKeys.emplace_back(j, this->cardinality(j));
}
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
for (size_t i = nrFrontals(); i < size(); i++) {
Key j = keys_[i];
if (given.count(j) == 0) {
dKeys.emplace_back(j, this->cardinality(j));
}
}
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
}

/* ******************************************************************************** */
/* ************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
Expand All @@ -217,7 +214,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
} catch (exception&) {
frontalValues.print("frontalValues: ");
throw runtime_error("DiscreteConditional::choose: frontal value missing");
};
}
}

// Convert ADT to factor.
Expand All @@ -242,7 +239,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(

/* ************************************************************************** */
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO(Abhijit): is this really the fastest way? He thinks it is.
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)

// Initialize
Expand Down Expand Up @@ -276,11 +272,9 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
(*values)[j] = sampled; // store result in partial solution
}

/* ******************************************************************************** */
/* ************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {

// TODO: is this really the fastest way? I think it is.
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)

// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way
Expand Down
17 changes: 14 additions & 3 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values);
}

/** Restrict to given parent values, returns DecisionTreeFactor */
DecisionTreeFactor::shared_ptr choose(
const DiscreteValues& parentsValues) const;
/**
* @brief restrict to given *parent* values.
*
* Note: does not need be complete set. Examples:
*
* P(C|D,E) + . -> P(C|D,E)
* P(C|D,E) + E -> P(C|D)
* P(C|D,E) + D -> P(C|E)
* P(C|D,E) + D,E -> P(C)
* P(C|D,E) + C -> error!
*
* @return a shared_ptr to a new DiscreteConditional
*/
shared_ptr choose(const DiscreteValues& given) const;

/** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood(
Expand Down
43 changes: 23 additions & 20 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,33 +64,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor
*/
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
class GTSAM_EXPORT DiscreteFactorGraph
: public FactorGraph<DiscreteFactor>,
public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
using This = DiscreteFactorGraph; ///< this class
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
using BaseEliminateable =
EliminateableFactorGraph<This>; ///< for elimination
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This

typedef DiscreteFactorGraph This; ///< Typedef to this class
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
using Values = DiscreteValues; ///< backwards compatibility

using Values = DiscreteValues; ///< backwards compatibility

/** A map from keys to values */
typedef KeyVector Indices;
using Indices = KeyVector; ///> map from keys to values

/** Default constructor */
DiscreteFactorGraph() {}

/** Construct from iterator over factors */
template<typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
template <typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
: Base(firstFactor, lastFactor) {}

/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
template <class CONTAINER>
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}

/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDFACTOR>
/** Implicit copy/downcast constructor to override explicit template container
* constructor */
template <class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}

/// Destructor
Expand All @@ -108,7 +110,7 @@ public EliminateableFactorGraph<DiscreteFactorGraph> {
void add(Args&&... args) {
emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...);
}

/** Return the set of variables involved in the factors (set union) */
KeySet keys() const;

Expand Down Expand Up @@ -163,9 +165,10 @@ public EliminateableFactorGraph<DiscreteFactorGraph> {
const DiscreteFactor::Names& names = {}) const;

/// @}
}; // \ DiscreteFactorGraph
}; // \ DiscreteFactorGraph

/// traits
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};

} // \ namespace gtsam
} // namespace gtsam
20 changes: 14 additions & 6 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
void printSignature(
string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* choose(
const gtsam::DiscreteValues& parentsValues) const;
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
Expand Down Expand Up @@ -230,11 +229,16 @@ class DiscreteFactorGraph {
DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);

void add(const gtsam::DiscreteKey& j, string table);
// Building the graph
void push_back(const gtsam::DiscreteFactor* factor);
void push_back(const gtsam::DiscreteConditional* conditional);
void push_back(const gtsam::DiscreteFactorGraph& graph);
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
void add(const gtsam::DiscreteKey& j, string spec);
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);

void add(const gtsam::DiscreteKeys& keys, string table);
void add(const std::vector<gtsam::DiscreteKey>& keys, string table);
void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);

bool empty() const;
size_t size() const;
Expand All @@ -258,8 +262,12 @@ class DiscreteFactorGraph {

gtsam::DiscreteBayesNet eliminateSequential();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
eliminatePartialSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree eliminateMultifrontal();
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);

string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand Down
28 changes: 28 additions & 0 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,34 @@ TEST(DiscreteConditional, likelihood) {
EXPECT(assert_equal(expected1, *actual1, 1e-9));
}

/* ************************************************************************* */
// Check choose on P(C|D,E)
TEST(DiscreteConditional, choose) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");

// Case 1: no given values: no-op
DiscreteValues given;
auto actual1 = C_given_DE.choose(given);
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));

// Case 2: 1 given value
given[D.first] = 1;
auto actual2 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
DiscreteConditional expected2(C | E = "1/1 1/4");
EXPECT(assert_equal(expected2, *actual2, 1e-9));

// Case 2: 2 given values
given[E.first] = 0;
auto actual3 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
DiscreteConditional expected3(C % "1/1");
EXPECT(assert_equal(expected3, *actual3, 1e-9));
}

/* ************************************************************************* */
// Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) {
Expand Down
22 changes: 15 additions & 7 deletions gtsam/discrete/tests/testDiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,12 @@ TEST(DiscreteFactorGraph, Dot) {
" var1[label=\"1\"];\n"
" var2[label=\"2\"];\n"
"\n"
" var0--var1;\n"
" var0--var2;\n"
" factor0[label=\"\", shape=point];\n"
" var0--factor0;\n"
" var1--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" var0--factor1;\n"
" var2--factor1;\n"
"}\n";
EXPECT(actual == expected);
}
Expand All @@ -397,12 +401,16 @@ TEST(DiscreteFactorGraph, DotWithNames) {
"graph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"C\"];\n"
" var1[label=\"A\"];\n"
" var2[label=\"B\"];\n"
" varC[label=\"C\"];\n"
" varA[label=\"A\"];\n"
" varB[label=\"B\"];\n"
"\n"
" var0--var1;\n"
" var0--var2;\n"
" factor0[label=\"\", shape=point];\n"
" varC--factor0;\n"
" varA--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" varC--factor1;\n"
" varB--factor1;\n"
"}\n";
EXPECT(actual == expected);
}
Expand Down
Loading

0 comments on commit d8abdc2

Please # to comment.