diff --git a/conf/taskcafe.example.toml b/conf/taskcafe.example.toml index 99ecd4a..62154d3 100644 --- a/conf/taskcafe.example.toml +++ b/conf/taskcafe.example.toml @@ -1,4 +1,4 @@ -[general] +[server] hostname = '0.0.0.0:3333' [email_notifications] diff --git a/internal/auth/auth.go b/internal/auth/auth.go index dc8a13c..75ea34a 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -7,8 +7,6 @@ import ( log "github.com/sirupsen/logrus" ) -var jwtKey = []byte("taskcafe_test_key") - // RestrictedMode is used restrict JWT access to just the install route type RestrictedMode string @@ -54,7 +52,7 @@ func (r *ErrMalformedToken) Error() string { } // NewAccessToken generates a new JWT access token with the correct claims -func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string) (string, error) { +func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string, jwtKey []byte) (string, error) { role := RoleMember if orgRole == "admin" { role = RoleAdmin @@ -76,7 +74,7 @@ func NewAccessToken(userID string, restrictedMode RestrictedMode, orgRole string } // NewAccessTokenCustomExpiration creates an access token with a custom duration -func NewAccessTokenCustomExpiration(userID string, dur time.Duration) (string, error) { +func NewAccessTokenCustomExpiration(userID string, dur time.Duration, jwtKey []byte) (string, error) { accessExpirationTime := time.Now().Add(dur) accessClaims := &AccessTokenClaims{ UserID: userID, @@ -94,7 +92,7 @@ func NewAccessTokenCustomExpiration(userID string, dur time.Duration) (string, e } // ValidateAccessToken validates a JWT access token and returns the contained claims or an error if it's invalid -func ValidateAccessToken(accessTokenString string) (AccessTokenClaims, error) { +func ValidateAccessToken(accessTokenString string, jwtKey []byte) (AccessTokenClaims, error) { accessClaims := &AccessTokenClaims{} accessToken, err := jwt.ParseWithClaims(accessTokenString, accessClaims, func(token *jwt.Token) (interface{}, error) { return jwtKey, nil diff --git a/internal/commands/token.go b/internal/commands/token.go index e0dbd0c..9477a9a 100644 --- a/internal/commands/token.go +++ b/internal/commands/token.go @@ -1,12 +1,15 @@ package commands import ( + "errors" "fmt" + "strings" "time" "github.com/jordanknott/taskcafe/internal/auth" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/spf13/viper" ) func newTokenCmd() *cobra.Command { @@ -15,13 +18,18 @@ func newTokenCmd() *cobra.Command { Short: "Create a long lived JWT token for dev purposes", Long: "Create a long lived JWT token for dev purposes", Args: cobra.ExactArgs(1), - Run: func(cmd *cobra.Command, args []string) { - token, err := auth.NewAccessTokenCustomExpiration(args[0], time.Hour*24) + RunE: func(cmd *cobra.Command, args []string) error { + secret := viper.GetString("server.secret") + if strings.TrimSpace(secret) == "" { + return errors.New("server.secret must be set (TASKCAFE_SERVER_SECRET)") + } + token, err := auth.NewAccessTokenCustomExpiration(args[0], time.Hour*24, []byte(secret)) if err != nil { log.WithError(err).Error("issue while creating access token") - return + return err } fmt.Println(token) + return nil }, } } diff --git a/internal/commands/web.go b/internal/commands/web.go index 0adc14b..cf64bc6 100644 --- a/internal/commands/web.go +++ b/internal/commands/web.go @@ -3,11 +3,13 @@ package commands import ( "fmt" "net/http" + "strings" "time" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" "github.com/golang-migrate/migrate/v4/source/httpfs" + "github.com/google/uuid" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -62,7 +64,12 @@ func newWebCmd() *cobra.Command { } log.WithFields(log.Fields{"url": viper.GetString("server.hostname")}).Info("starting server") - r, _ := route.NewRouter(db) + secret := viper.GetString("server.secret") + if strings.TrimSpace(secret) == "" { + log.Warn("server.secret is not set, generating a random secret") + secret = uuid.New().String() + } + r, _ := route.NewRouter(db, []byte(secret)) http.ListenAndServe(viper.GetString("server.hostname"), r) return nil }, diff --git a/internal/route/auth.go b/internal/route/auth.go index 4b9e73c..14c332b 100644 --- a/internal/route/auth.go +++ b/internal/route/auth.go @@ -14,8 +14,6 @@ import ( "golang.org/x/crypto/bcrypt" ) -var jwtKey = []byte("taskcafe_test_key") - type authResource struct{} // LoginRequestData is the request data when a user logs in @@ -69,7 +67,7 @@ func (h *TaskcafeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Req w.WriteHeader(http.StatusInternalServerError) return } - accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.InstallOnly, user.RoleCode) + accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.InstallOnly, user.RoleCode, h.jwtKey) if err != nil { w.WriteHeader(http.StatusInternalServerError) } @@ -123,7 +121,7 @@ func (h *TaskcafeHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Req w.WriteHeader(http.StatusInternalServerError) } - accessTokenString, err := auth.NewAccessToken(token.UserID.String(), auth.Unrestricted, user.RoleCode) + accessTokenString, err := auth.NewAccessToken(token.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey) if err != nil { w.WriteHeader(http.StatusInternalServerError) } @@ -190,7 +188,7 @@ func (h *TaskcafeHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1) refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt}) - accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode) + accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey) if err != nil { w.WriteHeader(http.StatusInternalServerError) } @@ -252,7 +250,7 @@ func (h *TaskcafeHandler) InstallHandler(w http.ResponseWriter, r *http.Request) refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), db.CreateRefreshTokenParams{user.UserID, refreshCreatedAt, refreshExpiresAt}) log.WithField("userID", user.UserID.String()).Info("creating install access token") - accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode) + accessTokenString, err := auth.NewAccessToken(user.UserID.String(), auth.Unrestricted, user.RoleCode, h.jwtKey) if err != nil { w.WriteHeader(http.StatusInternalServerError) } diff --git a/internal/route/middleware.go b/internal/route/middleware.go index 7e1b4dc..4a1336f 100644 --- a/internal/route/middleware.go +++ b/internal/route/middleware.go @@ -12,7 +12,12 @@ import ( ) // AuthenticationMiddleware is a middleware that requires a valid JWT token to be passed via the Authorization header -func AuthenticationMiddleware(next http.Handler) http.Handler { +type AuthenticationMiddleware struct { + jwtKey []byte +} + +// Middleware returns the middleware handler +func (m *AuthenticationMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { bearerTokenRaw := r.Header.Get("Authorization") splitToken := strings.Split(bearerTokenRaw, "Bearer") @@ -21,7 +26,7 @@ func AuthenticationMiddleware(next http.Handler) http.Handler { return } accessTokenString := strings.TrimSpace(splitToken[1]) - accessClaims, err := auth.ValidateAccessToken(accessTokenString) + accessClaims, err := auth.ValidateAccessToken(accessTokenString, m.jwtKey) if err != nil { if _, ok := err.(*auth.ErrExpiredToken); ok { w.WriteHeader(http.StatusUnauthorized) diff --git a/internal/route/route.go b/internal/route/route.go index 6566736..0607542 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -59,11 +59,12 @@ func (h FrontendHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // TaskcafeHandler contains all the route handlers type TaskcafeHandler struct { - repo db.Repository + repo db.Repository + jwtKey []byte } // NewRouter creates a new router for chi -func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) { +func NewRouter(dbConnection *sqlx.DB, jwtKey []byte) (chi.Router, error) { formatter := new(log.TextFormatter) formatter.TimestampFormat = "02-01-2006 15:04:05" formatter.FullTimestamp = true @@ -79,7 +80,7 @@ func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) { r.Use(middleware.Timeout(60 * time.Second)) repository := db.NewRepository(dbConnection) - taskcafeHandler := TaskcafeHandler{*repository} + taskcafeHandler := TaskcafeHandler{*repository, jwtKey} var imgServer = http.FileServer(http.Dir("./uploads/")) r.Group(func(mux chi.Router) { @@ -88,8 +89,9 @@ func NewRouter(dbConnection *sqlx.DB) (chi.Router, error) { mux.Mount("/uploads/", http.StripPrefix("/uploads/", imgServer)) }) + auth := AuthenticationMiddleware{jwtKey} r.Group(func(mux chi.Router) { - mux.Use(AuthenticationMiddleware) + mux.Use(auth.Middleware) mux.Post("/users/me/avatar", taskcafeHandler.ProfileImageUpload) mux.Post("/auth/install", taskcafeHandler.InstallHandler) mux.Handle("/graphql", graph.NewHandler(*repository))