fix: secret key is no longer hard coded
the secret key for signing JWT tokens is now read from server.secret. if that does not exist, then a random UUID v4 is generated and used instead. a log warning is also shown.
This commit is contained in:
parent
9fdb3008db
commit
52c60abcd7
@ -1,4 +1,4 @@
|
|||||||
[general]
|
[server]
|
||||||
hostname = '0.0.0.0:3333'
|
hostname = '0.0.0.0:3333'
|
||||||
|
|
||||||
[email_notifications]
|
[email_notifications]
|
||||||
|
@ -7,8 +7,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var jwtKey = []byte("taskcafe_test_key")
|
|
||||||
|
|
||||||
// RestrictedMode is used restrict JWT access to just the install route
|
// RestrictedMode is used restrict JWT access to just the install route
|
||||||
type RestrictedMode string
|
type RestrictedMode string
|
||||||
|
|
||||||
@ -54,7 +52,7 @@ func (r *ErrMalformedToken) Error() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewAccessToken generates a new JWT access token with the correct claims
|
// NewAccessToken generates a new JWT access token with the correct claims
|
||||||
func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string) (string, error) {
|
func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string, jwtKey []byte) (string, error) {
|
||||||
role := RoleMember
|
role := RoleMember
|
||||||
if orgRole == "admin" {
|
if orgRole == "admin" {
|
||||||
role = RoleAdmin
|
role = RoleAdmin
|
||||||
@ -76,7 +74,7 @@ func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewAccessTokenCustomExpiration creates an access token with a custom duration
|
// NewAccessTokenCustomExpiration creates an access token with a custom duration
|
||||||
func NewAccessTokenCustomExpiration(userID string, dur time.Duration) (string, error) {
|
func NewAccessTokenCustomExpiration(userID string, dur time.Duration, jwtKey []byte) (string, error) {
|
||||||
accessExpirationTime := time.Now().Add(dur)
|
accessExpirationTime := time.Now().Add(dur)
|
||||||
accessClaims := &AccessTokenClaims{
|
accessClaims := &AccessTokenClaims{
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
@ -94,7 +92,7 @@ func NewAccessTokenCustomExpiration(userID string, dur time.Duration) (string, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateAccessToken validates a JWT access token and returns the contained claims or an error if it's invalid
|
// ValidateAccessToken validates a JWT access token and returns the contained claims or an error if it's invalid
|
||||||
func ValidateAccessToken(accessTokenString string) (AccessTokenClaims, error) {
|
func ValidateAccessToken(accessTokenString string, jwtKey []byte) (AccessTokenClaims, error) {
|
||||||
accessClaims := &AccessTokenClaims{}
|
accessClaims := &AccessTokenClaims{}
|
||||||
accessToken, err := jwt.ParseWithClaims(accessTokenString, accessClaims, func(token *jwt.Token) (interface{}, error) {
|
accessToken, err := jwt.ParseWithClaims(accessTokenString, accessClaims, func(token *jwt.Token) (interface{}, error) {
|
||||||
return jwtKey, nil
|
return jwtKey, nil
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
package commands
|
package commands
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jordanknott/taskcafe/internal/auth"
|
"github.com/jordanknott/taskcafe/internal/auth"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTokenCmd() *cobra.Command {
|
func newTokenCmd() *cobra.Command {
|
||||||
@ -15,13 +18,18 @@ func newTokenCmd() *cobra.Command {
|
|||||||
Short: "Create a long lived JWT token for dev purposes",
|
Short: "Create a long lived JWT token for dev purposes",
|
||||||
Long: "Create a long lived JWT token for dev purposes",
|
Long: "Create a long lived JWT token for dev purposes",
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
token, err := auth.NewAccessTokenCustomExpiration(args[0], time.Hour*24)
|
secret := viper.GetString("server.secret")
|
||||||
|
if strings.TrimSpace(secret) == "" {
|
||||||
|
return errors.New("server.secret must be set (TASKCAFE_SERVER_SECRET)")
|
||||||
|
}
|
||||||
|
token, err := auth.NewAccessTokenCustomExpiration(args[0], time.Hour*24, []byte(secret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("issue while creating access token")
|
log.WithError(err).Error("issue while creating access token")
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
fmt.Println(token)
|
fmt.Println(token)
|
||||||
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,11 +3,13 @@ package commands
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-migrate/migrate/v4"
|
"github.com/golang-migrate/migrate/v4"
|
||||||
"github.com/golang-migrate/migrate/v4/database/postgres"
|
"github.com/golang-migrate/migrate/v4/database/postgres"
|
||||||
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
|
||||||
@ -62,7 +64,12 @@ func newWebCmd() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.WithFields(log.Fields{"url": viper.GetString("server.hostname")}).Info("starting server")
|
log.WithFields(log.Fields{"url": viper.GetString("server.hostname")}).Info("starting server")
|
||||||
r, _ := route.NewRouter(db)
|
secret := viper.GetString("server.secret")
|
||||||
|
if strings.TrimSpace(secret) == "" {
|
||||||
|
log.Warn("server.secret is not set, generating a random secret")
|
||||||
|
secret = uuid.New().String()
|
||||||
|
}
|
||||||
|
r, _ := route.NewRouter(db, []byte(secret))
|
||||||
http.ListenAndServe(viper.GetString("server.hostname"), r)
|
http.ListenAndServe(viper.GetString("server.hostname"), r)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -14,8 +14,6 @@ import (
|
|||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var jwtKey = []byte("taskcafe_test_key")
|
|
||||||
|
|
||||||
type authResource struct{}
|
type authResource struct{}
|
||||||
|
|
||||||
// LoginRequestData is the request data when a user logs in
|
// LoginRequestData is the request data when a user logs in
|
||||||
@ -69,7 +67,7 @@ func (h *TaskcafeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Req
|
|||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.InstallOnly, user.RoleCode)
|
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.InstallOnly, user.RoleCode, h.jwtKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -123,7 +121,7 @@ func (h *TaskcafeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Req
|
|||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
accessTokenString, err := auth.NewAccessToken(token.UserID.String(), auth.Unrestricted, user.RoleCode)
|
accessTokenString, err := auth.NewAccessToken(token.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -190,7 +188,7 @@ func (h *TaskcafeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
|
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
|
||||||
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
|
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
|
||||||
|
|
||||||
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode)
|
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
@ -252,7 +250,7 @@ func (h *TaskcafeHandler) InstallHandler(w http.ResponseWriter, r *http.Request)
|
|||||||
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
|
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
|
||||||
|
|
||||||
log.WithField("userID", user.UserID.String()).Info("creating install access token")
|
log.WithField("userID", user.UserID.String()).Info("creating install access token")
|
||||||
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode)
|
accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// AuthenticationMiddleware is a middleware that requires a valid JWT token to be passed via the Authorization header
|
// AuthenticationMiddleware is a middleware that requires a valid JWT token to be passed via the Authorization header
|
||||||
func AuthenticationMiddleware(next http.Handler) http.Handler {
|
type AuthenticationMiddleware struct {
|
||||||
|
jwtKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware returns the middleware handler
|
||||||
|
func (m *AuthenticationMiddleware) Middleware(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
bearerTokenRaw := r.Header.Get("Authorization")
|
bearerTokenRaw := r.Header.Get("Authorization")
|
||||||
splitToken := strings.Split(bearerTokenRaw, "Bearer")
|
splitToken := strings.Split(bearerTokenRaw, "Bearer")
|
||||||
@ -21,7 +26,7 @@ func AuthenticationMiddleware(next http.Handler) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
accessTokenString := strings.TrimSpace(splitToken[1])
|
accessTokenString := strings.TrimSpace(splitToken[1])
|
||||||
accessClaims, err := auth.ValidateAccessToken(accessTokenString)
|
accessClaims, err := auth.ValidateAccessToken(accessTokenString, m.jwtKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(*auth.ErrExpiredToken); ok {
|
if _, ok := err.(*auth.ErrExpiredToken); ok {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
@ -60,10 +60,11 @@ func (h FrontendHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
// TaskcafeHandler contains all the route handlers
|
// TaskcafeHandler contains all the route handlers
|
||||||
type TaskcafeHandler struct {
|
type TaskcafeHandler struct {
|
||||||
repo db.Repository
|
repo db.Repository
|
||||||
|
jwtKey []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRouter creates a new router for chi
|
// NewRouter creates a new router for chi
|
||||||
func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) {
|
func NewRouter(dbConnection *sqlx.DB, jwtKey []byte) (chi.Router, error) {
|
||||||
formatter := new(log.TextFormatter)
|
formatter := new(log.TextFormatter)
|
||||||
formatter.TimestampFormat = "02-01-2006 15:04:05"
|
formatter.TimestampFormat = "02-01-2006 15:04:05"
|
||||||
formatter.FullTimestamp = true
|
formatter.FullTimestamp = true
|
||||||
@ -79,7 +80,7 @@ func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) {
|
|||||||
r.Use(middleware.Timeout(60 * time.Second))
|
r.Use(middleware.Timeout(60 * time.Second))
|
||||||
|
|
||||||
repository := db.NewRepository(dbConnection)
|
repository := db.NewRepository(dbConnection)
|
||||||
taskcafeHandler := TaskcafeHandler{*repository}
|
taskcafeHandler := TaskcafeHandler{*repository, jwtKey}
|
||||||
|
|
||||||
var imgServer = http.FileServer(http.Dir("./uploads/"))
|
var imgServer = http.FileServer(http.Dir("./uploads/"))
|
||||||
r.Group(func(mux chi.Router) {
|
r.Group(func(mux chi.Router) {
|
||||||
@ -88,8 +89,9 @@ func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) {
|
|||||||
mux.Mount("/uploads/", http.StripPrefix("/uploads/", imgServer))
|
mux.Mount("/uploads/", http.StripPrefix("/uploads/", imgServer))
|
||||||
|
|
||||||
})
|
})
|
||||||
|
auth := AuthenticationMiddleware{jwtKey}
|
||||||
r.Group(func(mux chi.Router) {
|
r.Group(func(mux chi.Router) {
|
||||||
mux.Use(AuthenticationMiddleware)
|
mux.Use(auth.Middleware)
|
||||||
mux.Post("/users/me/avatar", taskcafeHandler.ProfileImageUpload)
|
mux.Post("/users/me/avatar", taskcafeHandler.ProfileImageUpload)
|
||||||
mux.Post("/auth/install", taskcafeHandler.InstallHandler)
|
mux.Post("/auth/install", taskcafeHandler.InstallHandler)
|
||||||
mux.Handle("/graphql", graph.NewHandler(*repository))
|
mux.Handle("/graphql", graph.NewHandler(*repository))
|
||||||
|
Loading…
Reference in New Issue
Block a user