diff --git a/benchmarks_test.go b/benchmarks_test.go index c6dbd73..78bd480 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -944,3 +944,19 @@ func BenchmarkSet(bench *testing.B) { bench.Run("single/uint256", benchmarkUint256) bench.Run("single/big", benchmarkBig) } + +func BenchmarkByte(bench *testing.B) { + var ( + a = new(Int).SetBytes(hex2Bytes("f123456789abcdeffedcba9876543210f2f3f4f5f6f7f8f9fff3f4f5f6f7f8f9")) + num = NewInt(0) + result = new(Int) + ) + bench.ResetTimer() + for i := 0; i < bench.N; i++ { + for ii := uint64(0); ii < 32; ii++ { + result.Set(a) + num.SetUint64(ii) + _ = result.Byte(num) + } + } +} diff --git a/uint256.go b/uint256.go index 726f7c0..f3caf7b 100644 --- a/uint256.go +++ b/uint256.go @@ -1236,22 +1236,20 @@ func (z *Int) Xor(x, y *Int) *Int { } // Byte sets z to the value of the byte at position n, -// with 'z' considered as a big-endian 32-byte integer -// if 'n' > 32, f is set to 0 -// Example: f = '5', n=31 => 5 +// with z considered as a big-endian 32-byte integer. +// if n >= 32, z is set to 0 +// Example: z=5, n=31 => 5 func (z *Int) Byte(n *Int) *Int { - // in z, z[0] is the least significant - // - if number, overflow := n.Uint64WithOverflow(); !overflow { - if number < 32 { - number := z[4-1-number/8] - offset := (n[0] & 0x7) << 3 // 8*(n.d % 8) - z[0] = (number & (0xff00000000000000 >> offset)) >> (56 - offset) - z[3], z[2], z[1] = 0, 0, 0 - return z - } + index, overflow := n.Uint64WithOverflow() + if overflow || index >= 32 { + return z.Clear() } - return z.Clear() + // in z, z[0] is the least significant + number := z[4-1-index/8] + offset := (index & 0x7) << 3 // 8 * (index % 8) + z[0] = (number >> (56 - offset)) & 0xff + z[3], z[2], z[1] = 0, 0, 0 + return z } // Exp sets z = base**exponent mod 2**256, and returns z. diff --git a/uint256_test.go b/uint256_test.go index c739926..0d21890 100644 --- a/uint256_test.go +++ b/uint256_test.go @@ -881,32 +881,27 @@ func TestSRsh(t *testing.T) { } func TestByte(t *testing.T) { - z := new(Int).SetBytes(hex2Bytes("ABCDEF09080706050403020100000000000000000000000000000000000000ef")) - actual := z.Byte(NewInt(0)) - expected := new(Int).SetBytes(hex2Bytes("00000000000000000000000000000000000000000000000000000000000000ab")) - if !actual.Eq(expected) { - t.Fatalf("Expected %x, got %x", expected, actual) - } - - z = new(Int).SetBytes(hex2Bytes("ABCDEF09080706050403020100000000000000000000000000000000000000ef")) - actual = z.Byte(NewInt(31)) - expected = new(Int).SetBytes(hex2Bytes("00000000000000000000000000000000000000000000000000000000000000ef")) - if !actual.Eq(expected) { - t.Fatalf("Expected %x, got %x", expected, actual) - } - - z = new(Int).SetBytes(hex2Bytes("ABCDEF09080706050403020100000000000000000000000000000000000000ef")) - actual = z.Byte(NewInt(32)) - expected = new(Int).SetBytes(hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")) - if !actual.Eq(expected) { - t.Fatalf("Expected %x, got %x", expected, actual) - } - - z = new(Int).SetBytes(hex2Bytes("ABCDEF0908070605040302011111111111111111111111111111111111111111")) - actual = z.Byte(new(Int).SetBytes(hex2Bytes("f000000000000000000000000000000000000000000000000000000000000001"))) - expected = new(Int).SetBytes(hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")) - if !actual.Eq(expected) { - t.Fatalf("Expected %x, got %x", expected, actual) + input, err := FromHex("0x102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") + if err != nil { + t.Fatal(err) + } + for i := uint64(0); i < 35; i++ { + var ( + z = input.Clone() + index = NewInt(i) + have = z.Byte(index) + want = NewInt(i) + ) + if i >= 32 { + want.Clear() + } + if !have.Eq(want) { + t.Fatalf("index %d: have %#x want %#x", i, have, want) + } + // Also check that we indeed modified the z + if z != have { + t.Fatalf("index %d: should return self %v %v", i, z, have) + } } }