Skip to content

Commit

Permalink
Merge pull request #163 from unfoldtoolbox/julia19_invfix
Browse files Browse the repository at this point in the history
fix for SE calculation breaking CI
  • Loading branch information
behinger authored Feb 2, 2024
2 parents 458f50a + 3373071 commit d1f0fef
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 35 deletions.
69 changes: 35 additions & 34 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,52 @@ using StatsBase: var
function solver_default(
X,
data::AbstractArray{T,2};
stderror = false,
multithreading = true,
showprogress = true,
stderror=false,
multithreading=true,
showprogress=true,
) where {T<:Union{Missing,<:Number}}
minfo = Array{IterativeSolvers.ConvergenceHistory,1}(undef,size(data,1))
beta = zeros(size(data,1),size(X,2)) # had issues with undef
minfo = Array{IterativeSolvers.ConvergenceHistory,1}(undef, size(data, 1))

beta = zeros(size(data, 1), size(X, 2)) # had issues with undef

p = Progress(size(data,1);enabled=showprogress)
p = Progress(size(data, 1); enabled=showprogress)
@maybe_threads multithreading for ch = 1:size(data, 1)
dd = view(data, ch, :)
dd = view(data, ch, :)
ix = @. !ismissing(dd)
# use the previous channel as a starting point
ch == 1 || copyto!(view(beta, ch, :), view(beta, ch-1, :))
beta[ch,:],h = lsmr!(@view(beta[ch, :]), (X[ix,:]), @view(data[ch, ix]),log=true)
# use the previous channel as a starting point
ch == 1 || copyto!(view(beta, ch, :), view(beta, ch - 1, :))

beta[ch, :], h = lsmr!(@view(beta[ch, :]), (X[ix, :]), @view(data[ch, ix]), log=true)

minfo[ch] = h
next!(p)
end
finish!(p)

if stderror
stderror = calculate_stderror(X, data, beta)
modelfit = Unfold.LinearModelFit(beta, ["lsmr",minfo], stderror)
stderror = calculate_stderror(X, data, beta)
modelfit = Unfold.LinearModelFit(beta, ["lsmr", minfo], stderror)
else
modelfit = Unfold.LinearModelFit(beta, ["lsmr",minfo])
modelfit = Unfold.LinearModelFit(beta, ["lsmr", minfo])
end
return modelfit
end

function calculate_stderror(Xdc, data::Matrix{T}, beta) where {T<:Union{Missing,<:Number}}

# remove missings
ix = any(.!ismissing.(data), dims = 1)[1, :]
ix = any(.!ismissing.(data), dims=1)[1, :]
if length(ix) != size(data, 2)
@warn(
"Limitation: Missing data are calculated over all channels for standard error"
)
end

data = data[:, ix]
Xdc = Xdc[ix, :]

