diff --git a/tree.go b/tree.go index b1aeaa7e..1410b965 100644 --- a/tree.go +++ b/tree.go @@ -8,6 +8,7 @@ extern int _go_git_treewalk(git_tree *tree, git_treewalk_mode mode, void *ptr); import "C" import ( + "errors" "runtime" "unsafe" ) @@ -134,7 +135,9 @@ func treeWalkCallback(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer } err := data.callback(C.GoString(_root), newTreeEntry(entry)) - if err != nil { + if err == TreeWalkSkip { + return C.int(1) + } else if err != nil { *data.errorTarget = err return C.int(ErrorCodeUser) } @@ -142,6 +145,19 @@ func treeWalkCallback(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer return C.int(ErrorCodeOK) } +// TreeWalkSkip is an error that can be returned form TreeWalkCallback to skip +// a subtree from being expanded. +var TreeWalkSkip = errors.New("skip") + +// Walk traverses the entries in a tree and its subtrees in pre order. +// +// The entries will be traversed in the pre order, children subtrees will be +// automatically loaded as required, and the callback will be called once per +// entry with the current (relative) root for the entry and the entry data +// itself. +// +// If the callback returns TreeWalkSkip, the passed entry will be skipped on +// the traversal. Any other non-nil error stops the walk. func (t *Tree) Walk(callback TreeWalkCallback) error { var err error data := treeWalkCallbackData{ diff --git a/tree_test.go b/tree_test.go index f5b68221..0c0c0048 100644 --- a/tree_test.go +++ b/tree_test.go @@ -1,6 +1,9 @@ package git -import "testing" +import ( + "errors" + "testing" +) func TestTreeEntryById(t *testing.T) { t.Parallel() @@ -63,3 +66,72 @@ func TestTreeBuilderInsert(t *testing.T) { t.Fatalf("got oid %v, want %v", entry.Id, blobId) } } + +func TestTreeWalk(t *testing.T) { + t.Parallel() + repo, err := OpenRepository("testdata/TestGitRepository.git") + checkFatal(t, err) + treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125") + checkFatal(t, err) + + tree, err := repo.LookupTree(treeID) + checkFatal(t, err) + + var callCount int + err = tree.Walk(func(name string, entry *TreeEntry) error { + callCount++ + + return nil + }) + checkFatal(t, err) + if callCount != 11 { + t.Fatalf("got called %v times, want %v", callCount, 11) + } +} + +func TestTreeWalkSkip(t *testing.T) { + t.Parallel() + repo, err := OpenRepository("testdata/TestGitRepository.git") + checkFatal(t, err) + treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125") + checkFatal(t, err) + + tree, err := repo.LookupTree(treeID) + checkFatal(t, err) + + var callCount int + err = tree.Walk(func(name string, entry *TreeEntry) error { + callCount++ + + return TreeWalkSkip + }) + checkFatal(t, err) + if callCount != 4 { + t.Fatalf("got called %v times, want %v", callCount, 4) + } +} + +func TestTreeWalkStop(t *testing.T) { + t.Parallel() + repo, err := OpenRepository("testdata/TestGitRepository.git") + checkFatal(t, err) + treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125") + checkFatal(t, err) + + tree, err := repo.LookupTree(treeID) + checkFatal(t, err) + + var callCount int + stopError := errors.New("stop") + err = tree.Walk(func(name string, entry *TreeEntry) error { + callCount++ + + return stopError + }) + if err != stopError { + t.Fatalf("got error %v, want %v", err, stopError) + } + if callCount != 1 { + t.Fatalf("got called %v times, want %v", callCount, 1) + } +}