Skip to content

Commit

Permalink
move one more horrible comparison function into cmpx
Browse files Browse the repository at this point in the history
  • Loading branch information
mcy committed Feb 13, 2025
1 parent 12fca6c commit 9283ced
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 65 deletions.
4 changes: 1 addition & 3 deletions experimental/report/renderer.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,7 @@ func (r *renderer) window(w *window) {
// in a whole diagnostic, much less five snippets that share a line, so
// this shouldn't be an issue.
restSorted := slices.Clone(rest[idx:])
slices.SortFunc(restSorted, func(a, b *underline) int {
return a.start - b.start
})
slices.SortFunc(restSorted, cmpx.Key(func(u *underline) int { return u.start }))

var nonColorLen int
for _, ul := range restSorted {
Expand Down
132 changes: 131 additions & 1 deletion internal/ext/cmpx/cmpx.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
// package cmpx contains extensions to Go's package cmp.
package cmpx

import "cmp"
import (
"cmp"
"fmt"
"math"
"reflect"
"unsafe"
)

// Result is the type returned by an [Ordering], and in particular
// [cmp.Compare].
Expand All @@ -28,6 +34,11 @@ const (
Greater Result = 1
)

// Ordered is like [cmp.Ordered], but includes additional types.
type Ordered interface {
~bool | cmp.Ordered
}

// Ordering is an ordering for the type T, which is any function with the same
// signature as [Compare].
type Ordering[T any] func(T, T) Result
Expand All @@ -50,3 +61,122 @@ func Join[T any](cmps ...Ordering[T]) Ordering[T] {
return Equal
}
}

// Bool compares two bools, where false < true.
//
// This works around a bug where bool does not satisfy [cmp.Ordered].
func Bool[B ~bool](a, b B) Result {
var ai, bi byte
if a {
ai = 1
}
if b {
bi = 1
}
return cmp.Compare(ai, bi)
}

// Address compares two pointers by address.
//
// The result will be unstable if either pointer points to the stack, since
// stack resizing may cause their relative addresses to change.
func Address[P ~*E, E any](a, b P) Result {
return cmp.Compare(uintptr(unsafe.Pointer(a)), uintptr(unsafe.Pointer(b)))
}

// Any compares any two [cmp.Ordered] types, according to the following criteria:
//
// 1. any(nil) is least of all.
//
// 2. If the values are not mutually comparable, their [reflect.Kind]s are
// compared.
//
// 3. If either value is not of a [cmp.Ordered] type, this function panics.
//
// 4. Otherwise, the arguments are compared as-if by [cmp.Compare].
//
// For the purposes of this function, bool is treated as satisfying [cmp.Compare].
func Any(a, b any) Result {
if a == nil || b == nil {
return Bool(a != nil, b != nil)
}

ra := reflect.ValueOf(a)
rb := reflect.ValueOf(b)

type kind int
const (
kBool kind = 1 << iota
kInt
kUint
kFloat
kString
)

which := func(r reflect.Value) kind {
switch r.Kind() {
case reflect.Bool:
return kBool
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return kInt
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Uintptr:
return kUint
case reflect.Float32, reflect.Float64:
return kFloat
case reflect.String:
return kString
default:
panic(fmt.Sprintf("cmpx.Any: incomparable value %v (type %[1]T)", r.Interface()))
}
}

switch which(ra) | which(rb) {
case kBool:
return Bool(ra.Bool(), rb.Bool())

case kInt:
return cmp.Compare(ra.Int(), rb.Int())

case kUint:
return cmp.Compare(ra.Uint(), rb.Uint())

case kInt | kUint:
if rb.CanUint() {
v := rb.Uint()
if v > math.MaxInt64 {
return Less
}
return cmp.Compare(ra.Int(), int64(v))
} else {
v := ra.Uint()
if v > math.MaxInt64 {
return Greater
}
return cmp.Compare(int64(v), rb.Int())
}

case kFloat:
return cmp.Compare(ra.Float(), rb.Float())

case kFloat | kInt:
if ra.CanFloat() {
return cmp.Compare(ra.Float(), float64(rb.Int()))
} else {
return cmp.Compare(float64(ra.Int()), rb.Float())
}

case kFloat | kUint:
if ra.CanFloat() {
return cmp.Compare(ra.Float(), float64(rb.Uint()))
} else {
return cmp.Compare(float64(ra.Uint()), rb.Float())
}

case kString:
return cmp.Compare(ra.String(), rb.String())

default:
return cmp.Compare(ra.Kind(), rb.Kind())
}
}
62 changes: 62 additions & 0 deletions internal/ext/cmpx/cmpx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2020-2025 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cmpx_test

import (
"math"
"testing"

"github.com/bufbuild/protocompile/internal/ext/cmpx"
"github.com/stretchr/testify/assert"
)

func TestAny(t *testing.T) {
t.Parallel()

tests := []struct {
a, b any
r cmpx.Result
}{
{a: nil, b: nil, r: cmpx.Equal},
{a: nil, b: 1, r: cmpx.Less},
{a: 1, b: nil, r: cmpx.Greater},

{a: false, b: true, r: cmpx.Less},
{a: true, b: true, r: cmpx.Equal},

{a: true, b: 0, r: cmpx.Less},

{a: byte(0), b: int(-1), r: cmpx.Greater},
{a: byte(0), b: int(0), r: cmpx.Equal},
{a: byte(0), b: int(1), r: cmpx.Less},

{a: int(2), b: uint(1), r: cmpx.Greater},
{a: int(2), b: uint(2), r: cmpx.Equal},
{a: int(2), b: uint(3), r: cmpx.Less},

{a: int(math.MaxInt), b: uint(math.MaxUint), r: cmpx.Less},

{a: 1.5, b: 2, r: cmpx.Less},
{a: 2, b: 1.5, r: cmpx.Greater},
{a: 1.5, b: 2.0, r: cmpx.Less},

{a: "foo", b: "bar", r: cmpx.Greater},
{a: 1, b: "1", r: cmpx.Less},
}

for _, tt := range tests {
assert.Equal(t, tt.r, cmpx.Any(tt.a, tt.b), "cmpx.Any(%v, %v)", tt.a, tt.b)
}
}
65 changes: 4 additions & 61 deletions internal/prototest/yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
package prototest

import (
"cmp"
"fmt"
"math"
"slices"
"strings"

"github.com/bufbuild/protocompile/internal/ext/cmpx"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
Expand Down Expand Up @@ -253,7 +252,9 @@ func (d *doc) push(k, v any) {
// appropriate.
func (d *doc) prepare() {
if d.needsSort {
slices.SortFunc(d.pairs, cmpMapKeys)
slices.SortFunc(d.pairs, func(a, b [2]any) int {
return cmpx.Any(a[0], b[0])
})
}

if d.isArray || len(d.pairs) == 0 {
Expand Down Expand Up @@ -292,61 +293,3 @@ func (d *doc) prepare() {
}
}
}

func cmpMapKeys(a, b [2]any) int {
// key is a concrete comparable type that is useful for forcing relative
// order between types.
type key struct {
which int
bool
int64
uint64
string
}

any2key := func(v any) key {
switch v := v.(type) {
case bool:
return key{which: 0, bool: v}
case int32:
return key{which: 1, int64: int64(v)}
case int64:
return key{which: 1, int64: v}
case uint32:
return key{which: 1, int64: int64(v)}
case uint64:
if v <= math.MaxInt64 {
return key{which: 1, int64: int64(v)}
}
return key{which: 2, uint64: v}
case protoreflect.Name:
return key{which: 3, string: string(v)}
case string:
return key{which: 3, string: v}
default:
return key{}
}
}

ak := any2key(b)
bk := any2key(a)
if n := cmp.Compare(ak.which, bk.which); n != 0 {
return n
}

switch ak.which {
case 0:
if !ak.bool {
return -1
}
return 1
case 1:
return cmp.Compare(ak.int64, bk.int64)
case 2:
return cmp.Compare(ak.uint64, bk.uint64)
case 3:
return cmp.Compare(ak.string, bk.string)
default:
return 0
}
}

0 comments on commit 9283ced

Please # to comment.