package webserver import ( "crypto/rand" "errors" "fmt" "io/fs" "log/slog" "net/http" "github.com/go-chi/chi/v5" "git.wntrmute.dev/mc/mcdsl/auth" "git.wntrmute.dev/mc/mcdsl/csrf" "git.wntrmute.dev/mc/mcdsl/httpserver" "git.wntrmute.dev/mc/mcdsl/web" mcatweb "git.wntrmute.dev/mc/mcat/web" ) const ( sessionCookieName = "mcat_token" csrfCookieName = "mcat_csrf" csrfFieldName = "csrf_token" ) // Config holds the webserver configuration. Extracted from the service // config so the webserver doesn't depend on the full config package. type Config struct { ServiceName string Tags []string } // Server is the mcat web UI server. type Server struct { wsCfg Config auth *auth.Authenticator logger *slog.Logger csrf *csrf.Protect staticFS fs.FS handler http.Handler } // New creates a new web UI server. func New(wsCfg Config, authenticator *auth.Authenticator, logger *slog.Logger) (*Server, error) { staticFS, err := fs.Sub(mcatweb.FS, "static") if err != nil { return nil, fmt.Errorf("webserver: static fs: %w", err) } secret := make([]byte, 32) if _, err := rand.Read(secret); err != nil { return nil, fmt.Errorf("webserver: generate CSRF secret: %w", err) } s := &Server{ wsCfg: wsCfg, auth: authenticator, logger: logger, csrf: csrf.New(secret, csrfCookieName, csrfFieldName), staticFS: staticFS, } r := chi.NewRouter() r.Use(s.loggingMiddleware) r.Use(s.csrf.Middleware) s.registerRoutes(r) s.handler = r return s, nil } // Handler returns the HTTP handler for the web server. func (s *Server) Handler() http.Handler { return s.handler } func (s *Server) registerRoutes(r chi.Router) { r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.FS(s.staticFS)))) r.Get("/", s.handleRoot) r.Get("/login", s.handleLogin) r.Post("/login", s.handleLogin) r.Post("/logout", s.requireAuth(s.handleLogout)) r.Get("/dashboard", s.requireAuth(s.handleDashboard)) } func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) { token := web.GetSessionToken(r, sessionCookieName) if token != "" { if _, err := s.auth.ValidateToken(token); err == nil { http.Redirect(w, r, "/dashboard", http.StatusFound) return } } http.Redirect(w, r, "/login", http.StatusFound) } func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { s.renderTemplate(w, "login.html", map[string]interface{}{}) return } if err := r.ParseForm(); err != nil { //nolint:gosec // form size bounded by http.Server ReadTimeout s.renderTemplate(w, "login.html", map[string]interface{}{ "Error": "Invalid form data.", }) return } username := r.FormValue("username") //nolint:gosec // parsed above password := r.FormValue("password") //nolint:gosec // parsed above totpCode := r.FormValue("totp_code") //nolint:gosec // parsed above token, _, err := s.auth.Login(username, password, totpCode) if err != nil { msg := "Login failed." if errors.Is(err, auth.ErrInvalidCredentials) { msg = "Invalid username or password." } else if errors.Is(err, auth.ErrForbidden) { msg = "Login denied by policy." } s.renderTemplate(w, "login.html", map[string]interface{}{ "Error": msg, }) return } web.SetSessionCookie(w, sessionCookieName, token) http.Redirect(w, r, "/dashboard", http.StatusFound) } func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { token := web.GetSessionToken(r, sessionCookieName) if token != "" { if err := s.auth.Logout(token); err != nil { s.logger.Warn("logout failed", "error", err) } } web.ClearSessionCookie(w, sessionCookieName) http.Redirect(w, r, "/login", http.StatusFound) } func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) { info := auth.TokenInfoFromContext(r.Context()) s.renderTemplate(w, "dashboard.html", map[string]interface{}{ "Username": info.Username, "Roles": info.Roles, "IsAdmin": info.IsAdmin, "ServiceName": s.wsCfg.ServiceName, "Tags": s.wsCfg.Tags, }) } // requireAuth wraps a handler that requires a valid MCIAS session. func (s *Server) requireAuth(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { token := web.GetSessionToken(r, sessionCookieName) if token == "" { http.Redirect(w, r, "/login", http.StatusFound) return } info, err := s.auth.ValidateToken(token) if err != nil { http.Redirect(w, r, "/login", http.StatusFound) return } ctx := auth.ContextWithTokenInfo(r.Context(), info) next(w, r.WithContext(ctx)) } } func (s *Server) renderTemplate(w http.ResponseWriter, name string, data interface{}) { web.RenderTemplate(w, mcatweb.FS, name, data, s.csrf.TemplateFunc(w)) } func (s *Server) loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sw := &httpserver.StatusWriter{ResponseWriter: w, Status: http.StatusOK} next.ServeHTTP(sw, r) s.logger.Info("http", "method", r.Method, "path", r.URL.Path, "status", sw.Status, "remote", r.RemoteAddr, ) }) }