fix: secret key is no longer hard coded

the secret key for signing JWT tokens is now read from server.secret.

if that does not exist, then a random UUID v4 is generated and used
instead. a log warning is also shown.
This commit is contained in:
Jordan Knott 2020-09-12 18:03:17 -05:00
parent 9fdb3008db
commit 52c60abcd7
7 changed files with 40 additions and 22 deletions

View File

@ -1,4 +1,4 @@
[general] [server]
hostname = '0.0.0.0:3333' hostname = '0.0.0.0:3333'
[email_notifications] [email_notifications]

View File

@ -7,8 +7,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var jwtKey = []byte("taskcafe_test_key")
// RestrictedMode is used restrict JWT access to just the install route // RestrictedMode is used restrict JWT access to just the install route
type RestrictedMode string type RestrictedMode string
@ -54,7 +52,7 @@ func (r *ErrMalformedToken) Error() string {
} }
// NewAccessToken generates a new JWT access token with the correct claims // NewAccessToken generates a new JWT access token with the correct claims
func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string) (string, error) { func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string, jwtKey []byte) (string, error) {
role := RoleMember role := RoleMember
if orgRole == "admin" { if orgRole == "admin" {
role = RoleAdmin role = RoleAdmin
@ -76,7 +74,7 @@ func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string
} }
// NewAccessTokenCustomExpiration creates an access token with a custom duration // NewAccessTokenCustomExpiration creates an access token with a custom duration
func NewAccessTokenCustomExpiration(userID string, dur time.Duration) (string, error) { func NewAccessTokenCustomExpiration(userID string, dur time.Duration, jwtKey []byte) (string, error) {
accessExpirationTime := time.Now().Add(dur) accessExpirationTime := time.Now().Add(dur)
accessClaims := &AccessTokenClaims{ accessClaims := &AccessTokenClaims{
UserID: userID, UserID: userID,
@ -94,7 +92,7 @@ func NewAccessTokenCustomExpiration(userID string, dur time.Duration) (string, e
} }
// ValidateAccessToken validates a JWT access token and returns the contained claims or an error if it's invalid // ValidateAccessToken validates a JWT access token and returns the contained claims or an error if it's invalid
func ValidateAccessToken(accessTokenString string) (AccessTokenClaims, error) { func ValidateAccessToken(accessTokenString string, jwtKey []byte) (AccessTokenClaims, error) {
accessClaims := &AccessTokenClaims{} accessClaims := &AccessTokenClaims{}
accessToken, err := jwt.ParseWithClaims(accessTokenString, accessClaims, func(token *jwt.Token) (interface{}, error) { accessToken, err := jwt.ParseWithClaims(accessTokenString, accessClaims, func(token *jwt.Token) (interface{}, error) {
return jwtKey, nil return jwtKey, nil

View File

@ -1,12 +1,15 @@
package commands package commands
import ( import (
"errors"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/jordanknott/taskcafe/internal/auth" "github.com/jordanknott/taskcafe/internal/auth"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper"
) )
func newTokenCmd() *cobra.Command { func newTokenCmd() *cobra.Command {
@ -15,13 +18,18 @@ func newTokenCmd() *cobra.Command {
Short: "Create a long lived JWT token for dev purposes", Short: "Create a long lived JWT token for dev purposes",
Long: "Create a long lived JWT token for dev purposes", Long: "Create a long lived JWT token for dev purposes",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
token, err := auth.NewAccessTokenCustomExpiration(args[0], time.Hour*24) secret := viper.GetString("server.secret")
if strings.TrimSpace(secret) == "" {
return errors.New("server.secret must be set (TASKCAFE_SERVER_SECRET)")
}
token, err := auth.NewAccessTokenCustomExpiration(args[0], time.Hour*24, []byte(secret))
if err != nil { if err != nil {
log.WithError(err).Error("issue while creating access token") log.WithError(err).Error("issue while creating access token")
return return err
} }
fmt.Println(token) fmt.Println(token)
return nil
}, },
} }
} }

View File

@ -3,11 +3,13 @@ package commands
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres" "github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/httpfs" "github.com/golang-migrate/migrate/v4/source/httpfs"
"github.com/google/uuid"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
@ -62,7 +64,12 @@ func newWebCmd() *cobra.Command {
} }
log.WithFields(log.Fields{"url": viper.GetString("server.hostname")}).Info("starting server") log.WithFields(log.Fields{"url": viper.GetString("server.hostname")}).Info("starting server")
r, _ := route.NewRouter(db) secret := viper.GetString("server.secret")
if strings.TrimSpace(secret) == "" {
log.Warn("server.secret is not set, generating a random secret")
secret = uuid.New().String()
}
r, _ := route.NewRouter(db, []byte(secret))
http.ListenAndServe(viper.GetString("server.hostname"), r) http.ListenAndServe(viper.GetString("server.hostname"), r)
return nil return nil
}, },

View File

@ -14,8 +14,6 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
var jwtKey = []byte("taskcafe_test_key")
type authResource struct{} type authResource struct{}
// LoginRequestData is the request data when a user logs in // LoginRequestData is the request data when a user logs in
@ -69,7 +67,7 @@ func (h *TaskcafeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Req
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.InstallOnly, user.RoleCode) accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.InstallOnly, user.RoleCode, h.jwtKey)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
} }
@ -123,7 +121,7 @@ func (h *TaskcafeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Req
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
} }
accessTokenString, err := auth.NewAccessToken(token.UserID.String(), auth.Unrestricted, user.RoleCode) accessTokenString, err := auth.NewAccessToken(token.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
} }
@ -190,7 +188,7 @@ func (h *TaskcafeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1) refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt}) refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode) accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
} }
@ -252,7 +250,7 @@ func (h *TaskcafeHandler) InstallHandler(w http.ResponseWriter, r *http.Request)
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt}) refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
log.WithField("userID", user.UserID.String()).Info("creating install access token") log.WithField("userID", user.UserID.String()).Info("creating install access token")
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode) accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
} }

