diff --git a/api/jsonrpc/client.go b/api/jsonrpc/client.go index d189525700..ea3b29acaa 100644 --- a/api/jsonrpc/client.go +++ b/api/jsonrpc/client.go @@ -170,8 +170,8 @@ func (cli *JSONRPCClient) GenerateTransactionManual( // Build transaction actionCodec, authCodec := parser.ActionCodec(), parser.AuthCodec() - tx := chain.NewTx(base, actions) - tx, err := tx.Sign(authFactory, actionCodec, authCodec) + unsignedTx := chain.NewTxData(base, actions) + tx, err := unsignedTx.Sign(authFactory, actionCodec, authCodec) if err != nil { return nil, nil, fmt.Errorf("%w: failed to sign transaction", err) } diff --git a/api/jsonrpc/server.go b/api/jsonrpc/server.go index 364b82185f..361ca9513d 100644 --- a/api/jsonrpc/server.go +++ b/api/jsonrpc/server.go @@ -105,7 +105,7 @@ func (j *JSONRPCServer) SubmitTx( if !rtx.Empty() { return errTransactionExtraBytes } - if err := tx.Verify(ctx); err != nil { + if err := tx.VerifyAuth(ctx); err != nil { return err } txID := tx.ID() diff --git a/api/ws/server.go b/api/ws/server.go index d583420411..b98103659d 100644 --- a/api/ws/server.go +++ b/api/ws/server.go @@ -264,7 +264,7 @@ func (w *WebSocketServer) MessageCallback() pubsub.Callback { // Verify tx if w.vm.GetVerifyAuth() { - if err := tx.Verify(ctx); err != nil { + if err := tx.VerifyAuth(ctx); err != nil { w.logger.Error("failed to verify sig", zap.Error(err), ) diff --git a/chain/transaction.go b/chain/transaction.go index 059e0d55cb..1bb3eb1480 100644 --- a/chain/transaction.go +++ b/chain/transaction.go @@ -28,28 +28,23 @@ var ( _ mempool.Item = (*Transaction)(nil) ) -type Transaction struct { +type TransactionData struct { Base *Base `json:"base"` Actions Actions `json:"actions"` - Auth Auth `json:"auth"` unsignedBytes []byte - bytes []byte - size int - id ids.ID - stateKeys state.Keys } -func NewTx(base *Base, actions Actions) *Transaction { - return &Transaction{ +func NewTxData(base *Base, actions Actions) *TransactionData { + return &TransactionData{ Base: base, Actions: actions, } } -// UnsignedBytes returns the byte slice representation of the unsigned tx -func (t *Transaction) UnsignedBytes() ([]byte, error) { +// UnsignedBytes returns the byte slice representation of the tx +func (t *TransactionData) UnsignedBytes() ([]byte, error) { if len(t.unsignedBytes) > 0 { return t.unsignedBytes, nil } @@ -62,16 +57,16 @@ func (t *Transaction) UnsignedBytes() ([]byte, error) { size += actionsSize p := codec.NewWriter(size, consts.NetworkSizeLimit) - if err := t.marshal(p, false); err != nil { + if err := t.marshal(p); err != nil { return nil, err } - - return p.Bytes(), p.Err() + t.unsignedBytes = p.Bytes() + return t.unsignedBytes, p.Err() } // Sign returns a new signed transaction with the unsigned tx copied from // the original and a signature provided by the authFactory -func (t *Transaction) Sign( +func (t *TransactionData) Sign( factory AuthFactory, actionCodec *codec.TypeParser[Action], authCodec *codec.TypeParser[Auth], @@ -86,9 +81,11 @@ func (t *Transaction) Sign( } signedTransaction := Transaction{ - Base: t.Base, - Actions: t.Actions, - Auth: auth, + TransactionData: TransactionData{ + Base: t.Base, + Actions: t.Actions, + }, + Auth: auth, } // Ensure transaction is fully initialized and correct by reloading it from @@ -105,14 +102,58 @@ func (t *Transaction) Sign( return UnmarshalTx(p, actionCodec, authCodec) } -// Verify that the transaction was signed correctly. -func (t *Transaction) Verify(ctx context.Context) error { - msg, err := t.UnsignedBytes() - if err != nil { - // Should never occur because populated during unmarshal - return err +func (t *TransactionData) Expiry() int64 { return t.Base.Timestamp } + +func (t *TransactionData) MaxFee() uint64 { return t.Base.MaxFee } + +func (t *TransactionData) Marshal(p *codec.Packer) error { + if len(t.unsignedBytes) > 0 { + p.PackFixedBytes(t.unsignedBytes) + return p.Err() } - return t.Auth.Verify(ctx, msg) + return t.marshal(p) +} + +func (t *TransactionData) marshal(p *codec.Packer) error { + t.Base.Marshal(p) + return t.Actions.marshalInto(p) +} + +type Actions []Action + +func (a Actions) Size() (int, error) { + var size int + for _, action := range a { + actionSize, err := GetSize(action) + if err != nil { + return 0, err + } + size += consts.ByteLen + actionSize + } + return size, nil +} + +func (a Actions) marshalInto(p *codec.Packer) error { + p.PackByte(uint8(len(a))) + for _, action := range a { + p.PackByte(action.GetTypeID()) + err := marshalInto(action, p) + if err != nil { + return err + } + } + return nil +} + +type Transaction struct { + TransactionData + + Auth Auth `json:"auth"` + + bytes []byte + size int + id ids.ID + stateKeys state.Keys } func (t *Transaction) Bytes() []byte { return t.bytes } @@ -121,10 +162,6 @@ func (t *Transaction) Size() int { return t.size } func (t *Transaction) ID() ids.ID { return t.id } -func (t *Transaction) Expiry() int64 { return t.Base.Timestamp } - -func (t *Transaction) MaxFee() uint64 { return t.Base.MaxFee } - func (t *Transaction) StateKeys(sm StateManager) (state.Keys, error) { if t.stateKeys != nil { return t.stateKeys, nil @@ -150,9 +187,6 @@ func (t *Transaction) StateKeys(sm StateManager) (state.Keys, error) { return stateKeys, nil } -// Sponsor is the [codec.Address] that pays fees for this transaction. -func (t *Transaction) Sponsor() codec.Address { return t.Auth.Sponsor() } - // Units is charged whether or not a transaction is successful. func (t *Transaction) Units(sm StateManager, r Rules) (fees.Dimensions, error) { // Calculate compute usage @@ -204,81 +238,10 @@ func (t *Transaction) Units(sm StateManager, r Rules) (fees.Dimensions, error) { return fees.Dimensions{uint64(t.Size()), maxComputeUnits, reads, allocates, writes}, nil } -// EstimateUnits provides a pessimistic estimate (some key accesses may be duplicates) of the cost -// to execute a transaction. -// -// This is typically used during transaction construction. -func EstimateUnits(r Rules, actions Actions, authFactory AuthFactory) (fees.Dimensions, error) { - var ( - bandwidth = uint64(BaseSize) - stateKeysMaxChunks = []uint16{} // TODO: preallocate - computeOp = math.NewUint64Operator(r.GetBaseComputeUnits()) - readsOp = math.NewUint64Operator(0) - allocatesOp = math.NewUint64Operator(0) - writesOp = math.NewUint64Operator(0) - ) - - // Calculate over action/auth - bandwidth += consts.Uint8Len - for _, action := range actions { - actionSize, err := GetSize(action) - if err != nil { - return fees.Dimensions{}, err - } - - actor := authFactory.Address() - stateKeys := action.StateKeys(actor) - actionStateKeysMaxChunks, ok := stateKeys.ChunkSizes() - if !ok { - return fees.Dimensions{}, ErrInvalidKeyValue - } - bandwidth += consts.ByteLen + uint64(actionSize) - stateKeysMaxChunks = append(stateKeysMaxChunks, actionStateKeysMaxChunks...) - computeOp.Add(action.ComputeUnits(r)) - } - authBandwidth, authCompute := authFactory.MaxUnits() - bandwidth += consts.ByteLen + authBandwidth - sponsorStateKeyMaxChunks := r.GetSponsorStateKeysMaxChunks() - stateKeysMaxChunks = append(stateKeysMaxChunks, sponsorStateKeyMaxChunks...) - computeOp.Add(authCompute) - - // Estimate compute costs - compute, err := computeOp.Value() - if err != nil { - return fees.Dimensions{}, err - } - - // Estimate storage costs - for _, maxChunks := range stateKeysMaxChunks { - // Compute key costs - readsOp.Add(r.GetStorageKeyReadUnits()) - allocatesOp.Add(r.GetStorageKeyAllocateUnits()) - writesOp.Add(r.GetStorageKeyWriteUnits()) - - // Compute value costs - readsOp.MulAdd(uint64(maxChunks), r.GetStorageValueReadUnits()) - allocatesOp.MulAdd(uint64(maxChunks), r.GetStorageValueAllocateUnits()) - writesOp.MulAdd(uint64(maxChunks), r.GetStorageValueWriteUnits()) - } - reads, err := readsOp.Value() - if err != nil { - return fees.Dimensions{}, err - } - allocates, err := allocatesOp.Value() - if err != nil { - return fees.Dimensions{}, err - } - writes, err := writesOp.Value() - if err != nil { - return fees.Dimensions{}, err - } - return fees.Dimensions{bandwidth, compute, reads, allocates, writes}, nil -} - func (t *Transaction) PreExecute( ctx context.Context, feeManager *internalfees.Manager, - s StateManager, + sm StateManager, r Rules, im state.Immutable, timestamp int64, @@ -305,7 +268,7 @@ func (t *Transaction) PreExecute( if end >= 0 && timestamp > end { return ErrAuthNotActivated } - units, err := t.Units(s, r) + units, err := t.Units(sm, r) if err != nil { return err } @@ -313,7 +276,7 @@ func (t *Transaction) PreExecute( if err != nil { return err } - return s.CanDeduct(ctx, t.Auth.Sponsor(), im, fee) + return sm.CanDeduct(ctx, t.Auth.Sponsor(), im, fee) } // Execute after knowing a transaction can pay a fee. Attempt @@ -323,13 +286,13 @@ func (t *Transaction) PreExecute( func (t *Transaction) Execute( ctx context.Context, feeManager *internalfees.Manager, - s StateManager, + sm StateManager, r Rules, ts *tstate.TStateView, timestamp int64, ) (*Result, error) { // Always charge fee first - units, err := t.Units(s, r) + units, err := t.Units(sm, r) if err != nil { // Should never happen return nil, err @@ -339,7 +302,7 @@ func (t *Transaction) Execute( // Should never happen return nil, err } - if err := s.Deduct(ctx, t.Auth.Sponsor(), ts, fee); err != nil { + if err := sm.Deduct(ctx, t.Auth.Sponsor(), ts, fee); err != nil { // This should never fail for low balance (as we check [CanDeductFee] // immediately before). return nil, err @@ -385,81 +348,95 @@ func (t *Transaction) Execute( }, nil } +// Sponsor is the [codec.Address] that pays fees for this transaction. +func (t *Transaction) Sponsor() codec.Address { return t.Auth.Sponsor() } + func (t *Transaction) Marshal(p *codec.Packer) error { if len(t.bytes) > 0 { p.PackFixedBytes(t.bytes) return p.Err() } - return t.marshal(p, true) + return t.marshal(p) } -func (t *Transaction) marshal(p *codec.Packer, marshalSignature bool) error { - t.Base.Marshal(p) - if err := t.Actions.marshalInto(p); err != nil { +func (t *Transaction) marshal(p *codec.Packer) error { + if err := t.TransactionData.marshal(p); err != nil { return err } - if marshalSignature { - authID := t.Auth.GetTypeID() - p.PackByte(authID) - t.Auth.Marshal(p) - } + authID := t.Auth.GetTypeID() + p.PackByte(authID) + t.Auth.Marshal(p) + return p.Err() } -type Actions []Action - -func (a Actions) Size() (int, error) { - var size int - for _, action := range a { - actionSize, err := GetSize(action) - if err != nil { - return 0, err - } - size += consts.ByteLen + actionSize +// VerifyAuth verifies that the transaction was signed correctly. +func (t *Transaction) VerifyAuth(ctx context.Context) error { + msg, err := t.UnsignedBytes() + if err != nil { + // Should never occur because populated during unmarshal + return err } - return size, nil + return t.Auth.Verify(ctx, msg) } -func (a Actions) marshalInto(p *codec.Packer) error { - p.PackByte(uint8(len(a))) - for _, action := range a { - p.PackByte(action.GetTypeID()) - err := marshalInto(action, p) - if err != nil { - return err - } +func UnmarshalTxData( + p *codec.Packer, + actionRegistry *codec.TypeParser[Action], +) (*TransactionData, error) { + start := p.Offset() + base, err := UnmarshalBase(p) + if err != nil { + return nil, fmt.Errorf("%w: could not unmarshal base", err) } - return nil + actions, err := UnmarshalActions(p, actionRegistry) + if err != nil { + return nil, fmt.Errorf("%w: could not unmarshal actions", err) + } + + var tx TransactionData + tx.Base = base + tx.Actions = actions + if err := p.Err(); err != nil { + return nil, p.Err() + } + codecBytes := p.Bytes() + tx.unsignedBytes = codecBytes[start:p.Offset()] // ensure errors handled before grabbing memory + return &tx, nil } -func MarshalTxs(txs []*Transaction) ([]byte, error) { - if len(txs) == 0 { - return nil, ErrNoTxs +func UnmarshalActions( + p *codec.Packer, + actionRegistry *codec.TypeParser[Action], +) (Actions, error) { + actionCount := p.UnpackByte() + if actionCount == 0 { + return nil, fmt.Errorf("%w: no actions", ErrInvalidObject) } - size := consts.IntLen + codec.CummSize(txs) - p := codec.NewWriter(size, consts.NetworkSizeLimit) - p.PackInt(uint32(len(txs))) - for _, tx := range txs { - if err := tx.Marshal(p); err != nil { - return nil, err + actions := Actions{} + for i := uint8(0); i < actionCount; i++ { + action, err := actionRegistry.Unmarshal(p) + if err != nil { + return nil, fmt.Errorf("%w: could not unmarshal action", err) } + actions = append(actions, action) } - return p.Bytes(), p.Err() + return actions, nil } func UnmarshalTxs( raw []byte, initialCapacity int, - actionCodec *codec.TypeParser[Action], - authCodec *codec.TypeParser[Auth], + actionRegistry *codec.TypeParser[Action], + authRegistry *codec.TypeParser[Auth], ) (map[uint8]int, []*Transaction, error) { p := codec.NewReader(raw, consts.NetworkSizeLimit) txCount := p.UnpackInt(true) authCounts := map[uint8]int{} txs := make([]*Transaction, 0, initialCapacity) // DoS to set size to txCount for i := uint32(0); i < txCount; i++ { - tx, err := UnmarshalTx(p, actionCodec, authCodec) + tx, err := UnmarshalTx(p, actionRegistry, authRegistry) if err != nil { return nil, nil, err } @@ -475,20 +452,15 @@ func UnmarshalTxs( func UnmarshalTx( p *codec.Packer, - actionCodec *codec.TypeParser[Action], - authCodec *codec.TypeParser[Auth], + actionRegistry *codec.TypeParser[Action], + authRegistry *codec.TypeParser[Auth], ) (*Transaction, error) { start := p.Offset() - base, err := UnmarshalBase(p) + unsignedTransaction, err := UnmarshalTxData(p, actionRegistry) if err != nil { - return nil, fmt.Errorf("%w: could not unmarshal base", err) - } - actions, err := UnmarshalActions(p, actionCodec) - if err != nil { - return nil, fmt.Errorf("%w: could not unmarshal actions", err) + return nil, err } - digest := p.Offset() - auth, err := authCodec.Unmarshal(p) + auth, err := authRegistry.Unmarshal(p) if err != nil { return nil, fmt.Errorf("%w: could not unmarshal auth", err) } @@ -502,35 +474,100 @@ func UnmarshalTx( } var tx Transaction - tx.Base = base - tx.Actions = actions + tx.TransactionData = *unsignedTransaction tx.Auth = auth if err := p.Err(); err != nil { return nil, p.Err() } codecBytes := p.Bytes() - tx.unsignedBytes = codecBytes[start:digest] tx.bytes = codecBytes[start:p.Offset()] // ensure errors handled before grabbing memory tx.size = len(tx.bytes) tx.id = utils.ToID(tx.bytes) return &tx, nil } -func UnmarshalActions( - p *codec.Packer, - actionCodec *codec.TypeParser[Action], -) (Actions, error) { - actionCount := p.UnpackByte() - if actionCount == 0 { - return nil, fmt.Errorf("%w: no actions", ErrInvalidObject) +func MarshalTxs(txs []*Transaction) ([]byte, error) { + if len(txs) == 0 { + return nil, ErrNoTxs } - actions := Actions{} - for i := uint8(0); i < actionCount; i++ { - action, err := actionCodec.Unmarshal(p) + size := consts.IntLen + codec.CummSize(txs) + p := codec.NewWriter(size, consts.NetworkSizeLimit) + p.PackInt(uint32(len(txs))) + for _, tx := range txs { + if err := tx.Marshal(p); err != nil { + return nil, err + } + } + return p.Bytes(), p.Err() +} + +// EstimateUnits provides a pessimistic estimate (some key accesses may be duplicates) of the cost +// to execute a transaction. +// +// This is typically used during transaction construction. +func EstimateUnits(r Rules, actions Actions, authFactory AuthFactory) (fees.Dimensions, error) { + var ( + bandwidth = uint64(BaseSize) + stateKeysMaxChunks = []uint16{} // TODO: preallocate + computeOp = math.NewUint64Operator(r.GetBaseComputeUnits()) + readsOp = math.NewUint64Operator(0) + allocatesOp = math.NewUint64Operator(0) + writesOp = math.NewUint64Operator(0) + ) + + // Calculate over action/auth + bandwidth += consts.Uint8Len + for _, action := range actions { + actionSize, err := GetSize(action) if err != nil { - return nil, fmt.Errorf("%w: could not unmarshal action", err) + return fees.Dimensions{}, err } - actions = append(actions, action) + + actor := authFactory.Address() + stateKeys := action.StateKeys(actor) + actionStateKeysMaxChunks, ok := stateKeys.ChunkSizes() + if !ok { + return fees.Dimensions{}, ErrInvalidKeyValue + } + bandwidth += consts.ByteLen + uint64(actionSize) + stateKeysMaxChunks = append(stateKeysMaxChunks, actionStateKeysMaxChunks...) + computeOp.Add(action.ComputeUnits(r)) } - return actions, nil + authBandwidth, authCompute := authFactory.MaxUnits() + bandwidth += consts.ByteLen + authBandwidth + sponsorStateKeyMaxChunks := r.GetSponsorStateKeysMaxChunks() + stateKeysMaxChunks = append(stateKeysMaxChunks, sponsorStateKeyMaxChunks...) + computeOp.Add(authCompute) + + // Estimate compute costs + compute, err := computeOp.Value() + if err != nil { + return fees.Dimensions{}, err + } + + // Estimate storage costs + for _, maxChunks := range stateKeysMaxChunks { + // Compute key costs + readsOp.Add(r.GetStorageKeyReadUnits()) + allocatesOp.Add(r.GetStorageKeyAllocateUnits()) + writesOp.Add(r.GetStorageKeyWriteUnits()) + + // Compute value costs + readsOp.MulAdd(uint64(maxChunks), r.GetStorageValueReadUnits()) + allocatesOp.MulAdd(uint64(maxChunks), r.GetStorageValueAllocateUnits()) + writesOp.MulAdd(uint64(maxChunks), r.GetStorageValueWriteUnits()) + } + reads, err := readsOp.Value() + if err != nil { + return fees.Dimensions{}, err + } + allocates, err := allocatesOp.Value() + if err != nil { + return fees.Dimensions{}, err + } + writes, err := writesOp.Value() + if err != nil { + return fees.Dimensions{}, err + } + return fees.Dimensions{bandwidth, compute, reads, allocates, writes}, nil } diff --git a/chain/transaction_test.go b/chain/transaction_test.go index 0b8a6996ff..9c12ec8cc5 100644 --- a/chain/transaction_test.go +++ b/chain/transaction_test.go @@ -77,7 +77,7 @@ func unmarshalAction2(p *codec.Packer) (chain.Action, error) { func TestMarshalUnmarshal(t *testing.T) { require := require.New(t) - tx := chain.Transaction{ + tx := chain.TransactionData{ Base: &chain.Base{ Timestamp: 1724315246000, ChainID: [32]byte{1, 2, 3, 4, 5, 6, 7}, @@ -115,7 +115,7 @@ func TestMarshalUnmarshal(t *testing.T) { err = actionCodec.Register(&action2{}, unmarshalAction2) require.NoError(err) - txBeforeSign := chain.Transaction{ + txBeforeSign := chain.TransactionData{ Base: &chain.Base{ Timestamp: 1724315246000, ChainID: [32]byte{1, 2, 3, 4, 5, 6, 7}, @@ -138,8 +138,10 @@ func TestMarshalUnmarshal(t *testing.T) { }, }, } + // call UnsignedBytes so that the "unsignedBytes" field would get populated. + _, err = txBeforeSign.UnsignedBytes() + require.NoError(err) - require.Nil(tx.Auth) signedTx, err := tx.Sign(factory, actionCodec, authCodec) require.NoError(err) require.Equal(txBeforeSign, tx) diff --git a/tests/integration/integration.go b/tests/integration/integration.go index 80bc79c76d..0640b20a3f 100644 --- a/tests/integration/integration.go +++ b/tests/integration/integration.go @@ -346,7 +346,7 @@ var _ = ginkgo.Describe("[Tx Processing]", ginkgo.Serial, func() { }) ginkgo.By("skip invalid time", func() { - tx := chain.NewTx( + tx := chain.NewTxData( &chain.Base{ ChainID: instances[0].chainID, Timestamp: 1, @@ -360,9 +360,12 @@ var _ = ginkgo.Describe("[Tx Processing]", ginkgo.Serial, func() { require.NoError(err) auth, err := authFactory.Sign(unsignedTxBytes) require.NoError(err) - tx.Auth = auth + signedTx := chain.Transaction{ + TransactionData: *tx, + Auth: auth, + } p := codec.NewWriter(0, consts.MaxInt) // test codec growth - require.NoError(tx.Marshal(p)) + require.NoError(signedTx.Marshal(p)) require.NoError(p.Err()) _, err = instances[0].cli.SubmitTx( context.Background(), diff --git a/vm/vm.go b/vm/vm.go index 71df707431..bd9873c833 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -925,13 +925,7 @@ func (vm *VM) Submit( // Verify auth if not already verified by caller if verifyAuth && vm.config.VerifyAuth { - unsignedTxBytes, err := tx.UnsignedBytes() - if err != nil { - // Should never fail - errs = append(errs, err) - continue - } - if err := tx.Auth.Verify(ctx, unsignedTxBytes); err != nil { + if err := tx.VerifyAuth(ctx); err != nil { // Failed signature verification is the only safe place to remove // a transaction in listeners. Every other case may still end up with // the transaction in a block.