Skip to content

Commit

Permalink
Improve LabeledValue and LabeledArray
Browse files Browse the repository at this point in the history
  • Loading branch information
junyuan-chen committed Apr 1, 2024
1 parent d6f9f1c commit b6906d6
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 27 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Expand All @@ -24,6 +25,7 @@ DataAPI = "1.13"
DataFrames = "1"
InlineStrings = "1.1"
MappedArrays = "0.4"
Missings = "1"
PooledArrays = "1"
PrecompileTools = "1"
PrettyTables = "1, 2"
Expand Down
89 changes: 70 additions & 19 deletions src/LabeledArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ false
julia> v1 == 1
true
julia> isnan(v1)
false
julia> isequal(vm, missing)
true
Expand Down Expand Up @@ -87,8 +90,15 @@ Base.isless(x::Missing, y::LabeledValue) = isless(x, y.value)
Base.isapprox(x::LabeledValue, y; kwargs...) = isapprox(x.value, y; kwargs...)
Base.isapprox(x, y::LabeledValue; kwargs...) = isapprox(x, y.value; kwargs...)

Base.iszero(x::LabeledValue) = iszero(x.value)
Base.isnan(x::LabeledValue) = isnan(x.value)
Base.isinf(x::LabeledValue) = isinf(x.value)
Base.isfinite(x::LabeledValue) = isfinite(x.value)

Base.hash(x::LabeledValue, h::UInt=zero(UInt)) = hash(x.value, h)

Base.length(x::LabeledValue) = length(x.value)

"""
unwrap(x::LabeledValue)
Expand All @@ -115,8 +125,8 @@ Base.show(io::IO, x::LabeledValue) = print(io, _getlabel(x))
Base.show(io::IO, ::MIME"text/plain", x::LabeledValue) =
print(io, x.value, " => ", _getlabel(x))

Base.convert(::Type{<:LabeledValue{T1}}, x::LabeledValue{T2}) where {T1,T2} =
LabeledValue(convert(T1, x.value), x.labels)
Base.convert(::Type{<:LabeledValue{T1,K}}, x::LabeledValue{T2,K}) where {T1,T2,K} =
LabeledValue{T1,K}(convert(T1, x.value), x.labels)
Base.convert(::Type{T}, x::LabeledValue) where T<:AbstractString = convert(T, _getlabel(x))

"""
Expand Down Expand Up @@ -152,6 +162,10 @@ are supported for [`LabeledVector`](@ref).
They are applied on the underlying array of values retrieved via [`refarray`](@ref)
and do not modify the dictionary of value labels.
For convenience, `LabeledArray(x::AbstractArray{<:AbstractString}, ::Type{T}=Int32)`
converts a string array to a `LabeledArray`
by encoding the string values with integers of the specified type (`Int32` by default).
# Examples
```jldoctest
julia> lbls1 = Dict(1=>"a", 2=>"b");
Expand Down Expand Up @@ -197,6 +211,14 @@ julia> push!(x, 2)
2 => b
2 => b
julia> push!(x, 3 => "c")
5-element LabeledVector{Int64, Vector{Int64}, Int64}:
0 => 0
1 => a
2 => b
2 => b
3 => c
julia> deleteat!(x, 4)
3-element LabeledVector{Int64, Vector{Int64}, Int64}:
0 => 0
Expand All @@ -211,6 +233,14 @@ julia> append!(x, [0, 1, 2])
0 => 0
1 => a
2 => b
julia> v = ["a", "b", "c"];
julia> LabeledArray(v, Int16)
3-element LabeledVector{Int16, Vector{Int16}, Union{Char, Int32}}:
1 => a
2 => b
3 => c
```
"""
struct LabeledArray{V, N, A<:AbstractArray{V, N}, K} <: AbstractArray{LabeledValue{V, K}, N}
Expand All @@ -222,6 +252,13 @@ struct LabeledArray{V, N, A<:AbstractArray{V, N}, K} <: AbstractArray{LabeledVal
new{V, N, A, K}(A(undef, dims), Dict{K, String}())
end

# Convenience method for encoding string arrays
function LabeledArray(x::AbstractArray{<:AbstractString}, ::Type{T}=Int32) where T
refs, invpool, pool = _label(x, eltype(x), T)
lbls = Dict{Union{Int32,Char},String}(Int32(v)=>string(k) for (k, v) in pairs(invpool))
return LabeledArray(refs, lbls)
end

"""
LabeledVector{V, A, K} <: AbstractVector{LabeledValue{V, K}}
Expand All @@ -239,16 +276,16 @@ const LabeledMatrix{V, A, K} = LabeledArray{V, 2, A, K}
defaultarray(::Type{LabeledValue{V,K}}, N) where {V,K} =
LabeledArray{V, N, defaultarray(V, N), K}

