diff --git a/src/recode.jl b/src/recode.jl index dbec74f3..25854670 100644 --- a/src/recode.jl +++ b/src/recode.jl @@ -36,22 +36,19 @@ recode!(dest::CategoricalArray, src::AbstractArray, pairs::Pair...) = recode!(dest::CategoricalArray, src::CategoricalArray, pairs::Pair...) = recode!(dest, src, nothing, pairs...) - """ recode_in(x, collection) Helper function to test if `x` is a member of `collection`. The default method is to test if any element in the `collection` `isequal` to -`x`. For sets using `in` should be faster than the default method. +`x`. For `Set`s `in` is used as it is faster than the default method and equivalent to it. A user defined type could override this method to define an appropriate test function. """ @inline recode_in(x, ::Missing) = false -@inline recode_in(x, ::AbstractString) = false @inline recode_in(x, collection::Set) = x in collection @inline recode_in(x, collection) = any(x ≅ y for y in collection) - function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T} if length(dest) != length(src) throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))")) @@ -63,7 +60,7 @@ function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs for j in 1:length(pairs) p = pairs[j] # we use isequal and recode_in because we cannot really distinguish scalars from collections - if (x ≅ p.first || recode_in(x, p.first)) + if x ≅ p.first || recode_in(x, p.first) dest[i] = p.second @goto nextitem end @@ -116,7 +113,7 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa for j in 1:length(pairs) p = pairs[j] # we use isequal and recode_in because we cannot really distinguish scalars from collections - if (x ≅ p.first || recode_in(x, p.first)) + if x ≅ p.first || recode_in(x, p.first) drefs[i] = dupvals ? pairmap[j] : j @goto nextitem end @@ -229,7 +226,7 @@ function recode!(dest::CategoricalArray{T, N, R}, src::CategoricalArray, @inbounds for (i, l) in enumerate(srclevels) for j in 1:length(pairs) p = pairs[j] - if (l ≅ p.first || recode_in(l, p.first)) + if l ≅ p.first || recode_in(l, p.first) levelsmap[i+1] = pairmap[j] @goto nextitem end diff --git a/test/16_recode.jl b/test/16_recode.jl index db3d2396..757d744b 100644 --- a/test/16_recode.jl +++ b/test/16_recode.jl @@ -12,6 +12,8 @@ const ≅ = isequal @testset "recode_in" begin @testset "collection is a string" begin @test !CategoricalArrays.recode_in("a", "ab") + @test CategoricalArrays.recode_in('a', "ab") + @test !CategoricalArrays.recode_in('c', "ab") @test !CategoricalArrays.recode_in(missing, "b") end @testset "collection without missing" begin