From 2f47541b186d2f0c23d7de216c46b1d2970929b9 Mon Sep 17 00:00:00 2001 From: Lupino Date: Wed, 6 May 2020 16:05:28 +0800 Subject: [PATCH] revert pattern When learnPattern is set, use learnPattern. inferPattern is set, use inferPattern --- src/htm/regions/ClassifierRegion.cpp | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/htm/regions/ClassifierRegion.cpp b/src/htm/regions/ClassifierRegion.cpp index ea3a0acc1f..0db388e99b 100644 --- a/src/htm/regions/ClassifierRegion.cpp +++ b/src/htm/regions/ClassifierRegion.cpp @@ -79,6 +79,8 @@ namespace htm { inputs: { bucket: { description: "The quantized value of the current sample, one from each encoder if more than one, for the learn step", type: Real64, count: 0}, + pattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: activeCells from TM", + type: SDR, count: 0}, inferPattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: predictiveCells from TM", type: SDR, count: 0}, learnPattern: { description: "An SDR output bit pattern for a sample. Usually the output of the SP or TM. For example: activeCells from TM", @@ -140,8 +142,8 @@ Dimensions ClassifierRegion::askImplForOutputDimensions(const std::string &name) void ClassifierRegion::compute() { + SDR &pattern = getInput("pattern")->getData().getSDR(); if (learn_) { - SDR &learnPattern = getInput("learnPattern")->getData().getSDR(); Array &b = getInput("bucket")->getData(); // 'bucket' is a list of quantized samples being processed for this iteration. // There are one of these for each encoder (or value being encoded). @@ -165,14 +167,25 @@ void ClassifierRegion::compute() { } categoryIdxList.push_back(c); } - classifier_->learn(learnPattern, categoryIdxList); + + SDR &learnPattern = getInput("learnPattern")->getData().getSDR(); + if (learnPattern.size == 0) { + classifier_->learn(pattern, categoryIdxList); + } else { + classifier_->learn(learnPattern, categoryIdxList); + } } SDR &inferPattern = getInput("inferPattern")->getData().getSDR(); // Note: if there is no link to 'inferPattern' input, the 'inferPattern' SDR length is 0 // and SDRClassifier::infer() will throw an exception. - - PDF pdf = classifier_->infer(inferPattern); + // + PDF pdf; + if (inferPattern.size == 0) { + pdf = classifier_->infer(pattern); + } else { + pdf = classifier_->infer(inferPattern); + } // Adjust the buffer size to match the pdf. if (getOutput("pdf")->getData().getCount() < pdf.size()) {