package config import ( "os" "path/filepath" "testing" "time" ) // validConfig returns a minimal valid TOML config string. func validConfig() string { return ` [server] listen_addr = "0.0.0.0:8443" tls_cert = "/srv/mcias/server.crt" tls_key = "/srv/mcias/server.key" [database] path = "/srv/mcias/mcias.db" [tokens] issuer = "https://auth.example.com" default_expiry = "720h" admin_expiry = "8h" service_expiry = "8760h" [argon2] time = 3 memory = 65536 threads = 4 [master_key] passphrase_env = "MCIAS_MASTER_PASSPHRASE" ` } func writeTempConfig(t *testing.T, content string) string { t.Helper() dir := t.TempDir() path := filepath.Join(dir, "mcias.toml") if err := os.WriteFile(path, []byte(content), 0600); err != nil { t.Fatalf("write temp config: %v", err) } return path } func TestLoadValidConfig(t *testing.T) { path := writeTempConfig(t, validConfig()) cfg, err := Load(path) if err != nil { t.Fatalf("Load returned error: %v", err) } if cfg.Server.ListenAddr != "0.0.0.0:8443" { t.Errorf("ListenAddr = %q, want %q", cfg.Server.ListenAddr, "0.0.0.0:8443") } if cfg.Tokens.Issuer != "https://auth.example.com" { t.Errorf("Issuer = %q, want %q", cfg.Tokens.Issuer, "https://auth.example.com") } if cfg.DefaultExpiry() != 720*time.Hour { t.Errorf("DefaultExpiry = %v, want %v", cfg.DefaultExpiry(), 720*time.Hour) } if cfg.AdminExpiry() != 8*time.Hour { t.Errorf("AdminExpiry = %v, want %v", cfg.AdminExpiry(), 8*time.Hour) } if cfg.ServiceExpiry() != 8760*time.Hour { t.Errorf("ServiceExpiry = %v, want %v", cfg.ServiceExpiry(), 8760*time.Hour) } if cfg.Argon2.Time != 3 { t.Errorf("Argon2.Time = %d, want 3", cfg.Argon2.Time) } if cfg.Argon2.Memory != 65536 { t.Errorf("Argon2.Memory = %d, want 65536", cfg.Argon2.Memory) } if cfg.MasterKey.PassphraseEnv != "MCIAS_MASTER_PASSPHRASE" { t.Errorf("MasterKey.PassphraseEnv = %q", cfg.MasterKey.PassphraseEnv) } } func TestLoadMissingFile(t *testing.T) { _, err := Load("/nonexistent/path/mcias.toml") if err == nil { t.Error("expected error for missing file, got nil") } } func TestLoadInvalidTOML(t *testing.T) { path := writeTempConfig(t, "this is not valid TOML {{{{") _, err := Load(path) if err == nil { t.Error("expected error for invalid TOML, got nil") } } func TestValidateMissingListenAddr(t *testing.T) { path := writeTempConfig(t, ` [server] tls_cert = "/etc/mcias/server.crt" tls_key = "/etc/mcias/server.key" [database] path = "/var/lib/mcias/mcias.db" [tokens] issuer = "https://auth.example.com" default_expiry = "720h" admin_expiry = "8h" service_expiry = "8760h" [argon2] time = 3 memory = 65536 threads = 4 [master_key] passphrase_env = "MCIAS_MASTER_PASSPHRASE" `) _, err := Load(path) if err == nil { t.Error("expected error for missing listen_addr, got nil") } } func TestValidateArgon2TooWeak(t *testing.T) { tests := []struct { name string time uint32 memory uint32 }{ {"time too low", 1, 65536}, {"memory too low", 3, 32768}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { content := validConfig() // Override argon2 section path := writeTempConfig(t, content) cfg, err := Load(path) if err != nil { t.Fatalf("baseline load failed: %v", err) } // Manually set unsafe params and re-validate cfg.Argon2.Time = tc.time cfg.Argon2.Memory = tc.memory if err := cfg.validate(); err == nil { t.Errorf("expected validation error for time=%d memory=%d, got nil", tc.time, tc.memory) } }) } } func TestValidateMasterKeyBothSet(t *testing.T) { path := writeTempConfig(t, ` [server] listen_addr = "0.0.0.0:8443" tls_cert = "/srv/mcias/server.crt" tls_key = "/srv/mcias/server.key" [database] path = "/srv/mcias/mcias.db" [tokens] issuer = "https://auth.example.com" default_expiry = "720h" admin_expiry = "8h" service_expiry = "8760h" [argon2] time = 3 memory = 65536 threads = 4 [master_key] passphrase_env = "MCIAS_MASTER_PASSPHRASE" keyfile = "/srv/mcias/master.key" `) _, err := Load(path) if err == nil { t.Error("expected error when both passphrase_env and keyfile are set, got nil") } } func TestValidateMasterKeyNoneSet(t *testing.T) { path := writeTempConfig(t, ` [server] listen_addr = "0.0.0.0:8443" tls_cert = "/srv/mcias/server.crt" tls_key = "/srv/mcias/server.key" [database] path = "/srv/mcias/mcias.db" [tokens] issuer = "https://auth.example.com" default_expiry = "720h" admin_expiry = "8h" service_expiry = "8760h" [argon2] time = 3 memory = 65536 threads = 4 [master_key] `) _, err := Load(path) if err == nil { t.Error("expected error when neither passphrase_env nor keyfile is set, got nil") } } // TestTrustedProxyValidation verifies that trusted_proxy must be a valid IP. func TestTrustedProxyValidation(t *testing.T) { tests := []struct { name string proxy string wantErr bool }{ {"empty is valid (disabled)", "", false}, {"valid IPv4", "127.0.0.1", false}, {"valid IPv6 loopback", "::1", false}, {"valid private IPv4", "10.0.0.1", false}, {"hostname rejected", "proxy.example.com", true}, {"CIDR rejected", "10.0.0.0/8", true}, {"garbage rejected", "not-an-ip", true}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { cfg, _ := Load(writeTempConfig(t, validConfig())) if cfg == nil { t.Fatal("baseline config load failed") } cfg.Server.TrustedProxy = tc.proxy err := cfg.validate() if tc.wantErr && err == nil { t.Errorf("expected validation error for proxy=%q, got nil", tc.proxy) } if !tc.wantErr && err != nil { t.Errorf("unexpected error for proxy=%q: %v", tc.proxy, err) } }) } } func TestSSOClientValidation(t *testing.T) { tests := []struct { name string extra string wantErr bool }{ { name: "valid single client", extra: ` [[sso.clients]] client_id = "mcr" redirect_uri = "https://mcr.example.com/sso/callback" service_name = "mcr" tags = ["env:restricted"] `, wantErr: false, }, { name: "valid multiple clients", extra: ` [[sso.clients]] client_id = "mcr" redirect_uri = "https://mcr.example.com/sso/callback" service_name = "mcr" [[sso.clients]] client_id = "mcat" redirect_uri = "https://mcat.example.com/sso/callback" service_name = "mcat" `, wantErr: false, }, { name: "missing client_id", extra: ` [[sso.clients]] redirect_uri = "https://mcr.example.com/sso/callback" service_name = "mcr" `, wantErr: true, }, { name: "missing redirect_uri", extra: ` [[sso.clients]] client_id = "mcr" service_name = "mcr" `, wantErr: true, }, { name: "http redirect_uri rejected", extra: ` [[sso.clients]] client_id = "mcr" redirect_uri = "http://mcr.example.com/sso/callback" service_name = "mcr" `, wantErr: true, }, { name: "missing service_name", extra: ` [[sso.clients]] client_id = "mcr" redirect_uri = "https://mcr.example.com/sso/callback" `, wantErr: true, }, { name: "duplicate client_id", extra: ` [[sso.clients]] client_id = "mcr" redirect_uri = "https://mcr.example.com/sso/callback" service_name = "mcr" [[sso.clients]] client_id = "mcr" redirect_uri = "https://other.example.com/sso/callback" service_name = "mcr2" `, wantErr: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { path := writeTempConfig(t, validConfig()+tc.extra) _, err := Load(path) if tc.wantErr && err == nil { t.Error("expected validation error, got nil") } if !tc.wantErr && err != nil { t.Errorf("unexpected error: %v", err) } }) } } func TestSSOClientLookup(t *testing.T) { path := writeTempConfig(t, validConfig()+` [[sso.clients]] client_id = "mcr" redirect_uri = "https://mcr.example.com/sso/callback" service_name = "mcr" tags = ["env:restricted"] `) cfg, err := Load(path) if err != nil { t.Fatalf("Load: %v", err) } cl := cfg.SSOClient("mcr") if cl == nil { t.Fatal("SSOClient(mcr) returned nil") } if cl.RedirectURI != "https://mcr.example.com/sso/callback" { t.Errorf("RedirectURI = %q", cl.RedirectURI) } if cl.ServiceName != "mcr" { t.Errorf("ServiceName = %q", cl.ServiceName) } if len(cl.Tags) != 1 || cl.Tags[0] != "env:restricted" { t.Errorf("Tags = %v", cl.Tags) } if cfg.SSOClient("nonexistent") != nil { t.Error("SSOClient(nonexistent) should return nil") } if !cfg.SSOEnabled() { t.Error("SSOEnabled() should return true") } } func TestSSODisabledByDefault(t *testing.T) { path := writeTempConfig(t, validConfig()) cfg, err := Load(path) if err != nil { t.Fatalf("Load: %v", err) } if cfg.SSOEnabled() { t.Error("SSOEnabled() should return false with no clients") } } func TestDurationParsing(t *testing.T) { var d duration if err := d.UnmarshalText([]byte("1h30m")); err != nil { t.Fatalf("unexpected error: %v", err) } if d.Duration != 90*time.Minute { t.Errorf("Duration = %v, want %v", d.Duration, 90*time.Minute) } if err := d.UnmarshalText([]byte("not-a-duration")); err == nil { t.Error("expected error for invalid duration, got nil") } }