Skip to content

Commit

Permalink
BearerTokenPolicy rewinds bodies before retrying (#23597)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Oct 17, 2024
1 parent 546e099 commit 32f5e82
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 10 deletions.
6 changes: 2 additions & 4 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# Release History

## 1.15.1 (Unreleased)
## 1.16.0 (2024-10-17)

### Features Added

* Added field `Kind` to `runtime.StartSpanOptions` to allow a kind to be set when starting a span.

### Breaking Changes

### Bugs Fixed

### Other Changes
* `BearerTokenPolicy` now rewinds request bodies before retrying

## 1.15.0 (2024-10-14)

Expand Down
2 changes: 1 addition & 1 deletion sdk/azcore/internal/shared/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,5 @@ const (
Module = "azcore"

// Version is the semantic version (see http://semver.org) of this module.
Version = "v1.15.1"
Version = "v1.16.0"
)
13 changes: 8 additions & 5 deletions sdk/azcore/runtime/policy_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,17 @@ func (b *BearerTokenPolicy) handleChallenge(req *policy.Request, res *http.Respo
tro.Claims = caeChallenge.params["claims"]
return b.authenticateAndAuthorize(req)(tro)
}
err = b.authzHandler.OnRequest(req, authNZ)
if err == nil {
res, err = req.Next()
if err = b.authzHandler.OnRequest(req, authNZ); err == nil {
if err = req.RewindBody(); err == nil {
res, err = req.Next()
}
}
case b.authzHandler.OnChallenge != nil && !recursed:
if err = b.authzHandler.OnChallenge(req, res, b.authenticateAndAuthorize(req)); err == nil {
if res, err = req.Next(); err == nil {
res, err = b.handleChallenge(req, res, true)
if err = req.RewindBody(); err == nil {
if res, err = req.Next(); err == nil {
res, err = b.handleChallenge(req, res, true)
}
}
} else {
// don't retry challenge handling errors
Expand Down
57 changes: 57 additions & 0 deletions sdk/azcore/runtime/policy_bearer_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"encoding/base64"
"fmt"
"io"
"strings"

"errors"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -531,6 +533,61 @@ func TestBearerTokenPolicy_RequiresHTTPS(t *testing.T) {
require.ErrorAs(t, err, &nre)
}

func TestBearerTokenPolicy_RewindsBeforeRetry(t *testing.T) {
const expected = "expected"
for _, test := range []struct {
challenge, desc string
onChallenge bool
}{
{
desc: "CAE challenge",
challenge: `Bearer error="insufficient_claims", claims="ey=="`,
},
{
desc: "non-CAE challenge",
challenge: `Bearer authorization_uri="https://login.windows.net/", error="invalid_token"`,
onChallenge: true,
},
} {
t.Run(test.desc, func(t *testing.T) {
read := func(r *http.Request) bool {
actual, err := io.ReadAll(r.Body)
require.NoError(t, err, "request should have body content")
require.EqualValues(t, expected, actual)
return true
}
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(
mock.WithHeader(shared.HeaderWWWAuthenticate, test.challenge),
mock.WithPredicate(read),
mock.WithStatusCode(http.StatusUnauthorized),
)
srv.AppendResponse()
srv.AppendResponse(mock.WithPredicate(read))
srv.AppendResponse()

called := false
o := &policy.BearerTokenOptions{}
if test.onChallenge {
o.AuthorizationHandler.OnChallenge = func(*policy.Request, *http.Response, func(policy.TokenRequestOptions) error) error {
called = true
return nil
}
}
b := NewBearerTokenPolicy(mockCredential{}, []string{scope}, o)
pl := newTestPipeline(&policy.ClientOptions{PerRetryPolicies: []policy.Policy{b}, Transport: srv})
req, err := NewRequest(context.Background(), http.MethodPost, srv.URL())
require.NoError(t, err)
require.NoError(t, req.SetBody(streaming.NopCloser(strings.NewReader(expected)), "text/plain"))

_, err = pl.Do(req)
require.NoError(t, err)
require.Equal(t, test.onChallenge, called, "policy should call OnChallenge when set")
})
}
}

func TestCheckHTTPSForAuth(t *testing.T) {
req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
require.NoError(t, err)
Expand Down

0 comments on commit 32f5e82

Please sign in to comment.