From e10101f3c51fed379d00e0bd807fae881fc95986 Mon Sep 17 00:00:00 2001 From: Ben Campbell Date: Tue, 17 Oct 2023 18:13:17 -0600 Subject: [PATCH] Use Scanning Replace for Parameter Tokenization (#24) * scan replace param place holders * account for error paths for 100% coverage --- query_test.go | 20 ++++++ utils.go | 110 ++++++++++++++++++++++++++++++--- utils_test.go | 168 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 288 insertions(+), 10 deletions(-) create mode 100644 utils_test.go diff --git a/query_test.go b/query_test.go index 340a8cf..3d538aa 100644 --- a/query_test.go +++ b/query_test.go @@ -669,3 +669,23 @@ func TestValuerError(t *testing.T) { t.Errorf("got: %q, want: %q", err, wantError) } } + +func Benchmark_ToPgsql_Params(b *testing.B) { + parts := []string{} + args := []any{} + for j := 0; j < 10; j++ { + for j := 0; j < 1000; j++ { + parts = append(parts, "?") + args = append(args, 1) + } + } + + query := fmt.Sprintf("(%s)", strings.Join(parts, ",")) + + for i := 0; i < b.N; i++ { + _, _, err := New(query, args...).ToPgsql() + if err != nil { + b.Fatalf("failed to make benchmark sql: %v", err) + } + } +} diff --git a/utils.go b/utils.go index f6af97c..d7c0d29 100644 --- a/utils.go +++ b/utils.go @@ -1,29 +1,119 @@ package bqb import ( + "bufio" + "bytes" "database/sql/driver" "encoding/json" + "errors" "fmt" "strings" ) func dialectReplace(dialect Dialect, sql string, params []any) (string, error) { - if dialect == MYSQL || dialect == SQL { - sql = strings.ReplaceAll(sql, paramPh, "?") - } - for i, param := range params { - if dialect == RAW { + const ( + questionMark = "?" + doubleQuestionMarkDelimiter = "??" + parameterPlaceholder = paramPh + ) + + switch dialect { + case RAW: + raws := make([]string, len(params)) + for i, param := range params { p, err := paramToRaw(param) if err != nil { return "", err } - sql = strings.Replace(sql, paramPh, p, 1) - } else if dialect == PGSQL { - sql = strings.ReplaceAll(sql, "??", "?") - sql = strings.Replace(sql, paramPh, fmt.Sprintf("$%d", i+1), 1) + raws[i] = p + } + + return replaceWithScans( + sql, + scan{pattern: parameterPlaceholder, fn: func(i int) string { return raws[i] }}, + ) + case MYSQL, SQL: + return replaceWithScans( + sql, + scan{pattern: parameterPlaceholder, fn: func(int) string { return questionMark }}, + ) + case PGSQL: + return replaceWithScans( + sql, + scan{pattern: doubleQuestionMarkDelimiter, fn: func(int) string { return questionMark }}, + scan{pattern: parameterPlaceholder, fn: func(i int) string { return fmt.Sprintf("$%d", i+1) }}, + ) + default: + // No replacement defined for dialect + return sql, nil + } +} + +type replaceFn func(int) string + +type scan struct { + pattern string + fn replaceFn +} + +// replaceWithScans applies the given set of scanning arguments and joins their +// errors together. +func replaceWithScans(in string, ss ...scan) (string, error) { + errs := []error{} + for _, s := range ss { + out, err := scanReplace(in, s.pattern, s.fn) + errs = append(errs, err) + in = out + } + return in, errors.Join(errs...) +} + +func scanReplace(stmt string, replace string, fn replaceFn) (string, error) { + // Build a scanner that will iterate based on the replace token + ph := []byte(replace) + scanner := bufio.NewScanner(bytes.NewBufferString(stmt)) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + // Return nothing if at end of file and no data passed + if atEOF && len(data) == 0 { + return 0, nil, nil + } + + switch i := bytes.Index(data, ph); { + case i == 0: + return len(ph), data[0:len(ph)], nil + case i > 0: + return i, data[0:i], nil + + } + + // If at end of file with data return the data + if atEOF { + return len(data), data, nil + } + + return + }) + + i := 0 + + // Scan replacing tokens with the value returned from delegate + sb := strings.Builder{} + for scanner.Scan() { + switch txt := scanner.Text(); txt { + case replace: + // String builder will always return nil for an err so it is thrown + // away. + _, _ = sb.WriteString(fn(i)) + i++ + + default: + // String builder will always return nil for an err so it is thrown + // away. + _, _ = sb.WriteString(txt) } } - return sql, nil + + return sb.String(), scanner.Err() } func convertArg(text string, arg any) (string, []any, []error) { diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..dae5d21 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,168 @@ +package bqb + +import ( + "fmt" + "testing" +) + +func Test_scanReplace(t *testing.T) { + const ( + testReplace = "{{TEST_TOKEN}}" + ) + type args struct { + stmt string + replace string + fn replaceFn + } + type want struct { + str string + err error + } + + for _, tt := range []struct { + name string + args args + want want + }{ + { + name: "empty statement", + args: args{ + stmt: "", + replace: testReplace, + fn: func(i int) string { + return fmt.Sprintf("%v", i) + }, + }, + want: want{ + str: "", + err: nil, + }, + }, + { + name: "empty token statement", + args: args{ + stmt: "this tests an empty pattern token", + replace: "", + fn: func(i int) string { + return fmt.Sprintf("%v", i) + }, + }, + want: want{ + str: "this tests an empty pattern token", + err: nil, + }, + }, + { + name: "no tokens", + args: args{ + stmt: "this tests no tokens", + replace: testReplace, + fn: func(i int) string { + return fmt.Sprintf("%v", i) + }, + }, + want: want{ + str: "this tests no tokens", + err: nil, + }, + }, + { + name: "one front token", + args: args{ + stmt: fmt.Sprintf("%s this tests one token", testReplace), + replace: testReplace, + fn: func(i int) string { + return fmt.Sprintf("%v", i) + }, + }, + want: want{ + str: "0 this tests one token", + err: nil, + }, + }, + { + name: "one boundary token", + args: args{ + stmt: fmt.Sprintf("this tests one%s token", testReplace), + replace: testReplace, + fn: func(i int) string { + return fmt.Sprintf("%v", i) + }, + }, + want: want{ + str: "this tests one0 token", + err: nil, + }, + }, + { + name: "one end token", + args: args{ + stmt: fmt.Sprintf("this tests one token%s", testReplace), + replace: testReplace, + fn: func(i int) string { + return fmt.Sprintf("%v", i) + }, + }, + want: want{ + str: "this tests one token0", + err: nil, + }, + }, + { + name: "several tokens", + args: args{ + stmt: fmt.Sprintf("this tests %s the token %s", testReplace, testReplace), + replace: testReplace, + fn: func(i int) string { + return fmt.Sprintf("%v", i) + }, + }, + want: want{ + str: "this tests 0 the token 1", + err: nil, + }, + }, + { + name: "several tokens", + args: args{ + stmt: fmt.Sprintf("%s%sthis tests %s%s the token %s%s", testReplace, testReplace, testReplace, testReplace, testReplace, testReplace), + replace: testReplace, + fn: func(i int) string { + return fmt.Sprintf("%v", i) + }, + }, + want: want{ + str: "01this tests 23 the token 45", + err: nil, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + str, err := scanReplace(tt.args.stmt, testReplace, tt.args.fn) + if tt.want.str != str { + t.Errorf("unexpected str: want '%s' got '%s'", tt.want.str, str) + } + + if tt.want.err != err { + t.Errorf("unexpected err: want '%s' got '%s'", tt.want.err, err) + } + + }) + } +} + +func Test_dialectReplace_unknown_dialect(t *testing.T) { + const ( + testSql = "test-sql" + ) + params := []any{1, 2, "a", "c"} + sql, err := dialectReplace(Dialect("unknown"), testSql, params) + + if sql != "test-sql" { + t.Errorf("unexpected sql statement: want %s got %s", testSql, sql) + } + + if err != nil { + t.Error("unknown dialect should not return an error") + } +}