Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

refactored - breaking - improved ttest performance, inplace ttest #9

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions src/cluster.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,79 @@
"""

clusterdepth(rng,data::AbstractArray;τ=2.3, statFun=x->abs.(studentt(x)),permFun=sign_permute!,nperm=5000,pval_type=:troendle)
clusterdepth(rng,data::AbstractArray;τ=2.3, statfun=x->abs.(studentt(x)),permfun=sign_permute!,nperm=5000,pval_type=:troendle)

calculate clusterdepth of given datamatrix.


- `data`: `statFun` will be applied on second dimension of data (typically this will be subjects)
- `data`: `statfun` will be applied on second dimension of data (typically this will be subjects)

Optional
- `τ`: Cluster-forming threshold
- `statFun`: default the one-sample `studenttest`, can be any custom function on a Matrix returning a Vector
- `permFun`: default to sign-flip (for one-sample case)
- `nperm`: number of permutations, default 5000
- `stat_type`: default the one-sample `t-test`, custom function can be specified (see `statfun!` and `statfun`)
- `side_type`: default: `:abs` - what function should be applied after the `statfun`? could be `:abs`, `:square`, `:positive` to test positive clusters, `:negative` to test negative clusters. Custom function can be provided, see `sidefun``
- `perm_type`: default `:sign` for one-sample data (e.g. differences), performs sign flips. custom function can be provided, see `permfun`
- `pval_type`: how to calculate pvalues within each cluster, default `:troendle`, see `?pvals`
- `statfun` / `statfun!` a function that either takes one or two arguments and aggregates over last dimension. in the two argument case we expect the first argument to be modified inplace and provide a suitable Vector/Matrix.
- `sidefun`: default `abs`. Provide a function to be applied on each element of the output of `statfun`.
- `permfun` function to permute the data, should accept an RNG-object and the data. can be inplace, the data is copied, but the same array is shared between permutations

"""
clusterdepth(data::AbstractArray,args...;kwargs...) = clusterdepth(MersenneTwister(1),data,args...;kwargs...)
function clusterdepth(rng,data::AbstractArray;τ=2.3, statFun=x->abs.(studentt(x)),permFun=sign_permute!,nperm=5000,pval_type=:troendle)
cdmTuple = perm_clusterdepths_both(rng,data,statFun,permFun,τ;nₚ=nperm)
return pvals(statFun(data),cdmTuple,τ;type=pval_type)
function clusterdepth(rng,data::AbstractArray;τ=2.3,stat_type=:onesample_ttest,perm_type=:sign,side_type=:abs,nperm=5000,pval_type=:troendle,statfun! = nothing,statfun=nothing)
if stat_type == :onesample_ttest
statfun! = studentt!
statfun = studentt
end
if perm_type == :sign
permfun = sign_permute!
end
if side_type == :abs
sidefun = abs
elseif side_type == :square
sidefun = x->x^2
elseif side_type == :negative
sidefun = x->-x
elseif side_type == :positive
sidefun = nothing # the default :)
else
@assert isnothing(side_type) "unknown side_type ($side_type) specified. Check your spelling and ?clusterdepth"
end

cdmTuple = perm_clusterdepths_both(rng,data,permfun,τ;nₚ=nperm,statfun! = statfun!,statfun=statfun,sidefun = sidefun)

return pvals(statfun(data),cdmTuple,τ;type=pval_type)
end




function perm_clusterdepths_both(rng,data,statFun,permFun,τ;nₚ=1000)
function perm_clusterdepths_both(rng,data,permfun,τ;statfun = nothing,statfun! = nothing,nₚ=1000,sidefun=nothing)
@assert !(isnothing(statfun) && isnothing(statfun!)) "either statfun or statfun! has to be defined"

#Jₖ_head = ExtendableSparseMatrix(size(data,2),nₚ)
#Jₖ_tail = ExtendableSparseMatrix(size(data,2),nₚ)
data_perm = deepcopy(data)
rows_h = Int[]
cols_h = Int[]
vals_h = Float64[]
rows_t = Int[]
cols_t = Int[]
vals_t = Float64[]
rows_h = Int[]; cols_h = Int[]; vals_h = Float64[]
rows_t = Int[]; cols_t = Int[]; vals_t = Float64[]

if ndims(data_perm) == 2
d0 = Array{Float64}(undef,size(data_perm,1))
else
d0 = Array{Float64}(undef,size(data_perm)[[1,2]])
end
#@debug size(d0)
#@debug size(data_perm)
for i = 1:nₚ
# permute
d0 = permFun(rng,data_perm,statFun)

d_perm = permfun(rng,data_perm)
if isnothing(statfun!)
d0 = statfun(d_perm)
else
# inplace!
statfun!(d0,d_perm)
end
if !isnothing(sidefun)
d0 .= sidefun.(d0)
end
# get clusterdepth
(fromTo,head,tail) = calc_clusterdepth(d0,τ)

66 changes: 59 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,73 @@
function studentt!(out::AbstractMatrix,x::AbstractArray{<:Real,3};kwargs...)

