diff --git a/sdk/messaging/azeventhubs/CHANGELOG.md b/sdk/messaging/azeventhubs/CHANGELOG.md index bf684e595a66..dd125f150611 100644 --- a/sdk/messaging/azeventhubs/CHANGELOG.md +++ b/sdk/messaging/azeventhubs/CHANGELOG.md @@ -1,10 +1,12 @@ # Release History -## 0.3.1 (2023-01-10) +## 0.4.0 (2023-01-10) ### Bugs Fixed - User-Agent was incorrectly formatted in our AMQP-based clients. (PR#19712) +- Connection recovery has been improved, removing some unnecessasry retries as well as adding a bound around + some operations (Close) that could potentially block recovery for a long time. (PR#19683) ## 0.3.0 (2022-11-10) diff --git a/sdk/messaging/azeventhubs/checkpoints/doc.go b/sdk/messaging/azeventhubs/checkpoints/doc.go index 347512dc3747..beb299882dab 100644 --- a/sdk/messaging/azeventhubs/checkpoints/doc.go +++ b/sdk/messaging/azeventhubs/checkpoints/doc.go @@ -1,9 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + //go:build go1.16 // +build go1.16 -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - // Package checkpoints provides a CheckpointStore using Azure Blob Storage. // // CheckpointStore's are generally not used on their own and will be created so they @@ -14,4 +14,5 @@ // // [Processor]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs#Processor // [example_processor_test.go]: https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/messaging/azeventhubs/example_processor_test.go + package checkpoints diff --git a/sdk/messaging/azeventhubs/consumer_client_internal_test.go b/sdk/messaging/azeventhubs/consumer_client_internal_test.go index e2e095c98ead..9f77de774b30 100644 --- a/sdk/messaging/azeventhubs/consumer_client_internal_test.go +++ b/sdk/messaging/azeventhubs/consumer_client_internal_test.go @@ -23,7 +23,7 @@ func TestConsumerClient_Recovery(t *testing.T) { testParams := test.GetConnectionParamsForTest(t) // Uncomment to see the entire recovery playbook run. - // test.EnableStdoutLogging() + test.EnableStdoutLogging() dac, err := azidentity.NewDefaultAzureCredential(nil) require.NoError(t, err) @@ -122,7 +122,7 @@ func TestConsumerClient_Recovery(t *testing.T) { defer test.RequireClose(t, consumerClient) - log.Printf("3. closing connection, which will force recovery for each partition client so they can read the next event") + log.Printf("3. closing internal connection (non-permanently), which will force recovery for each partition client so they can read the next event") // now we'll close the internal connection, simulating a connection break require.NoError(t, consumerClient.namespace.Close(context.Background(), false)) @@ -158,7 +158,7 @@ func TestConsumerClient_RecoveryLink(t *testing.T) { testParams := test.GetConnectionParamsForTest(t) // Uncomment to see the entire recovery playbook run. - // test.EnableStdoutLogging() + test.EnableStdoutLogging() dac, err := azidentity.NewDefaultAzureCredential(nil) require.NoError(t, err) diff --git a/sdk/messaging/azeventhubs/consumer_client_test.go b/sdk/messaging/azeventhubs/consumer_client_test.go index fbba131e8e0a..9845677c88bd 100644 --- a/sdk/messaging/azeventhubs/consumer_client_test.go +++ b/sdk/messaging/azeventhubs/consumer_client_test.go @@ -4,6 +4,7 @@ package azeventhubs_test import ( "context" + "strings" "sync" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/test" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/eventhub/armeventhub" "github.com/stretchr/testify/require" ) @@ -427,6 +429,77 @@ func TestConsumerClient_ReceiveEvents(t *testing.T) { } } +func TestConsumerClient_Detaches(t *testing.T) { + testParams := test.GetConnectionParamsForTest(t) + + test.EnableStdoutLogging() + + dac, err := azidentity.NewDefaultAzureCredential(nil) + require.NoError(t, err) + + // create our event hub + producerClient, err := azeventhubs.NewProducerClientFromConnectionString(testParams.ConnectionString, testParams.EventHubName, nil) + require.NoError(t, err) + + defer producerClient.Close(context.Background()) + + enableOrDisableEventHub(t, testParams, dac, true) + t.Logf("Sending events, connection should be fine") + err = sendEvent(t, producerClient) + require.NoError(t, err) + + enableOrDisableEventHub(t, testParams, dac, false) + t.Logf("Sending events, expected to fail since entity is disabled") + err = sendEvent(t, producerClient) + require.Error(t, err, "fails, entity has become disabled") + + enableOrDisableEventHub(t, testParams, dac, true) + t.Logf("Sending events, should reconnect") + err = sendEvent(t, producerClient) + require.NoError(t, err, "reattach happens") +} + +func sendEvent(t *testing.T, producerClient *azeventhubs.ProducerClient) error { + batch, err := producerClient.NewEventDataBatch(context.Background(), nil) + require.NoError(t, err) + + err = batch.AddEventData(&azeventhubs.EventData{ + Body: []byte("hello world"), + }, nil) + require.NoError(t, err) + + return producerClient.SendEventDataBatch(context.Background(), batch, nil) +} + +// enableOrDisableEventHub sets an eventhub to active if active is true, or disables it if active is false. +// +// This is useful when testing attach/detach type scenarios where you want the service to force links +// to detach. +func enableOrDisableEventHub(t *testing.T, testParams test.ConnectionParamsForTest, dac *azidentity.DefaultAzureCredential, active bool) { + client, err := armeventhub.NewEventHubsClient(testParams.SubscriptionID, dac, nil) + require.NoError(t, err) + + ns := strings.Split(testParams.EventHubNamespace, ".")[0] + + resp, err := client.Get(context.Background(), testParams.ResourceGroup, ns, testParams.EventHubName, nil) + require.NoError(t, err) + + if active { + resp.Properties.Status = to.Ptr(armeventhub.EntityStatusActive) + } else { + resp.Properties.Status = to.Ptr(armeventhub.EntityStatusDisabled) + } + + t.Logf("Setting entity status to %s", *resp.Properties.Status) + _, err = client.CreateOrUpdate(context.Background(), testParams.ResourceGroup, ns, testParams.EventHubName, armeventhub.Eventhub{ + Properties: resp.Properties, + }, nil) + require.NoError(t, err) + + // give a little time for the change to take effect + time.Sleep(5 * time.Second) +} + func newPartitionClientForTest(t *testing.T, partitionID string, subscribeOptions azeventhubs.PartitionClientOptions) (*azeventhubs.PartitionClient, func()) { testParams := test.GetConnectionParamsForTest(t) diff --git a/sdk/messaging/azeventhubs/doc.go b/sdk/messaging/azeventhubs/doc.go index ede39de548c9..a9fc2e7e5697 100644 --- a/sdk/messaging/azeventhubs/doc.go +++ b/sdk/messaging/azeventhubs/doc.go @@ -11,4 +11,5 @@ // There are two clients for consuming events: // - [azeventhubs.Processor], which handles checkpointing and load balancing using durable storage. // - [azeventhubs.ConsumerClient], which is fully manual, but provides full control. + package azeventhubs diff --git a/sdk/messaging/azeventhubs/go.mod b/sdk/messaging/azeventhubs/go.mod index bbd65e23783d..a41f7c245990 100644 --- a/sdk/messaging/azeventhubs/go.mod +++ b/sdk/messaging/azeventhubs/go.mod @@ -6,6 +6,8 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0 github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/eventhub/armeventhub v1.0.0 + github.com/golang/mock v1.6.0 github.com/joho/godotenv v1.4.0 github.com/stretchr/testify v1.7.1 ) diff --git a/sdk/messaging/azeventhubs/go.sum b/sdk/messaging/azeventhubs/go.sum index fbaf0f294dd9..7be7720366da 100644 --- a/sdk/messaging/azeventhubs/go.sum +++ b/sdk/messaging/azeventhubs/go.sum @@ -6,6 +6,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0 h1:QkAcEIAKbNL4KoFr4Sath github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0/go.mod h1:bhXu1AjYL+wutSL/kpSq6s7733q2Rb0yuot9Zgfqa/0= github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8= github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/eventhub/armeventhub v1.0.0 h1:BWeAAEzkCnL0ABVJqs+4mYudNch7oFGPtTlSmIWL8ms= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/eventhub/armeventhub v1.0.0/go.mod h1:Y3gnVwfaz8h6L1YHar+NfWORtBoVUSB5h4GlGkdeF7Q= github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1 h1:BWe8a+f/t+7KY7zH2mqygeUD0t8hNFXe08p1Pb3/jKE= github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1/go.mod h1:Vt9sXTKwMyGcOxSmLDMnGPgqsUg7m8pe215qMLrDXw4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -18,6 +20,8 @@ github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx github.com/golang-jwt/jwt v3.2.1+incompatible h1:73Z+4BJcrTC+KczS6WvTPvRGOp1WmfEP4Q1lOd9Z/+c= github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -45,19 +49,41 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tedsuo/ifrit v0.0.0-20180802180643-bea94bb476cc/go.mod h1:eyZnKCc955uh98WQvzOm0dgAeLnf2O0Rz0LPoC5ze+0= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88 h1:Tgea0cVUD0ivh5ADBX4WwuI12DUd2to3nCYe2eayMIw= golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 h1:HVyaeDAYux4pnY+D/SiwmLOR36ewZ4iGQIIrtnuCjFA= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/sdk/messaging/azeventhubs/internal/amqpInterfaces.go b/sdk/messaging/azeventhubs/internal/amqpInterfaces.go index 8b3ced56c3f5..f6ea7f0cc377 100644 --- a/sdk/messaging/azeventhubs/internal/amqpInterfaces.go +++ b/sdk/messaging/azeventhubs/internal/amqpInterfaces.go @@ -7,7 +7,6 @@ import ( "context" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" - "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp" ) type AMQPReceiver = amqpwrap.AMQPReceiver @@ -15,13 +14,6 @@ type AMQPReceiverCloser = amqpwrap.AMQPReceiverCloser type AMQPSender = amqpwrap.AMQPSender type AMQPSenderCloser = amqpwrap.AMQPSenderCloser -// RPCLink is implemented by *rpc.Link -type RPCLink interface { - Close(ctx context.Context) error - RPC(ctx context.Context, msg *amqp.Message) (*RPCResponse, error) - LinkName() string -} - // Closeable is implemented by pretty much any AMQP link/client // including our own higher level Receiver/Sender. type Closeable interface { diff --git a/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go b/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go index 78c9ef37d2a3..c83eea731e0d 100644 --- a/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go +++ b/sdk/messaging/azeventhubs/internal/amqpwrap/amqpwrap.go @@ -64,6 +64,22 @@ type AMQPClient interface { NewSession(ctx context.Context, opts *amqp.SessionOptions) (AMQPSession, error) } +// RPCLink is implemented by *rpc.Link +type RPCLink interface { + Close(ctx context.Context) error + RPC(ctx context.Context, msg *amqp.Message) (*RPCResponse, error) + LinkName() string +} + +// RPCResponse is the simplified response structure from an RPC like call +type RPCResponse struct { + // Code is the response code - these originate from Service Bus. Some + // common values are called out below, with the RPCResponseCode* constants. + Code int + Description string + Message *amqp.Message +} + // AMQPClientWrapper is a simple interface, implemented by *AMQPClientWrapper // It exists only so we can return AMQPSession, which itself only exists so we can // return interfaces for AMQPSender and AMQPReceiver from AMQPSession. diff --git a/sdk/messaging/azeventhubs/internal/cbs.go b/sdk/messaging/azeventhubs/internal/cbs.go index b3462071b4f8..ae3aa6c359b6 100644 --- a/sdk/messaging/azeventhubs/internal/cbs.go +++ b/sdk/messaging/azeventhubs/internal/cbs.go @@ -5,6 +5,7 @@ package internal import ( "context" + "errors" azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" @@ -23,7 +24,10 @@ const ( ) // NegotiateClaim attempts to put a token to the $cbs management endpoint to negotiate auth for the given audience -func NegotiateClaim(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider) error { +// +// contextWithTimeoutFn is intended to be context.WithTimeout in production code, but can be stubbed out when writing +// unit tests to keep timeouts reasonable. +func NegotiateClaim(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider, contextWithTimeoutFn contextWithTimeoutFn) error { link, err := NewRPCLink(ctx, RPCLinkArgs{ Client: conn, Address: cbsAddress, @@ -34,15 +38,27 @@ func NegotiateClaim(ctx context.Context, audience string, conn amqpwrap.AMQPClie return err } - defer func() { + closeLink := func(ctx context.Context, origErr error) error { + ctx, cancel := contextWithTimeoutFn(ctx, defaultCloseTimeout) + defer cancel() + if err := link.Close(ctx); err != nil { + if IsCancelError(err) { + azlog.Writef(exported.EventAuth, "Failed closing claim link because it was cancelled. Connection will need to be reset") + return errConnResetNeeded + } + azlog.Writef(exported.EventAuth, "Failed closing claim link: %s", err.Error()) + return err } - }() + + return origErr + } token, err := provider.GetToken(audience) if err != nil { - return err + azlog.Writef(exported.EventAuth, "Failed to get token from provider") + return closeLink(ctx, err) } azlog.Writef(exported.EventAuth, "negotiating claim for audience %s with token type %s and expiry of %s", audience, token.TokenType, token.Expiry) @@ -58,8 +74,11 @@ func NegotiateClaim(ctx context.Context, audience string, conn amqpwrap.AMQPClie } if _, err := link.RPC(ctx, msg); err != nil { - return err + azlog.Writef(exported.EventAuth, "Failed to send/receive RPC message") + return closeLink(ctx, err) } - return nil + return closeLink(ctx, nil) } + +var errConnResetNeeded = errors.New("connection must be reset, link/connection state may be inconsistent") diff --git a/sdk/messaging/azeventhubs/internal/cbs_test.go b/sdk/messaging/azeventhubs/internal/cbs_test.go new file mode 100644 index 000000000000..cafabdda041c --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/cbs_test.go @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package internal + +import ( + "context" + "fmt" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/mock" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +func TestNegotiateClaimWithCloseTimeout(t *testing.T) { + for _, errToReturn := range []error{context.Canceled, context.DeadlineExceeded} { + t.Run(fmt.Sprintf("Close() cancels with error %v", errToReturn), func(t *testing.T) { + ctrl := gomock.NewController(t) + + tp := mock.NewMockTokenProvider(ctrl) + receiver := mock.NewMockAMQPReceiverCloser(ctrl) + sender := mock.NewMockAMQPSenderCloser(ctrl) + session := mock.NewMockAMQPSession(ctrl) + client := mock.NewMockAMQPClient(ctrl) + + client.EXPECT().NewSession(mock.NotCancelled, gomock.Any()).Return(session, nil) + session.EXPECT().NewReceiver(mock.NotCancelled, gomock.Any(), gomock.Any()).Return(receiver, nil) + session.EXPECT().NewSender(mock.NotCancelled, gomock.Any(), gomock.Any()).Return(sender, nil) + tp.EXPECT().GetToken(gomock.Any()).Return(&auth.Token{}, nil) + + mock.SetupRPC(sender, receiver, 1, func(sent, response *amqp.Message) { + response.ApplicationProperties = map[string]interface{}{ + "status-code": int32(200), + } + }) + + // the context passed to these calls are already cancelled since the parent + // context was cancelled. This basically just falls through the error handling + // but it's okay - each resource should close any local state they can before + // returning and we're going to end up abandoning ship on the connection. + session.EXPECT().Close(mock.CancelledAndHasTimeout) + sender.EXPECT().Close(mock.CancelledAndHasTimeout) + + // When links fail to close in a timely manner it's either because the connection is (somehow) + // no longer valid _or_ conditions are preventing us from closing the link. In either case we + // have to be careful since it means that some resources (for instance, singleton links like $cbs) + // might "leak" since they can't be closed. + // + // Rather than attempt to do some complicated piecemeal recovery, we instead invalidate the entire + // connection, which is the only safe way to ensure the client and service agree on what is open and + // active. + receiver.EXPECT().Close(mock.NotCancelledAndHasTimeout).DoAndReturn(func(ctx context.Context) error { + <-ctx.Done() + return errToReturn + }) + + err := NegotiateClaim(context.Background(), "audience", client, tp, mock.NewContextWithTimeoutForTests) + require.EqualError(t, err, "connection must be reset, link/connection state may be inconsistent") + require.Equal(t, GetRecoveryKind(err), RecoveryKindConn) + }) + } +} + +func TestNegotiateClaimWithAuthFailure(t *testing.T) { + ctrl := gomock.NewController(t) + + tp := mock.NewMockTokenProvider(ctrl) + receiver := mock.NewMockAMQPReceiverCloser(ctrl) + sender := mock.NewMockAMQPSenderCloser(ctrl) + session := mock.NewMockAMQPSession(ctrl) + client := mock.NewMockAMQPClient(ctrl) + + client.EXPECT().NewSession(mock.NotCancelled, gomock.Any()).Return(session, nil) + session.EXPECT().NewReceiver(mock.NotCancelled, gomock.Any(), gomock.Any()).Return(receiver, nil) + session.EXPECT().NewSender(mock.NotCancelled, gomock.Any(), gomock.Any()).Return(sender, nil) + tp.EXPECT().GetToken(gomock.Any()).Return(&auth.Token{}, nil) + + session.EXPECT().Close(mock.NotCancelledAndHasTimeout) + sender.EXPECT().Close(mock.NotCancelledAndHasTimeout) + receiver.EXPECT().Close(mock.NotCancelledAndHasTimeout) + + mock.SetupRPC(sender, receiver, 1, func(sent, response *amqp.Message) { + // this is the kind of error you get if your connection string is inconsistent + // (ie, you tamper with the shared key, etc..) + response.ApplicationProperties = map[string]interface{}{ + "status-code": int32(401), + "status-description": "InvalidSignature: The token has an invalid signature.", + "error-condition": "com.microsoft:auth-failed", + } + }) + + err := NegotiateClaim(context.Background(), "audience", client, tp, mock.NewContextWithTimeoutForTests) + + require.EqualError(t, err, "rpc: failed, status code 401 and description: InvalidSignature: The token has an invalid signature.") + require.Equal(t, GetRecoveryKind(err), RecoveryKindLink) +} + +func TestNegotiateClaimSuccess(t *testing.T) { + ctrl := gomock.NewController(t) + + tp := mock.NewMockTokenProvider(ctrl) + receiver := mock.NewMockAMQPReceiverCloser(ctrl) + sender := mock.NewMockAMQPSenderCloser(ctrl) + session := mock.NewMockAMQPSession(ctrl) + client := mock.NewMockAMQPClient(ctrl) + + client.EXPECT().NewSession(mock.NotCancelled, gomock.Any()).Return(session, nil) + session.EXPECT().NewReceiver(mock.NotCancelled, gomock.Any(), gomock.Any()).Return(receiver, nil) + session.EXPECT().NewSender(mock.NotCancelled, gomock.Any(), gomock.Any()).Return(sender, nil) + tp.EXPECT().GetToken(gomock.Any()).Return(&auth.Token{}, nil) + + session.EXPECT().Close(mock.NotCancelledAndHasTimeout) + sender.EXPECT().Close(mock.NotCancelledAndHasTimeout) + receiver.EXPECT().Close(mock.NotCancelledAndHasTimeout) + + mock.SetupRPC(sender, receiver, 1, func(sent, response *amqp.Message) { + response.ApplicationProperties = map[string]interface{}{ + "status-code": int32(200), + } + }) + + err := NegotiateClaim(context.Background(), "audience", client, tp, mock.NewContextWithTimeoutForTests) + require.NoError(t, err) +} diff --git a/sdk/messaging/azeventhubs/internal/constants.go b/sdk/messaging/azeventhubs/internal/constants.go index cc439f8100e9..a3b8e975381e 100644 --- a/sdk/messaging/azeventhubs/internal/constants.go +++ b/sdk/messaging/azeventhubs/internal/constants.go @@ -4,4 +4,4 @@ package internal // Version is the semantic version number -const Version = "v0.3.1" +const Version = "v0.4.0" diff --git a/sdk/messaging/azeventhubs/internal/eh.go b/sdk/messaging/azeventhubs/internal/eh.go index 272e367e6a04..3d827c406711 100644 --- a/sdk/messaging/azeventhubs/internal/eh.go +++ b/sdk/messaging/azeventhubs/internal/eh.go @@ -5,6 +5,7 @@ package internal import ( "context" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" ) @@ -13,7 +14,7 @@ func (l *rpcLink) LinkName() string { return l.sender.LinkName() } -func (ns *Namespace) NewRPCLink(ctx context.Context, managementPath string) (RPCLink, uint64, error) { +func (ns *Namespace) NewRPCLink(ctx context.Context, managementPath string) (amqpwrap.RPCLink, uint64, error) { client, connID, err := ns.GetAMQPClientImpl(ctx) if err != nil { diff --git a/sdk/messaging/azeventhubs/internal/errors.go b/sdk/messaging/azeventhubs/internal/errors.go index b7d263dac78c..465b16165189 100644 --- a/sdk/messaging/azeventhubs/internal/errors.go +++ b/sdk/messaging/azeventhubs/internal/errors.go @@ -13,6 +13,7 @@ import ( "reflect" "strings" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp" ) @@ -156,6 +157,10 @@ func GetRecoveryKind(err error) RecoveryKind { return RecoveryKindFatal } + if errors.Is(err, errConnResetNeeded) { + return RecoveryKindConn + } + var netErr net.Error // these are errors that can flow from the go-amqp connection to @@ -259,7 +264,7 @@ type ( } // ErrAMQP indicates that the server communicated an AMQP error with a particular - ErrAMQP RPCResponse + ErrAMQP amqpwrap.RPCResponse // ErrNoMessages is returned when an operation returned no messages. It is not indicative that there will not be // more messages in the future. diff --git a/sdk/messaging/azeventhubs/internal/errors_test.go b/sdk/messaging/azeventhubs/internal/errors_test.go index 8903723876c3..ebe448514e65 100644 --- a/sdk/messaging/azeventhubs/internal/errors_test.go +++ b/sdk/messaging/azeventhubs/internal/errors_test.go @@ -4,6 +4,7 @@ package internal import ( + "context" "errors" "testing" @@ -32,3 +33,10 @@ func TestOwnershipLost(t *testing.T) { require.False(t, IsOwnershipLostError(&amqp.ConnectionError{})) require.False(t, IsOwnershipLostError(errors.New("definitely not an ownership lost error"))) } + +func TestGetRecoveryKind(t *testing.T) { + require.Equal(t, GetRecoveryKind(nil), RecoveryKindNone) + require.Equal(t, GetRecoveryKind(errConnResetNeeded), RecoveryKindConn) + require.Equal(t, GetRecoveryKind(&amqp.DetachError{}), RecoveryKindLink) + require.Equal(t, GetRecoveryKind(context.Canceled), RecoveryKindFatal) +} diff --git a/sdk/messaging/azeventhubs/internal/links.go b/sdk/messaging/azeventhubs/internal/links.go index b7a9765d4d00..a06c011309d0 100644 --- a/sdk/messaging/azeventhubs/internal/links.go +++ b/sdk/messaging/azeventhubs/internal/links.go @@ -5,7 +5,9 @@ package internal import ( "context" + "fmt" "sync" + "time" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log" @@ -27,6 +29,17 @@ type LinkWithID[LinkT AMQPLink] struct { // Link will be an amqp.Receiver or amqp.Sender link. Link LinkT + + // PartitionID, if available. + PartitionID string +} + +func (lwid *LinkWithID[LinkT]) String() string { + if lwid == nil { + return "none" + } + + return fmt.Sprintf("c:%d,l:%.5s,p:%s", lwid.ConnID, lwid.Link.LinkName(), lwid.PartitionID) } // LinksForPartitionClient are the functions that the PartitionClient uses within Links[T] @@ -44,11 +57,13 @@ type Links[LinkT AMQPLink] struct { links map[string]*linkState[LinkT] managementLinkMu *sync.RWMutex - managementLink *linkState[RPCLink] + managementLink *linkState[amqpwrap.RPCLink] managementPath string newLinkFn func(ctx context.Context, session amqpwrap.AMQPSession, partitionID string) (LinkT, error) entityPathFn func(partitionID string) string + + contextWithTimeoutFn contextWithTimeoutFn // stubbable version of context.WithTimeout } type NewLinksFn[LinkT AMQPLink] func(ctx context.Context, session amqpwrap.AMQPSession, entityPath string) (LinkT, error) @@ -59,9 +74,11 @@ func NewLinks[LinkT AMQPLink](ns NamespaceForAMQPLinks, managementPath string, e linksMu: &sync.RWMutex{}, links: map[string]*linkState[LinkT]{}, managementLinkMu: &sync.RWMutex{}, - newLinkFn: newLinkFn, - entityPathFn: entityPathFn, managementPath: managementPath, + + newLinkFn: newLinkFn, + entityPathFn: entityPathFn, + contextWithTimeoutFn: context.WithTimeout, } } @@ -76,8 +93,41 @@ func (l *Links[LinkT]) RecoverIfNeeded(ctx context.Context, partitionID string, case RecoveryKindNone: return nil case RecoveryKindLink: - return l.closePartitionLinkIfMatch(ctx, partitionID, lwid.Link.LinkName()) + ctx, cancel := l.contextWithTimeoutFn(ctx, defaultCloseTimeout) + defer cancel() + + err := l.closePartitionLinkIfMatch(ctx, partitionID, lwid.Link.LinkName()) + + if err != nil { + if IsCancelError(err) { + azlog.Writef(exported.EventConn, "(%s) Link close was cancelled, connection will reset on next recovery", lwid.String()) + // if we failed to close a link then something odd is going on with + // our connection or the user has cancelled. Let the next attempt to use + // the connection recover it. + return errConnResetNeeded + } + + // we don't need to propagate this error - it'll just be the link detach error or whatever + // caused the link to detach (for instance, if the Event Hub itself has been Disabled). + azlog.Writef(exported.EventConn, "(%s) Error when cleaning up old link for link recovery: %s", lwid.String(), err) + } + + return nil case RecoveryKindConn: + // We only close _this_ partition's link. Other partitions will also get an error, and will recover. + // We used to close _all_ the links, but no longer do that since it's possible (when we do receiver + // redirect) to have more than one active connection at a time which means not all links would be + // affected when a single connection goes down. + ctx, cancel := l.contextWithTimeoutFn(ctx, defaultCloseTimeout) + defer cancel() + + if err := l.closePartitionLinkIfMatch(ctx, partitionID, lwid.Link.LinkName()); err != nil { + azlog.Writef(exported.EventConn, "(%s) Error when cleaning up old link: %s", lwid.String(), err) + + // NOTE: this is best effort - it's probable the connection is dead anyways so we'll log + // but ignore the error for recovery purposes. + } + // There are two possibilities here: // // 1. (stale) The caller got this error but the `lwid` they're passing us is 'stale' - ie, ' @@ -92,17 +142,18 @@ func (l *Links[LinkT]) RecoverIfNeeded(ctx context.Context, partitionID string, // not match the current link. // // For #2, we may recreate the connection. It's possible we won't if the connection itself - // has already been recovered by another goroutine. After that we'll recycle the link if - // it matches - we don't care about what happened with the connection because the link ID is - // unique - it wouldn't match unless it really was the same one that got the error. - if err := l.ns.Recover(ctx, lwid.ConnID); err != nil { + // has already been recovered by another goroutine. + ctx, cancel = l.contextWithTimeoutFn(ctx, defaultCloseTimeout) + defer cancel() + + err := l.ns.Recover(ctx, lwid.ConnID) + + if err != nil { + azlog.Writef(exported.EventConn, "(%s) Failure recovering connection for link: %s", lwid.String(), err) return err } - // We only close _this_ partition's link. Other partitions will also get an error, and will recover. - // We used to close _all_ the links, but no longer do that since it's possible (when we do receiver - // redirect) to have more than one active connection at a time. - return l.closePartitionLinkIfMatch(ctx, partitionID, lwid.Link.LinkName()) + return nil default: return err } @@ -117,7 +168,11 @@ func (l *Links[LinkT]) Retry(ctx context.Context, eventName log.Event, operation return GetRecoveryKind(err) == RecoveryKindFatal } - return utils.Retry(ctx, eventName, operation, retryOptions, func(ctx context.Context, args *utils.RetryFnArgs) error { + prefix := func() string { + return prevLinkWithID.String() + } + + return utils.Retry(ctx, eventName, prefix, retryOptions, func(ctx context.Context, args *utils.RetryFnArgs) error { if err := l.RecoverIfNeeded(ctx, partitionID, prevLinkWithID, args.LastErr); err != nil { return err } @@ -150,7 +205,7 @@ func (l *Links[LinkT]) Retry(ctx context.Context, eventName log.Event, operation // Whereas normally you'd do (for non-detach errors): // 0th attempt // (actual retries) - azlog.Writef(exported.EventConn, "(%s) Link was previously detached. Attempting quick reconnect to recover from error: %s", operation, err.Error()) + azlog.Writef(exported.EventConn, "(%s, %s) Link was previously detached. Attempting quick reconnect to recover from error: %s", linkWithID.String(), operation, err.Error()) didQuickRetry = true args.ResetAttempts() } @@ -197,8 +252,9 @@ func (l *Links[LinkT]) GetLink(ctx context.Context, partitionID string) (*LinkWi if current != nil { return &LinkWithID[LinkT]{ - ConnID: l.links[partitionID].ConnID, - Link: *l.links[partitionID].Link, + ConnID: l.links[partitionID].ConnID, + Link: *l.links[partitionID].Link, + PartitionID: partitionID, }, nil } @@ -220,14 +276,15 @@ func (l *Links[LinkT]) GetLink(ctx context.Context, partitionID string) (*LinkWi } return &LinkWithID[LinkT]{ - ConnID: l.links[partitionID].ConnID, - Link: *l.links[partitionID].Link, + ConnID: l.links[partitionID].ConnID, + Link: *l.links[partitionID].Link, + PartitionID: partitionID, }, nil } -func (l *Links[LinkT]) GetManagementLink(ctx context.Context) (LinkWithID[RPCLink], error) { +func (l *Links[LinkT]) GetManagementLink(ctx context.Context) (LinkWithID[amqpwrap.RPCLink], error) { if err := l.checkOpen(); err != nil { - return LinkWithID[RPCLink]{}, err + return LinkWithID[amqpwrap.RPCLink]{}, err } l.managementLinkMu.Lock() @@ -237,25 +294,30 @@ func (l *Links[LinkT]) GetManagementLink(ctx context.Context) (LinkWithID[RPCLin ls, err := l.newManagementLinkState(ctx) if err != nil { - return LinkWithID[RPCLink]{}, err + return LinkWithID[amqpwrap.RPCLink]{}, err } l.managementLink = ls } - return LinkWithID[RPCLink]{ + return LinkWithID[amqpwrap.RPCLink]{ ConnID: l.managementLink.ConnID, Link: *l.managementLink.Link, }, nil } func (l *Links[LinkT]) newLinkState(ctx context.Context, partitionID string) (*linkState[LinkT], error) { + azlog.Writef(exported.EventConn, "Creating link for partition ID '%s'", partitionID) + // check again now that we have the write lock - ls := &linkState[LinkT]{} + ls := &linkState[LinkT]{ + PartitionID: partitionID, + } cancelAuth, _, err := l.ns.NegotiateClaim(ctx, l.entityPathFn(partitionID)) if err != nil { + azlog.Writef(exported.EventConn, "(%s): Failed to negotiate claim for partition ID '%s': %s", ls.String(), partitionID, err) return nil, err } @@ -264,6 +326,7 @@ func (l *Links[LinkT]) newLinkState(ctx context.Context, partitionID string) (*l session, connID, err := l.ns.NewAMQPSession(ctx) if err != nil { + azlog.Writef(exported.EventConn, "(%s): Failed to create AMQP session for partition ID '%s': %s", ls.String(), partitionID, err) _ = ls.Close(ctx) return nil, err } @@ -274,16 +337,18 @@ func (l *Links[LinkT]) newLinkState(ctx context.Context, partitionID string) (*l tmpLink, err := l.newLinkFn(ctx, session, l.entityPathFn(partitionID)) if err != nil { + azlog.Writef(exported.EventConn, "(%s): Failed to create link for partition ID '%s': %s", ls.String(), partitionID, err) _ = ls.Close(ctx) return nil, err } ls.Link = &tmpLink + azlog.Writef(exported.EventConn, "(%s): Succesfully created link for partition ID '%s'", ls.String(), partitionID) return ls, nil } -func (l *Links[LinkT]) newManagementLinkState(ctx context.Context) (*linkState[RPCLink], error) { - ls := &linkState[RPCLink]{} +func (l *Links[LinkT]) newManagementLinkState(ctx context.Context) (*linkState[amqpwrap.RPCLink], error) { + ls := &linkState[amqpwrap.RPCLink]{} cancelAuth, _, err := l.ns.NegotiateClaim(ctx, l.managementPath) @@ -310,13 +375,15 @@ func (l *Links[LinkT]) Close(ctx context.Context) error { return l.closeLinks(ctx, true) } -func (l *Links[LinkT]) closeLinks(_ context.Context, permanent bool) error { - // we're finding, in practice, that allowing cancellations when cleaning up state - // just results in inconsistencies. We'll cut cancellation off here for now. - ctx := context.Background() +func (l *Links[LinkT]) closeLinks(ctx context.Context, permanent bool) error { + cancelled := false if err := l.closeManagementLink(ctx); err != nil { azlog.Writef(exported.EventConn, "Error while cleaning up management link while doing connection recovery: %s", err.Error()) + + if IsCancelError(err) { + cancelled = true + } } l.linksMu.Lock() @@ -328,12 +395,23 @@ func (l *Links[LinkT]) closeLinks(_ context.Context, permanent bool) error { for partitionID, link := range tmpLinks { if err := link.Close(ctx); err != nil { azlog.Writef(exported.EventConn, "Error while cleaning up link for partition ID '%s' while doing connection recovery: %s", partitionID, err.Error()) + + if IsCancelError(err) { + cancelled = true + } } } if !permanent { l.links = map[string]*linkState[LinkT]{} } + + if cancelled { + // this is the only kind of error I'd consider usable from Close() - it'll indicate + // that some of the links haven't been cleanly closed. + return ctx.Err() + } + return nil } @@ -351,6 +429,9 @@ func (l *Links[LinkT]) checkOpen() error { // closePartitionLinkIfMatch will close the link in the cache if it matches the passed in linkName. // This is similar to how an etag works - we'll only close it if you are working with the latest link - // if not, it's a no-op since somebody else has already 'saved' (recovered) before you. +// +// Note that the only error that can be returned here will come from go-amqp. Cleanup of _our_ internal state +// will always happen, if needed. func (l *Links[LinkT]) closePartitionLinkIfMatch(ctx context.Context, partitionID string, linkName string) error { l.linksMu.RLock() current, exists := l.links[partitionID] @@ -371,9 +452,7 @@ func (l *Links[LinkT]) closePartitionLinkIfMatch(ctx context.Context, partitionI return nil } - current.cancelAuth() delete(l.links, partitionID) - return current.Close(ctx) } @@ -399,6 +478,9 @@ type linkState[LinkT AMQPLink] struct { // Link will be an amqp.Receiver, an amqp.Sender link, or an RPCLink. Link *LinkT + // PartitionID, if available. + PartitionID string + // cancelAuth cancels the backround claim negotation for this link. cancelAuth func() @@ -407,6 +489,28 @@ type linkState[LinkT AMQPLink] struct { session amqpwrap.AMQPSession } +// String returns a string that can be used for logging, of the format: +// (c:,l:<5 characters of link id>) +// +// It can also handle nil and partial initialization. +func (ls *linkState[LinkT]) String() string { + if ls == nil { + return "none" + } + + linkName := "" + + if ls.Link != nil { + linkName = (*ls.Link).LinkName() + } + + return fmt.Sprintf("c:%d,l:%.5s,p:%s", ls.ConnID, linkName, ls.PartitionID) +} + +// Close cancels the background authentication loop for this link and +// then closes the AMQP links. +// NOTE: this avoids any issues where closing fails on the broker-side or +// locally and we leak a goroutine. func (ls *linkState[LinkT]) Close(ctx context.Context) error { if ls.cancelAuth != nil { ls.cancelAuth() @@ -418,3 +522,9 @@ func (ls *linkState[LinkT]) Close(ctx context.Context) error { return nil } + +const defaultCloseTimeout = time.Minute + +// contextWithTimeoutFn matches the signature for `context.WithTimeout` and is used when we want to +// stub things out for tests. +type contextWithTimeoutFn func(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) diff --git a/sdk/messaging/azeventhubs/internal/links_unit_test.go b/sdk/messaging/azeventhubs/internal/links_unit_test.go index 36a842274cd8..cce5e7b5ea14 100644 --- a/sdk/messaging/azeventhubs/internal/links_unit_test.go +++ b/sdk/messaging/azeventhubs/internal/links_unit_test.go @@ -10,6 +10,8 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/mock" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" ) @@ -120,56 +122,130 @@ func TestLinks_LinkRecoveryOnly(t *testing.T) { } func TestLinks_ConnectionRecovery(t *testing.T) { - recoverClientCalled := 0 - - fakeNS := &FakeNSForPartClient{ - RecoverFn: func(ctx context.Context, clientRevision uint64) error { - // we'll just always recover for our test. - recoverClientCalled++ - return nil - }, - } + ctrl := gomock.NewController(t) + ns := mock.NewMockNamespaceForAMQPLinks(ctrl) + receiver := mock.NewMockAMQPReceiverCloser(ctrl) + session := mock.NewMockAMQPSession(ctrl) - var nextID int - var receivers []*FakeAMQPReceiver + negotiateClaimCtx, cancelNegotiateClaim := context.WithCancel(context.Background()) - links := NewLinks(fakeNS, "managementPath", func(partitionID string) string { + ns.EXPECT().NegotiateClaim(mock.NotCancelled, gomock.Any()).Return(cancelNegotiateClaim, negotiateClaimCtx.Done(), nil) + ns.EXPECT().NewAMQPSession(mock.NotCancelled).Return(session, uint64(1), nil) + + receiver.EXPECT().LinkName().Return("link1").AnyTimes() + + links := NewLinks(ns, "managementPath", func(partitionID string) string { return fmt.Sprintf("part:%s", partitionID) - }, - func(ctx context.Context, session amqpwrap.AMQPSession, entityPath string) (*FakeAMQPReceiver, error) { - nextID++ - receivers = append(receivers, &FakeAMQPReceiver{ - NameForLink: fmt.Sprintf("Link%d", nextID), - }) - return receivers[len(receivers)-1], nil - }) + }, func(ctx context.Context, session amqpwrap.AMQPSession, entityPath string) (amqpwrap.AMQPReceiverCloser, error) { + return receiver, nil + }) + + require.NotNil(t, links.contextWithTimeoutFn, "sanity check, we are setting the context.WithTimeout func") + links.contextWithTimeoutFn = mock.NewContextWithTimeoutForTests lwid, err := links.GetLink(context.Background(), "0") require.NoError(t, err) - require.NotNil(t, lwid) - require.NotNil(t, links.links["0"], "cache contains the newly created link for partition 0") + require.NotNil(t, links.links["0"]) + require.Equal(t, 1, len(links.links)) - require.Equal(t, recoverClientCalled, 0) + // if the connection has closed in response to an error then it'll propagate it's error to + // the children, including receivers. Which means closing the receiver here will _also_ return + // a connection error. + receiver.EXPECT().Close(mock.NotCancelledAndHasTimeout).Return(&amqp.ConnectionError{}) + ns.EXPECT().Recover(mock.NotCancelledAndHasTimeout, gomock.Any()).Return(nil) + + // initiate a connection level recovery err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.ConnectionError{}) require.NoError(t, err) - require.Nil(t, links.links["0"], "cache will no longer a link for partition 0") - require.Equal(t, recoverClientCalled, 1, "client was recovered") + // we still cleanup what we can (including cancelling our background negotiate claim loop) + require.ErrorIs(t, context.Canceled, negotiateClaimCtx.Err()) + require.Empty(t, links.links, "link is removed") +} - // no new links are create - we'll need to do something that requires a link - // to cause it to come back. - require.Equal(t, 1, len(receivers)) - require.Equal(t, 1, receivers[0].CloseCalled) +func TestLinks_closeWithTimeout(t *testing.T) { + for _, errToReturn := range []error{context.DeadlineExceeded, context.Canceled} { + t.Run(fmt.Sprintf("Close() cancels with error %v", errToReturn), func(t *testing.T) { + ctrl := gomock.NewController(t) + ns := mock.NewMockNamespaceForAMQPLinks(ctrl) + receiver := mock.NewMockAMQPReceiverCloser(ctrl) + session := mock.NewMockAMQPSession(ctrl) - // cause a new link to get created to replace the old one. - receivers = nil + negotiateClaimCtx, cancelNegotiateClaim := context.WithCancel(context.Background()) - newLWID, err := links.GetLink(context.Background(), "0") + ns.EXPECT().NegotiateClaim(mock.NotCancelled, gomock.Any()).Return(cancelNegotiateClaim, negotiateClaimCtx.Done(), nil) + ns.EXPECT().NewAMQPSession(mock.NotCancelled).Return(session, uint64(1), nil) + + receiver.EXPECT().LinkName().Return("link1").AnyTimes() + + links := NewLinks(ns, "managementPath", func(partitionID string) string { + return fmt.Sprintf("part:%s", partitionID) + }, func(ctx context.Context, session amqpwrap.AMQPSession, entityPath string) (amqpwrap.AMQPReceiverCloser, error) { + return receiver, nil + }) + + require.NotNil(t, links.contextWithTimeoutFn, "sanity check, we are setting the context.WithTimeout func") + links.contextWithTimeoutFn = mock.NewContextWithTimeoutForTests + + lwid, err := links.GetLink(context.Background(), "0") + require.NoError(t, err) + + // now set ourselves up so Close() is "slow" and we end up timing out, or + // the user "cancels" + receiver.EXPECT().Close(mock.NotCancelledAndHasTimeout).DoAndReturn(func(ctx context.Context) error { + <-ctx.Done() + return errToReturn + }) + + // purposefully recover with what should be a link level recovery. However, the Close() failing + // means we end up "upgrading" to a connection reset instead. + err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.DetachError{}) + require.ErrorIs(t, err, errConnResetNeeded) + + // the error that comes back when the link times out being closed can only + // be fixed by a connection reset. + require.Equal(t, RecoveryKindConn, GetRecoveryKind(errConnResetNeeded)) + + // we still cleanup what we can (including cancelling our background negotiate claim loop) + require.ErrorIs(t, context.Canceled, negotiateClaimCtx.Err()) + }) + } +} + +func TestLinks_linkRecoveryOnly(t *testing.T) { + ctrl := gomock.NewController(t) + fakeNS := mock.NewMockNamespaceForAMQPLinks(ctrl) + fakeReceiver := mock.NewMockAMQPReceiverCloser(ctrl) + session := mock.NewMockAMQPSession(ctrl) + + negotiateClaimCtx, cancelNegotiateClaim := context.WithCancel(context.Background()) + + fakeNS.EXPECT().NegotiateClaim(mock.NotCancelled, gomock.Any()).Return( + cancelNegotiateClaim, negotiateClaimCtx.Done(), nil, + ) + fakeNS.EXPECT().NewAMQPSession(mock.NotCancelled).Return(session, uint64(1), nil) + + fakeReceiver.EXPECT().LinkName().Return("link1").AnyTimes() + + // super important that when we close we're given a context that properly times out. + // (in this test the Close(ctx) call doesn't time out) + fakeReceiver.EXPECT().Close(mock.NotCancelledAndHasTimeout).Return(nil) + + links := NewLinks(fakeNS, "managementPath", func(partitionID string) string { + return fmt.Sprintf("part:%s", partitionID) + }, func(ctx context.Context, session amqpwrap.AMQPSession, entityPath string) (amqpwrap.AMQPReceiverCloser, error) { + return fakeReceiver, nil + }) + + links.contextWithTimeoutFn = mock.NewContextWithTimeoutForTests + + lwid, err := links.GetLink(context.Background(), "0") require.NoError(t, err) - require.NotEqual(t, lwid, newLWID, "new link gets a new ID") - require.NotNil(t, links.links["0"], "cache contains the newly created link for partition 0") - require.Equal(t, 1, len(receivers)) - require.Equal(t, 0, receivers[0].CloseCalled) + err = links.RecoverIfNeeded(context.Background(), "0", lwid, &amqp.DetachError{}) + require.NoError(t, err) + + // we still cleanup what we can (including cancelling our background negotiate claim loop) + require.ErrorIs(t, context.Canceled, negotiateClaimCtx.Err()) } diff --git a/sdk/messaging/azeventhubs/internal/mock/mock_amqp.go b/sdk/messaging/azeventhubs/internal/mock/mock_amqp.go new file mode 100644 index 000000000000..f71582e3a32d --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/mock/mock_amqp.go @@ -0,0 +1,689 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: ../amqpwrap/amqpwrap.go + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + amqpwrap "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + amqp "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp" + gomock "github.com/golang/mock/gomock" +) + +// MockAMQPReceiver is a mock of AMQPReceiver interface. +type MockAMQPReceiver struct { + ctrl *gomock.Controller + recorder *MockAMQPReceiverMockRecorder +} + +// MockAMQPReceiverMockRecorder is the mock recorder for MockAMQPReceiver. +type MockAMQPReceiverMockRecorder struct { + mock *MockAMQPReceiver +} + +// NewMockAMQPReceiver creates a new mock instance. +func NewMockAMQPReceiver(ctrl *gomock.Controller) *MockAMQPReceiver { + mock := &MockAMQPReceiver{ctrl: ctrl} + mock.recorder = &MockAMQPReceiverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAMQPReceiver) EXPECT() *MockAMQPReceiverMockRecorder { + return m.recorder +} + +// AcceptMessage mocks base method. +func (m *MockAMQPReceiver) AcceptMessage(ctx context.Context, msg *amqp.Message) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptMessage", ctx, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// AcceptMessage indicates an expected call of AcceptMessage. +func (mr *MockAMQPReceiverMockRecorder) AcceptMessage(ctx, msg interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptMessage", reflect.TypeOf((*MockAMQPReceiver)(nil).AcceptMessage), ctx, msg) +} + +// Credits mocks base method. +func (m *MockAMQPReceiver) Credits() uint32 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Credits") + ret0, _ := ret[0].(uint32) + return ret0 +} + +// Credits indicates an expected call of Credits. +func (mr *MockAMQPReceiverMockRecorder) Credits() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Credits", reflect.TypeOf((*MockAMQPReceiver)(nil).Credits)) +} + +// IssueCredit mocks base method. +func (m *MockAMQPReceiver) IssueCredit(credit uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IssueCredit", credit) + ret0, _ := ret[0].(error) + return ret0 +} + +// IssueCredit indicates an expected call of IssueCredit. +func (mr *MockAMQPReceiverMockRecorder) IssueCredit(credit interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IssueCredit", reflect.TypeOf((*MockAMQPReceiver)(nil).IssueCredit), credit) +} + +// LinkName mocks base method. +func (m *MockAMQPReceiver) LinkName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkName") + ret0, _ := ret[0].(string) + return ret0 +} + +// LinkName indicates an expected call of LinkName. +func (mr *MockAMQPReceiverMockRecorder) LinkName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkName", reflect.TypeOf((*MockAMQPReceiver)(nil).LinkName)) +} + +// LinkSourceFilterValue mocks base method. +func (m *MockAMQPReceiver) LinkSourceFilterValue(name string) interface{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkSourceFilterValue", name) + ret0, _ := ret[0].(interface{}) + return ret0 +} + +// LinkSourceFilterValue indicates an expected call of LinkSourceFilterValue. +func (mr *MockAMQPReceiverMockRecorder) LinkSourceFilterValue(name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSourceFilterValue", reflect.TypeOf((*MockAMQPReceiver)(nil).LinkSourceFilterValue), name) +} + +// ModifyMessage mocks base method. +func (m *MockAMQPReceiver) ModifyMessage(ctx context.Context, msg *amqp.Message, options *amqp.ModifyMessageOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ModifyMessage", ctx, msg, options) + ret0, _ := ret[0].(error) + return ret0 +} + +// ModifyMessage indicates an expected call of ModifyMessage. +func (mr *MockAMQPReceiverMockRecorder) ModifyMessage(ctx, msg, options interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ModifyMessage", reflect.TypeOf((*MockAMQPReceiver)(nil).ModifyMessage), ctx, msg, options) +} + +// Prefetched mocks base method. +func (m *MockAMQPReceiver) Prefetched() *amqp.Message { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prefetched") + ret0, _ := ret[0].(*amqp.Message) + return ret0 +} + +// Prefetched indicates an expected call of Prefetched. +func (mr *MockAMQPReceiverMockRecorder) Prefetched() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prefetched", reflect.TypeOf((*MockAMQPReceiver)(nil).Prefetched)) +} + +// Receive mocks base method. +func (m *MockAMQPReceiver) Receive(ctx context.Context) (*amqp.Message, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Receive", ctx) + ret0, _ := ret[0].(*amqp.Message) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Receive indicates an expected call of Receive. +func (mr *MockAMQPReceiverMockRecorder) Receive(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockAMQPReceiver)(nil).Receive), ctx) +} + +// RejectMessage mocks base method. +func (m *MockAMQPReceiver) RejectMessage(ctx context.Context, msg *amqp.Message, e *amqp.Error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RejectMessage", ctx, msg, e) + ret0, _ := ret[0].(error) + return ret0 +} + +// RejectMessage indicates an expected call of RejectMessage. +func (mr *MockAMQPReceiverMockRecorder) RejectMessage(ctx, msg, e interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RejectMessage", reflect.TypeOf((*MockAMQPReceiver)(nil).RejectMessage), ctx, msg, e) +} + +// ReleaseMessage mocks base method. +func (m *MockAMQPReceiver) ReleaseMessage(ctx context.Context, msg *amqp.Message) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReleaseMessage", ctx, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReleaseMessage indicates an expected call of ReleaseMessage. +func (mr *MockAMQPReceiverMockRecorder) ReleaseMessage(ctx, msg interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseMessage", reflect.TypeOf((*MockAMQPReceiver)(nil).ReleaseMessage), ctx, msg) +} + +// MockAMQPReceiverCloser is a mock of AMQPReceiverCloser interface. +type MockAMQPReceiverCloser struct { + ctrl *gomock.Controller + recorder *MockAMQPReceiverCloserMockRecorder +} + +// MockAMQPReceiverCloserMockRecorder is the mock recorder for MockAMQPReceiverCloser. +type MockAMQPReceiverCloserMockRecorder struct { + mock *MockAMQPReceiverCloser +} + +// NewMockAMQPReceiverCloser creates a new mock instance. +func NewMockAMQPReceiverCloser(ctrl *gomock.Controller) *MockAMQPReceiverCloser { + mock := &MockAMQPReceiverCloser{ctrl: ctrl} + mock.recorder = &MockAMQPReceiverCloserMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAMQPReceiverCloser) EXPECT() *MockAMQPReceiverCloserMockRecorder { + return m.recorder +} + +// AcceptMessage mocks base method. +func (m *MockAMQPReceiverCloser) AcceptMessage(ctx context.Context, msg *amqp.Message) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcceptMessage", ctx, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// AcceptMessage indicates an expected call of AcceptMessage. +func (mr *MockAMQPReceiverCloserMockRecorder) AcceptMessage(ctx, msg interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptMessage", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).AcceptMessage), ctx, msg) +} + +// Close mocks base method. +func (m *MockAMQPReceiverCloser) Close(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockAMQPReceiverCloserMockRecorder) Close(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).Close), ctx) +} + +// Credits mocks base method. +func (m *MockAMQPReceiverCloser) Credits() uint32 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Credits") + ret0, _ := ret[0].(uint32) + return ret0 +} + +// Credits indicates an expected call of Credits. +func (mr *MockAMQPReceiverCloserMockRecorder) Credits() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Credits", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).Credits)) +} + +// IssueCredit mocks base method. +func (m *MockAMQPReceiverCloser) IssueCredit(credit uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IssueCredit", credit) + ret0, _ := ret[0].(error) + return ret0 +} + +// IssueCredit indicates an expected call of IssueCredit. +func (mr *MockAMQPReceiverCloserMockRecorder) IssueCredit(credit interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IssueCredit", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).IssueCredit), credit) +} + +// LinkName mocks base method. +func (m *MockAMQPReceiverCloser) LinkName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkName") + ret0, _ := ret[0].(string) + return ret0 +} + +// LinkName indicates an expected call of LinkName. +func (mr *MockAMQPReceiverCloserMockRecorder) LinkName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkName", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).LinkName)) +} + +// LinkSourceFilterValue mocks base method. +func (m *MockAMQPReceiverCloser) LinkSourceFilterValue(name string) interface{} { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkSourceFilterValue", name) + ret0, _ := ret[0].(interface{}) + return ret0 +} + +// LinkSourceFilterValue indicates an expected call of LinkSourceFilterValue. +func (mr *MockAMQPReceiverCloserMockRecorder) LinkSourceFilterValue(name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSourceFilterValue", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).LinkSourceFilterValue), name) +} + +// ModifyMessage mocks base method. +func (m *MockAMQPReceiverCloser) ModifyMessage(ctx context.Context, msg *amqp.Message, options *amqp.ModifyMessageOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ModifyMessage", ctx, msg, options) + ret0, _ := ret[0].(error) + return ret0 +} + +// ModifyMessage indicates an expected call of ModifyMessage. +func (mr *MockAMQPReceiverCloserMockRecorder) ModifyMessage(ctx, msg, options interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ModifyMessage", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).ModifyMessage), ctx, msg, options) +} + +// Prefetched mocks base method. +func (m *MockAMQPReceiverCloser) Prefetched() *amqp.Message { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prefetched") + ret0, _ := ret[0].(*amqp.Message) + return ret0 +} + +// Prefetched indicates an expected call of Prefetched. +func (mr *MockAMQPReceiverCloserMockRecorder) Prefetched() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prefetched", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).Prefetched)) +} + +// Receive mocks base method. +func (m *MockAMQPReceiverCloser) Receive(ctx context.Context) (*amqp.Message, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Receive", ctx) + ret0, _ := ret[0].(*amqp.Message) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Receive indicates an expected call of Receive. +func (mr *MockAMQPReceiverCloserMockRecorder) Receive(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Receive", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).Receive), ctx) +} + +// RejectMessage mocks base method. +func (m *MockAMQPReceiverCloser) RejectMessage(ctx context.Context, msg *amqp.Message, e *amqp.Error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RejectMessage", ctx, msg, e) + ret0, _ := ret[0].(error) + return ret0 +} + +// RejectMessage indicates an expected call of RejectMessage. +func (mr *MockAMQPReceiverCloserMockRecorder) RejectMessage(ctx, msg, e interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RejectMessage", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).RejectMessage), ctx, msg, e) +} + +// ReleaseMessage mocks base method. +func (m *MockAMQPReceiverCloser) ReleaseMessage(ctx context.Context, msg *amqp.Message) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReleaseMessage", ctx, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// ReleaseMessage indicates an expected call of ReleaseMessage. +func (mr *MockAMQPReceiverCloserMockRecorder) ReleaseMessage(ctx, msg interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseMessage", reflect.TypeOf((*MockAMQPReceiverCloser)(nil).ReleaseMessage), ctx, msg) +} + +// MockAMQPSender is a mock of AMQPSender interface. +type MockAMQPSender struct { + ctrl *gomock.Controller + recorder *MockAMQPSenderMockRecorder +} + +// MockAMQPSenderMockRecorder is the mock recorder for MockAMQPSender. +type MockAMQPSenderMockRecorder struct { + mock *MockAMQPSender +} + +// NewMockAMQPSender creates a new mock instance. +func NewMockAMQPSender(ctrl *gomock.Controller) *MockAMQPSender { + mock := &MockAMQPSender{ctrl: ctrl} + mock.recorder = &MockAMQPSenderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAMQPSender) EXPECT() *MockAMQPSenderMockRecorder { + return m.recorder +} + +// LinkName mocks base method. +func (m *MockAMQPSender) LinkName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkName") + ret0, _ := ret[0].(string) + return ret0 +} + +// LinkName indicates an expected call of LinkName. +func (mr *MockAMQPSenderMockRecorder) LinkName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkName", reflect.TypeOf((*MockAMQPSender)(nil).LinkName)) +} + +// MaxMessageSize mocks base method. +func (m *MockAMQPSender) MaxMessageSize() uint64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MaxMessageSize") + ret0, _ := ret[0].(uint64) + return ret0 +} + +// MaxMessageSize indicates an expected call of MaxMessageSize. +func (mr *MockAMQPSenderMockRecorder) MaxMessageSize() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaxMessageSize", reflect.TypeOf((*MockAMQPSender)(nil).MaxMessageSize)) +} + +// Send mocks base method. +func (m *MockAMQPSender) Send(ctx context.Context, msg *amqp.Message) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Send", ctx, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send. +func (mr *MockAMQPSenderMockRecorder) Send(ctx, msg interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockAMQPSender)(nil).Send), ctx, msg) +} + +// MockAMQPSenderCloser is a mock of AMQPSenderCloser interface. +type MockAMQPSenderCloser struct { + ctrl *gomock.Controller + recorder *MockAMQPSenderCloserMockRecorder +} + +// MockAMQPSenderCloserMockRecorder is the mock recorder for MockAMQPSenderCloser. +type MockAMQPSenderCloserMockRecorder struct { + mock *MockAMQPSenderCloser +} + +// NewMockAMQPSenderCloser creates a new mock instance. +func NewMockAMQPSenderCloser(ctrl *gomock.Controller) *MockAMQPSenderCloser { + mock := &MockAMQPSenderCloser{ctrl: ctrl} + mock.recorder = &MockAMQPSenderCloserMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAMQPSenderCloser) EXPECT() *MockAMQPSenderCloserMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockAMQPSenderCloser) Close(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockAMQPSenderCloserMockRecorder) Close(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAMQPSenderCloser)(nil).Close), ctx) +} + +// LinkName mocks base method. +func (m *MockAMQPSenderCloser) LinkName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkName") + ret0, _ := ret[0].(string) + return ret0 +} + +// LinkName indicates an expected call of LinkName. +func (mr *MockAMQPSenderCloserMockRecorder) LinkName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkName", reflect.TypeOf((*MockAMQPSenderCloser)(nil).LinkName)) +} + +// MaxMessageSize mocks base method. +func (m *MockAMQPSenderCloser) MaxMessageSize() uint64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MaxMessageSize") + ret0, _ := ret[0].(uint64) + return ret0 +} + +// MaxMessageSize indicates an expected call of MaxMessageSize. +func (mr *MockAMQPSenderCloserMockRecorder) MaxMessageSize() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaxMessageSize", reflect.TypeOf((*MockAMQPSenderCloser)(nil).MaxMessageSize)) +} + +// Send mocks base method. +func (m *MockAMQPSenderCloser) Send(ctx context.Context, msg *amqp.Message) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Send", ctx, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send. +func (mr *MockAMQPSenderCloserMockRecorder) Send(ctx, msg interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockAMQPSenderCloser)(nil).Send), ctx, msg) +} + +// MockAMQPSession is a mock of AMQPSession interface. +type MockAMQPSession struct { + ctrl *gomock.Controller + recorder *MockAMQPSessionMockRecorder +} + +// MockAMQPSessionMockRecorder is the mock recorder for MockAMQPSession. +type MockAMQPSessionMockRecorder struct { + mock *MockAMQPSession +} + +// NewMockAMQPSession creates a new mock instance. +func NewMockAMQPSession(ctrl *gomock.Controller) *MockAMQPSession { + mock := &MockAMQPSession{ctrl: ctrl} + mock.recorder = &MockAMQPSessionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAMQPSession) EXPECT() *MockAMQPSessionMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockAMQPSession) Close(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockAMQPSessionMockRecorder) Close(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAMQPSession)(nil).Close), ctx) +} + +// NewReceiver mocks base method. +func (m *MockAMQPSession) NewReceiver(ctx context.Context, source string, opts *amqp.ReceiverOptions) (amqpwrap.AMQPReceiverCloser, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewReceiver", ctx, source, opts) + ret0, _ := ret[0].(amqpwrap.AMQPReceiverCloser) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewReceiver indicates an expected call of NewReceiver. +func (mr *MockAMQPSessionMockRecorder) NewReceiver(ctx, source, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewReceiver", reflect.TypeOf((*MockAMQPSession)(nil).NewReceiver), ctx, source, opts) +} + +// NewSender mocks base method. +func (m *MockAMQPSession) NewSender(ctx context.Context, target string, opts *amqp.SenderOptions) (amqpwrap.AMQPSenderCloser, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewSender", ctx, target, opts) + ret0, _ := ret[0].(amqpwrap.AMQPSenderCloser) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewSender indicates an expected call of NewSender. +func (mr *MockAMQPSessionMockRecorder) NewSender(ctx, target, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSender", reflect.TypeOf((*MockAMQPSession)(nil).NewSender), ctx, target, opts) +} + +// MockAMQPClient is a mock of AMQPClient interface. +type MockAMQPClient struct { + ctrl *gomock.Controller + recorder *MockAMQPClientMockRecorder +} + +// MockAMQPClientMockRecorder is the mock recorder for MockAMQPClient. +type MockAMQPClientMockRecorder struct { + mock *MockAMQPClient +} + +// NewMockAMQPClient creates a new mock instance. +func NewMockAMQPClient(ctrl *gomock.Controller) *MockAMQPClient { + mock := &MockAMQPClient{ctrl: ctrl} + mock.recorder = &MockAMQPClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAMQPClient) EXPECT() *MockAMQPClientMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockAMQPClient) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockAMQPClientMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAMQPClient)(nil).Close)) +} + +// NewSession mocks base method. +func (m *MockAMQPClient) NewSession(ctx context.Context, opts *amqp.SessionOptions) (amqpwrap.AMQPSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewSession", ctx, opts) + ret0, _ := ret[0].(amqpwrap.AMQPSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewSession indicates an expected call of NewSession. +func (mr *MockAMQPClientMockRecorder) NewSession(ctx, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSession", reflect.TypeOf((*MockAMQPClient)(nil).NewSession), ctx, opts) +} + +// MockRPCLink is a mock of RPCLink interface. +type MockRPCLink struct { + ctrl *gomock.Controller + recorder *MockRPCLinkMockRecorder +} + +// MockRPCLinkMockRecorder is the mock recorder for MockRPCLink. +type MockRPCLinkMockRecorder struct { + mock *MockRPCLink +} + +// NewMockRPCLink creates a new mock instance. +func NewMockRPCLink(ctrl *gomock.Controller) *MockRPCLink { + mock := &MockRPCLink{ctrl: ctrl} + mock.recorder = &MockRPCLinkMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRPCLink) EXPECT() *MockRPCLinkMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRPCLink) Close(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRPCLinkMockRecorder) Close(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRPCLink)(nil).Close), ctx) +} + +// LinkName mocks base method. +func (m *MockRPCLink) LinkName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkName") + ret0, _ := ret[0].(string) + return ret0 +} + +// LinkName indicates an expected call of LinkName. +func (mr *MockRPCLinkMockRecorder) LinkName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkName", reflect.TypeOf((*MockRPCLink)(nil).LinkName)) +} + +// RPC mocks base method. +func (m *MockRPCLink) RPC(ctx context.Context, msg *amqp.Message) (*amqpwrap.RPCResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RPC", ctx, msg) + ret0, _ := ret[0].(*amqpwrap.RPCResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RPC indicates an expected call of RPC. +func (mr *MockRPCLinkMockRecorder) RPC(ctx, msg interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RPC", reflect.TypeOf((*MockRPCLink)(nil).RPC), ctx, msg) +} diff --git a/sdk/messaging/azeventhubs/internal/mock/mock_generate.go b/sdk/messaging/azeventhubs/internal/mock/mock_generate.go new file mode 100644 index 000000000000..0352beb720ad --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/mock/mock_generate.go @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//go:generate mockgen -source ../namespace.go -package mock -copyright_file ./testdata/copyright.txt -destination mock_namespace.go NamespaceWithNewAMQPLinks,NamespaceForAMQPLinks + +//go:generate mockgen -source ../amqpwrap/amqpwrap.go -package mock -copyright_file ./testdata/copyright.txt -destination mock_amqp.go + +package mock diff --git a/sdk/messaging/azeventhubs/internal/mock/mock_helpers.go b/sdk/messaging/azeventhubs/internal/mock/mock_helpers.go new file mode 100644 index 000000000000..0f38298fdb68 --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/mock/mock_helpers.go @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package mock + +import ( + context "context" + "fmt" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp" + gomock "github.com/golang/mock/gomock" +) + +func SetupRPC(sender *MockAMQPSenderCloser, receiver *MockAMQPReceiverCloser, expectedCount int, handler func(sent *amqp.Message, response *amqp.Message)) { + // this is an RPC pattern - when we send a message we give it a message ID, and the + // response comes back with a correlation ID filled out, so you can match requests + // to responses. + ch := make(chan *amqp.Message, 1000) + + for i := 0; i < expectedCount; i++ { + sender.EXPECT().Send(gomock.Any(), gomock.Any()).Do(func(ctx context.Context, msg *amqp.Message) error { + ch <- msg + return nil + }) + } + + // RPC loops forever. We get one extra Receive() call here (the one that waits on the ctx.Done()) + for i := 0; i < expectedCount+1; i++ { + receiver.EXPECT().Receive(gomock.Any()).DoAndReturn(func(ctx context.Context) (*amqp.Message, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case sentMessage := <-ch: + response := &amqp.Message{ + // this is how RPC responses are correlated with their + // sent messages. + Properties: &amqp.MessageProperties{ + CorrelationID: sentMessage.Properties.MessageID, + }, + } + receiver.EXPECT().AcceptMessage(gomock.Any(), gomock.Any()).Return(nil) + + // let the caller fill in the blanks of whatever needs to happen here. + handler(sentMessage, response) + return response, nil + } + }) + } +} + +// Cancelled matches context.Context instances that are cancelled. +var Cancelled gomock.Matcher = ContextCancelledMatcher{true} + +// NotCancelled matches context.Context instances that are not cancelled. +var NotCancelled gomock.Matcher = ContextCancelledMatcher{false} + +// NotCancelledAndHasTimeout matches context.Context instances that are not cancelled +// AND were also created from NewContextForTest. +var NotCancelledAndHasTimeout gomock.Matcher = gomock.All(ContextCancelledMatcher{false}, ContextHasTestValueMatcher{}) + +// CancelledAndHasTimeout matches context.Context instances that are cancelled +// AND were also created from NewContextForTest. +var CancelledAndHasTimeout gomock.Matcher = gomock.All(ContextCancelledMatcher{true}, ContextHasTestValueMatcher{}) + +type ContextCancelledMatcher struct { + // WantCancelled should be set if we expect the context should + // be cancelled. If true, we check if Err() != nil, if false we check + // that it's nil. + WantCancelled bool +} + +// Matches returns whether x is a match. +func (m ContextCancelledMatcher) Matches(x interface{}) bool { + ctx := x.(context.Context) + + if m.WantCancelled { + return ctx.Err() != nil + } else { + return ctx.Err() == nil + } +} + +// String describes what the matcher matches. +func (m ContextCancelledMatcher) String() string { + return fmt.Sprintf("want cancelled:%v", m.WantCancelled) +} + +type ContextHasTestValueMatcher struct{} + +func (m ContextHasTestValueMatcher) Matches(x interface{}) bool { + ctx := x.(context.Context) + return ctx.Value(testContextKey(0)) == "correctContextWasUsed" +} + +func (m ContextHasTestValueMatcher) String() string { + return "has test context value" +} + +type testContextKey int + +// NewContextWithTimeoutForTests creates a context with a lower timeout than requested just to keep +// unit test times reasonable. +// +// It validates that the passed in timeout is the actual defaultCloseTimeout and also +// adds in a testContextKey(0) as a value, which can be used to verify that the context +// has been properly propagated. +func NewContextWithTimeoutForTests(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + // (we're in the wrong package to share the value, but this is meant to match defaultCloseTimeout) + if timeout != time.Minute { + // panic'ing instead of require.Equal() otherwise I would need to take a 't' and not be signature + // compatible with context.WithTimeout. + panic(fmt.Sprintf("Incorrect close timeout: expected %s, actual %s", time.Minute, timeout)) + } + + parentWithValue := context.WithValue(parent, testContextKey(0), "correctContextWasUsed") + + // NOTE: if you're debugging then you might need to bump up this + // value so you can single step. + return context.WithTimeout(parentWithValue, time.Second) +} diff --git a/sdk/messaging/azeventhubs/internal/mock/mock_namespace.go b/sdk/messaging/azeventhubs/internal/mock/mock_namespace.go new file mode 100644 index 000000000000..5db363568f4a --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/mock/mock_namespace.go @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: ../namespace.go + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + amqpwrap "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" + gomock "github.com/golang/mock/gomock" +) + +// MockNamespaceWithNewAMQPLinks is a mock of NamespaceWithNewAMQPLinks interface. +type MockNamespaceWithNewAMQPLinks struct { + ctrl *gomock.Controller + recorder *MockNamespaceWithNewAMQPLinksMockRecorder +} + +// MockNamespaceWithNewAMQPLinksMockRecorder is the mock recorder for MockNamespaceWithNewAMQPLinks. +type MockNamespaceWithNewAMQPLinksMockRecorder struct { + mock *MockNamespaceWithNewAMQPLinks +} + +// NewMockNamespaceWithNewAMQPLinks creates a new mock instance. +func NewMockNamespaceWithNewAMQPLinks(ctrl *gomock.Controller) *MockNamespaceWithNewAMQPLinks { + mock := &MockNamespaceWithNewAMQPLinks{ctrl: ctrl} + mock.recorder = &MockNamespaceWithNewAMQPLinksMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNamespaceWithNewAMQPLinks) EXPECT() *MockNamespaceWithNewAMQPLinksMockRecorder { + return m.recorder +} + +// Check mocks base method. +func (m *MockNamespaceWithNewAMQPLinks) Check() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Check") + ret0, _ := ret[0].(error) + return ret0 +} + +// Check indicates an expected call of Check. +func (mr *MockNamespaceWithNewAMQPLinksMockRecorder) Check() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Check", reflect.TypeOf((*MockNamespaceWithNewAMQPLinks)(nil).Check)) +} + +// MockNamespaceForAMQPLinks is a mock of NamespaceForAMQPLinks interface. +type MockNamespaceForAMQPLinks struct { + ctrl *gomock.Controller + recorder *MockNamespaceForAMQPLinksMockRecorder +} + +// MockNamespaceForAMQPLinksMockRecorder is the mock recorder for MockNamespaceForAMQPLinks. +type MockNamespaceForAMQPLinksMockRecorder struct { + mock *MockNamespaceForAMQPLinks +} + +// NewMockNamespaceForAMQPLinks creates a new mock instance. +func NewMockNamespaceForAMQPLinks(ctrl *gomock.Controller) *MockNamespaceForAMQPLinks { + mock := &MockNamespaceForAMQPLinks{ctrl: ctrl} + mock.recorder = &MockNamespaceForAMQPLinksMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNamespaceForAMQPLinks) EXPECT() *MockNamespaceForAMQPLinksMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockNamespaceForAMQPLinks) Close(ctx context.Context, permanently bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", ctx, permanently) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockNamespaceForAMQPLinksMockRecorder) Close(ctx, permanently interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockNamespaceForAMQPLinks)(nil).Close), ctx, permanently) +} + +// GetEntityAudience mocks base method. +func (m *MockNamespaceForAMQPLinks) GetEntityAudience(entityPath string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEntityAudience", entityPath) + ret0, _ := ret[0].(string) + return ret0 +} + +// GetEntityAudience indicates an expected call of GetEntityAudience. +func (mr *MockNamespaceForAMQPLinksMockRecorder) GetEntityAudience(entityPath interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEntityAudience", reflect.TypeOf((*MockNamespaceForAMQPLinks)(nil).GetEntityAudience), entityPath) +} + +// NegotiateClaim mocks base method. +func (m *MockNamespaceForAMQPLinks) NegotiateClaim(ctx context.Context, entityPath string) (context.CancelFunc, <-chan struct{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NegotiateClaim", ctx, entityPath) + ret0, _ := ret[0].(context.CancelFunc) + ret1, _ := ret[1].(<-chan struct{}) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// NegotiateClaim indicates an expected call of NegotiateClaim. +func (mr *MockNamespaceForAMQPLinksMockRecorder) NegotiateClaim(ctx, entityPath interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NegotiateClaim", reflect.TypeOf((*MockNamespaceForAMQPLinks)(nil).NegotiateClaim), ctx, entityPath) +} + +// NewAMQPSession mocks base method. +func (m *MockNamespaceForAMQPLinks) NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession, uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewAMQPSession", ctx) + ret0, _ := ret[0].(amqpwrap.AMQPSession) + ret1, _ := ret[1].(uint64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// NewAMQPSession indicates an expected call of NewAMQPSession. +func (mr *MockNamespaceForAMQPLinksMockRecorder) NewAMQPSession(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewAMQPSession", reflect.TypeOf((*MockNamespaceForAMQPLinks)(nil).NewAMQPSession), ctx) +} + +// NewRPCLink mocks base method. +func (m *MockNamespaceForAMQPLinks) NewRPCLink(ctx context.Context, managementPath string) (amqpwrap.RPCLink, uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewRPCLink", ctx, managementPath) + ret0, _ := ret[0].(amqpwrap.RPCLink) + ret1, _ := ret[1].(uint64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// NewRPCLink indicates an expected call of NewRPCLink. +func (mr *MockNamespaceForAMQPLinksMockRecorder) NewRPCLink(ctx, managementPath interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewRPCLink", reflect.TypeOf((*MockNamespaceForAMQPLinks)(nil).NewRPCLink), ctx, managementPath) +} + +// Recover mocks base method. +func (m *MockNamespaceForAMQPLinks) Recover(ctx context.Context, clientRevision uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recover", ctx, clientRevision) + ret0, _ := ret[0].(error) + return ret0 +} + +// Recover indicates an expected call of Recover. +func (mr *MockNamespaceForAMQPLinksMockRecorder) Recover(ctx, clientRevision interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recover", reflect.TypeOf((*MockNamespaceForAMQPLinks)(nil).Recover), ctx, clientRevision) +} diff --git a/sdk/messaging/azeventhubs/internal/mock/mock_token.go b/sdk/messaging/azeventhubs/internal/mock/mock_token.go new file mode 100644 index 000000000000..d2de1871c308 --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/mock/mock_token.go @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Code generated by MockGen. DO NOT EDIT. +// Source: ./auth/token.go + +package mock + +import ( + reflect "reflect" + + auth "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/auth" + gomock "github.com/golang/mock/gomock" +) + +// MockTokenProvider is a mock of TokenProvider interface. +type MockTokenProvider struct { + ctrl *gomock.Controller + recorder *MockTokenProviderMockRecorder +} + +// MockTokenProviderMockRecorder is the mock recorder for MockTokenProvider. +type MockTokenProviderMockRecorder struct { + mock *MockTokenProvider +} + +// NewMockTokenProvider creates a new mock instance. +func NewMockTokenProvider(ctrl *gomock.Controller) *MockTokenProvider { + mock := &MockTokenProvider{ctrl: ctrl} + mock.recorder = &MockTokenProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTokenProvider) EXPECT() *MockTokenProviderMockRecorder { + return m.recorder +} + +// GetToken mocks base method. +func (m *MockTokenProvider) GetToken(uri string) (*auth.Token, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetToken", uri) + ret0, _ := ret[0].(*auth.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetToken indicates an expected call of GetToken. +func (mr *MockTokenProviderMockRecorder) GetToken(uri interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetToken", reflect.TypeOf((*MockTokenProvider)(nil).GetToken), uri) +} diff --git a/sdk/messaging/azeventhubs/internal/mock/mock_token_credential.go b/sdk/messaging/azeventhubs/internal/mock/mock_token_credential.go new file mode 100644 index 000000000000..e2c28c065154 --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/mock/mock_token_credential.go @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Code generated by MockGen. DO NOT EDIT. +// Source: ../../../azcore/internal/exported/exported.go + +package mock + +import ( + context "context" + reflect "reflect" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + gomock "github.com/golang/mock/gomock" +) + +// MockTokenCredential is a mock of TokenCredential interface. +type MockTokenCredential struct { + ctrl *gomock.Controller + recorder *MockTokenCredentialMockRecorder +} + +// MockTokenCredentialMockRecorder is the mock recorder for MockTokenCredential. +type MockTokenCredentialMockRecorder struct { + mock *MockTokenCredential +} + +// NewMockTokenCredential creates a new mock instance. +func NewMockTokenCredential(ctrl *gomock.Controller) *MockTokenCredential { + mock := &MockTokenCredential{ctrl: ctrl} + mock.recorder = &MockTokenCredentialMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTokenCredential) EXPECT() *MockTokenCredentialMockRecorder { + return m.recorder +} + +// GetToken mocks base method. +func (m *MockTokenCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetToken", ctx, options) + ret0, _ := ret[0].(azcore.AccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetToken indicates an expected call of GetToken. +func (mr *MockTokenCredentialMockRecorder) GetToken(ctx, options interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetToken", reflect.TypeOf((*MockTokenCredential)(nil).GetToken), ctx, options) +} diff --git a/sdk/messaging/azeventhubs/internal/mock/testdata/copyright.txt b/sdk/messaging/azeventhubs/internal/mock/testdata/copyright.txt new file mode 100644 index 000000000000..679520bee0e3 --- /dev/null +++ b/sdk/messaging/azeventhubs/internal/mock/testdata/copyright.txt @@ -0,0 +1,2 @@ +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. diff --git a/sdk/messaging/azeventhubs/internal/namespace.go b/sdk/messaging/azeventhubs/internal/namespace.go index cbdac020f2c1..65fc694dd912 100644 --- a/sdk/messaging/azeventhubs/internal/namespace.go +++ b/sdk/messaging/azeventhubs/internal/namespace.go @@ -71,7 +71,7 @@ type NamespaceWithNewAMQPLinks interface { type NamespaceForAMQPLinks interface { NegotiateClaim(ctx context.Context, entityPath string) (context.CancelFunc, <-chan struct{}, error) NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession, uint64, error) - NewRPCLink(ctx context.Context, managementPath string) (RPCLink, uint64, error) + NewRPCLink(ctx context.Context, managementPath string) (amqpwrap.RPCLink, uint64, error) GetEntityAudience(entityPath string) string Recover(ctx context.Context, clientRevision uint64) error Close(ctx context.Context, permanently bool) error @@ -263,8 +263,11 @@ func (ns *Namespace) Recover(ctx context.Context, theirConnID uint64) error { oldClient := ns.client ns.client = nil - // the error on close isn't critical - _ = oldClient.Close() + if err := oldClient.Close(); err != nil { + // the error on close isn't critical, we don't need to exit or + // return it. + log.Writef(exported.EventConn, "Error closing old client: %s", err.Error()) + } } log.Writef(exported.EventConn, "Creating a new client (rev:%d)", ns.connID) @@ -276,6 +279,11 @@ func (ns *Namespace) Recover(ctx context.Context, theirConnID uint64) error { return nil } +// negotiateClaimFn matches the signature for NegotiateClaim, and is used when we want to stub things out for tests. +type negotiateClaimFn func( + ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider, + contextWithTimeoutFn contextWithTimeoutFn) error + // negotiateClaim performs initial authentication and starts periodic refresh of credentials. // the returned func is to cancel() the refresh goroutine. func (ns *Namespace) NegotiateClaim(ctx context.Context, entityPath string) (context.CancelFunc, <-chan struct{}, error) { @@ -291,7 +299,7 @@ func (ns *Namespace) NegotiateClaim(ctx context.Context, entityPath string) (con // when the background renewal stops or an error. func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context, entityPath string, - cbsNegotiateClaim func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider) error, + cbsNegotiateClaim negotiateClaimFn, nextClaimRefreshDurationFn func(expirationTime time.Time, currentTime time.Time) time.Duration) (func(), <-chan struct{}, error) { audience := ns.GetEntityAudience(entityPath) @@ -317,7 +325,7 @@ func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context, // The current cbs.NegotiateClaim implementation automatically creates and shuts // down it's own link so we have to guard against that here. ns.negotiateClaimMu.Lock() - err = cbsNegotiateClaim(ctx, audience, amqpClient, token) + err = cbsNegotiateClaim(ctx, audience, amqpClient, token, context.WithTimeout) ns.negotiateClaimMu.Unlock() if err != nil { @@ -370,7 +378,7 @@ func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context, return case <-time.After(nextClaimAt): for { - err := utils.Retry(refreshCtx, exported.EventAuth, "NegotiateClaimRefresh", ns.RetryOptions, func(ctx context.Context, args *utils.RetryFnArgs) error { + err := utils.Retry(refreshCtx, exported.EventAuth, func() string { return "NegotiateClaimRefresh" }, ns.RetryOptions, func(ctx context.Context, args *utils.RetryFnArgs) error { tmpExpiresOn, err := refreshClaim(ctx) if err != nil { diff --git a/sdk/messaging/azeventhubs/internal/namespace_test.go b/sdk/messaging/azeventhubs/internal/namespace_test.go index 3d6fe071c22e..0f76710faeb8 100644 --- a/sdk/messaging/azeventhubs/internal/namespace_test.go +++ b/sdk/messaging/azeventhubs/internal/namespace_test.go @@ -67,7 +67,7 @@ func TestNamespaceNegotiateClaim(t *testing.T) { cbsNegotiateClaimCalled := 0 - cbsNegotiateClaim := func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider) error { + cbsNegotiateClaim := func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider, contextWithTimeoutFn contextWithTimeoutFn) error { cbsNegotiateClaimCalled++ return nil } @@ -111,7 +111,7 @@ func TestNamespaceNegotiateClaimRenewal(t *testing.T) { cbsNegotiateClaimCalled := 0 - cbsNegotiateClaim := func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider) error { + cbsNegotiateClaim := func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider, contextWithTimeoutFn contextWithTimeoutFn) error { cbsNegotiateClaimCalled++ return nil } @@ -160,7 +160,7 @@ func TestNamespaceNegotiateClaimFailsToGetClient(t *testing.T) { cancel, _, err := ns.startNegotiateClaimRenewer( context.Background(), "entity path", - func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider) error { + func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider, contextWithTimeoutFn contextWithTimeoutFn) error { return errors.New("NegotiateClaim amqp.Client failed") }, func(expirationTime, currentTime time.Time) time.Duration { // refresh immediately since we're in a unit test. @@ -182,7 +182,7 @@ func TestNamespaceNegotiateClaimNonRenewableToken(t *testing.T) { cbsNegotiateClaimCalled := 0 - cbsNegotiateClaim := func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider) error { + cbsNegotiateClaim := func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider, contextWithTimeoutFn contextWithTimeoutFn) error { cbsNegotiateClaimCalled++ return nil } @@ -222,7 +222,7 @@ func TestNamespaceNegotiateClaimFails(t *testing.T) { cancel, _, err := ns.startNegotiateClaimRenewer( context.Background(), "entity path", - func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider) error { + func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider, contextWithTimeoutFn contextWithTimeoutFn) error { return errors.New("NegotiateClaim amqp.Client failed") }, func(expirationTime, currentTime time.Time) time.Duration { // not even used. @@ -240,7 +240,7 @@ func TestNamespaceNegotiateClaimFatalErrors(t *testing.T) { cbsNegotiateClaimCalled := 0 - cbsNegotiateClaim := func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider) error { + cbsNegotiateClaim := func(ctx context.Context, audience string, conn amqpwrap.AMQPClient, provider auth.TokenProvider, contextWithTimeoutFn contextWithTimeoutFn) error { cbsNegotiateClaimCalled++ // work the first time, fail on renewals. diff --git a/sdk/messaging/azeventhubs/internal/rpc.go b/sdk/messaging/azeventhubs/internal/rpc.go index b0c165490356..daee2d2b2787 100644 --- a/sdk/messaging/azeventhubs/internal/rpc.go +++ b/sdk/messaging/azeventhubs/internal/rpc.go @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + package internal import ( @@ -50,15 +51,6 @@ type ( uuidNewV4 func() (uuid.UUID, error) } - // RPCResponse is the simplified response structure from an RPC like call - RPCResponse struct { - // Code is the response code - these originate from Service Bus. Some - // common values are called out below, with the RPCResponseCode* constants. - Code int - Description string - Message *amqp.Message - } - // RPCLinkOption provides a way to customize the construction of a Link RPCLinkOption func(link *rpcLink) error @@ -79,7 +71,7 @@ const ( // RPCError is an error from an RPCLink. // RPCLinks are used for communication with the $management and $cbs links. type RPCError struct { - Resp *RPCResponse + Resp *amqpwrap.RPCResponse Message string } @@ -226,7 +218,7 @@ func (l *rpcLink) startResponseRouter() { } // RPC sends a request and waits on a response for that request -func (l *rpcLink) RPC(ctx context.Context, msg *amqp.Message) (*RPCResponse, error) { +func (l *rpcLink) RPC(ctx context.Context, msg *amqp.Message) (*amqpwrap.RPCResponse, error) { l.startResponseRouterOnce.Do(func() { go l.startResponseRouter() }) @@ -311,7 +303,7 @@ func (l *rpcLink) RPC(ctx context.Context, msg *amqp.Message) (*RPCResponse, err } } - response := &RPCResponse{ + response := &amqpwrap.RPCResponse{ Code: int(statusCode), Description: description, Message: res, @@ -331,11 +323,7 @@ func (l *rpcLink) RPC(ctx context.Context, msg *amqp.Message) (*RPCResponse, err } // Close the link receiver, sender and session -func (l *rpcLink) Close(_ context.Context) error { - // we're finding, in practice, that allowing cancellations when cleaning up state - // just results in inconsistencies. We'll cut cancellation off here for now. - ctx := context.Background() - +func (l *rpcLink) Close(ctx context.Context) error { l.rpcLinkCtxCancel() if err := l.closeReceiver(ctx); err != nil { @@ -460,7 +448,7 @@ func addMessageID(message *amqp.Message, uuidNewV4 func() (uuid.UUID, error)) (* // asRPCError checks to see if the res is actually a failed request // (where failed means the status code was non-2xx). If so, // it returns true and updates the struct pointed to by err. -func asRPCError(res *RPCResponse, err *RPCError) bool { +func asRPCError(res *amqpwrap.RPCResponse, err *RPCError) bool { if res == nil { return false } diff --git a/sdk/messaging/azeventhubs/internal/test/test_helpers.go b/sdk/messaging/azeventhubs/internal/test/test_helpers.go index 1158cb1328f7..fefbab02d2b0 100644 --- a/sdk/messaging/azeventhubs/internal/test/test_helpers.go +++ b/sdk/messaging/azeventhubs/internal/test/test_helpers.go @@ -13,9 +13,11 @@ import ( "strings" "sync" "testing" + "time" azlog "github.com/Azure/azure-sdk-for-go/sdk/internal/log" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/conn" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/exported" "github.com/joho/godotenv" "github.com/stretchr/testify/require" ) @@ -64,6 +66,7 @@ func CaptureLogsForTestWithChannel(messagesCh chan string) func() []string { // EnableStdoutLogging turns on logging to stdout for diagnostics. func EnableStdoutLogging() { + azlog.SetEvents(exported.EventAuth, exported.EventConn, exported.EventConsumer, exported.EventProducer) setAzLogListener(func(e azlog.Event, s string) { log.Printf("%s %s", e, s) }) @@ -90,20 +93,23 @@ func RandomString(prefix string, length int) string { } type ConnectionParamsForTest struct { - ConnectionString string - EventHubName string - EventHubNamespace string - + ConnectionString string + EventHubName string + EventHubNamespace string + ResourceGroup string StorageConnectionString string + SubscriptionID string } func GetConnectionParamsForTest(t *testing.T) ConnectionParamsForTest { _ = godotenv.Load() envVars := mustGetEnvironmentVars(t, []string{ + "AZURE_SUBSCRIPTION_ID", + "CHECKPOINTSTORE_STORAGE_CONNECTION_STRING", "EVENTHUB_CONNECTION_STRING", "EVENTHUB_NAME", - "CHECKPOINTSTORE_STORAGE_CONNECTION_STRING", + "RESOURCE_GROUP", }) parsedConn, err := conn.ParsedConnectionFromStr(envVars["EVENTHUB_CONNECTION_STRING"]) @@ -113,7 +119,9 @@ func GetConnectionParamsForTest(t *testing.T) ConnectionParamsForTest { ConnectionString: envVars["EVENTHUB_CONNECTION_STRING"], EventHubName: envVars["EVENTHUB_NAME"], EventHubNamespace: parsedConn.Namespace, + ResourceGroup: envVars["RESOURCE_GROUP"], StorageConnectionString: envVars["CHECKPOINTSTORE_STORAGE_CONNECTION_STRING"], + SubscriptionID: envVars["AZURE_SUBSCRIPTION_ID"], } } @@ -145,3 +153,16 @@ func RequireClose(t *testing.T, closeable interface { }) { require.NoError(t, closeable.Close(context.Background())) } + +// RequireContextHasDefaultTimeout checks that the context has a deadline set, and that it's +// using the right timeout. +// NOTE: There's some wiggle room since some time will expire before this is called. +func RequireContextHasDefaultTimeout(t *testing.T, ctx context.Context, timeout time.Duration) { + tm, hasDeadline := ctx.Deadline() + + require.True(t, hasDeadline, "deadline must exist, we always set an operation timeout") + duration := time.Until(tm) + + require.Greater(t, duration, time.Duration(0)) + require.LessOrEqual(t, duration, timeout) +} diff --git a/sdk/messaging/azeventhubs/internal/utils/retrier.go b/sdk/messaging/azeventhubs/internal/utils/retrier.go index 62f353dab0ba..a61eb134934c 100644 --- a/sdk/messaging/azeventhubs/internal/utils/retrier.go +++ b/sdk/messaging/azeventhubs/internal/utils/retrier.go @@ -34,7 +34,7 @@ func (rf *RetryFnArgs) ResetAttempts() { // Retry runs a standard retry loop. It executes your passed in fn as the body of the loop. // It returns if it exceeds the number of configured retry options or if 'isFatal' returns true. -func Retry(ctx context.Context, eventName log.Event, operation string, o exported.RetryOptions, fn func(ctx context.Context, callbackArgs *RetryFnArgs) error, isFatalFn func(err error) bool) error { +func Retry(ctx context.Context, eventName log.Event, prefix func() string, o exported.RetryOptions, fn func(ctx context.Context, callbackArgs *RetryFnArgs) error, isFatalFn func(err error) bool) error { if isFatalFn == nil { panic("isFatalFn is nil, errors would panic") } @@ -47,7 +47,7 @@ func Retry(ctx context.Context, eventName log.Event, operation string, o exporte for i := int32(0); i <= ro.MaxRetries; i++ { if i > 0 { sleep := calcDelay(ro, i) - log.Writef(eventName, "(%s) Retry attempt %d sleeping for %s", operation, i, sleep) + log.Writef(eventName, "(%s) Retry attempt %d sleeping for %s", prefix(), i, sleep) select { case <-ctx.Done(): @@ -63,7 +63,7 @@ func Retry(ctx context.Context, eventName log.Event, operation string, o exporte err = fn(ctx, &args) if args.resetAttempts { - log.Writef(eventName, "(%s) Resetting retry attempts", operation) + log.Writef(eventName, "(%s) Resetting retry attempts", prefix()) // it looks weird, but we're doing -1 here because the post-increment // will set it back to 0, which is what we want - go back to the 0th @@ -76,13 +76,13 @@ func Retry(ctx context.Context, eventName log.Event, operation string, o exporte if err != nil { if isFatalFn(err) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - log.Writef(eventName, "(%s) Retry attempt %d was cancelled, stopping: %s", operation, i, err.Error()) + log.Writef(eventName, "(%s) Retry attempt %d was cancelled, stopping: %s", prefix(), i, err.Error()) } else { - log.Writef(eventName, "(%s) Retry attempt %d returned non-retryable error: %s", operation, i, err.Error()) + log.Writef(eventName, "(%s) Retry attempt %d returned non-retryable error: %s", prefix(), i, err.Error()) } return err } else { - log.Writef(eventName, "(%s) Retry attempt %d returned retryable error: %s", operation, i, err.Error()) + log.Writef(eventName, "(%s) Retry attempt %d returned retryable error: %s", prefix(), i, err.Error()) } continue diff --git a/sdk/messaging/azeventhubs/internal/utils/retrier_test.go b/sdk/messaging/azeventhubs/internal/utils/retrier_test.go index 6d39f3136549..240c4a9e0cfb 100644 --- a/sdk/messaging/azeventhubs/internal/utils/retrier_test.go +++ b/sdk/messaging/azeventhubs/internal/utils/retrier_test.go @@ -26,7 +26,7 @@ func TestRetrier(t *testing.T) { called := 0 - err := Retry(ctx, testLogEvent, "notused", exported.RetryOptions{}, func(ctx context.Context, args *RetryFnArgs) error { + err := Retry(ctx, testLogEvent, func() string { return "notused" }, exported.RetryOptions{}, func(ctx context.Context, args *RetryFnArgs) error { called++ return nil }, func(err error) bool { @@ -50,7 +50,7 @@ func TestRetrier(t *testing.T) { return false } - err := Retry(ctx, testLogEvent, "notused", fastRetryOptions, func(ctx context.Context, args *RetryFnArgs) error { + err := Retry(ctx, testLogEvent, func() string { return "notused" }, fastRetryOptions, func(ctx context.Context, args *RetryFnArgs) error { called++ if args.I == 3 { @@ -77,7 +77,7 @@ func TestRetrier(t *testing.T) { return true } - err := Retry(ctx, testLogEvent, "notused", exported.RetryOptions{}, func(ctx context.Context, args *RetryFnArgs) error { + err := Retry(ctx, testLogEvent, func() string { return "notused" }, exported.RetryOptions{}, func(ctx context.Context, args *RetryFnArgs) error { called++ return errors.New("isFatalFn says this is a fatal error") }, isFatalFn) @@ -98,7 +98,7 @@ func TestRetrier(t *testing.T) { maxRetries := int32(2) - err := Retry(context.Background(), testLogEvent, "notused", exported.RetryOptions{ + err := Retry(context.Background(), testLogEvent, func() string { return "notused" }, exported.RetryOptions{ MaxRetries: maxRetries, RetryDelay: time.Millisecond, MaxRetryDelay: time.Millisecond, @@ -131,7 +131,7 @@ func TestRetrier(t *testing.T) { called := 0 - err := Retry(context.Background(), testLogEvent, "notused", customRetryOptions, func(ctx context.Context, args *RetryFnArgs) error { + err := Retry(context.Background(), testLogEvent, func() string { return "notused" }, customRetryOptions, func(ctx context.Context, args *RetryFnArgs) error { called++ return errors.New("whatever") }, isFatalFn) @@ -151,7 +151,7 @@ func TestCancellationCancelsSleep(t *testing.T) { called := 0 - err := Retry(ctx, testLogEvent, "notused", exported.RetryOptions{ + err := Retry(ctx, testLogEvent, func() string { return "notused" }, exported.RetryOptions{ RetryDelay: time.Hour, }, func(ctx context.Context, args *RetryFnArgs) error { called++ @@ -175,7 +175,7 @@ func TestCancellationFromUserFunc(t *testing.T) { called := 0 - err := Retry(alreadyCancelledCtx, testLogEvent, "notused", exported.RetryOptions{}, func(ctx context.Context, args *RetryFnArgs) error { + err := Retry(alreadyCancelledCtx, testLogEvent, func() string { return "notused" }, exported.RetryOptions{}, func(ctx context.Context, args *RetryFnArgs) error { called++ select { @@ -199,7 +199,7 @@ func TestCancellationTimeoutsArentPropagatedToUser(t *testing.T) { tryAgainErr := errors.New("try again") called := 0 - err := Retry(context.Background(), testLogEvent, "notused", exported.RetryOptions{ + err := Retry(context.Background(), testLogEvent, func() string { return "notused" }, exported.RetryOptions{ RetryDelay: time.Millisecond, }, func(ctx context.Context, args *RetryFnArgs) error { called++ @@ -301,7 +301,7 @@ func TestRetryLogging(t *testing.T) { t.Run("normal error", func(t *testing.T) { logs = nil - err := Retry(context.Background(), testLogEvent, "my_operation", exported.RetryOptions{ + err := Retry(context.Background(), testLogEvent, func() string { return "my_operation" }, exported.RetryOptions{ RetryDelay: time.Microsecond, }, func(ctx context.Context, args *RetryFnArgs) error { azlog.Writef("TestFunc", "Attempt %d, within test func, returning error hello", args.I) @@ -332,7 +332,7 @@ func TestRetryLogging(t *testing.T) { t.Run("cancellation error", func(t *testing.T) { logs = nil - err := Retry(context.Background(), testLogEvent, "test_operation", exported.RetryOptions{ + err := Retry(context.Background(), testLogEvent, func() string { return "test_operation" }, exported.RetryOptions{ RetryDelay: time.Microsecond, }, func(ctx context.Context, args *RetryFnArgs) error { azlog.Writef("TestFunc", @@ -352,7 +352,7 @@ func TestRetryLogging(t *testing.T) { t.Run("custom fatal error", func(t *testing.T) { logs = nil - err := Retry(context.Background(), testLogEvent, "test_operation", exported.RetryOptions{ + err := Retry(context.Background(), testLogEvent, func() string { return "test_operation" }, exported.RetryOptions{ RetryDelay: time.Microsecond, }, func(ctx context.Context, args *RetryFnArgs) error { azlog.Writef("TestFunc", @@ -374,7 +374,7 @@ func TestRetryLogging(t *testing.T) { reset := false - err := Retry(context.Background(), testLogEvent, "test_operation", exported.RetryOptions{ + err := Retry(context.Background(), testLogEvent, func() string { return "test_operation" }, exported.RetryOptions{ RetryDelay: time.Microsecond, }, func(ctx context.Context, args *RetryFnArgs) error { azlog.Writef("TestFunc", "Attempt %d, within test func", args.I) diff --git a/sdk/messaging/azeventhubs/mgmt.go b/sdk/messaging/azeventhubs/mgmt.go index 84399a10785b..333bf717d2a7 100644 --- a/sdk/messaging/azeventhubs/mgmt.go +++ b/sdk/messaging/azeventhubs/mgmt.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal" + "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/amqpwrap" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/eh" "github.com/Azure/azure-sdk-for-go/sdk/messaging/azeventhubs/internal/go-amqp" ) @@ -28,7 +29,7 @@ type GetEventHubPropertiesOptions struct { } // getEventHubProperties gets event hub properties, like the available partition IDs and when the Event Hub was created. -func getEventHubProperties(ctx context.Context, ns internal.NamespaceForManagementOps, rpcLink internal.RPCLink, eventHub string, options *GetEventHubPropertiesOptions) (EventHubProperties, error) { +func getEventHubProperties(ctx context.Context, ns internal.NamespaceForManagementOps, rpcLink amqpwrap.RPCLink, eventHub string, options *GetEventHubPropertiesOptions) (EventHubProperties, error) { token, err := ns.GetTokenForEntity(eventHub) if err != nil { @@ -87,7 +88,7 @@ type GetPartitionPropertiesOptions struct { // getPartitionProperties gets properties for a specific partition. This includes data like the last enqueued sequence number, the first sequence // number and when an event was last enqueued to the partition. -func getPartitionProperties(ctx context.Context, ns internal.NamespaceForManagementOps, rpcLink internal.RPCLink, eventHub string, partitionID string, options *GetPartitionPropertiesOptions) (PartitionProperties, error) { +func getPartitionProperties(ctx context.Context, ns internal.NamespaceForManagementOps, rpcLink amqpwrap.RPCLink, eventHub string, partitionID string, options *GetPartitionPropertiesOptions) (PartitionProperties, error) { token, err := ns.GetTokenForEntity(eventHub) if err != nil { diff --git a/sdk/messaging/azeventhubs/partition_client.go b/sdk/messaging/azeventhubs/partition_client.go index 3b03801fec33..eec90d5615e1 100644 --- a/sdk/messaging/azeventhubs/partition_client.go +++ b/sdk/messaging/azeventhubs/partition_client.go @@ -109,9 +109,10 @@ func (pc *PartitionClient) ReceiveEvents(ctx context.Context, count int, options if count > int(remainingCredits) { newCredits := uint32(count) - remainingCredits - log.Writef(EventConsumer, "Have %d outstanding credit, only issuing %d credits", remainingCredits, newCredits) + log.Writef(EventConsumer, "(%s) Have %d outstanding credit, only issuing %d credits", lwid.String(), remainingCredits, newCredits) if err := lwid.Link.IssueCredit(newCredits); err != nil { + log.Writef(EventConsumer, "(%s) Error when issuing credits: %s", lwid.String(), err) return err } } @@ -121,6 +122,7 @@ func (pc *PartitionClient) ReceiveEvents(ctx context.Context, count int, options amqpMessage, err := lwid.Link.Receive(ctx) if internal.IsOwnershipLostError(err) { + log.Writef(EventConsumer, "(%s) Error, link ownership lost: %s", lwid.String(), err) events = nil return err } @@ -132,6 +134,7 @@ func (pc *PartitionClient) ReceiveEvents(ctx context.Context, count int, options re, err := newReceivedEventData(amqpMsg) if err != nil { + log.Writef(EventConsumer, "(%s) Failed converting AMQP message to EventData: %s", lwid.String(), err) return err } @@ -149,6 +152,7 @@ func (pc *PartitionClient) ReceiveEvents(ctx context.Context, count int, options receivedEvent, err := newReceivedEventData(amqpMessage) if err != nil { + log.Writef(EventConsumer, "(%s) Failed converting AMQP message to EventData: %s", lwid.String(), err) return err } diff --git a/sdk/messaging/azeventhubs/test-resources.json b/sdk/messaging/azeventhubs/test-resources.json index c2e28b54a8f3..54f10236d968 100644 --- a/sdk/messaging/azeventhubs/test-resources.json +++ b/sdk/messaging/azeventhubs/test-resources.json @@ -212,6 +212,14 @@ "CHECKPOINTSTORE_STORAGE_CONNECTION_STRING": { "type": "string", "value": "[concat('DefaultEndpointsProtocol=https;AccountName=', variables('storageAccountName'), ';AccountKey=', listKeys(resourceId('Microsoft.Storage/storageAccounts', variables('storageAccountName')), variables('storageApiVersion')).keys[0].value, ';EndpointSuffix=', parameters('storageEndpointSuffix'))]" + }, + "RESOURCE_GROUP": { + "type": "string", + "value": "[resourceGroup().name]" + }, + "AZURE_SUBSCRIPTION_ID": { + "type": "string", + "value": "[subscription().subscriptionId]" } } }