Skip to content

Commit

Permalink
revert pattern
Browse files Browse the repository at this point in the history
When learnPattern is set, use learnPattern.
inferPattern is set, use inferPattern
  • Loading branch information
Lupino committed May 6, 2020
1 parent 27aec54 commit 2f47541
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/htm/regions/ClassifierRegion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ namespace htm {
inputs: {
bucket: { description: "The quantized value of the current sample, one from each encoder if more than one, for the learn step",
type: Real64, count: 0},
pattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: activeCells from TM",
type: SDR, count: 0},
inferPattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: predictiveCells from TM",
type: SDR, count: 0},
learnPattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: activeCells from TM",
Expand Down Expand Up @@ -140,8 +142,8 @@ Dimensions ClassifierRegion::askImplForOutputDimensions(const std::string &name)


void ClassifierRegion::compute() {
SDR &pattern = getInput("pattern")->getData().getSDR();
if (learn_) {
SDR &learnPattern = getInput("learnPattern")->getData().getSDR();
Array &b = getInput("bucket")->getData();
// 'bucket' is a list of quantized samples being processed for this iteration.
// There are one of these for each encoder (or value being encoded).
Expand All @@ -165,14 +167,25 @@ void ClassifierRegion::compute() {
}
categoryIdxList.push_back(c);
}
classifier_->learn(learnPattern, categoryIdxList);

SDR &learnPattern = getInput("learnPattern")->getData().getSDR();
if (learnPattern.size == 0) {
classifier_->learn(pattern, categoryIdxList);
} else {
classifier_->learn(learnPattern, categoryIdxList);
}
}

SDR &inferPattern = getInput("inferPattern")->getData().getSDR();
// Note: if there is no link to 'inferPattern' input, the 'inferPattern' SDR length is 0
// and SDRClassifier::infer() will throw an exception.

PDF pdf = classifier_->infer(inferPattern);
//
PDF pdf;
if (inferPattern.size == 0) {
pdf = classifier_->infer(pattern);
} else {
pdf = classifier_->infer(inferPattern);
}

// Adjust the buffer size to match the pdf.
if (getOutput("pdf")->getData().getCount() < pdf.size()) {
Expand Down

0 comments on commit 2f47541

Please # to comment.