Skip to content

Commit

Permalink
Use Scanning Replace for Parameter Tokenization (#24)
Browse files Browse the repository at this point in the history
* scan replace param place holders

* account for error paths for 100% coverage
  • Loading branch information
benjic authored Oct 18, 2023
1 parent 43651a0 commit e10101f
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 10 deletions.
20 changes: 20 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,23 @@ func TestValuerError(t *testing.T) {
t.Errorf("got: %q, want: %q", err, wantError)
}
}

func Benchmark_ToPgsql_Params(b *testing.B) {
parts := []string{}
args := []any{}
for j := 0; j < 10; j++ {
for j := 0; j < 1000; j++ {
parts = append(parts, "?")
args = append(args, 1)
}
}

query := fmt.Sprintf("(%s)", strings.Join(parts, ","))

for i := 0; i < b.N; i++ {
_, _, err := New(query, args...).ToPgsql()
if err != nil {
b.Fatalf("failed to make benchmark sql: %v", err)
}
}
}
110 changes: 100 additions & 10 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,119 @@
package bqb

import (
"bufio"
"bytes"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
)

func dialectReplace(dialect Dialect, sql string, params []any) (string, error) {
if dialect == MYSQL || dialect == SQL {
sql = strings.ReplaceAll(sql, paramPh, "?")
}
for i, param := range params {
if dialect == RAW {
const (
questionMark = "?"
doubleQuestionMarkDelimiter = "??"
parameterPlaceholder = paramPh
)

switch dialect {
case RAW:
raws := make([]string, len(params))
for i, param := range params {
p, err := paramToRaw(param)
if err != nil {
return "", err
}
sql = strings.Replace(sql, paramPh, p, 1)
} else if dialect == PGSQL {
sql = strings.ReplaceAll(sql, "??", "?")
sql = strings.Replace(sql, paramPh, fmt.Sprintf("$%d", i+1), 1)
raws[i] = p
}

return replaceWithScans(
sql,
scan{pattern: parameterPlaceholder, fn: func(i int) string { return raws[i] }},
)
case MYSQL, SQL:
return replaceWithScans(
sql,
scan{pattern: parameterPlaceholder, fn: func(int) string { return questionMark }},
)
case PGSQL:
return replaceWithScans(
sql,
scan{pattern: doubleQuestionMarkDelimiter, fn: func(int) string { return questionMark }},
scan{pattern: parameterPlaceholder, fn: func(i int) string { return fmt.Sprintf("$%d", i+1) }},
)
default:
// No replacement defined for dialect
return sql, nil
}
}

type replaceFn func(int) string

type scan struct {
pattern string
fn replaceFn
}

// replaceWithScans applies the given set of scanning arguments and joins their
// errors together.
func replaceWithScans(in string, ss ...scan) (string, error) {
errs := []error{}
for _, s := range ss {
out, err := scanReplace(in, s.pattern, s.fn)
errs = append(errs, err)
in = out
}
return in, errors.Join(errs...)
}

func scanReplace(stmt string, replace string, fn replaceFn) (string, error) {
// Build a scanner that will iterate based on the replace token
ph := []byte(replace)
scanner := bufio.NewScanner(bytes.NewBufferString(stmt))
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
// Return nothing if at end of file and no data passed
if atEOF && len(data) == 0 {
return 0, nil, nil
}

switch i := bytes.Index(data, ph); {
case i == 0:
return len(ph), data[0:len(ph)], nil
case i > 0:
return i, data[0:i], nil

}

// If at end of file with data return the data
if atEOF {
return len(data), data, nil
}

return
})

i := 0

// Scan replacing tokens with the value returned from delegate
sb := strings.Builder{}
for scanner.Scan() {
switch txt := scanner.Text(); txt {
case replace:
// String builder will always return nil for an err so it is thrown
// away.
_, _ = sb.WriteString(fn(i))
i++

default:
// String builder will always return nil for an err so it is thrown
// away.
_, _ = sb.WriteString(txt)
}
}
return sql, nil

return sb.String(), scanner.Err()
}

func convertArg(text string, arg any) (string, []any, []error) {
Expand Down
168 changes: 168 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package bqb

import (
"fmt"
"testing"
)

func Test_scanReplace(t *testing.T) {
const (
testReplace = "{{TEST_TOKEN}}"
)
type args struct {
stmt string
replace string
fn replaceFn
}
type want struct {
str string
err error
}

for _, tt := range []struct {
name string
args args
want want
}{
{
name: "empty statement",
args: args{
stmt: "",
replace: testReplace,
fn: func(i int) string {
return fmt.Sprintf("%v", i)
},
},
want: want{
str: "",
err: nil,
},
},
{
name: "empty token statement",
args: args{
stmt: "this tests an empty pattern token",
replace: "",
fn: func(i int) string {
return fmt.Sprintf("%v", i)
},
},
want: want{
str: "this tests an empty pattern token",
err: nil,
},
},
{
name: "no tokens",
args: args{
stmt: "this tests no tokens",
replace: testReplace,
fn: func(i int) string {
return fmt.Sprintf("%v", i)
},
},
want: want{
str: "this tests no tokens",
err: nil,
},
},
{
name: "one front token",
args: args{
stmt: fmt.Sprintf("%s this tests one token", testReplace),
replace: testReplace,
fn: func(i int) string {
return fmt.Sprintf("%v", i)
},
},
want: want{
str: "0 this tests one token",
err: nil,
},
},
{
name: "one boundary token",
args: args{
stmt: fmt.Sprintf("this tests one%s token", testReplace),
replace: testReplace,
fn: func(i int) string {
return fmt.Sprintf("%v", i)
},
},
want: want{
str: "this tests one0 token",
err: nil,
},
},
{
name: "one end token",
args: args{
stmt: fmt.Sprintf("this tests one token%s", testReplace),
replace: testReplace,
fn: func(i int) string {
return fmt.Sprintf("%v", i)
},
},
want: want{
str: "this tests one token0",
err: nil,
},
},
{
name: "several tokens",
args: args{
stmt: fmt.Sprintf("this tests %s the token %s", testReplace, testReplace),
replace: testReplace,
fn: func(i int) string {
return fmt.Sprintf("%v", i)
},
},
want: want{
str: "this tests 0 the token 1",
err: nil,
},
},
{
name: "several tokens",
args: args{
stmt: fmt.Sprintf("%s%sthis tests %s%s the token %s%s", testReplace, testReplace, testReplace, testReplace, testReplace, testReplace),
replace: testReplace,
fn: func(i int) string {
return fmt.Sprintf("%v", i)
},
},
want: want{
str: "01this tests 23 the token 45",
err: nil,
},
},
} {
t.Run(tt.name, func(t *testing.T) {
str, err := scanReplace(tt.args.stmt, testReplace, tt.args.fn)
if tt.want.str != str {
t.Errorf("unexpected str: want '%s' got '%s'", tt.want.str, str)
}

if tt.want.err != err {
t.Errorf("unexpected err: want '%s' got '%s'", tt.want.err, err)
}

})
}
}

func Test_dialectReplace_unknown_dialect(t *testing.T) {
const (
testSql = "test-sql"
)
params := []any{1, 2, "a", "c"}
sql, err := dialectReplace(Dialect("unknown"), testSql, params)

if sql != "test-sql" {
t.Errorf("unexpected sql statement: want %s got %s", testSql, sql)
}

if err != nil {
t.Error("unknown dialect should not return an error")
}
}

0 comments on commit e10101f

Please sign in to comment.