Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Some discrete improvements #1045

Merged
merged 8 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 33 additions & 39 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,67 +143,64 @@ void DiscreteConditional::print(const string& s,
}
}
cout << "):\n";
ADT::print("");
ADT::print("", formatter);
cout << endl;
}

/* ******************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false;
else {
const DecisionTreeFactor& f(
static_cast<const DecisionTreeFactor&>(other));
} else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return DecisionTreeFactor::equals(f, tol);
}
}

/* ******************************************************************************** */
/* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
const DiscreteValues& parentsValues) {
const DiscreteValues& given,
bool forceComplete = true) {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional);
size_t value;
for (Key j : conditional.parents()) {
try {
value = parentsValues.at(j);
value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
if (forceComplete) {
given.print("parentsValues: ");
throw runtime_error(
"DiscreteConditional::Choose: parent value missing");
}
}
}
return adt;
}

/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
const DiscreteValues& parentsValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
ADT adt(*this);
size_t value;
for (Key j : parents()) {
try {
value = parentsValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
/* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const {
ADT adt = Choose(*this, given, false); // P(F|S=given)

// Convert ADT to factor.
DiscreteKeys discreteKeys;
// Collect all keys not in given.
DiscreteKeys dKeys;
for (Key j : frontals()) {
discreteKeys.emplace_back(j, this->cardinality(j));
dKeys.emplace_back(j, this->cardinality(j));
}
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
for (size_t i = nrFrontals(); i < size(); i++) {
Key j = keys_[i];
if (given.count(j) == 0) {
dKeys.emplace_back(j, this->cardinality(j));
}
}
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
}

/* ******************************************************************************** */
/* ************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
Expand All @@ -217,7 +214,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
} catch (exception&) {
frontalValues.print("frontalValues: ");
throw runtime_error("DiscreteConditional::choose: frontal value missing");
};
}
}

// Convert ADT to factor.
Expand All @@ -242,7 +239,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(

/* ************************************************************************** */
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO(Abhijit): is this really the fastest way? He thinks it is.
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)

// Initialize
Expand Down Expand Up @@ -276,11 +272,9 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
(*values)[j] = sampled; // store result in partial solution
}

/* ******************************************************************************** */
/* ************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {

// TODO: is this really the fastest way? I think it is.
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)

// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way
Expand Down
17 changes: 14 additions & 3 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values);
}

/** Restrict to given parent values, returns DecisionTreeFactor */
DecisionTreeFactor::shared_ptr choose(
const DiscreteValues& parentsValues) const;
/**
* @brief restrict to given *parent* values.
*
* Note: does not need be complete set. Examples:
*
* P(C|D,E) + . -> P(C|D,E)
* P(C|D,E) + E -> P(C|D)
* P(C|D,E) + D -> P(C|E)
* P(C|D,E) + D,E -> P(C)
* P(C|D,E) + C -> error!
*
* @return a shared_ptr to a new DiscreteConditional
*/
shared_ptr choose(const DiscreteValues& given) const;

/** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood(
Expand Down
43 changes: 23 additions & 20 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,33 +64,35 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor
*/
class GTSAM_EXPORT DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
class GTSAM_EXPORT DiscreteFactorGraph
: public FactorGraph<DiscreteFactor>,
public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
using This = DiscreteFactorGraph; ///< this class
using Base = FactorGraph<DiscreteFactor>; ///< base factor graph type
using BaseEliminateable =
EliminateableFactorGraph<This>; ///< for elimination
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This

typedef DiscreteFactorGraph This; ///< Typedef to this class
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
using Values = DiscreteValues; ///< backwards compatibility

using Values = DiscreteValues; ///< backwards compatibility

/** A map from keys to values */
typedef KeyVector Indices;
using Indices = KeyVector; ///> map from keys to values

/** Default constructor */
DiscreteFactorGraph() {}

/** Construct from iterator over factors */
template<typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
template <typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor)
: Base(firstFactor, lastFactor) {}

/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
template <class CONTAINER>
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}

/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDFACTOR>
/** Implicit copy/downcast constructor to override explicit template container
* constructor */
template <class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}

