diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 75ea34a..149b339 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -98,10 +98,6 @@ func ValidateAccessToken(accessTokenString string, jwtKey []byte) (AccessTokenCl return jwtKey, nil }) - if err != nil { - return *accessClaims, nil - } - if accessToken.Valid { log.WithFields(log.Fields{ "token": accessTokenString, @@ -111,7 +107,7 @@ func ValidateAccessToken(accessTokenString string, jwtKey []byte) (AccessTokenCl } if ve, ok := err.(*jwt.ValidationError); ok { - if ve.Errors&jwt.ValidationErrorMalformed != 0 { + if ve.Errors&(jwt.ValidationErrorMalformed|jwt.ValidationErrorSignatureInvalid) != 0 { return AccessTokenClaims{}, &ErrMalformedToken{} } else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { return AccessTokenClaims{}, &ErrExpiredToken{} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..866e54d --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,56 @@ +package auth + +import ( + "testing" + "time" + + "github.com/dgrijalva/jwt-go" +) + +// Override time value for jwt tests. Restore default value after. +func at(t time.Time, f func()) { + jwt.TimeFunc = func() time.Time { + return t + } + f() + jwt.TimeFunc = time.Now +} + +func TestAuth_ValidateAccessToken(t *testing.T) { + expectedToken := AccessTokenClaims{ + UserID: "1234", + Restricted: "unrestricted", + OrgRole: "member", + StandardClaims: jwt.StandardClaims{ExpiresAt: 1000}, + } + // jwt with the claims of expectedToken signed by secretKey + jwtString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySWQiOiIxMjM0IiwicmVzdHJpY3RlZCI6InVucmVzdHJpY3RlZCIsIm9yZ1JvbGUiOiJtZW1iZXIiLCJleHAiOjEwMDB9.Zc4mrnogDccYffA7dWogdWsZMELftQluh2X5xDyzOpA" + secretKey := []byte("secret") + + // Check that decrypt failure is detected + token, err := ValidateAccessToken(jwtString, []byte("incorrectSecret")) + if err == nil { + t.Errorf("[IncorrectKey] Expected an error when validating a token with the incorrect key, instead got token %v", token) + } else if _, ok := err.(*ErrMalformedToken); !ok { + t.Errorf("[IncorrectKey] Expected an ErrMalformedToken error when validating a token with the incorrect key, instead got error %T:%v", err, err) + } + + // Check that token expiration check works + token, err = ValidateAccessToken(jwtString, secretKey) + if err == nil { + t.Errorf("[TokenExpired] Expected an error when validating an expired token, instead got token %v", token) + } else if _, ok := err.(*ErrExpiredToken); !ok { + t.Errorf("[TokenExpired] Expected an ErrExpiredToken error when validating an expired token, instead got error %T:%v", err, err) + } + + // Check that token validation works with a valid token + // Set the time to be valid for the token expiration + at(time.Unix(500, 0), func() { + token, err = ValidateAccessToken(jwtString, secretKey) + if err != nil { + t.Errorf("[TokenValid] Expected no errors when validating token, instead got err %v", err) + } else if token != expectedToken { + t.Errorf("[TokenValid] Expected token with claims %v but instead had claims %v", expectedToken, token) + } + }) +} \ No newline at end of file diff --git a/magefile.go b/magefile.go index c3efec4..24845db 100644 --- a/magefile.go +++ b/magefile.go @@ -118,6 +118,11 @@ func (Backend) Schema() error { return sh.Run("gqlgen") } +func (Backend) Test() error { + fmt.Println("running taskcafe backend unit tests") + return sh.RunV("go", "test", "./...") +} + // Install runs frontend:install func Install() { mg.SerialDeps(Frontend.Install)