Skip to content
This repository has been archived by the owner on Jan 24, 2019. It is now read-only.

Commit

Permalink
Merge pull request #81 from 18F/access-token-refactor
Browse files Browse the repository at this point in the history
Refactor pass_access_token changes from #80
  • Loading branch information
jehiah committed Apr 7, 2015
2 parents b0f0409 + 83ad43a commit 9534808
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 51 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ _testmain.go
*.exe
dist
.godeps

# Editor swap/temp files
.*.swp
31 changes: 31 additions & 0 deletions cookies.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,34 @@ func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (st

return string(encrypted_access_token), nil
}

func buildCookieValue(email string, aes_cipher cipher.Block,
access_token string) (string, error) {
if aes_cipher == nil {
return email, nil
}

encoded_token, err := encodeAccessToken(aes_cipher, access_token)
if err != nil {
return email, fmt.Errorf(
"error encoding access token for %s: %s", email, err)
}
return email + "|" + encoded_token, nil
}

func parseCookieValue(value string, aes_cipher cipher.Block) (email, user,
access_token string, err error) {
components := strings.Split(value, "|")
email = components[0]
user = strings.Split(email, "@")[0]

if aes_cipher != nil && len(components) == 2 {
access_token, err = decodeAccessToken(aes_cipher, components[1])
if err != nil {
err = fmt.Errorf(
"error decoding access token for %s: %s",
email, err)
}
}
return email, user, access_token, err
}
52 changes: 52 additions & 0 deletions cookies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"crypto/aes"
"github.com/bmizerany/assert"
"strings"
"testing"
)

Expand All @@ -21,3 +22,54 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
assert.NotEqual(t, access_token, encoded_token)
assert.Equal(t, access_token, decoded_token)
}

func TestBuildCookieValueWithoutAccessToken(t *testing.T) {
value, err := buildCookieValue("[email protected]", nil, "")
assert.Equal(t, nil, err)
assert.Equal(t, "[email protected]", value)
}

func TestBuildCookieValueWithAccessTokenAndNilCipher(t *testing.T) {
value, err := buildCookieValue("[email protected]", nil,
"access token")
assert.Equal(t, nil, err)
assert.Equal(t, "[email protected]", value)
}

func TestParseCookieValueWithoutAccessToken(t *testing.T) {
email, user, access_token, err := parseCookieValue(
"[email protected]", nil)
assert.Equal(t, nil, err)
assert.Equal(t, "[email protected]", email)
assert.Equal(t, "michael.bland", user)
assert.Equal(t, "", access_token)
}

func TestParseCookieValueWithAccessTokenAndNilCipher(t *testing.T) {
email, user, access_token, err := parseCookieValue(
"[email protected]|access_token", nil)
assert.Equal(t, nil, err)
assert.Equal(t, "[email protected]", email)
assert.Equal(t, "michael.bland", user)
assert.Equal(t, "", access_token)
}

func TestBuildAndParseCookieValueWithAccessToken(t *testing.T) {
aes_cipher, err := aes.NewCipher([]byte("0123456789abcdef"))
assert.Equal(t, nil, err)
value, err := buildCookieValue("[email protected]", aes_cipher,
"access_token")
assert.Equal(t, nil, err)

prefix := "[email protected]|"
if !strings.HasPrefix(value, prefix) {
t.Fatal("cookie value does not start with \"%s\": %s",
prefix, value)
}

email, user, access_token, err := parseCookieValue(value, aes_cipher)
assert.Equal(t, nil, err)
assert.Equal(t, "[email protected]", email)
assert.Equal(t, "michael.bland", user)
assert.Equal(t, "access_token", access_token)
}
45 changes: 10 additions & 35 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ type OauthProxy struct {
DisplayHtpasswdForm bool
serveMux http.Handler
PassBasicAuth bool
PassAccessToken bool
AesCipher cipher.Block
skipAuthRegex []string
compiledRegex []*regexp.Regexp
Expand Down Expand Up @@ -121,20 +120,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain)

var aes_cipher cipher.Block