const LabeledArrOrSubOrReshape{V, N} = Union{LabeledArray{V, N},
SubArray{<:Any, N, <:LabeledArray{V}}, Base.ReshapedArray{<:Any, N, <:LabeledArray{V}},
SubArray{<:Any, N, <:Base.ReshapedArray{<:Any, <:Any, <:LabeledArray{V}}}}
const LabeledArrOrSubOrReshape{V, K, N} = Union{LabeledArray{V, N, <:Any, K},
SubArray{<:Any, N, <:LabeledArray{V, <:Any, <:Any, K}}, Base.ReshapedArray{<:Any, N, <:LabeledArray{V, <:Any, <:Any, K}},
SubArray{<:Any, N, <:Base.ReshapedArray{<:Any, <:Any, <:LabeledArray{V, <:Any, <:Any, K}}}}

Base.size(x::LabeledArray) = size(refarray(x))
Base.IndexStyle(::Type{<:LabeledArray{V,N,A}}) where {V,N,A} = IndexStyle(A)

Base.@propagate_inbounds function Base.getindex(x::LabeledArray, i::Int)
val = refarray(x)[i]
return LabeledValue(val, getvaluelabels(x))
Base.@propagate_inbounds function Base.getindex(x::LabeledArray{V,N,A,K}, i::Int) where {V,N,A,K}
val = refarray(x)[i]::V
return LabeledValue{V,K}(val, getvaluelabels(x))
end

Base.@propagate_inbounds function Base.setindex!(x::LabeledArray, v, i::Int)
Expand Down Expand Up @@ -286,9 +323,10 @@ getvaluelabels(x::Base.ReshapedArray{<:Any, <:Any, <:LabeledArray}) = parent(x).
getvaluelabels(x::SubArray{<:Any, <:Any,
<:Base.ReshapedArray{<:Any, <:Any, <:LabeledArray}}) = parent(parent(x)).labels

Base.@propagate_inbounds function Base.getindex(x::LabeledArrOrSubOrReshape, i::Integer)
val = refarray(x)[i]
return LabeledValue(val, getvaluelabels(x))
# The type annotation ::V and LabeledValue{V,K} avoids an allocation
Base.@propagate_inbounds function Base.getindex(x::LabeledArrOrSubOrReshape{V,K}, i::Integer) where {V,K}
val = refarray(x)[i]::V
return LabeledValue{V,K}(val, getvaluelabels(x))
end

