Skip to content

Commit

Permalink
Add CategoricalValue(x, source) and disallow mixed isless and < (
Browse files Browse the repository at this point in the history
…#346)

`isless` and `<` between `CategoricalValue` and other types can breaks
transitivity. Require converting the value to a `CategoricalValue`,
attaching it to a pool, using
`CategoricalValue(x, source::Union{CategoricalValue, CategoricalArray)`.
Deprecate `CategoricalValue(i::Integer, pool::CategoricalPool)` in favor
of `pool[i]` or `CategoricalValue(pool, i)`, as the order is more logical
and it avoids a potential ambiguity if `source::CategoricalPool` is allowed
in the future.
  • Loading branch information
nalimilan authored Apr 23, 2021
1 parent fe7bed1 commit 4abf776
Show file tree
Hide file tree
Showing 16 changed files with 345 additions and 322 deletions.
10 changes: 5 additions & 5 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ end
end

Base.fill(v::CategoricalValue{T}, dims::NTuple{N, Integer}) where {T, N} =
CategoricalArray{T, N}(fill(level(v), dims), copy(pool(v)))
CategoricalArray{T, N}(fill(refcode(v), dims), copy(pool(v)))

# to avoid ambiguity
Base.fill(v::CategoricalValue, dims::Tuple{}) =
Expand Down Expand Up @@ -953,9 +953,9 @@ end

function in(x::CategoricalValue, y::CategoricalArray{T, N, R}) where {T, N, R}
if x.pool === y.pool
return x.level in y.refs
return refcode(x) in y.refs
else
ref = get(y.pool, levels(x.pool)[x.level], zero(R))
ref = get(y.pool, levels(x.pool)[refcode(x)], zero(R))
return ref != 0 ? ref in y.refs : false
end
end
Expand Down Expand Up @@ -1029,8 +1029,8 @@ function Base.sort!(v::CategoricalVector;
seen = counts .> 0
anymissing = eltype(v) >: Missing && seen[1]
levs = eltype(v) >: Missing ?
eltype(v)[i == 0 ? missing : CategoricalValue(i, v.pool) for i in 0:length(v.pool)] :
eltype(v)[CategoricalValue(i, v.pool) for i in 1:length(v.pool)]
eltype(v)[i == 0 ? missing : CategoricalValue(v.pool, i) for i in 0:length(v.pool)] :
eltype(v)[CategoricalValue(v.pool, i) for i in 1:length(v.pool)]
sortedlevs = sort!(Vector(view(levs, seen)), order=ord)
levelsmap = something.(indexin(sortedlevs, levs))
j = 0
Expand Down
1 change: 1 addition & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ end
import Base: get

@deprecate get(x::CategoricalValue) DataAPI.unwrap(x)
@deprecate CategoricalValue(i::Integer, pool::CategoricalPool) pool[i]
8 changes: 4 additions & 4 deletions src/missingarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ Base.fill!(A::CategoricalArray{>:Missing}, ::Missing) = (fill!(A.refs, 0); A)
in(x::Missing, y::CategoricalArray) = false
in(x::Missing, y::CategoricalArray{>:Missing}) = !all(v -> v > 0, y.refs)

function Missings.replace(a::CategoricalArray{S, N, R, V, C}, replacement::V) where {S, N, R, V, C}
function Missings.replace(a::CategoricalArray{T, N, R, V, C}, replacement::V) where {T, N, R, V, C}
pool = copy(a.pool)
v = C(get!(pool, replacement), pool)
v = C(pool, get!(pool, replacement))
Missings.replace(a, v)
end

function collect(r::Missings.EachReplaceMissing{<:CategoricalArray{S, N, R, C}}) where {S, N, R, C}
CategoricalArray{C,N}(R[v.level for v in r], r.replacement.pool)
function collect(r::Missings.EachReplaceMissing{<:CategoricalArray{T, N, R, V}}) where {T, N, R, V}
CategoricalArray{V,N}(R[refcode(v) for v in r], r.replacement.pool)
end
4 changes: 2 additions & 2 deletions src/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end

Base.length(pool::CategoricalPool) = length(pool.levels)

Base.getindex(pool::CategoricalPool, i::Integer) = CategoricalValue(i, pool)
Base.getindex(pool::CategoricalPool, i::Integer) = CategoricalValue(pool, i)
Base.get(pool::CategoricalPool, level::Any) = pool.invindex[level]
Base.get(pool::CategoricalPool, level::Any, default::Any) = get(pool.invindex, level, default)

Expand Down Expand Up @@ -148,7 +148,7 @@ end

@inline function Base.get!(pool::CategoricalPool, level::CategoricalValue)
if pool === level.pool || pool == level.pool
return level.level
return refcode(level)
end
if level.pool pool
if isordered(pool)
Expand Down
2 changes: 1 addition & 1 deletion src/typedefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ the order of the pool's [`levels`](@ref DataAPI.levels) is used rather than the
ordering of values of type `T`.
"""
struct CategoricalValue{T <: SupportedTypes, R <: Integer}
level::R
pool::CategoricalPool{T, R, CategoricalValue{T, R}}
ref::R
end

## Arrays
Expand Down
54 changes: 30 additions & 24 deletions src/value.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
CategoricalValue(level::Integer, pool::CategoricalPool{T, R}) where {T, R} =
CategoricalValue(convert(R, level), pool)
CategoricalValue(pool::CategoricalPool{T, R}, level::Integer) where {T, R} =
CategoricalValue(pool, convert(R, level))

"""
CategoricalValue(value, source::Union{CategoricalValue, CategoricalArray})
Return a `CategoricalValue` object wrapping `value` and attached to
the [`CategoricalPool`](@ref) of `source`.
"""
function CategoricalValue(value, source::Union{CategoricalValue, CatArrOrSub})
p = pool(source)
i = get(p, value, nothing)
i === nothing && throw(ArgumentError("level $value not found in source pool"))
return CategoricalValue(p, i)
end

leveltype(::Type{<:CategoricalValue{T}}) where {T} = T
leveltype(::Type{T}) where {T} = T
Expand All @@ -12,7 +25,7 @@ reftype(::Type{<:CategoricalValue{<:Any, R}}) where {R} = R
reftype(x::Any) = reftype(typeof(x))

pool(x::CategoricalValue) = x.pool
level(x::CategoricalValue) = x.level
refcode(x::CategoricalValue) = x.ref
isordered(x::CategoricalValue) = isordered(x.pool)

# extract the type of the original value from array eltype `T`
Expand All @@ -29,15 +42,15 @@ unwrap_catvaluetype(::Type{T}) where {T <: CategoricalValue} = leveltype(T)
Get the value wrapped by categorical value `x`. If `x` is `Missing` return `missing`.
"""
DataAPI.unwrap(x::CategoricalValue) = levels(x)[level(x)]
DataAPI.unwrap(x::CategoricalValue) = levels(x)[refcode(x)]

"""
levelcode(x::CategoricalValue)
Get the code of categorical value `x`, i.e. its index in the set
of possible values returned by [`levels(x)`](@ref DataAPI.levels).
"""
levelcode(x::CategoricalValue) = Signed(widen(level(x)))
levelcode(x::CategoricalValue) = Signed(widen(refcode(x)))

"""
levelcode(x::Missing)
Expand Down Expand Up @@ -107,7 +120,7 @@ Base.String(x::CategoricalValue{<:AbstractString}) = String(unwrap(x))

@inline function Base.:(==)(x::CategoricalValue, y::CategoricalValue)
if pool(x) === pool(y) || pool(x) == pool(y)
return level(x) == level(y)
return refcode(x) == refcode(y)
else
return unwrap(x) == unwrap(y)
end
Expand All @@ -118,7 +131,7 @@ Base.:(==)(x::SupportedTypes, y::CategoricalValue) = x == unwrap(y)

@inline function Base.isequal(x::CategoricalValue, y::CategoricalValue)
if pool(x) === pool(y) || pool(x) == pool(y)
return level(x) == level(y)
return refcode(x) == refcode(y)
else
return isequal(unwrap(x), unwrap(y))
end
Expand All @@ -140,8 +153,11 @@ function Base.isless(x::CategoricalValue, y::CategoricalValue)
end
end

Base.isless(x::CategoricalValue, y::SupportedTypes) = levelcode(x) < levelcode(x.pool[get(x.pool, y)])
Base.isless(y::SupportedTypes, x::CategoricalValue) = levelcode(x.pool[get(x.pool, y)]) < levelcode(x)
Base.isless(x::CategoricalValue, y::SupportedTypes) =
throw(ArgumentError("cannot compare a `CategoricalValue` to value `v` of type " *
"`$(typeof(x))`: wrap `v` using `CategoricalValue(v, catvalue)` " *
"or `CategoricalValue(v, catarray)` first"))
Base.isless(y::SupportedTypes, x::CategoricalValue) = isless(x, y)

function Base.:<(x::CategoricalValue, y::CategoricalValue)
poolx = pool(x)
Expand All @@ -157,21 +173,11 @@ function Base.:<(x::CategoricalValue, y::CategoricalValue)
end
end

function Base.:<(x::CategoricalValue, y::SupportedTypes)
if !isordered(pool(x))
throw(ArgumentError("Unordered CategoricalValue objects cannot be tested for order using <. Use isless instead, or call the ordered! function on the parent array to change this"))
else
return levelcode(x) < levelcode(x.pool[get(x.pool, y)])
end
end

function Base.:<(y::SupportedTypes, x::CategoricalValue)
if !isordered(pool(x))
throw(ArgumentError("Unordered CategoricalValue objects cannot be tested for order using <. Use isless instead, or call the ordered! function on the parent array to change this"))
else
return levelcode(x.pool[get(x.pool, y)]) < levelcode(x)
end
end
Base.:<(x::CategoricalValue, y::SupportedTypes) =
throw(ArgumentError("cannot compare a `CategoricalValue` to value `v` of type " *
"`$(typeof(x))`: wrap `v` using `CategoricalValue(v, catvalue)` " *
"or `CategoricalValue(v, catarray)` first"))
Base.:<(y::SupportedTypes, x::CategoricalValue) = x < y

# JSON of CategoricalValue is JSON of the value it refers to
JSON.lower(x::CategoricalValue) = JSON.lower(unwrap(x))
Expand Down
30 changes: 23 additions & 7 deletions test/01_value.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module TestValue
using Test
using CategoricalArrays
using CategoricalArrays: DefaultRefType, level, reftype, leveltype
using CategoricalArrays: DefaultRefType, refcode, reftype, leveltype

@testset "leveltype on non CategoricalValue types" begin
@test leveltype("abc") === String
Expand All @@ -20,19 +20,19 @@ end
)

for i in 1:3
x = CategoricalValue(i, pool)
x = CategoricalValue(pool, i)

@test leveltype(x) === String
@test leveltype(typeof(x)) === String
@test reftype(x) === DefaultRefType
@test reftype(typeof(x)) === DefaultRefType
@test x isa CategoricalValue{String, DefaultRefType}

@test level(x) === DefaultRefType(i)
@test refcode(x) === DefaultRefType(i)
@test CategoricalArrays.pool(x) === pool

@test typeof(x)(x) === x
@test CategoricalValue(UInt8(i), pool) == x
@test CategoricalValue(pool, UInt8(i)) == x
end
end

Expand All @@ -46,19 +46,35 @@ end
)

for i in 1:3
x = CategoricalValue(i, pool)
x = CategoricalValue(pool, i)

@test leveltype(x) === String
@test leveltype(typeof(x)) === String
@test reftype(x) === UInt8
@test reftype(typeof(x)) === UInt8
@test x isa CategoricalValue{String, UInt8}

@test level(x) === UInt8(i)
@test refcode(x) === UInt8(i)
@test CategoricalArrays.pool(x) === pool

@test typeof(x)(x) === x
@test CategoricalValue(UInt32(i), pool) == x
@test CategoricalValue(pool, UInt32(i)) == x
end
end

@testset "constructor from other value" begin
pool = CategoricalPool([2, 3, 1])
arr = CategoricalVector{Int}(DefaultRefType[2, 1, 3], pool)
for x in (CategoricalValue(pool, 1), arr, view(arr, 2:3))
for (i, v) in enumerate(levels(pool))
@test CategoricalValue(v, x) ===
CategoricalValue(float(v), x) ===
CategoricalValue(CategoricalValue(pool, i), x) ===
CategoricalValue(pool, i)
end

@test_throws ArgumentError CategoricalValue(4, x)
@test_throws ArgumentError CategoricalValue(missing, x)
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/04_constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ end
pool = CategoricalPool{Float64, UInt8}([1.0, 2.0, 3.0])

@test isa(pool, CategoricalPool{Float64, UInt8, CategoricalValue{Float64, UInt8}})
@test CategoricalValue(1, pool) isa CategoricalValue{Float64, UInt8}
@test CategoricalValue(pool, 1) isa CategoricalValue{Float64, UInt8}
end

@testset "Invalid arguments" begin
Expand Down
10 changes: 5 additions & 5 deletions test/05_convert.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module TestConvert
using Test
using CategoricalArrays
using CategoricalArrays: DefaultRefType, level, reftype, leveltype
using CategoricalArrays: DefaultRefType, refcode, reftype, leveltype

@testset "convert() for CategoricalPool{Int, DefaultRefType} and values" begin
pool = CategoricalPool([1, 2, 3])
Expand All @@ -12,9 +12,9 @@ using CategoricalArrays: DefaultRefType, level, reftype, leveltype
convert(CategoricalPool{Float64}, pool)
convert(CategoricalPool, pool)

v1 = CategoricalValue(1, pool)
v2 = CategoricalValue(2, pool)
v3 = CategoricalValue(3, pool)
v1 = CategoricalValue(pool, 1)
v2 = CategoricalValue(pool, 2)
v3 = CategoricalValue(pool, 3)
@test eltype(v1) === Any
@test eltype(typeof(v1)) === Any
@test leveltype(v1) === Int
Expand Down Expand Up @@ -153,7 +153,7 @@ end
@testset "levelcode" begin
pool = CategoricalPool{Int,UInt8}([2, 1, 3])
for i in 1:3
v = CategoricalValue(i, pool)
v = CategoricalValue(pool, i)
@test levelcode(v) isa Int16
@test levels(pool)[levelcode(v)] == v
end
Expand Down
24 changes: 12 additions & 12 deletions test/06_show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ using CategoricalArrays
pool = CategoricalPool(["c", "b", "a"])
opool = CategoricalPool(["c", "b", "a"], true)

nv1 = CategoricalValue(1, pool)
nv2 = CategoricalValue(2, pool)
nv3 = CategoricalValue(3, pool)
nv1 = CategoricalValue(pool, 1)
nv2 = CategoricalValue(pool, 2)
nv3 = CategoricalValue(pool, 3)

ov1 = CategoricalValue(1, opool)
ov2 = CategoricalValue(2, opool)
ov3 = CategoricalValue(3, opool)
ov1 = CategoricalValue(opool, 1)
ov2 = CategoricalValue(opool, 2)
ov3 = CategoricalValue(opool, 3)

if VERSION >= v"1.6.0"
@test sprint(show, pool) == "$CategoricalPool{String, UInt32}([\"c\", \"b\", \"a\"])"
Expand Down Expand Up @@ -78,30 +78,30 @@ using JSON
@testset "JSON.lower" for pool in (CategoricalPool(["a"]),
CategoricalPool([1]),
CategoricalPool([1.0]))
v = CategoricalValue(1, pool)
v = CategoricalValue(pool, 1)
@test JSON.lower(v) == JSON.lower(unwrap(v))
@test typeof(JSON.lower(v)) == typeof(JSON.lower(unwrap(v)))
end

using JSON3
using StructTypes
@testset "JSON3.write" begin
v = CategoricalValue(1, CategoricalPool(["a"]))
v = CategoricalValue(CategoricalPool(["a"]), 1)
@test JSON3.write(v) === "\"a\""

v = CategoricalValue(1, CategoricalPool([1]))
v = CategoricalValue(CategoricalPool([1]), 1)
@test JSON3.write(v) === "1"
@test StructTypes.numbertype(typeof(v)) === Int

v = CategoricalValue(1, CategoricalPool([2.0]))
v = CategoricalValue(CategoricalPool([2.0]), 1)
@test JSON3.write(v) === "2.0"
@test StructTypes.numbertype(typeof(v)) === Float64

v = CategoricalValue(1, CategoricalPool([BigFloat(3.0,10)]))
v = CategoricalValue(CategoricalPool([BigFloat(3.0,10)]), 1)
@test JSON3.write(v) === "3.0"
@test StructTypes.numbertype(typeof(v)) === BigFloat

v = CategoricalValue(2, CategoricalPool([true,false]))
v = CategoricalValue(CategoricalPool([true,false]), 2)
@test JSON3.write(v) == "false"
@test StructTypes.numbertype(typeof(v)) === Bool
end
Expand Down
Loading

0 comments on commit 4abf776

Please # to comment.