initial commit
This commit is contained in:
113
api/router/auth.go
Normal file
113
api/router/auth.go
Normal file
@ -0,0 +1,113 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jordanknott/project-citadel/api/pg"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var jwtKey = []byte("citadel_test_key")
|
||||
|
||||
type authResource struct{}
|
||||
|
||||
func (h *CitadelHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := r.Cookie("refreshToken")
|
||||
if err != nil {
|
||||
if err == http.ErrNoCookie {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
refreshTokenID := uuid.MustParse(c.Value)
|
||||
token, err := h.repo.GetRefreshTokenByID(r.Context(), refreshTokenID)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
refreshCreatedAt := time.Now().UTC()
|
||||
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
|
||||
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), pg.CreateRefreshTokenParams{token.UserID, refreshCreatedAt, refreshExpiresAt})
|
||||
|
||||
err = h.repo.DeleteRefreshTokenByID(r.Context(), token.TokenID)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
accessTokenString, err := NewAccessToken("1")
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-type", "application/json")
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refreshToken",
|
||||
Value: refreshTokenString.TokenID.String(),
|
||||
Expires: refreshExpiresAt,
|
||||
HttpOnly: true,
|
||||
})
|
||||
json.NewEncoder(w).Encode(LoginResponseData{AccessToken: accessTokenString})
|
||||
}
|
||||
|
||||
func (h *CitadelHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var requestData LoginRequestData
|
||||
err := json.NewDecoder(r.Body).Decode(&requestData)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
log.Debug("bad request body")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.repo.GetUserAccountByUsername(r.Context(), requestData.Username)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"username": requestData.Username,
|
||||
}).Warn("user account not found")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(requestData.Password))
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"password": requestData.Password,
|
||||
"password_hash": user.PasswordHash,
|
||||
}).Warn("password incorrect")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userID := uuid.MustParse("0183d9ab-d0ed-4c9b-a3df-77a0cdd93dca")
|
||||
refreshCreatedAt := time.Now().UTC()
|
||||
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
|
||||
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), pg.CreateRefreshTokenParams{userID, refreshCreatedAt, refreshExpiresAt})
|
||||
|
||||
accessTokenString, err := NewAccessToken("1")
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-type", "application/json")
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "refreshToken",
|
||||
Value: refreshTokenString.TokenID.String(),
|
||||
Expires: refreshExpiresAt,
|
||||
HttpOnly: true,
|
||||
})
|
||||
json.NewEncoder(w).Encode(LoginResponseData{accessTokenString})
|
||||
}
|
||||
|
||||
func (rs authResource) Routes(citadelHandler CitadelHandler) chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Post("/login", citadelHandler.LoginHandler)
|
||||
r.Post("/refresh_token", citadelHandler.RefreshTokenHandler)
|
||||
return r
|
||||
}
|
13
api/router/errors.go
Normal file
13
api/router/errors.go
Normal file
@ -0,0 +1,13 @@
|
||||
package router
|
||||
|
||||
type ErrExpiredToken struct{}
|
||||
|
||||
func (r *ErrExpiredToken) Error() string {
|
||||
return "token is expired"
|
||||
}
|
||||
|
||||
type ErrMalformedToken struct{}
|
||||
|
||||
func (r *ErrMalformedToken) Error() string {
|
||||
return "token is malformed"
|
||||
}
|
7
api/router/handlers.go
Normal file
7
api/router/handlers.go
Normal file
@ -0,0 +1,7 @@
|
||||
package router
|
||||
|
||||
import "github.com/jordanknott/project-citadel/api/pg"
|
||||
|
||||
type CitadelHandler struct {
|
||||
repo pg.Repository
|
||||
}
|
93
api/router/logger.go
Normal file
93
api/router/logger.go
Normal file
@ -0,0 +1,93 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// StructuredLogger is a simple, but powerful implementation of a custom structured
|
||||
// logger backed on logrus. I encourage users to copy it, adapt it and make it their
|
||||
// own. Also take a look at https://github.com/pressly/lg for a dedicated pkg based
|
||||
// on this work, designed for context-based http routers.
|
||||
|
||||
func NewStructuredLogger(logger *logrus.Logger) func(next http.Handler) http.Handler {
|
||||
return middleware.RequestLogger(&StructuredLogger{logger})
|
||||
}
|
||||
|
||||
type StructuredLogger struct {
|
||||
Logger *logrus.Logger
|
||||
}
|
||||
|
||||
func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry {
|
||||
entry := &StructuredLoggerEntry{Logger: logrus.NewEntry(l.Logger)}
|
||||
logFields := logrus.Fields{}
|
||||
|
||||
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
|
||||
logFields["req_id"] = reqID
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
logFields["http_scheme"] = scheme
|
||||
logFields["http_proto"] = r.Proto
|
||||
logFields["http_method"] = r.Method
|
||||
|
||||
logFields["remote_addr"] = r.RemoteAddr
|
||||
logFields["user_agent"] = r.UserAgent()
|
||||
|
||||
logFields["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
|
||||
|
||||
entry.Logger = entry.Logger.WithFields(logFields)
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
type StructuredLoggerEntry struct {
|
||||
Logger logrus.FieldLogger
|
||||
}
|
||||
|
||||
func (l *StructuredLoggerEntry) Write(status, bytes int, elapsed time.Duration) {
|
||||
l.Logger = l.Logger.WithFields(logrus.Fields{
|
||||
"resp_status": status, "resp_bytes_length": bytes,
|
||||
"resp_elapsed_ms": float64(elapsed.Nanoseconds()) / 1000000.0,
|
||||
})
|
||||
|
||||
l.Logger.Infoln("request complete")
|
||||
}
|
||||
|
||||
func (l *StructuredLoggerEntry) Panic(v interface{}, stack []byte) {
|
||||
l.Logger = l.Logger.WithFields(logrus.Fields{
|
||||
"stack": string(stack),
|
||||
"panic": fmt.Sprintf("%+v", v),
|
||||
})
|
||||
}
|
||||
|
||||
// Helper methods used by the application to get the request-scoped
|
||||
// logger entry and set additional fields between handlers.
|
||||
//
|
||||
// This is a useful pattern to use to set state on the entry as it
|
||||
// passes through the handler chain, which at any point can be logged
|
||||
// with a call to .Print(), .Info(), etc.
|
||||
|
||||
func GetLogEntry(r *http.Request) logrus.FieldLogger {
|
||||
entry := middleware.GetLogEntry(r).(*StructuredLoggerEntry)
|
||||
return entry.Logger
|
||||
}
|
||||
|
||||
func LogEntrySetField(r *http.Request, key string, value interface{}) {
|
||||
if entry, ok := r.Context().Value(middleware.LogEntryCtxKey).(*StructuredLoggerEntry); ok {
|
||||
entry.Logger = entry.Logger.WithField(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func LogEntrySetFields(r *http.Request, fields map[string]interface{}) {
|
||||
if entry, ok := r.Context().Value(middleware.LogEntryCtxKey).(*StructuredLoggerEntry); ok {
|
||||
entry.Logger = entry.Logger.WithFields(fields)
|
||||
}
|
||||
}
|
44
api/router/middleware.go
Normal file
44
api/router/middleware.go
Normal file
@ -0,0 +1,44 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func AuthenticationMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
bearerTokenRaw := r.Header.Get("Authorization")
|
||||
splitToken := strings.Split(bearerTokenRaw, "Bearer")
|
||||
if len(splitToken) != 2 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
accessTokenString := strings.TrimSpace(splitToken[1])
|
||||
accessClaims, err := ValidateAccessToken(accessTokenString)
|
||||
if err != nil {
|
||||
if _, ok := err.(*ErrExpiredToken); ok {
|
||||
w.Write([]byte(`{
|
||||
"data": {},
|
||||
"errors": [
|
||||
{
|
||||
"extensions": {
|
||||
"code": "UNAUTHENTICATED"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`))
|
||||
return
|
||||
}
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), "accessClaims", accessClaims)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
28
api/router/models.go
Normal file
28
api/router/models.go
Normal file
@ -0,0 +1,28 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
type AccessTokenClaims struct {
|
||||
UserID string `json:"userId"`
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
type RefreshTokenClaims struct {
|
||||
UserID string `json:"userId"`
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
type LoginRequestData struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type LoginResponseData struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
}
|
||||
|
||||
type RefreshTokenResponseData struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
}
|
61
api/router/router.go
Normal file
61
api/router/router.go
Normal file
@ -0,0 +1,61 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/jmoiron/sqlx"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/jordanknott/project-citadel/api/graph"
|
||||
"github.com/jordanknott/project-citadel/api/pg"
|
||||
)
|
||||
|
||||
func (h *CitadelHandler) PingHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("pong"))
|
||||
}
|
||||
|
||||
func NewRouter(db *sqlx.DB) (chi.Router, error) {
|
||||
formatter := new(log.TextFormatter)
|
||||
formatter.TimestampFormat = "02-01-2006 15:04:05"
|
||||
formatter.FullTimestamp = true
|
||||
|
||||
routerLogger := log.New()
|
||||
routerLogger.SetLevel(log.WarnLevel)
|
||||
routerLogger.Formatter = formatter
|
||||
r := chi.NewRouter()
|
||||
cors := cors.New(cors.Options{
|
||||
// AllowedOrigins: []string{"https://foo.com"}, // Use this to allow specific origin hosts
|
||||
AllowedOrigins: []string{"*"},
|
||||
// AllowOriginFunc: func(r *http.Request, origin string) bool { return true },
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token", "Cookie"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 300, // Maximum value not ignored by any of major browsers
|
||||
})
|
||||
r.Use(cors.Handler)
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(NewStructuredLogger(routerLogger))
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.Timeout(60 * time.Second))
|
||||
|
||||
repository := pg.NewRepository(db)
|
||||
citadelHandler := CitadelHandler{repository}
|
||||
|
||||
r.Group(func(mux chi.Router) {
|
||||
mux.Mount("/auth", authResource{}.Routes(citadelHandler))
|
||||
mux.Handle("/__graphql", graph.NewPlaygroundHandler("/graphql"))
|
||||
})
|
||||
r.Group(func(mux chi.Router) {
|
||||
mux.Use(AuthenticationMiddleware)
|
||||
mux.Get("/ping", citadelHandler.PingHandler)
|
||||
mux.Handle("/graphql", graph.NewHandler(repository))
|
||||
})
|
||||
|
||||
return r, nil
|
||||
}
|
77
api/router/tokens.go
Normal file
77
api/router/tokens.go
Normal file
@ -0,0 +1,77 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func NewAccessToken(userID string) (string, error) {
|
||||
accessExpirationTime := time.Now().Add(5 * time.Second)
|
||||
accessClaims := &AccessTokenClaims{
|
||||
UserID: userID,
|
||||
StandardClaims: jwt.StandardClaims{ExpiresAt: accessExpirationTime.Unix()},
|
||||
}
|
||||
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||||
accessTokenString, err := accessToken.SignedString(jwtKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accessTokenString, nil
|
||||
}
|
||||
|
||||
func NewAccessTokenCustomExpiration(userID string, dur time.Duration) (string, error) {
|
||||
accessExpirationTime := time.Now().Add(dur)
|
||||
accessClaims := &AccessTokenClaims{
|
||||
UserID: userID,
|
||||
StandardClaims: jwt.StandardClaims{ExpiresAt: accessExpirationTime.Unix()},
|
||||
}
|
||||
|
||||
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
||||
accessTokenString, err := accessToken.SignedString(jwtKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accessTokenString, nil
|
||||
}
|
||||
|
||||
func ValidateAccessToken(accessTokenString string) (AccessTokenClaims, error) {
|
||||
accessClaims := &AccessTokenClaims{}
|
||||
accessToken, err := jwt.ParseWithClaims(accessTokenString, accessClaims, func(token *jwt.Token) (interface{}, error) {
|
||||
return jwtKey, nil
|
||||
})
|
||||
|
||||
if accessToken.Valid {
|
||||
log.WithFields(log.Fields{
|
||||
"token": accessTokenString,
|
||||
"timeToExpire": time.Unix(accessClaims.ExpiresAt, 0),
|
||||
}).Info("token is valid")
|
||||
return *accessClaims, nil
|
||||
}
|
||||
|
||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
|
||||
return AccessTokenClaims{}, &ErrMalformedToken{}
|
||||
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
|
||||
return AccessTokenClaims{}, &ErrExpiredToken{}
|
||||
}
|
||||
}
|
||||
return AccessTokenClaims{}, err
|
||||
}
|
||||
|
||||
func NewRefreshToken(userID string) (string, time.Time, error) {
|
||||
refreshExpirationTime := time.Now().Add(24 * time.Hour)
|
||||
refreshClaims := &RefreshTokenClaims{
|
||||
UserID: userID,
|
||||
StandardClaims: jwt.StandardClaims{ExpiresAt: refreshExpirationTime.Unix()},
|
||||
}
|
||||
|
||||
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
||||
refreshTokenString, err := refreshToken.SignedString(jwtKey)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
return refreshTokenString, refreshExpirationTime, nil
|
||||
}
|
Reference in New Issue
Block a user