Skip to content

Commit

Permalink
Update Keyspace/Table name in prepared Query statement
Browse files Browse the repository at this point in the history
Previously TokenAwarePolicy always used Keyspace explicitly set in
cluster.Keyspace regardless of the keyspace in the Query. Now after
preparing statement Keyspace and Table names are transferred to the
Query and it can make use of that.

Fixes: #1621
  • Loading branch information
sylwiaszunejko committed Jul 14, 2023
1 parent 642e867 commit f420b60
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 29 deletions.
2 changes: 1 addition & 1 deletion cass1batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestShouldPrepareFunction(t *testing.T) {
}

for _, test := range shouldPrepareTests {
q := &Query{stmt: test.Stmt}
q := &Query{stmt: test.Stmt, routingInfo: &queryRoutingInfo{}}
if got := q.shouldPrepare(); got != test.Result {
t.Fatalf("%q: got %v, expected %v\n", test.Stmt, got, test.Result)
}
Expand Down
6 changes: 6 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,12 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter {
params: params,
customPayload: qry.customPayload,
}

// Set "keyspace" and "table" property in the query if it is present in preparedMetadata
qry.routingInfo.mu.Lock()
qry.routingInfo.keyspace = info.request.keyspace
qry.routingInfo.table = info.request.table
qry.routingInfo.mu.Unlock()
} else {
frame = &writeQueryFrame{
statement: qry.stmt,
Expand Down
13 changes: 8 additions & 5 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,10 @@ type preparedMetadata struct {

// proto v4+
pkeyColumns []int

keyspace string

table string
}

func (r preparedMetadata) String() string {
Expand Down Expand Up @@ -952,26 +956,25 @@ func (f *framer) parsePreparedMetadata() preparedMetadata {
return meta
}

var keyspace, table string
globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec
if globalSpec {
keyspace = f.readString()
table = f.readString()
meta.keyspace = f.readString()
meta.table = f.readString()
}

var cols []ColumnInfo
if meta.colCount < 1000 {
// preallocate columninfo to avoid excess copying
cols = make([]ColumnInfo, meta.colCount)
for i := 0; i < meta.colCount; i++ {
f.readCol(&cols[i], &meta.resultMetadata, globalSpec, keyspace, table)
f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table)
}
} else {
// use append, huge number of columns usually indicates a corrupt frame or
// just a huge row.
for i := 0; i < meta.colCount; i++ {
var col ColumnInfo
f.readCol(&col, &meta.resultMetadata, globalSpec, keyspace, table)
f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table)
cols = append(cols, col)
}
}
Expand Down
62 changes: 62 additions & 0 deletions keyspace_table_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package gocql

import (
"context"
"fmt"
"log"
"testing"
)

// Keyspace_table checks if Query.Keyspace() is updated based on prepared statement
func TestKeyspaceTable(t *testing.T) {
cluster := createCluster()

fallback := RoundRobinHostPolicy()
cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback)

session, err := cluster.CreateSession()
if err != nil {
t.Fatal("createSession:", err)
}

cluster.Keyspace = "wrong_keyspace"

keyspace := "test"
table := "table1"

createKeyspace(t, cluster, keyspace)

err = createTable(session, fmt.Sprintf(`CREATE TABLE %s.%s (pk int, ck int, v int, PRIMARY KEY (pk, ck));
`, keyspace, table))

if err != nil {
panic(fmt.Sprintf("unable to create table: %v", err))
}

if err := session.control.awaitSchemaAgreement(); err != nil {
t.Fatal(err)
}

ctx := context.Background()

// insert a row
if err := session.Query(`INSERT INTO test.table1(pk, ck, v) VALUES (?, ?, ?)`,
1, 2, 3).WithContext(ctx).Exec(); err != nil {
log.Fatal(err)
}

var pk int

/* Search for a specific set of records whose 'pk' column matches
* the value of inserted row. */
qry := session.Query(`SELECT pk FROM test.table1 WHERE pk = ? LIMIT 1`,
1).WithContext(ctx).Consistency(One)
if err := qry.Scan(&pk); err != nil {
log.Fatal(err)
}

// cluster.Keyspace was set to "wrong_keyspace", but during prepering statement
// Keyspace in Query should be changed to "test" and Table should be changed to table1
assertEqual(t, "qry.Keyspace()", "test", qry.Keyspace())
assertEqual(t, "qry.Table()", "table1", qry.Table())
}
14 changes: 7 additions & 7 deletions policies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) {
return nil, errors.New("not initalized")
}