View File

@ -12,7 +12,12 @@ import (
) )
// AuthenticationMiddleware is a middleware that requires a valid JWT token to be passed via the Authorization header // AuthenticationMiddleware is a middleware that requires a valid JWT token to be passed via the Authorization header
func AuthenticationMiddleware(next http.Handler) http.Handler { type AuthenticationMiddleware struct {
jwtKey []byte
}
// Middleware returns the middleware handler
func (m *AuthenticationMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bearerTokenRaw := r.Header.Get("Authorization") bearerTokenRaw := r.Header.Get("Authorization")
splitToken := strings.Split(bearerTokenRaw, "Bearer") splitToken := strings.Split(bearerTokenRaw, "Bearer")
@ -21,7 +26,7 @@ func AuthenticationMiddleware(next http.Handler) http.Handler {
return return
} }
accessTokenString := strings.TrimSpace(splitToken[1]) accessTokenString := strings.TrimSpace(splitToken[1])
accessClaims, err := auth.ValidateAccessToken(accessTokenString) accessClaims, err := auth.ValidateAccessToken(accessTokenString, m.jwtKey)
if err != nil { if err != nil {
if _, ok := err.(*auth.ErrExpiredToken); ok { if _, ok := err.(*auth.ErrExpiredToken); ok {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)

View File

@ -59,11 +59,12 @@ func (h FrontendHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// TaskcafeHandler contains all the route handlers // TaskcafeHandler contains all the route handlers
type TaskcafeHandler struct { type TaskcafeHandler struct {
repo db.Repository repo db.Repository
jwtKey []byte
} }
// NewRouter creates a new router for chi // NewRouter creates a new router for chi
func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) { func NewRouter(dbConnection *sqlx.DB, jwtKey []byte) (chi.Router, error) {
formatter := new(log.TextFormatter) formatter := new(log.TextFormatter)
formatter.TimestampFormat = "02-01-2006 15:04:05" formatter.TimestampFormat = "02-01-2006 15:04:05"
formatter.FullTimestamp = true formatter.FullTimestamp = true
@ -79,7 +80,7 @@ func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) {
r.Use(middleware.Timeout(60 * time.Second)) r.Use(middleware.Timeout(60 * time.Second))
repository := db.NewRepository(dbConnection) repository := db.NewRepository(dbConnection)
taskcafeHandler := TaskcafeHandler{*repository} taskcafeHandler := TaskcafeHandler{*repository, jwtKey}
var imgServer = http.FileServer(http.Dir("./uploads/")) var imgServer = http.FileServer(http.Dir("./uploads/"))
r.Group(func(mux chi.Router) { r.Group(func(mux chi.Router) {
@ -88,8 +89,9 @@ func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) {
mux.Mount("/uploads/", http.StripPrefix("/uploads/", imgServer)) mux.Mount("/uploads/", http.StripPrefix("/uploads/", imgServer))
}) })
auth := AuthenticationMiddleware{jwtKey}
r.Group(func(mux chi.Router) { r.Group(func(mux chi.Router) {
mux.Use(AuthenticationMiddleware) mux.Use(auth.Middleware)
mux.Post("/users/me/avatar", taskcafeHandler.ProfileImageUpload) mux.Post("/users/me/avatar", taskcafeHandler.ProfileImageUpload)
mux.Post("/auth/install", taskcafeHandler.InstallHandler) mux.Post("/auth/install", taskcafeHandler.InstallHandler)
mux.Handle("/graphql", graph.NewHandler(*repository)) mux.Handle("/graphql", graph.NewHandler(*repository))