diff --git a/src/classification/main.jl b/src/classification/main.jl index 24f7d344..0f3f6926 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -180,7 +180,7 @@ n_labels` matrix of probabilities, each row summing up to 1. (eg. ["versicolor", "virginica", "setosa"]). It specifies the column ordering of the output matrix. """ apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} = - collect(leaf.values ./ leaf.total) + leaf.values ./ leaf.total function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels) where {S, T} if tree.featval === nothing @@ -192,8 +192,13 @@ function apply_tree_proba(tree::Node{S, T}, features::AbstractVector{S}, labels) end end -apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T} = - stack_function_results(row->apply_tree_proba(tree, row, labels), features) +function apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T} + predictions = Vector{NTuple{length(labels), Float64}}(undef, size(features, 1)) + for i in 1:size(features, 1) + predictions[i] = apply_tree_proba(tree, view(features, i, :), labels) + end + reinterpret(reshape, Float64, predictions) |> transpose |> Matrix +end function build_forest( labels :: AbstractVector{T},