query := &Query{}
query := &Query{routingInfo: &queryRoutingInfo{}}
query.getKeyspace = func() string { return keyspace }

iter := policy.Pick(nil)
Expand Down Expand Up @@ -201,7 +201,7 @@ func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) {
}
policy.SetPartitioner("OrderedPartitioner")

query := &Query{}
query := &Query{routingInfo: &queryRoutingInfo{}}
query.getKeyspace = func() string { return "myKeyspace" }
query.RoutingKey([]byte("20"))

Expand Down Expand Up @@ -259,7 +259,7 @@ func TestCOWList_Add(t *testing.T) {

// TestSimpleRetryPolicy makes sure that we only allow 1 + numRetries attempts
func TestSimpleRetryPolicy(t *testing.T) {
q := &Query{}
q := &Query{routingInfo: &queryRoutingInfo{}}

// this should allow a total of 3 tries.
rt := &SimpleRetryPolicy{NumRetries: 2}
Expand Down Expand Up @@ -317,7 +317,7 @@ func TestExponentialBackoffPolicy(t *testing.T) {

func TestDowngradingConsistencyRetryPolicy(t *testing.T) {

q := &Query{cons: LocalQuorum}
q := &Query{cons: LocalQuorum, routingInfo: &queryRoutingInfo{}}

rewt0 := &RequestErrWriteTimeout{
Received: 0,
Expand Down Expand Up @@ -478,7 +478,7 @@ func TestHostPolicy_TokenAware(t *testing.T) {
return nil, errors.New("not initialized")
}

query := &Query{}
query := &Query{routingInfo: &queryRoutingInfo{}}
query.getKeyspace = func() string { return keyspace }

iter := policy.Pick(nil)
Expand Down Expand Up @@ -580,7 +580,7 @@ func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) {
return nil, errors.New("not initialized")
}

query := &Query{}
query := &Query{routingInfo: &queryRoutingInfo{}}
query.getKeyspace = func() string { return keyspace }

iter := policy.Pick(nil)
Expand Down Expand Up @@ -707,7 +707,7 @@ func TestHostPolicy_TokenAware_RackAware(t *testing.T) {
policyWithFallbackInternal.getKeyspaceName = policyInternal.getKeyspaceName
policyWithFallbackInternal.getKeyspaceMetadata = policyInternal.getKeyspaceMetadata

query := &Query{}
query := &Query{routingInfo: &queryRoutingInfo{}}
query.getKeyspace = func() string { return keyspace }

iter := policy.Pick(nil)
Expand Down
1 change: 1 addition & 0 deletions query_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type ExecutableQuery interface {
speculativeExecutionPolicy() SpeculativeExecutionPolicy
GetRoutingKey() ([]byte, error)
Keyspace() string
Table() string
IsIdempotent() bool

withContext(context.Context) ExecutableQuery
Expand Down
71 changes: 57 additions & 14 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ type Session struct {

var queryPool = &sync.Pool{
New: func() interface{} {
return &Query{refCount: 1}
return &Query{routingInfo: &queryRoutingInfo{}, refCount: 1}
},
}

Expand Down Expand Up @@ -630,6 +630,9 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
return nil, nil
}

table := info.request.table
keyspace := info.request.keyspace

if len(info.request.pkeyColumns) > 0 {
// proto v4 dont need to calculate primary key columns
types := make([]TypeInfo, len(info.request.pkeyColumns))
Expand All @@ -638,17 +641,16 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI
}

routingKeyInfo := &routingKeyInfo{
indexes: info.request.pkeyColumns,
types: types,
indexes: info.request.pkeyColumns,
types: types,
keyspace: keyspace,
table: table,
}

inflight.value = routingKeyInfo
return routingKeyInfo, nil
}

// get the table metadata
table := info.request.columns[0].Table

var keyspaceMetadata *KeyspaceMetadata
keyspaceMetadata, inflight.err = s.KeyspaceMetadata(info.request.columns[0].Keyspace)
if inflight.err != nil {
Expand All @@ -672,8 +674,10 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI

size := len(partitionKey)
routingKeyInfo := &routingKeyInfo{
indexes: make([]int, size),
types: make([]TypeInfo, size),
indexes: make([]int, size),
types: make([]TypeInfo, size),
keyspace: keyspace,
table: table,
}

for keyIndex, keyColumn := range partitionKey {
Expand Down Expand Up @@ -909,6 +913,18 @@ type Query struct {
// used by control conn queries to prevent triggering a write to systems
// tables in AWS MCS see
skipPrepare bool

// routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex.
routingInfo *queryRoutingInfo
}

type queryRoutingInfo struct {
// mu protects contents of queryRoutingInfo.
mu sync.RWMutex

keyspace string

table string
}

func (q *Query) defaultsFromSession() {
Expand Down Expand Up @@ -1104,6 +1120,10 @@ func (q *Query) Keyspace() string {
if q.getKeyspace != nil {
return q.getKeyspace()
}
if q.routingInfo.keyspace != "" {
return q.routingInfo.keyspace
}

if q.session == nil {
return ""
}
Expand All @@ -1112,6 +1132,11 @@ func (q *Query) Keyspace() string {
return q.session.cfg.Keyspace
}

// Table returns name of the table the query will be executed against.
func (q *Query) Table() string {
return q.routingInfo.table
}

// GetRoutingKey gets the routing key to use for routing this query. If
// a routing key has not been explicitly set, then the routing key will
// be constructed if possible using the keyspace's schema and the query
Expand All @@ -1134,6 +1159,12 @@ func (q *Query) GetRoutingKey() ([]byte, error) {
return nil, err
}

if routingKeyInfo != nil {
q.routingInfo.mu.Lock()
q.routingInfo.keyspace = routingKeyInfo.keyspace
q.routingInfo.table = routingKeyInfo.table
q.routingInfo.mu.Unlock()
}
return createRoutingKey(routingKeyInfo, q.values)
}

Expand Down Expand Up @@ -1349,7 +1380,7 @@ func (q *Query) Release() {

// reset zeroes out all fields of a query so that it can be safely pooled.
func (q *Query) reset() {
*q = Query{refCount: 1}
*q = Query{routingInfo: &queryRoutingInfo{}, refCount: 1}
}

func (q *Query) incRefCount() {
Expand Down Expand Up @@ -1691,16 +1722,20 @@ type Batch struct {
cancelBatch func()
keyspace string
metrics *queryMetrics

// routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex.
routingInfo *queryRoutingInfo
}

// NewBatch creates a new batch operation without defaults from the cluster
//
// Deprecated: use session.NewBatch instead
func NewBatch(typ BatchType) *Batch {
return &Batch{
Type: typ,
metrics: &queryMetrics{m: make(map[string]*hostMetrics)},
spec: &NonSpeculativeExecution{},
Type: typ,
metrics: &queryMetrics{m: make(map[string]*hostMetrics)},
spec: &NonSpeculativeExecution{},
routingInfo: &queryRoutingInfo{},
}
}

Expand All @@ -1719,6 +1754,7 @@ func (s *Session) NewBatch(typ BatchType) *Batch {
keyspace: s.cfg.Keyspace,
metrics: &queryMetrics{m: make(map[string]*hostMetrics)},
spec: &NonSpeculativeExecution{},
routingInfo: &queryRoutingInfo{},
}

s.mu.RUnlock()
Expand All @@ -1743,6 +1779,11 @@ func (b *Batch) Keyspace() string {
return b.keyspace
}

// Batch has no reasonable eqivalent of Query.Table().
func (b *Batch) Table() string {
return b.routingInfo.table
}

// Attempts returns the number of attempts made to execute the batch.
func (b *Batch) Attempts() int {
return b.metrics.attempts()
Expand Down Expand Up @@ -2014,8 +2055,10 @@ type routingKeyInfoLRU struct {
}

type routingKeyInfo struct {
indexes []int
types []TypeInfo
indexes []int
types []TypeInfo
keyspace string
table string
}

func (r *routingKeyInfo) String() string {
Expand Down
4 changes: 2 additions & 2 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (f funcQueryObserver) ObserveQuery(ctx context.Context, o ObservedQuery) {
}

func TestQueryBasicAPI(t *testing.T) {
qry := &Query{}
qry := &Query{routingInfo: &queryRoutingInfo{}}

// Initiate host
ip := "127.0.0.1"
Expand Down Expand Up @@ -164,7 +164,7 @@ func TestQueryBasicAPI(t *testing.T) {
func TestQueryShouldPrepare(t *testing.T) {
toPrepare := []string{"select * ", "INSERT INTO", "update table", "delete from", "begin batch"}
cantPrepare := []string{"create table", "USE table", "LIST keyspaces", "alter table", "drop table", "grant user", "revoke user"}
q := &Query{}
q := &Query{routingInfo: &queryRoutingInfo{}}

for i := 0; i < len(toPrepare); i++ {
q.stmt = toPrepare[i]
Expand Down

0 comments on commit f420b60

Please sign in to comment.