package server import ( "context" "log/slog" "net/http" "strings" "time" mcdslauth "git.wntrmute.dev/mc/mcdsl/auth" ) type contextKey string const tokenInfoKey contextKey = "tokenInfo" // requireAuth returns middleware that validates Bearer tokens via MCIAS. func requireAuth(auth *mcdslauth.Authenticator) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := extractBearerToken(r) if token == "" { writeError(w, http.StatusUnauthorized, "authentication required") return } info, err := auth.ValidateToken(token) if err != nil { writeError(w, http.StatusUnauthorized, "invalid or expired token") return } ctx := context.WithValue(r.Context(), tokenInfoKey, info) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // requireAdmin is middleware that checks the caller has the admin role. func requireAdmin(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { info := tokenInfoFromContext(r.Context()) if info == nil || !info.IsAdmin { writeError(w, http.StatusForbidden, "admin role required") return } next.ServeHTTP(w, r) }) } // tokenInfoFromContext extracts the TokenInfo from the request context. func tokenInfoFromContext(ctx context.Context) *mcdslauth.TokenInfo { info, _ := ctx.Value(tokenInfoKey).(*mcdslauth.TokenInfo) return info } // extractBearerToken extracts a bearer token from the Authorization header. func extractBearerToken(r *http.Request) string { h := r.Header.Get("Authorization") if h == "" { return "" } const prefix = "Bearer " if !strings.HasPrefix(h, prefix) { return "" } return strings.TrimSpace(h[len(prefix):]) } // loggingMiddleware logs HTTP requests. func loggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(sw, r) logger.Info("http", "method", r.Method, "path", r.URL.Path, "status", sw.status, "duration", time.Since(start), "remote", r.RemoteAddr, ) }) } } type statusWriter struct { http.ResponseWriter status int } func (w *statusWriter) WriteHeader(code int) { w.status = code w.ResponseWriter.WriteHeader(code) }