Skip to content

Commit

Permalink
Update predict.jl (#244)
Browse files Browse the repository at this point in the history
* Update predict.jl

* fixed the bug urgh
  • Loading branch information
behinger authored Jan 15, 2025
1 parent 778c2db commit 7c7d6ea
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ function _residuals(::Type{T}, yhat, y) where {T<:UnfoldModel}#; ContinuousTimeT
n_y = size(y, 2)
@debug n_yhat n_y
if n_yhat >= n_y
@debug "n_yhat > n_y" size(y) size.(_split_data(yhat, n_y))
@debug "n_yhat > n_y, yhat is longer" size(y) size.(_split_data(yhat, n_y))
return y .- _split_data(yhat, n_y)[1]
else
@debug "n_y < n_yhat"
@debug "n_yhat < n_y, y is longer"
yA, yB = _split_data(y, n_yhat)
@debug size(yA) size(yB)
res = yA .- n_y
res = yA .- yhat
return cat(res, yB; dims = 2)
end

Expand Down
13 changes: 13 additions & 0 deletions test/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,17 @@ pt = Unfold.result_to_table(m, p, repeat([evts], 2))
@test maximum(abs.(data_e .- (resids_e.+predict(m_mul)[1])[1, :, :])) < 0.0000001


##


@test all(Unfold._residuals(UnfoldModel, [1 2 3; 3 4 5], [1 2 3; 3 4 5]) .== 0)

# y longer
res = Unfold._residuals(UnfoldModel, [1 2 3; 3 4 5], [1 2 3 4; 3 4 5 6])
@test all(res[:, 1:3] .== 0)
@test res[:, 4] == [4, 6]

# yhat longer
@test all(Unfold._residuals(UnfoldModel, [1 2 3 4; 3 4 5 6], [1 2 3; 3 4 5]) .== 0)

end

0 comments on commit 7c7d6ea

Please # to comment.