diff --git a/conn.go b/conn.go index 31da6fc..cea8935 100644 --- a/conn.go +++ b/conn.go @@ -17,7 +17,8 @@ type managedConn struct { killed bool mu sync.RWMutex - execQueryCounter int + execStmtsCounter int // count the number of exec calls in a transaction + queryStmtsCounter int // count the number of query calls in a transaction } // BeginTx calls the underlying BeginTx method unless the supervising context @@ -78,6 +79,7 @@ func (c *managedConn) Exec(query string, args []driver.Value) (driver.Result, er if !ok { return nil, driver.ErrSkip } + c.incExecStmtsCounter() //increment the exec counter to keep track of the number of exec calls return conn.Exec(query, args) } @@ -86,7 +88,7 @@ func (c *managedConn) ExecContext(ctx context.Context, query string, args []driv if !ok { return nil, driver.ErrSkip } - c.incExecQueryCounter() //increment the exec counter to keep track of the number of exec calls + c.incExecStmtsCounter() //increment the exec counter to keep track of the number of exec calls return conn.ExecContext(ctx, query, args) } @@ -103,6 +105,7 @@ func (c *managedConn) Query(query string, args []driver.Value) (driver.Rows, err if !ok { return nil, driver.ErrSkip } + c.incQueryStmtsCounter() //increment the query counter to keep track of the number of query calls return conn.Query(query, args) } @@ -111,6 +114,7 @@ func (c *managedConn) QueryContext(ctx context.Context, query string, args []dri if !ok { return nil, driver.ErrSkip } + c.incQueryStmtsCounter() //increment the query counter to keep track of the number of query calls return conn.QueryContext(ctx, query, args) } @@ -193,14 +197,26 @@ func (c *managedConn) GetKill() bool { return c.killed } -func (c *managedConn) incExecQueryCounter() { +func (c *managedConn) incExecStmtsCounter() { c.mu.Lock() defer c.mu.Unlock() - c.execQueryCounter++ + c.execStmtsCounter++ } -func (c *managedConn) resetExecQueryCounter() { +func (c *managedConn) resetExecStmtsCounter() { c.mu.Lock() defer c.mu.Unlock() - c.execQueryCounter = 0 + c.execStmtsCounter = 0 +} + +func (c *managedConn) incQueryStmtsCounter() { + c.mu.Lock() + defer c.mu.Unlock() + c.queryStmtsCounter++ +} + +func (c *managedConn) resetQueryStmtsCounter() { + c.mu.Lock() + defer c.mu.Unlock() + c.queryStmtsCounter = 0 } diff --git a/conn_test.go b/conn_test.go index d17a278..aeaa0cc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1,9 +1,15 @@ package hotload import ( + "context" + "database/sql/driver" + "io" + "strings" + "sync" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - "sync" + "github.com/prometheus/client_golang/prometheus/testutil" ) var _ = Describe("managedConn", func() { @@ -34,3 +40,197 @@ var _ = Describe("managedConn", func() { Consistently(readLockAcquired).Should(BeFalse()) }) }) + +/**** Mocks for Prometheus Metrics ****/ + +type mockDriverConn struct{} + +type mockTx struct{} + +func (mockTx) Commit() error { + return nil +} + +func (mockTx) Rollback() error { + return nil +} + +func (mockDriverConn) Prepare(query string) (driver.Stmt, error) { + return nil, nil +} + +func (mockDriverConn) Begin() (driver.Tx, error) { + return mockTx{}, nil +} + +func (mockDriverConn) Close() error { + return nil +} + +func (mockDriverConn) IsValid() bool { + return true +} + +func (mockDriverConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return mockTx{}, nil +} + +func (mockDriverConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return nil, nil +} + +func (mockDriverConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return nil, nil +} + +func (mockDriverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return nil, nil +} + +func (mockDriverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return nil, nil +} + +/**** End Mocks for Prometheus Metrics ****/ + +var _ = Describe("PrometheusMetrics", func() { + const help = ` + # HELP transaction_sql_stmts_total The number of sql stmts called in a transaction by statement type per grpc service and method + # TYPE transaction_sql_stmts_total summary + ` + + var service1Metrics = ` + transaction_sql_stmts_total_sum{grpc_method="service_1",grpc_service="method_1",stmt="exec"} 3 + transaction_sql_stmts_total_count{grpc_method="service_1",grpc_service="method_1",stmt="exec"} 1 + transaction_sql_stmts_total_sum{grpc_method="service_1",grpc_service="method_1",stmt="query"} 3 + transaction_sql_stmts_total_count{grpc_method="service_1",grpc_service="method_1",stmt="query"} 1 + ` + + var service2Metrics = ` + transaction_sql_stmts_total_sum{grpc_method="service_2",grpc_service="method_2",stmt="exec"} 4 + transaction_sql_stmts_total_count{grpc_method="service_2",grpc_service="method_2",stmt="exec"} 1 + transaction_sql_stmts_total_sum{grpc_method="service_2",grpc_service="method_2",stmt="query"} 4 + transaction_sql_stmts_total_count{grpc_method="service_2",grpc_service="method_2",stmt="query"} 1 + ` + + var service1RerunMetrics = ` + transaction_sql_stmts_total_sum{grpc_method="service_1",grpc_service="method_1",stmt="exec"} 4 + transaction_sql_stmts_total_count{grpc_method="service_1",grpc_service="method_1",stmt="exec"} 2 + transaction_sql_stmts_total_sum{grpc_method="service_1",grpc_service="method_1",stmt="query"} 4 + transaction_sql_stmts_total_count{grpc_method="service_1",grpc_service="method_1",stmt="query"} 2 + ` + + var noMethodMetrics = ` + transaction_sql_stmts_total_sum{grpc_method="",grpc_service="",stmt="exec"} 1 + transaction_sql_stmts_total_count{grpc_method="",grpc_service="",stmt="exec"} 1 + transaction_sql_stmts_total_sum{grpc_method="",grpc_service="",stmt="query"} 1 + transaction_sql_stmts_total_count{grpc_method="",grpc_service="",stmt="query"} 1 + ` + + It("Should emit the correct metrics", func() { + mc := newManagedConn(context.Background(), mockDriverConn{}) + + ctx := ContextWithExecLabels(context.Background(), map[string]string{"grpc_method": "service_1", "grpc_service": "method_1"}) + + // begin a transaction + tx, err := mc.BeginTx(ctx, driver.TxOptions{}) + Expect(err).ShouldNot(HaveOccurred()) + + // exec a statement + mc.Exec("INSERT INTO table (column) VALUES (?)", []driver.Value{"value"}) + + // query a statement + mc.Query("SELECT * FROM table WHERE column = ?", []driver.Value{"value"}) + mc.Query("SELECT * FROM table WHERE column = ?", []driver.Value{"value"}) + + // exec a statement with context + mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}}) + mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}}) + + // query a statement with context + mc.QueryContext(ctx, "SELECT * FROM table WHERE column = ?", []driver.NamedValue{{Value: "value"}}) + + // commit the transaction + err = tx.Commit() + Expect(err).ShouldNot(HaveOccurred()) + + // collect and compare metrics + err = testutil.CollectAndCompare(sqlStmtsSummary, strings.NewReader(help+service1Metrics)) + Expect(err).ShouldNot(HaveOccurred()) + + // reset the metrics + // new context + ctx = ContextWithExecLabels(context.Background(), map[string]string{"grpc_method": "service_2", "grpc_service": "method_2"}) + // begin a transaction + tx, err = mc.BeginTx(ctx, driver.TxOptions{}) + Expect(err).ShouldNot(HaveOccurred()) + + // exec a statement + mc.Exec("INSERT INTO table (column) VALUES (?)", []driver.Value{"value"}) + mc.Exec("INSERT INTO table (column) VALUES (?)", []driver.Value{"value"}) + + // query a statement + mc.Query("SELECT * FROM table WHERE column = ?", []driver.Value{"value"}) + mc.Query("SELECT * FROM table WHERE column = ?", []driver.Value{"value"}) + + // exec a statement with context + mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}}) + mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}}) + + // query a statement with context + mc.QueryContext(ctx, "SELECT * FROM table WHERE column = ?", []driver.NamedValue{{Value: "value"}}) + mc.QueryContext(ctx, "SELECT * FROM table WHERE column = ?", []driver.NamedValue{{Value: "value"}}) + + // commit the transaction + err = tx.Commit() + Expect(err).ShouldNot(HaveOccurred()) + + // collect and compare metrics + err = testutil.CollectAndCompare(sqlStmtsSummary, strings.NewReader(help+service1Metrics+service2Metrics)) + Expect(err).ShouldNot(HaveOccurred()) + + // rerun with initial metrics + ctx = ContextWithExecLabels(context.Background(), map[string]string{"grpc_method": "service_1", "grpc_service": "method_1"}) + // begin a transaction + tx, err = mc.BeginTx(ctx, driver.TxOptions{}) + Expect(err).ShouldNot(HaveOccurred()) + + // exec a statement with context + mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}}) + + // query a statement with context + mc.QueryContext(ctx, "SELECT * FROM table WHERE column = ?", []driver.NamedValue{{Value: "value"}}) + + // rollback the transaction + err = tx.Rollback() + Expect(err).ShouldNot(HaveOccurred()) + + // collect and compare metrics + err = testutil.CollectAndCompare(sqlStmtsSummary, strings.NewReader(help+service1RerunMetrics+service2Metrics)) + Expect(err).ShouldNot(HaveOccurred()) + + // non labeled context + ctx = context.Background() + // begin a transaction + tx, err = mc.BeginTx(ctx, driver.TxOptions{}) + Expect(err).ShouldNot(HaveOccurred()) + + // exec query context + mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}}) + + // query a statement with context + mc.QueryContext(ctx, "SELECT * FROM table WHERE column = ?", []driver.NamedValue{{Value: "value"}}) + + // commit the transaction + err = tx.Commit() + Expect(err).ShouldNot(HaveOccurred()) + + // collect and compare metrics + err = testutil.CollectAndCompare(sqlStmtsSummary, strings.NewReader(help+noMethodMetrics+service1RerunMetrics+service2Metrics)) + Expect(err).ShouldNot(HaveOccurred()) + }) +}) + +func CollectAndCompareMetrics(r io.Reader) error { + return testutil.CollectAndCompare(sqlStmtsSummary, r) +} diff --git a/go.mod b/go.mod index ccb707e..917c54c 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/nxadm/tail v1.4.8 // indirect diff --git a/integrationtests/hotload_test.go b/integrationtests/hotload_test.go index cd91db1..398a9be 100644 --- a/integrationtests/hotload_test.go +++ b/integrationtests/hotload_test.go @@ -3,15 +3,16 @@ package integrationtests import ( "database/sql" "fmt" + "io/ioutil" + "log" + "time" + "github.com/infobloxopen/hotload" _ "github.com/infobloxopen/hotload/fsnotify" "github.com/lib/pq" _ "github.com/lib/pq" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - "io/ioutil" - "log" - "time" ) const ( diff --git a/prometheus.go b/prometheus.go index 30f38bd..bd0be0b 100644 --- a/prometheus.go +++ b/prometheus.go @@ -11,17 +11,20 @@ import ( const ( GRPCMethodKey = "grpc_method" GRPCServiceKey = "grpc_service" + StatementKey = "stmt" // either exec or query + ExecStatement = "exec" + QueryStatement = "query" ) -// execQuerySummary is a prometheus metric to keep track of the number of times -// exec query is called in a transaction -var execQuerySummary = prometheus.NewSummaryVec(prometheus.SummaryOpts{ - Name: "transaction_exec_query_total", - Help: "The number of times exec query is called in a transaction", -}, []string{GRPCServiceKey, GRPCMethodKey}) +// sqlStmtsSummary is a prometheus metric to keep track of the number of times +// a sql statement is called in a transaction by statement type per grpc service +var sqlStmtsSummary = prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "transaction_sql_stmts_total", + Help: "The number of sql stmts called in a transaction by statement type per grpc service and method", +}, []string{GRPCServiceKey, GRPCMethodKey, StatementKey}) func init() { - prometheus.MustRegister(execQuerySummary) + prometheus.MustRegister(sqlStmtsSummary) } // PromUnaryServerInterceptor returns a unary server interceptor that sets the diff --git a/prometheus_test.go b/prometheus_test.go new file mode 100644 index 0000000..9fff8ef --- /dev/null +++ b/prometheus_test.go @@ -0,0 +1,39 @@ +package hotload + +import ( + "context" + "errors" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/prometheus/client_golang/prometheus" + "google.golang.org/grpc" +) + +var _ = Describe("PrometheusMetric", func() { + It("Should register a prometheus metric", func() { + // This test is a placeholder for a real test + err := prometheus.Register(sqlStmtsSummary) + Expect(err).Should(HaveOccurred()) + Expect(errors.As(err, &prometheus.AlreadyRegisteredError{})).Should(BeTrue()) + }) +}) + +var _ = Describe("PromUnaryServerInterceptor", func() { + It("Should return a unary server interceptor", func() { + validationHandler := func(ctx context.Context, req interface{}) (interface{}, error) { + labels := GetExecLabelsFromContext(ctx) + + Expect(labels).ShouldNot(BeNil()) + Expect(labels[GRPCMethodKey]).Should(Equal("List")) + Expect(labels[GRPCServiceKey]).Should(Equal("infoblox.service.SampleService")) + + return nil, nil + } + + promUnaryServerInterceptor := PromUnaryServerInterceptor() + promUnaryServerInterceptor(context.Background(), struct{}{}, &grpc.UnaryServerInfo{ + FullMethod: "/infoblox.service.SampleService/List", + }, validationHandler) + }) +}) diff --git a/transaction.go b/transaction.go index 0ff062c..7d695bd 100644 --- a/transaction.go +++ b/transaction.go @@ -6,7 +6,7 @@ import ( ) // managedTx wraps a sql/driver.Tx so that it can store the context of the -// transaction and clean up the execQueryCounter on Commit or Rollback. +// transaction and clean up the execqueryCallsCounter on Commit or Rollback. type managedTx struct { tx driver.Tx conn *managedConn @@ -25,14 +25,19 @@ func (t *managedTx) Rollback() error { return err } -func observeExecQuerySummary(ctx context.Context, counter int) { +func observeSQLStmtsSummary(ctx context.Context, execStmtsCounter, queryStmtsCounter int) { labels := GetExecLabelsFromContext(ctx) - execQuerySummary.WithLabelValues(labels[GRPCServiceKey], labels[GRPCMethodKey]).Observe(float64(counter)) + service := labels[GRPCServiceKey] + method := labels[GRPCMethodKey] + + sqlStmtsSummary.WithLabelValues(service, method, ExecStatement).Observe(float64(execStmtsCounter)) + sqlStmtsSummary.WithLabelValues(service, method, QueryStatement).Observe(float64(queryStmtsCounter)) } func (t *managedTx) cleanup() { - observeExecQuerySummary(t.ctx, t.conn.execQueryCounter) - t.conn.resetExecQueryCounter() + observeSQLStmtsSummary(t.ctx, t.conn.execStmtsCounter, t.conn.queryStmtsCounter) + t.conn.resetExecStmtsCounter() + t.conn.resetQueryStmtsCounter() } var promLabelKey = struct{}{}