Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Refactor whitening for closer integration with StatsBase types #144

Merged
merged 8 commits into from
Jun 1, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added dims parameter
  • Loading branch information
wildart committed Jun 1, 2021
commit 9236dc5281aedae694a1848cd1999e143962d0c1
4 changes: 2 additions & 2 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -22,8 +22,8 @@ decentralize(x::AbstractMatrix, m::AbstractVector) = (isempty(m) ? x : x .+ m)

fullmean(d::Int, mv::AbstractVector{T}) where T = (isempty(mv) ? zeros(T, d) : mv)

preprocess_mean(X::AbstractMatrix{T}, m) where T<:Real =
(m === nothing ? vec(mean(X, dims=2)) : m == 0 ? T[] : m)
preprocess_mean(X::AbstractMatrix{T}, m; dims=2) where T<:Real =
(m === nothing ? vec(mean(X, dims=dims)) : m == 0 ? T[] : m)

# choose the first k values and columns
#
42 changes: 36 additions & 6 deletions src/whiten.jl
Original file line number Diff line number Diff line change
@@ -80,18 +80,36 @@ mean(f::Whitening) = fullmean(indim(f), f.mean)

Apply the whitening transform `f` to a vector or a matrix `x` with samples in columns, as ``\\mathbf{W}^T (\\mathbf{x} - \\boldsymbol{\\mu})``.
"""
transform(f::Whitening, x::AbstractVecOrMat{<:Real}) = transpose(f.W) * centralize(x, f.mean)
function transform(f::Whitening, x::AbstractVecOrMat{<:Real})
s = size(x)
Z, dims = if length(s) == 1
length(f.mean) == s[1] || throw(DimensionMismatch("Inconsistent dimensions."))
x - f.mean, 2
else
dims = (s[1] == length(f.mean)) + 1
length(f.mean) == s[3-dims] || throw(DimensionMismatch("Inconsistent dimensions."))
x .- (dims == 2 ? f.mean : transpose(f.mean)), dims
end
if dims == 2
transpose(f.W) * Z
else
Z * f.W
end
end

"""
fit(::Type{Whitening}, X::AbstractMatrix{T}; kwargs...)

Estimate a whitening transform from the data given in `X`. Here, `X` should be a matrix, whose columns give the samples.
Estimate a whitening transform from the data given in `X`.

This function returns an instance of [`Whitening`](@ref)

**Keyword Arguments:**
- `regcoef`: The regularization coefficient. The covariance will be regularized as follows when `regcoef` is positive `C + (eigmax(C) * regcoef) * eye(d)`. Default values is `zero(T)`.

- `dims`: if `1` the transformation calculated from the row samples. fit standardization parameters in column-wise fashion;
if `2` the transformation calculated from the column samples. The default is `nothing`, which is equivalent to `dims=2` with a deprecation warning.

- `mean`: The mean vector, which can be either of:
- `0`: the input data has already been centralized
- `nothing`: this function will compute the mean (**default**)
@@ -100,11 +118,23 @@ This function returns an instance of [`Whitening`](@ref)
**Note:** This function internally relies on [`cov_whitening`](@ref) to derive the transformation `W`.
"""
function fit(::Type{Whitening}, X::AbstractMatrix{T};
dims::Union{Integer,Nothing}=nothing,
mean=nothing, regcoef::Real=zero(T)) where {T<:Real}
n = size(X, 2)
n > 1 || error("X must contain more than one sample.")
mv = preprocess_mean(X, mean)
Z = centralize(X, mv)
if dims === nothing
Base.depwarn("fit(Whitening, x) is deprecated: use fit(Whitening, x, dims=2) instead", :fit)
dims = 2
end
if dims == 1
n = size(X,1)
n >= 2 || error("X must contain at least two rows.")
elseif dims == 2
n = size(X, 2)
n >= 2 || error("X must contain at least two columns.")
else
throw(DomainError(dims, "fit only accept dims to be 1 or 2."))
end
mv = preprocess_mean(X, mean; dims=dims)
Z = centralize((dims==1 ? transpose(X) : X), mv)
C = rmul!(Z * transpose(Z), one(T) / (n - 1))
return Whitening(mv, cov_whitening!(C, regcoef))
end
15 changes: 15 additions & 0 deletions test/whiten.jl
Original file line number Diff line number Diff line change
@@ -101,4 +101,19 @@ import Random
SM = fit(Whitening, SX; mean=sprand(Float32, 3, 0.75))
Y = transform(SM, SX)
@test eltype(Y) == Float32

# different dimensions
@test_throws DomainError fit(Whitening, X'; dims=3)
M1 = fit(Whitening, X'; dims=1)
M2 = fit(Whitening, X; dims=2)
@test M1.W == M2.W
@test_throws DimensionMismatch transform(M1, rand(6,4))
@test_throws DimensionMismatch transform(M2, rand(4,6))
Y1 = transform(M1,X')
Y2 = transform(M2,X)
@test Y1' == Y2
@test_throws DimensionMismatch transform(M1, rand(7))
V1 = transform(M1,X[:,1])
V2 = transform(M2,X[:,1])
@test V1 == V2
end