Skip to content

Commit 0ef5444

Browse files
committed
feat: improved account validation
1 parent bcef409 commit 0ef5444

File tree

5 files changed

+80
-3
lines changed

5 files changed

+80
-3
lines changed

internal/interpreter/evaluate_expr.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ func (st *programState) evaluateExpr(expr parser.ValueExpr) (Value, InterpreterE
3636
}
3737
}
3838
name := strings.Join(parts, ":")
39-
// TODO validate valid names
40-
return AccountAddress(name), nil
39+
return NewAccountAddress(name)
4140

4241
case *parser.StringLiteral:
4342
return String(expr.String), nil

internal/interpreter/interpreter.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func parseVar(type_ string, rawValue string, r parser.Range) (Value, Interpreter
8989
case analysis.TypeMonetary:
9090
return parseMonetary(rawValue)
9191
case analysis.TypeAccount:
92-
return AccountAddress(rawValue), nil
92+
return NewAccountAddress(rawValue)
9393
case analysis.TypePortion:
9494
bi, err := ParsePortionSpecific(rawValue)
9595
if err != nil {

internal/interpreter/interpreter_error.go

+9
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,12 @@ type CannotCastToString struct {
211211
func (e CannotCastToString) Error() string {
212212
return fmt.Sprintf("Cannot cast this value to string: %s", e.Value)
213213
}
214+
215+
type InvalidAccountName struct {
216+
parser.Range
217+
Name string
218+
}
219+
220+
func (e InvalidAccountName) Error() string {
221+
return fmt.Sprintf("Invalid account name: @%s", e.Name)
222+
}

internal/interpreter/interpreter_test.go

+52
Original file line numberDiff line numberDiff line change
@@ -3836,6 +3836,58 @@ func TestOneofDestinationRemainingClause(t *testing.T) {
38363836
testWithFeatureFlag(t, tc, machine.ExperimentalOneofFeatureFlag)
38373837
}
38383838

3839+
func TestInvalidAccount(t *testing.T) {
3840+
script := `
3841+
vars {
3842+
account $acc
3843+
}
3844+
set_tx_meta("k", $acc)
3845+
`
3846+
3847+
tc := NewTestCase()
3848+
tc.setVarsFromJSON(t, `
3849+
{
3850+
"acc": "!invalid acc.."
3851+
}
3852+
`)
3853+
3854+
tc.compile(t, script)
3855+
3856+
tc.expected = CaseResult{
3857+
Postings: []Posting{},
3858+
Error: machine.InvalidAccountName{
3859+
Name: "!invalid acc..",
3860+
},
3861+
}
3862+
test(t, tc)
3863+
}
3864+
3865+
func TestInvalidInterpAccount(t *testing.T) {
3866+
script := `
3867+
vars {
3868+
string $status
3869+
}
3870+
set_tx_meta("k", @user:$status)
3871+
`
3872+
3873+
tc := NewTestCase()
3874+
tc.setVarsFromJSON(t, `
3875+
{
3876+
"status": "!invalid acc.."
3877+
}
3878+
`)
3879+
3880+
tc.compile(t, script)
3881+
3882+
tc.expected = CaseResult{
3883+
Postings: []Posting{},
3884+
Error: machine.InvalidAccountName{
3885+
Name: "user:!invalid acc..",
3886+
},
3887+
}
3888+
testWithFeatureFlag(t, tc, machine.ExperimentalAccountInterpolationFlag)
3889+
}
3890+
38393891
func TestAccountInterp(t *testing.T) {
38403892
script := `
38413893
vars {

internal/interpreter/value.go

+17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package interpreter
33
import (
44
"fmt"
55
"math/big"
6+
"regexp"
67

78
"github.com/formancehq/numscript/internal/analysis"
89
"github.com/formancehq/numscript/internal/parser"
@@ -30,6 +31,13 @@ func (Monetary) value() {}
3031
func (Portion) value() {}
3132
func (Asset) value() {}
3233

34+
func NewAccountAddress(src string) (AccountAddress, InterpreterError) {
35+
if !validateAddress(src) {
36+
return AccountAddress(""), InvalidAccountName{Name: src}
37+
}
38+
return AccountAddress(src), nil
39+
}
40+
3341
func (v MonetaryInt) MarshalJSON() ([]byte, error) {
3442
bigInt := big.Int(v)
3543
s := fmt.Sprintf(`"%s"`, bigInt.String())
@@ -227,3 +235,12 @@ func (m MonetaryInt) Sub(other MonetaryInt) MonetaryInt {
227235
sum := new(big.Int).Sub(&bi, &otherBi)
228236
return MonetaryInt(*sum)
229237
}
238+
239+
const segmentRegex = "[a-zA-Z0-9_-]+"
240+
const accountPattern = "^" + segmentRegex + "(:" + segmentRegex + ")*$"
241+
242+
var Regexp = regexp.MustCompile(accountPattern)
243+
244+
func validateAddress(addr string) bool {
245+
return Regexp.Match([]byte(addr))
246+
}

0 commit comments

Comments
 (0)