diff --git a/safecast.go b/safecast.go index e71cd1a..3d53cad 100644 --- a/safecast.go +++ b/safecast.go @@ -27,17 +27,24 @@ type Number interface { var ErrOutOfRange = errors.New("out of range") -func Negative[T Number](t T) bool { - return t < 0 -} - -func SameSign[T1, T2 Number](a T1, b T2) bool { - return Negative(a) == Negative(b) -} +const all64bitsOne = ^uint64(0) // same as uint64(math.MaxUint64) +// Convert converts a number from one type to another, +// returning an error if the conversion would result in a loss of precision, +// range or sign (overflow). In other words if the converted number is not +// equal to the original number. +// Do not use for identity (same type in and out) but in particular this +// will error for Convert[uint64](uint64(math.MaxUint64)) because it needs to +// when converting to any float. func Convert[NumOut Number, NumIn Number](orig NumIn) (converted NumOut, err error) { + origPositive := orig > 0 + // all bits set on uint64 is the only special case not detected by roundtrip (afaik). + if origPositive && (uint64(orig) == all64bitsOne) { + err = ErrOutOfRange + return + } converted = NumOut(orig) - if !SameSign(orig, converted) { + if origPositive != (converted > 0) { err = ErrOutOfRange return } @@ -47,6 +54,7 @@ func Convert[NumOut Number, NumIn Number](orig NumIn) (converted NumOut, err err return } +// Same as Convert but panics if there is an error. func MustConvert[NumOut Number, NumIn Number](orig NumIn) NumOut { converted, err := Convert[NumOut, NumIn](orig) if err != nil { @@ -55,14 +63,19 @@ func MustConvert[NumOut Number, NumIn Number](orig NumIn) NumOut { return converted } +// Converts a float to an integer by truncating the fractional part. +// Returns an error if the conversion would result in a loss of precision. func Truncate[NumOut Number, NumIn Float](orig NumIn) (converted NumOut, err error) { return Convert[NumOut](math.Trunc(float64(orig))) } +// Converts a float to an integer by rounding to the nearest integer. +// Returns an error if the conversion would result in a loss of precision. func Round[NumOut Number, NumIn Float](orig NumIn) (converted NumOut, err error) { return Convert[NumOut](math.Round(float64(orig))) } +// Same as Truncate but panics if there is an error. func MustTruncate[NumOut Number, NumIn Float](orig NumIn) NumOut { converted, err := Truncate[NumOut, NumIn](orig) if err != nil { @@ -71,6 +84,7 @@ func MustTruncate[NumOut Number, NumIn Float](orig NumIn) NumOut { return converted } +// Same as Round but panics if there is an error. func MustRound[NumOut Number, NumIn Float](orig NumIn) NumOut { converted, err := Round[NumOut, NumIn](orig) if err != nil { diff --git a/safecast_test.go b/safecast_test.go index 8758402..65a507d 100644 --- a/safecast_test.go +++ b/safecast_test.go @@ -2,6 +2,7 @@ package safecast_test import ( "fmt" + "math" "testing" "fortio.org/safecast" @@ -9,6 +10,99 @@ import ( // TODO: steal the tests from https://github.com/ccoVeille/go-safecast +const all64bitsOne = ^uint64(0) + +// Interesting part is the "true" for the first line, which is why we have to change the +// code in Convert to handle that 1 special case. +// safecast_test.go:22: bits 64: 1111111111111111111111111111111111111111111111111111111111111111 +// : 18446744073709551615 -> float64 18446744073709551616 true. +func FindNumIntBits[T safecast.Float](t *testing.T) int { + var v T + for i := 0; i < 64; i++ { + bits := (all64bitsOne >> i) + v = T(bits) + t.Logf("bits %02d: %b : %d -> %T %.0f %t", 64-i, bits, bits, v, v, uint64(v) == bits) + if v != v-1 { + return 64 - i + } + } + panic("bug... didn't fine num bits") +} + +func TestFloat32Bounds(t *testing.T) { + float32bits := FindNumIntBits[float32](t) + t.Logf("float32: %d bits", float32bits) + float32int := uint64(1<<(float32bits) - 1) // 24 bits + for i := 0; i <= 64-float32bits; i++ { + t.Logf("float32int %b %d", float32int, float32int) + f := safecast.MustConvert[float32](float32int) + t.Logf("float32int -> %.0f", f) + float32int <<= 1 + } +} + +func TestFloat64Bounds(t *testing.T) { + float64bits := FindNumIntBits[float64](t) + t.Logf("float64: %d bits", float64bits) + float64int := uint64(1<<(float64bits) - 1) // 53 bits + for i := 0; i <= 64-float64bits; i++ { + t.Logf("float64int %b %d", float64int, float64int) + f := safecast.MustConvert[float64](float64int) + t.Logf("float64int -> %.0f", f) + float64int <<= 1 + } +} + +func TestNonIntegerFloat(t *testing.T) { + _, err := safecast.Convert[int](math.Pi) + if err == nil { + t.Errorf("expected error") + } + truncPi := math.Trunc(math.Pi) // math.Trunc returns a float64 + i, err := safecast.Convert[int](truncPi) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if i != 3 { + t.Errorf("unexpected value: %v", i) + } + i, err = safecast.Truncate[int](math.Pi) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if i != 3 { + t.Errorf("unexpected value: %v", i) + } + i, err = safecast.Round[int](math.Phi) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if i != 2 { + t.Errorf("unexpected value: %v", i) + } +} + +// MaxUint64 special case and also MaxInt64+1. +func TestMaxInt64(t *testing.T) { + f32, err := safecast.Convert[float32](all64bitsOne) + if err == nil { + t.Errorf("expected error, got %d -> %.0f", all64bitsOne, f32) + } + f64, err := safecast.Convert[float64](all64bitsOne) + if err == nil { + t.Errorf("expected error, got %d -> %.0f", all64bitsOne, f64) + } + minInt64p1 := int64(math.MinInt64) + 1 // not a power of 2 + t.Logf("minInt64p1 %b %d", minInt64p1, minInt64p1) + _, err = safecast.Convert[float64](minInt64p1) + f64 = float64(minInt64p1) + int2 := int64(f64) + t.Logf("minInt64p1 -> %.0f %d", f64, int2) + if err == nil { + t.Errorf("expected error, got %d -> %.0f", minInt64p1, f64) + } +} + func TestConvert(t *testing.T) { var inp uint32 = 42 out, err := safecast.Convert[int8](inp) @@ -25,7 +119,7 @@ func TestConvert(t *testing.T) { if err == nil { t.Errorf("expected error") } - inp2 := int32(-42) + inp2 := int32(-1) _, err = safecast.Convert[uint8](inp2) t.Logf("Got err: %v", err) if err == nil { @@ -36,7 +130,7 @@ func TestConvert(t *testing.T) { if err != nil { t.Errorf("unexpected error: %v", err) } - if out != -42 { + if out != -1 { t.Errorf("unexpected value: %v", out) } inp2 = -129