Skip to content

Commit

Permalink
Merge pull request #164 from unfoldtoolbox/bsplinekit17
Browse files Browse the repository at this point in the history
fix #124, assert min number of spines; compat 0.17 BSplineKit
  • Loading branch information
behinger authored Feb 2, 2024
2 parents 2191dd8 + b4822a6 commit 4c9accc
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 77 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ UnfoldRobustModelsExt = "RobustModels"
UnfoldKrylovExt = ["Krylov","CUDA"]

[compat]
BSplineKit = "0.16"
BSplineKit = "0.16,0.17"
CUDA = "4,5"
DSP = "0.7"
DataFrames = "1"
Expand Down
101 changes: 51 additions & 50 deletions ext/UnfoldBSplineKitExt/splinepredictors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,103 +23,104 @@ minus one due to the intercept.
Note: that due to the boundary condition (`natural`) spline, we repeat the boundary knots to each side `order` times, enforcing smoothness there - this is done within BSplineKit
"""
function genSpl_breakpoints(p::AbstractSplineTerm,x)
p = range(0.0, length = p.df-2, stop = 1.0)
function genSpl_breakpoints(p::AbstractSplineTerm, x)
p = range(0.0, length=p.df - 2, stop=1.0)
breakpoints = quantile(x, p)
return breakpoints
end

"""
In the circular case, we do not use quantiles, (circular quantiles are difficult)
"""
function genSpl_breakpoints(p::PeriodicBSplineTerm,x)
function genSpl_breakpoints(p::PeriodicBSplineTerm, x)
# periodic case -
return range(p.low,p.high,length=p.df+2)
return range(p.low, p.high, length=p.df + 2)
end

"""
function that fills in an Matrix `large` according to the evaluated values in `x` with the designmatrix values for the spline.
Two separate functions are needed here, as the periodicBSplineBasis implemented in BSplineKits is a bit weird, that it requires to evalute "negative" knots + knots above the top-boundary and fold them down
"""
function spl_fillMat!(bs::PeriodicBSplineBasis,large::Matrix,x::AbstractVector)
function spl_fillMat!(bs::PeriodicBSplineBasis, large::Matrix, x::AbstractVector)
# wrap values around the boundaries
bnds = boundaries(bs)
x = deepcopy(x)
x = mod.(x .- bnds[1],period(bs)) .+ bnds[1]
x = mod.(x .- bnds[1], period(bs)) .+ bnds[1]

for k = -1:length(bs)+2
ix = basis_to_array_index(bs,axes(large,2),k)
large[:,ix] .+= bs[k](x)
ix = basis_to_array_index(bs, axes(large, 2), k)
large[:, ix] .+= bs[k](x)
end
end
function spl_fillMat!(bs::BSplineBasis,large::Matrix,x::AbstractVector)
function spl_fillMat!(bs::BSplineBasis, large::Matrix, x::AbstractVector)
for k = 1:length(bs)
large[:,k] .+= bs[k](x)

large[:, k] .+= bs[k](x)
end

bnds = boundaries(bs)
ix = x .< bnds[1] .|| x .>bnds[2]
ix = x .< bnds[1] .|| x .> bnds[2]

if sum(ix) != 0
@warn("spline prediction outside of possible range putting those values to missing.\n `findfirst(Out-Of-Bound-value)` is x=$(x[findfirst(ix)]), with bounds: $bnds")
large[ix,:] .= missing
large[ix, :] .= missing
end

end

"""
evaluate a spline basisset `basis` at `x`
returns `Missing` if x is outside of the basis set
"""
function splFunction(x, bs)
# init array
large = zeros(Union{Missing,Float64},length(x), length(bs))
large = zeros(Union{Missing,Float64}, length(x), length(bs))

# fill it with spline values
spl_fillMat!(bs,large,x)
spl_fillMat!(bs, large, x)

return large
end

function splFunction(x,spl::PeriodicBSplineTerm)
basis = PeriodicBSplineBasis(BSplineOrder(spl.order),deepcopy(spl.breakpoints))
splFunction(x,basis)
function splFunction(x, spl::PeriodicBSplineTerm)
basis = PeriodicBSplineBasis(BSplineOrder(spl.order), deepcopy(spl.breakpoints))
splFunction(x, basis)
end

function splFunction(x,spl::BSplineTerm)
basis = BSplineKit.BSplineBasis(BSplineOrder(spl.order),deepcopy(spl.breakpoints))
splFunction(x,basis)
function splFunction(x, spl::BSplineTerm)
basis = BSplineKit.BSplineBasis(BSplineOrder(spl.order), deepcopy(spl.breakpoints))
splFunction(x, basis)
end
#spl(x,df) = Splines2.bs(x,df=df,intercept=true) # assumes intercept
Unfold.spl(x, df) = 0 # fallback

# make a nice call if the function is called via REPL
Unfold.spl(t::Symbol, d::Int) = BSplineTerm(term(t), d, 4,[])
Unfold.circspl(t::Symbol, d::Int,low,high) = PeriodicBSplineTerm(term(t), term(d),4,low,high)
Unfold.spl(t::Symbol, d::Int) = BSplineTerm(term(t), d, 4, [])
Unfold.circspl(t::Symbol, d::Int, low, high) = PeriodicBSplineTerm(term(t), term(d), 4, low, high)

"""
Construct a BSplineTerm, if breakpoints/basis are not defined yet, put to `nothing`
"""
function BSplineTerm(term, df,order=4)
BSplineTerm(term, df,order,[])
function BSplineTerm(term, df, order=4)
@assert df > 3 "Minimal degrees of freedom has to be 4"
BSplineTerm(term, df, order, [])
end

function BSplineTerm(term, df::ConstantTerm,order=4)
BSplineTerm(term, df.n,order,[])
function BSplineTerm(term, df::ConstantTerm, order=4)
BSplineTerm(term, df.n, order, [])
end


function PeriodicBSplineTerm(term, df,low,high)
PeriodicBSplineTerm(term, df,4,low,high)
function PeriodicBSplineTerm(term, df, low, high)
PeriodicBSplineTerm(term, df, 4, low, high)
end
function PeriodicBSplineTerm(term::AbstractTerm, df::ConstantTerm,order,low::ConstantTerm,high::ConstantTerm,breakvec)
PeriodicBSplineTerm(term, df.n,order,low.n,high.n,breakvec)
function PeriodicBSplineTerm(term::AbstractTerm, df::ConstantTerm, order, low::ConstantTerm, high::ConstantTerm, breakvec)
PeriodicBSplineTerm(term, df.n, order, low.n, high.n, breakvec)
end
function PeriodicBSplineTerm(term, df,order,low,high)
PeriodicBSplineTerm(term, df,order,low,high,[])
function PeriodicBSplineTerm(term, df, order, low, high)
PeriodicBSplineTerm(term, df, order, low, high, [])
end

Base.show(io::IO, p::BSplineTerm) = print(io, "spl($(p.term), $(p.df))")
Expand All @@ -145,7 +146,7 @@ function StatsModels.apply_schema(
sch::StatsModels.Schema,
Mod::Type{<:bsPLINE_CONTEXT},
)
ar = nothing
ar = nothing
try
ar = t.args
catch
Expand All @@ -158,7 +159,7 @@ function StatsModels.apply_schema(
sch::StatsModels.Schema,
Mod::Type{<:bsPLINE_CONTEXT},
)
@debug "BSpline Inner schema"
@debug "BSpline Inner schema"
term = apply_schema(t.term, sch, Mod)
isa(term, ContinuousTerm) ||
throw(ArgumentError("BSplineTerm only works with continuous terms (got $term)"))
Expand All @@ -168,34 +169,34 @@ function StatsModels.apply_schema(
# in case of ConstantTerm of Èffects.jl``
t.df.n
catch
throw(ArgumentError("BSplineTerm df must be a number (got $(t.df))"))
throw(ArgumentError("BSplineTerm df must be a number (got $(t.df))"))
end
end
return construct_spline(t,term)
end
construct_spline(t::BSplineTerm,term)=BSplineTerm(term, t.df,t.order)
construct_spline(t::PeriodicBSplineTerm,term)=PeriodicBSplineTerm(term, t.df,t.order,t.low,t.high)
return construct_spline(t, term)
end
construct_spline(t::BSplineTerm, term) = BSplineTerm(term, t.df, t.order)
construct_spline(t::PeriodicBSplineTerm, term) = PeriodicBSplineTerm(term, t.df, t.order, t.low, t.high)

function StatsModels.modelcols(p::AbstractSplineTerm, d::NamedTuple)

col = modelcols(p.term, d)

if isempty(p.breakpoints)
p.breakpoints = genSpl_breakpoints(p,col)
p.breakpoints = genSpl_breakpoints(p, col)
end

#basis = genSpl_basis(pp.breakpoints,p.order)#Splines2.bs_(col,df=p.df+1,intercept=true)

#X = Splines2.bs(col, df=p.df+1,intercept=true)
X = splFunction(col,p)
X = splFunction(col, p)

# remove middle X to negate intercept = true, generating a pseudo effect code
return X[:, Not(Int(ceil(end / 2)))]
end

#StatsModels.terms(p::BSplineTerm) = terms(p.term)
StatsModels.termvars(p::AbstractSplineTerm) = StatsModels.termvars(p.term)
StatsModels.width(p::AbstractSplineTerm) = p.df-1
StatsModels.width(p::AbstractSplineTerm) = p.df - 1
StatsModels.coefnames(p::BSplineTerm) =
"spl(" .* coefnames(p.term) .* "," .* string.(1:p.df-1) .* ")"
StatsModels.coefnames(p::PeriodicBSplineTerm) =
Expand Down
57 changes: 31 additions & 26 deletions test/splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ data, evts = loadtestdata("test_case_3a") #
f_spl = @formula 0 ~ 1 + conditionA + spl(continuousA, 4) # 1
f = @formula 0 ~ 1 + conditionA + continuousA # 1
data_r = reshape(data, (1, :))
data_e, times = Unfold.epoch(data = data_r, tbl = evts, τ = (-1.0, 1.0), sfreq = 10) # cut the data into epochs
data_e, times = Unfold.epoch(data=data_r, tbl=evts, τ=(-1.0, 1.0), sfreq=10) # cut the data into epochs

m_mul = coeftable(fit(UnfoldModel, f, evts, data_e, times))
m_mul_spl = coeftable(fit(UnfoldModel, f_spl, evts, data_e, times))
Expand All @@ -18,23 +18,23 @@ s = Unfold.formula(fit(UnfoldModel, f_spl, evts, data_e, times)).rhs.terms[3]
@test s.df == 4

@testset "outside bounds" begin
# test safe prediction
m = fit(UnfoldModel, f_spl, evts, data_e, times)
r = predict(m,DataFrame(conditionA=[0,0],continuousA=[0.9,1.9]))
@test all(ismissing.(r.yhat[r.continuousA.==1.9]))
@test !any(ismissing.(r.yhat[r.continuousA.==0.9]))
# test safe prediction
m = fit(UnfoldModel, f_spl, evts, data_e, times)
r = predict(m, DataFrame(conditionA=[0, 0], continuousA=[0.9, 1.9]))
@test all(ismissing.(r.yhat[r.continuousA.==1.9]))
@test !any(ismissing.(r.yhat[r.continuousA.==0.9]))
end

basisfunction = firbasis = (-1, 1), sfreq = 10, name = "A")
basisfunction = firbasis=(-1, 1), sfreq=10, name="A")
@testset "timeexpanded" begin
# test time expanded
m_tul = coeftable(fit(UnfoldModel, f, evts, data_r, basisfunction))
m_tul_spl = coeftable(fit(UnfoldModel, f_spl, evts, data_r, basisfunction))
# test time expanded
m_tul = coeftable(fit(UnfoldModel, f, evts, data_r, basisfunction))
m_tul_spl = coeftable(fit(UnfoldModel, f_spl, evts, data_r, basisfunction))
end
@testset "safe prediction outside bounds" begin
# test safe predict
m = fit(UnfoldModel, f_spl, evts, data_r, basisfunction)
@test_broken predict(m,DataFrame(conditionA=[0,0,0],continuousA=[0.9,0.9,1.9]))
# test safe predict
m = fit(UnfoldModel, f_spl, evts, data_r, basisfunction)
@test_broken predict(m, DataFrame(conditionA=[0, 0, 0], continuousA=[0.9, 0.9, 1.9]))
end
#@test_broken all(ismissing.)

Expand All @@ -46,27 +46,32 @@ if 1 == 0
using AlgebraOfGraphics
yhat_mul.conditionA = categorical(yhat_mul.conditionA)
yhat_mul.continuousA = categorical(yhat_mul.continuousA)
m = mapping(:times, :yhat, color = :continuousA, linestyle = :conditionA)
m = mapping(:times, :yhat, color=:continuousA, linestyle=:conditionA)
df = yhat_mul
AlgebraOfGraphics.data(df) * visual(Lines) * m |> draw
end
@testset "many splines" begin
# test much higher number of splines
f_spl_many = @formula 0 ~ 1 + spl(continuousA, 131) # 1
m_mul_spl_many = coeftable(fit(UnfoldModel, f_spl_many, evts, data_e, times))
@test length(unique(m_mul_spl_many.coefname)) == 131
# test much higher number of splines
f_spl_many = @formula 0 ~ 1 + spl(continuousA, 131) # 1
m_mul_spl_many = coeftable(fit(UnfoldModel, f_spl_many, evts, data_e, times))
@test length(unique(m_mul_spl_many.coefname)) == 131
end

@testset "PeriodicSplines" begin
f_circspl = @formula 0 ~ 1 + circspl(continuousA, 10,-1,1) # 1
f_circspl = @formula 0 ~ 1 + circspl(continuousA, 10, -1, 1) # 1
m = fit(UnfoldModel, f_circspl, evts, data_e, times)
f_evaluated = Unfold.formula(m)
effValues = [-1,-0.99,0,0.99,1]
effValues = range(-1.1,1.1,step=0.1)

effValues = [-1, -0.99, 0, 0.99, 1]
effValues = range(-1.1, 1.1, step=0.1)
effSingle = effects(Dict(:continuousA => effValues), m)
tmp = subset(effSingle,:time =>x->x.== -1.0)
@test tmp.yhat[tmp.continuousA .== -1.1] tmp.yhat[tmp.continuousA .== 0.9]
@test tmp.yhat[tmp.continuousA .== -1.0] tmp.yhat[tmp.continuousA .== 1]
@test tmp.yhat[tmp.continuousA .== -0.9] tmp.yhat[tmp.continuousA .== 1.1]
tmp = subset(effSingle, :time => x -> x .== -1.0)
@test tmp.yhat[tmp.continuousA.==-1.1] tmp.yhat[tmp.continuousA.==0.9]
@test tmp.yhat[tmp.continuousA.==-1.0] tmp.yhat[tmp.continuousA.==1]
@test tmp.yhat[tmp.continuousA.==-0.9] tmp.yhat[tmp.continuousA.==1.1]
end

@testset "minimal number of splines" begin
f_spl = @formula 0 ~ 1 + conditionA + spl(continuousA, 3) # 1
@test_throws AssertionError fit(UnfoldModel, f_spl, evts, data_e, times)
end

0 comments on commit 4c9accc

Please # to comment.