From 047b5ae1ee93775c15ab545c9612d057b99debe0 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Sat, 7 Sep 2024 17:54:41 -0400 Subject: [PATCH] Add AccruedFees to state.Chain interface --- .../block/executor/proposal_block_test.go | 2 ++ .../block/executor/standard_block_test.go | 2 ++ .../block/executor/verifier_test.go | 5 +++ vms/platformvm/state/diff.go | 15 ++++++-- vms/platformvm/state/diff_test.go | 25 +++++++++++++ vms/platformvm/state/mock_chain.go | 26 ++++++++++++++ vms/platformvm/state/mock_diff.go | 26 ++++++++++++++ vms/platformvm/state/mock_state.go | 26 ++++++++++++++ vms/platformvm/state/state.go | 35 +++++++++++++++++++ vms/platformvm/state/state_test.go | 16 +++++++++ 10 files changed, 176 insertions(+), 2 deletions(-) diff --git a/vms/platformvm/block/executor/proposal_block_test.go b/vms/platformvm/block/executor/proposal_block_test.go index a1c5151044d5..66a6c7604e23 100644 --- a/vms/platformvm/block/executor/proposal_block_test.go +++ b/vms/platformvm/block/executor/proposal_block_test.go @@ -90,6 +90,7 @@ func TestApricotProposalBlockTimeVerification(t *testing.T) { // setup state to validate proposal block transaction onParentAccept.EXPECT().GetTimestamp().Return(chainTime).AnyTimes() onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() + onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() currentStakersIt := iteratormock.NewIterator[*state.Staker](ctrl) currentStakersIt.EXPECT().Next().Return(true) @@ -161,6 +162,7 @@ func TestBanffProposalBlockTimeVerification(t *testing.T) { onParentAccept := state.NewMockDiff(ctrl) onParentAccept.EXPECT().GetTimestamp().Return(parentTime).AnyTimes() onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() + onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() onParentAccept.EXPECT().GetCurrentSupply(constants.PrimaryNetworkID).Return(uint64(1000), nil).AnyTimes() env.blkManager.(*manager).blkIDToState[parentID] = &blockState{ diff --git a/vms/platformvm/block/executor/standard_block_test.go b/vms/platformvm/block/executor/standard_block_test.go index 7c31ccb7f253..fa64eee74697 100644 --- a/vms/platformvm/block/executor/standard_block_test.go +++ b/vms/platformvm/block/executor/standard_block_test.go @@ -59,6 +59,7 @@ func TestApricotStandardBlockTimeVerification(t *testing.T) { chainTime := env.clk.Time().Truncate(time.Second) onParentAccept.EXPECT().GetTimestamp().Return(chainTime).AnyTimes() onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() + onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() // wrong height apricotChildBlk, err := block.NewApricotStandardBlock( @@ -136,6 +137,7 @@ func TestBanffStandardBlockTimeVerification(t *testing.T) { onParentAccept.EXPECT().GetTimestamp().Return(chainTime).AnyTimes() onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() + onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() txID := ids.GenerateTestID() utxo := &avax.UTXO{ diff --git a/vms/platformvm/block/executor/verifier_test.go b/vms/platformvm/block/executor/verifier_test.go index 6508d4c88624..5b786b0e33d8 100644 --- a/vms/platformvm/block/executor/verifier_test.go +++ b/vms/platformvm/block/executor/verifier_test.go @@ -103,6 +103,7 @@ func TestVerifierVisitProposalBlock(t *testing.T) { // One call for each of onCommitState and onAbortState. parentOnAcceptState.EXPECT().GetTimestamp().Return(timestamp).Times(2) parentOnAcceptState.EXPECT().GetFeeState().Return(gas.State{}).Times(2) + parentOnAcceptState.EXPECT().GetAccruedFees().Return(uint64(0)).Times(2) backend := &backend{ lastAccepted: parentID, @@ -334,6 +335,7 @@ func TestVerifierVisitStandardBlock(t *testing.T) { timestamp := time.Now() parentState.EXPECT().GetTimestamp().Return(timestamp).Times(1) parentState.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + parentState.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) parentStatelessBlk.EXPECT().Height().Return(uint64(1)).Times(1) mempool.EXPECT().Remove(apricotBlk.Txs()).Times(1) @@ -595,6 +597,7 @@ func TestBanffAbortBlockTimestampChecks(t *testing.T) { s.EXPECT().GetLastAccepted().Return(parentID).Times(3) s.EXPECT().GetTimestamp().Return(parentTime).Times(3) s.EXPECT().GetFeeState().Return(gas.State{}).Times(3) + s.EXPECT().GetAccruedFees().Return(uint64(0)).Times(3) onDecisionState, err := state.NewDiff(parentID, backend) require.NoError(err) @@ -692,6 +695,7 @@ func TestBanffCommitBlockTimestampChecks(t *testing.T) { s.EXPECT().GetLastAccepted().Return(parentID).Times(3) s.EXPECT().GetTimestamp().Return(parentTime).Times(3) s.EXPECT().GetFeeState().Return(gas.State{}).Times(3) + s.EXPECT().GetAccruedFees().Return(uint64(0)).Times(3) onDecisionState, err := state.NewDiff(parentID, backend) require.NoError(err) @@ -807,6 +811,7 @@ func TestVerifierVisitStandardBlockWithDuplicateInputs(t *testing.T) { parentStatelessBlk.EXPECT().Height().Return(uint64(1)).Times(1) parentState.EXPECT().GetTimestamp().Return(timestamp).Times(1) parentState.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + parentState.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) parentStatelessBlk.EXPECT().Parent().Return(grandParentID).Times(1) err = verifier.ApricotStandardBlock(blk) diff --git a/vms/platformvm/state/diff.go b/vms/platformvm/state/diff.go index 16f7edf4435b..aceecc47ad56 100644 --- a/vms/platformvm/state/diff.go +++ b/vms/platformvm/state/diff.go @@ -35,8 +35,9 @@ type diff struct { parentID ids.ID stateVersions Versions - timestamp time.Time - feeState gas.State + timestamp time.Time + feeState gas.State + accruedFees uint64 // Subnet ID --> supply of native asset of the subnet currentSupply map[ids.ID]uint64 @@ -77,6 +78,7 @@ func NewDiff( stateVersions: stateVersions, timestamp: parentState.GetTimestamp(), feeState: parentState.GetFeeState(), + accruedFees: parentState.GetAccruedFees(), subnetOwners: make(map[ids.ID]fx.Owner), subnetManagers: make(map[ids.ID]chainIDAndAddr), }, nil @@ -112,6 +114,14 @@ func (d *diff) SetFeeState(feeState gas.State) { d.feeState = feeState } +func (d *diff) GetAccruedFees() uint64 { + return d.accruedFees +} + +func (d *diff) SetAccruedFees(accruedFees uint64) { + d.accruedFees = accruedFees +} + func (d *diff) GetCurrentSupply(subnetID ids.ID) (uint64, error) { supply, ok := d.currentSupply[subnetID] if ok { @@ -437,6 +447,7 @@ func (d *diff) DeleteUTXO(utxoID ids.ID) { func (d *diff) Apply(baseState Chain) error { baseState.SetTimestamp(d.timestamp) baseState.SetFeeState(d.feeState) + baseState.SetAccruedFees(d.accruedFees) for subnetID, supply := range d.currentSupply { baseState.SetCurrentSupply(subnetID, supply) } diff --git a/vms/platformvm/state/diff_test.go b/vms/platformvm/state/diff_test.go index 56ff3bc8ac8b..a7eec42364b1 100644 --- a/vms/platformvm/state/diff_test.go +++ b/vms/platformvm/state/diff_test.go @@ -67,6 +67,24 @@ func TestDiffFeeState(t *testing.T) { assertChainsEqual(t, state, d) } +func TestDiffAccruedFees(t *testing.T) { + require := require.New(t) + + state := newTestState(t, memdb.New()) + + d, err := NewDiffOn(state) + require.NoError(err) + + initialAccruedFees := state.GetAccruedFees() + newAccruedFees := initialAccruedFees + 1 + d.SetAccruedFees(newAccruedFees) + require.Equal(newAccruedFees, d.GetAccruedFees()) + require.Equal(initialAccruedFees, state.GetAccruedFees()) + + require.NoError(d.Apply(state)) + assertChainsEqual(t, state, d) +} + func TestDiffCurrentSupply(t *testing.T) { require := require.New(t) @@ -101,6 +119,7 @@ func TestDiffCurrentValidator(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -135,6 +154,7 @@ func TestDiffPendingValidator(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -175,6 +195,7 @@ func TestDiffCurrentDelegator(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -221,6 +242,7 @@ func TestDiffPendingDelegator(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -361,6 +383,7 @@ func TestDiffTx(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -458,6 +481,7 @@ func TestDiffUTXO(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -518,6 +542,7 @@ func assertChainsEqual(t *testing.T, expected, actual Chain) { require.Equal(expected.GetTimestamp(), actual.GetTimestamp()) require.Equal(expected.GetFeeState(), actual.GetFeeState()) + require.Equal(expected.GetAccruedFees(), actual.GetAccruedFees()) expectedCurrentSupply, err := expected.GetCurrentSupply(constants.PrimaryNetworkID) require.NoError(err) diff --git a/vms/platformvm/state/mock_chain.go b/vms/platformvm/state/mock_chain.go index 727847a7c07f..1891b74099d8 100644 --- a/vms/platformvm/state/mock_chain.go +++ b/vms/platformvm/state/mock_chain.go @@ -178,6 +178,20 @@ func (mr *MockChainMockRecorder) DeleteUTXO(utxoID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUTXO", reflect.TypeOf((*MockChain)(nil).DeleteUTXO), utxoID) } +// GetAccruedFees mocks base method. +func (m *MockChain) GetAccruedFees() uint64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccruedFees") + ret0, _ := ret[0].(uint64) + return ret0 +} + +// GetAccruedFees indicates an expected call of GetAccruedFees. +func (mr *MockChainMockRecorder) GetAccruedFees() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccruedFees", reflect.TypeOf((*MockChain)(nil).GetAccruedFees)) +} + // GetCurrentDelegatorIterator mocks base method. func (m *MockChain) GetCurrentDelegatorIterator(subnetID ids.ID, nodeID ids.NodeID) (iterator.Iterator[*Staker], error) { m.ctrl.T.Helper() @@ -455,6 +469,18 @@ func (mr *MockChainMockRecorder) PutPendingValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutPendingValidator", reflect.TypeOf((*MockChain)(nil).PutPendingValidator), staker) } +// SetAccruedFees mocks base method. +func (m *MockChain) SetAccruedFees(f uint64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccruedFees", f) +} + +// SetAccruedFees indicates an expected call of SetAccruedFees. +func (mr *MockChainMockRecorder) SetAccruedFees(f any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccruedFees", reflect.TypeOf((*MockChain)(nil).SetAccruedFees), f) +} + // SetCurrentSupply mocks base method. func (m *MockChain) SetCurrentSupply(subnetID ids.ID, cs uint64) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/mock_diff.go b/vms/platformvm/state/mock_diff.go index ccf6619b5b24..557c59faf58c 100644 --- a/vms/platformvm/state/mock_diff.go +++ b/vms/platformvm/state/mock_diff.go @@ -192,6 +192,20 @@ func (mr *MockDiffMockRecorder) DeleteUTXO(utxoID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUTXO", reflect.TypeOf((*MockDiff)(nil).DeleteUTXO), utxoID) } +// GetAccruedFees mocks base method. +func (m *MockDiff) GetAccruedFees() uint64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccruedFees") + ret0, _ := ret[0].(uint64) + return ret0 +} + +// GetAccruedFees indicates an expected call of GetAccruedFees. +func (mr *MockDiffMockRecorder) GetAccruedFees() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccruedFees", reflect.TypeOf((*MockDiff)(nil).GetAccruedFees)) +} + // GetCurrentDelegatorIterator mocks base method. func (m *MockDiff) GetCurrentDelegatorIterator(subnetID ids.ID, nodeID ids.NodeID) (iterator.Iterator[*Staker], error) { m.ctrl.T.Helper() @@ -469,6 +483,18 @@ func (mr *MockDiffMockRecorder) PutPendingValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutPendingValidator", reflect.TypeOf((*MockDiff)(nil).PutPendingValidator), staker) } +// SetAccruedFees mocks base method. +func (m *MockDiff) SetAccruedFees(f uint64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccruedFees", f) +} + +// SetAccruedFees indicates an expected call of SetAccruedFees. +func (mr *MockDiffMockRecorder) SetAccruedFees(f any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccruedFees", reflect.TypeOf((*MockDiff)(nil).SetAccruedFees), f) +} + // SetCurrentSupply mocks base method. func (m *MockDiff) SetCurrentSupply(subnetID ids.ID, cs uint64) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index 527db5cf8a53..b6f0c7bd6b8a 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -293,6 +293,20 @@ func (mr *MockStateMockRecorder) DeleteUTXO(utxoID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUTXO", reflect.TypeOf((*MockState)(nil).DeleteUTXO), utxoID) } +// GetAccruedFees mocks base method. +func (m *MockState) GetAccruedFees() uint64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccruedFees") + ret0, _ := ret[0].(uint64) + return ret0 +} + +// GetAccruedFees indicates an expected call of GetAccruedFees. +func (mr *MockStateMockRecorder) GetAccruedFees() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccruedFees", reflect.TypeOf((*MockState)(nil).GetAccruedFees)) +} + // GetBlockIDAtHeight mocks base method. func (m *MockState) GetBlockIDAtHeight(height uint64) (ids.ID, error) { m.ctrl.T.Helper() @@ -704,6 +718,18 @@ func (mr *MockStateMockRecorder) ReindexBlocks(lock, log any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReindexBlocks", reflect.TypeOf((*MockState)(nil).ReindexBlocks), lock, log) } +// SetAccruedFees mocks base method. +func (m *MockState) SetAccruedFees(f uint64) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccruedFees", f) +} + +// SetAccruedFees indicates an expected call of SetAccruedFees. +func (mr *MockStateMockRecorder) SetAccruedFees(f any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccruedFees", reflect.TypeOf((*MockState)(nil).SetAccruedFees), f) +} + // SetCurrentSupply mocks base method. func (m *MockState) SetCurrentSupply(subnetID ids.ID, cs uint64) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index b5288df75140..849a1bb1f110 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -86,6 +86,7 @@ var ( TimestampKey = []byte("timestamp") FeeStateKey = []byte("fee state") + AccruedFeesKey = []byte("accrued fees") CurrentSupplyKey = []byte("current supply") LastAcceptedKey = []byte("last accepted") HeightsIndexedKey = []byte("heights indexed") @@ -107,6 +108,9 @@ type Chain interface { GetFeeState() gas.State SetFeeState(f gas.State) + GetAccruedFees() uint64 + SetAccruedFees(f uint64) + GetCurrentSupply(subnetID ids.ID) (uint64, error) SetCurrentSupply(subnetID ids.ID, cs uint64) @@ -279,6 +283,7 @@ type stateBlk struct { * |-- blocksReindexedKey -> nil * |-- timestampKey -> timestamp * |-- feeStateKey -> feeState + * |-- accruedFeesKey -> accruedFees * |-- currentSupplyKey -> currentSupply * |-- lastAcceptedKey -> lastAccepted * '-- heightsIndexKey -> startIndexHeight + endIndexHeight @@ -371,6 +376,7 @@ type state struct { // The persisted fields represent the current database value timestamp, persistedTimestamp time.Time feeState, persistedFeeState gas.State + accruedFees, persistedAccruedFees uint64 currentSupply, persistedCurrentSupply uint64 // [lastAccepted] is the most recently accepted block. lastAccepted, persistedLastAccepted ids.ID @@ -1052,6 +1058,14 @@ func (s *state) SetFeeState(feeState gas.State) { s.feeState = feeState } +func (s *state) GetAccruedFees() uint64 { + return s.accruedFees +} + +func (s *state) SetAccruedFees(accruedFees uint64) { + s.accruedFees = accruedFees +} + func (s *state) GetLastAccepted() ids.ID { return s.lastAccepted } @@ -1343,6 +1357,13 @@ func (s *state) loadMetadata() error { s.persistedFeeState = feeState s.SetFeeState(feeState) + accruedFees, err := getAccruedFees(s.singletonDB) + if err != nil { + return err + } + s.persistedAccruedFees = accruedFees + s.SetAccruedFees(accruedFees) + currentSupply, err := database.GetUInt64(s.singletonDB, CurrentSupplyKey) if err != nil { return err @@ -2344,6 +2365,12 @@ func (s *state) writeMetadata() error { } s.persistedFeeState = s.feeState } + if s.accruedFees != s.persistedAccruedFees { + if err := database.PutUInt64(s.singletonDB, AccruedFeesKey, s.accruedFees); err != nil { + return fmt.Errorf("failed to write accrued fees: %w", err) + } + s.persistedAccruedFees = s.accruedFees + } if s.persistedCurrentSupply != s.currentSupply { if err := database.PutUInt64(s.singletonDB, CurrentSupplyKey, s.currentSupply); err != nil { return fmt.Errorf("failed to write current supply: %w", err) @@ -2555,3 +2582,11 @@ func getFeeState(db database.KeyValueReader) (gas.State, error) { } return feeState, nil } + +func getAccruedFees(db database.KeyValueReader) (uint64, error) { + accruedFees, err := database.GetUInt64(db, AccruedFeesKey) + if err == database.ErrNotFound { + return 0, nil + } + return accruedFees, err +} diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 08b53eeb4fe8..3be526333e3e 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -1468,6 +1468,22 @@ func TestStateFeeStateCommitAndLoad(t *testing.T) { require.Equal(expectedFeeState, s.GetFeeState()) } +// Verify that committing the state writes the accrued fees to the database and +// that loading the state fetches the accrued fees from the database. +func TestStateAccruedFeesCommitAndLoad(t *testing.T) { + require := require.New(t) + + db := memdb.New() + s := newTestState(t, db) + + expectedAccruedFees := uint64(1) + s.SetAccruedFees(expectedAccruedFees) + require.NoError(s.Commit()) + + s = newTestState(t, db) + require.Equal(expectedAccruedFees, s.GetAccruedFees()) +} + func TestMarkAndIsInitialized(t *testing.T) { require := require.New(t)