/// Destructor
Expand All @@ -108,7 +110,7 @@ public EliminateableFactorGraph<DiscreteFactorGraph> {
void add(Args&&... args) {
emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...);
}

/** Return the set of variables involved in the factors (set union) */
KeySet keys() const;

Expand Down Expand Up @@ -163,9 +165,10 @@ public EliminateableFactorGraph<DiscreteFactorGraph> {
const DiscreteFactor::Names& names = {}) const;

/// @}
}; // \ DiscreteFactorGraph
}; // \ DiscreteFactorGraph

/// traits
template<> struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};

} // \ namespace gtsam
} // namespace gtsam
20 changes: 14 additions & 6 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
void printSignature(
string s = "Discrete Conditional: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
gtsam::DecisionTreeFactor* choose(
const gtsam::DiscreteValues& parentsValues) const;
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
Expand Down Expand Up @@ -230,11 +229,16 @@ class DiscreteFactorGraph {
DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);

void add(const gtsam::DiscreteKey& j, string table);
// Building the graph
void push_back(const gtsam::DiscreteFactor* factor);
void push_back(const gtsam::DiscreteConditional* conditional);
void push_back(const gtsam::DiscreteFactorGraph& graph);
void push_back(const gtsam::DiscreteBayesNet& bayesNet);
void push_back(const gtsam::DiscreteBayesTree& bayesTree);
void add(const gtsam::DiscreteKey& j, string spec);
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);

void add(const gtsam::DiscreteKeys& keys, string table);
void add(const std::vector<gtsam::DiscreteKey>& keys, string table);
void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);

bool empty() const;
size_t size() const;
Expand All @@ -258,8 +262,12 @@ class DiscreteFactorGraph {

gtsam::DiscreteBayesNet eliminateSequential();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
eliminatePartialSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree eliminateMultifrontal();
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);

string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand Down
28 changes: 28 additions & 0 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,34 @@ TEST(DiscreteConditional, likelihood) {
EXPECT(assert_equal(expected1, *actual1, 1e-9));
}

/* ************************************************************************* */
// Check choose on P(C|D,E)
TEST(DiscreteConditional, choose) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");

// Case 1: no given values: no-op
DiscreteValues given;
auto actual1 = C_given_DE.choose(given);
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));

// Case 2: 1 given value
given[D.first] = 1;
auto actual2 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
DiscreteConditional expected2(C | E = "1/1 1/4");
EXPECT(assert_equal(expected2, *actual2, 1e-9));

// Case 2: 2 given values
given[E.first] = 0;
auto actual3 = C_given_DE.choose(given);
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
DiscreteConditional expected3(C % "1/1");
EXPECT(assert_equal(expected3, *actual3, 1e-9));
}

/* ************************************************************************* */
// Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) {
Expand Down
22 changes: 15 additions & 7 deletions gtsam/discrete/tests/testDiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,12 @@ TEST(DiscreteFactorGraph, Dot) {
" var1[label=\"1\"];\n"
" var2[label=\"2\"];\n"
"\n"
" var0--var1;\n"
" var0--var2;\n"
" factor0[label=\"\", shape=point];\n"
" var0--factor0;\n"
" var1--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" var0--factor1;\n"
" var2--factor1;\n"
"}\n";
EXPECT(actual == expected);
}
Expand All @@ -397,12 +401,16 @@ TEST(DiscreteFactorGraph, DotWithNames) {
"graph {\n"
" size=\"5,5\";\n"
"\n"
" var0[label=\"C\"];\n"
" var1[label=\"A\"];\n"
" var2[label=\"B\"];\n"
" varC[label=\"C\"];\n"
" varA[label=\"A\"];\n"
" varB[label=\"B\"];\n"
"\n"
" var0--var1;\n"
" var0--var2;\n"
" factor0[label=\"\", shape=point];\n"
" varC--factor0;\n"
" varA--factor0;\n"
" factor1[label=\"\", shape=point];\n"
" varC--factor1;\n"
" varB--factor1;\n"
"}\n";
EXPECT(actual == expected);
}
Expand Down
Loading