refactor: replace refresh & access token with auth token only
changes authentication to no longer use a refresh token & access token for accessing protected endpoints. Instead only an auth token is used. Before the login flow was: Login -> get refresh (stored as HttpOnly cookie) + access token (stored in memory) -> protected endpoint request (attach access token as Authorization header) -> access token expires in 15 minutes, so use refresh token to obtain new one when that happens now it looks like this: Login -> get auth token (stored as HttpOnly cookie) -> make protected endpont request (token sent) the reasoning for using the refresh + access token was to reduce DB calls, but in the end I don't think its worth the hassle.
This commit is contained in:
@ -1,117 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// RestrictedMode is used restrict JWT access to just the install route
|
||||
type RestrictedMode string
|
||||
|
||||
const (
|
||||
// Unrestricted is the code to allow access to all routes
|
||||
Unrestricted RestrictedMode = "unrestricted"
|
||||
// InstallOnly is the code to restrict access ONLY to install route
|
||||
InstallOnly = "install_only"
|
||||
)
|
||||
|
||||
// Role is the role code for the user
|
||||
type Role string
|
||||
|
||||
const (
|
||||
// RoleAdmin is the code for the admin role
|
||||
RoleAdmin Role = "admin"
|
||||
// RoleMember is the code for the member role
|
||||
RoleMember Role = "member"
|
||||
)
|
||||
|
||||
// AccessTokenClaims is the claims the access JWT token contains
|
||||
type AccessTokenClaims struct {
|
||||
UserID string `json:"userId"`
|
||||
Restricted RestrictedMode `json:"restricted"`
|
||||
OrgRole Role `json:"orgRole"`
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
// ErrExpiredToken is the error returned if the token has expired
|
||||
type ErrExpiredToken struct{}
|
||||
|
||||
// Error returns the error message for ErrExpiredToken
|
||||
func (r *ErrExpiredToken) Error() string {
|
||||
return "token is expired"
|
||||
}
|
||||
|
||||
// ErrMalformedToken is the error returned if the token has malformed
|
||||
type ErrMalformedToken struct{}
|
||||
|
||||
// Error returns the error message for ErrMalformedToken
|
||||
func (r *ErrMalformedToken) Error() string {
|
||||
return "token is malformed"
|
||||
}
|
||||
|
||||
// NewAccessToken generates a new JWT access token with the correct claims
|
||||
func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string, jwtKey []byte, expirationTime time.Duration) (string, error) {
|
||||
role := RoleMember
|
||||
if orgRole == "admin" {
|
||||
role = RoleAdmin
|
||||
}
|
||||
accessExpirationTime := time.Now().Add(expirationTime)
|
||||
accessClaims := &AccessTokenClaims{
|
||||
UserID: userID,
|
||||
Restricted: restrictedMode,
|
||||
OrgRole: role,
|
||||
StandardClaims: jwt.StandardClaims{ExpiresAt: accessExpirationTime.Unix()},
|
||||
}
|
||||
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||||
accessTokenString, err := accessToken.SignedString(jwtKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accessTokenString, nil
|
||||
}
|
||||
|
||||
// NewAccessTokenCustomExpiration creates an access token with a custom duration
|
||||
func NewAccessTokenCustomExpiration(userID string, dur time.Duration, jwtKey []byte) (string, error) {
|
||||
accessExpirationTime := time.Now().Add(dur)
|
||||
accessClaims := &AccessTokenClaims{
|
||||
UserID: userID,
|
||||
Restricted: Unrestricted,
|
||||
OrgRole: RoleMember,
|
||||
StandardClaims: jwt.StandardClaims{ExpiresAt: accessExpirationTime.Unix()},
|
||||
}
|
||||
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||||
accessTokenString, err := accessToken.SignedString(jwtKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accessTokenString, nil
|
||||
}
|
||||
|
||||
// ValidateAccessToken validates a JWT access token and returns the contained claims or an error if it's invalid
|
||||
func ValidateAccessToken(accessTokenString string, jwtKey []byte) (AccessTokenClaims, error) {
|
||||
accessClaims := &AccessTokenClaims{}
|
||||
accessToken, err := jwt.ParseWithClaims(accessTokenString, accessClaims, func(token *jwt.Token) (interface{}, error) {
|
||||
return jwtKey, nil
|
||||
})
|
||||
|
||||
if accessToken.Valid {
|
||||
log.WithFields(log.Fields{
|
||||
"token": accessTokenString,
|
||||
"timeToExpire": time.Unix(accessClaims.ExpiresAt, 0),
|
||||
}).Debug("token is valid")
|
||||
return *accessClaims, nil
|
||||
}
|
||||
|
||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||
if ve.Errors&(jwt.ValidationErrorMalformed|jwt.ValidationErrorSignatureInvalid) != 0 {
|
||||
return AccessTokenClaims{}, &ErrMalformedToken{}
|
||||
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
|
||||
return AccessTokenClaims{}, &ErrExpiredToken{}
|
||||
}
|
||||
}
|
||||
return AccessTokenClaims{}, err
|
||||
}
|
@ -1,56 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
// Override time value for jwt tests. Restore default value after.
|
||||
func at(t time.Time, f func()) {
|
||||
jwt.TimeFunc = func() time.Time {
|
||||
return t
|
||||
}
|
||||
f()
|
||||
jwt.TimeFunc = time.Now
|
||||
}
|
||||
|
||||
func TestAuth_ValidateAccessToken(t *testing.T) {
|
||||
expectedToken := AccessTokenClaims{
|
||||
UserID: "1234",
|
||||
Restricted: "unrestricted",
|
||||
OrgRole: "member",
|
||||
StandardClaims: jwt.StandardClaims{ExpiresAt: 1000},
|
||||
}
|
||||
// jwt with the claims of expectedToken signed by secretKey
|
||||
jwtString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySWQiOiIxMjM0IiwicmVzdHJpY3RlZCI6InVucmVzdHJpY3RlZCIsIm9yZ1JvbGUiOiJtZW1iZXIiLCJleHAiOjEwMDB9.Zc4mrnogDccYffA7dWogdWsZMELftQluh2X5xDyzOpA"
|
||||
secretKey := []byte("secret")
|
||||
|
||||
// Check that decrypt failure is detected
|
||||
token, err := ValidateAccessToken(jwtString, []byte("incorrectSecret"))
|
||||
if err == nil {
|
||||
t.Errorf("[IncorrectKey] Expected an error when validating a token with the incorrect key, instead got token %v", token)
|
||||
} else if _, ok := err.(*ErrMalformedToken); !ok {
|
||||
t.Errorf("[IncorrectKey] Expected an ErrMalformedToken error when validating a token with the incorrect key, instead got error %T:%v", err, err)
|
||||
}
|
||||
|
||||
// Check that token expiration check works
|
||||
token, err = ValidateAccessToken(jwtString, secretKey)
|
||||
if err == nil {
|
||||
t.Errorf("[TokenExpired] Expected an error when validating an expired token, instead got token %v", token)
|
||||
} else if _, ok := err.(*ErrExpiredToken); !ok {
|
||||
t.Errorf("[TokenExpired] Expected an ErrExpiredToken error when validating an expired token, instead got error %T:%v", err, err)
|
||||
}
|
||||
|
||||
// Check that token validation works with a valid token
|
||||
// Set the time to be valid for the token expiration
|
||||
at(time.Unix(500, 0), func() {
|
||||
token, err = ValidateAccessToken(jwtString, secretKey)
|
||||
if err != nil {
|
||||
t.Errorf("[TokenValid] Expected no errors when validating token, instead got err %v", err)
|
||||
} else if token != expectedToken {
|
||||
t.Errorf("[TokenValid] Expected token with claims %v but instead had claims %v", expectedToken, token)
|
||||
}
|
||||
})
|
||||
}
|
@ -86,6 +86,6 @@ func Execute() {
|
||||
viper.SetDefault("queue.store", "memcache://localhost:11211")
|
||||
|
||||
rootCmd.SetVersionTemplate(VersionTemplate())
|
||||
rootCmd.AddCommand(newWebCmd(), newMigrateCmd(), newTokenCmd(), newWorkerCmd(), newResetPasswordCmd(), newSeedCmd())
|
||||
rootCmd.AddCommand(newWebCmd(), newMigrateCmd(), newWorkerCmd(), newResetPasswordCmd(), newSeedCmd())
|
||||
rootCmd.Execute()
|
||||
}
|
||||
|
@ -1,35 +0,0 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jordanknott/taskcafe/internal/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func newTokenCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "token",
|
||||
Short: "Create a long lived JWT token for dev purposes",
|
||||
Long: "Create a long lived JWT token for dev purposes",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
secret := viper.GetString("server.secret")
|
||||
if strings.TrimSpace(secret) == "" {
|
||||
return errors.New("server.secret must be set (TASKCAFE_SERVER_SECRET)")
|
||||
}
|
||||
token, err := auth.NewAccessTokenCustomExpiration(args[0], time.Hour*24, []byte(secret))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("issue while creating access token")
|
||||
return err
|
||||
}
|
||||
fmt.Println(token)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
@ -10,6 +10,13 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuthToken struct {
|
||||
TokenID uuid.UUID `json:"token_id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type LabelColor struct {
|
||||
LabelColorID uuid.UUID `json:"label_color_id"`
|
||||
ColorHex string `json:"color_hex"`
|
||||
@ -74,13 +81,6 @@ type ProjectMemberInvited struct {
|
||||
UserAccountInvitedID uuid.UUID `json:"user_account_invited_id"`
|
||||
}
|
||||
|
||||
type RefreshToken struct {
|
||||
TokenID uuid.UUID `json:"token_id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type Role struct {
|
||||
Code string `json:"code"`
|
||||
Name string `json:"name"`
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
)
|
||||
|
||||
type Querier interface {
|
||||
CreateAuthToken(ctx context.Context, arg CreateAuthTokenParams) (AuthToken, error)
|
||||
CreateConfirmToken(ctx context.Context, email string) (UserAccountConfirmToken, error)
|
||||
CreateInvitedProjectMember(ctx context.Context, arg CreateInvitedProjectMemberParams) (ProjectMemberInvited, error)
|
||||
CreateInvitedUser(ctx context.Context, email string) (UserAccountInvited, error)
|
||||
@ -20,7 +21,6 @@ type Querier interface {
|
||||
CreatePersonalProjectLink(ctx context.Context, arg CreatePersonalProjectLinkParams) (PersonalProject, error)
|
||||
CreateProjectLabel(ctx context.Context, arg CreateProjectLabelParams) (ProjectLabel, error)
|
||||
CreateProjectMember(ctx context.Context, arg CreateProjectMemberParams) (ProjectMember, error)
|
||||
CreateRefreshToken(ctx context.Context, arg CreateRefreshTokenParams) (RefreshToken, error)
|
||||
CreateSystemOption(ctx context.Context, arg CreateSystemOptionParams) (SystemOption, error)
|
||||
CreateTask(ctx context.Context, arg CreateTaskParams) (Task, error)
|
||||
CreateTaskActivity(ctx context.Context, arg CreateTaskActivityParams) (TaskActivity, error)
|
||||
@ -35,6 +35,8 @@ type Querier interface {
|
||||
CreateTeamMember(ctx context.Context, arg CreateTeamMemberParams) (TeamMember, error)
|
||||
CreateTeamProject(ctx context.Context, arg CreateTeamProjectParams) (Project, error)
|
||||
CreateUserAccount(ctx context.Context, arg CreateUserAccountParams) (UserAccount, error)
|
||||
DeleteAuthTokenByID(ctx context.Context, tokenID uuid.UUID) error
|
||||
DeleteAuthTokenByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteConfirmTokenForEmail(ctx context.Context, email string) error
|
||||
DeleteExpiredTokens(ctx context.Context) error
|
||||
DeleteInvitedProjectMemberByID(ctx context.Context, projectMemberInvitedID uuid.UUID) error
|
||||
@ -43,8 +45,6 @@ type Querier interface {
|
||||
DeleteProjectLabelByID(ctx context.Context, projectLabelID uuid.UUID) error
|
||||
DeleteProjectMember(ctx context.Context, arg DeleteProjectMemberParams) error
|
||||
DeleteProjectMemberInvitedForEmail(ctx context.Context, email string) error
|
||||
DeleteRefreshTokenByID(ctx context.Context, tokenID uuid.UUID) error
|
||||
DeleteRefreshTokenByUserID(ctx context.Context, userID uuid.UUID) error
|
||||
DeleteTaskAssignedByID(ctx context.Context, arg DeleteTaskAssignedByIDParams) (TaskAssigned, error)
|
||||
DeleteTaskByID(ctx context.Context, taskID uuid.UUID) error
|
||||
DeleteTaskChecklistByID(ctx context.Context, taskChecklistID uuid.UUID) error
|
||||
@ -71,6 +71,7 @@ type Querier interface {
|
||||
GetAssignedMembersForTask(ctx context.Context, taskID uuid.UUID) ([]TaskAssigned, error)
|
||||
GetAssignedTasksDueDateForUserID(ctx context.Context, arg GetAssignedTasksDueDateForUserIDParams) ([]Task, error)
|
||||
GetAssignedTasksProjectForUserID(ctx context.Context, arg GetAssignedTasksProjectForUserIDParams) ([]Task, error)
|
||||
GetAuthTokenByID(ctx context.Context, tokenID uuid.UUID) (AuthToken, error)
|
||||
GetCommentsForTaskID(ctx context.Context, taskID uuid.UUID) ([]TaskComment, error)
|
||||
GetConfirmTokenByEmail(ctx context.Context, email string) (UserAccountConfirmToken, error)
|
||||
GetConfirmTokenByID(ctx context.Context, confirmTokenID uuid.UUID) (UserAccountConfirmToken, error)
|
||||
@ -100,7 +101,6 @@ type Querier interface {
|
||||
GetProjectRolesForUserID(ctx context.Context, userID uuid.UUID) ([]GetProjectRolesForUserIDRow, error)
|
||||
GetProjectsForInvitedMember(ctx context.Context, email string) ([]uuid.UUID, error)
|
||||
GetRecentlyAssignedTaskForUserID(ctx context.Context, arg GetRecentlyAssignedTaskForUserIDParams) ([]Task, error)
|
||||
GetRefreshTokenByID(ctx context.Context, tokenID uuid.UUID) (RefreshToken, error)
|
||||
GetRoleForProjectMemberByUserID(ctx context.Context, arg GetRoleForProjectMemberByUserIDParams) (Role, error)
|
||||
GetRoleForTeamMember(ctx context.Context, arg GetRoleForTeamMemberParams) (Role, error)
|
||||
GetRoleForUserID(ctx context.Context, userID uuid.UUID) (GetRoleForUserIDRow, error)
|
||||
|
@ -1,14 +1,14 @@
|
||||
-- name: GetRefreshTokenByID :one
|
||||
SELECT * FROM refresh_token WHERE token_id = $1;
|
||||
-- name: GetAuthTokenByID :one
|
||||
SELECT * FROM auth_token WHERE token_id = $1;
|
||||
|
||||
-- name: CreateRefreshToken :one
|
||||
INSERT INTO refresh_token (user_id, created_at, expires_at) VALUES ($1, $2, $3) RETURNING *;
|
||||
-- name: CreateAuthToken :one
|
||||
INSERT INTO auth_token (user_id, created_at, expires_at) VALUES ($1, $2, $3) RETURNING *;
|
||||
|
||||
-- name: DeleteRefreshTokenByID :exec
|
||||
DELETE FROM refresh_token WHERE token_id = $1;
|
||||
-- name: DeleteAuthTokenByID :exec
|
||||
DELETE FROM auth_token WHERE token_id = $1;
|
||||
|
||||
-- name: DeleteRefreshTokenByUserID :exec
|
||||
DELETE FROM refresh_token WHERE user_id = $1;
|
||||
-- name: DeleteAuthTokenByUserID :exec
|
||||
DELETE FROM auth_token WHERE user_id = $1;
|
||||
|
||||
-- name: DeleteExpiredTokens :exec
|
||||
DELETE FROM refresh_token WHERE expires_at <= NOW();
|
||||
DELETE FROM auth_token WHERE expires_at <= NOW();
|
||||
|
@ -10,19 +10,19 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const createRefreshToken = `-- name: CreateRefreshToken :one
|
||||
INSERT INTO refresh_token (user_id, created_at, expires_at) VALUES ($1, $2, $3) RETURNING token_id, user_id, created_at, expires_at
|
||||
const createAuthToken = `-- name: CreateAuthToken :one
|
||||
INSERT INTO auth_token (user_id, created_at, expires_at) VALUES ($1, $2, $3) RETURNING token_id, user_id, created_at, expires_at
|
||||
`
|
||||
|
||||
type CreateRefreshTokenParams struct {
|
||||
type CreateAuthTokenParams struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateRefreshToken(ctx context.Context, arg CreateRefreshTokenParams) (RefreshToken, error) {
|
||||
row := q.db.QueryRowContext(ctx, createRefreshToken, arg.UserID, arg.CreatedAt, arg.ExpiresAt)
|
||||
var i RefreshToken
|
||||
func (q *Queries) CreateAuthToken(ctx context.Context, arg CreateAuthTokenParams) (AuthToken, error) {
|
||||
row := q.db.QueryRowContext(ctx, createAuthToken, arg.UserID, arg.CreatedAt, arg.ExpiresAt)
|
||||
var i AuthToken
|
||||
err := row.Scan(
|
||||
&i.TokenID,
|
||||
&i.UserID,
|
||||
@ -32,8 +32,26 @@ func (q *Queries) CreateRefreshToken(ctx context.Context, arg CreateRefreshToken
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteAuthTokenByID = `-- name: DeleteAuthTokenByID :exec
|
||||
DELETE FROM auth_token WHERE token_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteAuthTokenByID(ctx context.Context, tokenID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteAuthTokenByID, tokenID)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteAuthTokenByUserID = `-- name: DeleteAuthTokenByUserID :exec
|
||||
DELETE FROM auth_token WHERE user_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteAuthTokenByUserID(ctx context.Context, userID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteAuthTokenByUserID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteExpiredTokens = `-- name: DeleteExpiredTokens :exec
|
||||
DELETE FROM refresh_token WHERE expires_at <= NOW()
|
||||
DELETE FROM auth_token WHERE expires_at <= NOW()
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteExpiredTokens(ctx context.Context) error {
|
||||
@ -41,31 +59,13 @@ func (q *Queries) DeleteExpiredTokens(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteRefreshTokenByID = `-- name: DeleteRefreshTokenByID :exec
|
||||
DELETE FROM refresh_token WHERE token_id = $1
|
||||
const getAuthTokenByID = `-- name: GetAuthTokenByID :one
|
||||
SELECT token_id, user_id, created_at, expires_at FROM auth_token WHERE token_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteRefreshTokenByID(ctx context.Context, tokenID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteRefreshTokenByID, tokenID)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteRefreshTokenByUserID = `-- name: DeleteRefreshTokenByUserID :exec
|
||||
DELETE FROM refresh_token WHERE user_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteRefreshTokenByUserID(ctx context.Context, userID uuid.UUID) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteRefreshTokenByUserID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
const getRefreshTokenByID = `-- name: GetRefreshTokenByID :one
|
||||
SELECT token_id, user_id, created_at, expires_at FROM refresh_token WHERE token_id = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetRefreshTokenByID(ctx context.Context, tokenID uuid.UUID) (RefreshToken, error) {
|
||||
row := q.db.QueryRowContext(ctx, getRefreshTokenByID, tokenID)
|
||||
var i RefreshToken
|
||||
func (q *Queries) GetAuthTokenByID(ctx context.Context, tokenID uuid.UUID) (AuthToken, error) {
|
||||
row := q.db.QueryRowContext(ctx, getAuthTokenByID, tokenID)
|
||||
var i AuthToken
|
||||
err := row.Scan(
|
||||
&i.TokenID,
|
||||
&i.UserID,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -17,7 +17,6 @@ import (
|
||||
"github.com/99designs/gqlgen/graphql/handler/transport"
|
||||
"github.com/99designs/gqlgen/graphql/playground"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jordanknott/taskcafe/internal/auth"
|
||||
"github.com/jordanknott/taskcafe/internal/db"
|
||||
"github.com/jordanknott/taskcafe/internal/logger"
|
||||
"github.com/jordanknott/taskcafe/internal/utils"
|
||||
@ -34,15 +33,18 @@ func NewHandler(repo db.Repository, emailConfig utils.EmailConfig) http.Handler
|
||||
},
|
||||
}
|
||||
c.Directives.HasRole = func(ctx context.Context, obj interface{}, next graphql.Resolver, roles []RoleLevel, level ActionLevel, typeArg ObjectType) (interface{}, error) {
|
||||
role, ok := GetUserRole(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("user ID is missing")
|
||||
}
|
||||
if role == "admin" {
|
||||
return next(ctx)
|
||||
} else if level == ActionLevelOrg {
|
||||
return nil, errors.New("must be an org admin")
|
||||
}
|
||||
/*
|
||||
TODO: add permission check
|
||||
role, ok := GetUserRole(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("user ID is missing")
|
||||
}
|
||||
if role == "admin" {
|
||||
return next(ctx)
|
||||
} else if level == ActionLevelOrg {
|
||||
return nil, errors.New("must be an org admin")
|
||||
}
|
||||
*/
|
||||
|
||||
var subjectID uuid.UUID
|
||||
in := graphql.GetFieldContext(ctx).Args["input"]
|
||||
@ -76,7 +78,7 @@ func NewHandler(repo db.Repository, emailConfig utils.EmailConfig) http.Handler
|
||||
// TODO: add config setting to disable personal projects
|
||||
return next(ctx)
|
||||
}
|
||||
subjectID, ok = subjectField.Interface().(uuid.UUID)
|
||||
subjectID, ok := subjectField.Interface().(uuid.UUID)
|
||||
if !ok {
|
||||
logger.New(ctx).Error("error while casting subject UUID")
|
||||
return nil, errors.New("error while casting subject uuid")
|
||||
@ -190,23 +192,10 @@ func GetUserID(ctx context.Context) (uuid.UUID, bool) {
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// GetUserRole retrieves the user role out of a context
|
||||
func GetUserRole(ctx context.Context) (auth.Role, bool) {
|
||||
role, ok := ctx.Value(utils.OrgRoleKey).(auth.Role)
|
||||
return role, ok
|
||||
}
|
||||
|
||||
// GetUser retrieves both the user id & user role out of a context
|
||||
func GetUser(ctx context.Context) (uuid.UUID, auth.Role, bool) {
|
||||
func GetUser(ctx context.Context) (uuid.UUID, bool) {
|
||||
userID, userOK := GetUserID(ctx)
|
||||
role, roleOK := GetUserRole(ctx)
|
||||
return userID, role, userOK && roleOK
|
||||
}
|
||||
|
||||
// GetRestrictedMode retrieves the restricted mode code out of a context
|
||||
func GetRestrictedMode(ctx context.Context) (auth.RestrictedMode, bool) {
|
||||
restricted, ok := ctx.Value(utils.RestrictedModeKey).(auth.RestrictedMode)
|
||||
return restricted, ok
|
||||
return userID, userOK
|
||||
}
|
||||
|
||||
// GetProjectRoles retrieves the team & project role for the given project ID
|
||||
|
@ -255,6 +255,7 @@ type LogoutUser struct {
|
||||
|
||||
type MePayload struct {
|
||||
User *db.UserAccount `json:"user"`
|
||||
Organization *RoleCode `json:"organization"`
|
||||
TeamRoles []TeamRole `json:"teamRoles"`
|
||||
ProjectRoles []ProjectRole `json:"projectRoles"`
|
||||
}
|
||||
@ -312,10 +313,6 @@ type NewProjectLabel struct {
|
||||
Name *string `json:"name"`
|
||||
}
|
||||
|
||||
type NewRefreshToken struct {
|
||||
UserID uuid.UUID `json:"userID"`
|
||||
}
|
||||
|
||||
type NewTask struct {
|
||||
TaskGroupID uuid.UUID `json:"taskGroupID"`
|
||||
Name string `json:"name"`
|
||||
@ -382,6 +379,12 @@ type ProfileIcon struct {
|
||||
BgColor *string `json:"bgColor"`
|
||||
}
|
||||
|
||||
type ProjectPermission struct {
|
||||
Team RoleCode `json:"team"`
|
||||
Project RoleCode `json:"project"`
|
||||
Org RoleCode `json:"org"`
|
||||
}
|
||||
|
||||
type ProjectRole struct {
|
||||
ProjectID uuid.UUID `json:"projectID"`
|
||||
RoleCode RoleCode `json:"roleCode"`
|
||||
@ -435,6 +438,11 @@ type TaskPositionUpdate struct {
|
||||
Position float64 `json:"position"`
|
||||
}
|
||||
|
||||
type TeamPermission struct {
|
||||
Team RoleCode `json:"team"`
|
||||
Org RoleCode `json:"org"`
|
||||
}
|
||||
|
||||
type TeamRole struct {
|
||||
TeamID uuid.UUID `json:"teamID"`
|
||||
RoleCode RoleCode `json:"roleCode"`
|
||||
|
@ -50,13 +50,6 @@ type Member {
|
||||
member: MemberList!
|
||||
}
|
||||
|
||||
type RefreshToken {
|
||||
id: ID!
|
||||
userId: UUID!
|
||||
expiresAt: Time!
|
||||
createdAt: Time!
|
||||
}
|
||||
|
||||
type Role {
|
||||
code: String!
|
||||
name: String!
|
||||
@ -97,6 +90,7 @@ type Team {
|
||||
id: ID!
|
||||
createdAt: Time!
|
||||
name: String!
|
||||
permission: TeamPermission!
|
||||
members: [Member!]!
|
||||
}
|
||||
|
||||
@ -106,6 +100,17 @@ type InvitedMember {
|
||||
invitedOn: Time!
|
||||
}
|
||||
|
||||
type TeamPermission {
|
||||
team: RoleCode!
|
||||
org: RoleCode!
|
||||
}
|
||||
|
||||
type ProjectPermission {
|
||||
team: RoleCode!
|
||||
project: RoleCode!
|
||||
org: RoleCode!
|
||||
}
|
||||
|
||||
type Project {
|
||||
id: ID!
|
||||
createdAt: Time!
|
||||
@ -114,6 +119,7 @@ type Project {
|
||||
taskGroups: [TaskGroup!]!
|
||||
members: [Member!]!
|
||||
invitedMembers: [InvitedMember!]!
|
||||
permission: ProjectPermission!
|
||||
labels: [ProjectLabel!]!
|
||||
}
|
||||
|
||||
@ -314,6 +320,7 @@ type ProjectRole {
|
||||
|
||||
type MePayload {
|
||||
user: UserAccount!
|
||||
organization: RoleCode
|
||||
teamRoles: [TeamRole!]!
|
||||
projectRoles: [ProjectRole!]!
|
||||
}
|
||||
@ -881,7 +888,6 @@ type UpdateTeamMemberRolePayload {
|
||||
}
|
||||
|
||||
extend type Mutation {
|
||||
createRefreshToken(input: NewRefreshToken!): RefreshToken!
|
||||
createUserAccount(input: NewUserAccount!):
|
||||
UserAccount! @hasRole(roles: [ADMIN], level: ORG, type: ORG)
|
||||
deleteUserAccount(input: DeleteUserAccount!):
|
||||
@ -954,10 +960,6 @@ type UpdateUserRolePayload {
|
||||
user: UserAccount!
|
||||
}
|
||||
|
||||
input NewRefreshToken {
|
||||
userID: UUID!
|
||||
}
|
||||
|
||||
input NewUserAccount {
|
||||
username: String!
|
||||
email: String!
|
||||
|
@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jinzhu/now"
|
||||
"github.com/jordanknott/taskcafe/internal/auth"
|
||||
"github.com/jordanknott/taskcafe/internal/db"
|
||||
"github.com/jordanknott/taskcafe/internal/logger"
|
||||
"github.com/jordanknott/taskcafe/internal/utils"
|
||||
@ -864,11 +863,12 @@ func (r *mutationResolver) DeleteTeam(ctx context.Context, input DeleteTeam) (*D
|
||||
}
|
||||
|
||||
func (r *mutationResolver) CreateTeam(ctx context.Context, input NewTeam) (*db.Team, error) {
|
||||
_, role, ok := GetUser(ctx)
|
||||
_, ok := GetUser(ctx)
|
||||
if !ok {
|
||||
return &db.Team{}, nil
|
||||
}
|
||||
if role == auth.RoleAdmin {
|
||||
// if role == auth.RoleAdmin { // TODO: add permision check
|
||||
if true {
|
||||
createdAt := time.Now().UTC()
|
||||
team, err := r.Repository.CreateTeam(ctx, db.CreateTeamParams{OrganizationID: input.OrganizationID, CreatedAt: createdAt, Name: input.Name})
|
||||
return &team, err
|
||||
@ -944,20 +944,13 @@ func (r *mutationResolver) DeleteTeamMember(ctx context.Context, input DeleteTea
|
||||
return &DeleteTeamMemberPayload{TeamID: input.TeamID, UserID: input.UserID}, err
|
||||
}
|
||||
|
||||
func (r *mutationResolver) CreateRefreshToken(ctx context.Context, input NewRefreshToken) (*db.RefreshToken, error) {
|
||||
userID := uuid.MustParse("0183d9ab-d0ed-4c9b-a3df-77a0cdd93dca")
|
||||
refreshCreatedAt := time.Now().UTC()
|
||||
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
|
||||
refreshToken, err := r.Repository.CreateRefreshToken(ctx, db.CreateRefreshTokenParams{userID, refreshCreatedAt, refreshExpiresAt})
|
||||
return &refreshToken, err
|
||||
}
|
||||
|
||||
func (r *mutationResolver) CreateUserAccount(ctx context.Context, input NewUserAccount) (*db.UserAccount, error) {
|
||||
_, role, ok := GetUser(ctx)
|
||||
_, ok := GetUser(ctx)
|
||||
if !ok {
|
||||
return &db.UserAccount{}, nil
|
||||
}
|
||||
if role != auth.RoleAdmin {
|
||||
// if role != auth.RoleAdmin { TODO: add permsion check
|
||||
if true {
|
||||
return &db.UserAccount{}, &gqlerror.Error{
|
||||
Message: "Must be an organization admin",
|
||||
Extensions: map[string]interface{}{
|
||||
@ -984,11 +977,12 @@ func (r *mutationResolver) CreateUserAccount(ctx context.Context, input NewUserA
|
||||
}
|
||||
|
||||
func (r *mutationResolver) DeleteUserAccount(ctx context.Context, input DeleteUserAccount) (*DeleteUserAccountPayload, error) {
|
||||
_, role, ok := GetUser(ctx)
|
||||
_, ok := GetUser(ctx)
|
||||
if !ok {
|
||||
return &DeleteUserAccountPayload{Ok: false}, nil
|
||||
}
|
||||
if role != auth.RoleAdmin {
|
||||
// if role != auth.RoleAdmin { TODO: add permision check
|
||||
if true {
|
||||
return &DeleteUserAccountPayload{Ok: false}, &gqlerror.Error{
|
||||
Message: "User not found",
|
||||
Extensions: map[string]interface{}{
|
||||
@ -1030,7 +1024,7 @@ func (r *mutationResolver) DeleteInvitedUserAccount(ctx context.Context, input D
|
||||
}
|
||||
|
||||
func (r *mutationResolver) LogoutUser(ctx context.Context, input LogoutUser) (bool, error) {
|
||||
err := r.Repository.DeleteRefreshTokenByUserID(ctx, input.UserID)
|
||||
err := r.Repository.DeleteAuthTokenByUserID(ctx, input.UserID)
|
||||
return true, err
|
||||
}
|
||||
|
||||
@ -1059,11 +1053,12 @@ func (r *mutationResolver) UpdateUserPassword(ctx context.Context, input UpdateU
|
||||
}
|
||||
|
||||
func (r *mutationResolver) UpdateUserRole(ctx context.Context, input UpdateUserRole) (*UpdateUserRolePayload, error) {
|
||||
_, role, ok := GetUser(ctx)
|
||||
_, ok := GetUser(ctx)
|
||||
if !ok {
|
||||
return &UpdateUserRolePayload{}, nil
|
||||
}
|
||||
if role != auth.RoleAdmin {
|
||||
// if role != auth.RoleAdmin { TODO: add permision check
|
||||
if true {
|
||||
return &UpdateUserRolePayload{}, &gqlerror.Error{
|
||||
Message: "User not found",
|
||||
Extensions: map[string]interface{}{
|
||||
@ -1211,6 +1206,10 @@ func (r *projectResolver) InvitedMembers(ctx context.Context, obj *db.Project) (
|
||||
return invited, err
|
||||
}
|
||||
|
||||
func (r *projectResolver) Permission(ctx context.Context, obj *db.Project) (*ProjectPermission, error) {
|
||||
panic(fmt.Errorf("not implemented"))
|
||||
}
|
||||
|
||||
func (r *projectResolver) Labels(ctx context.Context, obj *db.Project) ([]db.ProjectLabel, error) {
|
||||
labels, err := r.Repository.GetProjectLabelsForProject(ctx, obj.ProjectID)
|
||||
return labels, err
|
||||
@ -1296,7 +1295,7 @@ func (r *queryResolver) FindTask(ctx context.Context, input FindTask) (*db.Task,
|
||||
}
|
||||
|
||||
func (r *queryResolver) Projects(ctx context.Context, input *ProjectsFilter) ([]db.Project, error) {
|
||||
userID, orgRole, ok := GetUser(ctx)
|
||||
userID, ok := GetUser(ctx)
|
||||
if !ok {
|
||||
logger.New(ctx).Info("user id was not found from middleware")
|
||||
return []db.Project{}, nil
|
||||
@ -1309,11 +1308,14 @@ func (r *queryResolver) Projects(ctx context.Context, input *ProjectsFilter) ([]
|
||||
|
||||
var teams []db.Team
|
||||
var err error
|
||||
/* TODO: add permsion check
|
||||
if orgRole == "admin" {
|
||||
teams, err = r.Repository.GetAllTeams(ctx)
|
||||
} else {
|
||||
teams, err = r.Repository.GetTeamsForUserIDWhereAdmin(ctx, userID)
|
||||
}
|
||||
*/
|
||||
teams, err = r.Repository.GetAllTeams(ctx)
|
||||
|
||||
projects := make(map[string]db.Project)
|
||||
for _, team := range teams {
|
||||
@ -1359,15 +1361,18 @@ func (r *queryResolver) FindTeam(ctx context.Context, input FindTeam) (*db.Team,
|
||||
}
|
||||
|
||||
func (r *queryResolver) Teams(ctx context.Context) ([]db.Team, error) {
|
||||
userID, orgRole, ok := GetUser(ctx)
|
||||
userID, ok := GetUser(ctx)
|
||||
if !ok {
|
||||
logger.New(ctx).Error("userID or org role does not exist")
|
||||
return []db.Team{}, errors.New("internal error")
|
||||
}
|
||||
if orgRole == "admin" {
|
||||
|
||||
return r.Repository.GetAllTeams(ctx)
|
||||
}
|
||||
/*
|
||||
TODO: add permision check
|
||||
if orgRole == "admin" {
|
||||
return r.Repository.GetAllTeams(ctx)
|
||||
}
|
||||
*/
|
||||
|
||||
teams := make(map[string]db.Team)
|
||||
adminTeams, err := r.Repository.GetTeamsForUserIDWhereAdmin(ctx, userID)
|
||||
@ -1596,10 +1601,6 @@ func (r *queryResolver) SearchMembers(ctx context.Context, input MemberSearchFil
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *refreshTokenResolver) ID(ctx context.Context, obj *db.RefreshToken) (uuid.UUID, error) {
|
||||
return obj.TokenID, nil
|
||||
}
|
||||
|
||||
func (r *taskResolver) ID(ctx context.Context, obj *db.Task) (uuid.UUID, error) {
|
||||
return obj.TaskID, nil
|
||||
}
|
||||
@ -1848,6 +1849,10 @@ func (r *teamResolver) ID(ctx context.Context, obj *db.Team) (uuid.UUID, error)
|
||||
return obj.TeamID, nil
|
||||
}
|
||||
|
||||
func (r *teamResolver) Permission(ctx context.Context, obj *db.Team) (*TeamPermission, error) {
|
||||
panic(fmt.Errorf("not implemented"))
|
||||
}
|
||||
|
||||
func (r *teamResolver) Members(ctx context.Context, obj *db.Team) ([]Member, error) {
|
||||
members := []Member{}
|
||||
|
||||
@ -1966,9 +1971,6 @@ func (r *Resolver) ProjectLabel() ProjectLabelResolver { return &projectLabelRes
|
||||
// Query returns QueryResolver implementation.
|
||||
func (r *Resolver) Query() QueryResolver { return &queryResolver{r} }
|
||||
|
||||
// RefreshToken returns RefreshTokenResolver implementation.
|
||||
func (r *Resolver) RefreshToken() RefreshTokenResolver { return &refreshTokenResolver{r} }
|
||||
|
||||
// Task returns TaskResolver implementation.
|
||||
func (r *Resolver) Task() TaskResolver { return &taskResolver{r} }
|
||||
|
||||
@ -2005,7 +2007,6 @@ type organizationResolver struct{ *Resolver }
|
||||
type projectResolver struct{ *Resolver }
|
||||
type projectLabelResolver struct{ *Resolver }
|
||||
type queryResolver struct{ *Resolver }
|
||||
type refreshTokenResolver struct{ *Resolver }
|
||||
type taskResolver struct{ *Resolver }
|
||||
type taskActivityResolver struct{ *Resolver }
|
||||
type taskChecklistResolver struct{ *Resolver }
|
||||
|
@ -50,13 +50,6 @@ type Member {
|
||||
member: MemberList!
|
||||
}
|
||||
|
||||
type RefreshToken {
|
||||
id: ID!
|
||||
userId: UUID!
|
||||
expiresAt: Time!
|
||||
createdAt: Time!
|
||||
}
|
||||
|
||||
type Role {
|
||||
code: String!
|
||||
name: String!
|
||||
@ -97,6 +90,7 @@ type Team {
|
||||
id: ID!
|
||||
createdAt: Time!
|
||||
name: String!
|
||||
permission: TeamPermission!
|
||||
members: [Member!]!
|
||||
}
|
||||
|
||||
@ -106,6 +100,17 @@ type InvitedMember {
|
||||
invitedOn: Time!
|
||||
}
|
||||
|
||||
type TeamPermission {
|
||||
team: RoleCode!
|
||||
org: RoleCode!
|
||||
}
|
||||
|
||||
type ProjectPermission {
|
||||
team: RoleCode!
|
||||
project: RoleCode!
|
||||
org: RoleCode!
|
||||
}
|
||||
|
||||
type Project {
|
||||
id: ID!
|
||||
createdAt: Time!
|
||||
@ -114,6 +119,7 @@ type Project {
|
||||
taskGroups: [TaskGroup!]!
|
||||
members: [Member!]!
|
||||
invitedMembers: [InvitedMember!]!
|
||||
permission: ProjectPermission!
|
||||
labels: [ProjectLabel!]!
|
||||
}
|
||||
|
||||
|
@ -90,6 +90,7 @@ type ProjectRole {
|
||||
|
||||
type MePayload {
|
||||
user: UserAccount!
|
||||
organization: RoleCode
|
||||
teamRoles: [TeamRole!]!
|
||||
projectRoles: [ProjectRole!]!
|
||||
}
|
||||
|
@ -1,5 +1,4 @@
|
||||
extend type Mutation {
|
||||
createRefreshToken(input: NewRefreshToken!): RefreshToken!
|
||||
createUserAccount(input: NewUserAccount!):
|
||||
UserAccount! @hasRole(roles: [ADMIN], level: ORG, type: ORG)
|
||||
deleteUserAccount(input: DeleteUserAccount!):
|
||||
@ -72,10 +71,6 @@ type UpdateUserRolePayload {
|
||||
user: UserAccount!
|
||||
}
|
||||
|
||||
input NewRefreshToken {
|
||||
userID: UUID!
|
||||
}
|
||||
|
||||
input NewUserAccount {
|
||||
username: String!
|
||||
email: String!
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jordanknott/taskcafe/internal/auth"
|
||||
"github.com/jordanknott/taskcafe/internal/db"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@ -54,10 +53,15 @@ type Setup struct {
|
||||
ConfirmToken string `json:"confirmToken"`
|
||||
}
|
||||
|
||||
type ValidateAuthTokenResponse struct {
|
||||
Valid bool `json:"valid"`
|
||||
UserID string `json:"userID"`
|
||||
}
|
||||
|
||||
// LoginResponseData is the response data for when a user logs in
|
||||
type LoginResponseData struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
Setup bool `json:"setup"`
|
||||
UserID string `json:"userID"`
|
||||
Complete bool `json:"complete"`
|
||||
}
|
||||
|
||||
// LogoutResponseData is the response data for when a user logs out
|
||||
@ -65,8 +69,8 @@ type LogoutResponseData struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// RefreshTokenResponseData is the response data for when an access token is refreshed
|
||||
type RefreshTokenResponseData struct {
|
||||
// AuthTokenResponseData is the response data for when an access token is refreshed
|
||||
type AuthTokenResponseData struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
}
|
||||
|
||||
@ -76,93 +80,9 @@ type AvatarUploadResponseData struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// RefreshTokenHandler handles when a user attempts to refresh token
|
||||
func (h *TaskcafeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userExists, err := h.repo.HasAnyUser(r.Context())
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
log.WithError(err).Error("issue while fetching if user accounts exist")
|
||||
return
|
||||
}
|
||||
|
||||
log.WithField("userExists", userExists).Info("checking if setup")
|
||||
if !userExists {
|
||||
w.Header().Set("Content-type", "application/json")
|
||||
json.NewEncoder(w).Encode(LoginResponseData{AccessToken: "", Setup: true})
|
||||
return
|
||||
}
|
||||
|
||||
c, err := r.Cookie("refreshToken")
|
||||
if err != nil {
|
||||
if err == http.ErrNoCookie {
|
||||
log.Warn("no cookie")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
log.WithError(err).Error("unknown error")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
refreshTokenID := uuid.MustParse(c.Value)
|
||||
token, err := h.repo.GetRefreshTokenByID(r.Context(), refreshTokenID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
|
||||
log.WithError(err).WithFields(log.Fields{"refreshTokenID": refreshTokenID.String()}).Error("no tokens found")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
log.WithError(err).Error("token retrieve failure")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.repo.GetUserAccountByID(r.Context(), token.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("user retrieve failure")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.Active {
|
||||
log.WithFields(log.Fields{
|
||||
"username": user.Username,
|
||||
}).Warn("attempt to refresh token with inactive user")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
refreshCreatedAt := time.Now().UTC()
|
||||
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
|
||||
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{token.UserID, refreshCreatedAt, refreshExpiresAt})
|
||||
|
||||
err = h.repo.DeleteRefreshTokenByID(r.Context(), token.TokenID)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("here 1")
|
||||
accessTokenString, err := auth.NewAccessToken(token.UserID.String(), auth.Unrestricted, user.RoleCode, h.SecurityConfig.Secret, h.SecurityConfig.AccessTokenExpiration)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("here 2")
|
||||
w.Header().Set("Content-type", "application/json")
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refreshToken",
|
||||
Value: refreshTokenString.TokenID.String(),
|
||||
Expires: refreshExpiresAt,
|
||||
HttpOnly: true,
|
||||
})
|
||||
json.NewEncoder(w).Encode(LoginResponseData{AccessToken: accessTokenString, Setup: false})
|
||||
}
|
||||
|
||||
// LogoutHandler removes all refresh tokens to log out user
|
||||
func (h *TaskcafeHandler) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := r.Cookie("refreshToken")
|
||||
c, err := r.Cookie("authToken")
|
||||
if err != nil {
|
||||
if err == http.ErrNoCookie {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@ -172,7 +92,7 @@ func (h *TaskcafeHandler) LogoutHandler(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
refreshTokenID := uuid.MustParse(c.Value)
|
||||
err = h.repo.DeleteRefreshTokenByID(r.Context(), refreshTokenID)
|
||||
err = h.repo.DeleteAuthTokenByID(r.Context(), refreshTokenID)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
@ -216,87 +136,23 @@ func (h *TaskcafeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
refreshCreatedAt := time.Now().UTC()
|
||||
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
|
||||
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
|
||||
authCreatedAt := time.Now().UTC()
|
||||
authExpiresAt := authCreatedAt.AddDate(0, 0, 1)
|
||||
authToken, err := h.repo.CreateAuthToken(r.Context(), db.CreateAuthTokenParams{user.UserID, authCreatedAt, authExpiresAt})
|
||||
|
||||
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.SecurityConfig.Secret, h.SecurityConfig.AccessTokenExpiration)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-type", "application/json")
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refreshToken",
|
||||
Value: refreshTokenString.TokenID.String(),
|
||||
Expires: refreshExpiresAt,
|
||||
Name: "authToken",
|
||||
Value: authToken.TokenID.String(),
|
||||
Expires: authExpiresAt,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
})
|
||||
json.NewEncoder(w).Encode(LoginResponseData{accessTokenString, false})
|
||||
}
|
||||
|
||||
// TODO: remove
|
||||
// InstallHandler creates first user on fresh install
|
||||
func (h *TaskcafeHandler) InstallHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if restricted, ok := r.Context().Value("restricted_mode").(auth.RestrictedMode); ok {
|
||||
if restricted != auth.InstallOnly {
|
||||
log.Warning("attempted to install without install only restriction")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, err := h.repo.GetSystemOptionByKey(r.Context(), "is_installed")
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("install handler called even though system is installed")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var requestData InstallRequestData
|
||||
err = json.NewDecoder(r.Body).Decode(&requestData)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
createdAt := time.Now().UTC()
|
||||
hashedPwd, err := bcrypt.GenerateFromPassword([]byte(requestData.User.Password), 14)
|
||||
user, err := h.repo.CreateUserAccount(r.Context(), db.CreateUserAccountParams{
|
||||
FullName: requestData.User.FullName,
|
||||
Username: requestData.User.Username,
|
||||
Initials: requestData.User.Initials,
|
||||
Email: requestData.User.Email,
|
||||
PasswordHash: string(hashedPwd),
|
||||
CreatedAt: createdAt,
|
||||
RoleCode: "admin",
|
||||
})
|
||||
|
||||
_, err = h.repo.CreateSystemOption(r.Context(), db.CreateSystemOptionParams{Key: "is_installed", Value: sql.NullString{Valid: true, String: "true"}})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
refreshCreatedAt := time.Now().UTC()
|
||||
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
|
||||
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
|
||||
|
||||
log.WithField("userID", user.UserID.String()).Info("creating install access token")
|
||||
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.SecurityConfig.Secret, h.SecurityConfig.AccessTokenExpiration)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
log.Info(accessTokenString)
|
||||
|
||||
w.Header().Set("Content-type", "application/json")
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refreshToken",
|
||||
Value: refreshTokenString.TokenID.String(),
|
||||
Expires: refreshExpiresAt,
|
||||
HttpOnly: true,
|
||||
})
|
||||
json.NewEncoder(w).Encode(LoginResponseData{accessTokenString, false})
|
||||
json.NewEncoder(w).Encode(LoginResponseData{Complete: true, UserID: authToken.UserID.String()})
|
||||
}
|
||||
|
||||
func (h *TaskcafeHandler) ConfirmUser(w http.ResponseWriter, r *http.Request) {
|
||||
@ -382,23 +238,43 @@ func (h *TaskcafeHandler) ConfirmUser(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
}
|
||||
|
||||
refreshCreatedAt := time.Now().UTC()
|
||||
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
|
||||
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
|
||||
authCreatedAt := time.Now().UTC()
|
||||
authExpiresAt := authCreatedAt.AddDate(0, 0, 1)
|
||||
authToken, err := h.repo.CreateAuthToken(r.Context(), db.CreateAuthTokenParams{user.UserID, authCreatedAt, authExpiresAt})
|
||||
|
||||
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.SecurityConfig.Secret, h.SecurityConfig.AccessTokenExpiration)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-type", "application/json")
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refreshToken",
|
||||
Value: refreshTokenString.TokenID.String(),
|
||||
Expires: refreshExpiresAt,
|
||||
Name: "authToken",
|
||||
Value: authToken.TokenID.String(),
|
||||
Path: "/",
|
||||
Expires: authExpiresAt,
|
||||
HttpOnly: true,
|
||||
})
|
||||
json.NewEncoder(w).Encode(LoginResponseData{accessTokenString, false})
|
||||
json.NewEncoder(w).Encode(LoginResponseData{Complete: true, UserID: authToken.UserID.String()})
|
||||
}
|
||||
func (h *TaskcafeHandler) ValidateAuthTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := r.Cookie("authToken")
|
||||
if err != nil {
|
||||
if err == http.ErrNoCookie {
|
||||
json.NewEncoder(w).Encode(ValidateAuthTokenResponse{Valid: false, UserID: ""})
|
||||
return
|
||||
}
|
||||
log.WithError(err).Error("unknown error")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
authTokenID := uuid.MustParse(c.Value)
|
||||
token, err := h.repo.GetAuthTokenByID(r.Context(), authTokenID)
|
||||
if err != nil {
|
||||
json.NewEncoder(w).Encode(ValidateAuthTokenResponse{Valid: false, UserID: ""})
|
||||
} else {
|
||||
json.NewEncoder(w).Encode(ValidateAuthTokenResponse{Valid: true, UserID: token.UserID.String()})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TaskcafeHandler) RegisterUser(w http.ResponseWriter, r *http.Request) {
|
||||
@ -425,6 +301,7 @@ func (h *TaskcafeHandler) RegisterUser(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error checking for active user")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !hasActiveUser {
|
||||
json.NewEncoder(w).Encode(RegisteredUserResponseData{Setup: true})
|
||||
@ -469,7 +346,7 @@ func (h *TaskcafeHandler) RegisterUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (rs authResource) Routes(taskcafeHandler TaskcafeHandler) chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Post("/login", taskcafeHandler.LoginHandler)
|
||||
r.Post("/refresh_token", taskcafeHandler.RefreshTokenHandler)
|
||||
r.Post("/logout", taskcafeHandler.LogoutHandler)
|
||||
r.Post("/validate", taskcafeHandler.ValidateAuthTokenHandler)
|
||||
return r
|
||||
}
|
||||
|
@ -3,35 +3,51 @@ package route
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jordanknott/taskcafe/internal/auth"
|
||||
"github.com/jordanknott/taskcafe/internal/db"
|
||||
"github.com/jordanknott/taskcafe/internal/utils"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// AuthenticationMiddleware is a middleware that requires a valid JWT token to be passed via the Authorization header
|
||||
type AuthenticationMiddleware struct {
|
||||
jwtKey []byte
|
||||
repo db.Repository
|
||||
}
|
||||
|
||||
// Middleware returns the middleware handler
|
||||
func (m *AuthenticationMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := uuid.New()
|
||||
bearerTokenRaw := r.Header.Get("Authorization")
|
||||
splitToken := strings.Split(bearerTokenRaw, "Bearer")
|
||||
if len(splitToken) != 2 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
accessTokenString := strings.TrimSpace(splitToken[1])
|
||||
accessClaims, err := auth.ValidateAccessToken(accessTokenString, m.jwtKey)
|
||||
foundToken := true
|
||||
tokenRaw := ""
|
||||
c, err := r.Cookie("authToken")
|
||||
if err != nil {
|
||||
if _, ok := err.(*auth.ErrExpiredToken); ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{
|
||||
if err == http.ErrNoCookie {
|
||||
foundToken = false
|
||||
} else {
|
||||
log.WithError(err).Error("unknown error")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !foundToken {
|
||||
token := r.Header.Get("Authorization")
|
||||
if token == "" {
|
||||
log.WithError(err).Error("no auth token found in cookie or authorization header")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
tokenRaw = token
|
||||
} else {
|
||||
tokenRaw = c.Value
|
||||
}
|
||||
authTokenID := uuid.MustParse(tokenRaw)
|
||||
token, err := m.repo.GetAuthTokenByID(r.Context(), authTokenID)
|
||||
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{
|
||||
"data": {},
|
||||
"errors": [
|
||||
{
|
||||
@ -41,27 +57,12 @@ func (m *AuthenticationMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
}
|
||||
]
|
||||
}`))
|
||||
return
|
||||
}
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var userID uuid.UUID
|
||||
if accessClaims.Restricted == auth.InstallOnly {
|
||||
userID = uuid.New()
|
||||
} else {
|
||||
userID, err = uuid.Parse(accessClaims.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("middleware access token userID parse")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), utils.UserIDKey, userID)
|
||||
ctx = context.WithValue(ctx, utils.RestrictedModeKey, accessClaims.Restricted)
|
||||
ctx = context.WithValue(ctx, utils.OrgRoleKey, accessClaims.OrgRole)
|
||||
ctx := context.WithValue(r.Context(), utils.UserIDKey, token.UserID)
|
||||
// ctx = context.WithValue(ctx, utils.RestrictedModeKey, accessClaims.Restricted)
|
||||
// ctx = context.WithValue(ctx, utils.OrgRoleKey, accessClaims.OrgRole)
|
||||
ctx = context.WithValue(ctx, utils.ReqIDKey, requestID)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/jmoiron/sqlx"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@ -80,6 +81,17 @@ func NewRouter(dbConnection *sqlx.DB, emailConfig utils.EmailConfig, securityCon
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.Timeout(60 * time.Second))
|
||||
|
||||
r.Use(cors.Handler(cors.Options{
|
||||
// AllowedOrigins: []string{"https://foo.com"}, // Use this to allow specific origin hosts
|
||||
AllowedOrigins: []string{"https://*", "http://*"},
|
||||
// AllowOriginFunc: func(r *http.Request, origin string) bool { return true },
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Cookie", "Content-Type", "X-CSRF-Token"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 300, // Maximum value not ignored by any of major browsers
|
||||
}))
|
||||
|
||||
repository := db.NewRepository(dbConnection)
|
||||
taskcafeHandler := TaskcafeHandler{*repository, securityConfig}
|
||||
|
||||
@ -91,7 +103,7 @@ func NewRouter(dbConnection *sqlx.DB, emailConfig utils.EmailConfig, securityCon
|
||||
mux.Post("/auth/confirm", taskcafeHandler.ConfirmUser)
|
||||
mux.Post("/auth/register", taskcafeHandler.RegisterUser)
|
||||
})
|
||||
auth := AuthenticationMiddleware{securityConfig.Secret}
|
||||
auth := AuthenticationMiddleware{*repository}
|
||||
r.Group(func(mux chi.Router) {
|
||||
mux.Use(auth.Middleware)
|
||||
mux.Post("/users/me/avatar", taskcafeHandler.ProfileImageUpload)
|
||||
|
Reference in New Issue
Block a user