Skip to content

Commit d343798

Browse files
authored
GODRIVER-3240 Code hardening. (#1684)
1 parent 59d0717 commit d343798

File tree

9 files changed

+67
-45
lines changed

9 files changed

+67
-45
lines changed

bson/bsoncodec/default_value_decoders.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr bsonrw.ValueReade
330330
case reflect.Int64:
331331
return reflect.ValueOf(i64), nil
332332
case reflect.Int:
333-
if int64(int(i64)) != i64 { // Can we fit this inside of an int
333+
if i64 > math.MaxInt { // Can we fit this inside of an int
334334
return emptyValue, fmt.Errorf("%d overflows int", i64)
335335
}
336336

@@ -434,7 +434,7 @@ func (dvd DefaultValueDecoders) UintDecodeValue(dc DecodeContext, vr bsonrw.Valu
434434
return fmt.Errorf("%d overflows uint64", i64)
435435
}
436436
case reflect.Uint:
437-
if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint
437+
if i64 < 0 || uint64(i64) > uint64(math.MaxUint) { // Can we fit this inside of an uint
438438
return fmt.Errorf("%d overflows uint", i64)
439439
}
440440
default:

bson/bsoncodec/uint_codec.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,15 @@ func (uic *UIntCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t refl
164164

165165
return reflect.ValueOf(uint64(i64)), nil
166166
case reflect.Uint:
167-
if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint
167+
if i64 < 0 {
168+
return emptyValue, fmt.Errorf("%d overflows uint", i64)
169+
}
170+
v := uint64(i64)
171+
if v > math.MaxUint { // Can we fit this inside of an uint
168172
return emptyValue, fmt.Errorf("%d overflows uint", i64)
169173
}
170174

171-
return reflect.ValueOf(uint(i64)), nil
175+
return reflect.ValueOf(uint(v)), nil
172176
default:
173177
return emptyValue, ValueDecoderError{
174178
Name: "UintDecodeValue",

bson/bsonrw/extjson_wrappers.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
9595
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
9696
}
9797

98-
i, err := strconv.ParseInt(val.v.(string), 16, 64)
98+
i, err := strconv.ParseUint(val.v.(string), 16, 8)
9999
if err != nil {
100-
return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string))
100+
return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err)
101101
}
102102

103103
subType = byte(i)

bson/bsonrw/value_reader.go

+5-7
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ func (vr *valueReader) peekLength() (int32, error) {
842842
}
843843

844844
idx := vr.offset
845-
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
845+
return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil
846846
}
847847

848848
func (vr *valueReader) readLength() (int32, error) { return vr.readi32() }
@@ -854,7 +854,7 @@ func (vr *valueReader) readi32() (int32, error) {
854854

855855
idx := vr.offset
856856
vr.offset += 4
857-
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
857+
return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil
858858
}
859859

