package graph import ( "context" "errors" "net/http" "os" "reflect" "time" "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler/extension" "github.com/99designs/gqlgen/graphql/handler/lru" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/99designs/gqlgen/graphql/playground" "github.com/google/uuid" "github.com/jordanknott/taskcafe/internal/auth" "github.com/jordanknott/taskcafe/internal/db" "github.com/jordanknott/taskcafe/internal/utils" log "github.com/sirupsen/logrus" "github.com/vektah/gqlparser/v2/gqlerror" ) // NewHandler returns a new graphql endpoint handler. func NewHandler(repo db.Repository) http.Handler { c := Config{ Resolvers: &Resolver{ Repository: repo, }, } c.Directives.HasRole = func(ctx context.Context, obj interface{}, next graphql.Resolver, roles []RoleLevel, level ActionLevel, typeArg ObjectType) (interface{}, error) { role, ok := GetUserRole(ctx) if !ok { return nil, errors.New("user ID is missing") } if role == "admin" { return next(ctx) } else if level == ActionLevelOrg { return nil, errors.New("must be an org admin") } var subjectID uuid.UUID in := graphql.GetResolverContext(ctx).Args["input"] val := reflect.ValueOf(in) // could be any underlying type if val.Kind() == reflect.Ptr { val = reflect.Indirect(val) } var fieldName string switch typeArg { case ObjectTypeTeam: fieldName = "TeamID" case ObjectTypeTask: fieldName = "TaskID" case ObjectTypeTaskGroup: fieldName = "TaskGroupID" default: fieldName = "ProjectID" } log.WithFields(log.Fields{"typeArg": typeArg, "fieldName": fieldName}).Info("getting field by name") subjectID, ok = val.FieldByName(fieldName).Interface().(uuid.UUID) if !ok { return nil, errors.New("error while casting subject uuid") } var err error if level == ActionLevelProject { if typeArg == ObjectTypeTask { log.WithFields(log.Fields{"subjectID": subjectID}).Info("fetching project ID using task ID") subjectID, err = repo.GetProjectIDForTask(ctx, subjectID) if err != nil { return nil, err } } else if typeArg == ObjectTypeTaskGroup { log.WithFields(log.Fields{"subjectID": subjectID}).Info("fetching project ID using task group ID") taskGroup, err := repo.GetTaskGroupByID(ctx, subjectID) if err != nil { return nil, err } subjectID = taskGroup.ProjectID } roles, err := GetProjectRoles(ctx, repo, subjectID) if err != nil { return nil, err } if roles.TeamRole == "admin" || roles.ProjectRole == "admin" { log.WithFields(log.Fields{"teamRole": roles.TeamRole, "projectRole": roles.ProjectRole}).Info("is team or project role") return next(ctx) } return nil, errors.New("must be a team or project admin") } else if level == ActionLevelTeam { userID, ok := GetUserID(ctx) if !ok { return nil, errors.New("user id is missing") } role, err := repo.GetTeamRoleForUserID(ctx, db.GetTeamRoleForUserIDParams{UserID: userID, TeamID: subjectID}) if err != nil { return nil, err } if role.RoleCode == "admin" { return next(ctx) } return nil, errors.New("must be a team admin") } return nil, errors.New("invalid path") } srv := handler.New(NewExecutableSchema(c)) srv.AddTransport(transport.Websocket{ KeepAlivePingInterval: 10 * time.Second, }) srv.AddTransport(transport.Options{}) srv.AddTransport(transport.GET{}) srv.AddTransport(transport.POST{}) srv.AddTransport(transport.MultipartForm{}) srv.SetQueryCache(lru.New(1000)) srv.Use(extension.AutomaticPersistedQuery{ Cache: lru.New(100), }) if isProd := os.Getenv("PRODUCTION") == "true"; isProd { srv.Use(extension.FixedComplexityLimit(10)) } else { srv.Use(extension.Introspection{}) } return srv } // NewPlaygroundHandler returns a new GraphQL Playground handler. func NewPlaygroundHandler(endpoint string) http.Handler { return playground.Handler("GraphQL Playground", endpoint) } // GetUserID retrieves the UserID out of a context func GetUserID(ctx context.Context) (uuid.UUID, bool) { userID, ok := ctx.Value(utils.UserIDKey).(uuid.UUID) return userID, ok } // GetUserRole retrieves the user role out of a context func GetUserRole(ctx context.Context) (auth.Role, bool) { role, ok := ctx.Value(utils.OrgRoleKey).(auth.Role) return role, ok } // GetUser retrieves both the user id & user role out of a context func GetUser(ctx context.Context) (uuid.UUID, auth.Role, bool) { userID, userOK := GetUserID(ctx) role, roleOK := GetUserRole(ctx) return userID, role, userOK && roleOK } // GetRestrictedMode retrieves the restricted mode code out of a context func GetRestrictedMode(ctx context.Context) (auth.RestrictedMode, bool) { restricted, ok := ctx.Value(utils.RestrictedModeKey).(auth.RestrictedMode) return restricted, ok } // GetProjectRoles retrieves the team & project role for the given project ID func GetProjectRoles(ctx context.Context, r db.Repository, projectID uuid.UUID) (db.GetUserRolesForProjectRow, error) { userID, ok := GetUserID(ctx) if !ok { return db.GetUserRolesForProjectRow{}, errors.New("user ID is not found") } return r.GetUserRolesForProject(ctx, db.GetUserRolesForProjectParams{UserID: userID, ProjectID: projectID}) } // ConvertToRoleCode converts a role code string to a RoleCode type func ConvertToRoleCode(r string) RoleCode { if r == RoleCodeAdmin.String() { return RoleCodeAdmin } if r == RoleCodeMember.String() { return RoleCodeMember } return RoleCodeObserver } // GetEntityType converts integer to EntityType enum func GetEntityType(entityType int32) EntityType { switch entityType { case 1: return EntityTypeTask default: panic("Not a valid entity type!") } } // GetActionType converts integer to ActionType enum func GetActionType(actionType int32) ActionType { switch actionType { case 1: return ActionTypeTaskMemberAdded default: panic("Not a valid entity type!") } } // NotFoundError creates a 404 gqlerror func NotFoundError(message string) error { return &gqlerror.Error{ Message: message, Extensions: map[string]interface{}{ "code": "404", }, } }