initial commit

This commit is contained in:
Jordan Knott
2020-04-09 21:40:22 -05:00
commit 9611105364
141 changed files with 29236 additions and 0 deletions

113
api/router/auth.go Normal file
View File

@ -0,0 +1,113 @@
package router
import (
"encoding/json"
"net/http"
"time"
"github.com/go-chi/chi"
"github.com/google/uuid"
"github.com/jordanknott/project-citadel/api/pg"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
)
var jwtKey = []byte("citadel_test_key")
type authResource struct{}
func (h *CitadelHandler) RefreshTokenHandler(w http.ResponseWriter, r *http.Request) {
c, err := r.Cookie("refreshToken")
if err != nil {
if err == http.ErrNoCookie {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusBadRequest)
return
}
refreshTokenID := uuid.MustParse(c.Value)
token, err := h.repo.GetRefreshTokenByID(r.Context(), refreshTokenID)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
refreshCreatedAt := time.Now().UTC()
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), pg.CreateRefreshTokenParams{token.UserID, refreshCreatedAt, refreshExpiresAt})
err = h.repo.DeleteRefreshTokenByID(r.Context(), token.TokenID)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
accessTokenString, err := NewAccessToken("1")
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
w.Header().Set("Content-type", "application/json")
http.SetCookie(w, &http.Cookie{
Name: "refreshToken",
Value: refreshTokenString.TokenID.String(),
Expires: refreshExpiresAt,
HttpOnly: true,
})
json.NewEncoder(w).Encode(LoginResponseData{AccessToken: accessTokenString})
}
func (h *CitadelHandler) LoginHandler(w http.ResponseWriter, r *http.Request) {
var requestData LoginRequestData
err := json.NewDecoder(r.Body).Decode(&requestData)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
log.Debug("bad request body")
return
}
user, err := h.repo.GetUserAccountByUsername(r.Context(), requestData.Username)
if err != nil {
log.WithFields(log.Fields{
"username": requestData.Username,
}).Warn("user account not found")
w.WriteHeader(http.StatusUnauthorized)
return
}
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(requestData.Password))
if err != nil {
log.WithFields(log.Fields{
"password": requestData.Password,
"password_hash": user.PasswordHash,
}).Warn("password incorrect")
w.WriteHeader(http.StatusUnauthorized)
return
}
userID := uuid.MustParse("0183d9ab-d0ed-4c9b-a3df-77a0cdd93dca")
refreshCreatedAt := time.Now().UTC()
refreshExpiresAt := refreshCreatedAt.AddDate(0, 0, 1)
refreshTokenString, err := h.repo.CreateRefreshToken(r.Context(), pg.CreateRefreshTokenParams{userID, refreshCreatedAt, refreshExpiresAt})
accessTokenString, err := NewAccessToken("1")
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
w.Header().Set("Content-type", "application/json")
http.SetCookie(w, &http.Cookie{
Name: "refreshToken",
Value: refreshTokenString.TokenID.String(),
Expires: refreshExpiresAt,
HttpOnly: true,
})
json.NewEncoder(w).Encode(LoginResponseData{accessTokenString})
}
func (rs authResource) Routes(citadelHandler CitadelHandler) chi.Router {
r := chi.NewRouter()
r.Post("/login", citadelHandler.LoginHandler)
r.Post("/refresh_token", citadelHandler.RefreshTokenHandler)
return r
}

13
api/router/errors.go Normal file
View File

@ -0,0 +1,13 @@
package router
type ErrExpiredToken struct{}
func (r *ErrExpiredToken) Error() string {
return "token is expired"
}
type ErrMalformedToken struct{}
func (r *ErrMalformedToken) Error() string {
return "token is malformed"
}

7
api/router/handlers.go Normal file
View File

@ -0,0 +1,7 @@
package router
import "github.com/jordanknott/project-citadel/api/pg"
type CitadelHandler struct {
repo pg.Repository
}

93
api/router/logger.go Normal file
View File

