diff --git a/.gitignore b/.gitignore index efcbc24..4acaf5e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ srv/ -mcat +/mcat *.db *.db-wal *.db-shm diff --git a/cmd/mcat/main.go b/cmd/mcat/main.go new file mode 100644 index 0000000..16e6492 --- /dev/null +++ b/cmd/mcat/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +var version = "dev" + +func main() { + root := &cobra.Command{ + Use: "mcat", + Short: "MCIAS login policy tester", + } + + root.AddCommand(serverCmd()) + root.AddCommand(versionCmd()) + + if err := root.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func versionCmd() *cobra.Command { + return &cobra.Command{ + Use: "version", + Short: "Print mcat version", + Run: func(cmd *cobra.Command, args []string) { + fmt.Println(version) + }, + } +} diff --git a/cmd/mcat/server.go b/cmd/mcat/server.go new file mode 100644 index 0000000..22a2d4f --- /dev/null +++ b/cmd/mcat/server.go @@ -0,0 +1,116 @@ +package main + +import ( + "context" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/spf13/cobra" + + "git.wntrmute.dev/mc/mcat/internal/webserver" + "git.wntrmute.dev/mc/mcdsl/auth" + "git.wntrmute.dev/mc/mcdsl/config" + "git.wntrmute.dev/mc/mcdsl/httpserver" + "git.wntrmute.dev/mc/mcdsl/sso" +) + +// mcatConfig is the mcat-specific configuration. It embeds config.Base +// for the standard sections (server, mcias, log). mcat has no database +// or additional sections. +type mcatConfig struct { + config.Base + SSO ssoConfig `toml:"sso"` +} + +type ssoConfig struct { + RedirectURI string `toml:"redirect_uri"` +} + +func serverCmd() *cobra.Command { + var configPath string + + cmd := &cobra.Command{ + Use: "server", + Short: "Start the mcat web server", + RunE: func(_ *cobra.Command, _ []string) error { + return runServer(configPath) + }, + } + + cmd.Flags().StringVarP(&configPath, "config", "c", "mcat.toml", "path to config file") + + return cmd +} + +func runServer(configPath string) error { + cfg, err := config.Load[mcatConfig](configPath, "MCAT") + if err != nil { + return err + } + + var logLevel slog.Level + switch cfg.Log.Level { + case "debug": + logLevel = slog.LevelDebug + case "warn": + logLevel = slog.LevelWarn + case "error": + logLevel = slog.LevelError + default: + logLevel = slog.LevelInfo + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel})) + + authenticator, err := auth.New(cfg.MCIAS, logger) + if err != nil { + return err + } + + // Create SSO client if configured. + var ssoClient *sso.Client + if cfg.SSO.RedirectURI != "" { + ssoClient, err = sso.New(sso.Config{ + MciasURL: cfg.MCIAS.ServerURL, + ClientID: cfg.MCIAS.ServiceName, + RedirectURI: cfg.SSO.RedirectURI, + CACert: cfg.MCIAS.CACert, + }) + if err != nil { + return err + } + logger.Info("SSO enabled", "mcias", cfg.MCIAS.ServerURL) + } + + wsCfg := webserver.Config{ + ServiceName: cfg.MCIAS.ServiceName, + Tags: cfg.MCIAS.Tags, + } + srv, err := webserver.New(wsCfg, authenticator, logger, ssoClient) + if err != nil { + return err + } + + httpSrv := httpserver.New(cfg.Server, logger) + httpSrv.Router.Mount("/", srv.Handler()) + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + go func() { + logger.Info("starting mcat", "addr", cfg.Server.ListenAddr, "version", version) + if err := httpSrv.ListenAndServeTLS(); err != nil { + logger.Error("server error", "error", err) + os.Exit(1) + } + }() + + <-ctx.Done() + logger.Info("shutting down") + + shutdownCtx, cancel := context.WithTimeout(context.Background(), cfg.Server.ShutdownTimeout.Duration) + defer cancel() + + return httpSrv.Shutdown(shutdownCtx) +} diff --git a/go.mod b/go.mod index d69e000..4f70602 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module git.wntrmute.dev/mc/mcat go 1.25.7 require ( - git.wntrmute.dev/mc/mcdsl v1.2.0 + git.wntrmute.dev/mc/mcdsl v1.5.0 github.com/go-chi/chi/v5 v5.2.5 github.com/spf13/cobra v1.10.2 ) diff --git a/go.sum b/go.sum index 33672bf..e891815 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -git.wntrmute.dev/mc/mcdsl v1.2.0 h1:41hep7/PNZJfN0SN/nM+rQpyF1GSZcvNNjyVG81DI7U= -git.wntrmute.dev/mc/mcdsl v1.2.0/go.mod h1:lXYrAt74ZUix6rx9oVN8d2zH1YJoyp4uxPVKQ+SSxuM= +git.wntrmute.dev/mc/mcdsl v1.5.0 h1:JUlSYuvETRCycf+cZ56Gxp/1XZn0T7fOfWezM3m89qE= +git.wntrmute.dev/mc/mcdsl v1.5.0/go.mod h1:MhYahIu7Sg53lE2zpQ20nlrsoNRjQzOJBAlCmom2wJc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= diff --git a/internal/webserver/server.go b/internal/webserver/server.go index d845273..7393e14 100644 --- a/internal/webserver/server.go +++ b/internal/webserver/server.go @@ -13,6 +13,7 @@ import ( "git.wntrmute.dev/mc/mcdsl/auth" "git.wntrmute.dev/mc/mcdsl/csrf" "git.wntrmute.dev/mc/mcdsl/httpserver" + "git.wntrmute.dev/mc/mcdsl/sso" "git.wntrmute.dev/mc/mcdsl/web" mcatweb "git.wntrmute.dev/mc/mcat/web" @@ -33,16 +34,17 @@ type Config struct { // Server is the mcat web UI server. type Server struct { - wsCfg Config - auth *auth.Authenticator - logger *slog.Logger - csrf *csrf.Protect - staticFS fs.FS - handler http.Handler + wsCfg Config + auth *auth.Authenticator + logger *slog.Logger + csrf *csrf.Protect + staticFS fs.FS + handler http.Handler + ssoClient *sso.Client } // New creates a new web UI server. -func New(wsCfg Config, authenticator *auth.Authenticator, logger *slog.Logger) (*Server, error) { +func New(wsCfg Config, authenticator *auth.Authenticator, logger *slog.Logger, ssoClient *sso.Client) (*Server, error) { staticFS, err := fs.Sub(mcatweb.FS, "static") if err != nil { return nil, fmt.Errorf("webserver: static fs: %w", err) @@ -54,11 +56,12 @@ func New(wsCfg Config, authenticator *auth.Authenticator, logger *slog.Logger) ( } s := &Server{ - wsCfg: wsCfg, - auth: authenticator, - logger: logger, - csrf: csrf.New(secret, csrfCookieName, csrfFieldName), - staticFS: staticFS, + wsCfg: wsCfg, + auth: authenticator, + logger: logger, + csrf: csrf.New(secret, csrfCookieName, csrfFieldName), + staticFS: staticFS, + ssoClient: ssoClient, } r := chi.NewRouter() @@ -79,8 +82,14 @@ func (s *Server) registerRoutes(r chi.Router) { r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.FS(s.staticFS)))) r.Get("/", s.handleRoot) - r.Get("/login", s.handleLogin) - r.Post("/login", s.handleLogin) + if s.ssoClient != nil { + r.Get("/login", s.handleSSOLogin) + r.Get("/sso/redirect", s.handleSSORedirect) + r.Get("/sso/callback", s.handleSSOCallback) + } else { + r.Get("/login", s.handleLogin) + r.Post("/login", s.handleLogin) + } r.Post("/logout", s.requireAuth(s.handleLogout)) r.Get("/dashboard", s.requireAuth(s.handleDashboard)) } @@ -131,6 +140,37 @@ func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/dashboard", http.StatusFound) } +// handleSSOLogin renders a landing page with a "Sign in with MCIAS" button. +func (s *Server) handleSSOLogin(w http.ResponseWriter, r *http.Request) { + s.renderTemplate(w, "login.html", map[string]interface{}{ + "SSO": true, + }) +} + +// handleSSORedirect initiates the SSO redirect to MCIAS. +func (s *Server) handleSSORedirect(w http.ResponseWriter, r *http.Request) { + if err := sso.RedirectToLogin(w, r, s.ssoClient, "mcat"); err != nil { + s.logger.Error("sso: redirect to login", "error", err) + http.Error(w, "internal error", http.StatusInternalServerError) + } +} + +// handleSSOCallback exchanges the authorization code for a JWT and sets the session. +func (s *Server) handleSSOCallback(w http.ResponseWriter, r *http.Request) { + token, returnTo, err := sso.HandleCallback(w, r, s.ssoClient, "mcat") + if err != nil { + s.logger.Error("sso: callback", "error", err) + s.renderTemplate(w, "login.html", map[string]interface{}{ + "SSO": s.ssoClient != nil, + "Error": "Login failed. Please try again.", + }) + return + } + + web.SetSessionCookie(w, sessionCookieName, token) + http.Redirect(w, r, returnTo, http.StatusFound) +} + func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { token := web.GetSessionToken(r, sessionCookieName) if token != "" { diff --git a/web/templates/login.html b/web/templates/login.html index ad3c2d7..a7728cb 100644 --- a/web/templates/login.html +++ b/web/templates/login.html @@ -8,6 +8,12 @@
Sign In
{{if .Error}}
{{.Error}}
{{end}} + {{if .SSO}} +

Sign in to test MCIAS login policies.

+
+ Sign in with MCIAS +
+ {{else}}
{{csrfField}}
@@ -26,5 +32,6 @@
+ {{end}}
{{end}}