Skip to content

Commit

Permalink
Allow strings and numbers as labels in cut (#393)
Browse files Browse the repository at this point in the history
  • Loading branch information
skleinbo authored May 23, 2022
1 parent 7c686fb commit a8e4787
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 19 deletions.
43 changes: 28 additions & 15 deletions src/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ default_formatter(from, to, i; leftclosed, rightclosed) =

@doc raw"""
cut(x::AbstractArray, breaks::AbstractVector;
labels::Union{AbstractVector{<:AbstractString},Function},
labels::Union{AbstractVector,Function},
extend::Union{Bool,Missing}=false, allowempty::Bool=false)
Cut a numeric array into intervals at values `breaks`
Expand All @@ -52,7 +52,8 @@ also accept them.
in `x` fall outside of the breaks; when `true`, breaks are automatically added to include
all values in `x`, and the upper bound is included in the last interval; when `missing`,
values outside of the breaks generate `missing` entries.
* `labels::Union{AbstractVector,Function}`: a vector of strings giving the names to use for
* `labels::Union{AbstractVector, Function}`: a vector of strings, characters
or numbers giving the names to use for
the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates
the labels from the left and right interval boundaries and the group index. Defaults to
`"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`).
Expand Down Expand Up @@ -87,7 +88,15 @@ julia> cut(-1:0.5:1, 2, labels=["A", "B"])
"A"
"B"
"B"
"B"
"B"
julia> cut(-1:0.5:1, 2, labels=[-0.5, +0.5])
5-element CategoricalArray{Float64,1,UInt32}:
-0.5
-0.5
0.5
0.5
0.5
julia> fmt(from, to, i; leftclosed, rightclosed) = "grp $i ($from//$to)"
fmt (generic function with 1 method)
Expand All @@ -98,12 +107,12 @@ julia> cut(-1:0.5:1, 3, labels=fmt)
"grp 1 (-1.0//-0.3333333333333335)"
"grp 2 (-0.3333333333333335//0.33333333333333326)"
"grp 3 (0.33333333333333326//1.0)"
"grp 3 (0.33333333333333326//1.0)"
"grp 3 (0.33333333333333326//1.0)"
```
"""
@inline function cut(x::AbstractArray, breaks::AbstractVector;
extend::Union{Bool, Missing}=false,
labels::Union{AbstractVector{<:AbstractString},Function}=default_formatter,
labels::Union{AbstractVector{<:SupportedTypes},Function}=default_formatter,
allowmissing::Union{Bool, Nothing}=nothing,
allow_missing::Union{Bool, Nothing}=nothing,
allowempty::Bool=false)
Expand All @@ -123,7 +132,7 @@ end
# Separate function for inferability (thanks to inlining of cut)
function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
extend::Union{Bool, Missing},
labels::Union{AbstractVector{<:AbstractString},Function},
labels::Union{AbstractVector{<:SupportedTypes},Function},
allowempty::Bool=false) where {T, N}
if !allowempty && !allunique(breaks)
throw(ArgumentError("all breaks must be unique unless `allowempty=true`"))
Expand Down Expand Up @@ -152,10 +161,11 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
end
end
if !ismissing(min_x) && breaks[1] > min_x
breaks = [min_x; breaks]
# this type annotation is needed on Julia<1.7 for stable inference
breaks = [min_x::nonmissingtype(eltype(x)); breaks]
end
if !ismissing(max_x) && breaks[end] < max_x
breaks = [breaks; max_x]
breaks = [breaks; max_x::nonmissingtype(eltype(x))]
end
length(breaks) > 1 ||
throw(ArgumentError("could not extend breaks as all values are equal: " *
Expand All @@ -180,8 +190,11 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
if labels isa Function
from = breaks[1:n-1]
to = breaks[2:n]
levs = Vector{String}(undef, n-1)
for i in 1:n-2
firstlevel = labels(from[1], to[1], 1,
leftclosed=breaks[1] != breaks[2], rightclosed=false)
levs = Vector{typeof(firstlevel)}(undef, n-1)
levs[1] = firstlevel
for i in 2:n-2
levs[i] = labels(from[i], to[i], i,
leftclosed=breaks[i] != breaks[i+1], rightclosed=false)
end
Expand All @@ -191,8 +204,7 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
else
length(labels) == n-1 ||
throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))"))
# Levels must have element type String for type stability of the result
levs::Vector{String} = copy(labels)
levs = copy(labels)
end
if !allunique(levs)
if labels === default_formatter
Expand All @@ -204,7 +216,7 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector,
end

