diff --git a/common/callback/callback.go b/common/callback/callback.go new file mode 100644 index 00000000..dd167594 --- /dev/null +++ b/common/callback/callback.go @@ -0,0 +1,38 @@ +// Copyright 2025 StreamNative, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package callback + +// Callback defines an interface for handling the completion of an asynchronous operation. +// It allows the caller to notify the system when an operation has completed successfully +// or when it has failed with an error. The interface provides two methods for handling +// both success and error cases. +// +// The generic type T represents the result type of the operation, which can vary depending +// on the specific use case. +// +// Methods: +// +// - Complete(t T): This method is called when the asynchronous operation completes successfully. +// It accepts a result of type T, which is the outcome of the operation. +// +// - CompleteError(err error): This method is called when the asynchronous operation fails. +// It accepts an error, which indicates the reason for the failure. +type Callback[T any] interface { + // Complete is invoked when the operation completes successfully with the result 't' of type T. + Complete(t T) + + // CompleteError is invoked when the operation fails, providing an error 'err' indicating the failure reason. + CompleteError(err error) +} diff --git a/common/callback/once.go b/common/callback/once.go new file mode 100644 index 00000000..ade3b92b --- /dev/null +++ b/common/callback/once.go @@ -0,0 +1,63 @@ +// Copyright 2025 StreamNative, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package callback + +import "sync/atomic" + +// Once ensures that a specific callback is executed only once, either on completion or on error. +// It prevents further execution of the callbacks after the first call and ensures atomicity +// of the operation using the atomic package. +// +// The generic type T represents the result type of the operation, and the callbacks +// provide the behavior for handling success or failure of the operation. +// +// Fields: +// - OnComplete: A function that gets called with the result of type T when the operation completes successfully. +// - OnCompleteError: A function that gets called with an error if the operation fails. +// - completed: An atomic boolean used to track if the operation has already completed, ensuring only one callback is executed. + +type Once[T any] struct { + OnComplete func(t T) // Callback function called on successful completion + OnCompleteError func(err error) // Callback function called when an error occurs + completed atomic.Bool // Atomic flag to track completion status +} + +// Complete is called to notify that the operation has completed successfully with the result 't'. +// It ensures that the 'OnComplete' callback is only called once. +func (c *Once[T]) Complete(t T) { + if !c.completed.CompareAndSwap(false, true) { + return + } + c.OnComplete(t) +} + +// CompleteError is called to notify that the operation has failed with an error 'err'. +// It ensures that the 'OnCompleteError' callback is only called once. +func (c *Once[T]) CompleteError(err error) { + if !c.completed.CompareAndSwap(false, true) { + return + } + c.OnCompleteError(err) +} + +// NewOnce creates a new instance of Once with the provided success and error callbacks. +// It ensures that the callbacks are invoked only once, either for success or failure. +func NewOnce[T any](onComplete func(t T), onError func(err error)) Callback[T] { + return &Once[T]{ + onComplete, + onError, + atomic.Bool{}, + } +} diff --git a/common/callback/once_test.go b/common/callback/once_test.go new file mode 100644 index 00000000..91553f87 --- /dev/null +++ b/common/callback/once_test.go @@ -0,0 +1,100 @@ +// Copyright 2025 StreamNative, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2025 StreamNative, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package callback + +import ( + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "sync" + "sync/atomic" + "testing" +) + +func Test_Once_Complete_Concurrent(t *testing.T) { + callbackCounter := atomic.Int32{} + onceCallback := NewOnce[any]( + func(t any) { + callbackCounter.Add(1) + }, + func(err error) { + callbackCounter.Add(1) + }) + + group := sync.WaitGroup{} + for i := 0; i < 5; i++ { + group.Add(1) + go func() { + if i%2 == 0 { + onceCallback.Complete(nil) + } else { + onceCallback.CompleteError(errors.New("error")) + } + group.Done() + }() + } + + group.Wait() + assert.Equal(t, int32(1), callbackCounter.Load()) +} + +func Test_Once_Complete(t *testing.T) { + var callbackError error + var callbackValue int32 + + onceCallback := NewOnce[int32]( + func(t int32) { + callbackValue = t + }, + func(err error) { + callbackError = err + }) + + onceCallback.Complete(1) + + assert.Nil(t, callbackError) + assert.Equal(t, int32(1), callbackValue) +} + +func Test_Once_Complete_Error(t *testing.T) { + var callbackError error + var callbackValue *int32 + + onceCallback := NewOnce[int32]( + func(t int32) { + callbackValue = &t + }, + func(err error) { + callbackError = err + }) + + e1 := errors.New("error") + onceCallback.CompleteError(e1) + + assert.Equal(t, e1, callbackError) + assert.Nil(t, callbackValue) +} diff --git a/common/error_codes.go b/common/error_codes.go index a8af6d20..a7423258 100644 --- a/common/error_codes.go +++ b/common/error_codes.go @@ -40,7 +40,7 @@ var ( ErrorInvalidTerm = status.Error(CodeInvalidTerm, "oxia: invalid term") ErrorInvalidStatus = status.Error(CodeInvalidStatus, "oxia: invalid status") ErrorLeaderAlreadyConnected = status.Error(CodeLeaderAlreadyConnected, "oxia: leader is already connected") - ErrorAlreadyClosed = status.Error(CodeAlreadyClosed, "oxia: node is shutting down") + ErrorAlreadyClosed = status.Error(CodeAlreadyClosed, "oxia: resource is already closed") ErrorNodeIsNotLeader = status.Error(CodeNodeIsNotLeader, "oxia: node is not leader for shard") ErrorNodeIsNotFollower = status.Error(CodeNodeIsNotFollower, "oxia: node is not follower for shard") ErrorInvalidSession = status.Error(CodeInvalidSession, "oxia: session not found") diff --git a/server/leader_controller.go b/server/leader_controller.go index 19bbc445..53b18053 100644 --- a/server/leader_controller.go +++ b/server/leader_controller.go @@ -17,6 +17,7 @@ package server import ( "context" "fmt" + "github.com/streamnative/oxia/common/callback" "io" "log/slog" "sync" @@ -349,7 +350,7 @@ func (lc *leaderController) BecomeLeader(ctx context.Context, req *proto.BecomeL // committed in the quorum, to avoid missing any entries in the DB // by the moment we make the leader controller accepting new write/read // requests - if _, err = lc.quorumAckTracker.WaitForCommitOffset(ctx, lc.leaderElectionHeadEntryId.Offset, nil); err != nil { + if err = lc.quorumAckTracker.WaitForCommitOffset(ctx, lc.leaderElectionHeadEntryId.Offset); err != nil { return nil, err } @@ -796,10 +797,11 @@ func (lc *leaderController) write(ctx context.Context, request func(int64) *prot return wal.InvalidOffset, nil, err } - resp, err := lc.quorumAckTracker.WaitForCommitOffset(ctx, newOffset, func() (*proto.WriteResponse, error) { - return lc.db.ProcessWrite(actualRequest, newOffset, timestamp, WrapperUpdateOperationCallback) - }) - return newOffset, resp, err + if err := lc.quorumAckTracker.WaitForCommitOffset(ctx, newOffset); err != nil { + return wal.InvalidOffset, nil, err + } + writeResponse, err := lc.db.ProcessWrite(actualRequest, newOffset, timestamp, WrapperUpdateOperationCallback) + return newOffset, writeResponse, err } func (lc *leaderController) appendToWal(ctx context.Context, request func(int64) *proto.WriteRequest) (actualRequest *proto.WriteRequest, offset int64, timestamp uint64, err error) { @@ -914,31 +916,33 @@ func (lc *leaderController) handleWalSynced(stream proto.OxiaClient_WriteStreamS return } - lc.quorumAckTracker.WaitForCommitOffsetAsync(offset, func() (*proto.WriteResponse, error) { - return lc.db.ProcessWrite(req, offset, timestamp, WrapperUpdateOperationCallback) - }, func(response *proto.WriteResponse, err error) { - if err != nil { - timer.Done() - sendNonBlocking(closeCh, err) - return - } - - if err = stream.Send(response); err != nil { - timer.Done() + lc.quorumAckTracker.WaitForCommitOffsetAsync(context.Background(), offset, callback.NewOnce[any]( + func(_ any) { + defer timer.Done() + localResponse, err := lc.db.ProcessWrite(req, offset, timestamp, WrapperUpdateOperationCallback) + if err != nil { + sendNonBlocking(closeCh, err) + return + } + if err = stream.Send(localResponse); err != nil { + sendNonBlocking(closeCh, err) + return + } + }, + func(err error) { + defer timer.Done() sendNonBlocking(closeCh, err) - return - } - timer.Done() - }) + }, + )) } func (lc *leaderController) appendToWalStreamRequest(request *proto.WriteRequest, - callback func(offset int64, timestamp uint64, err error)) { + cb func(offset int64, timestamp uint64, err error)) { lc.Lock() if err := checkStatusIsLeader(lc.status); err != nil { lc.Unlock() - callback(wal.InvalidOffset, 0, err) + cb(wal.InvalidOffset, 0, err) return } @@ -961,7 +965,7 @@ func (lc *leaderController) appendToWalStreamRequest(request *proto.WriteRequest value, err := logEntryValue.MarshalVT() if err != nil { lc.Unlock() - callback(wal.InvalidOffset, timestamp, err) + cb(wal.InvalidOffset, timestamp, err) return } logEntry := &proto.LogEntry{ @@ -973,10 +977,10 @@ func (lc *leaderController) appendToWalStreamRequest(request *proto.WriteRequest lc.wal.AppendAndSync(logEntry, func(err error) { if err != nil { - callback(wal.InvalidOffset, timestamp, errors.Wrap(err, "oxia: failed to append to wal")) + cb(wal.InvalidOffset, timestamp, errors.Wrap(err, "oxia: failed to append to wal")) } else { lc.quorumAckTracker.AdvanceHeadOffset(newOffset) - callback(newOffset, timestamp, nil) + cb(newOffset, timestamp, nil) } }) lc.Unlock() diff --git a/server/quorum_ack_tracker.go b/server/quorum_ack_tracker.go index 59306bf0..f99ca87f 100644 --- a/server/quorum_ack_tracker.go +++ b/server/quorum_ack_tracker.go @@ -17,12 +17,12 @@ package server import ( "context" "errors" + "github.com/streamnative/oxia/common/callback" "io" "sync" "sync/atomic" "github.com/streamnative/oxia/common" - "github.com/streamnative/oxia/proto" "github.com/streamnative/oxia/server/util" ) @@ -44,13 +44,40 @@ type QuorumAckTracker interface { CommitOffset() int64 // WaitForCommitOffset - // Waits for the specific entry id to be fully committed. - // After that, invokes the function f - WaitForCommitOffset(ctx context.Context, offset int64, f func() (*proto.WriteResponse, error)) (*proto.WriteResponse, error) + // Waits for the specific entry, identified by its offset, to be fully committed. + // Once the commit is confirmed, the function will return without error. + // + // Parameters: + // - ctx: The context used for managing cancellation and deadlines for the operation. + // - offset: The unique identifier (offset) of the entry to wait for. + // + // Returns: + // - error: Returns an error if the operation is unsuccessful, otherwise nil. + // + // Note: + // This method blocks until the commit is confirmed. + + WaitForCommitOffset(ctx context.Context, offset int64) error + + // WaitForCommitOffsetAsync + // Asynchronously waits for the specific entry, identified by its offset, to be fully committed. + // Once the commit is confirmed, the provided callback function (cb) is invoked. + // + // Parameters: + // - ctx: The context used for managing cancellation and deadlines for the operation. + // - offset: The unique identifier (offset) of the entry to wait for. + // - cb: The callback function to invoke after the commit is confirmed. The callback + // will receive the result or error from the operation. + // + // Returns: + // - This method does not return anything immediately. The callback will handle + // the result or error asynchronously. + // + // Note: + // This method returns immediately and does not block the caller, allowing other + // operations to continue while waiting for the commit. + WaitForCommitOffsetAsync(ctx context.Context, offset int64, cb callback.Callback[any]) // NextOffset returns the offset for the next entry to write - WaitForCommitOffsetAsync(offset int64, f func() (*proto.WriteResponse, error), callback func(*proto.WriteResponse, error)) - - // NextOffset returns the offset for the next entry to write // Note this can go ahead of the head-offset as there can be multiple operations in flight. NextOffset() int64 @@ -97,7 +124,7 @@ type cursorAcker struct { type waitingRequest struct { minOffset int64 - callback func() + callback callback.Callback[any] } func NewQuorumAckTracker(replicationFactor uint32, headOffset int64, commitOffset int64) QuorumAckTracker { @@ -166,59 +193,37 @@ func (q *quorumAckTracker) WaitForHeadOffset(ctx context.Context, offset int64) return nil } -func (q *quorumAckTracker) WaitForCommitOffset(ctx context.Context, offset int64, f func() (*proto.WriteResponse, error)) (*proto.WriteResponse, error) { - ch := make(chan struct { - *proto.WriteResponse - error - }, 1) - q.WaitForCommitOffsetAsync(offset, f, func(response *proto.WriteResponse, err error) { - ch <- struct { - *proto.WriteResponse - error - }{response, err} - }) +func (q *quorumAckTracker) WaitForCommitOffset(ctx context.Context, offset int64) error { + ch := make(chan error, 1) + q.WaitForCommitOffsetAsync(ctx, offset, callback.NewOnce( + func(_ any) { ch <- nil }, + func(err error) { ch <- err }, + )) select { - case s := <-ch: - return s.WriteResponse, s.error + case err := <-ch: + return err case <-ctx.Done(): - return nil, ctx.Err() + return ctx.Err() } } -func (q *quorumAckTracker) WaitForCommitOffsetAsync(offset int64, f func() (*proto.WriteResponse, error), - callback func(*proto.WriteResponse, error)) { +func (q *quorumAckTracker) WaitForCommitOffsetAsync(_ context.Context, offset int64, cb callback.Callback[any]) { q.Lock() if q.closed { q.Unlock() - callback(nil, common.ErrorAlreadyClosed) + cb.CompleteError(common.ErrorAlreadyClosed) return } if q.requiredAcks == 0 || q.commitOffset.Load() >= offset { q.Unlock() - - var res *proto.WriteResponse - var err error - if f != nil { - res, err = f() - } - - callback(res, err) + cb.Complete(nil) return } - q.waitingRequests = append(q.waitingRequests, waitingRequest{offset, func() { - var res *proto.WriteResponse - var err error - if f != nil { - res, err = f() - } - - callback(res, err) - }}) - + q.waitingRequests = append(q.waitingRequests, waitingRequest{offset, cb}) q.Unlock() } @@ -231,16 +236,19 @@ func (q *quorumAckTracker) notifyCommitOffsetAdvanced(commitOffset int64) { } q.waitingRequests = q.waitingRequests[1:] - r.callback() + r.callback.Complete(nil) } } func (q *quorumAckTracker) Close() error { q.Lock() - defer q.Unlock() - q.closed = true q.waitForHeadOffset.Broadcast() + q.Unlock() + // unblock waiting request + for _, r := range q.waitingRequests { + r.callback.CompleteError(common.ErrorAlreadyClosed) + } return nil } diff --git a/server/quorum_ack_tracker_test.go b/server/quorum_ack_tracker_test.go index bfaa63d6..388a792a 100644 --- a/server/quorum_ack_tracker_test.go +++ b/server/quorum_ack_tracker_test.go @@ -16,12 +16,12 @@ package server import ( "context" + "github.com/streamnative/oxia/common" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/streamnative/oxia/proto" "github.com/streamnative/oxia/server/wal" ) @@ -188,10 +188,7 @@ func TestQuorumAckTracker_WaitForCommitOffset(t *testing.T) { ch := make(chan error) go func() { - _, err := at.WaitForCommitOffset(context.Background(), 2, func() (*proto.WriteResponse, error) { - return nil, nil //nolint:nilnil - }) - ch <- err + ch <- at.WaitForCommitOffset(context.Background(), 2) }() time.Sleep(100 * time.Millisecond) @@ -270,3 +267,24 @@ func TestQuorumAckTracker_AddingCursors_RF5(t *testing.T) { assert.EqualValues(t, 10, at.HeadOffset()) assert.EqualValues(t, 7, at.CommitOffset()) } + +func TestQuorumAckTracker_ClearPending(t *testing.T) { + at := NewQuorumAckTracker(5, 10, 5) + asyncRes := make(chan error, 1) + go func() { + asyncRes <- at.WaitForCommitOffset(context.Background(), 6) + }() + + time.Sleep(100 * time.Millisecond) + err := at.Close() + assert.NoError(t, err) + + // Wait for the result from asyncRes + select { + case resErr := <-asyncRes: + // Ensure that we received the expected result (in this case, error should be nil) + assert.ErrorIs(t, resErr, common.ErrorAlreadyClosed) + case <-time.After(2 * time.Second): // Adding a timeout for safety + t.Fatal("Timed out waiting for async result") + } +}