if opts.PassAccessToken == true {
valid_cookie_secret_size := false
for _, i := range []int{16, 24, 32} {
if len(opts.CookieSecret) == i {
valid_cookie_secret_size = true
}
}
if valid_cookie_secret_size == false {
log.Fatal("cookie_secret must be 16, 24, or 32 bytes " +
"to create an AES cipher when " +
"pass_access_token == true")
}

if opts.PassAccessToken {
var err error
aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret))
if err != nil {
Expand Down Expand Up @@ -163,7 +149,6 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
skipAuthRegex: opts.SkipAuthRegex,
compiledRegex: opts.CompiledRegex,
PassBasicAuth: opts.PassBasicAuth,
PassAccessToken: opts.PassAccessToken,
AesCipher: aes_cipher,
templates: loadTemplates(opts.CustomTemplatesDir),
}
Expand Down Expand Up @@ -440,20 +425,12 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// set cookie, or deny
if p.Validator(email) {
log.Printf("%s authenticating %s completed", remoteAddr, email)
encoded_token := ""
if p.PassAccessToken {
encoded_token, err = encodeAccessToken(p.AesCipher, access_token)
if err != nil {
log.Printf("error encoding access token: %s", err)
}
}
access_token = ""

if encoded_token != "" {
p.SetCookie(rw, req, email+"|"+encoded_token)
} else {
p.SetCookie(rw, req, email)
value, err := buildCookieValue(
email, p.AesCipher, access_token)
if err != nil {
log.Printf(err.Error())
}
p.SetCookie(rw, req, value)
http.Redirect(rw, req, redirect, 302)
return
} else {
Expand All @@ -467,15 +444,13 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if err == nil {
var value string
value, ok = validateCookie(cookie, p.CookieSeed)
components := strings.Split(value, "|")
email = components[0]
if len(components) == 2 {
access_token, err = decodeAccessToken(p.AesCipher, components[1])
if ok {
email, user, access_token, err = parseCookieValue(
value, p.AesCipher)
if err != nil {
log.Printf("error decoding access token: %s", err)
log.Printf(err.Error())
}
}
user = strings.Split(email, "@")[0]
}
}

Expand Down
33 changes: 17 additions & 16 deletions oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,24 +152,25 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
return t
}

func Close(t *PassAccessTokenTest) {
t.provider_server.Close()
func (pat_test *PassAccessTokenTest) Close() {
pat_test.provider_server.Close()
}

func getCallbackEndpoint(pac_test *PassAccessTokenTest) (http_code int, cookie string) {
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
cookie string) {
rw := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
strings.NewReader(""))
if err != nil {
return 0, ""
}
pac_test.proxy.ServeHTTP(rw, req)
pat_test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.HeaderMap["Set-Cookie"][0]
}

func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code int,
access_token string) {
cookie_key := pac_test.proxy.CookieKey
func (pat_test *PassAccessTokenTest) getRootEndpoint(
cookie string) (http_code int, access_token string) {
cookie_key := pat_test.proxy.CookieKey
var value string
key_prefix := cookie_key + "="

Expand Down Expand Up @@ -198,43 +199,43 @@ func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code in
})

rw := httptest.NewRecorder()
pac_test.proxy.ServeHTTP(rw, req)
pat_test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}

func TestForwardAccessTokenUpstream(t *testing.T) {
pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true,
})
defer Close(pac_test)
defer pat_test.Close()

// A successful validation will redirect and set the auth cookie.
code, cookie := getCallbackEndpoint(pac_test)
code, cookie := pat_test.getCallbackEndpoint()
assert.Equal(t, 302, code)
assert.NotEqual(t, nil, cookie)

// Now we make a regular request; the access_token from the cookie is
// forwarded as the "X-Forwarded-Access-Token" header. The token is
// read by the test provider server and written in the response body.
code, payload := getRootEndpoint(pac_test, cookie)
code, payload := pat_test.getRootEndpoint(cookie)
assert.Equal(t, 200, code)
assert.Equal(t, "my_auth_token", payload)
}

func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: false,
})
defer Close(pac_test)
defer pat_test.Close()

// A successful validation will redirect and set the auth cookie.
code, cookie := getCallbackEndpoint(pac_test)
code, cookie := pat_test.getCallbackEndpoint()
assert.Equal(t, 302, code)
assert.NotEqual(t, nil, cookie)

// Now we make a regular request, but the access token header should
// not be present.
code, payload := getRootEndpoint(pac_test, cookie)
code, payload := pat_test.getRootEndpoint(cookie)
assert.Equal(t, 200, code)
assert.Equal(t, "No access token found.", payload)
}
Expand Down
17 changes: 17 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,23 @@ func (o *Options) Validate() error {
}
msgs = parseProviderInfo(o, msgs)

if o.PassAccessToken {
valid_cookie_secret_size := false
for _, i := range []int{16, 24, 32} {
if len(o.CookieSecret) == i {
valid_cookie_secret_size = true
}
}
if valid_cookie_secret_size == false {
msgs = append(msgs, fmt.Sprintf(
"cookie_secret must be 16, 24, or 32 bytes "+
"to create an AES cipher when "+
"pass_access_token == true, "+
"but is %d bytes",
len(o.CookieSecret)))
}
}

if len(msgs) != 0 {
return fmt.Errorf("Invalid configuration:\n %s",
strings.Join(msgs, "\n "))
Expand Down
19 changes: 19 additions & 0 deletions options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,22 @@ func TestDefaultProviderApiSettings(t *testing.T) {
assert.Equal(t, "", p.ProfileUrl.String())
assert.Equal(t, "profile email", p.Scope)
}

func TestPassAccessTokenRequiresSpecificCookieSecretLengths(t *testing.T) {
o := testOptions()
assert.Equal(t, nil, o.Validate())

assert.Equal(t, false, o.PassAccessToken)
o.PassAccessToken = true
o.CookieSecret = "cookie of invalid length-"
assert.NotEqual(t, nil, o.Validate())

o.CookieSecret = "16 bytes AES-128"
assert.Equal(t, nil, o.Validate())

o.CookieSecret = "24 byte secret AES-192--"
assert.Equal(t, nil, o.Validate())

o.CookieSecret = "32 byte secret for AES-256------"
assert.Equal(t, nil, o.Validate())
}

0 comments on commit 9534808

Please sign in to comment.