diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index cb05f54fc6f7..5eab1b07ee99 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -149,17 +149,17 @@ func (mr *MockStateMockRecorder) AddUTXO(utxo any) *gomock.Call { } // ApplyValidatorPublicKeyDiffs mocks base method. -func (m *MockState) ApplyValidatorPublicKeyDiffs(ctx context.Context, validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight, endHeight uint64) error { +func (m *MockState) ApplyValidatorPublicKeyDiffs(ctx context.Context, validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight, endHeight uint64, subnetID ids.ID) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ApplyValidatorPublicKeyDiffs", ctx, validators, startHeight, endHeight) + ret := m.ctrl.Call(m, "ApplyValidatorPublicKeyDiffs", ctx, validators, startHeight, endHeight, subnetID) ret0, _ := ret[0].(error) return ret0 } // ApplyValidatorPublicKeyDiffs indicates an expected call of ApplyValidatorPublicKeyDiffs. -func (mr *MockStateMockRecorder) ApplyValidatorPublicKeyDiffs(ctx, validators, startHeight, endHeight any) *gomock.Call { +func (mr *MockStateMockRecorder) ApplyValidatorPublicKeyDiffs(ctx, validators, startHeight, endHeight, subnetID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyValidatorPublicKeyDiffs", reflect.TypeOf((*MockState)(nil).ApplyValidatorPublicKeyDiffs), ctx, validators, startHeight, endHeight) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyValidatorPublicKeyDiffs", reflect.TypeOf((*MockState)(nil).ApplyValidatorPublicKeyDiffs), ctx, validators, startHeight, endHeight, subnetID) } // ApplyValidatorWeightDiffs mocks base method. diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 53b109c23249..c70547265cf1 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -59,8 +59,9 @@ const ( var ( _ State = (*state)(nil) - errValidatorSetAlreadyPopulated = errors.New("validator set already populated") - errIsNotSubnet = errors.New("is not a subnet") + errValidatorSetAlreadyPopulated = errors.New("validator set already populated") + errIsNotSubnet = errors.New("is not a subnet") + errMissingPrimaryNetworkValidator = errors.New("missing primary network validator") BlockIDPrefix = []byte("blockID") BlockPrefix = []byte("block") @@ -193,6 +194,7 @@ type State interface { validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight uint64, endHeight uint64, + subnetID ids.ID, ) error SetHeight(height uint64) @@ -1257,10 +1259,11 @@ func (s *state) ApplyValidatorPublicKeyDiffs( validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight uint64, endHeight uint64, + subnetID ids.ID, ) error { diffIter := s.validatorPublicKeyDiffsDB.NewIteratorWithStartAndPrefix( - marshalStartDiffKey(constants.PrimaryNetworkID, startHeight), - constants.PrimaryNetworkID[:], + marshalStartDiffKey(subnetID, startHeight), + subnetID[:], ) defer diffIter.Release() @@ -1736,19 +1739,30 @@ func (s *state) loadPendingValidators() error { // Invariant: initValidatorSets requires loadCurrentValidators to have already // been called. func (s *state) initValidatorSets() error { - for subnetID, validators := range s.currentStakers.validators { + primaryNetworkValidators := s.currentStakers.validators[constants.PrimaryNetworkID] + for subnetID, subnetValidators := range s.currentStakers.validators { if s.validators.Count(subnetID) != 0 { // Enforce the invariant that the validator set is empty here. return fmt.Errorf("%w: %s", errValidatorSetAlreadyPopulated, subnetID) } - for nodeID, validator := range validators { - validatorStaker := validator.validator - if err := s.validators.AddStaker(subnetID, nodeID, validatorStaker.PublicKey, validatorStaker.TxID, validatorStaker.Weight); err != nil { + for nodeID, subnetValidator := range subnetValidators { + // The subnet validator's Public Key is inherited from the + // corresponding primary network validator. + primaryValidator, ok := primaryNetworkValidators[nodeID] + if !ok { + return fmt.Errorf("%w: %s", errMissingPrimaryNetworkValidator, nodeID) + } + + var ( + primaryStaker = primaryValidator.validator + subnetStaker = subnetValidator.validator + ) + if err := s.validators.AddStaker(subnetID, nodeID, primaryStaker.PublicKey, subnetStaker.TxID, subnetStaker.Weight); err != nil { return err } - delegatorIterator := iterator.FromTree(validator.delegators) + delegatorIterator := iterator.FromTree(subnetValidator.delegators) for delegatorIterator.Next() { delegatorStaker := delegatorIterator.Value() if err := s.validators.AddWeight(subnetID, nodeID, delegatorStaker.Weight); err != nil { @@ -2028,164 +2042,219 @@ func (s *state) writeExpiry() error { func (s *state) writeCurrentStakers(updateValidators bool, height uint64, codecVersion uint16) error { for subnetID, validatorDiffs := range s.currentStakers.validatorDiffs { + // We must write the primary network stakers last because writing subnet + // validator diffs may depend on the primary network validator diffs to + // inherit the public keys. + if subnetID == constants.PrimaryNetworkID { + continue + } + delete(s.currentStakers.validatorDiffs, subnetID) - // Select db to write to - validatorDB := s.currentSubnetValidatorList - delegatorDB := s.currentSubnetDelegatorList - if subnetID == constants.PrimaryNetworkID { - validatorDB = s.currentValidatorList - delegatorDB = s.currentDelegatorList + err := s.writeCurrentStakersSubnetDiff( + subnetID, + validatorDiffs, + updateValidators, + height, + codecVersion, + ) + if err != nil { + return err } + } - // Record the change in weight and/or public key for each validator. - for nodeID, validatorDiff := range validatorDiffs { - // Copy [nodeID] so it doesn't get overwritten next iteration. - nodeID := nodeID + if validatorDiffs, ok := s.currentStakers.validatorDiffs[constants.PrimaryNetworkID]; ok { + delete(s.currentStakers.validatorDiffs, constants.PrimaryNetworkID) - weightDiff := &ValidatorWeightDiff{ - Decrease: validatorDiff.validatorStatus == deleted, - } - switch validatorDiff.validatorStatus { - case added: - staker := validatorDiff.validator - weightDiff.Amount = staker.Weight - - // Invariant: Only the Primary Network contains non-nil public - // keys. - if staker.PublicKey != nil { - // Record that the public key for the validator is being - // added. This means the prior value for the public key was - // nil. - err := s.validatorPublicKeyDiffsDB.Put( - marshalDiffKey(constants.PrimaryNetworkID, height, nodeID), - nil, - ) - if err != nil { - return err - } - } + err := s.writeCurrentStakersSubnetDiff( + constants.PrimaryNetworkID, + validatorDiffs, + updateValidators, + height, + codecVersion, + ) + if err != nil { + return err + } + } - // The validator is being added. - // - // Invariant: It's impossible for a delegator to have been - // rewarded in the same block that the validator was added. - startTime := uint64(staker.StartTime.Unix()) - metadata := &validatorMetadata{ - txID: staker.TxID, - lastUpdated: staker.StartTime, - - UpDuration: 0, - LastUpdated: startTime, - StakerStartTime: startTime, - PotentialReward: staker.PotentialReward, - PotentialDelegateeReward: 0, - } + // TODO: Move validator set management out of the state package + if !updateValidators { + return nil + } - metadataBytes, err := MetadataCodec.Marshal(codecVersion, metadata) - if err != nil { - return fmt.Errorf("failed to serialize current validator: %w", err) - } + // Update the stake metrics + totalWeight, err := s.validators.TotalWeight(constants.PrimaryNetworkID) + if err != nil { + return fmt.Errorf("failed to get total weight of primary network: %w", err) + } - if err = validatorDB.Put(staker.TxID[:], metadataBytes); err != nil { - return fmt.Errorf("failed to write current validator to list: %w", err) - } + s.metrics.SetLocalStake(s.validators.GetWeight(constants.PrimaryNetworkID, s.ctx.NodeID)) + s.metrics.SetTotalStake(totalWeight) + return nil +} - s.validatorState.LoadValidatorMetadata(nodeID, subnetID, metadata) - case deleted: - staker := validatorDiff.validator - weightDiff.Amount = staker.Weight - - // Invariant: Only the Primary Network contains non-nil public - // keys. - if staker.PublicKey != nil { - // Record that the public key for the validator is being - // removed. This means we must record the prior value of the - // public key. - // - // Note: We store the uncompressed public key here as it is - // significantly more efficient to parse when applying - // diffs. - err := s.validatorPublicKeyDiffsDB.Put( - marshalDiffKey(constants.PrimaryNetworkID, height, nodeID), - bls.PublicKeyToUncompressedBytes(staker.PublicKey), - ) - if err != nil { - return err - } - } +func (s *state) writeCurrentStakersSubnetDiff( + subnetID ids.ID, + validatorDiffs map[ids.NodeID]*diffValidator, + updateValidators bool, + height uint64, + codecVersion uint16, +) error { + // Select db to write to + validatorDB := s.currentSubnetValidatorList + delegatorDB := s.currentSubnetDelegatorList + if subnetID == constants.PrimaryNetworkID { + validatorDB = s.currentValidatorList + delegatorDB = s.currentDelegatorList + } - if err := validatorDB.Delete(staker.TxID[:]); err != nil { - return fmt.Errorf("failed to delete current staker: %w", err) + // Record the change in weight and/or public key for each validator. + for nodeID, validatorDiff := range validatorDiffs { + var ( + staker *Staker + pk *bls.PublicKey + weightDiff = &ValidatorWeightDiff{ + Decrease: validatorDiff.validatorStatus == deleted, + } + ) + if validatorDiff.validatorStatus != unmodified { + staker = validatorDiff.validator + + pk = staker.PublicKey + // For non-primary network validators, the public key is inherited + // from the primary network. + if subnetID != constants.PrimaryNetworkID { + if vdr, ok := s.currentStakers.validators[constants.PrimaryNetworkID][nodeID]; ok && vdr.validator != nil { + // The primary network validator is still present after + // writing. + pk = vdr.validator.PublicKey + } else if vdr, ok := s.currentStakers.validatorDiffs[constants.PrimaryNetworkID][nodeID]; ok && vdr.validator != nil { + // The primary network validator is being removed during + // writing. + pk = vdr.validator.PublicKey + } else { + // This should never happen as the primary network diffs are + // written last and subnet validator times must be a subset + // of the primary network validator times. + return fmt.Errorf("%w: %s", errMissingPrimaryNetworkValidator, nodeID) } - - s.validatorState.DeleteValidatorMetadata(nodeID, subnetID) } - err := writeCurrentDelegatorDiff( - delegatorDB, - weightDiff, - validatorDiff, - codecVersion, - ) - if err != nil { - return err + weightDiff.Amount = staker.Weight + } + + switch validatorDiff.validatorStatus { + case added: + if pk != nil { + // Record that the public key for the validator is being added. + // This means the prior value for the public key was nil. + err := s.validatorPublicKeyDiffsDB.Put( + marshalDiffKey(subnetID, height, nodeID), + nil, + ) + if err != nil { + return err + } } - if weightDiff.Amount == 0 { - // No weight change to record; go to next validator. - continue + // The validator is being added. + // + // Invariant: It's impossible for a delegator to have been rewarded + // in the same block that the validator was added. + startTime := uint64(staker.StartTime.Unix()) + metadata := &validatorMetadata{ + txID: staker.TxID, + lastUpdated: staker.StartTime, + + UpDuration: 0, + LastUpdated: startTime, + StakerStartTime: startTime, + PotentialReward: staker.PotentialReward, + PotentialDelegateeReward: 0, } - err = s.validatorWeightDiffsDB.Put( - marshalDiffKey(subnetID, height, nodeID), - marshalWeightDiff(weightDiff), - ) + metadataBytes, err := MetadataCodec.Marshal(codecVersion, metadata) if err != nil { - return err + return fmt.Errorf("failed to serialize current validator: %w", err) } - // TODO: Move the validator set management out of the state package - if !updateValidators { - continue + if err = validatorDB.Put(staker.TxID[:], metadataBytes); err != nil { + return fmt.Errorf("failed to write current validator to list: %w", err) } - if weightDiff.Decrease { - err = s.validators.RemoveWeight(subnetID, nodeID, weightDiff.Amount) - } else { - if validatorDiff.validatorStatus == added { - staker := validatorDiff.validator - err = s.validators.AddStaker( - subnetID, - nodeID, - staker.PublicKey, - staker.TxID, - weightDiff.Amount, - ) - } else { - err = s.validators.AddWeight(subnetID, nodeID, weightDiff.Amount) + s.validatorState.LoadValidatorMetadata(nodeID, subnetID, metadata) + case deleted: + if pk != nil { + // Record that the public key for the validator is being + // removed. This means we must record the prior value of the + // public key. + // + // Note: We store the uncompressed public key here as it is + // significantly more efficient to parse when applying diffs. + err := s.validatorPublicKeyDiffsDB.Put( + marshalDiffKey(subnetID, height, nodeID), + bls.PublicKeyToUncompressedBytes(pk), + ) + if err != nil { + return err } } - if err != nil { - return fmt.Errorf("failed to update validator weight: %w", err) + + if err := validatorDB.Delete(staker.TxID[:]); err != nil { + return fmt.Errorf("failed to delete current staker: %w", err) } + + s.validatorState.DeleteValidatorMetadata(nodeID, subnetID) } - } - // TODO: Move validator set management out of the state package - // - // Attempt to update the stake metrics - if !updateValidators { - return nil - } + err := writeCurrentDelegatorDiff( + delegatorDB, + weightDiff, + validatorDiff, + codecVersion, + ) + if err != nil { + return err + } - totalWeight, err := s.validators.TotalWeight(constants.PrimaryNetworkID) - if err != nil { - return fmt.Errorf("failed to get total weight of primary network: %w", err) - } + if weightDiff.Amount == 0 { + // No weight change to record; go to next validator. + continue + } - s.metrics.SetLocalStake(s.validators.GetWeight(constants.PrimaryNetworkID, s.ctx.NodeID)) - s.metrics.SetTotalStake(totalWeight) + err = s.validatorWeightDiffsDB.Put( + marshalDiffKey(subnetID, height, nodeID), + marshalWeightDiff(weightDiff), + ) + if err != nil { + return err + } + + // TODO: Move the validator set management out of the state package + if !updateValidators { + continue + } + + if weightDiff.Decrease { + err = s.validators.RemoveWeight(subnetID, nodeID, weightDiff.Amount) + } else { + if validatorDiff.validatorStatus == added { + err = s.validators.AddStaker( + subnetID, + nodeID, + pk, + staker.TxID, + weightDiff.Amount, + ) + } else { + err = s.validators.AddWeight(subnetID, nodeID, weightDiff.Amount) + } + } + if err != nil { + return fmt.Errorf("failed to update validator weight: %w", err) + } + } return nil } diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 12214c2060b0..b965af1531eb 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -278,13 +278,15 @@ func TestState_writeStakers(t *testing.T) { addStakerTx: addSubnetValidator, expectedCurrentValidator: subnetCurrentValidatorStaker, expectedValidatorSetOutput: &validators.GetValidatorOutput{ - NodeID: subnetCurrentValidatorStaker.NodeID, - Weight: subnetCurrentValidatorStaker.Weight, + NodeID: subnetCurrentValidatorStaker.NodeID, + PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, + Weight: subnetCurrentValidatorStaker.Weight, }, expectedWeightDiff: &ValidatorWeightDiff{ Decrease: false, Amount: subnetCurrentValidatorStaker.Weight, }, + expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](nil), }, "delete current primary network validator": { initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, @@ -342,6 +344,7 @@ func TestState_writeStakers(t *testing.T) { Decrease: true, Amount: subnetCurrentValidatorStaker.Weight, }, + expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](primaryNetworkCurrentValidatorStaker.PublicKey), }, } @@ -832,8 +835,9 @@ func TestState_ApplyValidatorDiffs(t *testing.T) { }, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ subnetStakers[0].NodeID: { - NodeID: subnetStakers[0].NodeID, - Weight: subnetStakers[0].Weight, + NodeID: subnetStakers[0].NodeID, + PublicKey: primaryStakers[0].PublicKey, + Weight: subnetStakers[0].Weight, }, }, }, @@ -877,8 +881,9 @@ func TestState_ApplyValidatorDiffs(t *testing.T) { }, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ subnetStakers[2].NodeID: { - NodeID: subnetStakers[2].NodeID, - Weight: subnetStakers[2].Weight, + NodeID: subnetStakers[2].NodeID, + PublicKey: primaryStakers[2].PublicKey, + Weight: subnetStakers[2].Weight, }, }, }, @@ -904,16 +909,19 @@ func TestState_ApplyValidatorDiffs(t *testing.T) { }, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ subnetStakers[2].NodeID: { - NodeID: subnetStakers[2].NodeID, - Weight: subnetStakers[2].Weight, + NodeID: subnetStakers[2].NodeID, + PublicKey: primaryStakers[2].PublicKey, + Weight: subnetStakers[2].Weight, }, subnetStakers[3].NodeID: { - NodeID: subnetStakers[3].NodeID, - Weight: subnetStakers[3].Weight, + NodeID: subnetStakers[3].NodeID, + PublicKey: primaryStakers[3].PublicKey, + Weight: subnetStakers[3].Weight, }, subnetStakers[4].NodeID: { - NodeID: subnetStakers[4].NodeID, - Weight: subnetStakers[4].Weight, + NodeID: subnetStakers[4].NodeID, + PublicKey: primaryStakers[4].PublicKey, + Weight: subnetStakers[4].Weight, }, }, }, @@ -1011,10 +1019,41 @@ func TestState_ApplyValidatorDiffs(t *testing.T) { primaryValidatorSet, currentHeight, prevHeight+1, + constants.PrimaryNetworkID, )) require.Equal(prevDiff.expectedPrimaryValidatorSet, primaryValidatorSet) } + { + legacySubnetValidatorSet := copyValidatorSet(diff.expectedSubnetValidatorSet) + require.NoError(state.ApplyValidatorWeightDiffs( + context.Background(), + legacySubnetValidatorSet, + currentHeight, + prevHeight+1, + subnetID, + )) + + // Update the public keys of the subnet validators with the current + // primary network validator public keys + for nodeID, vdr := range legacySubnetValidatorSet { + if primaryVdr, ok := diff.expectedPrimaryValidatorSet[nodeID]; ok { + vdr.PublicKey = primaryVdr.PublicKey + } else { + vdr.PublicKey = nil + } + } + + require.NoError(state.ApplyValidatorPublicKeyDiffs( + context.Background(), + legacySubnetValidatorSet, + currentHeight, + prevHeight+1, + constants.PrimaryNetworkID, + )) + require.Equal(prevDiff.expectedSubnetValidatorSet, legacySubnetValidatorSet) + } + { subnetValidatorSet := copyValidatorSet(diff.expectedSubnetValidatorSet) require.NoError(state.ApplyValidatorWeightDiffs( @@ -1024,6 +1063,14 @@ func TestState_ApplyValidatorDiffs(t *testing.T) { prevHeight+1, subnetID, )) + + require.NoError(state.ApplyValidatorPublicKeyDiffs( + context.Background(), + subnetValidatorSet, + currentHeight, + prevHeight+1, + subnetID, + )) require.Equal(prevDiff.expectedSubnetValidatorSet, subnetValidatorSet) } } diff --git a/vms/platformvm/validators/manager.go b/vms/platformvm/validators/manager.go index 7f1ea5ea6407..142db3e7635c 100644 --- a/vms/platformvm/validators/manager.go +++ b/vms/platformvm/validators/manager.go @@ -85,6 +85,7 @@ type State interface { validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight uint64, endHeight uint64, + subnetID ids.ID, ) error } @@ -271,7 +272,7 @@ func (m *manager) makePrimaryNetworkValidatorSet( validatorSet, currentHeight, lastDiffHeight, - constants.PlatformChainID, + constants.PrimaryNetworkID, ) if err != nil { return nil, 0, err @@ -282,6 +283,7 @@ func (m *manager) makePrimaryNetworkValidatorSet( validatorSet, currentHeight, lastDiffHeight, + constants.PrimaryNetworkID, ) return validatorSet, currentHeight, err } @@ -348,6 +350,10 @@ func (m *manager) makeSubnetValidatorSet( subnetValidatorSet, currentHeight, lastDiffHeight, + // TODO: Etna introduces L1s whose validators specify their own public + // keys, rather than inheriting them from the primary network. + // Therefore, this will need to use the subnetID after Etna. + constants.PrimaryNetworkID, ) return subnetValidatorSet, currentHeight, err }