@ -0,0 +1,93 @@
package router
import (
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/middleware"
"github.com/sirupsen/logrus"
)
// StructuredLogger is a simple, but powerful implementation of a custom structured
// logger backed on logrus. I encourage users to copy it, adapt it and make it their
// own. Also take a look at https://github.com/pressly/lg for a dedicated pkg based
// on this work, designed for context-based http routers.
func NewStructuredLogger(logger *logrus.Logger) func(next http.Handler) http.Handler {
return middleware.RequestLogger(&StructuredLogger{logger})
}
type StructuredLogger struct {
Logger *logrus.Logger
}
func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry {
entry := &StructuredLoggerEntry{Logger: logrus.NewEntry(l.Logger)}
logFields := logrus.Fields{}
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
logFields["req_id"] = reqID
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
logFields["http_scheme"] = scheme
logFields["http_proto"] = r.Proto
logFields["http_method"] = r.Method
logFields["remote_addr"] = r.RemoteAddr
logFields["user_agent"] = r.UserAgent()
logFields["uri"] = fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI)
entry.Logger = entry.Logger.WithFields(logFields)
return entry
}
type StructuredLoggerEntry struct {
Logger logrus.FieldLogger
}
func (l *StructuredLoggerEntry) Write(status, bytes int, elapsed time.Duration) {
l.Logger = l.Logger.WithFields(logrus.Fields{
"resp_status": status, "resp_bytes_length": bytes,
"resp_elapsed_ms": float64(elapsed.Nanoseconds()) / 1000000.0,
})
l.Logger.Infoln("request complete")
}
func (l *StructuredLoggerEntry) Panic(v interface{}, stack []byte) {
l.Logger = l.Logger.WithFields(logrus.Fields{
"stack": string(stack),
"panic": fmt.Sprintf("%+v", v),
})
}
// Helper methods used by the application to get the request-scoped
// logger entry and set additional fields between handlers.
//
// This is a useful pattern to use to set state on the entry as it
// passes through the handler chain, which at any point can be logged
// with a call to .Print(), .Info(), etc.
func GetLogEntry(r *http.Request) logrus.FieldLogger {
entry := middleware.GetLogEntry(r).(*StructuredLoggerEntry)
return entry.Logger
}
func LogEntrySetField(r *http.Request, key string, value interface{}) {
if entry, ok := r.Context().Value(middleware.LogEntryCtxKey).(*StructuredLoggerEntry); ok {
entry.Logger = entry.Logger.WithField(key, value)
}
}
func LogEntrySetFields(r *http.Request, fields map[string]interface{}) {
if entry, ok := r.Context().Value(middleware.LogEntryCtxKey).(*StructuredLoggerEntry); ok {
entry.Logger = entry.Logger.WithFields(fields)
}
}

44
api/router/middleware.go Normal file
View File

