Skip to content

Add support for "select now() from system.local" #138

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 2 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions parser/parse_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package parser

import (
"errors"
"fmt"
"strings"
)

// Determines is the proxy handles the select statement.
Expand Down Expand Up @@ -77,33 +79,48 @@ func isHandledUseStmt(l *lexer) (handled bool, stmt Statement, err error) {
// selectors: selector ( ',' selector )*
// selector: unaliasedSelector ( 'AS' identifier )
// unaliasedSelector:
// identifier
// 'COUNT(*)'
// term
// 'CAST' '(' unaliasedSelector 'AS' primitiveType ')'
//
// identifier
// 'COUNT(*)' | 'COUNT' '(' identifier ')' | NOW()'
// term
// 'CAST' '(' unaliasedSelector 'AS' primitiveType ')'
//
// Note: Doesn't handle term or cast
func parseSelector(l *lexer, t token) (selector Selector, next token, err error) {
switch t {
case tkIdentifier:
if isUnreservedKeyword(l, t, "count") {
countText := l.identifierStr()
if tkLparen != l.next() {
return nil, tkInvalid, errors.New("expected '(' after 'COUNT' in select statement")
name := l.identifierStr()
l.mark()
if tkLparen == l.next() {
var args []string
for t = l.next(); tkRparen != t && tkEOF != t; t = skipToken(l, l.next(), tkComma) {
if tkStar == t {
args = append(args, "*")
} else if tkIdentifier == t {
args = append(args, l.identifierStr())
} else {
return nil, tkInvalid, fmt.Errorf("unexpected argument type for function call '%s(...)' in select statement", name)
}
}
if t = l.next(); tkStar == t {
selector = &CountStarSelector{Name: countText + "(*)"}
} else if tkIdentifier == t {
selector = &CountStarSelector{Name: countText + "(" + l.identifierStr() + ")"}
} else {

return nil, tkInvalid, errors.New("expected * or identifier in argument 'COUNT(...)' in select statement")
if tkRparen != t {
return nil, tkInvalid, fmt.Errorf("expected closing ')' for function call '%s' in select statement", name)
}
if tkRparen != l.next() {
return nil, tkInvalid, errors.New("expected closing ')' for 'COUNT' in select statement")
if strings.EqualFold(name, "count") {
if len(args) == 0 {
return nil, tkInvalid, fmt.Errorf("expected * or identifier in argument 'COUNT(...)' in select statement")
}
return &CountFuncSelector{Arg: args[0]}, l.next(), nil
} else if strings.EqualFold(name, "now") {
if len(args) != 0 {
return nil, tkInvalid, fmt.Errorf("unexpected argument for 'NOW()' function call in select statement")
}
return &NowFuncSelector{}, l.next(), nil
} else {
return nil, tkInvalid, fmt.Errorf("unsupported function call '%s' in select statement", name)
}
} else {
selector = &IDSelector{Name: l.identifierStr()}
l.rewind()
selector = &IDSelector{Name: name}
}
case tkStar:
return &StarSelector{}, l.next(), nil
Expand All @@ -119,4 +136,4 @@ func parseSelector(l *lexer, t token) (selector Selector, next token, err error)
}

return selector, t, nil
}
}
108 changes: 100 additions & 8 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,127 @@

package parser

import "errors"
import (
"errors"
"fmt"
"strings"

"github.com/datastax/go-cassandra-native-protocol/datatype"
"github.com/datastax/go-cassandra-native-protocol/message"
"github.com/google/uuid"
)

type Selector interface {
isSelector()
Values(columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error)
Columns(columns []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if this weren't fixing the underlying problem the conversion to managing values and columns within selectors rather than as external util functions is worth the cost of admission by itself.

}

type AliasSelector struct {
Selector Selector
Alias string
}

func (a AliasSelector) isSelector() {}
func (a AliasSelector) Values(columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) {
return a.Selector.Values(columns, valueFunc)
}

func (a AliasSelector) Columns(columns []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) {
cols, err := a.Selector.Columns(columns, stmt)
if err != nil {
return
}
for _, column := range cols {
alias := *column // Make a copy so we can modify the name
alias.Name = a.Alias
filtered = append(filtered, &alias)
}
return
}

type IDSelector struct {
Name string
}

func (I IDSelector) isSelector() {}
func (i IDSelector) Values(_ []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) {
value, err := valueFunc(i.Name)
if err != nil {
return
}
return []message.Column{value}, err
}

func (i IDSelector) Columns(columns []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) {
if column := FindColumnMetadata(columns, i.Name); column != nil {
return []*message.ColumnMetadata{column}, nil
} else {
return nil, fmt.Errorf("invalid column %s", i.Name)
}
}

type StarSelector struct{}

func (s StarSelector) isSelector() {}
func (s StarSelector) Values(columns []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) {
for _, column := range columns {
var val message.Column
val, err = valueFunc(column.Name)
if err != nil {
return
}
filtered = append(filtered, val)
}
return
}

func (s StarSelector) Columns(columns []*message.ColumnMetadata, _ *SelectStatement) (filtered []*message.ColumnMetadata, err error) {
filtered = columns
return
}

type CountStarSelector struct {
Name string
type CountFuncSelector struct {
Arg string
}

func (c CountStarSelector) isSelector() {}
func (s CountFuncSelector) Values(_ []*message.ColumnMetadata, valueFunc ValueLookupFunc) (filtered []message.Column, err error) {
val, err := valueFunc(CountValueName)
if err != nil {
return
}
filtered = append(filtered, val)
return
}

func (s CountFuncSelector) Columns(_ []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) {
name := "count"
if s.Arg != "*" {
name = fmt.Sprintf("system.count(%s)", strings.ToLower(s.Arg))
}
return []*message.ColumnMetadata{{
Keyspace: stmt.Keyspace,
Table: stmt.Table,
Name: name,
Type: datatype.Int,
}}, nil
}

type NowFuncSelector struct{}

func (s NowFuncSelector) Values(_ []*message.ColumnMetadata, _ ValueLookupFunc) (filtered []message.Column, err error) {
u, err := uuid.NewUUID()
if err != nil {
return
}
filtered = append(filtered, u[:])
return
}

func (s NowFuncSelector) Columns(_ []*message.ColumnMetadata, stmt *SelectStatement) (filtered []*message.ColumnMetadata, err error) {
return []*message.ColumnMetadata{{
Keyspace: stmt.Keyspace,
Table: stmt.Table,
Name: "system.now()",
Type: datatype.Timeuuid,
}}, nil
}
type Statement interface {
isStatement()
}
Expand Down
74 changes: 66 additions & 8 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ package parser
import (
"testing"

"github.com/datastax/go-cassandra-native-protocol/datacodec"
"github.com/datastax/go-cassandra-native-protocol/primitive"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestParser(t *testing.T) {
Expand All @@ -35,56 +39,56 @@ func TestParser(t *testing.T) {
Selectors: []Selector{
&IDSelector{Name: "key"},
&AliasSelector{Alias: "address", Selector: &IDSelector{Name: "rpc_address"}},
&CountStarSelector{Name: "count(*)"},
&CountFuncSelector{Arg: "*"},
},
}, false},
{"system", "SELECT count(*) FROM local", true, true, &SelectStatement{
Keyspace: "system",
Table: "local",
Selectors: []Selector{
&CountStarSelector{Name: "count(*)"},
&CountFuncSelector{Arg: "*"},
},
}, false},
{"system", "SELECT count(*) FROM \"local\"", true, true, &SelectStatement{
Keyspace: "system",
Table: "local",
Selectors: []Selector{
&CountStarSelector{Name: "count(*)"},
&CountFuncSelector{Arg: "*"},
},
}, false},
{"", "SELECT count(*) FROM system.peers", true, true, &SelectStatement{
Keyspace: "system",
Table: "peers",
Selectors: []Selector{
&CountStarSelector{Name: "count(*)"},
&CountFuncSelector{Arg: "*"},
},
}, false},
{"", "SELECT count(*) FROM \"system\".\"peers\"", true, true, &SelectStatement{
Keyspace: "system",
Table: "peers",
Selectors: []Selector{
&CountStarSelector{Name: "count(*)"},
&CountFuncSelector{Arg: "*"},
},
}, false},
{"system", "SELECT count(*) FROM peers", true, true, &SelectStatement{
Keyspace: "system",
Table: "peers",
Selectors: []Selector{
&CountStarSelector{Name: "count(*)"},
&CountFuncSelector{Arg: "*"},
},
}, false},
{"", "SELECT count(*) FROM system.peers_v2", true, true, &SelectStatement{
Keyspace: "system",
Table: "peers_v2",
Selectors: []Selector{
&CountStarSelector{Name: "count(*)"},
&CountFuncSelector{Arg: "*"},
},
}, false},
{"system", "SELECT count(*) FROM peers_v2", true, true, &SelectStatement{
Keyspace: "system",
Table: "peers_v2",
Selectors: []Selector{
&CountStarSelector{Name: "count(*)"},
&CountFuncSelector{Arg: "*"},
},
}, false},
{"", "SELECT func(key) FROM system.local", true, true, nil, true},
Expand Down Expand Up @@ -176,3 +180,57 @@ func TestParser(t *testing.T) {
assert.Equal(t, tt.stmt, stmt, "invalid parsed statement", tt.query)
}
}

func TestParserSystemNowFunction(t *testing.T) {
var tests = []struct {
query string
table string
}{
{"SELECT now() FROM system.local", "local"},
{"SELECT now() FROM system.peers", "peers"},
}

start, _, err := uuid.GetTime()
require.NoError(t, err)

for _, tt := range tests {
t.Run(tt.query, func(t *testing.T) {
handled, stmt, err := IsQueryHandled(IdentifierFromString(""), tt.query)
assert.NoError(t, err, tt.query)
assert.True(t, handled, tt.query)
require.NotNil(t, stmt)
if selectStmt, ok := stmt.(*SelectStatement); ok {
require.Len(t, selectStmt.Selectors, 1)
selector := selectStmt.Selectors[0]
require.NotNil(t, selector)
if nowSelector, ok := selector.(*NowFuncSelector); ok {
values, err := nowSelector.Values(nil, nil)
assert.NoError(t, err)
assert.Len(t, values, 1)

var u primitive.UUID
wasNull, err := datacodec.Timeuuid.Decode(values[0], &u, primitive.ProtocolVersion4)
assert.False(t, wasNull)
assert.NoError(t, err)

v, err := uuid.FromBytes(u[:])
assert.NoError(t, err)
assert.Equal(t, uuid.RFC4122, v.Variant())
assert.Equal(t, uuid.Version(1), v.Version())
assert.GreaterOrEqual(t, v.Time(), start)

columns, err := nowSelector.Columns(nil, selectStmt)
assert.NoError(t, err)
assert.Len(t, columns, 1)
assert.Equal(t, "system.now()", columns[0].Name)
assert.Equal(t, "system", columns[0].Keyspace)
assert.Equal(t, tt.table, columns[0].Table)
} else {
assert.Fail(t, "expected now function selector")
}
} else {
assert.Fail(t, "expected select statement")
}
})
}
}
Loading