# This avoids method ambiguity on Julia v1.11 with
Expand All @@ -303,21 +341,27 @@ Base.@propagate_inbounds function Base.getindex(x::SubArray{<:Any, N,
return LabeledArray(val, getvaluelabels(x))
end

# Needed for repeat(x, inner=2) to work
Base.@propagate_inbounds function Base.getindex(x::LabeledArrOrSubOrReshape{V,K}, i::CartesianIndex) where {V,K}
val = refarray(x)[i]::V
return LabeledValue{V,K}(val, getvaluelabels(x))
end

Base.@propagate_inbounds function Base.getindex(x::LabeledArrOrSubOrReshape, i)
val = refarray(x)[i]
return LabeledArray(val, getvaluelabels(x))
end

Base.@propagate_inbounds function Base.getindex(x::LabeledArrOrSubOrReshape{V,N},
I::Vararg{Int,N}) where {V,N}
val = refarray(x)[I...]
return LabeledValue(val, getvaluelabels(x))
Base.@propagate_inbounds function Base.getindex(x::LabeledArrOrSubOrReshape{V,K,N},
I::Vararg{Int,N}) where {V,K,N}
val = refarray(x)[I...]::V
return LabeledValue{V,K}(val, getvaluelabels(x))
end

Base.@propagate_inbounds function Base.getindex(x::LabeledArrOrSubOrReshape{V,N},
I::Vararg{Integer,N}) where {V,N}
val = refarray(x)[I...]
return LabeledValue(val, getvaluelabels(x))
Base.@propagate_inbounds function Base.getindex(x::LabeledArrOrSubOrReshape{V,K,N},
I::Vararg{Integer,N}) where {V,K,N}
val = refarray(x)[I...]::V
return LabeledValue{V,K}(val, getvaluelabels(x))
end

Base.@propagate_inbounds function Base.getindex(x::LabeledArrOrSubOrReshape{V,N},
Expand All @@ -330,7 +374,11 @@ Base.fill!(x::LabeledArrOrSubOrReshape, v) = (fill!(refarray(x), unwrap(v)); x)

Base.resize!(x::LabeledVector, n::Integer) = (resize!(refarray(x), n); x)
Base.push!(x::LabeledVector, v) = (push!(refarray(x), unwrap(v)); x)
Base.push!(x::LabeledVector, p::Pair) =
(getvaluelabels(x)[p[1]] = p[2]; push!(refarray(x), p[1]); x)
Base.pushfirst!(x::LabeledVector, v) = (pushfirst!(refarray(x), unwrap(v)); x)
Base.pushfirst!(x::LabeledVector, p::Pair) =
(getvaluelabels(x)[p[1]] = p[2]; pushfirst!(refarray(x), p[1]); x)
Base.insert!(x::LabeledVector, i, v) = (insert!(refarray(x), i, unwrap(v)); x)
Base.deleteat!(x::LabeledVector, i) = (deleteat!(refarray(x), i); x)
Base.append!(x::LabeledVector, v) = (append!(refarray(x), refarray(v)); x)
Expand Down Expand Up @@ -427,6 +475,9 @@ Base.collect(x::LabeledArrOrSubOrReshape) =
Base.collect(::Type{<:LabeledValue{T}}, x::LabeledArrOrSubOrReshape) where T =
LabeledArray(collect(T, refarray(x)), getvaluelabels(x))

disallowmissing(x::LabeledArrOrSubOrReshape) =
LabeledArray(disallowmissing(refarray(x)), getvaluelabels(x))

# Assume VERSION >= v"1.3.0"
# Define abbreviated element type name for printing with PrettyTables.jl
function compact_type_str(::Type{<:LabeledValue{V}}) where V
Expand Down
4 changes: 3 additions & 1 deletion src/ReadStatTables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Dates
using Dates: unix2datetime
using InlineStrings
using MappedArrays: MappedArray, mappedarray
using PooledArrays: PooledArray, PooledVector, RefArray
using PooledArrays: PooledArray, PooledVector, RefArray, _label
using PrettyTables: pretty_table
using ReadStat_jll
using SentinelArrays: SentinelVector, ChainedVector
Expand All @@ -16,13 +16,15 @@ using Tables
import DataAPI: defaultarray, refarray, unwrap, nrow, ncol,
metadatasupport, colmetadatasupport,
metadata, metadatakeys, metadata!, colmetadata, colmetadatakeys, colmetadata!
import Missings: disallowmissing
import PrettyTables: compact_type_str
import Tables: columnnames

export refarray, unwrap, nrow, ncol, metadata, metadatakeys, metadata!,
colmetadata, colmetadatakeys, colmetadata!
export Date, DateTime # Needed for avoiding the "Dates." qualifier when printing tables
export String3, String7, String15, String31, String63, String127, String255
export disallowmissing
export columnnames

export LabeledValue,
Expand Down
9 changes: 7 additions & 2 deletions src/writestat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ function _set_vallabels!(colmetavec, vallabels, lblname, refpoolaslabel, names,
colmetavec.vallabel[i] = lblname
end
end
return get(vallabels, lblname, nothing) # Return the labels
end

"""
Expand Down Expand Up @@ -157,7 +158,7 @@ function ReadStatTable(table, ext::AbstractString;
colmeta.type[i] = colmetadata(table, i, "type", type)
lblname = colmetadata(table, i, "vallabel", Symbol())
colmeta.vallabel[i] = lblname
_set_vallabels!(colmeta, vallabels, lblname, refpoolaslabel, names, col, i)
lbls = _set_vallabels!(colmeta, vallabels, lblname, refpoolaslabel, names, col, i)
# type may have been modified based on refarray
type = colmeta.type[i]
if type === READSTAT_TYPE_STRING
Expand All @@ -169,7 +170,11 @@ function ReadStatTable(table, ext::AbstractString;
width = Csize_t(0)
end
colmeta.storage_width[i] = width
colmeta.display_width[i] = max(Cint(width), Cint(9))
if lbls === nothing
colmeta.display_width[i] = max(Cint(width), Cint(9))
else
colmeta.display_width[i] = max(Cint(maximum(length, values(lbls))), Cint(9))
end
colmeta.measure[i] = READSTAT_MEASURE_UNKNOWN
colmeta.alignment[i] = READSTAT_ALIGNMENT_UNKNOWN
if copycols
Expand Down
45 changes: 40 additions & 5 deletions test/LabeledArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
@test isless(v1, v2)
@test isapprox(v1, v3)

@test !iszero(v1)
@test iszero(LabeledValue(0, lbls1))
@test !isnan(v1)
@test isnan(LabeledValue(NaN, lbls1))
@test !isinf(v1)
@test isinf(LabeledValue(Inf, lbls1))
@test isfinite(v1)
@test !isfinite(LabeledValue(Inf, lbls1))

@test v1 == 1
@test 1 == v1
@test ismissing(v1 == missing)
Expand All @@ -37,6 +46,8 @@
d = Dict{LabeledValue, Int}(v1 => 1)
@test haskey(d, v3)

@test length(v1) == 1

@test unwrap(v1) === 1
@test valuelabel(v1) == "a"
@test valuelabel(v4) == "missing"
Expand All @@ -46,7 +57,7 @@
@test sprint(show, v4) == "missing"
@test sprint(show, MIME("text/plain"), v1) == "1 => a"

v5 = convert(LabeledValue{Int16}, v1)
v5 = convert(LabeledValue{Int16, Int32}, v1)
@test v5.value isa Int16
@test v5.labels === v1.labels
@test convert(String, v1) == "a"
Expand All @@ -60,10 +71,20 @@ end
@test size(x) == (6,)
@test IndexStyle(typeof(x)) == IndexStyle(typeof(vals))
@test DataAPI.defaultarray(eltype(x), 1) == typeof(x)
@test length(unique(x)) == 3
@test typeof(repeat(x, 2)) == typeof(x)
@test repeat(x, 2) == LabeledArray(repeat(vals, 2), lbls)
@test typeof(repeat(x, inner=2, outer=2)) == typeof(x)
@test repeat(x, inner=2, outer=2) == LabeledArray(repeat(vals, inner=2, outer=2), lbls)
x0 = typeof(x)(undef, 10)
@test length(x0) == 10
@test isempty(getvaluelabels(x0))

v = ["a", "b", "c"]
l = LabeledArray(v, Int16)
@test l == 1:3
@test valuelabels(l) == v

@test x[1] === LabeledValue(1, lbls)
@test x[Int16(1)] === LabeledValue(1, lbls)
s = x[2:3]
Expand Down Expand Up @@ -133,12 +154,20 @@ end
x1 = LabeledArray(vals1, lbls)
@test isequal(x1[3], missing)
@test isequal(x1[3], LabeledValue(missing, Dict{Any,String}()))
@test length(unique(vals1)) == 3

@test refarray(x) === x.values
@test refarray(view(x, [1, 3])) == 1:2
@test refarray(reshape(x, 3, 2)) == reshape(x.values, 3, 2)
@test refarray(view(x2, 1:3)) == x2.values[1:3]

@test_throws MethodError disallowmissing(x1)
@test typeof(disallowmissing(x1[1:2])) ==
LabeledVector{Int64, Vector{Int64}, Union{Missing, Int64}}

@test typeof(repeat(x1, 2)) == typeof(x1)
@test typeof(repeat(x1, inner=2)) == typeof(x1)

x3 = LabeledArray(copy(vals1), lbls)
fill!(x3, 1)
@test all(x3 .== 1)
Expand All @@ -151,18 +180,24 @@ end
@test length(v) == 7
push!(v, 1)
@test v[8] == 1
push!(v, 5=>"0")
@test v[end] == 5
@test getvaluelabels(v)[5] == "0"
pushfirst!(v, 2)
@test v[1] == 2
pushfirst!(v, 6=>"-1")
@test v[1] == 6
@test getvaluelabels(v)[6] == "-1"
insert!(v, 3, 4)
@test v[3] == 4
deleteat!(v, 3)
@test length(v) == 9
@test length(v) == 11
@test v[3] == 1
append!(v, 1:3)
@test length(v) == 12
@test v[10:12] == 1:3
@test length(v) == 14
@test v[12:14] == 1:3
prepend!(v, 1:3)
@test length(v) == 15
@test length(v) == 17
@test v[1:3] == 1:3
empty!(v)
@test isempty(v)
Expand Down
1 change: 1 addition & 0 deletions test/writestat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# Date/Time columns are converted to numbers
@test eltype(getfield(tb, :columns)[8]) >: Float64
@test eltype(getfield(tb, :columns)[9]) >: Float64
@test colmetadata(tb, :vstrL, "display_width") == length(df.vstrL[1])

df = DataFrame(readstat(alltypes))
emptycolmetadata!(df)
Expand Down

0 comments on commit b6906d6

Please # to comment.