fix: secret key is no longer hard coded
the secret key for signing JWT tokens is now read from server.secret. if that does not exist, then a random UUID v4 is generated and used instead. a log warning is also shown.
This commit is contained in:
		@@ -1,4 +1,4 @@
 | 
				
			|||||||
[general]
 | 
					[server]
 | 
				
			||||||
hostname = '0.0.0.0:3333'
 | 
					hostname = '0.0.0.0:3333'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[email_notifications]
 | 
					[email_notifications]
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,8 +7,6 @@ import (
 | 
				
			|||||||
	log "github.com/sirupsen/logrus"
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var jwtKey = []byte("taskcafe_test_key")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// RestrictedMode is used restrict JWT access to just the install route
 | 
					// RestrictedMode is used restrict JWT access to just the install route
 | 
				
			||||||
type RestrictedMode string
 | 
					type RestrictedMode string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -54,7 +52,7 @@ func (r *ErrMalformedToken) Error() string {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewAccessToken generates a new JWT access token with the correct claims
 | 
					// NewAccessToken generates a new JWT access token with the correct claims
 | 
				
			||||||
func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string) (string, error) {
 | 
					func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string, jwtKey []byte) (string, error) {
 | 
				
			||||||
	role := RoleMember
 | 
						role := RoleMember
 | 
				
			||||||
	if orgRole == "admin" {
 | 
						if orgRole == "admin" {
 | 
				
			||||||
		role = RoleAdmin
 | 
							role = RoleAdmin
 | 
				
			||||||
@@ -76,7 +74,7 @@ func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewAccessTokenCustomExpiration creates an access token with a custom duration
 | 
					// NewAccessTokenCustomExpiration creates an access token with a custom duration
 | 
				
			||||||
func NewAccessTokenCustomExpiration(userID string, dur time.Duration) (string, error) {
 | 
					func NewAccessTokenCustomExpiration(userID string, dur time.Duration, jwtKey []byte) (string, error) {
 | 
				
			||||||
	accessExpirationTime := time.Now().Add(dur)
 | 
						accessExpirationTime := time.Now().Add(dur)
 | 
				
			||||||
	accessClaims := &AccessTokenClaims{
 | 
						accessClaims := &AccessTokenClaims{
 | 
				
			||||||
		UserID:         userID,
 | 
							UserID:         userID,
 | 
				
			||||||
@@ -94,7 +92,7 @@ func NewAccessTokenCustomExpiration(userID string, dur time.Duration) (string, e
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ValidateAccessToken validates a JWT access token and returns the contained claims or an error if it's invalid
 | 
					// ValidateAccessToken validates a JWT access token and returns the contained claims or an error if it's invalid
 | 
				
			||||||
func ValidateAccessToken(accessTokenString string) (AccessTokenClaims, error) {
 | 
					func ValidateAccessToken(accessTokenString string, jwtKey []byte) (AccessTokenClaims, error) {
 | 
				
			||||||
	accessClaims := &AccessTokenClaims{}
 | 
						accessClaims := &AccessTokenClaims{}
 | 
				
			||||||
	accessToken, err := jwt.ParseWithClaims(accessTokenString, accessClaims, func(token *jwt.Token) (interface{}, error) {
 | 
						accessToken, err := jwt.ParseWithClaims(accessTokenString, accessClaims, func(token *jwt.Token) (interface{}, error) {
 | 
				
			||||||
		return jwtKey, nil
 | 
							return jwtKey, nil
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,12 +1,15 @@
 | 
				
			|||||||
package commands
 | 
					package commands
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/jordanknott/taskcafe/internal/auth"
 | 
						"github.com/jordanknott/taskcafe/internal/auth"
 | 
				
			||||||
	log "github.com/sirupsen/logrus"
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
	"github.com/spf13/cobra"
 | 
						"github.com/spf13/cobra"
 | 
				
			||||||
 | 
						"github.com/spf13/viper"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func newTokenCmd() *cobra.Command {
 | 
					func newTokenCmd() *cobra.Command {
 | 
				
			||||||
@@ -15,13 +18,18 @@ func newTokenCmd() *cobra.Command {
 | 
				
			|||||||
		Short: "Create a long lived JWT token for dev purposes",
 | 
							Short: "Create a long lived JWT token for dev purposes",
 | 
				
			||||||
		Long:  "Create a long lived JWT token for dev purposes",
 | 
							Long:  "Create a long lived JWT token for dev purposes",
 | 
				
			||||||
		Args:  cobra.ExactArgs(1),
 | 
							Args:  cobra.ExactArgs(1),
 | 
				
			||||||
		Run: func(cmd *cobra.Command, args []string) {
 | 
							RunE: func(cmd *cobra.Command, args []string) error {
 | 
				
			||||||
			token, err := auth.NewAccessTokenCustomExpiration(args[0], time.Hour*24)
 | 
								secret := viper.GetString("server.secret")
 | 
				
			||||||
 | 
								if strings.TrimSpace(secret) == "" {
 | 
				
			||||||
 | 
									return errors.New("server.secret must be set (TASKCAFE_SERVER_SECRET)")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								token, err := auth.NewAccessTokenCustomExpiration(args[0], time.Hour*24, []byte(secret))
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				log.WithError(err).Error("issue while creating access token")
 | 
									log.WithError(err).Error("issue while creating access token")
 | 
				
			||||||
				return
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			fmt.Println(token)
 | 
								fmt.Println(token)
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,11 +3,13 @@ package commands
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/golang-migrate/migrate/v4"
 | 
						"github.com/golang-migrate/migrate/v4"
 | 
				
			||||||
	"github.com/golang-migrate/migrate/v4/database/postgres"
 | 
						"github.com/golang-migrate/migrate/v4/database/postgres"
 | 
				
			||||||
	"github.com/golang-migrate/migrate/v4/source/httpfs"
 | 
						"github.com/golang-migrate/migrate/v4/source/httpfs"
 | 
				
			||||||
 | 
						"github.com/google/uuid"
 | 
				
			||||||
	"github.com/spf13/cobra"
 | 
						"github.com/spf13/cobra"
 | 
				
			||||||
	"github.com/spf13/viper"
 | 
						"github.com/spf13/viper"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -62,7 +64,12 @@ func newWebCmd() *cobra.Command {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			log.WithFields(log.Fields{"url": viper.GetString("server.hostname")}).Info("starting server")
 | 
								log.WithFields(log.Fields{"url": viper.GetString("server.hostname")}).Info("starting server")
 | 
				
			||||||
			r, _ := route.NewRouter(db)
 | 
								secret := viper.GetString("server.secret")
 | 
				
			||||||
 | 
								if strings.TrimSpace(secret) == "" {
 | 
				
			||||||
 | 
									log.Warn("server.secret is not set, generating a random secret")
 | 
				
			||||||
 | 
									secret = uuid.New().String()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								r, _ := route.NewRouter(db, []byte(secret))
 | 
				
			||||||
			http.ListenAndServe(viper.GetString("server.hostname"), r)
 | 
								http.ListenAndServe(viper.GetString("server.hostname"), r)
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,8 +14,6 @@ import (
 | 
				
			|||||||
	"golang.org/x/crypto/bcrypt"
 | 
						"golang.org/x/crypto/bcrypt"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var jwtKey = []byte("taskcafe_test_key")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type authResource struct{}
 | 
					type authResource struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// LoginRequestData is the request data when a user logs in
 | 
					// LoginRequestData is the request data when a user logs in
 | 
				
			||||||
@@ -69,7 +67,7 @@ func (h *TaskcafeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Req
 | 
				
			|||||||
			w.WriteHeader(http.StatusInternalServerError)
 | 
								w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.InstallOnly, user.RoleCode)
 | 
							accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.InstallOnly, user.RoleCode, h.jwtKey)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			w.WriteHeader(http.StatusInternalServerError)
 | 
								w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -123,7 +121,7 @@ func (h *TaskcafeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Req
 | 
				
			|||||||
		w.WriteHeader(http.StatusInternalServerError)
 | 
							w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	accessTokenString, err := auth.NewAccessToken(token.UserID.String(), auth.Unrestricted, user.RoleCode)
 | 
						accessTokenString, err := auth.NewAccessToken(token.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		w.WriteHeader(http.StatusInternalServerError)
 | 
							w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -190,7 +188,7 @@ func (h *TaskcafeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
 | 
						refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
 | 
				
			||||||
	refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
 | 
						refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode)
 | 
						accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		w.WriteHeader(http.StatusInternalServerError)
 | 
							w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -252,7 +250,7 @@ func (h *TaskcafeHandler) InstallHandler(w http.ResponseWriter, r *http.Request)
 | 
				
			|||||||
	refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
 | 
						refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.WithField("userID", user.UserID.String()).Info("creating install access token")
 | 
						log.WithField("userID", user.UserID.String()).Info("creating install access token")
 | 
				
			||||||
	accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode)
 | 
						accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		w.WriteHeader(http.StatusInternalServerError)
 | 
							w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,12 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// AuthenticationMiddleware is a middleware that requires a valid JWT token to be passed via the Authorization header
 | 
					// AuthenticationMiddleware is a middleware that requires a valid JWT token to be passed via the Authorization header
 | 
				
			||||||
func AuthenticationMiddleware(next http.Handler) http.Handler {
 | 
					type AuthenticationMiddleware struct {
 | 
				
			||||||
 | 
						jwtKey []byte
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Middleware returns the middleware handler
 | 
				
			||||||
 | 
					func (m *AuthenticationMiddleware) Middleware(next http.Handler) http.Handler {
 | 
				
			||||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
						return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
		bearerTokenRaw := r.Header.Get("Authorization")
 | 
							bearerTokenRaw := r.Header.Get("Authorization")
 | 
				
			||||||
		splitToken := strings.Split(bearerTokenRaw, "Bearer")
 | 
							splitToken := strings.Split(bearerTokenRaw, "Bearer")
 | 
				
			||||||
@@ -21,7 +26,7 @@ func AuthenticationMiddleware(next http.Handler) http.Handler {
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		accessTokenString := strings.TrimSpace(splitToken[1])
 | 
							accessTokenString := strings.TrimSpace(splitToken[1])
 | 
				
			||||||
		accessClaims, err := auth.ValidateAccessToken(accessTokenString)
 | 
							accessClaims, err := auth.ValidateAccessToken(accessTokenString, m.jwtKey)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			if _, ok := err.(*auth.ErrExpiredToken); ok {
 | 
								if _, ok := err.(*auth.ErrExpiredToken); ok {
 | 
				
			||||||
				w.WriteHeader(http.StatusUnauthorized)
 | 
									w.WriteHeader(http.StatusUnauthorized)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -60,10 +60,11 @@ func (h FrontendHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
// TaskcafeHandler contains all the route handlers
 | 
					// TaskcafeHandler contains all the route handlers
 | 
				
			||||||
type TaskcafeHandler struct {
 | 
					type TaskcafeHandler struct {
 | 
				
			||||||
	repo   db.Repository
 | 
						repo   db.Repository
 | 
				
			||||||
 | 
						jwtKey []byte
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewRouter creates a new router for chi
 | 
					// NewRouter creates a new router for chi
 | 
				
			||||||
func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) {
 | 
					func NewRouter(dbConnection *sqlx.DB, jwtKey []byte) (chi.Router, error) {
 | 
				
			||||||
	formatter := new(log.TextFormatter)
 | 
						formatter := new(log.TextFormatter)
 | 
				
			||||||
	formatter.TimestampFormat = "02-01-2006 15:04:05"
 | 
						formatter.TimestampFormat = "02-01-2006 15:04:05"
 | 
				
			||||||
	formatter.FullTimestamp = true
 | 
						formatter.FullTimestamp = true
 | 
				
			||||||
@@ -79,7 +80,7 @@ func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) {
 | 
				
			|||||||
	r.Use(middleware.Timeout(60 * time.Second))
 | 
						r.Use(middleware.Timeout(60 * time.Second))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	repository := db.NewRepository(dbConnection)
 | 
						repository := db.NewRepository(dbConnection)
 | 
				
			||||||
	taskcafeHandler := TaskcafeHandler{*repository}
 | 
						taskcafeHandler := TaskcafeHandler{*repository, jwtKey}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var imgServer = http.FileServer(http.Dir("./uploads/"))
 | 
						var imgServer = http.FileServer(http.Dir("./uploads/"))
 | 
				
			||||||
	r.Group(func(mux chi.Router) {
 | 
						r.Group(func(mux chi.Router) {
 | 
				
			||||||
@@ -88,8 +89,9 @@ func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) {
 | 
				
			|||||||
		mux.Mount("/uploads/", http.StripPrefix("/uploads/", imgServer))
 | 
							mux.Mount("/uploads/", http.StripPrefix("/uploads/", imgServer))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
						auth := AuthenticationMiddleware{jwtKey}
 | 
				
			||||||
	r.Group(func(mux chi.Router) {
 | 
						r.Group(func(mux chi.Router) {
 | 
				
			||||||
		mux.Use(AuthenticationMiddleware)
 | 
							mux.Use(auth.Middleware)
 | 
				
			||||||
		mux.Post("/users/me/avatar", taskcafeHandler.ProfileImageUpload)
 | 
							mux.Post("/users/me/avatar", taskcafeHandler.ProfileImageUpload)
 | 
				
			||||||
		mux.Post("/auth/install", taskcafeHandler.InstallHandler)
 | 
							mux.Post("/auth/install", taskcafeHandler.InstallHandler)
 | 
				
			||||||
		mux.Handle("/graphql", graph.NewHandler(*repository))
 | 
							mux.Handle("/graphql", graph.NewHandler(*repository))
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user