From 908701fe4188ba3b656b774a77aefb566273fc30 Mon Sep 17 00:00:00 2001 From: TEC Date: Fri, 24 Jun 2022 11:02:05 +0800 Subject: [PATCH] Directly operate on leaf tuples --- src/classification/main.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/classification/main.jl b/src/classification/main.jl index 24f7d344..0f3f6926 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -180,7 +180,7 @@ n_labels` matrix of probabilities, each row summing up to 1. (eg. ["versicolor", "virginica", "setosa"]). It specifies the column ordering of the output matrix. """ apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} = - collect(leaf.values ./ leaf.total) + leaf.values ./ leaf.total function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels) where {S, T} if tree.featval === nothing @@ -192,8 +192,13 @@ function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels) end end -apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T} = - stack_function_results(row->apply_tree_proba(tree, row, labels), features) +function apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T} + predictions = Vector{NTuple{length(labels), Float64}}(undef, size(features, 1)) + for i in 1:size(features, 1) + predictions[i] = apply_tree_proba(tree, view(features, i, :), labels) + end + reinterpret(reshape, Float64, predictions) |> transpose |> Matrix +end function build_forest( labels :: AbstractVector{T},