for (x_ch,o_ch) = zip(eachslice(x,dims=1),eachslice(out,dims=1))
#@debug size(x_ch),size(o_ch)
studentt!(o_ch,x_ch;kwargs...)
end
return out
end


"""
studentt_test!(out,x;type=:abs)

studentt(x::AbstractMatrix) = (mean(x,dims=2)[:,1])./(std(x,dims=2)[:,1]./sqrt(size(x,2)-1))
strongly optimized one-sample t-test function.

studentt(x::AbstractArray{<:Real,3}) = return dropdims(mapslices(studentt,x,dims=(2,3)),dims=3)

Implements: t = mean(x) / ( sqrt(var(x))) / sqrt(size(x,2)-1)

Accepts 2D or 3D matrices, always aggregates over the last dimension

"""
function studentt_test!(out,x::AbstractMatrix)
mean!(out,x)
df = 1 ./ sqrt(size(x,2)-1)
#@debug size(out),size(x)
tmp = [1.]
for k = eachindex(out)
std!(tmp,@view(x[k,:]),out[k])
@views out[k] /= (sqrt(tmp[1]) * df)
end
return out
end

function std!(tmp,x_slice,μ)

@views x_slice .= (x_slice .- μ).^2
sum!(tmp,x_slice)
tmp .= sqrt.(tmp ./ (length(x_slice)-1))


end


function studentt!(out,x)
#@debug size(out),size(x)
mean!(out,x)
out .= out ./ (std(x,mean=out, dims=2)[:,1]./sqrt(size(x,2)-1))
end
function studentt(x::AbstractMatrix)
# more efficient than this one liner
# studentt(x::AbstractMatrix) = (mean(x,dims=2)[:,1])./(std(x,dims=2)[:,1]./sqrt(size(x,2)-1))
μ = mean(x,dims=2)[:,1]
μ .= μ ./ (std(x,mean=μ, dims=2)[:,1]./sqrt(size(x,2)-1))
end

studentt(x::AbstractArray{<:Real,3}) = dropdims(mapslices(studentt,x,dims=(2,3)),dims=3)

"""
Permutation via random sign-flip
Flips signs along the last dimension
"""
function sign_permute!(rng,x::AbstractArray,fun)
function sign_permute!(rng,x::AbstractArray)
n = ndims(x)
@assert n > 1 "vectors cannot be permuted"

fl = rand(rng,[-1,1],size(x,n))
#flipped = map((x,y)->x.*y,eachslice(x;dims=n),fl)
for (f,k) = zip(fl,eachslice(x;dims=n))
k .*= f

for (flip,xslice) = zip(fl,eachslice(x;dims=n))
xslice .= xslice .* flip
end

return fun(x)
return x
end
35 changes: 23 additions & 12 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,23 @@
@testset "sign_permute" begin

m = [1 1 1;2 2 2;3 3 3;4 4 4]
p = ClusterDepth.sign_permute!(StableRNG(2),deepcopy(m),x->x)
p = ClusterDepth.sign_permute!(StableRNG(2),deepcopy(m))

@test p[1,:] == [1, -1, 1]

# different seeds are different
@test p!= ClusterDepth.sign_permute!(StableRNG(3),deepcopy(m),x->x)
@test p!= ClusterDepth.sign_permute!(StableRNG(3),deepcopy(m))
# same seeds are the same
@test p == ClusterDepth.sign_permute!(StableRNG(2),deepcopy(m),x->x)
@test p == ClusterDepth.sign_permute!(StableRNG(2),deepcopy(m))

m = ones(1,1000000)
@test abs.(ClusterDepth.sign_permute!(StableRNG(1),deepcopy(m),mean))<0.001
@test abs(mean(ClusterDepth.sign_permute!(StableRNG(1),deepcopy(m))))<0.001

m = ones(1,2,3,4,5,6,7,100)
o = ClusterDepth.sign_permute!(StableRNG(1),deepcopy(m),x->x)
o = ClusterDepth.sign_permute!(StableRNG(1),deepcopy(m))
@test sort(unique(mean(o,dims=1:ndims(o)-1))) == [-1.,1.]

#2D input data
data = randn(StableRNG(1),4,5)
@test size(ClusterDepth.sign_permute!(StableRNG(1),data,ClusterDepth.studentt)) == (4,)

#3D input data
data = randn(StableRNG(1),3,4,5);
@test size(ClusterDepth.sign_permute!(StableRNG(1),data,ClusterDepth.studentt)) == (3,4)

end

@testset "studentt" begin
@@ -41,4 +34,22 @@ end
#3D input data
data = randn(StableRNG(1),3,4,5);
@test size(ClusterDepth.studentt(data)) == (3,4)

#
t = rand(10000)
ClusterDepth.studentt!(t,x)
@test t ≈ ClusterDepth.studentt(x)
@test length(t) == 10000
@test maximum(abs.(t))<10 # we'd need to be super lucky ;)
@test mean(abs.(t).>2) < 0.06

#2D input data
data = randn(StableRNG(1),4,5)
t = rand(4)
ClusterDepth.studentt!(t,data)
@test size(t) == (4,)

#3D input data
data = randn(StableRNG(1),3,4,5);
@test size(ClusterDepth.studentt(data)) == (3,4)
end