Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Allow skipping an entry expansion in tree.Walk() #838

Merged
merged 1 commit into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ extern int _go_git_treewalk(git_tree *tree, git_treewalk_mode mode, void *ptr);
import "C"

import (
"errors"
"runtime"
"unsafe"
)
Expand Down Expand Up @@ -134,14 +135,29 @@ func treeWalkCallback(_root *C.char, entry *C.git_tree_entry, ptr unsafe.Pointer
}

err := data.callback(C.GoString(_root), newTreeEntry(entry))
if err != nil {
if err == TreeWalkSkip {
return C.int(1)
} else if err != nil {
*data.errorTarget = err
return C.int(ErrorCodeUser)
}

return C.int(ErrorCodeOK)
}

// TreeWalkSkip is an error that can be returned form TreeWalkCallback to skip
// a subtree from being expanded.
var TreeWalkSkip = errors.New("skip")

// Walk traverses the entries in a tree and its subtrees in pre order.
//
// The entries will be traversed in the pre order, children subtrees will be
// automatically loaded as required, and the callback will be called once per
// entry with the current (relative) root for the entry and the entry data
// itself.
//
// If the callback returns TreeWalkSkip, the passed entry will be skipped on
// the traversal. Any other non-nil error stops the walk.
func (t *Tree) Walk(callback TreeWalkCallback) error {
var err error
data := treeWalkCallbackData{
Expand Down
74 changes: 73 additions & 1 deletion tree_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package git

import "testing"
import (
"errors"
"testing"
)

func TestTreeEntryById(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -63,3 +66,72 @@ func TestTreeBuilderInsert(t *testing.T) {
t.Fatalf("got oid %v, want %v", entry.Id, blobId)
}
}

func TestTreeWalk(t *testing.T) {
t.Parallel()
repo, err := OpenRepository("testdata/TestGitRepository.git")
checkFatal(t, err)
treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
checkFatal(t, err)

tree, err := repo.LookupTree(treeID)
checkFatal(t, err)

var callCount int
err = tree.Walk(func(name string, entry *TreeEntry) error {
callCount++

return nil
})
checkFatal(t, err)
if callCount != 11 {
t.Fatalf("got called %v times, want %v", callCount, 11)
}
}

func TestTreeWalkSkip(t *testing.T) {
t.Parallel()
repo, err := OpenRepository("testdata/TestGitRepository.git")
checkFatal(t, err)
treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
checkFatal(t, err)

tree, err := repo.LookupTree(treeID)
checkFatal(t, err)

var callCount int
err = tree.Walk(func(name string, entry *TreeEntry) error {
callCount++

return TreeWalkSkip
})
checkFatal(t, err)
if callCount != 4 {
t.Fatalf("got called %v times, want %v", callCount, 4)
}
}

func TestTreeWalkStop(t *testing.T) {
t.Parallel()
repo, err := OpenRepository("testdata/TestGitRepository.git")
checkFatal(t, err)
treeID, err := NewOid("6020a3b8d5d636e549ccbd0c53e2764684bb3125")
checkFatal(t, err)

tree, err := repo.LookupTree(treeID)
checkFatal(t, err)

var callCount int
stopError := errors.New("stop")
err = tree.Walk(func(name string, entry *TreeEntry) error {
callCount++

return stopError
})
if err != stopError {
t.Fatalf("got error %v, want %v", err, stopError)
}
if callCount != 1 {
t.Fatalf("got called %v times, want %v", callCount, 1)
}
}