Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Context #75

Merged
merged 1 commit into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.8'
- '1'
os:
- ubuntu-latest
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ name = "Decimals"
uuid = "abce61dc-4473-55a0-ba07-351d65e31d42"
version = "0.4.1"

[deps]
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"

[compat]
julia = "1"
ScopedValues = "1"
julia = "1.8"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
148 changes: 91 additions & 57 deletions scripts/dectest.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,48 @@
function _precision(line)
m = match(r"^precision:\s*(\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _rounding(line)
m = match(r"^rounding:\s*(\w+)$", line)
return Symbol(m[1])
isnothing(m) && throw(ArgumentError(line))
r = m[1]
if r == "ceiling"
return "RoundUp"
elseif r == "down"
return "RoundToZero"
elseif r == "floor"
return "RoundDown"
elseif r == "half_even"
return "RoundNearest"
elseif r == "half_up"
return "RoundNearestTiesAway"
elseif r == "up"
return "RoundFromZero"
elseif r == "half_down"
return "RoundHalfDownUnsupported"
elseif r == "05up"
return "Round05UpUnsupported"
else
throw(ArgumentError(r))
end
end

function _maxexponent(line)
m = match(r"^maxexponent:\s*(\d+)$", line)
m = match(r"^maxexponent:\s*\+?(\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _minexponent(line)
m = match(r"^minexponent:\s*(-\d+)$", line)
isnothing(m) && throw(ArgumentError(line))
return parse(Int, m[1])
end

function _test(line)
occursin("->", line) || throw(ArgumentError(line))
lhs, rhs = split(line, "->")
id, operation, operands... = split(lhs)
result, conditions... = split(rhs)
Expand All @@ -31,47 +55,55 @@ function decimal(x)
return "dec\"$x\""
end

print_precision(io, p::Int) = println(io, " setprecision(Decimal, $p)")
print_maxexponent(io, e::Int) = println(io, " Decimals.CONTEXT.Emax = $e")
print_minexponent(io, e::Int) = println(io, " Decimals.CONTEXT.Emin = $e")
function print_rounding(io, r::Symbol)
modes = Dict(:ceiling => "RoundUp",
:down => "RoundToZero",
:floor => "RoundDown",
:half_even => "RoundNearest",
:half_up => "RoundNearestTiesAway",
:up => "RoundFromZero",
:half_down => "RoundHalfDownUnsupported",
Symbol("05up") => "Round05UpUnsupported")
haskey(modes, r) || throw(ArgumentError(r))
rmod = modes[r]
println(io, " setrounding(Decimal, $rmod)")
end

function print_operation(io, operation, operands)
if operation == "plus"
print_plus(io, operands...)
elseif operation == "minus"
print_minus(io, operands...)
if operation == "abs"
print_abs(io, operands...)
elseif operation == "add"
print_add(io, operands...)
elseif operation == "apply"
print_apply(io, operands...)
elseif operation == "compare"
print_compare(io, operands...)
elseif operation == "divide"
print_divide(io, operands...)
elseif operation == "minus"
print_minus(io, operands...)
elseif operation == "multiply"
print_multiply(io, operands...)
elseif operation == "plus"
print_plus(io, operands...)
elseif operation == "reduce"
print_reduce(io, operands...)
elseif operation == "subtract"
print_subtract(io, operands...)
else
throw(ArgumentError(operation))
end
end
print_abs(io, x) = print(io, "abs(", decimal(x), ")")
print_add(io, x, y) = print(io, decimal(x), " + ", decimal(y))
print_apply(io, x) = print(io, decimal(x))
print_compare(io, x, y) = print(io, "cmp(", decimal(x), ", ", decimal(y), ")")
print_divide(io, x, y) = print(io, decimal(x), " / ", decimal(y))
print_minus(io, x) = print(io, "-(", decimal(x), ")")
print_multiply(io, x, y) = print(io, decimal(x), " * ", decimal(y))
print_plus(io, x) = print(io, "+(", decimal(x), ")")
print_reduce(io, x) = print(io, "reduce(", decimal(x), ")")
print_subtract(io, x, y) = print(io, decimal(x), " - ", decimal(y))

function print_test(io, test)
function print_test(io, test, directives)
println(io, " # $(test.id)")

names = sort!(collect(keys(directives)))
params = join(("$k=$(directives[k])" for k in names), ", ")
print(io, " @with_context ($params) ")

if :overflow ∈ test.conditions
print(io, " @test_throws OverflowError ")
print(io, "@test_throws OverflowError ")
print_operation(io, test.operation, test.operands)
println(io)
else
print(io, " @test ")
print(io, "@test ")
print_operation(io, test.operation, test.operands)
print(io, " == ")
println(io, decimal(test.result))
Expand All @@ -83,34 +115,36 @@ function isspecial(value)
return occursin(r"(inf|nan|#)", value)
end

function translate(io, line)
isempty(line) && return
startswith(line, "--") && return

line = lowercase(line)

if startswith(line, "version:")
# ...
elseif startswith(line, "extended:")
# ...
elseif startswith(line, "clamp:")
# ...
elseif startswith(line, "precision:")
precision = _precision(line)
print_precision(io, precision)
elseif startswith(line, "rounding:")
rounding = _rounding(line)
print_rounding(io, rounding)
elseif startswith(line, "maxexponent:")
maxexponent = _maxexponent(line)
print_maxexponent(io, maxexponent)
elseif startswith(line, "minexponent:")
minexponent = _minexponent(line)
print_minexponent(io, minexponent)
else
test = _test(line)
any(isspecial, test.operands) && return
print_test(io, test)
function translate(io, dectest_path)
directives = Dict{String, Any}()

for line in eachline(dectest_path)
line = strip(line)

isempty(line) && continue
startswith(line, "--") && continue

line = lowercase(line)

if startswith(line, "version:")
# ...
elseif startswith(line, "extended:")
# ...
elseif startswith(line, "clamp:")
# ...
elseif startswith(line, "precision:")
directives["precision"] = _precision(line)
elseif startswith(line, "rounding:")
directives["rounding"] = _rounding(line)
elseif startswith(line, "maxexponent:")
directives["Emax"] = _maxexponent(line)
elseif startswith(line, "minexponent:")
directives["Emin"] = _minexponent(line)
else
test = _test(line)
any(isspecial, test.operands) && continue
print_test(io, test, directives)
end
end
end

Expand All @@ -120,13 +154,13 @@ function (@main)(args=ARGS)
open(output_path, "w") do io
println(io, """
using Decimals
using ScopedValues
using Test
using Decimals: @with_context

@testset \"$name\" begin""")

for line in eachline(dectest_path)
translate(io, line)
end
translate(io, dectest_path)

println(io, "end")
end
Expand Down
1 change: 1 addition & 0 deletions src/Decimals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct Decimal <: AbstractFloat
end

include("bigint.jl")
include("context.jl")

# Convert between Decimal objects, numbers, and strings
include("decimal.jl")
Expand Down
6 changes: 2 additions & 4 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ Base.promote_rule(::Type{Decimal}, ::Type{<:Real}) = Decimal
Base.promote_rule(::Type{BigFloat}, ::Type{Decimal}) = Decimal
Base.promote_rule(::Type{BigInt}, ::Type{Decimal}) = Decimal

const BigTen = BigInt(10)
Base.:(+)(x::Decimal) = fix(x)
Base.:(-)(x::Decimal) = fix(Decimal(!x.s, x.c, x.q))

# Addition
# To add, convert both decimals to the same exponent.
Expand All @@ -24,9 +25,6 @@ function Base.:(+)(x::Decimal, y::Decimal)
return normalize(Decimal(s, abs(c), y.q))
end

# Negation
Base.:(-)(x::Decimal) = Decimal(!x.s, x.c, x.q)

# Subtraction
Base.:(-)(x::Decimal, y::Decimal) = +(x, -y)

Expand Down
134 changes: 134 additions & 0 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
using ScopedValues

Base.@kwdef struct Context
precision::Int=28
rounding::RoundingMode=RoundNearest
Emax::Int=999999
Emin::Int=-999999
end

const CONTEXT = ScopedValue(Context())

Base.precision(::Type{Decimal}) = CONTEXT[].precision
Base.rounding(::Type{Decimal}) = CONTEXT[].rounding

"""
with_context(f; kwargs...)

Run `f` with [`Context`](@ref) parametrized by `kwargs`.

# Examples

```jldoctest
julia> Decimals.with_context(precision=42) do
precision(Decimal)
end
42
```

See also [`@with_context`](@ref).
"""
function with_context(f; kwargs...)
with(f, CONTEXT => Context(;kwargs...))
end

"""
@with_context params expr

Run `expr` with [`Context`](@ref) parametrized by a named tuple `params`.

# Examples

```jldoctest
julia> Decimals.@with_context (precision=42, ) precision(Decimal)
42
```

See also [`with_context`](@ref).
"""
macro with_context(params, expr)
return quote
@with Decimals.CONTEXT => Decimals.Context(;$params...) $(esc(expr))
end
end

"""
fix(x)

Round and fix the exponent of `x` to keep it within the precision and exponent
limits as given by the current `CONTEXT`.
"""
function fix(x::Decimal)
prec = precision(Decimal)
rmod = rounding(Decimal)

Emin, Emax = CONTEXT[].Emin, CONTEXT[].Emax
Etiny = Emin - prec + 1
Etop = Emax - prec + 1

if iszero(x)
return Decimal(x.s, x.c, clamp(x.q, Etiny, Etop))
end

clen = ndigits(x.c)
exp_min = clen + x.q - prec

# Equivalent to `clen + x.q - 1 > Emax`
if exp_min > Etop
throw(OverflowError("Exponent limit ($Emax) exceeded: $x"))
end

subnormal = exp_min < Etiny
if subnormal
exp_min = Etiny
end

# Number of digits and exponent within bounds
if x.q ≥ exp_min
return x
end

# Number of digits of the resulting coefficient
digits = clen + x.q - exp_min
if digits < 0
x = Decimal(x.s, BigOne, exp_min - 1)
digits = 0
end

# Number of least significant digits to remove from `c`
trun_len = clen - digits

# Signed coefficient for rounding modes like RoundToZero
c = (-1)^x.s * x.c

# Split `c` into `digits` most significant digits and `trun_len` least
# significant digits
# This is like round(c, rmod, sigdigits=digits), except here we can
# tell from `rem` if the rounding was lossless
c, rem = divrem(c, BigTen ^ trun_len, rmod)

# Rounding is exact if the truncated digits were zero
exact = iszero(rem)

# If the number of digits exceeded `digits` after rounding,
# it means that `c` was like 99...9 and was rounded up,
# becoming 100...0, so `c` is divisible by 10
if ndigits(c) > prec
c = exactdiv(c, 10)
exp_min += 1
end

# Exponent might have exceeded due to rounding
if exp_min > Etop
throw(OverflowError("Exponent limit ($Emax) exceeded: $x"))
end

x = Decimal(x.s, abs(c), exp_min)

#if subnormal && !exact
# throw(ErrorException("Underflow"))
#end

return x
end

Loading