@ -0,0 +1,44 @@
package router
import (
"context"
"net/http"
"strings"
log "github.com/sirupsen/logrus"
)
func AuthenticationMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bearerTokenRaw := r.Header.Get("Authorization")
splitToken := strings.Split(bearerTokenRaw, "Bearer")
if len(splitToken) != 2 {
w.WriteHeader(http.StatusBadRequest)
return
}
accessTokenString := strings.TrimSpace(splitToken[1])
accessClaims, err := ValidateAccessToken(accessTokenString)
if err != nil {
if _, ok := err.(*ErrExpiredToken); ok {
w.Write([]byte(`{
"data": {},
"errors": [
{
"extensions": {
"code": "UNAUTHENTICATED"
}
}
]
}`))
return
}
log.Error(err)
w.WriteHeader(http.StatusBadRequest)
return
}
ctx := context.WithValue(r.Context(), "accessClaims", accessClaims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

28
api/router/models.go Normal file
View File

@ -0,0 +1,28 @@
package router
import (
"github.com/dgrijalva/jwt-go"
)
type AccessTokenClaims struct {
UserID string `json:"userId"`
jwt.StandardClaims
}
type RefreshTokenClaims struct {
UserID string `json:"userId"`
jwt.StandardClaims
}
type LoginRequestData struct {
Username string
Password string
}
type LoginResponseData struct {
AccessToken string `json:"accessToken"`
}
type RefreshTokenResponseData struct {
AccessToken string `json:"accessToken"`
}

61
api/router/router.go Normal file
View File

@ -0,0 +1,61 @@
package router
import (
"net/http"
"time"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/cors"
"github.com/jmoiron/sqlx"
log "github.com/sirupsen/logrus"
"github.com/jordanknott/project-citadel/api/graph"
"github.com/jordanknott/project-citadel/api/pg"
)
func (h *CitadelHandler) PingHandler(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("pong"))
}
func NewRouter(db *sqlx.DB) (chi.Router, error) {
formatter := new(log.TextFormatter)
formatter.TimestampFormat = "02-01-2006 15:04:05"
formatter.FullTimestamp = true
routerLogger := log.New()
routerLogger.SetLevel(log.WarnLevel)
routerLogger.Formatter = formatter
r := chi.NewRouter()
cors := cors.New(cors.Options{
// AllowedOrigins: []string{"https://foo.com"}, // Use this to allow specific origin hosts
AllowedOrigins: []string{"*"},
// AllowOriginFunc: func(r *http.Request, origin string) bool { return true },
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token", "Cookie"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: 300, // Maximum value not ignored by any of major browsers
})
r.Use(cors.Handler)
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(NewStructuredLogger(routerLogger))
r.Use(middleware.Recoverer)
r.Use(middleware.Timeout(60 * time.Second))
repository := pg.NewRepository(db)
citadelHandler := CitadelHandler{repository}
r.Group(func(mux chi.Router) {
mux.Mount("/auth", authResource{}.Routes(citadelHandler))
mux.Handle("/__graphql", graph.NewPlaygroundHandler("/graphql"))
})
r.Group(func(mux chi.Router) {
mux.Use(AuthenticationMiddleware)
mux.Get("/ping", citadelHandler.PingHandler)
mux.Handle("/graphql", graph.NewHandler(repository))
})
return r, nil
}

77
api/router/tokens.go Normal file
View File

@ -0,0 +1,77 @@
package router
import (
"time"
"github.com/dgrijalva/jwt-go"
log "github.com/sirupsen/logrus"
)
func NewAccessToken(userID string) (string, error) {
accessExpirationTime := time.Now().Add(5 * time.Second)
accessClaims := &AccessTokenClaims{
UserID: userID,
StandardClaims: jwt.StandardClaims{ExpiresAt: accessExpirationTime.Unix()},
}
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
accessTokenString, err := accessToken.SignedString(jwtKey)
if err != nil {
return "", err
}
return accessTokenString, nil
}
func NewAccessTokenCustomExpiration(userID string, dur time.Duration) (string, error) {
accessExpirationTime := time.Now().Add(dur)
accessClaims := &AccessTokenClaims{
UserID: userID,
StandardClaims: jwt.StandardClaims{ExpiresAt: accessExpirationTime.Unix()},
}
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
accessTokenString, err := accessToken.SignedString(jwtKey)
if err != nil {
return "", err
}
return accessTokenString, nil
}
func ValidateAccessToken(accessTokenString string) (AccessTokenClaims, error) {
accessClaims := &AccessTokenClaims{}
accessToken, err := jwt.ParseWithClaims(accessTokenString, accessClaims, func(token *jwt.Token) (interface{}, error) {
return jwtKey, nil
})
if accessToken.Valid {
log.WithFields(log.Fields{
"token": accessTokenString,
"timeToExpire": time.Unix(accessClaims.ExpiresAt, 0),
}).Info("token is valid")
return *accessClaims, nil
}
if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
return AccessTokenClaims{}, &ErrMalformedToken{}
} else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
return AccessTokenClaims{}, &ErrExpiredToken{}
}
}
return AccessTokenClaims{}, err
}
func NewRefreshToken(userID string) (string, time.Time, error) {
refreshExpirationTime := time.Now().Add(24 * time.Hour)
refreshClaims := &RefreshTokenClaims{
UserID: userID,
StandardClaims: jwt.StandardClaims{ExpiresAt: refreshExpirationTime.Unix()},
}
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
refreshTokenString, err := refreshToken.SignedString(jwtKey)
if err != nil {
return "", time.Time{}, err
}
return refreshTokenString, refreshExpirationTime, nil
}