Skip to content

Commit

Permalink
Allow skipping an entry expansion in tree.Walk() (#838)
Browse files Browse the repository at this point in the history
It is now possible to skip expanding an entry in `tree.Walk()` by
returning `TreeWalkSkip`, in addition to stopping altogether by
returning a non-nil error.

Fixes: #837
  • Loading branch information
lhchavez authored Sep 30, 2021
1 parent c6da3b9 commit 9b15518
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 2 deletions.
18 changes: 17 additions & 1 deletion tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ extern int _go_git_treewalk(git_tree *tree, git_treewalk_mode mode, void *ptr);
import "C"

import (
"errors"
"runtime"
"unsafe"
)
Expand Down Expand Up @@ -134,14 +135,29 @@ func treeWalkCallback(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer
}

err := data.callback(C.GoString(_root), newTreeEntry(entry))
if err != nil {
if err == TreeWalkSkip {
return C.int(1)
} else if err != nil {
*data.errorTarget = err
return C.int(ErrorCodeUser)
}

return C.int(ErrorCodeOK)
}

// TreeWalkSkip is an error that can be returned form TreeWalkCallback to skip
// a subtree from being expanded.
var TreeWalkSkip = errors.New("skip")

// Walk traverses the entries in a tree and its subtrees in pre order.
//
// The entries will be traversed in the pre order, children subtrees will be
// automatically loaded as required, and the callback will be called once per
// entry with the current (relative) root for the entry and the entry data
// itself.
//
// If the callback returns TreeWalkSkip, the passed entry will be skipped on
// the traversal. Any other non-nil error stops the walk.
func (t *Tree) Walk(callback TreeWalkCallback) error {
var err error
data := treeWalkCallbackData{
Expand Down
74 changes: 73 additions & 1 deletion tree_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package git

import "testing"
import (
"errors"
"testing"
)

func TestTreeEntryById(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -63,3 +66,72 @@ func TestTreeBuilderInsert(t *testing.T) {
t.Fatalf("got oid %v, want %v", entry.Id, blobId)
}
}

func TestTreeWalk(t *testing.T) {
t.Parallel()
repo, err := OpenRepository("testdata/TestGitRepository.git")
checkFatal(t, err)
treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
checkFatal(t, err)

tree, err := repo.LookupTree(treeID)
checkFatal(t, err)

var callCount int
err = tree.Walk(func(name string, entry *TreeEntry) error {
callCount++

return nil
})
checkFatal(t, err)
if callCount != 11 {
t.Fatalf("got called %v times, want %v", callCount, 11)
}
}

func TestTreeWalkSkip(t *testing.T) {
t.Parallel()
repo, err := OpenRepository("testdata/TestGitRepository.git")
checkFatal(t, err)
treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
checkFatal(t, err)

tree, err := repo.LookupTree(treeID)
checkFatal(t, err)

var callCount int
err = tree.Walk(func(name string, entry *TreeEntry) error {
callCount++

return TreeWalkSkip
})
checkFatal(t, err)
if callCount != 4 {
t.Fatalf("got called %v times, want %v", callCount, 4)
}
}

func TestTreeWalkStop(t *testing.T) {
t.Parallel()
repo, err := OpenRepository("testdata/TestGitRepository.git")
checkFatal(t, err)
treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
checkFatal(t, err)

tree, err := repo.LookupTree(treeID)
checkFatal(t, err)

var callCount int
stopError := errors.New("stop")
err = tree.Walk(func(name string, entry *TreeEntry) error {
callCount++

return stopError
})
if err != stopError {
t.Fatalf("got error %v, want %v", err, stopError)
}
if callCount != 1 {
t.Fatalf("got called %v times, want %v", callCount, 1)
}
}

0 comments on commit 9b15518

Please # to comment.