Skip to content

Commit

Permalink
Allow skipping an entry expansion in tree.Walk()
Browse files Browse the repository at this point in the history
It is now possible to skip expanding an entry in `tree.Walk()` by
returning `TreeWalkSkip`, in addition to stopping altogether by
returning a non-nil error.

Fixes: libgit2#837
  • Loading branch information
lhchavez committed Sep 30, 2021
1 parent c6da3b9 commit 6b2c144
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 2 deletions.
18 changes: 17 additions & 1 deletion tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ extern int _go_git_treewalk(git_tree *tree, git_treewalk_mode mode, void *ptr);
import "C"

import (
"errors"
"runtime"
"unsafe"
)
Expand Down Expand Up @@ -134,14 +135,29 @@ func treeWalkCallback(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer
}

err := data.callback(C.GoString(_root), newTreeEntry(entry))
if err != nil {
if err == TreeWalkSkip {
return C.int(1)
} else if err != nil {
*data.errorTarget = err
return C.int(ErrorCodeUser)
}

return C.int(ErrorCodeOK)
}

// TreeWalkSkip is an error that can be returned form TreeWalkCallback to skip
// a subtree from being expanded.
var TreeWalkSkip = errors.New("skip")

// Walk traverses the entries in a tree and its subtrees in pre order.
//
// The entries will be traversed in the pre order, children subtrees will be
// automatically loaded as required, and the callback will be called once per
// entry with the current (relative) root for the entry and the entry data
// itself.
//
// If the callback returns TreeWalkSkip, the passed entry will be skipped on
// the traversal. Any other non-nil error stops the walk.
func (t *Tree) Walk(callback TreeWalkCallback) error {
var err error
data := treeWalkCallbackData{
Expand Down
74 changes: 73 additions & 1 deletion tree_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package git

import "testing"
import (
"errors"
"testing"
)

func TestTreeEntryById(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -63,3 +66,72 @@ func TestTreeBuilderInsert(t *testing.T) {
t.Fatalf("got oid %v, want %v", entry.Id, blobId)
}
}

func TestTreeWalk(t *testing.T) {
t.Parallel()
repo, err := OpenRepository("testdata/TestGitRepository.git")
checkFatal(t, err)
treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
checkFatal(t, err)

tree, err := repo.LookupTree(treeID)
checkFatal(t, err)

var callCount int
err = tree.Walk(func(name string, entry *TreeEntry) error {
callCount++

return nil
})
checkFatal(t, err)
if callCount != 11 {
t.Fatalf("got called %v times, want %v", callCount, 11)
}
}

func TestTreeWalkSkip(t *testing.T) {
t.Parallel()
repo, err := OpenRepository("testdata/TestGitRepository.git")
checkFatal(t, err)
treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
checkFatal(t, err)

tree, err := repo.LookupTree(treeID)
checkFatal(t, err)

var callCount int
err = tree.Walk(func(name string, entry *TreeEntry) error {
callCount++

return TreeWalkSkip
})
checkFatal(t, err)
if callCount != 4 {
t.Fatalf("got called %v times, want %v", callCount, 4)
}
}

func TestTreeWalkStop(t *testing.T) {
t.Parallel()
repo, err := OpenRepository("testdata/TestGitRepository.git")
checkFatal(t, err)
treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
checkFatal(t, err)

tree, err := repo.LookupTree(treeID)
checkFatal(t, err)

var callCount int
stopError := errors.New("stop")
err = tree.Walk(func(name string, entry *TreeEntry) error {
callCount++

return stopError
})
if err != stopError {
t.Fatalf("got error %v, want %v", err, stopError)
}
if callCount != 1 {
t.Fatalf("got called %v times, want %v", callCount, 1)
}
}

0 comments on commit 6b2c144

Please # to comment.