# Hat matrix only once
hat_prime = inv(Matrix(Xdc' * Xdc))
hat_prime = inv(disallowmissing(Matrix(Xdc' * Xdc)))
# Calculate residual variance
@warn(
"Autocorrelation was NOT taken into account. Therefore SE are UNRELIABLE. Use at your own discretion"
Expand Down Expand Up @@ -92,21 +93,21 @@ end
function solver_default(
X,
data::AbstractArray{T,3};
stderror = false,
stderror=false,
multithreading=true,
showprogress = true,
showprogress=true,
) where {T<:Union{Missing,<:Number}}
#beta = Array{Union{Missing,Number}}(undef, size(data, 1), size(data, 2), size(X, 2))
beta = zeros(Union{Missing,Number},size(data, 1), size(data, 2), size(X, 2))
p = Progress(size(data,1);enabled=showprogress)
beta = zeros(Union{Missing,Number}, size(data, 1), size(data, 2), size(X, 2))
p = Progress(size(data, 1); enabled=showprogress)
@maybe_threads multithreading for ch = 1:size(data, 1)
for t = 1:size(data, 2)
# @debug("$(ndims(data,)),$t,$ch")
dd = view(data, ch, t,:)
# @debug("$(ndims(data,)),$t,$ch")

dd = view(data, ch, t, :)
ix = @. !ismissing(dd)
beta[ch,t,:] = @view(X[ix,:]) \ @view(data[ch,t,ix])

beta[ch, t, :] = @view(X[ix, :]) \ @view(data[ch, t, ix])
# qr(X) was slower on Februar 2022
end
next!(p)
Expand All @@ -121,12 +122,12 @@ function solver_default(
return modelfit
end

solver_b2b(X, data, cross_val_reps) = solver_b2b(X, data, cross_val_reps = cross_val_reps)
solver_b2b(X, data, cross_val_reps) = solver_b2b(X, data, cross_val_reps=cross_val_reps)
function solver_b2b(
X,
data::AbstractArray{T,3};
cross_val_reps = 10,
multithreading = true,
cross_val_reps=10,
multithreading=true,
showprogress=true,
) where {T<:Union{Missing,<:Number}}

Expand All @@ -136,12 +137,12 @@ function solver_b2b(
E = zeros(size(data, 2), size(X, 2), size(X, 2))
W = Array{Float64}(undef, size(data, 2), size(X, 2), size(data, 1))

prog = Progress(size(data, 2) * cross_val_reps, 0.1;enabled=showprogress)
prog = Progress(size(data, 2) * cross_val_reps, 0.1; enabled=showprogress)
@maybe_threads multithreading for m = 1:cross_val_reps
k_ix = collect(Kfold(size(data, 3), 2))
X1 = @view X[k_ix[1], :]
X2 = @view X[k_ix[2], :]

for t = 1:size(data, 2)

Y1 = @view data[:, t, k_ix[1]]
Expand All @@ -151,16 +152,16 @@ function solver_b2b(
G = (Y1' \ X1)
H = X2 \ (Y2' * G)

E[t, :, :] += Diagonal(H[diagind(H)])
ProgressMeter.next!(prog; showvalues = [(:time, t), (:cross_val_rep, m)])
E[t, :, :] += Diagonal(H[diagind(H)])
ProgressMeter.next!(prog; showvalues=[(:time, t), (:cross_val_rep, m)])
end
E[t, :, :] = E[t, :, :] ./ cross_val_reps
W[t, :, :] = (X * E[t, :, :])' / data[:, t, :]

end

# extract diagonal
beta = mapslices(diag, E, dims = [2, 3])
beta = mapslices(diag, E, dims=[2, 3])
# reshape to conform to ch x time x pred
beta = permutedims(beta, [3 1 2])
modelinfo = Dict("W" => W, "E" => E, "cross_val_reps" => cross_val_reps) # no history implemented (yet?)
Expand Down
7 changes: 6 additions & 1 deletion test/splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ end
@testset "safe prediction outside bounds" begin
# test safe predict
m = fit(UnfoldModel, f_spl, evts, data_r, basisfunction)
@test_broken predict(m, DataFrame(conditionA=[0, 0, 0], continuousA=[0.9, 0.9, 1.9]))

p = predict(m, DataFrame(conditionA=[0, 0, 0], continuousA=[0.9, 0.9, 1.9]))
@test all(ismissing.(p[p.continuousA.==1.9, :yhat]))

end
#@test_broken all(ismissing.)

Expand Down Expand Up @@ -69,9 +72,11 @@ end
@test tmp.yhat[tmp.continuousA.==-1.1] tmp.yhat[tmp.continuousA.==0.9]
@test tmp.yhat[tmp.continuousA.==-1.0] tmp.yhat[tmp.continuousA.==1]
@test tmp.yhat[tmp.continuousA.==-0.9] tmp.yhat[tmp.continuousA.==1.1]

end

@testset "minimal number of splines" begin
f_spl = @formula 0 ~ 1 + conditionA + spl(continuousA, 3) # 1
@test_throws AssertionError fit(UnfoldModel, f_spl, evts, data_e, times)

end

0 comments on commit d1f0fef

Please # to comment.