From 7c7d6ead95a42c6855902fab1bb4aab20649b955 Mon Sep 17 00:00:00 2001 From: Benedikt Ehinger Date: Wed, 15 Jan 2025 14:15:54 +0100 Subject: [PATCH] Update predict.jl (#244) * Update predict.jl * fixed the bug urgh --- src/predict.jl | 6 +++--- test/predict.jl | 13 +++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/predict.jl b/src/predict.jl index 48cdcfc2..2bf947b4 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -127,13 +127,13 @@ function _residuals(::Type{T}, yhat, y) where {T<:UnfoldModel}#; ContinuousTimeT n_y = size(y, 2) @debug n_yhat n_y if n_yhat >= n_y - @debug "n_yhat > n_y" size(y) size.(_split_data(yhat, n_y)) + @debug "n_yhat > n_y, yhat is longer" size(y) size.(_split_data(yhat, n_y)) return y .- _split_data(yhat, n_y)[1] else - @debug "n_y < n_yhat" + @debug "n_yhat < n_y, y is longer" yA, yB = _split_data(y, n_yhat) @debug size(yA) size(yB) - res = yA .- n_y + res = yA .- yhat return cat(res, yB; dims = 2) end diff --git a/test/predict.jl b/test/predict.jl index d6ccb4d5..f5c67e2b 100644 --- a/test/predict.jl +++ b/test/predict.jl @@ -150,4 +150,17 @@ pt = Unfold.result_to_table(m, p, repeat([evts], 2)) @test maximum(abs.(data_e .- (resids_e.+predict(m_mul)[1])[1, :, :])) < 0.0000001 + ## + + + @test all(Unfold._residuals(UnfoldModel, [1 2 3; 3 4 5], [1 2 3; 3 4 5]) .== 0) + + # y longer + res = Unfold._residuals(UnfoldModel, [1 2 3; 3 4 5], [1 2 3 4; 3 4 5 6]) + @test all(res[:, 1:3] .== 0) + @test res[:, 4] == [4, 6] + + # yhat longer + @test all(Unfold._residuals(UnfoldModel, [1 2 3 4; 3 4 5 6], [1 2 3; 3 4 5]) .== 0) + end