860860
func (vr *valueReader) readu32() (uint32, error) {
@@ -864,7 +864,7 @@ func (vr *valueReader) readu32() (uint32, error) {
864864

865865
idx := vr.offset
866866
vr.offset += 4
867-
return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil
867+
return binary.LittleEndian.Uint32(vr.d[idx:]), nil
868868
}
869869

870870
func (vr *valueReader) readi64() (int64, error) {
@@ -874,8 +874,7 @@ func (vr *valueReader) readi64() (int64, error) {
874874

875875
idx := vr.offset
876876
vr.offset += 8
877-
return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 |
878-
int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil
877+
return int64(binary.LittleEndian.Uint64(vr.d[idx:])), nil
879878
}
880879

881880
func (vr *valueReader) readu64() (uint64, error) {
@@ -885,6 +884,5 @@ func (vr *valueReader) readu64() (uint64, error) {
885884

886885
idx := vr.offset
887886
vr.offset += 8
888-
return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 |
889-
uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil
887+
return binary.LittleEndian.Uint64(vr.d[idx:]), nil
890888
}

cmd/testatlas/main.go cmd/testatlas/atlas_test.go

+12-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111
"errors"
1212
"flag"
1313
"fmt"
14+
"os"
15+
"testing"
1416
"time"
1517

1618
"go.mongodb.org/mongo-driver/bson"
@@ -19,15 +21,19 @@ import (
1921
"go.mongodb.org/mongo-driver/mongo/options"
2022
)
2123

22-
func main() {
24+
func TestMain(m *testing.M) {
2325
flag.Parse()
26+
os.Exit(m.Run())
27+
}
28+
29+
func TestAtlas(t *testing.T) {
2430
uris := flag.Args()
2531
ctx := context.Background()
2632

27-
fmt.Printf("Running atlas tests for %d uris\n", len(uris))
33+
t.Logf("Running atlas tests for %d uris\n", len(uris))
2834

2935
for idx, uri := range uris {
30-
fmt.Printf("Running test %d\n", idx)
36+
t.Logf("Running test %d\n", idx)
3137

3238
// Set a low server selection timeout so we fail fast if there are errors.
3339
clientOpts := options.Client().
@@ -36,18 +42,18 @@ func main() {
3642

3743
// Run basic connectivity test.
3844
if err := runTest(ctx, clientOpts); err != nil {
39-
panic(fmt.Sprintf("error running test with TLS at index %d: %v", idx, err))
45+
t.Fatalf("error running test with TLS at index %d: %v", idx, err)
4046
}
4147

4248
// Run the connectivity test with InsecureSkipVerify to ensure SNI is done correctly even if verification is
4349
// disabled.
4450
clientOpts.TLSConfig.InsecureSkipVerify = true
4551
if err := runTest(ctx, clientOpts); err != nil {
46-
panic(fmt.Sprintf("error running test with tlsInsecure at index %d: %v", idx, err))
52+
t.Fatalf("error running test with tlsInsecure at index %d: %v", idx, err)
4753
}
4854
}
4955

50-
fmt.Println("Finished!")
56+
t.Logf("Finished!")
5157
}
5258

5359
func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {

etc/run-atlas-test.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ set +x
77
# Get the atlas secrets.
88
. ${DRIVERS_TOOLS}/.evergreen/secrets_handling/setup-secrets.sh drivers/atlas_connect
99

10-
echo "Running cmd/testatlas/main.go"
11-
go run ./cmd/testatlas/main.go "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite
10+
echo "Running cmd/testatlas"
11+
go test -v -run ^TestAtlas$ go.mongodb.org/mongo-driver/cmd/testatlas -args "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite

internal/logger/io_sink.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package logger
99
import (
1010
"encoding/json"
1111
"io"
12+
"math"
1213
"sync"
1314
"time"
1415
)
@@ -36,7 +37,11 @@ func NewIOSink(out io.Writer) *IOSink {
3637

3738
// Info will write a JSON-encoded message to the io.Writer.
3839
func (sink *IOSink) Info(_ int, msg string, keysAndValues ...interface{}) {
39-
kvMap := make(map[string]interface{}, len(keysAndValues)/2+2)
40+
mapSize := len(keysAndValues) / 2
41+
if math.MaxInt-mapSize >= 2 {
42+
mapSize += 2
43+
}
44+
kvMap := make(map[string]interface{}, mapSize)
4045

4146
kvMap[KeyTimestamp] = time.Now().UnixNano()
4247
kvMap[KeyMessage] = msg

mongo/options/clientoptions.go

+14-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"errors"
1616
"fmt"
1717
"io/ioutil"
18+
"math"
1819
"net"
1920
"net/http"
2021
"strings"
@@ -1177,7 +1178,19 @@ func addClientCertFromSeparateFiles(cfg *tls.Config, keyFile, certFile, keyPassw
11771178
return "", err
11781179
}
11791180

1180-
data := make([]byte, 0, len(keyData)+len(certData)+1)
1181+
keySize := len(keyData)
1182+
if keySize > 64*1024*1024 {
1183+
return "", errors.New("X.509 key must be less than 64 MiB")
1184+
}
1185+
certSize := len(certData)
1186+
if certSize > 64*1024*1024 {
1187+
return "", errors.New("X.509 certificate must be less than 64 MiB")
1188+
}
1189+
dataSize := keySize + certSize + 1
1190+
if dataSize > math.MaxInt {
1191+
return "", errors.New("size overflow")
1192+
}
1193+
data := make([]byte, 0, dataSize)
11811194
data = append(data, keyData...)
11821195
data = append(data, '\n')
11831196
data = append(data, certData...)

x/bsonx/bsoncore/bsoncore.go

+18-22
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package bsoncore // import "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
88

99
import (
1010
"bytes"
11+
"encoding/binary"
1112
"fmt"
1213
"math"
1314
"strconv"
@@ -706,17 +707,16 @@ func ReserveLength(dst []byte) (int32, []byte) {
706707

707708
// UpdateLength updates the length at index with length and returns the []byte.
708709
func UpdateLength(dst []byte, index, length int32) []byte {
709-
dst[index] = byte(length)
710-
dst[index+1] = byte(length >> 8)
711-
dst[index+2] = byte(length >> 16)
712-
dst[index+3] = byte(length >> 24)
710+
binary.LittleEndian.PutUint32(dst[index:], uint32(length))
713711
return dst
714712
}
715713

716714
func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) }
717715

718716
func appendi32(dst []byte, i32 int32) []byte {
719-
return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24))
717+
b := []byte{0, 0, 0, 0}
718+
binary.LittleEndian.PutUint32(b, uint32(i32))
719+
return append(dst, b...)
720720
}
721721

722722
// ReadLength reads an int32 length from src and returns the length and the remaining bytes. If
@@ -734,51 +734,47 @@ func readi32(src []byte) (int32, []byte, bool) {
734734
if len(src) < 4 {
735735
return 0, src, false
736736
}
737-
return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true
737+
return int32(binary.LittleEndian.Uint32(src)), src[4:], true
738738
}
739739

740740
func appendi64(dst []byte, i64 int64) []byte {
741-
return append(dst,
742-
byte(i64), byte(i64>>8), byte(i64>>16), byte(i64>>24),
743-
byte(i64>>32), byte(i64>>40), byte(i64>>48), byte(i64>>56),
744-
)
741+
b := []byte{0, 0, 0, 0, 0, 0, 0, 0}
742+
binary.LittleEndian.PutUint64(b, uint64(i64))
743+
return append(dst, b...)
745744
}
746745

747746
func readi64(src []byte) (int64, []byte, bool) {
748747
if len(src) < 8 {
749748
return 0, src, false
750749
}
751-
i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 |
752-
int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56)
753-
return i64, src[8:], true
750+
return int64(binary.LittleEndian.Uint64(src)), src[8:], true
754751
}
755752

756753
func appendu32(dst []byte, u32 uint32) []byte {
757-
return append(dst, byte(u32), byte(u32>>8), byte(u32>>16), byte(u32>>24))
754+
b := []byte{0, 0, 0, 0}
755+
binary.LittleEndian.PutUint32(b, u32)
756+
return append(dst, b...)
758757
}
759758

760759
func readu32(src []byte) (uint32, []byte, bool) {
761760
if len(src) < 4 {
762761
return 0, src, false
763762
}
764763

765-
return (uint32(src[0]) | uint32(src[1])<<8 | uint32(src[2])<<16 | uint32(src[3])<<24), src[4:], true
764+
return binary.LittleEndian.Uint32(src), src[4:], true
766765
}
767766

768767
func appendu64(dst []byte, u64 uint64) []byte {
769-
return append(dst,
770-
byte(u64), byte(u64>>8), byte(u64>>16), byte(u64>>24),
771-
byte(u64>>32), byte(u64>>40), byte(u64>>48), byte(u64>>56),
772-
)
768+
b := []byte{0, 0, 0, 0, 0, 0, 0, 0}
769+
binary.LittleEndian.PutUint64(b, u64)
770+
return append(dst, b...)
773771
}
774772

775773
func readu64(src []byte) (uint64, []byte, bool) {
776774
if len(src) < 8 {
777775
return 0, src, false
778776
}
779-
u64 := (uint64(src[0]) | uint64(src[1])<<8 | uint64(src[2])<<16 | uint64(src[3])<<24 |
780-
uint64(src[4])<<32 | uint64(src[5])<<40 | uint64(src[6])<<48 | uint64(src[7])<<56)
781-
return u64, src[8:], true
777+
return binary.LittleEndian.Uint64(src), src[8:], true
782778
}
783779

784780
// keep in sync with readcstringbytes

0 commit comments

Comments
 (0)