diff --git a/src/cluster.jl b/src/cluster.jl index 59e0709..58c8702 100755 --- a/src/cluster.jl +++ b/src/cluster.jl @@ -1,44 +1,79 @@ """ -clusterdepth(rng,data::AbstractArray;τ=2.3, statFun=x->abs.(studentt(x)),permFun=sign_permute!,nperm=5000,pval_type=:troendle) +clusterdepth(rng,data::AbstractArray;τ=2.3, statfun=x->abs.(studentt(x)),permfun=sign_permute!,nperm=5000,pval_type=:troendle) calculate clusterdepth of given datamatrix. -- `data`: `statFun` will be applied on second dimension of data (typically this will be subjects) +- `data`: `statfun` will be applied on second dimension of data (typically this will be subjects) Optional - `τ`: Cluster-forming threshold -- `statFun`: default the one-sample `studenttest`, can be any custom function on a Matrix returning a Vector -- `permFun`: default to sign-flip (for one-sample case) - `nperm`: number of permutations, default 5000 +- `stat_type`: default the one-sample `t-test`, custom function can be specified (see `statfun!` and `statfun`) +- `side_type`: default: `:abs` - what function should be applied after the `statfun`? could be `:abs`, `:square`, `:positive` to test positive clusters, `:negative` to test negative clusters. Custom function can be provided, see `sidefun`` +- `perm_type`: default `:sign` for one-sample data (e.g. differences), performs sign flips. custom function can be provided, see `permfun` - `pval_type`: how to calculate pvalues within each cluster, default `:troendle`, see `?pvals` +- `statfun` / `statfun!` a function that either takes one or two arguments and aggregates over last dimension. in the two argument case we expect the first argument to be modified inplace and provide a suitable Vector/Matrix. +- `sidefun`: default `abs`. Provide a function to be applied on each element of the output of `statfun`. +- `permfun` function to permute the data, should accept an RNG-object and the data. can be inplace, the data is copied, but the same array is shared between permutations """ clusterdepth(data::AbstractArray,args...;kwargs...) = clusterdepth(MersenneTwister(1),data,args...;kwargs...) -function clusterdepth(rng,data::AbstractArray;τ=2.3, statFun=x->abs.(studentt(x)),permFun=sign_permute!,nperm=5000,pval_type=:troendle) - cdmTuple = perm_clusterdepths_both(rng,data,statFun,permFun,τ;nₚ=nperm) - return pvals(statFun(data),cdmTuple,τ;type=pval_type) +function clusterdepth(rng,data::AbstractArray;τ=2.3,stat_type=:onesample_ttest,perm_type=:sign,side_type=:abs,nperm=5000,pval_type=:troendle,statfun! = nothing,statfun=nothing) + if stat_type == :onesample_ttest + statfun! = studentt! + statfun = studentt + end + if perm_type == :sign + permfun = sign_permute! + end + if side_type == :abs + sidefun = abs + elseif side_type == :square + sidefun = x->x^2 + elseif side_type == :negative + sidefun = x->-x + elseif side_type == :positive + sidefun = nothing # the default :) + else + @assert isnothing(side_type) "unknown side_type ($side_type) specified. Check your spelling and ?clusterdepth" + end + + cdmTuple = perm_clusterdepths_both(rng,data,permfun,τ;nₚ=nperm,statfun! = statfun!,statfun=statfun,sidefun = sidefun) + + return pvals(statfun(data),cdmTuple,τ;type=pval_type) end -function perm_clusterdepths_both(rng,data,statFun,permFun,τ;nₚ=1000) +function perm_clusterdepths_both(rng,data,permfun,τ;statfun = nothing,statfun! = nothing,nₚ=1000,sidefun=nothing) + @assert !(isnothing(statfun) && isnothing(statfun!)) "either statfun or statfun! has to be defined" - #Jₖ_head = ExtendableSparseMatrix(size(data,2),nₚ) - #Jₖ_tail = ExtendableSparseMatrix(size(data,2),nₚ) data_perm = deepcopy(data) - rows_h = Int[] - cols_h = Int[] - vals_h = Float64[] - rows_t = Int[] - cols_t = Int[] - vals_t = Float64[] + rows_h = Int[]; cols_h = Int[]; vals_h = Float64[] + rows_t = Int[]; cols_t = Int[]; vals_t = Float64[] + + if ndims(data_perm) == 2 + d0 = Array{Float64}(undef,size(data_perm,1)) + else + d0 = Array{Float64}(undef,size(data_perm)[[1,2]]) + end + #@debug size(d0) + #@debug size(data_perm) for i = 1:nₚ # permute - d0 = permFun(rng,data_perm,statFun) - + d_perm = permfun(rng,data_perm) + if isnothing(statfun!) + d0 = statfun(d_perm) + else + # inplace! + statfun!(d0,d_perm) + end + if !isnothing(sidefun) + d0 .= sidefun.(d0) + end # get clusterdepth (fromTo,head,tail) = calc_clusterdepth(d0,τ) diff --git a/src/utils.jl b/src/utils.jl index 262a0ed..4455c3a 100755 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,21 +1,73 @@ +function studentt!(out::AbstractMatrix,x::AbstractArray{<:Real,3};kwargs...) + + for (x_ch,o_ch) = zip(eachslice(x,dims=1),eachslice(out,dims=1)) + #@debug size(x_ch),size(o_ch) + studentt!(o_ch,x_ch;kwargs...) + end + return out +end + + +""" + studentt_test!(out,x;type=:abs) -studentt(x::AbstractMatrix) = (mean(x,dims=2)[:,1])./(std(x,dims=2)[:,1]./sqrt(size(x,2)-1)) +strongly optimized one-sample t-test function. -studentt(x::AbstractArray{<:Real,3}) = return dropdims(mapslices(studentt,x,dims=(2,3)),dims=3) + +Implements: t = mean(x) / ( sqrt(var(x))) / sqrt(size(x,2)-1) + +Accepts 2D or 3D matrices, always aggregates over the last dimension + +""" +function studentt_test!(out,x::AbstractMatrix) + mean!(out,x) + df = 1 ./ sqrt(size(x,2)-1) +#@debug size(out),size(x) +tmp = [1.] +for k = eachindex(out) + std!(tmp,@view(x[k,:]),out[k]) + @views out[k] /= (sqrt(tmp[1]) * df) + end + return out +end + +function std!(tmp,x_slice,μ) + + @views x_slice .= (x_slice .- μ).^2 + sum!(tmp,x_slice) + tmp .= sqrt.(tmp ./ (length(x_slice)-1)) + + +end + + +function studentt!(out,x) + #@debug size(out),size(x) + mean!(out,x) + out .= out ./ (std(x,mean=out, dims=2)[:,1]./sqrt(size(x,2)-1)) +end +function studentt(x::AbstractMatrix) + # more efficient than this one liner + # studentt(x::AbstractMatrix) = (mean(x,dims=2)[:,1])./(std(x,dims=2)[:,1]./sqrt(size(x,2)-1)) + μ = mean(x,dims=2)[:,1] + μ .= μ ./ (std(x,mean=μ, dims=2)[:,1]./sqrt(size(x,2)-1)) +end + +studentt(x::AbstractArray{<:Real,3}) = dropdims(mapslices(studentt,x,dims=(2,3)),dims=3) """ Permutation via random sign-flip Flips signs along the last dimension """ -function sign_permute!(rng,x::AbstractArray,fun) +function sign_permute!(rng,x::AbstractArray) n = ndims(x) @assert n > 1 "vectors cannot be permuted" fl = rand(rng,[-1,1],size(x,n)) - #flipped = map((x,y)->x.*y,eachslice(x;dims=n),fl) - for (f,k) = zip(fl,eachslice(x;dims=n)) - k .*= f + + for (flip,xslice) = zip(fl,eachslice(x;dims=n)) + xslice .= xslice .* flip end - return fun(x) + return x end diff --git a/test/utils.jl b/test/utils.jl index 9bc5c1a..d610a66 100755 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,30 +1,23 @@ @testset "sign_permute" begin m = [1 1 1;2 2 2;3 3 3;4 4 4] - p = ClusterDepth.sign_permute!(StableRNG(2),deepcopy(m),x->x) + p = ClusterDepth.sign_permute!(StableRNG(2),deepcopy(m)) @test p[1,:] == [1, -1, 1] # different seeds are different - @test p!= ClusterDepth.sign_permute!(StableRNG(3),deepcopy(m),x->x) + @test p!= ClusterDepth.sign_permute!(StableRNG(3),deepcopy(m)) # same seeds are the same - @test p == ClusterDepth.sign_permute!(StableRNG(2),deepcopy(m),x->x) + @test p == ClusterDepth.sign_permute!(StableRNG(2),deepcopy(m)) m = ones(1,1000000) - @test abs.(ClusterDepth.sign_permute!(StableRNG(1),deepcopy(m),mean))<0.001 + @test abs(mean(ClusterDepth.sign_permute!(StableRNG(1),deepcopy(m))))<0.001 m = ones(1,2,3,4,5,6,7,100) - o = ClusterDepth.sign_permute!(StableRNG(1),deepcopy(m),x->x) + o = ClusterDepth.sign_permute!(StableRNG(1),deepcopy(m)) @test sort(unique(mean(o,dims=1:ndims(o)-1))) == [-1.,1.] - #2D input data - data = randn(StableRNG(1),4,5) - @test size(ClusterDepth.sign_permute!(StableRNG(1),data,ClusterDepth.studentt)) == (4,) - #3D input data - data = randn(StableRNG(1),3,4,5); - @test size(ClusterDepth.sign_permute!(StableRNG(1),data,ClusterDepth.studentt)) == (3,4) - end @testset "studentt" begin @@ -41,4 +34,22 @@ end #3D input data data = randn(StableRNG(1),3,4,5); @test size(ClusterDepth.studentt(data)) == (3,4) + + # + t = rand(10000) + ClusterDepth.studentt!(t,x) + @test t ≈ ClusterDepth.studentt(x) + @test length(t) == 10000 + @test maximum(abs.(t))<10 # we'd need to be super lucky ;) + @test mean(abs.(t).>2) < 0.06 + + #2D input data + data = randn(StableRNG(1),4,5) + t = rand(4) + ClusterDepth.studentt!(t,data) + @test size(t) == (4,) + + #3D input data + data = randn(StableRNG(1),3,4,5); + @test size(ClusterDepth.studentt(data)) == (3,4) end