diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index e8aa4511d8..eb31d2e1ea 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -143,67 +143,64 @@ void DiscreteConditional::print(const string& s, } } cout << "):\n"; - ADT::print(""); + ADT::print("", formatter); cout << endl; } /* ******************************************************************************** */ bool DiscreteConditional::equals(const DiscreteFactor& other, - double tol) const { - if (!dynamic_cast(&other)) + double tol) const { + if (!dynamic_cast(&other)) { return false; - else { - const DecisionTreeFactor& f( - static_cast(other)); + } else { + const DecisionTreeFactor& f(static_cast(other)); return DecisionTreeFactor::equals(f, tol); } } -/* ******************************************************************************** */ +/* ************************************************************************** */ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, - const DiscreteValues& parentsValues) { + const DiscreteValues& given, + bool forceComplete = true) { // Get the big decision tree with all the levels, and then go down the // branches based on the value of the parent variables. DiscreteConditional::ADT adt(conditional); size_t value; for (Key j : conditional.parents()) { try { - value = parentsValues.at(j); + value = given.at(j); adt = adt.choose(j, value); // ADT keeps getting smaller. } catch (std::out_of_range&) { - parentsValues.print("parentsValues: "); - throw runtime_error("DiscreteConditional::choose: parent value missing"); - }; + if (forceComplete) { + given.print("parentsValues: "); + throw runtime_error( + "DiscreteConditional::Choose: parent value missing"); + } + } } return adt; } -/* ******************************************************************************** */ -DecisionTreeFactor::shared_ptr DiscreteConditional::choose( - const DiscreteValues& parentsValues) const { - // Get the big decision tree with all the levels, and then go down the - // branches based on the value of the parent variables. - ADT adt(*this); - size_t value; - for (Key j : parents()) { - try { - value = parentsValues.at(j); - adt = adt.choose(j, value); // ADT keeps getting smaller. - } catch (exception&) { - parentsValues.print("parentsValues: "); - throw runtime_error("DiscreteConditional::choose: parent value missing"); - }; - } +/* ************************************************************************** */ +DiscreteConditional::shared_ptr DiscreteConditional::choose( + const DiscreteValues& given) const { + ADT adt = Choose(*this, given, false); // P(F|S=given) - // Convert ADT to factor. - DiscreteKeys discreteKeys; + // Collect all keys not in given. + DiscreteKeys dKeys; for (Key j : frontals()) { - discreteKeys.emplace_back(j, this->cardinality(j)); + dKeys.emplace_back(j, this->cardinality(j)); } - return boost::make_shared(discreteKeys, adt); + for (size_t i = nrFrontals(); i < size(); i++) { + Key j = keys_[i]; + if (given.count(j) == 0) { + dKeys.emplace_back(j, this->cardinality(j)); + } + } + return boost::make_shared(nrFrontals(), dKeys, adt); } -/* ******************************************************************************** */ +/* ************************************************************************** */ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( const DiscreteValues& frontalValues) const { // Get the big decision tree with all the levels, and then go down the @@ -217,7 +214,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( } catch (exception&) { frontalValues.print("frontalValues: "); throw runtime_error("DiscreteConditional::choose: frontal value missing"); - }; + } } // Convert ADT to factor. @@ -242,7 +239,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( /* ************************************************************************** */ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { - // TODO(Abhijit): is this really the fastest way? He thinks it is. ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) // Initialize @@ -276,11 +272,9 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { (*values)[j] = sampled; // store result in partial solution } -/* ******************************************************************************** */ +/* ************************************************************************** */ size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { - - // TODO: is this really the fastest way? I think it is. - ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // Then, find the max over all remaining // TODO, only works for one key now, seems horribly slow this way diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index c3c8a66def..5908cc782e 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - /** Restrict to given parent values, returns DecisionTreeFactor */ - DecisionTreeFactor::shared_ptr choose( - const DiscreteValues& parentsValues) const; + /** + * @brief restrict to given *parent* values. + * + * Note: does not need be complete set. Examples: + * + * P(C|D,E) + . -> P(C|D,E) + * P(C|D,E) + E -> P(C|D) + * P(C|D,E) + D -> P(C|E) + * P(C|D,E) + D,E -> P(C) + * P(C|D,E) + C -> error! + * + * @return a shared_ptr to a new DiscreteConditional + */ + shared_ptr choose(const DiscreteValues& given) const; /** Convert to a likelihood factor by providing value before bar. */ DecisionTreeFactor::shared_ptr likelihood( diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 08c3d893d9..1da840eb8e 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -64,33 +64,35 @@ template<> struct EliminationTraits * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. * Factor == DiscreteFactor */ -class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph, -public EliminateableFactorGraph { -public: +class GTSAM_EXPORT DiscreteFactorGraph + : public FactorGraph, + public EliminateableFactorGraph { + public: + using This = DiscreteFactorGraph; ///< this class + using Base = FactorGraph; ///< base factor graph type + using BaseEliminateable = + EliminateableFactorGraph; ///< for elimination + using shared_ptr = boost::shared_ptr; ///< shared_ptr to This - typedef DiscreteFactorGraph This; ///< Typedef to this class - typedef FactorGraph Base; ///< Typedef to base factor graph type - typedef EliminateableFactorGraph BaseEliminateable; ///< Typedef to base elimination class - typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class + using Values = DiscreteValues; ///< backwards compatibility - using Values = DiscreteValues; ///< backwards compatibility - - /** A map from keys to values */ - typedef KeyVector Indices; + using Indices = KeyVector; ///> map from keys to values /** Default constructor */ DiscreteFactorGraph() {} /** Construct from iterator over factors */ - template - DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {} + template + DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) + : Base(firstFactor, lastFactor) {} /** Construct from container of factors (shared_ptr or plain objects) */ - template + template explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {} - /** Implicit copy/downcast constructor to override explicit template container constructor */ - template + /** Implicit copy/downcast constructor to override explicit template container + * constructor */ + template DiscreteFactorGraph(const FactorGraph& graph) : Base(graph) {} /// Destructor @@ -108,7 +110,7 @@ public EliminateableFactorGraph { void add(Args&&... args) { emplace_shared(std::forward(args)...); } - + /** Return the set of variables involved in the factors (set union) */ KeySet keys() const; @@ -163,9 +165,10 @@ public EliminateableFactorGraph { const DiscreteFactor::Names& names = {}) const; /// @} -}; // \ DiscreteFactorGraph +}; // \ DiscreteFactorGraph /// traits -template<> struct traits : public Testable {}; +template <> +struct traits : public Testable {}; -} // \ namespace gtsam +} // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index e4af27eb19..e2310f4344 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -107,8 +107,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { void printSignature( string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; - gtsam::DecisionTreeFactor* choose( - const gtsam::DiscreteValues& parentsValues) const; + gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const; gtsam::DecisionTreeFactor* likelihood( const gtsam::DiscreteValues& frontalValues) const; gtsam::DecisionTreeFactor* likelihood(size_t value) const; @@ -230,11 +229,16 @@ class DiscreteFactorGraph { DiscreteFactorGraph(); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); - void add(const gtsam::DiscreteKey& j, string table); + // Building the graph + void push_back(const gtsam::DiscreteFactor* factor); + void push_back(const gtsam::DiscreteConditional* conditional); + void push_back(const gtsam::DiscreteFactorGraph& graph); + void push_back(const gtsam::DiscreteBayesNet& bayesNet); + void push_back(const gtsam::DiscreteBayesTree& bayesTree); + void add(const gtsam::DiscreteKey& j, string spec); void add(const gtsam::DiscreteKey& j, const std::vector& spec); - - void add(const gtsam::DiscreteKeys& keys, string table); - void add(const std::vector& keys, string table); + void add(const gtsam::DiscreteKeys& keys, string spec); + void add(const std::vector& keys, string spec); bool empty() const; size_t size() const; @@ -258,8 +262,12 @@ class DiscreteFactorGraph { gtsam::DiscreteBayesNet eliminateSequential(); gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); + std::pair + eliminatePartialSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesTree eliminateMultifrontal(); gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering); + std::pair + eliminatePartialMultifrontal(const gtsam::Ordering& ordering); string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 1256595170..c2d941eaa7 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -221,6 +221,34 @@ TEST(DiscreteConditional, likelihood) { EXPECT(assert_equal(expected1, *actual1, 1e-9)); } +/* ************************************************************************* */ +// Check choose on P(C|D,E) +TEST(DiscreteConditional, choose) { + DiscreteKey C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + // Case 1: no given values: no-op + DiscreteValues given; + auto actual1 = C_given_DE.choose(given); + EXPECT(assert_equal(C_given_DE, *actual1, 1e-9)); + + // Case 2: 1 given value + given[D.first] = 1; + auto actual2 = C_given_DE.choose(given); + EXPECT_LONGS_EQUAL(1, actual2->nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual2->nrParents()); + DiscreteConditional expected2(C | E = "1/1 1/4"); + EXPECT(assert_equal(expected2, *actual2, 1e-9)); + + // Case 2: 2 given values + given[E.first] = 0; + auto actual3 = C_given_DE.choose(given); + EXPECT_LONGS_EQUAL(1, actual3->nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual3->nrParents()); + DiscreteConditional expected3(C % "1/1"); + EXPECT(assert_equal(expected3, *actual3, 1e-9)); +} + /* ************************************************************************* */ // Check markdown representation looks as expected, no parents. TEST(DiscreteConditional, markdown_prior) { diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index ef9efbe026..579244c57f 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -376,8 +376,12 @@ TEST(DiscreteFactorGraph, Dot) { " var1[label=\"1\"];\n" " var2[label=\"2\"];\n" "\n" - " var0--var1;\n" - " var0--var2;\n" + " factor0[label=\"\", shape=point];\n" + " var0--factor0;\n" + " var1--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " var0--factor1;\n" + " var2--factor1;\n" "}\n"; EXPECT(actual == expected); } @@ -397,12 +401,16 @@ TEST(DiscreteFactorGraph, DotWithNames) { "graph {\n" " size=\"5,5\";\n" "\n" - " var0[label=\"C\"];\n" - " var1[label=\"A\"];\n" - " var2[label=\"B\"];\n" + " varC[label=\"C\"];\n" + " varA[label=\"A\"];\n" + " varB[label=\"B\"];\n" "\n" - " var0--var1;\n" - " var0--var2;\n" + " factor0[label=\"\", shape=point];\n" + " varC--factor0;\n" + " varA--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " varC--factor1;\n" + " varB--factor1;\n" "}\n"; EXPECT(actual == expected); } diff --git a/gtsam/inference/DotWriter.cpp b/gtsam/inference/DotWriter.cpp index fb3ea05054..18130c35d7 100644 --- a/gtsam/inference/DotWriter.cpp +++ b/gtsam/inference/DotWriter.cpp @@ -35,7 +35,8 @@ void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter, const boost::optional& position, ostream* os) { // Label the node with the label from the KeyFormatter - *os << " var" << key << "[label=\"" << keyFormatter(key) << "\""; + *os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key) + << "\""; if (position) { *os << ", pos=\"" << position->x() << "," << position->y() << "!\""; } @@ -51,22 +52,26 @@ void DotWriter::DrawFactor(size_t i, const boost::optional& position, *os << "];\n"; } -void DotWriter::ConnectVariables(Key key1, Key key2, ostream* os) { - *os << " var" << key1 << "--" - << "var" << key2 << ";\n"; +static void ConnectVariables(Key key1, Key key2, + const KeyFormatter& keyFormatter, + ostream* os) { + *os << " var" << keyFormatter(key1) << "--" + << "var" << keyFormatter(key2) << ";\n"; } -void DotWriter::ConnectVariableFactor(Key key, size_t i, ostream* os) { - *os << " var" << key << "--" +static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter, + size_t i, ostream* os) { + *os << " var" << keyFormatter(key) << "--" << "factor" << i << ";\n"; } void DotWriter::processFactor(size_t i, const KeyVector& keys, + const KeyFormatter& keyFormatter, const boost::optional& position, ostream* os) const { if (plotFactorPoints) { if (binaryEdges && keys.size() == 2) { - ConnectVariables(keys[0], keys[1], os); + ConnectVariables(keys[0], keys[1], keyFormatter, os); } else { // Create dot for the factor. DrawFactor(i, position, os); @@ -74,7 +79,7 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys, // Make factor-variable connections if (connectKeysToFactor) { for (Key key : keys) { - ConnectVariableFactor(key, i, os); + ConnectVariableFactor(key, keyFormatter, i, os); } } } @@ -83,7 +88,7 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys, for (Key key1 : keys) { for (Key key2 : keys) { if (key2 > key1) { - ConnectVariables(key1, key2, os); + ConnectVariables(key1, key2, keyFormatter, os); } } } diff --git a/gtsam/inference/DotWriter.h b/gtsam/inference/DotWriter.h index bd36da496c..93c229c2b1 100644 --- a/gtsam/inference/DotWriter.h +++ b/gtsam/inference/DotWriter.h @@ -38,7 +38,7 @@ struct GTSAM_EXPORT DotWriter { explicit DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, bool plotFactorPoints = true, - bool connectKeysToFactor = true, bool binaryEdges = true) + bool connectKeysToFactor = true, bool binaryEdges = false) : figureWidthInches(figureWidthInches), figureHeightInches(figureHeightInches), plotFactorPoints(plotFactorPoints), @@ -57,14 +57,9 @@ struct GTSAM_EXPORT DotWriter { static void DrawFactor(size_t i, const boost::optional& position, std::ostream* os); - /// Connect two variables. - static void ConnectVariables(Key key1, Key key2, std::ostream* os); - - /// Connect variable and factor. - static void ConnectVariableFactor(Key key, size_t i, std::ostream* os); - /// Draw a single factor, specified by its index i and its variable keys. void processFactor(size_t i, const KeyVector& keys, + const KeyFormatter& keyFormatter, const boost::optional& position, std::ostream* os) const; }; diff --git a/gtsam/inference/FactorGraph-inst.h b/gtsam/inference/FactorGraph-inst.h index 058075f2d5..3ea17fc7ff 100644 --- a/gtsam/inference/FactorGraph-inst.h +++ b/gtsam/inference/FactorGraph-inst.h @@ -144,7 +144,7 @@ void FactorGraph::dot(std::ostream& os, const auto& factor = at(i); if (factor) { const KeyVector& factorKeys = factor->keys(); - writer.processFactor(i, factorKeys, boost::none, &os); + writer.processFactor(i, factorKeys, keyFormatter, boost::none, &os); } } diff --git a/gtsam/nonlinear/NonlinearFactorGraph.cpp b/gtsam/nonlinear/NonlinearFactorGraph.cpp index 89236ea878..da8935d5fc 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.cpp +++ b/gtsam/nonlinear/NonlinearFactorGraph.cpp @@ -33,8 +33,10 @@ # include #endif +#include #include #include +#include using namespace std; @@ -127,7 +129,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, // Create factors and variable connections size_t i = 0; for (const KeyVector& factorKeys : structure) { - writer.processFactor(i++, factorKeys, boost::none, &os); + writer.processFactor(i++, factorKeys, keyFormatter, boost::none, &os); } } else { // Create factors and variable connections @@ -135,7 +137,8 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, const NonlinearFactor::shared_ptr& factor = at(i); if (factor) { const KeyVector& factorKeys = factor->keys(); - writer.processFactor(i, factorKeys, writer.factorPos(min, i), &os); + writer.processFactor(i, factorKeys, keyFormatter, + writer.factorPos(min, i), &os); } } } diff --git a/tests/testNonlinearFactorGraph.cpp b/tests/testNonlinearFactorGraph.cpp index 8a360e4542..05a6e7f45e 100644 --- a/tests/testNonlinearFactorGraph.cpp +++ b/tests/testNonlinearFactorGraph.cpp @@ -335,15 +335,21 @@ TEST(NonlinearFactorGraph, dot) { "graph {\n" " size=\"5,5\";\n" "\n" - " var7782220156096217089[label=\"l1\"];\n" - " var8646911284551352321[label=\"x1\"];\n" - " var8646911284551352322[label=\"x2\"];\n" + " varl1[label=\"l1\"];\n" + " varx1[label=\"x1\"];\n" + " varx2[label=\"x2\"];\n" "\n" " factor0[label=\"\", shape=point];\n" - " var8646911284551352321--factor0;\n" - " var8646911284551352321--var8646911284551352322;\n" - " var8646911284551352321--var7782220156096217089;\n" - " var8646911284551352322--var7782220156096217089;\n" + " varx1--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " varx1--factor1;\n" + " varx2--factor1;\n" + " factor2[label=\"\", shape=point];\n" + " varx1--factor2;\n" + " varl1--factor2;\n" + " factor3[label=\"\", shape=point];\n" + " varx2--factor3;\n" + " varl1--factor3;\n" "}\n"; const NonlinearFactorGraph fg = createNonlinearFactorGraph(); @@ -357,15 +363,21 @@ TEST(NonlinearFactorGraph, dot_extra) { "graph {\n" " size=\"5,5\";\n" "\n" - " var7782220156096217089[label=\"l1\", pos=\"0,0!\"];\n" - " var8646911284551352321[label=\"x1\", pos=\"1,0!\"];\n" - " var8646911284551352322[label=\"x2\", pos=\"1,1.5!\"];\n" + " varl1[label=\"l1\", pos=\"0,0!\"];\n" + " varx1[label=\"x1\", pos=\"1,0!\"];\n" + " varx2[label=\"x2\", pos=\"1,1.5!\"];\n" "\n" " factor0[label=\"\", shape=point];\n" - " var8646911284551352321--factor0;\n" - " var8646911284551352321--var8646911284551352322;\n" - " var8646911284551352321--var7782220156096217089;\n" - " var8646911284551352322--var7782220156096217089;\n" + " varx1--factor0;\n" + " factor1[label=\"\", shape=point];\n" + " varx1--factor1;\n" + " varx2--factor1;\n" + " factor2[label=\"\", shape=point];\n" + " varx1--factor2;\n" + " varl1--factor2;\n" + " factor3[label=\"\", shape=point];\n" + " varx2--factor3;\n" + " varl1--factor3;\n" "}\n"; const NonlinearFactorGraph fg = createNonlinearFactorGraph();