From 5b8998703a904dddfd0208573fae9026360f9bdd Mon Sep 17 00:00:00 2001 From: ringsaturn Date: Sun, 25 Aug 2024 22:23:45 +0800 Subject: [PATCH] Support shuffle for iter --- cities.go | 15 ++++++++++----- iter.go | 25 ++++++++++++++++++------- iter_test.go | 4 ++-- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/cities.go b/cities.go index c99b9e1..d0a0365 100644 --- a/cities.go +++ b/cities.go @@ -7,8 +7,14 @@ import ( "strconv" ) -//go:embed upstream/cities.json -var citiesbytes []byte +var ( + //go:embed upstream/cities.json + citiesbytes []byte + + Cities []*City + + idx []int +) type City struct { Country string `json:"country"` @@ -19,8 +25,6 @@ type City struct { Admin2 string `json:"admin2"` } -var Cities []*City - func init() { type RawCity struct { Country string `json:"country"` @@ -35,7 +39,7 @@ func init() { if err := json.Unmarshal(citiesbytes, &rawcities); err != nil { panic(err) } - for _, rawcity := range rawcities { + for i, rawcity := range rawcities { lat, err := strconv.ParseFloat(rawcity.Lat, 64) if err != nil { panic(err) @@ -52,5 +56,6 @@ func init() { Admin1: rawcity.Admin1, Admin2: rawcity.Admin2, }) + idx = append(idx, i) } } diff --git a/iter.go b/iter.go index f5950c8..7f580b7 100644 --- a/iter.go +++ b/iter.go @@ -3,22 +3,33 @@ package gocitiesjson -import "iter" +import ( + "iter" + "math/rand" +) -func All() iter.Seq[*City] { +func All(shuffle bool) iter.Seq[*City] { return func(yield func(*City) bool) { - for _, city := range Cities { - if !yield(city) { + var idxes []int + if shuffle { + idxes = make([]int, len(idx)) + copy(idxes, idx) + rand.Shuffle(len(idxes), func(i, j int) { idxes[i], idxes[j] = idxes[j], idxes[i] }) + } else { + idxes = idx + } + for _, i := range idxes { + if !yield(Cities[i]) { return } } } } -func Filter(filter func(*City) bool) iter.Seq[*City] { +func Filter(fn func(*City) bool, shuffle bool) iter.Seq[*City] { return func(yield func(*City) bool) { - for _, city := range Cities { - if filter(city) { + for city := range All(shuffle) { + if fn(city) { if !yield(city) { return } diff --git a/iter_test.go b/iter_test.go index c8e4670..f4989c9 100644 --- a/iter_test.go +++ b/iter_test.go @@ -11,7 +11,7 @@ import ( func ExampleAll() { c := 0 - for city := range gocitiesjson.All() { + for city := range gocitiesjson.All(false) { _ = city c++ } @@ -25,7 +25,7 @@ func ExampleFilter() { bboxFilter := func(city *gocitiesjson.City) bool { return city.Lng > 0 && city.Lng < 10 && city.Lat > 0 && city.Lat < 10 } - for city := range gocitiesjson.Filter(bboxFilter) { + for city := range gocitiesjson.Filter(bboxFilter, false) { _ = city c++ }