diff --git a/parser/parse_select.go b/parser/parse_select.go index 1447de1..34925c0 100644 --- a/parser/parse_select.go +++ b/parser/parse_select.go @@ -16,6 +16,8 @@ package parser import ( "errors" + "fmt" + "strings" ) // Determines is the proxy handles the select statement. @@ -77,33 +79,48 @@ func isHandledUseStmt(l *lexer) (handled bool, stmt Statement, err error) { // selectors: selector ( ',' selector )* // selector: unaliasedSelector ( 'AS' identifier ) // unaliasedSelector: -// identifier -// 'COUNT(*)' -// term -// 'CAST' '(' unaliasedSelector 'AS' primitiveType ')' +// +// identifier +// 'COUNT(*)' | 'COUNT' '(' identifier ')' | NOW()' +// term +// 'CAST' '(' unaliasedSelector 'AS' primitiveType ')' // // Note: Doesn't handle term or cast func parseSelector(l *lexer, t token) (selector Selector, next token, err error) { switch t { case tkIdentifier: - if isUnreservedKeyword(l, t, "count") { - countText := l.identifierStr() - if tkLparen != l.next() { - return nil, tkInvalid, errors.New("expected '(' after 'COUNT' in select statement") + name := l.identifierStr() + l.mark() + if tkLparen == l.next() { + var args []string + for t = l.next(); tkRparen != t && tkEOF != t; t = skipToken(l, l.next(), tkComma) { + if tkStar == t { + args = append(args, "*") + } else if tkIdentifier == t { + args = append(args, l.identifierStr()) + } else { + return nil, tkInvalid, fmt.Errorf("unexpected argument type for function call '%s(...)' in select statement", name) + } } - if t = l.next(); tkStar == t { - selector = &CountStarSelector{Name: countText + "(*)"} - } else if tkIdentifier == t { - selector = &CountStarSelector{Name: countText + "(" + l.identifierStr() + ")"} - } else { - - return nil, tkInvalid, errors.New("expected * or identifier in argument 'COUNT(...)' in select statement") + if tkRparen != t { + return nil, tkInvalid, fmt.Errorf("expected closing ')' for function call '%s' in select statement", name) } - if tkRparen != l.next() { - return nil, tkInvalid, errors.New("expected closing ')' for 'COUNT' in select statement") + if strings.EqualFold(name, "count") { + if len(args) == 0 { + return nil, tkInvalid, fmt.Errorf("expected * or identifier in argument 'COUNT(...)' in select statement") + } + return &CountFuncSelector{Arg: args[0]}, l.next(), nil + } else if strings.EqualFold(name, "now") { + if len(args) != 0 { + return nil, tkInvalid, fmt.Errorf("unexpected argument for 'NOW()' function call in select statement") + } + return &NowFuncSelector{}, l.next(), nil + } else { + return nil, tkInvalid, fmt.Errorf("unsupported function call '%s' in select statement", name) } } else { - selector = &IDSelector{Name: l.identifierStr()} + l.rewind() + selector = &IDSelector{Name: name} } case tkStar: return &StarSelector{}, l.next(), nil @@ -119,4 +136,4 @@ func parseSelector(l *lexer, t token) (selector Selector, next token, err error) } return selector, t, nil -} +} \ No newline at end of file diff --git a/parser/parser.go b/parser/parser.go index 3ca33ff..b815259 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -17,10 +17,19 @@ package parser -import "errors" +import ( + "errors" + "fmt" + "strings" + + "github.com/datastax/go-cassandra-native-protocol/datatype" + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/google/uuid" +) type Selector interface { - isSelector() + Values(columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) + Columns(columns []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) } type AliasSelector struct { @@ -28,24 +37,107 @@ type AliasSelector struct { Alias string } -func (a AliasSelector) isSelector() {} +func (a AliasSelector) Values(columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) { + return a.Selector.Values(columns, valueFunc) +} + +func (a AliasSelector) Columns(columns []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) { + cols, err := a.Selector.Columns(columns, stmt) + if err != nil { + return + } + for _, column := range cols { + alias := *column // Make a copy so we can modify the name + alias.Name = a.Alias + filtered = append(filtered, &alias) + } + return +} type IDSelector struct { Name string } -func (I IDSelector) isSelector() {} +func (i IDSelector) Values(_ []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) { + value, err := valueFunc(i.Name) + if err != nil { + return + } + return []message.Column{value}, err +} + +func (i IDSelector) Columns(columns []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) { + if column := FindColumnMetadata(columns, i.Name); column != nil { + return []*message.ColumnMetadata{column}, nil + } else { + return nil, fmt.Errorf("invalid column %s", i.Name) + } +} type StarSelector struct{} -func (s StarSelector) isSelector() {} +func (s StarSelector) Values(columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) { + for _, column := range columns { + var val message.Column + val, err = valueFunc(column.Name) + if err != nil { + return + } + filtered = append(filtered, val) + } + return +} + +func (s StarSelector) Columns(columns []*message.ColumnMetadata, _ *SelectStatement) (filtered []*message.ColumnMetadata, err error) { + filtered = columns + return +} -type CountStarSelector struct { - Name string +type CountFuncSelector struct { + Arg string } -func (c CountStarSelector) isSelector() {} +func (s CountFuncSelector) Values(_ []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) { + val, err := valueFunc(CountValueName) + if err != nil { + return + } + filtered = append(filtered, val) + return +} +func (s CountFuncSelector) Columns(_ []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) { + name := "count" + if s.Arg != "*" { + name = fmt.Sprintf("system.count(%s)", strings.ToLower(s.Arg)) + } + return []*message.ColumnMetadata{{ + Keyspace: stmt.Keyspace, + Table: stmt.Table, + Name: name, + Type: datatype.Int, + }}, nil +} + +type NowFuncSelector struct{} + +func (s NowFuncSelector) Values(_ []*message.ColumnMetadata, _ ValueLookupFunc) (filtered []message.Column, err error) { + u, err := uuid.NewUUID() + if err != nil { + return + } + filtered = append(filtered, u[:]) + return +} + +func (s NowFuncSelector) Columns(_ []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) { + return []*message.ColumnMetadata{{ + Keyspace: stmt.Keyspace, + Table: stmt.Table, + Name: "system.now()", + Type: datatype.Timeuuid, + }}, nil +} type Statement interface { isStatement() } diff --git a/parser/parser_test.go b/parser/parser_test.go index 8f94e8f..d0a9f9a 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -17,7 +17,11 @@ package parser import ( "testing" + "github.com/datastax/go-cassandra-native-protocol/datacodec" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParser(t *testing.T) { @@ -35,56 +39,56 @@ func TestParser(t *testing.T) { Selectors: []Selector{ &IDSelector{Name: "key"}, &AliasSelector{Alias: "address", Selector: &IDSelector{Name: "rpc_address"}}, - &CountStarSelector{Name: "count(*)"}, + &CountFuncSelector{Arg: "*"}, }, }, false}, {"system", "SELECT count(*) FROM local", true, true, &SelectStatement{ Keyspace: "system", Table: "local", Selectors: []Selector{ - &CountStarSelector{Name: "count(*)"}, + &CountFuncSelector{Arg: "*"}, }, }, false}, {"system", "SELECT count(*) FROM \"local\"", true, true, &SelectStatement{ Keyspace: "system", Table: "local", Selectors: []Selector{ - &CountStarSelector{Name: "count(*)"}, + &CountFuncSelector{Arg: "*"}, }, }, false}, {"", "SELECT count(*) FROM system.peers", true, true, &SelectStatement{ Keyspace: "system", Table: "peers", Selectors: []Selector{ - &CountStarSelector{Name: "count(*)"}, + &CountFuncSelector{Arg: "*"}, }, }, false}, {"", "SELECT count(*) FROM \"system\".\"peers\"", true, true, &SelectStatement{ Keyspace: "system", Table: "peers", Selectors: []Selector{ - &CountStarSelector{Name: "count(*)"}, + &CountFuncSelector{Arg: "*"}, }, }, false}, {"system", "SELECT count(*) FROM peers", true, true, &SelectStatement{ Keyspace: "system", Table: "peers", Selectors: []Selector{ - &CountStarSelector{Name: "count(*)"}, + &CountFuncSelector{Arg: "*"}, }, }, false}, {"", "SELECT count(*) FROM system.peers_v2", true, true, &SelectStatement{ Keyspace: "system", Table: "peers_v2", Selectors: []Selector{ - &CountStarSelector{Name: "count(*)"}, + &CountFuncSelector{Arg: "*"}, }, }, false}, {"system", "SELECT count(*) FROM peers_v2", true, true, &SelectStatement{ Keyspace: "system", Table: "peers_v2", Selectors: []Selector{ - &CountStarSelector{Name: "count(*)"}, + &CountFuncSelector{Arg: "*"}, }, }, false}, {"", "SELECT func(key) FROM system.local", true, true, nil, true}, @@ -176,3 +180,57 @@ func TestParser(t *testing.T) { assert.Equal(t, tt.stmt, stmt, "invalid parsed statement", tt.query) } } + +func TestParserSystemNowFunction(t *testing.T) { + var tests = []struct { + query string + table string + }{ + {"SELECT now() FROM system.local", "local"}, + {"SELECT now() FROM system.peers", "peers"}, + } + + start, _, err := uuid.GetTime() + require.NoError(t, err) + + for _, tt := range tests { + t.Run(tt.query, func(t *testing.T) { + handled, stmt, err := IsQueryHandled(IdentifierFromString(""), tt.query) + assert.NoError(t, err, tt.query) + assert.True(t, handled, tt.query) + require.NotNil(t, stmt) + if selectStmt, ok := stmt.(*SelectStatement); ok { + require.Len(t, selectStmt.Selectors, 1) + selector := selectStmt.Selectors[0] + require.NotNil(t, selector) + if nowSelector, ok := selector.(*NowFuncSelector); ok { + values, err := nowSelector.Values(nil, nil) + assert.NoError(t, err) + assert.Len(t, values, 1) + + var u primitive.UUID + wasNull, err := datacodec.Timeuuid.Decode(values[0], &u, primitive.ProtocolVersion4) + assert.False(t, wasNull) + assert.NoError(t, err) + + v, err := uuid.FromBytes(u[:]) + assert.NoError(t, err) + assert.Equal(t, uuid.RFC4122, v.Variant()) + assert.Equal(t, uuid.Version(1), v.Version()) + assert.GreaterOrEqual(t, v.Time(), start) + + columns, err := nowSelector.Columns(nil, selectStmt) + assert.NoError(t, err) + assert.Len(t, columns, 1) + assert.Equal(t, "system.now()", columns[0].Name) + assert.Equal(t, "system", columns[0].Keyspace) + assert.Equal(t, tt.table, columns[0].Table) + } else { + assert.Fail(t, "expected now function selector") + } + } else { + assert.Fail(t, "expected select statement") + } + }) + } +} diff --git a/parser/parser_utils.go b/parser/parser_utils.go index 9c81b74..32ba6e2 100644 --- a/parser/parser_utils.go +++ b/parser/parser_utils.go @@ -16,9 +16,7 @@ package parser import ( "errors" - "fmt" - "github.com/datastax/go-cassandra-native-protocol/datatype" "github.com/datastax/go-cassandra-native-protocol/message" ) @@ -33,99 +31,27 @@ var nonIdempotentFuncs = []string{"uuid", "now"} type ValueLookupFunc func(name string) (value message.Column, err error) func FilterValues(stmt *SelectStatement, columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) { - if _, ok := stmt.Selectors[0].(*StarSelector); ok { - for _, column := range columns { - var val message.Column - val, err = valueFunc(column.Name) - if err != nil { - return nil, err - } - filtered = append(filtered, val) - } - } else { - for _, selector := range stmt.Selectors { - var val message.Column - val, err = valueFromSelector(selector, valueFunc) - if err != nil { - return nil, err - } - filtered = append(filtered, val) + for _, selector := range stmt.Selectors { + var vals []message.Column + vals, err = selector.Values(columns, valueFunc) + if err != nil { + return nil, err } + filtered = append(filtered, vals...) } return filtered, nil } -func valueFromSelector(selector Selector, valueFunc ValueLookupFunc) (val message.Column, err error) { - switch s := selector.(type) { - case *CountStarSelector: - return valueFunc(CountValueName) - case *IDSelector: - return valueFunc(s.Name) - case *AliasSelector: - return valueFromSelector(s.Selector, valueFunc) - default: - return nil, errors.New("unhandled selector type") - } -} - func FilterColumns(stmt *SelectStatement, columns []*message.ColumnMetadata) (filtered []*message.ColumnMetadata, err error) { - if _, ok := stmt.Selectors[0].(*StarSelector); ok { - filtered = columns - } else { - for _, selector := range stmt.Selectors { - var column *message.ColumnMetadata - column, err = columnFromSelector(selector, columns, stmt.Keyspace, stmt.Table) - if err != nil { - return nil, err - } - filtered = append(filtered, column) - } - } - return filtered, nil -} - -func isCountSelector(selector Selector) bool { - _, ok := selector.(*CountStarSelector) - return ok -} - -func IsCountStarQuery(stmt *SelectStatement) bool { - if len(stmt.Selectors) == 1 { - if isCountSelector(stmt.Selectors[0]) { - return true - } else if alias, ok := stmt.Selectors[0].(*AliasSelector); ok { - return isCountSelector(alias.Selector) - } - } - return false -} - -func columnFromSelector(selector Selector, columns []*message.ColumnMetadata, keyspace string, table string) (column *message.ColumnMetadata, err error) { - switch s := selector.(type) { - case *CountStarSelector: - return &message.ColumnMetadata{ - Keyspace: keyspace, - Table: table, - Name: s.Name, - Type: datatype.Int, - }, nil - case *IDSelector: - if column = FindColumnMetadata(columns, s.Name); column != nil { - return column, nil - } else { - return nil, fmt.Errorf("invalid column %s", s.Name) - } - case *AliasSelector: - column, err = columnFromSelector(s.Selector, columns, keyspace, table) + for _, selector := range stmt.Selectors { + var cols []*message.ColumnMetadata + cols, err = selector.Columns(columns, stmt) if err != nil { return nil, err } - alias := *column // Make a copy so we can modify the name - alias.Name = s.Alias - return &alias, nil - default: - return nil, errors.New("unhandled selector type") + filtered = append(filtered, cols...) } + return filtered, nil } func isSystemTable(name Identifier) bool {