pool = CategoricalPool(levs, true)
S = T >: Missing || extend isa Missing ? Union{String, Missing} : String
S = T >: Missing || extend isa Missing ? Union{eltype(levs), Missing} : eltype(levs)
CategoricalArray{S, N}(refs, pool)
end

Expand All @@ -227,7 +239,8 @@ If `x` contains `missing` values, they are automatically skipped when computing
quantiles.
# Keyword arguments
* `labels::Union{AbstractVector,Function}`: a vector of strings giving the names to use for
* `labels::Union{AbstractVector, Function}`: a vector of strings, characters
or numbers giving the names to use for
the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates
the labels from the left and right interval boundaries and the group index. Defaults to
`"Qi: [from, to)"` (or `"Qi: [from, to]"` for the rightmost interval).
Expand All @@ -237,7 +250,7 @@ quantiles.
(but duplicate labels are not allowed).
"""
function cut(x::AbstractArray, ngroups::Integer;
labels::Union{AbstractVector{<:AbstractString},Function}=quantile_formatter,
labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter,
allowempty::Bool=false)
xnm = eltype(x) >: Missing ? skipmissing(x) : x
breaks = Statistics.quantile(xnm, (1:ngroups-1)/ngroups)
Expand Down
73 changes: 69 additions & 4 deletions test/15_extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@ const ≅ = isequal
@test isa(x, CategoricalMatrix{Union{String, T}})
@test isordered(x)
@test levels(x) == ["[-2.134, 3.0)", "[3.0, 12.5)"]

labels = 0:2:8
x = @inferred cut(Vector{Union{T, Int}}(1:8), 0:2:10, labels=labels)
@test x == [0,2,2,4,4,6,6,8]
@test isa(x, CategoricalVector{Union{Int, T}})
@test isordered(x)
@test levels(x) == [0, 2, 4, 6, 8]

labels = Union{Int, String}[0, "2", 4, "6", 8]
x = @inferred cut(Vector{Union{T, Int}}(1:8), 10:-2:0, labels=labels)
@test x == [0, "2", "2", 4, 4, "6", "6", 8]
@test isa(x, CategoricalVector{Union{Int, String, T}})
@test isordered(x)
@test levels(x) == [0, "2", 4, "6", 8]

@test_throws ArgumentError cut([-0.0, 0.0], 2)
@test_throws ArgumentError cut([-0.0, 0.0], 2, labels=[-0.0, 0.0])
end

@testset "cut with missing values in input" begin
Expand All @@ -95,6 +112,10 @@ end
y = cut(x, [1, 5])
y[1] = missing
@test all(ismissing, y)

y = cut(x, [1, 5], labels=[1])
y[1] = missing
@test all(ismissing, y)
end

@testset "cut([5, 4, 3, 2], 2)" begin
Expand All @@ -119,13 +140,35 @@ end
x = 0.15:0.20:0.95
p = [0, 0.4, 0.8, 1.0]

@test cut(x, p, labels=my_formatter) ==
["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"]
a = @inferred cut(x, p, labels=my_formatter)
@test a == ["1: 0.0 -- 0.4", "1: 0.0 -- 0.4", "2: 0.4 -- 0.8", "2: 0.4 -- 0.8", "3: 0.8 -- 1.0"]

# GH 274
my_formatter_2(from, to, i; leftclosed, rightclosed) = "$i: $(from+1) -- $(to+1)"
@test cut(x, p, labels=my_formatter_2) ==
["1: 1.0 -- 1.4", "1: 1.0 -- 1.4", "2: 1.4 -- 1.8", "2: 1.4 -- 1.8", "3: 1.8 -- 2.0"]
a = @inferred cut(x, p, labels=my_formatter_2)
@test a == ["1: 1.0 -- 1.4", "1: 1.0 -- 1.4", "2: 1.4 -- 1.8", "2: 1.4 -- 1.8", "3: 1.8 -- 2.0"]

for T in (Union{}, Missing)
labels = (from, to, i; leftclosed, rightclosed) -> (to+from)/2
a = @inferred cut(Vector{Union{T, Int}}(1:8), 0:2:10, labels=labels)
@test a == [1.0, 3.0, 3.0, 5.0, 5.0, 7.0, 7.0, 9.0]
@test isa(a, CategoricalVector{Union{Float64, T}})
@test isordered(a)
@test levels(a) == [1.0, 3.0, 5.0, 7.0, 9.0]

labels = (from, to, i; leftclosed, rightclosed) -> "$((to+from)/2)"
a = @inferred cut(Vector{Union{T, Int}}(1:8), 0:2:10, labels=labels)
@test a == string.([1.0, 3.0, 3.0, 5.0, 5.0, 7.0, 7.0, 9.0])
@test isa(a, CategoricalVector{Union{String, T}})
@test isordered(a)
@test levels(a) == string.([1.0, 3.0, 5.0, 7.0, 9.0])
end

@test cut(0.0:8.0, 3, labels=[-0.0, 0.0, 1.0]) ==
[-0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]

@test cut([-0.0, 0.0, 1.0, 2.0, 3.0, 4.0], [-0.0, 0.0, 5.0], labels=[-0.0, 0.0]) ==
[-0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
end

@testset "cut with duplicated breaks" begin
Expand Down Expand Up @@ -183,6 +226,12 @@ end
labels=["1", "2", "2", "3"])
@test_throws ArgumentError cut(1:10, [1, 3, 3, 5, 5, 11], allowempty=true,
labels=["1", "2", "3", "2", "4"])

@test_throws ArgumentError cut(1:8, 0:2:10, labels=[0, 1, 1, 2, 3])
@test_throws ArgumentError cut(1:8, [0, 2, 2, 6, 8, 10], labels=[0, 1, 1, 2, 3], allowempty=true)

fmt = (from, to, i; leftclosed, rightclosed) -> (i % 2 == 0 ? to : 0.0)
@test_throws ArgumentError cut(1:8, 0:2:10, labels=fmt)
end

@testset "cut with extend=true" begin
Expand All @@ -199,6 +248,22 @@ end
@test err.value.msg == "could not extend breaks as all values are missing: please specify at least two breaks manually"

@test cut([missing], [1, 2], extend=true) [missing]

@test cut([-0.0, 0.0, 1.0, 2.0, 3.0, 4.0], [-0.0, 0.0, 3.0],
labels=[-0.0, 0.0, 3.0], extend=true) ==
[-0.0, 0.0, 0.0, 0.0, 3.0, 3.0]
end

@testset "cut with extend=missing" begin
x = @inferred cut([-0.0, 0.0, 1.0, 2.0, 3.0, 4.0], [-0.0, 0.0, 3.0],
labels=[-0.0, 0.0], extend=missing)
@test x [-0.0, 0.0, 0.0, 0.0, missing, missing]
@test x isa CategoricalArray{Union{Missing, Float64},1,UInt32}
@test isordered(x)
@test levels(x) == [-0.0, 0.0]

x = @inferred cut(-1:0.5:1, [0, 1], extend=true)
@test x == ["[-1.0, 0.0)", "[-1.0, 0.0)", "[0.0, 1.0]", "[0.0, 1.0]", "[0.0, 1.0]"]
end

end

0 comments on commit a8e4787

Please # to comment.