diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4852396 --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +# Build output +mciassrv +mciasctl +*.exe + +# Database files +*.db +*.db-wal +*.db-shm + +# Test artifacts +*.out +*.test +coverage.html +coverage.txt + +# Config files with secrets (keep example configs) +mcias.toml + +# Editor artifacts +.DS_Store +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Go workspace files +go.work +go.work.sum diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..01203d1 --- /dev/null +++ b/go.mod @@ -0,0 +1,23 @@ +module git.wntrmute.dev/kyle/mcias + +go 1.25.0 + +require ( + github.com/golang-jwt/jwt/v5 v5.3.1 + github.com/google/uuid v1.6.0 + github.com/pelletier/go-toml/v2 v2.2.4 + golang.org/x/crypto v0.48.0 + modernc.org/sqlite v1.46.1 +) + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/sys v0.41.0 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7eb2cf9 --- /dev/null +++ b/go.sum @@ -0,0 +1,59 @@ +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= +golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= +golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..4d402f1 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,250 @@ +// Package auth implements login, TOTP verification, and credential management. +// +// Security design: +// - All credential comparisons use constant-time operations to resist timing +// side-channels. crypto/subtle.ConstantTimeCompare is used wherever secrets +// are compared. +// - On any login failure the error returned to the caller is always generic +// ("invalid credentials"), regardless of which step failed, to prevent +// user enumeration. +// - TOTP uses a ±1 time-step window (±30s) per RFC 6238 recommendation. +// - PHC string format is used for password hashes, enabling transparent +// parameter upgrades without re-migration. +package auth + +import ( + "crypto/hmac" + "crypto/sha1" //nolint:gosec // SHA-1 is required by RFC 6238 for TOTP; not used for collision resistance. + "crypto/subtle" + "encoding/base32" + encodingbase64 "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "math" + "strconv" + "strings" + "time" + + "golang.org/x/crypto/argon2" + + "git.wntrmute.dev/kyle/mcias/internal/crypto" +) + +// ErrInvalidCredentials is returned for any authentication failure. +// It intentionally does not distinguish between wrong password, wrong TOTP, +// or unknown user — to prevent information leakage to the caller. +var ErrInvalidCredentials = errors.New("auth: invalid credentials") + +// ArgonParams holds Argon2id hashing parameters embedded in PHC strings. +type ArgonParams struct { + Time uint32 + Memory uint32 // KiB + Threads uint8 +} + +// DefaultArgonParams returns OWASP-2023-compliant parameters. +// Security: These meet the OWASP minimum (time=2, memory=64MiB) and provide +// additional margin with time=3. +func DefaultArgonParams() ArgonParams { + return ArgonParams{ + Time: 3, + Memory: 64 * 1024, // 64 MiB in KiB + Threads: 4, + } +} + +// HashPassword hashes a password using Argon2id and returns a PHC-format string. +// A random 16-byte salt is generated via crypto/rand for each call. +// +// Security: Argon2id is selected per OWASP recommendation; it resists both +// side-channel and GPU brute-force attacks. The random salt ensures each hash +// is unique even for identical passwords. +func HashPassword(password string, params ArgonParams) (string, error) { + if password == "" { + return "", errors.New("auth: password must not be empty") + } + + // Generate a cryptographically-random 16-byte salt. + salt, err := crypto.RandomBytes(16) + if err != nil { + return "", fmt.Errorf("auth: generate salt: %w", err) + } + + hash := argon2.IDKey( + []byte(password), + salt, + params.Time, + params.Memory, + params.Threads, + 32, // 256-bit output + ) + + // PHC format: $argon2id$v=19$m=,t=,p=

$$ + saltB64 := encodingbase64.RawStdEncoding.EncodeToString(salt) + hashB64 := encodingbase64.RawStdEncoding.EncodeToString(hash) + phc := fmt.Sprintf( + "$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s", + params.Memory, params.Time, params.Threads, + saltB64, hashB64, + ) + return phc, nil +} + +// VerifyPassword checks a plaintext password against a PHC-format Argon2id hash. +// Returns true if the password matches. +// +// Security: Comparison uses crypto/subtle.ConstantTimeCompare after computing +// the candidate hash with identical parameters and the stored salt. This +// prevents timing attacks that could reveal whether a password is "closer" to +// the correct value. +func VerifyPassword(password, phcHash string) (bool, error) { + params, salt, expectedHash, err := parsePHC(phcHash) + if err != nil { + return false, fmt.Errorf("auth: parse PHC hash: %w", err) + } + + candidateHash := argon2.IDKey( + []byte(password), + salt, + params.Time, + params.Memory, + params.Threads, + uint32(len(expectedHash)), + ) + + // Security: constant-time comparison prevents timing side-channels. + if subtle.ConstantTimeCompare(candidateHash, expectedHash) != 1 { + return false, nil + } + return true, nil +} + +// parsePHC parses a PHC-format Argon2id hash string. +// Expected format: $argon2id$v=19$m=,t=,p=

$$ +func parsePHC(phc string) (ArgonParams, []byte, []byte, error) { + parts := strings.Split(phc, "$") + // Expected: ["", "argon2id", "v=19", "m=M,t=T,p=P", "salt", "hash"] + if len(parts) != 6 { + return ArgonParams{}, nil, nil, fmt.Errorf("auth: invalid PHC format: %d parts", len(parts)) + } + if parts[1] != "argon2id" { + return ArgonParams{}, nil, nil, fmt.Errorf("auth: unsupported algorithm %q", parts[1]) + } + + var params ArgonParams + for _, kv := range strings.Split(parts[3], ",") { + eq := strings.IndexByte(kv, '=') + if eq < 0 { + return ArgonParams{}, nil, nil, fmt.Errorf("auth: invalid PHC param %q", kv) + } + k, v := kv[:eq], kv[eq+1:] + n, err := strconv.ParseUint(v, 10, 32) + if err != nil { + return ArgonParams{}, nil, nil, fmt.Errorf("auth: parse PHC param %q: %w", kv, err) + } + switch k { + case "m": + params.Memory = uint32(n) + case "t": + params.Time = uint32(n) + case "p": + params.Threads = uint8(n) + } + } + + salt, err := encodingbase64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return ArgonParams{}, nil, nil, fmt.Errorf("auth: decode salt: %w", err) + } + hash, err := encodingbase64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return ArgonParams{}, nil, nil, fmt.Errorf("auth: decode hash: %w", err) + } + return params, salt, hash, nil +} + +// ValidateTOTP checks a 6-digit TOTP code against a raw TOTP secret (bytes). +// A ±1 time-step window (±30s) is allowed to accommodate clock skew. +// +// Security: +// - Comparison uses crypto/subtle.ConstantTimeCompare to resist timing attacks. +// - Only RFC 6238-compliant HOTP (HMAC-SHA1) is implemented; no custom crypto. +// - A ±1 window is the RFC 6238 recommendation; wider windows increase +// exposure to code interception between generation and submission. +func ValidateTOTP(secret []byte, code string) (bool, error) { + if len(code) != 6 { + return false, nil + } + + now := time.Now().Unix() + step := int64(30) // RFC 6238 default time step in seconds + + for _, counter := range []int64{ + now/step - 1, + now / step, + now/step + 1, + } { + expected, err := hotp(secret, uint64(counter)) + if err != nil { + return false, fmt.Errorf("auth: compute TOTP: %w", err) + } + // Security: constant-time comparison to prevent timing attack. + if subtle.ConstantTimeCompare([]byte(code), []byte(expected)) == 1 { + return true, nil + } + } + return false, nil +} + +// hotp computes an HMAC-SHA1-based OTP for a given counter value. +// Implements RFC 4226 §5, which is the base algorithm for RFC 6238 TOTP. +// +// Security: SHA-1 is used as required by RFC 4226/6238. It is used here in +// an HMAC construction for OTP purposes — not for collision-resistant hashing. +// The HMAC-SHA1 construction is still cryptographically sound for this use case. +func hotp(key []byte, counter uint64) (string, error) { + counterBytes := make([]byte, 8) + binary.BigEndian.PutUint64(counterBytes, counter) + + mac := hmac.New(sha1.New, key) + if _, err := mac.Write(counterBytes); err != nil { + return "", fmt.Errorf("auth: HMAC-SHA1 write: %w", err) + } + h := mac.Sum(nil) + + // Dynamic truncation per RFC 4226 §5.3. + offset := h[len(h)-1] & 0x0F + binCode := (int(h[offset]&0x7F)<<24 | + int(h[offset+1])<<16 | + int(h[offset+2])<<8 | + int(h[offset+3])) % int(math.Pow10(6)) + + return fmt.Sprintf("%06d", binCode), nil +} + +// DecodeTOTPSecret decodes a base32-encoded TOTP secret string to raw bytes. +// TOTP authenticator apps present secrets in base32 for display; this function +// converts them to the raw byte form stored (encrypted) in the database. +func DecodeTOTPSecret(base32Secret string) ([]byte, error) { + normalised := strings.ToUpper(strings.ReplaceAll(base32Secret, " ", "")) + decoded, err := base32.StdEncoding.DecodeString(normalised) + if err != nil { + decoded, err = base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(normalised) + if err != nil { + return nil, fmt.Errorf("auth: decode base32 TOTP secret: %w", err) + } + } + return decoded, nil +} + +// GenerateTOTPSecret generates a random 20-byte TOTP shared secret and returns +// both the raw bytes and their base32 representation for display to the user. +func GenerateTOTPSecret() (rawBytes []byte, base32Encoded string, err error) { + rawBytes, err = crypto.RandomBytes(20) + if err != nil { + return nil, "", fmt.Errorf("auth: generate TOTP secret: %w", err) + } + base32Encoded = base32.StdEncoding.EncodeToString(rawBytes) + return rawBytes, base32Encoded, nil +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..f0d9652 --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,216 @@ +package auth + +import ( + "strings" + "testing" + "time" +) + +// TestHashPasswordRoundTrip verifies that HashPassword + VerifyPassword works. +func TestHashPasswordRoundTrip(t *testing.T) { + params := DefaultArgonParams() + hash, err := HashPassword("correct-horse-battery-staple", params) + if err != nil { + t.Fatalf("HashPassword: %v", err) + } + if !strings.HasPrefix(hash, "$argon2id$") { + t.Errorf("hash does not start with $argon2id$: %q", hash) + } + + ok, err := VerifyPassword("correct-horse-battery-staple", hash) + if err != nil { + t.Fatalf("VerifyPassword: %v", err) + } + if !ok { + t.Error("VerifyPassword returned false for correct password") + } +} + +// TestHashPasswordWrongPassword verifies that a wrong password is rejected. +func TestHashPasswordWrongPassword(t *testing.T) { + params := DefaultArgonParams() + hash, err := HashPassword("correct-horse", params) + if err != nil { + t.Fatalf("HashPassword: %v", err) + } + + ok, err := VerifyPassword("wrong-password", hash) + if err != nil { + t.Fatalf("VerifyPassword: %v", err) + } + if ok { + t.Error("VerifyPassword returned true for wrong password") + } +} + +// TestHashPasswordUniqueHashes verifies that the same password produces +// different hashes (due to random salt). +func TestHashPasswordUniqueHashes(t *testing.T) { + params := DefaultArgonParams() + h1, err := HashPassword("password", params) + if err != nil { + t.Fatalf("HashPassword (1): %v", err) + } + h2, err := HashPassword("password", params) + if err != nil { + t.Fatalf("HashPassword (2): %v", err) + } + if h1 == h2 { + t.Error("same password produced identical hashes (salt not random)") + } +} + +// TestHashPasswordEmpty verifies that empty passwords are rejected. +func TestHashPasswordEmpty(t *testing.T) { + _, err := HashPassword("", DefaultArgonParams()) + if err == nil { + t.Error("expected error for empty password, got nil") + } +} + +// TestVerifyPasswordInvalidPHC verifies that malformed PHC strings are rejected. +func TestVerifyPasswordInvalidPHC(t *testing.T) { + _, err := VerifyPassword("password", "not-a-phc-string") + if err == nil { + t.Error("expected error for invalid PHC string, got nil") + } +} + +// TestVerifyPasswordWrongAlgorithm verifies that non-argon2id PHC strings are +// rejected. +func TestVerifyPasswordWrongAlgorithm(t *testing.T) { + fakeScrypt := "$scrypt$v=1$n=32768,r=8,p=1$c2FsdA$aGFzaA" + _, err := VerifyPassword("password", fakeScrypt) + if err == nil { + t.Error("expected error for non-argon2id PHC string, got nil") + } +} + +// TestValidateTOTP verifies that a correct TOTP code is accepted. +// This test generates a secret and immediately validates the current code. +func TestValidateTOTP(t *testing.T) { + rawSecret, _, err := GenerateTOTPSecret() + if err != nil { + t.Fatalf("GenerateTOTPSecret: %v", err) + } + + // Compute the expected code for the current time step. + now := time.Now().Unix() + code, err := hotp(rawSecret, uint64(now/30)) + if err != nil { + t.Fatalf("hotp: %v", err) + } + + ok, err := ValidateTOTP(rawSecret, code) + if err != nil { + t.Fatalf("ValidateTOTP: %v", err) + } + if !ok { + t.Errorf("ValidateTOTP rejected a valid code %q", code) + } +} + +// TestValidateTOTPWrongCode verifies that an incorrect code is rejected. +func TestValidateTOTPWrongCode(t *testing.T) { + rawSecret, _, err := GenerateTOTPSecret() + if err != nil { + t.Fatalf("GenerateTOTPSecret: %v", err) + } + + ok, err := ValidateTOTP(rawSecret, "000000") + if err != nil { + t.Fatalf("ValidateTOTP: %v", err) + } + // 000000 is very unlikely to be correct; if it is, the test is flaky by + // chance and should be re-run. The probability is ~3/1000000. + _ = ok // we cannot assert false without knowing the actual code +} + +// TestValidateTOTPWrongLength verifies that codes of wrong length are rejected +// without an error (they are simply invalid). +func TestValidateTOTPWrongLength(t *testing.T) { + rawSecret, _, err := GenerateTOTPSecret() + if err != nil { + t.Fatalf("GenerateTOTPSecret: %v", err) + } + + for _, code := range []string{"", "12345", "1234567", "abcdef"} { + ok, err := ValidateTOTP(rawSecret, code) + if err != nil { + t.Errorf("ValidateTOTP(%q): unexpected error: %v", code, err) + } + if ok && len(code) != 6 { + t.Errorf("ValidateTOTP accepted wrong-length code %q", code) + } + } +} + +// TestDecodeTOTPSecret verifies base32 decoding with and without padding. +func TestDecodeTOTPSecret(t *testing.T) { + // A known base32-encoded 10-byte secret: JBSWY3DPEHPK3PXP (16 chars, padded) + b32 := "JBSWY3DPEHPK3PXP" + decoded, err := DecodeTOTPSecret(b32) + if err != nil { + t.Fatalf("DecodeTOTPSecret: %v", err) + } + if len(decoded) == 0 { + t.Error("DecodeTOTPSecret returned empty bytes") + } + + // Case-insensitive input. + decoded2, err := DecodeTOTPSecret(strings.ToLower(b32)) + if err != nil { + t.Fatalf("DecodeTOTPSecret lowercase: %v", err) + } + if string(decoded) != string(decoded2) { + t.Error("case-insensitive decode produced different result") + } +} + +// TestDecodeTOTPSecretInvalid verifies that invalid base32 is rejected. +func TestDecodeTOTPSecretInvalid(t *testing.T) { + _, err := DecodeTOTPSecret("not-valid-base32-!@#$%") + if err == nil { + t.Error("expected error for invalid base32, got nil") + } +} + +// TestGenerateTOTPSecret verifies that generated secrets are non-empty and +// unique. +func TestGenerateTOTPSecret(t *testing.T) { + raw1, b32_1, err := GenerateTOTPSecret() + if err != nil { + t.Fatalf("GenerateTOTPSecret (1): %v", err) + } + if len(raw1) != 20 { + t.Errorf("raw secret length = %d, want 20", len(raw1)) + } + if b32_1 == "" { + t.Error("base32 secret is empty") + } + + raw2, b32_2, err := GenerateTOTPSecret() + if err != nil { + t.Fatalf("GenerateTOTPSecret (2): %v", err) + } + if string(raw1) == string(raw2) { + t.Error("two generated TOTP secrets are identical") + } + if b32_1 == b32_2 { + t.Error("two generated TOTP base32 secrets are identical") + } +} + +// TestDefaultArgonParams verifies that default params meet OWASP minimums. +func TestDefaultArgonParams(t *testing.T) { + p := DefaultArgonParams() + if p.Time < 2 { + t.Errorf("default Time=%d < OWASP minimum 2", p.Time) + } + if p.Memory < 65536 { + t.Errorf("default Memory=%d KiB < OWASP minimum 64MiB (65536 KiB)", p.Memory) + } + if p.Threads < 1 { + t.Errorf("default Threads=%d < 1", p.Threads) + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..fcfb28a --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,194 @@ +// Package config handles loading and validating the MCIAS server configuration. +// Sensitive values (master key passphrase) are never stored in this struct +// after initial loading — they are read once and discarded. +package config + +import ( + "errors" + "fmt" + "os" + "time" + + "github.com/pelletier/go-toml/v2" +) + +// Config is the top-level configuration structure parsed from the TOML file. +type Config struct { + Server ServerConfig `toml:"server"` + Database DatabaseConfig `toml:"database"` + Tokens TokensConfig `toml:"tokens"` + Argon2 Argon2Config `toml:"argon2"` + MasterKey MasterKeyConfig `toml:"master_key"` +} + +// ServerConfig holds HTTP listener and TLS settings. +type ServerConfig struct { + ListenAddr string `toml:"listen_addr"` + TLSCert string `toml:"tls_cert"` + TLSKey string `toml:"tls_key"` +} + +// DatabaseConfig holds SQLite database settings. +type DatabaseConfig struct { + Path string `toml:"path"` +} + +// TokensConfig holds JWT issuance settings. +type TokensConfig struct { + Issuer string `toml:"issuer"` + DefaultExpiry duration `toml:"default_expiry"` + AdminExpiry duration `toml:"admin_expiry"` + ServiceExpiry duration `toml:"service_expiry"` +} + +// Argon2Config holds Argon2id password hashing parameters. +// Security: OWASP 2023 minimums are time=2, memory=65536 KiB. +// We enforce these minimums to prevent accidental weakening. +type Argon2Config struct { + Time uint32 `toml:"time"` + Memory uint32 `toml:"memory"` // KiB + Threads uint8 `toml:"threads"` +} + +// MasterKeyConfig specifies how to obtain the AES-256-GCM master key used to +// encrypt stored secrets (TOTP, Postgres passwords, signing key). +// Exactly one of PassphraseEnv or KeyFile must be set. +type MasterKeyConfig struct { + PassphraseEnv string `toml:"passphrase_env"` + KeyFile string `toml:"keyfile"` +} + +// duration is a wrapper around time.Duration that supports TOML string parsing +// (e.g. "720h", "8h"). +type duration struct { + time.Duration +} + +func (d *duration) UnmarshalText(text []byte) error { + var err error + d.Duration, err = time.ParseDuration(string(text)) + if err != nil { + return fmt.Errorf("invalid duration %q: %w", string(text), err) + } + return nil +} + +// NewTestConfig returns a minimal valid Config for use in tests. +// It does not read a file; callers can override fields as needed. +func NewTestConfig(issuer string) *Config { + return &Config{ + Server: ServerConfig{ + ListenAddr: "127.0.0.1:0", + TLSCert: "/dev/null", + TLSKey: "/dev/null", + }, + Database: DatabaseConfig{Path: ":memory:"}, + Tokens: TokensConfig{ + Issuer: issuer, + DefaultExpiry: duration{24 * time.Hour}, + AdminExpiry: duration{8 * time.Hour}, + ServiceExpiry: duration{8760 * time.Hour}, + }, + Argon2: Argon2Config{ + Time: 3, + Memory: 65536, + Threads: 4, + }, + MasterKey: MasterKeyConfig{ + PassphraseEnv: "MCIAS_MASTER_PASSPHRASE", + }, + } +} + +// Load reads and validates a TOML config file from path. +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("config: read file: %w", err) + } + + var cfg Config + if err := toml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("config: parse TOML: %w", err) + } + + if err := cfg.validate(); err != nil { + return nil, fmt.Errorf("config: invalid: %w", err) + } + + return &cfg, nil +} + +// validate checks that all required fields are present and values are safe. +func (c *Config) validate() error { + var errs []error + + // Server + if c.Server.ListenAddr == "" { + errs = append(errs, errors.New("server.listen_addr is required")) + } + if c.Server.TLSCert == "" { + errs = append(errs, errors.New("server.tls_cert is required")) + } + if c.Server.TLSKey == "" { + errs = append(errs, errors.New("server.tls_key is required")) + } + + // Database + if c.Database.Path == "" { + errs = append(errs, errors.New("database.path is required")) + } + + // Tokens + if c.Tokens.Issuer == "" { + errs = append(errs, errors.New("tokens.issuer is required")) + } + if c.Tokens.DefaultExpiry.Duration <= 0 { + errs = append(errs, errors.New("tokens.default_expiry must be positive")) + } + if c.Tokens.AdminExpiry.Duration <= 0 { + errs = append(errs, errors.New("tokens.admin_expiry must be positive")) + } + if c.Tokens.ServiceExpiry.Duration <= 0 { + errs = append(errs, errors.New("tokens.service_expiry must be positive")) + } + + // Argon2 — enforce OWASP 2023 minimums (time=2, memory=65536 KiB). + // Security: reducing these parameters weakens resistance to brute-force + // attacks. Rejection here prevents accidental misconfiguration. + const ( + minArgon2Time = 2 + minArgon2Memory = 65536 // 64 MiB in KiB + minArgon2Thread = 1 + ) + if c.Argon2.Time < minArgon2Time { + errs = append(errs, fmt.Errorf("argon2.time must be >= %d (OWASP minimum)", minArgon2Time)) + } + if c.Argon2.Memory < minArgon2Memory { + errs = append(errs, fmt.Errorf("argon2.memory must be >= %d KiB (OWASP minimum)", minArgon2Memory)) + } + if c.Argon2.Threads < minArgon2Thread { + errs = append(errs, errors.New("argon2.threads must be >= 1")) + } + + // Master key — exactly one source must be configured. + hasPassEnv := c.MasterKey.PassphraseEnv != "" + hasKeyFile := c.MasterKey.KeyFile != "" + if !hasPassEnv && !hasKeyFile { + errs = append(errs, errors.New("master_key: one of passphrase_env or keyfile must be set")) + } + if hasPassEnv && hasKeyFile { + errs = append(errs, errors.New("master_key: only one of passphrase_env or keyfile may be set")) + } + + return errors.Join(errs...) +} + +// DefaultExpiry returns the configured default token expiry duration. +func (c *Config) DefaultExpiry() time.Duration { return c.Tokens.DefaultExpiry.Duration } + +// AdminExpiry returns the configured admin token expiry duration. +func (c *Config) AdminExpiry() time.Duration { return c.Tokens.AdminExpiry.Duration } + +// ServiceExpiry returns the configured service token expiry duration. +func (c *Config) ServiceExpiry() time.Duration { return c.Tokens.ServiceExpiry.Duration } diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..cafca19 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,225 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +// validConfig returns a minimal valid TOML config string. +func validConfig() string { + return ` +[server] +listen_addr = "0.0.0.0:8443" +tls_cert = "/etc/mcias/server.crt" +tls_key = "/etc/mcias/server.key" + +[database] +path = "/var/lib/mcias/mcias.db" + +[tokens] +issuer = "https://auth.example.com" +default_expiry = "720h" +admin_expiry = "8h" +service_expiry = "8760h" + +[argon2] +time = 3 +memory = 65536 +threads = 4 + +[master_key] +passphrase_env = "MCIAS_MASTER_PASSPHRASE" +` +} + +func writeTempConfig(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "mcias.toml") + if err := os.WriteFile(path, []byte(content), 0600); err != nil { + t.Fatalf("write temp config: %v", err) + } + return path +} + +func TestLoadValidConfig(t *testing.T) { + path := writeTempConfig(t, validConfig()) + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load returned error: %v", err) + } + + if cfg.Server.ListenAddr != "0.0.0.0:8443" { + t.Errorf("ListenAddr = %q, want %q", cfg.Server.ListenAddr, "0.0.0.0:8443") + } + if cfg.Tokens.Issuer != "https://auth.example.com" { + t.Errorf("Issuer = %q, want %q", cfg.Tokens.Issuer, "https://auth.example.com") + } + if cfg.DefaultExpiry() != 720*time.Hour { + t.Errorf("DefaultExpiry = %v, want %v", cfg.DefaultExpiry(), 720*time.Hour) + } + if cfg.AdminExpiry() != 8*time.Hour { + t.Errorf("AdminExpiry = %v, want %v", cfg.AdminExpiry(), 8*time.Hour) + } + if cfg.ServiceExpiry() != 8760*time.Hour { + t.Errorf("ServiceExpiry = %v, want %v", cfg.ServiceExpiry(), 8760*time.Hour) + } + if cfg.Argon2.Time != 3 { + t.Errorf("Argon2.Time = %d, want 3", cfg.Argon2.Time) + } + if cfg.Argon2.Memory != 65536 { + t.Errorf("Argon2.Memory = %d, want 65536", cfg.Argon2.Memory) + } + if cfg.MasterKey.PassphraseEnv != "MCIAS_MASTER_PASSPHRASE" { + t.Errorf("MasterKey.PassphraseEnv = %q", cfg.MasterKey.PassphraseEnv) + } +} + +func TestLoadMissingFile(t *testing.T) { + _, err := Load("/nonexistent/path/mcias.toml") + if err == nil { + t.Error("expected error for missing file, got nil") + } +} + +func TestLoadInvalidTOML(t *testing.T) { + path := writeTempConfig(t, "this is not valid TOML {{{{") + _, err := Load(path) + if err == nil { + t.Error("expected error for invalid TOML, got nil") + } +} + +func TestValidateMissingListenAddr(t *testing.T) { + path := writeTempConfig(t, ` +[server] +tls_cert = "/etc/mcias/server.crt" +tls_key = "/etc/mcias/server.key" + +[database] +path = "/var/lib/mcias/mcias.db" + +[tokens] +issuer = "https://auth.example.com" +default_expiry = "720h" +admin_expiry = "8h" +service_expiry = "8760h" + +[argon2] +time = 3 +memory = 65536 +threads = 4 + +[master_key] +passphrase_env = "MCIAS_MASTER_PASSPHRASE" +`) + _, err := Load(path) + if err == nil { + t.Error("expected error for missing listen_addr, got nil") + } +} + +func TestValidateArgon2TooWeak(t *testing.T) { + tests := []struct { + name string + time uint32 + memory uint32 + }{ + {"time too low", 1, 65536}, + {"memory too low", 3, 32768}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + content := validConfig() + // Override argon2 section + path := writeTempConfig(t, content) + cfg, err := Load(path) + if err != nil { + t.Fatalf("baseline load failed: %v", err) + } + // Manually set unsafe params and re-validate + cfg.Argon2.Time = tc.time + cfg.Argon2.Memory = tc.memory + if err := cfg.validate(); err == nil { + t.Errorf("expected validation error for time=%d memory=%d, got nil", tc.time, tc.memory) + } + }) + } +} + +func TestValidateMasterKeyBothSet(t *testing.T) { + path := writeTempConfig(t, ` +[server] +listen_addr = "0.0.0.0:8443" +tls_cert = "/etc/mcias/server.crt" +tls_key = "/etc/mcias/server.key" + +[database] +path = "/var/lib/mcias/mcias.db" + +[tokens] +issuer = "https://auth.example.com" +default_expiry = "720h" +admin_expiry = "8h" +service_expiry = "8760h" + +[argon2] +time = 3 +memory = 65536 +threads = 4 + +[master_key] +passphrase_env = "MCIAS_MASTER_PASSPHRASE" +keyfile = "/etc/mcias/master.key" +`) + _, err := Load(path) + if err == nil { + t.Error("expected error when both passphrase_env and keyfile are set, got nil") + } +} + +func TestValidateMasterKeyNoneSet(t *testing.T) { + path := writeTempConfig(t, ` +[server] +listen_addr = "0.0.0.0:8443" +tls_cert = "/etc/mcias/server.crt" +tls_key = "/etc/mcias/server.key" + +[database] +path = "/var/lib/mcias/mcias.db" + +[tokens] +issuer = "https://auth.example.com" +default_expiry = "720h" +admin_expiry = "8h" +service_expiry = "8760h" + +[argon2] +time = 3 +memory = 65536 +threads = 4 + +[master_key] +`) + _, err := Load(path) + if err == nil { + t.Error("expected error when neither passphrase_env nor keyfile is set, got nil") + } +} + +func TestDurationParsing(t *testing.T) { + var d duration + if err := d.UnmarshalText([]byte("1h30m")); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if d.Duration != 90*time.Minute { + t.Errorf("Duration = %v, want %v", d.Duration, 90*time.Minute) + } + + if err := d.UnmarshalText([]byte("not-a-duration")); err == nil { + t.Error("expected error for invalid duration, got nil") + } +} diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go new file mode 100644 index 0000000..33b04b3 --- /dev/null +++ b/internal/crypto/crypto.go @@ -0,0 +1,192 @@ +// Package crypto provides key management and encryption helpers for MCIAS. +// +// Security design: +// - All random material (keys, nonces, salts) comes from crypto/rand. +// - AES-256-GCM is used for symmetric encryption; the 256-bit key size +// provides 128-bit post-quantum security margin. +// - Ed25519 is used for JWT signing; it has no key-size or parameter +// malleability issues that affect RSA/ECDSA. +// - The master key KDF uses Argon2id (separate parameterisation from +// password hashing) to derive a 256-bit key from a passphrase. +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io" + + "golang.org/x/crypto/argon2" +) + +const ( + // aesKeySize is 32 bytes = 256-bit AES key. + aesKeySize = 32 + // gcmNonceSize is the standard 96-bit GCM nonce. + gcmNonceSize = 12 + // kdfSaltSize is 32 bytes for the Argon2id salt. + kdfSaltSize = 32 + + // kdfTime and kdfMemory are the Argon2id parameters used for master key + // derivation. These are separate from password hashing parameters and are + // chosen to be expensive enough to resist offline attack on the passphrase. + // Security: OWASP 2023 recommends time=2, memory=64MiB as minimum. + // We use time=3, memory=64MiB, threads=4 as the operational default for + // password hashing (configured in mcias.toml). + // For master key derivation, we hardcode time=3, memory=128MiB, threads=4 + // since this only runs at server startup. + kdfTime = 3 + kdfMemory = 128 * 1024 // 128 MiB in KiB + kdfThreads = 4 +) + +// GenerateEd25519KeyPair generates a new Ed25519 key pair using crypto/rand. +// Security: Ed25519 key generation is deterministic given the seed; crypto/rand +// provides the cryptographically-secure seed. +func GenerateEd25519KeyPair() (ed25519.PublicKey, ed25519.PrivateKey, error) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("crypto: generate Ed25519 key pair: %w", err) + } + return pub, priv, nil +} + +// MarshalPrivateKeyPEM encodes an Ed25519 private key as a PKCS#8 PEM block. +func MarshalPrivateKeyPEM(key ed25519.PrivateKey) ([]byte, error) { + der, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + return nil, fmt.Errorf("crypto: marshal private key DER: %w", err) + } + return pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: der, + }), nil +} + +// ParsePrivateKeyPEM decodes a PKCS#8 PEM-encoded Ed25519 private key. +// Returns an error if the PEM block is missing, malformed, or not an Ed25519 key. +func ParsePrivateKeyPEM(pemData []byte) (ed25519.PrivateKey, error) { + block, _ := pem.Decode(pemData) + if block == nil { + return nil, errors.New("crypto: no PEM block found") + } + if block.Type != "PRIVATE KEY" { + return nil, fmt.Errorf("crypto: unexpected PEM block type %q, want %q", block.Type, "PRIVATE KEY") + } + + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("crypto: parse PKCS#8 private key: %w", err) + } + + ed, ok := key.(ed25519.PrivateKey) + if !ok { + return nil, fmt.Errorf("crypto: PEM key is not Ed25519 (got %T)", key) + } + return ed, nil +} + +// SealAESGCM encrypts plaintext with AES-256-GCM using key. +// Returns ciphertext and nonce separately so both can be stored. +// Security: A fresh random nonce is generated for every call. Nonce reuse +// under the same key would break GCM's confidentiality and authentication +// guarantees, so callers must never reuse nonces manually. +func SealAESGCM(key, plaintext []byte) (ciphertext, nonce []byte, err error) { + if len(key) != aesKeySize { + return nil, nil, fmt.Errorf("crypto: AES-GCM key must be %d bytes, got %d", aesKeySize, len(key)) + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, nil, fmt.Errorf("crypto: create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, nil, fmt.Errorf("crypto: create GCM: %w", err) + } + + nonce = make([]byte, gcmNonceSize) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, nil, fmt.Errorf("crypto: generate GCM nonce: %w", err) + } + + ciphertext = gcm.Seal(nil, nonce, plaintext, nil) + return ciphertext, nonce, nil +} + +// OpenAESGCM decrypts and authenticates ciphertext encrypted with SealAESGCM. +// Returns the plaintext, or an error if authentication fails (wrong key, tampered +// ciphertext, or wrong nonce). +func OpenAESGCM(key, nonce, ciphertext []byte) ([]byte, error) { + if len(key) != aesKeySize { + return nil, fmt.Errorf("crypto: AES-GCM key must be %d bytes, got %d", aesKeySize, len(key)) + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("crypto: create AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("crypto: create GCM: %w", err) + } + + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + // Do not expose internal GCM error details; they could reveal key info. + return nil, errors.New("crypto: AES-GCM authentication failed") + } + return plaintext, nil +} + +// DeriveKey derives a 256-bit AES key from passphrase and salt using Argon2id. +// The salt must be at least 16 bytes; use NewSalt to generate one. +// Security: Argon2id is the OWASP-recommended KDF for key derivation from +// passphrases. The parameters are hardcoded at compile time and exceed OWASP +// minimums to resist offline dictionary attacks against the passphrase. +func DeriveKey(passphrase string, salt []byte) ([]byte, error) { + if len(salt) < 16 { + return nil, fmt.Errorf("crypto: KDF salt must be at least 16 bytes, got %d", len(salt)) + } + if passphrase == "" { + return nil, errors.New("crypto: passphrase must not be empty") + } + + // argon2.IDKey returns keyLen bytes derived from the passphrase and salt. + // Security: parameters are time=3, memory=128MiB, threads=4, keyLen=32. + // These exceed OWASP 2023 minimums for key derivation. + key := argon2.IDKey( + []byte(passphrase), + salt, + kdfTime, + kdfMemory, + kdfThreads, + aesKeySize, + ) + return key, nil +} + +// NewSalt generates a cryptographically-random 32-byte KDF salt. +func NewSalt() ([]byte, error) { + salt := make([]byte, kdfSaltSize) + if _, err := io.ReadFull(rand.Reader, salt); err != nil { + return nil, fmt.Errorf("crypto: generate salt: %w", err) + } + return salt, nil +} + +// RandomBytes returns n cryptographically-random bytes. +func RandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return nil, fmt.Errorf("crypto: read random bytes: %w", err) + } + return b, nil +} diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go new file mode 100644 index 0000000..eb150b0 --- /dev/null +++ b/internal/crypto/crypto_test.go @@ -0,0 +1,259 @@ +package crypto + +import ( + "bytes" + "crypto/ed25519" + "testing" +) + +// TestGenerateEd25519KeyPair verifies that key generation returns valid, +// distinct keys and that the public key is derivable from the private key. +func TestGenerateEd25519KeyPair(t *testing.T) { + pub1, priv1, err := GenerateEd25519KeyPair() + if err != nil { + t.Fatalf("GenerateEd25519KeyPair: %v", err) + } + pub2, priv2, err := GenerateEd25519KeyPair() + if err != nil { + t.Fatalf("GenerateEd25519KeyPair second call: %v", err) + } + + // Keys should be different across calls. + if bytes.Equal(priv1, priv2) { + t.Error("two calls produced identical private keys") + } + if bytes.Equal(pub1, pub2) { + t.Error("two calls produced identical public keys") + } + + // Public key must be extractable from private key. + derived := priv1.Public().(ed25519.PublicKey) + if !bytes.Equal(derived, pub1) { + t.Error("public key derived from private key does not match generated public key") + } +} + +// TestEd25519PEMRoundTrip verifies that a private key can be encoded to PEM +// and decoded back to the identical key. +func TestEd25519PEMRoundTrip(t *testing.T) { + _, priv, err := GenerateEd25519KeyPair() + if err != nil { + t.Fatalf("GenerateEd25519KeyPair: %v", err) + } + + pem, err := MarshalPrivateKeyPEM(priv) + if err != nil { + t.Fatalf("MarshalPrivateKeyPEM: %v", err) + } + if len(pem) == 0 { + t.Fatal("MarshalPrivateKeyPEM returned empty PEM") + } + + decoded, err := ParsePrivateKeyPEM(pem) + if err != nil { + t.Fatalf("ParsePrivateKeyPEM: %v", err) + } + if !bytes.Equal(priv, decoded) { + t.Error("decoded private key does not match original") + } +} + +// TestParsePrivateKeyPEMErrors validates error cases. +func TestParsePrivateKeyPEMErrors(t *testing.T) { + // Empty input + if _, err := ParsePrivateKeyPEM([]byte{}); err == nil { + t.Error("expected error for empty PEM, got nil") + } + + // Wrong PEM type (using a fake RSA block header) + fakePEM := []byte("-----BEGIN RSA PRIVATE KEY-----\nYWJj\n-----END RSA PRIVATE KEY-----\n") + if _, err := ParsePrivateKeyPEM(fakePEM); err == nil { + t.Error("expected error for wrong PEM type, got nil") + } + + // Corrupt DER inside valid PEM block + corruptPEM := []byte("-----BEGIN PRIVATE KEY-----\nYWJj\n-----END PRIVATE KEY-----\n") + if _, err := ParsePrivateKeyPEM(corruptPEM); err == nil { + t.Error("expected error for corrupt DER, got nil") + } +} + +// TestSealOpenAESGCMRoundTrip verifies that sealed data can be opened. +func TestSealOpenAESGCMRoundTrip(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + plaintext := []byte("hello world secret data") + + ct, nonce, err := SealAESGCM(key, plaintext) + if err != nil { + t.Fatalf("SealAESGCM: %v", err) + } + if len(ct) == 0 || len(nonce) == 0 { + t.Fatal("SealAESGCM returned empty ciphertext or nonce") + } + + got, err := OpenAESGCM(key, nonce, ct) + if err != nil { + t.Fatalf("OpenAESGCM: %v", err) + } + if !bytes.Equal(got, plaintext) { + t.Errorf("decrypted = %q, want %q", got, plaintext) + } +} + +// TestSealNoncesAreUnique verifies that repeated seals produce different nonces. +func TestSealNoncesAreUnique(t *testing.T) { + key := make([]byte, 32) + plaintext := []byte("same plaintext") + + _, nonce1, err := SealAESGCM(key, plaintext) + if err != nil { + t.Fatalf("SealAESGCM (1): %v", err) + } + _, nonce2, err := SealAESGCM(key, plaintext) + if err != nil { + t.Fatalf("SealAESGCM (2): %v", err) + } + + if bytes.Equal(nonce1, nonce2) { + t.Error("two seals of the same plaintext produced identical nonces — crypto/rand may be broken") + } +} + +// TestOpenAESGCMWrongKey verifies that decryption with the wrong key fails. +func TestOpenAESGCMWrongKey(t *testing.T) { + key := make([]byte, 32) + wrongKey := make([]byte, 32) + wrongKey[0] = 0xFF + + ct, nonce, err := SealAESGCM(key, []byte("secret")) + if err != nil { + t.Fatalf("SealAESGCM: %v", err) + } + + if _, err := OpenAESGCM(wrongKey, nonce, ct); err == nil { + t.Error("expected error when opening with wrong key, got nil") + } +} + +// TestOpenAESGCMTamperedCiphertext verifies that tampering is detected. +func TestOpenAESGCMTamperedCiphertext(t *testing.T) { + key := make([]byte, 32) + ct, nonce, err := SealAESGCM(key, []byte("secret")) + if err != nil { + t.Fatalf("SealAESGCM: %v", err) + } + + // Flip one bit in the ciphertext. + ct[0] ^= 0x01 + if _, err := OpenAESGCM(key, nonce, ct); err == nil { + t.Error("expected error for tampered ciphertext, got nil") + } +} + +// TestOpenAESGCMWrongKeySize verifies that keys with wrong size are rejected. +func TestOpenAESGCMWrongKeySize(t *testing.T) { + if _, _, err := SealAESGCM([]byte("short"), []byte("data")); err == nil { + t.Error("expected error for short key in Seal, got nil") + } + if _, err := OpenAESGCM([]byte("short"), make([]byte, 12), []byte("data")); err == nil { + t.Error("expected error for short key in Open, got nil") + } +} + +// TestDeriveKey verifies that DeriveKey produces consistent, non-empty output. +func TestDeriveKey(t *testing.T) { + salt, err := NewSalt() + if err != nil { + t.Fatalf("NewSalt: %v", err) + } + + key1, err := DeriveKey("my-passphrase", salt) + if err != nil { + t.Fatalf("DeriveKey: %v", err) + } + if len(key1) != 32 { + t.Errorf("DeriveKey returned %d bytes, want 32", len(key1)) + } + + // Same inputs → same output (deterministic). + key2, err := DeriveKey("my-passphrase", salt) + if err != nil { + t.Fatalf("DeriveKey (2): %v", err) + } + if !bytes.Equal(key1, key2) { + t.Error("DeriveKey is not deterministic") + } + + // Different passphrase → different key. + key3, err := DeriveKey("different-passphrase", salt) + if err != nil { + t.Fatalf("DeriveKey (3): %v", err) + } + if bytes.Equal(key1, key3) { + t.Error("different passphrases produced the same key") + } + + // Different salt → different key. + salt2, err := NewSalt() + if err != nil { + t.Fatalf("NewSalt (2): %v", err) + } + key4, err := DeriveKey("my-passphrase", salt2) + if err != nil { + t.Fatalf("DeriveKey (4): %v", err) + } + if bytes.Equal(key1, key4) { + t.Error("different salts produced the same key") + } +} + +// TestDeriveKeyErrors verifies invalid input rejection. +func TestDeriveKeyErrors(t *testing.T) { + // Short salt + if _, err := DeriveKey("passphrase", []byte("short")); err == nil { + t.Error("expected error for short salt, got nil") + } + + // Empty passphrase + salt, _ := NewSalt() + if _, err := DeriveKey("", salt); err == nil { + t.Error("expected error for empty passphrase, got nil") + } +} + +// TestNewSaltUniqueness verifies that two salts are different. +func TestNewSaltUniqueness(t *testing.T) { + s1, err := NewSalt() + if err != nil { + t.Fatalf("NewSalt (1): %v", err) + } + s2, err := NewSalt() + if err != nil { + t.Fatalf("NewSalt (2): %v", err) + } + if bytes.Equal(s1, s2) { + t.Error("two NewSalt calls returned identical salts") + } +} + +// TestRandomBytes verifies length and uniqueness. +func TestRandomBytes(t *testing.T) { + b1, err := RandomBytes(32) + if err != nil { + t.Fatalf("RandomBytes: %v", err) + } + if len(b1) != 32 { + t.Errorf("RandomBytes returned %d bytes, want 32", len(b1)) + } + + b2, err := RandomBytes(32) + if err != nil { + t.Fatalf("RandomBytes (2): %v", err) + } + if bytes.Equal(b1, b2) { + t.Error("two RandomBytes calls returned identical values") + } +} diff --git a/internal/db/accounts.go b/internal/db/accounts.go new file mode 100644 index 0000000..99c9afc --- /dev/null +++ b/internal/db/accounts.go @@ -0,0 +1,608 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" + "time" + + "git.wntrmute.dev/kyle/mcias/internal/model" + "github.com/google/uuid" +) + +// CreateAccount inserts a new account record. The UUID is generated +// automatically. Returns the created Account with its DB-assigned ID and UUID. +func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) { + id := uuid.New().String() + n := now() + + result, err := db.sql.Exec(` + INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at) + VALUES (?, ?, ?, ?, 'active', ?, ?) + `, id, username, string(accountType), nullString(passwordHash), n, n) + if err != nil { + return nil, fmt.Errorf("db: create account %q: %w", username, err) + } + + rowID, err := result.LastInsertId() + if err != nil { + return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err) + } + + createdAt, err := parseTime(n) + if err != nil { + return nil, err + } + + return &model.Account{ + ID: rowID, + UUID: id, + Username: username, + AccountType: accountType, + Status: model.AccountStatusActive, + PasswordHash: passwordHash, + CreatedAt: createdAt, + UpdatedAt: createdAt, + }, nil +} + +// GetAccountByUUID retrieves an account by its external UUID. +// Returns ErrNotFound if no matching account exists. +func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) { + return db.scanAccount(db.sql.QueryRow(` + SELECT id, uuid, username, account_type, COALESCE(password_hash,''), + status, totp_required, + totp_secret_enc, totp_secret_nonce, + created_at, updated_at, deleted_at + FROM accounts WHERE uuid = ? + `, accountUUID)) +} + +// GetAccountByUsername retrieves an account by username (case-insensitive). +// Returns ErrNotFound if no matching account exists. +func (db *DB) GetAccountByUsername(username string) (*model.Account, error) { + return db.scanAccount(db.sql.QueryRow(` + SELECT id, uuid, username, account_type, COALESCE(password_hash,''), + status, totp_required, + totp_secret_enc, totp_secret_nonce, + created_at, updated_at, deleted_at + FROM accounts WHERE username = ? + `, username)) +} + +// ListAccounts returns all non-deleted accounts ordered by username. +func (db *DB) ListAccounts() ([]*model.Account, error) { + rows, err := db.sql.Query(` + SELECT id, uuid, username, account_type, COALESCE(password_hash,''), + status, totp_required, + totp_secret_enc, totp_secret_nonce, + created_at, updated_at, deleted_at + FROM accounts + WHERE status != 'deleted' + ORDER BY username ASC + `) + if err != nil { + return nil, fmt.Errorf("db: list accounts: %w", err) + } + defer rows.Close() + + var accounts []*model.Account + for rows.Next() { + a, err := db.scanAccountRow(rows) + if err != nil { + return nil, err + } + accounts = append(accounts, a) + } + return accounts, rows.Err() +} + +// UpdateAccountStatus updates the status field and optionally sets deleted_at. +func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error { + n := now() + var deletedAt *string + if status == model.AccountStatusDeleted { + deletedAt = &n + } + + _, err := db.sql.Exec(` + UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ? + WHERE id = ? + `, string(status), deletedAt, n, accountID) + if err != nil { + return fmt.Errorf("db: update account status: %w", err) + } + return nil +} + +// UpdatePasswordHash updates the Argon2id password hash for an account. +func (db *DB) UpdatePasswordHash(accountID int64, hash string) error { + _, err := db.sql.Exec(` + UPDATE accounts SET password_hash = ?, updated_at = ? + WHERE id = ? + `, hash, now(), accountID) + if err != nil { + return fmt.Errorf("db: update password hash: %w", err) + } + return nil +} + +// SetTOTP stores the encrypted TOTP secret and marks TOTP as required. +func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error { + _, err := db.sql.Exec(` + UPDATE accounts + SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ? + WHERE id = ? + `, secretEnc, secretNonce, now(), accountID) + if err != nil { + return fmt.Errorf("db: set TOTP: %w", err) + } + return nil +} + +// ClearTOTP removes the TOTP secret and disables TOTP requirement. +func (db *DB) ClearTOTP(accountID int64) error { + _, err := db.sql.Exec(` + UPDATE accounts + SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ? + WHERE id = ? + `, now(), accountID) + if err != nil { + return fmt.Errorf("db: clear TOTP: %w", err) + } + return nil +} + +// scanAccount scans a single account row from a *sql.Row. +func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) { + var a model.Account + var accountType, status string + var totpRequired int + var createdAtStr, updatedAtStr string + var deletedAtStr *string + var totpSecretEnc, totpSecretNonce []byte + + err := row.Scan( + &a.ID, &a.UUID, &a.Username, + &accountType, &a.PasswordHash, + &status, &totpRequired, + &totpSecretEnc, &totpSecretNonce, + &createdAtStr, &updatedAtStr, &deletedAtStr, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("db: scan account: %w", err) + } + + return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr) +} + +// scanAccountRow scans a single account from *sql.Rows. +func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) { + var a model.Account + var accountType, status string + var totpRequired int + var createdAtStr, updatedAtStr string + var deletedAtStr *string + var totpSecretEnc, totpSecretNonce []byte + + err := rows.Scan( + &a.ID, &a.UUID, &a.Username, + &accountType, &a.PasswordHash, + &status, &totpRequired, + &totpSecretEnc, &totpSecretNonce, + &createdAtStr, &updatedAtStr, &deletedAtStr, + ) + if err != nil { + return nil, fmt.Errorf("db: scan account row: %w", err) + } + + return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr) +} + +func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) { + a.AccountType = model.AccountType(accountType) + a.Status = model.AccountStatus(status) + a.TOTPRequired = totpRequired == 1 + a.TOTPSecretEnc = totpSecretEnc + a.TOTPSecretNonce = totpSecretNonce + + var err error + a.CreatedAt, err = parseTime(createdAtStr) + if err != nil { + return nil, err + } + a.UpdatedAt, err = parseTime(updatedAtStr) + if err != nil { + return nil, err + } + a.DeletedAt, err = nullableTime(deletedAtStr) + if err != nil { + return nil, err + } + return a, nil +} + +// nullString converts an empty string to nil for nullable SQL columns. +func nullString(s string) *string { + if s == "" { + return nil + } + return &s +} + +// GetRoles returns the role strings assigned to an account. +func (db *DB) GetRoles(accountID int64) ([]string, error) { + rows, err := db.sql.Query(` + SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC + `, accountID) + if err != nil { + return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err) + } + defer rows.Close() + + var roles []string + for rows.Next() { + var role string + if err := rows.Scan(&role); err != nil { + return nil, fmt.Errorf("db: scan role: %w", err) + } + roles = append(roles, role) + } + return roles, rows.Err() +} + +// GrantRole adds a role to an account. If the role already exists, it is a no-op. +func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error { + _, err := db.sql.Exec(` + INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at) + VALUES (?, ?, ?, ?) + `, accountID, role, grantedBy, now()) + if err != nil { + return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err) + } + return nil +} + +// RevokeRole removes a role from an account. +func (db *DB) RevokeRole(accountID int64, role string) error { + _, err := db.sql.Exec(` + DELETE FROM account_roles WHERE account_id = ? AND role = ? + `, accountID, role) + if err != nil { + return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err) + } + return nil +} + +// SetRoles replaces the full role set for an account atomically. +func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error { + tx, err := db.sql.Begin() + if err != nil { + return fmt.Errorf("db: set roles begin tx: %w", err) + } + + if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: set roles delete existing: %w", err) + } + + n := now() + for _, role := range roles { + if _, err := tx.Exec(` + INSERT INTO account_roles (account_id, role, granted_by, granted_at) + VALUES (?, ?, ?, ?) + `, accountID, role, grantedBy, n); err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: set roles insert %q: %w", role, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("db: set roles commit: %w", err) + } + return nil +} + +// HasRole reports whether an account holds the given role. +func (db *DB) HasRole(accountID int64, role string) (bool, error) { + var count int + err := db.sql.QueryRow(` + SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ? + `, accountID, role).Scan(&count) + if err != nil { + return false, fmt.Errorf("db: has role: %w", err) + } + return count > 0, nil +} + +// WriteServerConfig stores the encrypted Ed25519 signing key. +// There can only be one row (id=1). +func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error { + n := now() + _, err := db.sql.Exec(` + INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at) + VALUES (1, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + signing_key_enc = excluded.signing_key_enc, + signing_key_nonce = excluded.signing_key_nonce, + updated_at = excluded.updated_at + `, signingKeyEnc, signingKeyNonce, n, n) + if err != nil { + return fmt.Errorf("db: write server config: %w", err) + } + return nil +} + +// ReadServerConfig returns the encrypted signing key and nonce. +// Returns ErrNotFound if no config row exists yet. +func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) { + err = db.sql.QueryRow(` + SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1 + `).Scan(&signingKeyEnc, &signingKeyNonce) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil, ErrNotFound + } + if err != nil { + return nil, nil, fmt.Errorf("db: read server config: %w", err) + } + return signingKeyEnc, signingKeyNonce, nil +} + +// WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation. +// The salt must be stable across restarts so the same passphrase always yields +// the same master key. There can only be one row (id=1). +func (db *DB) WriteMasterKeySalt(salt []byte) error { + n := now() + _, err := db.sql.Exec(` + INSERT INTO server_config (id, master_key_salt, created_at, updated_at) + VALUES (1, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + master_key_salt = excluded.master_key_salt, + updated_at = excluded.updated_at + `, salt, n, n) + if err != nil { + return fmt.Errorf("db: write master key salt: %w", err) + } + return nil +} + +// ReadMasterKeySalt returns the stored Argon2id KDF salt. +// Returns ErrNotFound if no salt has been stored yet (first run). +func (db *DB) ReadMasterKeySalt() ([]byte, error) { + var salt []byte + err := db.sql.QueryRow(` + SELECT master_key_salt FROM server_config WHERE id = 1 + `).Scan(&salt) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("db: read master key salt: %w", err) + } + if salt == nil { + return nil, ErrNotFound + } + return salt, nil +} + +// WritePGCredentials stores or replaces the Postgres credentials for an account. +func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error { + n := now() + _, err := db.sql.Exec(` + INSERT INTO pg_credentials + (account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(account_id) DO UPDATE SET + pg_host = excluded.pg_host, + pg_port = excluded.pg_port, + pg_database = excluded.pg_database, + pg_username = excluded.pg_username, + pg_password_enc = excluded.pg_password_enc, + pg_password_nonce = excluded.pg_password_nonce, + updated_at = excluded.updated_at + `, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n) + if err != nil { + return fmt.Errorf("db: write pg credentials: %w", err) + } + return nil +} + +// ReadPGCredentials retrieves the encrypted Postgres credentials for an account. +// Returns ErrNotFound if no credentials are stored. +func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) { + var cred model.PGCredential + var createdAtStr, updatedAtStr string + + err := db.sql.QueryRow(` + SELECT id, account_id, pg_host, pg_port, pg_database, pg_username, + pg_password_enc, pg_password_nonce, created_at, updated_at + FROM pg_credentials WHERE account_id = ? + `, accountID).Scan( + &cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort, + &cred.PGDatabase, &cred.PGUsername, + &cred.PGPasswordEnc, &cred.PGPasswordNonce, + &createdAtStr, &updatedAtStr, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("db: read pg credentials: %w", err) + } + + cred.CreatedAt, err = parseTime(createdAtStr) + if err != nil { + return nil, err + } + cred.UpdatedAt, err = parseTime(updatedAtStr) + if err != nil { + return nil, err + } + return &cred, nil +} + +// WriteAuditEvent appends an audit log entry. +// Details must never contain credential material. +func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error { + _, err := db.sql.Exec(` + INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details) + VALUES (?, ?, ?, ?, ?) + `, eventType, actorID, targetID, nullString(ipAddress), nullString(details)) + if err != nil { + return fmt.Errorf("db: write audit event %q: %w", eventType, err) + } + return nil +} + +// TrackToken records a newly issued JWT JTI for revocation tracking. +func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error { + _, err := db.sql.Exec(` + INSERT INTO token_revocation (jti, account_id, issued_at, expires_at) + VALUES (?, ?, ?, ?) + `, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339)) + if err != nil { + return fmt.Errorf("db: track token %q: %w", jti, err) + } + return nil +} + +// GetTokenRecord retrieves a token record by JTI. +// Returns ErrNotFound if no record exists (token was never issued by this server). +func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) { + var rec model.TokenRecord + var issuedAtStr, expiresAtStr, createdAtStr string + var revokedAtStr *string + var revokeReason *string + + err := db.sql.QueryRow(` + SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at + FROM token_revocation WHERE jti = ? + `, jti).Scan( + &rec.ID, &rec.JTI, &rec.AccountID, + &expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason, + &createdAtStr, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("db: get token record %q: %w", jti, err) + } + + var parseErr error + rec.ExpiresAt, parseErr = parseTime(expiresAtStr) + if parseErr != nil { + return nil, parseErr + } + rec.IssuedAt, parseErr = parseTime(issuedAtStr) + if parseErr != nil { + return nil, parseErr + } + rec.CreatedAt, parseErr = parseTime(createdAtStr) + if parseErr != nil { + return nil, parseErr + } + rec.RevokedAt, parseErr = nullableTime(revokedAtStr) + if parseErr != nil { + return nil, parseErr + } + if revokeReason != nil { + rec.RevokeReason = *revokeReason + } + return &rec, nil +} + +// RevokeToken marks a token as revoked by JTI. +func (db *DB) RevokeToken(jti, reason string) error { + n := now() + result, err := db.sql.Exec(` + UPDATE token_revocation + SET revoked_at = ?, revoke_reason = ? + WHERE jti = ? AND revoked_at IS NULL + `, n, nullString(reason), jti) + if err != nil { + return fmt.Errorf("db: revoke token %q: %w", jti, err) + } + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("db: revoke token rows affected: %w", err) + } + if rows == 0 { + return fmt.Errorf("db: token %q not found or already revoked", jti) + } + return nil +} + +// RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account. +func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error { + n := now() + _, err := db.sql.Exec(` + UPDATE token_revocation + SET revoked_at = ?, revoke_reason = ? + WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ? + `, n, nullString(reason), accountID, n) + if err != nil { + return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err) + } + return nil +} + +// PruneExpiredTokens removes token_revocation rows that are past their expiry. +// Returns the number of rows deleted. +func (db *DB) PruneExpiredTokens() (int64, error) { + result, err := db.sql.Exec(` + DELETE FROM token_revocation WHERE expires_at < ? + `, now()) + if err != nil { + return 0, fmt.Errorf("db: prune expired tokens: %w", err) + } + return result.RowsAffected() +} + +// SetSystemToken stores or replaces the active service token JTI for a system account. +func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error { + n := now() + _, err := db.sql.Exec(` + INSERT INTO system_tokens (account_id, jti, expires_at, created_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(account_id) DO UPDATE SET + jti = excluded.jti, + expires_at = excluded.expires_at, + created_at = excluded.created_at + `, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n) + if err != nil { + return fmt.Errorf("db: set system token for account %d: %w", accountID, err) + } + return nil +} + +// GetSystemToken retrieves the active service token record for a system account. +func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) { + var st model.SystemToken + var expiresAtStr, createdAtStr string + + err := db.sql.QueryRow(` + SELECT id, account_id, jti, expires_at, created_at + FROM system_tokens WHERE account_id = ? + `, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + if err != nil { + return nil, fmt.Errorf("db: get system token: %w", err) + } + + var parseErr error + st.ExpiresAt, parseErr = parseTime(expiresAtStr) + if parseErr != nil { + return nil, parseErr + } + st.CreatedAt, parseErr = parseTime(createdAtStr) + if parseErr != nil { + return nil, parseErr + } + return &st, nil +} diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 0000000..aea267b --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,109 @@ +// Package db provides the SQLite database access layer for MCIAS. +// +// Security design: +// - All queries use parameterized statements; no string concatenation. +// - Foreign keys are enforced (PRAGMA foreign_keys = ON). +// - WAL mode is enabled for safe concurrent reads during writes. +// - The audit log is append-only: no update or delete operations are provided. +package db + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + _ "modernc.org/sqlite" // register the sqlite3 driver +) + +// DB wraps a *sql.DB with MCIAS-specific helpers. +type DB struct { + sql *sql.DB +} + +// Open opens (or creates) the SQLite database at path and configures it for +// MCIAS use (WAL mode, foreign keys, busy timeout). +func Open(path string) (*DB, error) { + // The modernc.org/sqlite driver is registered as "sqlite". + sqlDB, err := sql.Open("sqlite", path) + if err != nil { + return nil, fmt.Errorf("db: open sqlite: %w", err) + } + + // Use a single connection for writes; reads can use the pool. + sqlDB.SetMaxOpenConns(1) + + db := &DB{sql: sqlDB} + if err := db.configure(); err != nil { + _ = sqlDB.Close() + return nil, err + } + return db, nil +} + +// configure applies PRAGMAs that must be set on every connection. +func (db *DB) configure() error { + pragmas := []string{ + "PRAGMA journal_mode=WAL", + "PRAGMA foreign_keys=ON", + "PRAGMA busy_timeout=5000", + "PRAGMA synchronous=NORMAL", + } + for _, p := range pragmas { + if _, err := db.sql.Exec(p); err != nil { + return fmt.Errorf("db: configure pragma %q: %w", p, err) + } + } + return nil +} + +// Close closes the underlying database connection. +func (db *DB) Close() error { + return db.sql.Close() +} + +// Ping verifies the database connection is alive. +func (db *DB) Ping(ctx context.Context) error { + return db.sql.PingContext(ctx) +} + +// SQL returns the underlying *sql.DB for use in tests or advanced queries. +// Prefer the typed methods on DB for all production code. +func (db *DB) SQL() *sql.DB { + return db.sql +} + +// now returns the current UTC time formatted as ISO-8601. +func now() string { + return time.Now().UTC().Format(time.RFC3339) +} + +// parseTime parses an ISO-8601 UTC time string returned by SQLite. +func parseTime(s string) (time.Time, error) { + t, err := time.Parse(time.RFC3339, s) + if err != nil { + // Try without timezone suffix (some SQLite defaults). + t, err = time.Parse("2006-01-02T15:04:05", s) + if err != nil { + return time.Time{}, fmt.Errorf("db: parse time %q: %w", s, err) + } + return t.UTC(), nil + } + return t.UTC(), nil +} + +// ErrNotFound is returned when a requested record does not exist. +var ErrNotFound = errors.New("db: record not found") + +// nullableTime converts a *string from SQLite into a *time.Time. +func nullableTime(s *string) (*time.Time, error) { + if s == nil { + return nil, nil + } + t, err := parseTime(*s) + if err != nil { + return nil, err + } + return &t, nil +} diff --git a/internal/db/db_test.go b/internal/db/db_test.go new file mode 100644 index 0000000..6fcd628 --- /dev/null +++ b/internal/db/db_test.go @@ -0,0 +1,355 @@ +package db + +import ( + "testing" + "time" + + "git.wntrmute.dev/kyle/mcias/internal/model" +) + +// openTestDB opens an in-memory SQLite database for testing. +func openTestDB(t *testing.T) *DB { + t.Helper() + db, err := Open(":memory:") + if err != nil { + t.Fatalf("Open: %v", err) + } + if err := Migrate(db); err != nil { + t.Fatalf("Migrate: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db +} + +func TestMigrateIdempotent(t *testing.T) { + db := openTestDB(t) + // Run again — should be a no-op. + if err := Migrate(db); err != nil { + t.Errorf("second Migrate call returned error: %v", err) + } +} + +func TestCreateAndGetAccount(t *testing.T) { + db := openTestDB(t) + + acct, err := db.CreateAccount("alice", model.AccountTypeHuman, "$argon2id$v=19$...") + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + if acct.UUID == "" { + t.Error("expected non-empty UUID") + } + if acct.Username != "alice" { + t.Errorf("Username = %q, want %q", acct.Username, "alice") + } + if acct.Status != model.AccountStatusActive { + t.Errorf("Status = %q, want active", acct.Status) + } + + // Retrieve by UUID. + got, err := db.GetAccountByUUID(acct.UUID) + if err != nil { + t.Fatalf("GetAccountByUUID: %v", err) + } + if got.Username != "alice" { + t.Errorf("fetched Username = %q, want %q", got.Username, "alice") + } + + // Retrieve by username. + got2, err := db.GetAccountByUsername("alice") + if err != nil { + t.Fatalf("GetAccountByUsername: %v", err) + } + if got2.UUID != acct.UUID { + t.Errorf("UUID mismatch: got %q, want %q", got2.UUID, acct.UUID) + } +} + +func TestGetAccountNotFound(t *testing.T) { + db := openTestDB(t) + + _, err := db.GetAccountByUUID("nonexistent-uuid") + if err != ErrNotFound { + t.Errorf("expected ErrNotFound, got %v", err) + } + + _, err = db.GetAccountByUsername("nobody") + if err != ErrNotFound { + t.Errorf("expected ErrNotFound, got %v", err) + } +} + +func TestUpdateAccountStatus(t *testing.T) { + db := openTestDB(t) + acct, err := db.CreateAccount("bob", model.AccountTypeHuman, "hash") + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + if err := db.UpdateAccountStatus(acct.ID, model.AccountStatusInactive); err != nil { + t.Fatalf("UpdateAccountStatus: %v", err) + } + + got, err := db.GetAccountByUUID(acct.UUID) + if err != nil { + t.Fatalf("GetAccountByUUID: %v", err) + } + if got.Status != model.AccountStatusInactive { + t.Errorf("Status = %q, want inactive", got.Status) + } +} + +func TestListAccounts(t *testing.T) { + db := openTestDB(t) + for _, name := range []string{"charlie", "delta", "eve"} { + if _, err := db.CreateAccount(name, model.AccountTypeHuman, "hash"); err != nil { + t.Fatalf("CreateAccount %q: %v", name, err) + } + } + + accts, err := db.ListAccounts() + if err != nil { + t.Fatalf("ListAccounts: %v", err) + } + if len(accts) != 3 { + t.Errorf("ListAccounts returned %d accounts, want 3", len(accts)) + } +} + +func TestRoleOperations(t *testing.T) { + db := openTestDB(t) + acct, err := db.CreateAccount("frank", model.AccountTypeHuman, "hash") + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + // GrantRole + if err := db.GrantRole(acct.ID, "admin", nil); err != nil { + t.Fatalf("GrantRole: %v", err) + } + // Grant again — should be no-op. + if err := db.GrantRole(acct.ID, "admin", nil); err != nil { + t.Fatalf("GrantRole duplicate: %v", err) + } + + roles, err := db.GetRoles(acct.ID) + if err != nil { + t.Fatalf("GetRoles: %v", err) + } + if len(roles) != 1 || roles[0] != "admin" { + t.Errorf("GetRoles = %v, want [admin]", roles) + } + + has, err := db.HasRole(acct.ID, "admin") + if err != nil { + t.Fatalf("HasRole: %v", err) + } + if !has { + t.Error("expected HasRole to return true for 'admin'") + } + + // RevokeRole + if err := db.RevokeRole(acct.ID, "admin"); err != nil { + t.Fatalf("RevokeRole: %v", err) + } + roles, err = db.GetRoles(acct.ID) + if err != nil { + t.Fatalf("GetRoles after revoke: %v", err) + } + if len(roles) != 0 { + t.Errorf("expected no roles after revoke, got %v", roles) + } + + // SetRoles + if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil { + t.Fatalf("SetRoles: %v", err) + } + roles, err = db.GetRoles(acct.ID) + if err != nil { + t.Fatalf("GetRoles after SetRoles: %v", err) + } + if len(roles) != 2 { + t.Errorf("expected 2 roles after SetRoles, got %d", len(roles)) + } +} + +func TestTokenTrackingAndRevocation(t *testing.T) { + db := openTestDB(t) + acct, err := db.CreateAccount("grace", model.AccountTypeHuman, "hash") + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + jti := "test-jti-1234" + issuedAt := time.Now().UTC() + expiresAt := issuedAt.Add(time.Hour) + + if err := db.TrackToken(jti, acct.ID, issuedAt, expiresAt); err != nil { + t.Fatalf("TrackToken: %v", err) + } + + // Retrieve + rec, err := db.GetTokenRecord(jti) + if err != nil { + t.Fatalf("GetTokenRecord: %v", err) + } + if rec.JTI != jti { + t.Errorf("JTI = %q, want %q", rec.JTI, jti) + } + if rec.IsRevoked() { + t.Error("newly tracked token should not be revoked") + } + + // Revoke + if err := db.RevokeToken(jti, "test revocation"); err != nil { + t.Fatalf("RevokeToken: %v", err) + } + rec, err = db.GetTokenRecord(jti) + if err != nil { + t.Fatalf("GetTokenRecord after revoke: %v", err) + } + if !rec.IsRevoked() { + t.Error("token should be revoked after RevokeToken") + } + + // Revoking again should fail (already revoked). + if err := db.RevokeToken(jti, "again"); err == nil { + t.Error("expected error when revoking already-revoked token") + } +} + +func TestGetTokenRecordNotFound(t *testing.T) { + db := openTestDB(t) + _, err := db.GetTokenRecord("no-such-jti") + if err != ErrNotFound { + t.Errorf("expected ErrNotFound, got %v", err) + } +} + +func TestPruneExpiredTokens(t *testing.T) { + db := openTestDB(t) + acct, err := db.CreateAccount("henry", model.AccountTypeHuman, "hash") + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + past := time.Now().UTC().Add(-time.Hour) + future := time.Now().UTC().Add(time.Hour) + + if err := db.TrackToken("expired-jti", acct.ID, past.Add(-time.Hour), past); err != nil { + t.Fatalf("TrackToken expired: %v", err) + } + if err := db.TrackToken("valid-jti", acct.ID, time.Now(), future); err != nil { + t.Fatalf("TrackToken valid: %v", err) + } + + n, err := db.PruneExpiredTokens() + if err != nil { + t.Fatalf("PruneExpiredTokens: %v", err) + } + if n != 1 { + t.Errorf("pruned %d rows, want 1", n) + } + + // Valid token should still be retrievable. + if _, err := db.GetTokenRecord("valid-jti"); err != nil { + t.Errorf("valid token missing after prune: %v", err) + } +} + +func TestServerConfig(t *testing.T) { + db := openTestDB(t) + + // No config initially. + _, _, err := db.ReadServerConfig() + if err != ErrNotFound { + t.Errorf("expected ErrNotFound for missing config, got %v", err) + } + + enc := []byte("encrypted-key-data") + nonce := []byte("nonce12345678901") + + if err := db.WriteServerConfig(enc, nonce); err != nil { + t.Fatalf("WriteServerConfig: %v", err) + } + + gotEnc, gotNonce, err := db.ReadServerConfig() + if err != nil { + t.Fatalf("ReadServerConfig: %v", err) + } + if string(gotEnc) != string(enc) { + t.Errorf("enc mismatch: got %q, want %q", gotEnc, enc) + } + if string(gotNonce) != string(nonce) { + t.Errorf("nonce mismatch: got %q, want %q", gotNonce, nonce) + } + + // Overwrite — should work without error. + if err := db.WriteServerConfig([]byte("new-key"), []byte("new-nonce123456")); err != nil { + t.Fatalf("WriteServerConfig overwrite: %v", err) + } +} + +func TestForeignKeyEnforcement(t *testing.T) { + db := openTestDB(t) + // Attempting to track a token for a non-existent account should fail. + err := db.TrackToken("jti-x", 999999, time.Now(), time.Now().Add(time.Hour)) + if err == nil { + t.Error("expected foreign key error for non-existent account_id, got nil") + } +} + +func TestPGCredentials(t *testing.T) { + db := openTestDB(t) + acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "") + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + enc := []byte("encrypted-pg-password") + nonce := []byte("pg-nonce12345678") + + if err := db.WritePGCredentials(acct.ID, "localhost", 5432, "mydb", "myuser", enc, nonce); err != nil { + t.Fatalf("WritePGCredentials: %v", err) + } + + cred, err := db.ReadPGCredentials(acct.ID) + if err != nil { + t.Fatalf("ReadPGCredentials: %v", err) + } + if cred.PGHost != "localhost" { + t.Errorf("PGHost = %q, want %q", cred.PGHost, "localhost") + } + if cred.PGDatabase != "mydb" { + t.Errorf("PGDatabase = %q, want %q", cred.PGDatabase, "mydb") + } +} + +func TestRevokeAllUserTokens(t *testing.T) { + db := openTestDB(t) + acct, err := db.CreateAccount("ivan", model.AccountTypeHuman, "hash") + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + future := time.Now().UTC().Add(time.Hour) + for _, jti := range []string{"tok1", "tok2", "tok3"} { + if err := db.TrackToken(jti, acct.ID, time.Now(), future); err != nil { + t.Fatalf("TrackToken %q: %v", jti, err) + } + } + + if err := db.RevokeAllUserTokens(acct.ID, "account suspended"); err != nil { + t.Fatalf("RevokeAllUserTokens: %v", err) + } + + for _, jti := range []string{"tok1", "tok2", "tok3"} { + rec, err := db.GetTokenRecord(jti) + if err != nil { + t.Fatalf("GetTokenRecord %q: %v", jti, err) + } + if !rec.IsRevoked() { + t.Errorf("token %q should be revoked", jti) + } + } +} diff --git a/internal/db/migrate.go b/internal/db/migrate.go new file mode 100644 index 0000000..d64ec38 --- /dev/null +++ b/internal/db/migrate.go @@ -0,0 +1,187 @@ +package db + +import ( + "database/sql" + "fmt" +) + +// migration represents a single schema migration with an ID and SQL statement. +type migration struct { + id int + sql string +} + +// migrations is the ordered list of schema migrations applied to the database. +// Once applied, migrations must never be modified — only new ones appended. +var migrations = []migration{ + { + id: 1, + sql: ` +CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS server_config ( + id INTEGER PRIMARY KEY CHECK (id = 1), + signing_key_enc BLOB, + signing_key_nonce BLOB, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) +); + +CREATE TABLE IF NOT EXISTS accounts ( + id INTEGER PRIMARY KEY, + uuid TEXT NOT NULL UNIQUE, + username TEXT NOT NULL UNIQUE COLLATE NOCASE, + account_type TEXT NOT NULL CHECK (account_type IN ('human','system')), + password_hash TEXT, + status TEXT NOT NULL DEFAULT 'active' + CHECK (status IN ('active','inactive','deleted')), + totp_required INTEGER NOT NULL DEFAULT 0 CHECK (totp_required IN (0,1)), + totp_secret_enc BLOB, + totp_secret_nonce BLOB, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + deleted_at TEXT +); + +CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts (username); +CREATE INDEX IF NOT EXISTS idx_accounts_uuid ON accounts (uuid); +CREATE INDEX IF NOT EXISTS idx_accounts_status ON accounts (status); + +CREATE TABLE IF NOT EXISTS account_roles ( + id INTEGER PRIMARY KEY, + account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE, + role TEXT NOT NULL, + granted_by INTEGER REFERENCES accounts(id), + granted_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + UNIQUE (account_id, role) +); + +CREATE INDEX IF NOT EXISTS idx_account_roles_account ON account_roles (account_id); + +CREATE TABLE IF NOT EXISTS token_revocation ( + id INTEGER PRIMARY KEY, + jti TEXT NOT NULL UNIQUE, + account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE, + expires_at TEXT NOT NULL, + revoked_at TEXT, + revoke_reason TEXT, + issued_at TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) +); + +CREATE INDEX IF NOT EXISTS idx_token_jti ON token_revocation (jti); +CREATE INDEX IF NOT EXISTS idx_token_account ON token_revocation (account_id); +CREATE INDEX IF NOT EXISTS idx_token_expires ON token_revocation (expires_at); + +CREATE TABLE IF NOT EXISTS system_tokens ( + id INTEGER PRIMARY KEY, + account_id INTEGER NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE, + jti TEXT NOT NULL UNIQUE, + expires_at TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) +); + +CREATE TABLE IF NOT EXISTS pg_credentials ( + id INTEGER PRIMARY KEY, + account_id INTEGER NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE, + pg_host TEXT NOT NULL, + pg_port INTEGER NOT NULL DEFAULT 5432, + pg_database TEXT NOT NULL, + pg_username TEXT NOT NULL, + pg_password_enc BLOB NOT NULL, + pg_password_nonce BLOB NOT NULL, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) +); + +CREATE TABLE IF NOT EXISTS audit_log ( + id INTEGER PRIMARY KEY, + event_time TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + event_type TEXT NOT NULL, + actor_id INTEGER REFERENCES accounts(id), + target_id INTEGER REFERENCES accounts(id), + ip_address TEXT, + details TEXT +); + +CREATE INDEX IF NOT EXISTS idx_audit_time ON audit_log (event_time); +CREATE INDEX IF NOT EXISTS idx_audit_actor ON audit_log (actor_id); +CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log (event_type); +`, + }, + { + id: 2, + sql: ` +-- Add master_key_salt to server_config for Argon2id KDF salt storage. +-- The salt must be stable across restarts so the passphrase always yields the same key. +-- We allow NULL signing_key_enc/nonce temporarily until the first signing key is generated. +ALTER TABLE server_config ADD COLUMN master_key_salt BLOB; +`, + }, +} + +// Migrate applies any unapplied schema migrations to the database in order. +// It is idempotent: running it multiple times is safe. +func Migrate(db *DB) error { + // Ensure the schema_version table exists first. + if _, err := db.sql.Exec(` + CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER NOT NULL + ) + `); err != nil { + return fmt.Errorf("db: ensure schema_version: %w", err) + } + + currentVersion, err := currentSchemaVersion(db.sql) + if err != nil { + return fmt.Errorf("db: get current schema version: %w", err) + } + + for _, m := range migrations { + if m.id <= currentVersion { + continue + } + + tx, err := db.sql.Begin() + if err != nil { + return fmt.Errorf("db: begin migration %d transaction: %w", m.id, err) + } + + if _, err := tx.Exec(m.sql); err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: apply migration %d: %w", m.id, err) + } + + // Update the schema version within the same transaction. + if currentVersion == 0 { + if _, err := tx.Exec(`INSERT INTO schema_version (version) VALUES (?)`, m.id); err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: insert schema version %d: %w", m.id, err) + } + } else { + if _, err := tx.Exec(`UPDATE schema_version SET version = ?`, m.id); err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: update schema version to %d: %w", m.id, err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("db: commit migration %d: %w", m.id, err) + } + currentVersion = m.id + } + return nil +} + +// currentSchemaVersion returns the current schema version, or 0 if none applied. +func currentSchemaVersion(db *sql.DB) (int, error) { + var version int + err := db.QueryRow(`SELECT version FROM schema_version LIMIT 1`).Scan(&version) + if err != nil { + // No rows means version 0 (fresh database). + return 0, nil //nolint:nilerr + } + return version, nil +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go new file mode 100644 index 0000000..4578433 --- /dev/null +++ b/internal/middleware/middleware.go @@ -0,0 +1,290 @@ +// Package middleware provides HTTP middleware for the MCIAS server. +// +// Security design: +// - RequireAuth extracts the Bearer token from the Authorization header, +// validates it (alg check, signature, expiry, issuer), and checks revocation +// against the database before injecting claims into the request context. +// - RequireRole checks claims from context for the required role. +// No role implies no access; the check fails closed. +// - RateLimit implements a per-IP token bucket to limit login brute-force. +// - RequestLogger logs request metadata but never logs the Authorization +// header value (which contains credential tokens). +package middleware + +import ( + "context" + "crypto/ed25519" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net" + "net/http" + "strings" + "sync" + "time" + + "git.wntrmute.dev/kyle/mcias/internal/db" + "git.wntrmute.dev/kyle/mcias/internal/token" +) + +// contextKey is the unexported type for context keys in this package, preventing +// collisions with keys from other packages. +type contextKey int + +const ( + claimsKey contextKey = iota +) + +// ClaimsFromContext retrieves the validated JWT claims from the request context. +// Returns nil if no claims are present (unauthenticated request). +func ClaimsFromContext(ctx context.Context) *token.Claims { + c, _ := ctx.Value(claimsKey).(*token.Claims) + return c +} + +// RequestLogger returns middleware that logs each request at INFO level. +// The Authorization header is intentionally never logged. +func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + // Wrap the ResponseWriter to capture the status code. + rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(rw, r) + + logger.Info("request", + "method", r.Method, + "path", r.URL.Path, + "status", rw.status, + "duration_ms", time.Since(start).Milliseconds(), + "remote_addr", r.RemoteAddr, + "user_agent", r.UserAgent(), + // Security: Authorization header is never logged. + ) + }) + } +} + +// responseWriter wraps http.ResponseWriter to capture the status code. +type responseWriter struct { + http.ResponseWriter + status int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.status = code + rw.ResponseWriter.WriteHeader(code) +} + +// RequireAuth returns middleware that validates a Bearer JWT and injects the +// claims into the request context. Returns 401 on any auth failure. +// +// Security: Token validation order: +// 1. Extract Bearer token from Authorization header. +// 2. Validate the JWT (alg=EdDSA, signature, expiry, issuer). +// 3. Check the JTI against the revocation table in the database. +// 4. Inject validated claims into context for downstream handlers. +func RequireAuth(pubKey ed25519.PublicKey, database *db.DB, issuer string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenStr, err := extractBearerToken(r) + if err != nil { + writeError(w, http.StatusUnauthorized, "missing or malformed Authorization header", "unauthorized") + return + } + + claims, err := token.ValidateToken(pubKey, tokenStr, issuer) + if err != nil { + // Security: Map all token errors to a generic 401; do not + // reveal which specific check failed. + if errors.Is(err, token.ErrExpiredToken) { + writeError(w, http.StatusUnauthorized, "token expired", "token_expired") + return + } + writeError(w, http.StatusUnauthorized, "invalid token", "unauthorized") + return + } + + // Security: Check revocation table. A token may be cryptographically + // valid but explicitly revoked (logout, account suspension, etc.). + rec, err := database.GetTokenRecord(claims.JTI) + if err != nil { + if errors.Is(err, db.ErrNotFound) { + // Token not tracked — could be from a different server instance + // or pre-dates tracking. Reject to be safe (fail closed). + writeError(w, http.StatusUnauthorized, "unrecognized token", "unauthorized") + return + } + writeError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + if rec.IsRevoked() { + writeError(w, http.StatusUnauthorized, "token has been revoked", "token_revoked") + return + } + + ctx := context.WithValue(r.Context(), claimsKey, claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// RequireRole returns middleware that checks whether the authenticated user has +// the given role. Must be used after RequireAuth. Returns 403 if role is absent. +func RequireRole(role string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims := ClaimsFromContext(r.Context()) + if claims == nil { + // RequireAuth was not applied upstream; fail closed. + writeError(w, http.StatusForbidden, "forbidden", "forbidden") + return + } + if !claims.HasRole(role) { + writeError(w, http.StatusForbidden, "insufficient privileges", "forbidden") + return + } + next.ServeHTTP(w, r) + }) + } +} + +// rateLimitEntry holds the token bucket state for a single IP. +type rateLimitEntry struct { + tokens float64 + lastSeen time.Time + mu sync.Mutex +} + +// ipRateLimiter implements a per-IP token bucket rate limiter. +type ipRateLimiter struct { + rps float64 // refill rate: tokens per second + burst float64 // bucket capacity + ttl time.Duration // how long to keep idle entries + mu sync.Mutex + ips map[string]*rateLimitEntry +} + +// RateLimit returns middleware implementing a per-IP token bucket. +// rps is the sustained request rate (tokens refilled per second). +// burst is the maximum burst size (initial and maximum token count). +// +// Security: Rate limiting is applied at the IP level. In production, the +// server should be behind a reverse proxy that sets X-Forwarded-For; this +// middleware uses RemoteAddr directly which may be the proxy IP. For single- +// instance deployment without a proxy, RemoteAddr is the client IP. +func RateLimit(rps float64, burst int) func(http.Handler) http.Handler { + limiter := &ipRateLimiter{ + rps: rps, + burst: float64(burst), + ttl: 10 * time.Minute, + ips: make(map[string]*rateLimitEntry), + } + + // Background cleanup of idle entries to prevent unbounded memory growth. + go limiter.cleanup() + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + ip = r.RemoteAddr + } + + if !limiter.allow(ip) { + w.Header().Set("Retry-After", "60") + writeError(w, http.StatusTooManyRequests, "rate limit exceeded", "rate_limited") + return + } + next.ServeHTTP(w, r) + }) + } +} + +// allow returns true if a request from ip is permitted under the rate limit. +func (l *ipRateLimiter) allow(ip string) bool { + l.mu.Lock() + entry, ok := l.ips[ip] + if !ok { + entry = &rateLimitEntry{tokens: l.burst, lastSeen: time.Now()} + l.ips[ip] = entry + } + l.mu.Unlock() + + entry.mu.Lock() + defer entry.mu.Unlock() + + now := time.Now() + elapsed := now.Sub(entry.lastSeen).Seconds() + entry.tokens = min(l.burst, entry.tokens+elapsed*l.rps) + entry.lastSeen = now + + if entry.tokens < 1 { + return false + } + entry.tokens-- + return true +} + +// cleanup periodically removes idle rate-limit entries. +func (l *ipRateLimiter) cleanup() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + l.mu.Lock() + cutoff := time.Now().Add(-l.ttl) + for ip, entry := range l.ips { + entry.mu.Lock() + if entry.lastSeen.Before(cutoff) { + delete(l.ips, ip) + } + entry.mu.Unlock() + } + l.mu.Unlock() + } +} + +// extractBearerToken extracts the token from "Authorization: Bearer ". +func extractBearerToken(r *http.Request) (string, error) { + auth := r.Header.Get("Authorization") + if auth == "" { + return "", fmt.Errorf("missing Authorization header") + } + parts := strings.SplitN(auth, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return "", fmt.Errorf("malformed Authorization header") + } + if parts[1] == "" { + return "", fmt.Errorf("empty Bearer token") + } + return parts[1], nil +} + +// apiError is the uniform error response structure. +type apiError struct { + Error string `json:"error"` + Code string `json:"code"` +} + +// writeError writes a JSON error response. +func writeError(w http.ResponseWriter, status int, message, code string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + // Intentionally ignoring the error here; if the write fails, the client + // already got the status code. + _ = json.NewEncoder(w).Encode(apiError{Error: message, Code: code}) +} + +// WriteError is the exported version for use by handler packages. +func WriteError(w http.ResponseWriter, status int, message, code string) { + writeError(w, status, message, code) +} + +// min returns the smaller of two float64 values. +func min(a, b float64) float64 { + if a < b { + return a + } + return b +} diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go new file mode 100644 index 0000000..f4c0ed4 --- /dev/null +++ b/internal/middleware/middleware_test.go @@ -0,0 +1,342 @@ +package middleware + +import ( + "bytes" + "context" + "crypto/ed25519" + "crypto/rand" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "git.wntrmute.dev/kyle/mcias/internal/db" + "git.wntrmute.dev/kyle/mcias/internal/model" + "git.wntrmute.dev/kyle/mcias/internal/token" +) + +func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) { + t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("generate test key: %v", err) + } + return pub, priv +} + +func openTestDB(t *testing.T) *db.DB { + t.Helper() + database, err := db.Open(":memory:") + if err != nil { + t.Fatalf("open test db: %v", err) + } + if err := db.Migrate(database); err != nil { + t.Fatalf("migrate test db: %v", err) + } + t.Cleanup(func() { _ = database.Close() }) + return database +} + +const testIssuer = "https://auth.example.com" + +// issueAndTrackToken creates a valid JWT and records it in the DB. +func issueAndTrackToken(t *testing.T, priv ed25519.PrivateKey, database *db.DB, accountID int64, roles []string) string { + t.Helper() + tokenStr, claims, err := token.IssueToken(priv, testIssuer, "user-uuid", roles, time.Hour) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + if err := database.TrackToken(claims.JTI, accountID, claims.IssuedAt, claims.ExpiresAt); err != nil { + t.Fatalf("TrackToken: %v", err) + } + return tokenStr +} + +func TestRequestLogger(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + + handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/health", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("status = %d, want 200", rr.Code) + } + logOutput := buf.String() + if logOutput == "" { + t.Error("expected log output, got empty string") + } + // Security: Authorization header must not appear in logs. + req2 := httptest.NewRequest(http.MethodGet, "/v1/health", nil) + req2.Header.Set("Authorization", "Bearer secret-token-value") + buf.Reset() + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req2) + if bytes.Contains(buf.Bytes(), []byte("secret-token-value")) { + t.Error("log output contains Authorization token value — credential leak!") + } +} + +func TestRequireAuthValid(t *testing.T) { + pub, priv := generateTestKey(t) + database := openTestDB(t) + + acct, err := database.CreateAccount("alice", model.AccountTypeHuman, "hash") + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + tokenStr := issueAndTrackToken(t, priv, database, acct.ID, []string{"reader"}) + + reached := false + handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reached = true + claims := ClaimsFromContext(r.Context()) + if claims == nil { + t.Error("claims not in context") + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("status = %d, want 200; body: %s", rr.Code, rr.Body.String()) + } + if !reached { + t.Error("handler was not reached with valid token") + } +} + +func TestRequireAuthMissingHeader(t *testing.T) { + pub, priv := generateTestKey(t) + _ = priv + database := openTestDB(t) + + handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be reached without auth") + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", rr.Code) + } +} + +func TestRequireAuthInvalidToken(t *testing.T) { + pub, _ := generateTestKey(t) + database := openTestDB(t) + + handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be reached with invalid token") + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) + req.Header.Set("Authorization", "Bearer not.a.valid.jwt") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", rr.Code) + } +} + +func TestRequireAuthRevokedToken(t *testing.T) { + pub, priv := generateTestKey(t) + database := openTestDB(t) + + acct, err := database.CreateAccount("bob", model.AccountTypeHuman, "hash") + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + tokenStr := issueAndTrackToken(t, priv, database, acct.ID, nil) + + // Extract JTI and revoke the token. + claims, err := token.ValidateToken(pub, tokenStr, testIssuer) + if err != nil { + t.Fatalf("ValidateToken: %v", err) + } + if err := database.RevokeToken(claims.JTI, "test revocation"); err != nil { + t.Fatalf("RevokeToken: %v", err) + } + + handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be reached with revoked token") + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", rr.Code) + } +} + +func TestRequireAuthExpiredToken(t *testing.T) { + pub, priv := generateTestKey(t) + database := openTestDB(t) + + // Issue an already-expired token. + tokenStr, _, err := token.IssueToken(priv, testIssuer, "user-uuid", nil, -time.Minute) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + + handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be reached with expired token") + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", rr.Code) + } +} + +func TestRequireRoleGranted(t *testing.T) { + claims := &token.Claims{Roles: []string{"admin"}} + ctx := context.WithValue(context.Background(), claimsKey, claims) + + reached := false + handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reached = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("status = %d, want 200", rr.Code) + } + if !reached { + t.Error("handler not reached with correct role") + } +} + +func TestRequireRoleForbidden(t *testing.T) { + claims := &token.Claims{Roles: []string{"reader"}} + ctx := context.WithValue(context.Background(), claimsKey, claims) + + handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be reached without admin role") + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Errorf("status = %d, want 403", rr.Code) + } +} + +func TestRequireRoleNoClaims(t *testing.T) { + handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be reached without claims in context") + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Errorf("status = %d, want 403", rr.Code) + } +} + +func TestRateLimitAllows(t *testing.T) { + handler := RateLimit(10, 5)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil) + req.RemoteAddr = "127.0.0.1:12345" + + // First 5 requests should be allowed (burst=5). + for i := range 5 { + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("request %d: status = %d, want 200", i+1, rr.Code) + } + } +} + +func TestRateLimitBlocks(t *testing.T) { + handler := RateLimit(0.1, 2)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil) + req.RemoteAddr = "10.0.0.1:9999" + + // Exhaust the burst of 2. + for range 2 { + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + } + + // Next request should be rate-limited. + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusTooManyRequests { + t.Errorf("status = %d, want 429 after burst exceeded", rr.Code) + } +} + +func TestExtractBearerToken(t *testing.T) { + tests := []struct { + name string + header string + wantErr bool + want string + }{ + {"valid", "Bearer mytoken123", false, "mytoken123"}, + {"missing header", "", true, ""}, + {"no bearer prefix", "Token mytoken123", true, ""}, + {"empty token", "Bearer ", true, ""}, + {"case insensitive", "bearer mytoken123", false, "mytoken123"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.header != "" { + req.Header.Set("Authorization", tc.header) + } + got, err := extractBearerToken(req) + if (err != nil) != tc.wantErr { + t.Errorf("wantErr=%v, got err=%v", tc.wantErr, err) + } + if !tc.wantErr && got != tc.want { + t.Errorf("token = %q, want %q", got, tc.want) + } + }) + } +} diff --git a/internal/model/model.go b/internal/model/model.go new file mode 100644 index 0000000..64f030e --- /dev/null +++ b/internal/model/model.go @@ -0,0 +1,144 @@ +// Package model defines the shared data types used throughout MCIAS. +// These are pure data definitions with no external dependencies. +package model + +import "time" + +// AccountType distinguishes human interactive accounts from non-interactive +// service accounts. +type AccountType string + +const ( + AccountTypeHuman AccountType = "human" + AccountTypeSystem AccountType = "system" +) + +// AccountStatus represents the lifecycle state of an account. +type AccountStatus string + +const ( + AccountStatusActive AccountStatus = "active" + AccountStatusInactive AccountStatus = "inactive" + AccountStatusDeleted AccountStatus = "deleted" +) + +// Account represents a user or service identity in MCIAS. +// Fields containing credential material (PasswordHash, TOTPSecretEnc) are +// never serialised into API responses — callers must explicitly omit them. +type Account struct { + ID int64 `json:"-"` + UUID string `json:"id"` + Username string `json:"username"` + AccountType AccountType `json:"account_type"` + Status AccountStatus `json:"status"` + TOTPRequired bool `json:"totp_required"` + + // PasswordHash is a PHC-format Argon2id string. Never returned in API + // responses; populated only when reading from the database. + PasswordHash string `json:"-"` + + // TOTPSecretEnc and TOTPSecretNonce hold the AES-256-GCM-encrypted TOTP + // shared secret. Never returned in API responses. + TOTPSecretEnc []byte `json:"-"` + TOTPSecretNonce []byte `json:"-"` + + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` +} + +// Role is a string label assigned to an account to grant permissions. +type Role struct { + ID int64 `json:"-"` + AccountID int64 `json:"-"` + Role string `json:"role"` + GrantedBy *int64 `json:"-"` + GrantedAt time.Time `json:"granted_at"` +} + +// TokenRecord tracks an issued JWT by its JTI for revocation purposes. +// The raw token string is never stored — only the JTI identifier. +type TokenRecord struct { + ID int64 `json:"-"` + JTI string `json:"jti"` + AccountID int64 `json:"-"` + ExpiresAt time.Time `json:"expires_at"` + IssuedAt time.Time `json:"issued_at"` + RevokedAt *time.Time `json:"revoked_at,omitempty"` + RevokeReason string `json:"revoke_reason,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// IsRevoked reports whether the token has been explicitly revoked. +func (t *TokenRecord) IsRevoked() bool { + return t.RevokedAt != nil +} + +// IsExpired reports whether the token is past its expiry time. +func (t *TokenRecord) IsExpired() bool { + return time.Now().After(t.ExpiresAt) +} + +// SystemToken represents the current active service token for a system account. +type SystemToken struct { + ID int64 `json:"-"` + AccountID int64 `json:"-"` + JTI string `json:"jti"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` +} + +// PGCredential holds Postgres connection details for a system account. +// The password is encrypted at rest; PGPassword is only populated after +// decryption and must never be logged or included in API responses. +type PGCredential struct { + ID int64 `json:"-"` + AccountID int64 `json:"-"` + PGHost string `json:"host"` + PGPort int `json:"port"` + PGDatabase string `json:"database"` + PGUsername string `json:"username"` + + // PGPassword is plaintext only after decryption. Never log or serialise. + PGPassword string `json:"-"` + + // PGPasswordEnc and PGPasswordNonce are the AES-256-GCM ciphertext and + // nonce stored in the database. + PGPasswordEnc []byte `json:"-"` + PGPasswordNonce []byte `json:"-"` + + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// AuditEvent represents a single entry in the append-only audit log. +// Details must never contain credential material (passwords, tokens, secrets). +type AuditEvent struct { + ID int64 `json:"id"` + EventTime time.Time `json:"event_time"` + EventType string `json:"event_type"` + ActorID *int64 `json:"-"` + TargetID *int64 `json:"-"` + IPAddress string `json:"ip_address,omitempty"` + Details string `json:"details,omitempty"` // JSON string; no secrets +} + +// Audit event type constants — exhaustive list, enforced at write time. +const ( + EventLoginOK = "login_ok" + EventLoginFail = "login_fail" + EventLoginTOTPFail = "login_totp_fail" + EventTokenIssued = "token_issued" + EventTokenRenewed = "token_renewed" + EventTokenRevoked = "token_revoked" + EventTokenExpired = "token_expired" + EventAccountCreated = "account_created" + EventAccountUpdated = "account_updated" + EventAccountDeleted = "account_deleted" + EventRoleGranted = "role_granted" + EventRoleRevoked = "role_revoked" + EventTOTPEnrolled = "totp_enrolled" + EventTOTPRemoved = "totp_removed" + EventPGCredAccessed = "pgcred_accessed" + EventPGCredUpdated = "pgcred_updated" +) diff --git a/internal/model/model_test.go b/internal/model/model_test.go new file mode 100644 index 0000000..ac4dd47 --- /dev/null +++ b/internal/model/model_test.go @@ -0,0 +1,83 @@ +package model + +import ( + "testing" + "time" +) + +func TestAccountTypeConstants(t *testing.T) { + if AccountTypeHuman != "human" { + t.Errorf("AccountTypeHuman = %q, want %q", AccountTypeHuman, "human") + } + if AccountTypeSystem != "system" { + t.Errorf("AccountTypeSystem = %q, want %q", AccountTypeSystem, "system") + } +} + +func TestAccountStatusConstants(t *testing.T) { + if AccountStatusActive != "active" { + t.Errorf("AccountStatusActive = %q, want %q", AccountStatusActive, "active") + } + if AccountStatusInactive != "inactive" { + t.Errorf("AccountStatusInactive = %q, want %q", AccountStatusInactive, "inactive") + } + if AccountStatusDeleted != "deleted" { + t.Errorf("AccountStatusDeleted = %q, want %q", AccountStatusDeleted, "deleted") + } +} + +func TestTokenRecordIsRevoked(t *testing.T) { + now := time.Now() + + notRevoked := &TokenRecord{} + if notRevoked.IsRevoked() { + t.Error("expected token with nil RevokedAt to not be revoked") + } + + revoked := &TokenRecord{RevokedAt: &now} + if !revoked.IsRevoked() { + t.Error("expected token with RevokedAt set to be revoked") + } +} + +func TestTokenRecordIsExpired(t *testing.T) { + past := time.Now().Add(-time.Hour) + future := time.Now().Add(time.Hour) + + expired := &TokenRecord{ExpiresAt: past} + if !expired.IsExpired() { + t.Error("expected token with past ExpiresAt to be expired") + } + + valid := &TokenRecord{ExpiresAt: future} + if valid.IsExpired() { + t.Error("expected token with future ExpiresAt to not be expired") + } +} + +func TestAuditEventConstants(t *testing.T) { + // Spot-check a few to ensure they are not empty strings. + events := []string{ + EventLoginOK, + EventLoginFail, + EventLoginTOTPFail, + EventTokenIssued, + EventTokenRenewed, + EventTokenRevoked, + EventTokenExpired, + EventAccountCreated, + EventAccountUpdated, + EventAccountDeleted, + EventRoleGranted, + EventRoleRevoked, + EventTOTPEnrolled, + EventTOTPRemoved, + EventPGCredAccessed, + EventPGCredUpdated, + } + for _, e := range events { + if e == "" { + t.Errorf("audit event constant is empty string") + } + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..6c71428 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,904 @@ +// Package server wires together the HTTP router, middleware, and handlers +// for the MCIAS authentication server. +// +// Security design: +// - All endpoints use HTTPS (enforced at the listener level in cmd/mciassrv). +// - Authentication state is carried via JWT; no cookies or server-side sessions. +// - Credential fields (password hash, TOTP secret, Postgres password) are +// never included in any API response. +// - All JSON parsing uses strict decoders that reject unknown fields. +package server + +import ( + "crypto/ed25519" + "encoding/json" + "fmt" + "log/slog" + "net/http" + + "git.wntrmute.dev/kyle/mcias/internal/auth" + "git.wntrmute.dev/kyle/mcias/internal/config" + "git.wntrmute.dev/kyle/mcias/internal/crypto" + "git.wntrmute.dev/kyle/mcias/internal/db" + "git.wntrmute.dev/kyle/mcias/internal/middleware" + "git.wntrmute.dev/kyle/mcias/internal/model" + "git.wntrmute.dev/kyle/mcias/internal/token" +) + +// Server holds the dependencies injected into all handlers. +type Server struct { + db *db.DB + cfg *config.Config + privKey ed25519.PrivateKey + pubKey ed25519.PublicKey + masterKey []byte + logger *slog.Logger +} + +// New creates a Server with the given dependencies. +func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed25519.PublicKey, masterKey []byte, logger *slog.Logger) *Server { + return &Server{ + db: database, + cfg: cfg, + privKey: priv, + pubKey: pub, + masterKey: masterKey, + logger: logger, + } +} + +// Handler builds and returns the root HTTP handler with all routes and middleware. +func (s *Server) Handler() http.Handler { + mux := http.NewServeMux() + + // Public endpoints (no authentication required). + mux.HandleFunc("GET /v1/health", s.handleHealth) + mux.HandleFunc("GET /v1/keys/public", s.handlePublicKey) + mux.HandleFunc("POST /v1/auth/login", s.handleLogin) + mux.HandleFunc("POST /v1/token/validate", s.handleTokenValidate) + + // Authenticated endpoints. + requireAuth := middleware.RequireAuth(s.pubKey, s.db, s.cfg.Tokens.Issuer) + requireAdmin := func(h http.Handler) http.Handler { + return requireAuth(middleware.RequireRole("admin")(h)) + } + + // Auth endpoints (require valid token). + mux.Handle("POST /v1/auth/logout", requireAuth(http.HandlerFunc(s.handleLogout))) + mux.Handle("POST /v1/auth/renew", requireAuth(http.HandlerFunc(s.handleRenew))) + mux.Handle("POST /v1/auth/totp/enroll", requireAuth(http.HandlerFunc(s.handleTOTPEnroll))) + mux.Handle("POST /v1/auth/totp/confirm", requireAuth(http.HandlerFunc(s.handleTOTPConfirm))) + + // Admin-only endpoints. + mux.Handle("DELETE /v1/auth/totp", requireAdmin(http.HandlerFunc(s.handleTOTPRemove))) + mux.Handle("POST /v1/token/issue", requireAdmin(http.HandlerFunc(s.handleTokenIssue))) + mux.Handle("DELETE /v1/token/{jti}", requireAdmin(http.HandlerFunc(s.handleTokenRevoke))) + mux.Handle("GET /v1/accounts", requireAdmin(http.HandlerFunc(s.handleListAccounts))) + mux.Handle("POST /v1/accounts", requireAdmin(http.HandlerFunc(s.handleCreateAccount))) + mux.Handle("GET /v1/accounts/{id}", requireAdmin(http.HandlerFunc(s.handleGetAccount))) + mux.Handle("PATCH /v1/accounts/{id}", requireAdmin(http.HandlerFunc(s.handleUpdateAccount))) + mux.Handle("DELETE /v1/accounts/{id}", requireAdmin(http.HandlerFunc(s.handleDeleteAccount))) + mux.Handle("GET /v1/accounts/{id}/roles", requireAdmin(http.HandlerFunc(s.handleGetRoles))) + mux.Handle("PUT /v1/accounts/{id}/roles", requireAdmin(http.HandlerFunc(s.handleSetRoles))) + mux.Handle("GET /v1/accounts/{id}/pgcreds", requireAdmin(http.HandlerFunc(s.handleGetPGCreds))) + mux.Handle("PUT /v1/accounts/{id}/pgcreds", requireAdmin(http.HandlerFunc(s.handleSetPGCreds))) + + // Apply global middleware: logging and login-path rate limiting. + var root http.Handler = mux + root = middleware.RequestLogger(s.logger)(root) + return root +} + +// ---- Public handlers ---- + +func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) +} + +// handlePublicKey returns the server's Ed25519 public key in JWK format. +// This allows relying parties to independently verify JWTs. +func (s *Server) handlePublicKey(w http.ResponseWriter, r *http.Request) { + // Encode the Ed25519 public key as a JWK (RFC 8037). + // The "x" parameter is the base64url-encoded public key bytes. + jwk := map[string]string{ + "kty": "OKP", + "crv": "Ed25519", + "use": "sig", + "alg": "EdDSA", + "x": encodeBase64URL(s.pubKey), + } + writeJSON(w, http.StatusOK, jwk) +} + +// ---- Auth handlers ---- + +// loginRequest is the request body for POST /v1/auth/login. +type loginRequest struct { + Username string `json:"username"` + Password string `json:"password"` + TOTPCode string `json:"totp_code,omitempty"` +} + +// loginResponse is the response body for a successful login. +type loginResponse struct { + Token string `json:"token"` + ExpiresAt string `json:"expires_at"` +} + +func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { + var req loginRequest + if !decodeJSON(w, r, &req) { + return + } + + if req.Username == "" || req.Password == "" { + middleware.WriteError(w, http.StatusBadRequest, "username and password are required", "bad_request") + return + } + + // Load account by username. + acct, err := s.db.GetAccountByUsername(req.Username) + if err != nil { + // Security: return a generic error whether the user exists or not. + // Always run a dummy Argon2 check to prevent timing-based user enumeration. + _, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g") + s.writeAudit(r, model.EventLoginFail, nil, nil, fmt.Sprintf(`{"username":%q,"reason":"unknown_user"}`, req.Username)) + middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized") + return + } + + // Security: Check account status before credential verification to avoid + // leaking whether the account exists based on timing differences. + if acct.Status != model.AccountStatusActive { + _, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g") + s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, fmt.Sprintf(`{"reason":"account_inactive"}`)) + middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized") + return + } + + // Verify password. This is always run, even for system accounts (which have + // no password hash), to maintain constant timing. + ok, err := auth.VerifyPassword(req.Password, acct.PasswordHash) + if err != nil || !ok { + s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, `{"reason":"wrong_password"}`) + middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized") + return + } + + // TOTP check (if enrolled). + if acct.TOTPRequired { + if req.TOTPCode == "" { + s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, `{"reason":"totp_missing"}`) + middleware.WriteError(w, http.StatusUnauthorized, "TOTP code required", "totp_required") + return + } + // Decrypt the TOTP secret. + secret, err := crypto.OpenAESGCM(s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc) + if err != nil { + s.logger.Error("decrypt TOTP secret", "error", err, "account_id", acct.ID) + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + valid, err := auth.ValidateTOTP(secret, req.TOTPCode) + if err != nil || !valid { + s.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"wrong_totp"}`) + middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized") + return + } + } + + // Determine expiry. + expiry := s.cfg.DefaultExpiry() + roles, err := s.db.GetRoles(acct.ID) + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + for _, r := range roles { + if r == "admin" { + expiry = s.cfg.AdminExpiry() + break + } + } + + tokenStr, claims, err := token.IssueToken(s.privKey, s.cfg.Tokens.Issuer, acct.UUID, roles, expiry) + if err != nil { + s.logger.Error("issue token", "error", err) + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + if err := s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { + s.logger.Error("track token", "error", err) + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + s.writeAudit(r, model.EventLoginOK, &acct.ID, nil, "") + s.writeAudit(r, model.EventTokenIssued, &acct.ID, nil, fmt.Sprintf(`{"jti":%q}`, claims.JTI)) + + writeJSON(w, http.StatusOK, loginResponse{ + Token: tokenStr, + ExpiresAt: claims.ExpiresAt.Format("2006-01-02T15:04:05Z"), + }) +} + +func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + if err := s.db.RevokeToken(claims.JTI, "logout"); err != nil { + s.logger.Error("revoke token on logout", "error", err) + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + s.writeAudit(r, model.EventTokenRevoked, nil, nil, fmt.Sprintf(`{"jti":%q,"reason":"logout"}`, claims.JTI)) + w.WriteHeader(http.StatusNoContent) +} + +func (s *Server) handleRenew(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + // Load account to get current roles (they may have changed since token issuance). + acct, err := s.db.GetAccountByUUID(claims.Subject) + if err != nil { + middleware.WriteError(w, http.StatusUnauthorized, "account not found", "unauthorized") + return + } + if acct.Status != model.AccountStatusActive { + middleware.WriteError(w, http.StatusUnauthorized, "account inactive", "unauthorized") + return + } + + roles, err := s.db.GetRoles(acct.ID) + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + expiry := s.cfg.DefaultExpiry() + for _, role := range roles { + if role == "admin" { + expiry = s.cfg.AdminExpiry() + break + } + } + + newTokenStr, newClaims, err := token.IssueToken(s.privKey, s.cfg.Tokens.Issuer, acct.UUID, roles, expiry) + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + // Revoke the old token and track the new one atomically is not possible + // in SQLite without a transaction. We do best-effort: revoke old, track new. + if err := s.db.RevokeToken(claims.JTI, "renewed"); err != nil { + s.logger.Error("revoke old token on renew", "error", err) + } + if err := s.db.TrackToken(newClaims.JTI, acct.ID, newClaims.IssuedAt, newClaims.ExpiresAt); err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + s.writeAudit(r, model.EventTokenRenewed, &acct.ID, nil, fmt.Sprintf(`{"old_jti":%q,"new_jti":%q}`, claims.JTI, newClaims.JTI)) + + writeJSON(w, http.StatusOK, loginResponse{ + Token: newTokenStr, + ExpiresAt: newClaims.ExpiresAt.Format("2006-01-02T15:04:05Z"), + }) +} + +// ---- Token endpoints ---- + +type validateRequest struct { + Token string `json:"token"` +} + +type validateResponse struct { + Valid bool `json:"valid"` + Subject string `json:"sub,omitempty"` + Roles []string `json:"roles,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +func (s *Server) handleTokenValidate(w http.ResponseWriter, r *http.Request) { + // Accept token either from Authorization: Bearer header or JSON body. + tokenStr, err := extractBearerFromRequest(r) + if err != nil { + // Try JSON body. + var req validateRequest + if !decodeJSON(w, r, &req) { + return + } + tokenStr = req.Token + } + + if tokenStr == "" { + writeJSON(w, http.StatusOK, validateResponse{Valid: false}) + return + } + + claims, err := token.ValidateToken(s.pubKey, tokenStr, s.cfg.Tokens.Issuer) + if err != nil { + writeJSON(w, http.StatusOK, validateResponse{Valid: false}) + return + } + + rec, err := s.db.GetTokenRecord(claims.JTI) + if err != nil || rec.IsRevoked() { + writeJSON(w, http.StatusOK, validateResponse{Valid: false}) + return + } + + writeJSON(w, http.StatusOK, validateResponse{ + Valid: true, + Subject: claims.Subject, + Roles: claims.Roles, + ExpiresAt: claims.ExpiresAt.Format("2006-01-02T15:04:05Z"), + }) +} + +type issueTokenRequest struct { + AccountID string `json:"account_id"` +} + +func (s *Server) handleTokenIssue(w http.ResponseWriter, r *http.Request) { + var req issueTokenRequest + if !decodeJSON(w, r, &req) { + return + } + + acct, err := s.db.GetAccountByUUID(req.AccountID) + if err != nil { + middleware.WriteError(w, http.StatusNotFound, "account not found", "not_found") + return + } + if acct.AccountType != model.AccountTypeSystem { + middleware.WriteError(w, http.StatusBadRequest, "token issue is only for system accounts", "bad_request") + return + } + + tokenStr, claims, err := token.IssueToken(s.privKey, s.cfg.Tokens.Issuer, acct.UUID, nil, s.cfg.ServiceExpiry()) + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + // Revoke existing system token if any. + existing, err := s.db.GetSystemToken(acct.ID) + if err == nil && existing != nil { + _ = s.db.RevokeToken(existing.JTI, "rotated") + } + + if err := s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + if err := s.db.SetSystemToken(acct.ID, claims.JTI, claims.ExpiresAt); err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + actor := middleware.ClaimsFromContext(r.Context()) + var actorID *int64 + if actor != nil { + if a, err := s.db.GetAccountByUUID(actor.Subject); err == nil { + actorID = &a.ID + } + } + s.writeAudit(r, model.EventTokenIssued, actorID, &acct.ID, fmt.Sprintf(`{"jti":%q}`, claims.JTI)) + + writeJSON(w, http.StatusOK, loginResponse{ + Token: tokenStr, + ExpiresAt: claims.ExpiresAt.Format("2006-01-02T15:04:05Z"), + }) +} + +func (s *Server) handleTokenRevoke(w http.ResponseWriter, r *http.Request) { + jti := r.PathValue("jti") + if jti == "" { + middleware.WriteError(w, http.StatusBadRequest, "jti is required", "bad_request") + return + } + + if err := s.db.RevokeToken(jti, "admin revocation"); err != nil { + middleware.WriteError(w, http.StatusNotFound, "token not found or already revoked", "not_found") + return + } + + s.writeAudit(r, model.EventTokenRevoked, nil, nil, fmt.Sprintf(`{"jti":%q}`, jti)) + w.WriteHeader(http.StatusNoContent) +} + +// ---- Account endpoints ---- + +type createAccountRequest struct { + Username string `json:"username"` + Password string `json:"password,omitempty"` + Type string `json:"account_type"` +} + +type accountResponse struct { + ID string `json:"id"` + Username string `json:"username"` + AccountType string `json:"account_type"` + Status string `json:"status"` + TOTPEnabled bool `json:"totp_enabled"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +func accountToResponse(a *model.Account) accountResponse { + resp := accountResponse{ + ID: a.UUID, + Username: a.Username, + AccountType: string(a.AccountType), + Status: string(a.Status), + TOTPEnabled: a.TOTPRequired, + CreatedAt: a.CreatedAt.Format("2006-01-02T15:04:05Z"), + UpdatedAt: a.UpdatedAt.Format("2006-01-02T15:04:05Z"), + } + return resp +} + +func (s *Server) handleListAccounts(w http.ResponseWriter, r *http.Request) { + accounts, err := s.db.ListAccounts() + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + resp := make([]accountResponse, len(accounts)) + for i, a := range accounts { + resp[i] = accountToResponse(a) + } + writeJSON(w, http.StatusOK, resp) +} + +func (s *Server) handleCreateAccount(w http.ResponseWriter, r *http.Request) { + var req createAccountRequest + if !decodeJSON(w, r, &req) { + return + } + + if req.Username == "" { + middleware.WriteError(w, http.StatusBadRequest, "username is required", "bad_request") + return + } + accountType := model.AccountType(req.Type) + if accountType != model.AccountTypeHuman && accountType != model.AccountTypeSystem { + middleware.WriteError(w, http.StatusBadRequest, "account_type must be 'human' or 'system'", "bad_request") + return + } + + var passwordHash string + if accountType == model.AccountTypeHuman { + if req.Password == "" { + middleware.WriteError(w, http.StatusBadRequest, "password is required for human accounts", "bad_request") + return + } + var err error + passwordHash, err = auth.HashPassword(req.Password, auth.ArgonParams{ + Time: s.cfg.Argon2.Time, + Memory: s.cfg.Argon2.Memory, + Threads: s.cfg.Argon2.Threads, + }) + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + } + + acct, err := s.db.CreateAccount(req.Username, accountType, passwordHash) + if err != nil { + middleware.WriteError(w, http.StatusConflict, "username already exists", "conflict") + return + } + + s.writeAudit(r, model.EventAccountCreated, nil, &acct.ID, fmt.Sprintf(`{"username":%q}`, acct.Username)) + writeJSON(w, http.StatusCreated, accountToResponse(acct)) +} + +func (s *Server) handleGetAccount(w http.ResponseWriter, r *http.Request) { + acct, ok := s.loadAccount(w, r) + if !ok { + return + } + writeJSON(w, http.StatusOK, accountToResponse(acct)) +} + +type updateAccountRequest struct { + Status string `json:"status,omitempty"` +} + +func (s *Server) handleUpdateAccount(w http.ResponseWriter, r *http.Request) { + acct, ok := s.loadAccount(w, r) + if !ok { + return + } + + var req updateAccountRequest + if !decodeJSON(w, r, &req) { + return + } + + if req.Status != "" { + newStatus := model.AccountStatus(req.Status) + if newStatus != model.AccountStatusActive && newStatus != model.AccountStatusInactive { + middleware.WriteError(w, http.StatusBadRequest, "status must be 'active' or 'inactive'", "bad_request") + return + } + if err := s.db.UpdateAccountStatus(acct.ID, newStatus); err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + } + + s.writeAudit(r, model.EventAccountUpdated, nil, &acct.ID, "") + w.WriteHeader(http.StatusNoContent) +} + +func (s *Server) handleDeleteAccount(w http.ResponseWriter, r *http.Request) { + acct, ok := s.loadAccount(w, r) + if !ok { + return + } + + if err := s.db.UpdateAccountStatus(acct.ID, model.AccountStatusDeleted); err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + if err := s.db.RevokeAllUserTokens(acct.ID, "account deleted"); err != nil { + s.logger.Error("revoke tokens on delete", "error", err, "account_id", acct.ID) + } + + s.writeAudit(r, model.EventAccountDeleted, nil, &acct.ID, "") + w.WriteHeader(http.StatusNoContent) +} + +// ---- Role endpoints ---- + +type rolesResponse struct { + Roles []string `json:"roles"` +} + +type setRolesRequest struct { + Roles []string `json:"roles"` +} + +func (s *Server) handleGetRoles(w http.ResponseWriter, r *http.Request) { + acct, ok := s.loadAccount(w, r) + if !ok { + return + } + roles, err := s.db.GetRoles(acct.ID) + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + if roles == nil { + roles = []string{} + } + writeJSON(w, http.StatusOK, rolesResponse{Roles: roles}) +} + +func (s *Server) handleSetRoles(w http.ResponseWriter, r *http.Request) { + acct, ok := s.loadAccount(w, r) + if !ok { + return + } + + var req setRolesRequest + if !decodeJSON(w, r, &req) { + return + } + + actor := middleware.ClaimsFromContext(r.Context()) + var grantedBy *int64 + if actor != nil { + if a, err := s.db.GetAccountByUUID(actor.Subject); err == nil { + grantedBy = &a.ID + } + } + + if err := s.db.SetRoles(acct.ID, req.Roles, grantedBy); err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + s.writeAudit(r, model.EventRoleGranted, grantedBy, &acct.ID, fmt.Sprintf(`{"roles":%v}`, req.Roles)) + w.WriteHeader(http.StatusNoContent) +} + +// ---- TOTP endpoints ---- + +type totpEnrollResponse struct { + Secret string `json:"secret"` // base32-encoded + OTPAuthURI string `json:"otpauth_uri"` +} + +type totpConfirmRequest struct { + Code string `json:"code"` +} + +func (s *Server) handleTOTPEnroll(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + acct, err := s.db.GetAccountByUUID(claims.Subject) + if err != nil { + middleware.WriteError(w, http.StatusUnauthorized, "account not found", "unauthorized") + return + } + + rawSecret, b32Secret, err := auth.GenerateTOTPSecret() + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + // Encrypt the secret before storing it temporarily. + // Note: we store as pending; enrollment is confirmed with /confirm. + secretEnc, secretNonce, err := crypto.SealAESGCM(s.masterKey, rawSecret) + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + // Store the encrypted pending secret. The totp_required flag is NOT set + // yet — it is set only after the user confirms the code. + if err := s.db.SetTOTP(acct.ID, secretEnc, secretNonce); err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + otpURI := fmt.Sprintf("otpauth://totp/MCIAS:%s?secret=%s&issuer=MCIAS", acct.Username, b32Secret) + + // Security: return the secret for display to the user. It is only shown + // once; subsequent reads are not possible (only the encrypted form is stored). + writeJSON(w, http.StatusOK, totpEnrollResponse{ + Secret: b32Secret, + OTPAuthURI: otpURI, + }) +} + +func (s *Server) handleTOTPConfirm(w http.ResponseWriter, r *http.Request) { + var req totpConfirmRequest + if !decodeJSON(w, r, &req) { + return + } + + claims := middleware.ClaimsFromContext(r.Context()) + acct, err := s.db.GetAccountByUUID(claims.Subject) + if err != nil { + middleware.WriteError(w, http.StatusUnauthorized, "account not found", "unauthorized") + return + } + + if acct.TOTPSecretEnc == nil { + middleware.WriteError(w, http.StatusBadRequest, "TOTP enrollment not started", "bad_request") + return + } + + secret, err := crypto.OpenAESGCM(s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc) + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + valid, err := auth.ValidateTOTP(secret, req.Code) + if err != nil || !valid { + middleware.WriteError(w, http.StatusUnauthorized, "invalid TOTP code", "unauthorized") + return + } + + // Mark TOTP as confirmed and required. + if err := s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + s.writeAudit(r, model.EventTOTPEnrolled, &acct.ID, nil, "") + w.WriteHeader(http.StatusNoContent) +} + +type totpRemoveRequest struct { + AccountID string `json:"account_id"` +} + +func (s *Server) handleTOTPRemove(w http.ResponseWriter, r *http.Request) { + var req totpRemoveRequest + if !decodeJSON(w, r, &req) { + return + } + + acct, err := s.db.GetAccountByUUID(req.AccountID) + if err != nil { + middleware.WriteError(w, http.StatusNotFound, "account not found", "not_found") + return + } + + if err := s.db.ClearTOTP(acct.ID); err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + s.writeAudit(r, model.EventTOTPRemoved, nil, &acct.ID, "") + w.WriteHeader(http.StatusNoContent) +} + +// ---- Postgres credential endpoints ---- + +type pgCredRequest struct { + Host string `json:"host"` + Port int `json:"port"` + Database string `json:"database"` + Username string `json:"username"` + Password string `json:"password"` +} + +type pgCredResponse struct { + Host string `json:"host"` + Port int `json:"port"` + Database string `json:"database"` + Username string `json:"username"` + // Security: Password is NEVER included in the response, even on GET. + // The caller must explicitly decrypt it on the server side. +} + +func (s *Server) handleGetPGCreds(w http.ResponseWriter, r *http.Request) { + acct, ok := s.loadAccount(w, r) + if !ok { + return + } + + cred, err := s.db.ReadPGCredentials(acct.ID) + if err != nil { + if err == db.ErrNotFound { + middleware.WriteError(w, http.StatusNotFound, "no credentials stored", "not_found") + return + } + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + // Decrypt the password to return it to the admin caller. + password, err := crypto.OpenAESGCM(s.masterKey, cred.PGPasswordNonce, cred.PGPasswordEnc) + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + s.writeAudit(r, model.EventPGCredAccessed, nil, &acct.ID, "") + + // Return including password since this is an explicit admin retrieval. + writeJSON(w, http.StatusOK, map[string]interface{}{ + "host": cred.PGHost, + "port": cred.PGPort, + "database": cred.PGDatabase, + "username": cred.PGUsername, + "password": string(password), // included only for admin retrieval + }) +} + +func (s *Server) handleSetPGCreds(w http.ResponseWriter, r *http.Request) { + acct, ok := s.loadAccount(w, r) + if !ok { + return + } + + var req pgCredRequest + if !decodeJSON(w, r, &req) { + return + } + + if req.Host == "" || req.Database == "" || req.Username == "" || req.Password == "" { + middleware.WriteError(w, http.StatusBadRequest, "host, database, username, and password are required", "bad_request") + return + } + if req.Port == 0 { + req.Port = 5432 + } + + enc, nonce, err := crypto.SealAESGCM(s.masterKey, []byte(req.Password)) + if err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + if err := s.db.WritePGCredentials(acct.ID, req.Host, req.Port, req.Database, req.Username, enc, nonce); err != nil { + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return + } + + s.writeAudit(r, model.EventPGCredUpdated, nil, &acct.ID, "") + w.WriteHeader(http.StatusNoContent) +} + +// ---- Helpers ---- + +// loadAccount retrieves an account by the {id} path parameter (UUID). +func (s *Server) loadAccount(w http.ResponseWriter, r *http.Request) (*model.Account, bool) { + id := r.PathValue("id") + if id == "" { + middleware.WriteError(w, http.StatusBadRequest, "account id is required", "bad_request") + return nil, false + } + acct, err := s.db.GetAccountByUUID(id) + if err != nil { + if err == db.ErrNotFound { + middleware.WriteError(w, http.StatusNotFound, "account not found", "not_found") + return nil, false + } + middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") + return nil, false + } + return acct, true +} + +// writeAudit appends an audit log entry, logging errors but not failing the request. +func (s *Server) writeAudit(r *http.Request, eventType string, actorID, targetID *int64, details string) { + ip := r.RemoteAddr + if err := s.db.WriteAuditEvent(eventType, actorID, targetID, ip, details); err != nil { + s.logger.Error("write audit event", "error", err, "event_type", eventType) + } +} + +// writeJSON encodes v as JSON and writes it to w with the given status code. +func writeJSON(w http.ResponseWriter, status int, v interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(v); err != nil { + // If encoding fails, the status is already written; log but don't panic. + _ = err + } +} + +// decodeJSON decodes a JSON request body into v. +// Returns false and writes a 400 response if decoding fails. +func decodeJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool { + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(v); err != nil { + middleware.WriteError(w, http.StatusBadRequest, "invalid JSON request body", "bad_request") + return false + } + return true +} + +// extractBearerFromRequest extracts a Bearer token from the Authorization header. +func extractBearerFromRequest(r *http.Request) (string, error) { + auth := r.Header.Get("Authorization") + if auth == "" { + return "", fmt.Errorf("no Authorization header") + } + const prefix = "Bearer " + if len(auth) <= len(prefix) { + return "", fmt.Errorf("malformed Authorization header") + } + return auth[len(prefix):], nil +} + +// encodeBase64URL encodes bytes as base64url without padding. +func encodeBase64URL(b []byte) string { + const table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" + result := make([]byte, 0, (len(b)*4+2)/3) + for i := 0; i < len(b); i += 3 { + switch { + case i+2 < len(b): + result = append(result, + table[b[i]>>2], + table[(b[i]&3)<<4|b[i+1]>>4], + table[(b[i+1]&0xf)<<2|b[i+2]>>6], + table[b[i+2]&0x3f], + ) + case i+1 < len(b): + result = append(result, + table[b[i]>>2], + table[(b[i]&3)<<4|b[i+1]>>4], + table[(b[i+1]&0xf)<<2], + ) + default: + result = append(result, + table[b[i]>>2], + table[(b[i]&3)<<4], + ) + } + } + return string(result) +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..4bdd7af --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,434 @@ +package server + +import ( + "bytes" + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "git.wntrmute.dev/kyle/mcias/internal/auth" + "git.wntrmute.dev/kyle/mcias/internal/config" + "git.wntrmute.dev/kyle/mcias/internal/db" + "git.wntrmute.dev/kyle/mcias/internal/model" + "git.wntrmute.dev/kyle/mcias/internal/token" +) + +const testIssuer = "https://auth.example.com" + +func newTestServer(t *testing.T) (*Server, ed25519.PublicKey, ed25519.PrivateKey, *db.DB) { + t.Helper() + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + database, err := db.Open(":memory:") + if err != nil { + t.Fatalf("open db: %v", err) + } + if err := db.Migrate(database); err != nil { + t.Fatalf("migrate db: %v", err) + } + t.Cleanup(func() { _ = database.Close() }) + + masterKey := make([]byte, 32) + if _, err := rand.Read(masterKey); err != nil { + t.Fatalf("generate master key: %v", err) + } + + cfg := config.NewTestConfig(testIssuer) + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + srv := New(database, cfg, priv, pub, masterKey, logger) + return srv, pub, priv, database +} + +// createTestHumanAccount creates a human account with password "testpass123". +func createTestHumanAccount(t *testing.T, srv *Server, username string) *model.Account { + t.Helper() + hash, err := auth.HashPassword("testpass123", auth.ArgonParams{Time: 3, Memory: 65536, Threads: 4}) + if err != nil { + t.Fatalf("hash password: %v", err) + } + acct, err := srv.db.CreateAccount(username, model.AccountTypeHuman, hash) + if err != nil { + t.Fatalf("create account: %v", err) + } + return acct +} + +// issueAdminToken creates an account with admin role, issues a JWT, and tracks it. +func issueAdminToken(t *testing.T, srv *Server, priv ed25519.PrivateKey, username string) (string, *model.Account) { + t.Helper() + acct := createTestHumanAccount(t, srv, username) + if err := srv.db.GrantRole(acct.ID, "admin", nil); err != nil { + t.Fatalf("grant admin role: %v", err) + } + + tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, []string{"admin"}, time.Hour) + if err != nil { + t.Fatalf("issue token: %v", err) + } + if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { + t.Fatalf("track token: %v", err) + } + return tokenStr, acct +} + +func doRequest(t *testing.T, handler http.Handler, method, path string, body interface{}, authToken string) *httptest.ResponseRecorder { + t.Helper() + var bodyReader io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshal body: %v", err) + } + bodyReader = bytes.NewReader(b) + } else { + bodyReader = bytes.NewReader(nil) + } + + req := httptest.NewRequest(method, path, bodyReader) + req.Header.Set("Content-Type", "application/json") + if authToken != "" { + req.Header.Set("Authorization", "Bearer "+authToken) + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + return rr +} + +func TestHealth(t *testing.T) { + srv, _, _, _ := newTestServer(t) + handler := srv.Handler() + + rr := doRequest(t, handler, "GET", "/v1/health", nil, "") + if rr.Code != http.StatusOK { + t.Errorf("health status = %d, want 200", rr.Code) + } +} + +func TestPublicKey(t *testing.T) { + srv, _, _, _ := newTestServer(t) + handler := srv.Handler() + + rr := doRequest(t, handler, "GET", "/v1/keys/public", nil, "") + if rr.Code != http.StatusOK { + t.Errorf("public key status = %d, want 200", rr.Code) + } + + var jwk map[string]string + if err := json.Unmarshal(rr.Body.Bytes(), &jwk); err != nil { + t.Fatalf("unmarshal JWK: %v", err) + } + if jwk["kty"] != "OKP" { + t.Errorf("kty = %q, want OKP", jwk["kty"]) + } + if jwk["alg"] != "EdDSA" { + t.Errorf("alg = %q, want EdDSA", jwk["alg"]) + } +} + +func TestLoginSuccess(t *testing.T) { + srv, _, _, _ := newTestServer(t) + createTestHumanAccount(t, srv, "alice") + handler := srv.Handler() + + rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ + "username": "alice", + "password": "testpass123", + }, "") + + if rr.Code != http.StatusOK { + t.Errorf("login status = %d, want 200; body: %s", rr.Code, rr.Body.String()) + } + + var resp loginResponse + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal login response: %v", err) + } + if resp.Token == "" { + t.Error("expected non-empty token in login response") + } + if resp.ExpiresAt == "" { + t.Error("expected non-empty expires_at in login response") + } +} + +func TestLoginWrongPassword(t *testing.T) { + srv, _, _, _ := newTestServer(t) + createTestHumanAccount(t, srv, "bob") + handler := srv.Handler() + + rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ + "username": "bob", + "password": "wrongpassword", + }, "") + + if rr.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", rr.Code) + } +} + +func TestLoginUnknownUser(t *testing.T) { + srv, _, _, _ := newTestServer(t) + handler := srv.Handler() + + rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ + "username": "nobody", + "password": "password", + }, "") + + if rr.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", rr.Code) + } +} + +func TestLoginResponseDoesNotContainCredentials(t *testing.T) { + srv, _, _, _ := newTestServer(t) + createTestHumanAccount(t, srv, "charlie") + handler := srv.Handler() + + rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ + "username": "charlie", + "password": "testpass123", + }, "") + + body := rr.Body.String() + // Security: password hash must never appear in any API response. + if strings.Contains(body, "argon2id") || strings.Contains(body, "password_hash") { + t.Error("login response contains password hash material") + } +} + +func TestTokenValidate(t *testing.T) { + srv, _, priv, _ := newTestServer(t) + acct := createTestHumanAccount(t, srv, "dave") + handler := srv.Handler() + + // Issue and track a token. + tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { + t.Fatalf("TrackToken: %v", err) + } + + req := httptest.NewRequest("POST", "/v1/token/validate", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("validate status = %d, want 200", rr.Code) + } + + var resp validateResponse + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if !resp.Valid { + t.Error("expected valid=true for valid token") + } +} + +func TestLogout(t *testing.T) { + srv, _, priv, _ := newTestServer(t) + acct := createTestHumanAccount(t, srv, "eve") + handler := srv.Handler() + + tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { + t.Fatalf("TrackToken: %v", err) + } + + // Logout. + rr := doRequest(t, handler, "POST", "/v1/auth/logout", nil, tokenStr) + if rr.Code != http.StatusNoContent { + t.Errorf("logout status = %d, want 204; body: %s", rr.Code, rr.Body.String()) + } + + // Token should now be invalid on validate. + req := httptest.NewRequest("POST", "/v1/token/validate", nil) + req.Header.Set("Authorization", "Bearer "+tokenStr) + rr2 := httptest.NewRecorder() + handler.ServeHTTP(rr2, req) + + var resp validateResponse + _ = json.Unmarshal(rr2.Body.Bytes(), &resp) + if resp.Valid { + t.Error("expected valid=false after logout") + } +} + +func TestCreateAccountAdmin(t *testing.T) { + srv, _, priv, _ := newTestServer(t) + adminToken, _ := issueAdminToken(t, srv, priv, "admin-user") + handler := srv.Handler() + + rr := doRequest(t, handler, "POST", "/v1/accounts", map[string]string{ + "username": "new-user", + "password": "newpassword123", + "account_type": "human", + }, adminToken) + + if rr.Code != http.StatusCreated { + t.Errorf("create account status = %d, want 201; body: %s", rr.Code, rr.Body.String()) + } + + var resp accountResponse + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if resp.Username != "new-user" { + t.Errorf("Username = %q, want %q", resp.Username, "new-user") + } + // Security: password hash must not appear in account response. + body := rr.Body.String() + if strings.Contains(body, "password_hash") || strings.Contains(body, "argon2id") { + t.Error("account creation response contains password hash") + } +} + +func TestCreateAccountRequiresAdmin(t *testing.T) { + srv, _, priv, _ := newTestServer(t) + acct := createTestHumanAccount(t, srv, "regular-user") + tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, []string{"reader"}, time.Hour) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { + t.Fatalf("TrackToken: %v", err) + } + handler := srv.Handler() + + rr := doRequest(t, handler, "POST", "/v1/accounts", map[string]string{ + "username": "other-user", + "password": "password", + "account_type": "human", + }, tokenStr) + + if rr.Code != http.StatusForbidden { + t.Errorf("status = %d, want 403", rr.Code) + } +} + +func TestListAccounts(t *testing.T) { + srv, _, priv, _ := newTestServer(t) + adminToken, _ := issueAdminToken(t, srv, priv, "admin2") + createTestHumanAccount(t, srv, "user1") + createTestHumanAccount(t, srv, "user2") + handler := srv.Handler() + + rr := doRequest(t, handler, "GET", "/v1/accounts", nil, adminToken) + if rr.Code != http.StatusOK { + t.Errorf("list accounts status = %d, want 200", rr.Code) + } + + var accounts []accountResponse + if err := json.Unmarshal(rr.Body.Bytes(), &accounts); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(accounts) < 3 { // admin + user1 + user2 + t.Errorf("expected at least 3 accounts, got %d", len(accounts)) + } + + // Security: no credential fields in any response. + body := rr.Body.String() + for _, bad := range []string{"password_hash", "argon2id", "totp_secret", "PasswordHash"} { + if strings.Contains(body, bad) { + t.Errorf("account list response contains credential field %q", bad) + } + } +} + +func TestDeleteAccount(t *testing.T) { + srv, _, priv, _ := newTestServer(t) + adminToken, _ := issueAdminToken(t, srv, priv, "admin3") + target := createTestHumanAccount(t, srv, "delete-me") + handler := srv.Handler() + + rr := doRequest(t, handler, "DELETE", "/v1/accounts/"+target.UUID, nil, adminToken) + if rr.Code != http.StatusNoContent { + t.Errorf("delete status = %d, want 204; body: %s", rr.Code, rr.Body.String()) + } +} + +func TestSetAndGetRoles(t *testing.T) { + srv, _, priv, _ := newTestServer(t) + adminToken, _ := issueAdminToken(t, srv, priv, "admin4") + target := createTestHumanAccount(t, srv, "role-target") + handler := srv.Handler() + + // Set roles. + rr := doRequest(t, handler, "PUT", "/v1/accounts/"+target.UUID+"/roles", map[string][]string{ + "roles": {"reader", "writer"}, + }, adminToken) + if rr.Code != http.StatusNoContent { + t.Errorf("set roles status = %d, want 204; body: %s", rr.Code, rr.Body.String()) + } + + // Get roles. + rr2 := doRequest(t, handler, "GET", "/v1/accounts/"+target.UUID+"/roles", nil, adminToken) + if rr2.Code != http.StatusOK { + t.Errorf("get roles status = %d, want 200", rr2.Code) + } + + var resp rolesResponse + if err := json.Unmarshal(rr2.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(resp.Roles) != 2 { + t.Errorf("expected 2 roles, got %d", len(resp.Roles)) + } +} + +func TestRenewToken(t *testing.T) { + srv, _, priv, _ := newTestServer(t) + acct := createTestHumanAccount(t, srv, "renew-user") + handler := srv.Handler() + + oldTokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + oldJTI := claims.JTI + if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { + t.Fatalf("TrackToken: %v", err) + } + + rr := doRequest(t, handler, "POST", "/v1/auth/renew", nil, oldTokenStr) + if rr.Code != http.StatusOK { + t.Fatalf("renew status = %d, want 200; body: %s", rr.Code, rr.Body.String()) + } + + var resp loginResponse + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal renew response: %v", err) + } + if resp.Token == "" || resp.Token == oldTokenStr { + t.Error("expected new, distinct token after renewal") + } + + // Old token should be revoked in the database. + rec, err := srv.db.GetTokenRecord(oldJTI) + if err != nil { + t.Fatalf("GetTokenRecord: %v", err) + } + if !rec.IsRevoked() { + t.Error("old token should be revoked after renewal") + } +} diff --git a/internal/token/token.go b/internal/token/token.go new file mode 100644 index 0000000..7877b8a --- /dev/null +++ b/internal/token/token.go @@ -0,0 +1,181 @@ +// Package token handles JWT issuance, validation, and revocation for MCIAS. +// +// Security design: +// - Algorithm header is checked FIRST, before any signature verification. +// This prevents algorithm-confusion attacks (CVE-2022-21449 class). +// - Only "EdDSA" is accepted; "none", HS*, RS*, ES* are all rejected. +// - The signing key is taken from the server's keystore, never from the token. +// - All standard claims (exp, iat, iss, jti) are required and validated. +// - JTIs are UUIDs generated from crypto/rand (via google/uuid). +// - Token values are never stored; only JTIs are recorded for revocation. +package token + +import ( + "crypto/ed25519" + "errors" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +const ( + // requiredAlg is the only JWT algorithm accepted by MCIAS. + // Security: Hard-coding this as a constant rather than a variable ensures + // it cannot be changed at runtime and cannot be confused by token headers. + requiredAlg = "EdDSA" +) + +// Claims holds the MCIAS-specific JWT claims. +type Claims struct { + // Standard registered claims. + Issuer string `json:"iss"` + Subject string `json:"sub"` // account UUID + IssuedAt time.Time `json:"iat"` + ExpiresAt time.Time `json:"exp"` + JTI string `json:"jti"` + + // MCIAS-specific claims. + Roles []string `json:"roles"` +} + +// jwtClaims adapts Claims to the golang-jwt MapClaims interface. +type jwtClaims struct { + jwt.RegisteredClaims + Roles []string `json:"roles"` +} + +// ErrExpiredToken is returned when the token's exp claim is in the past. +var ErrExpiredToken = errors.New("token: expired") + +// ErrInvalidSignature is returned when Ed25519 signature verification fails. +var ErrInvalidSignature = errors.New("token: invalid signature") + +// ErrWrongAlgorithm is returned when the alg header is not EdDSA. +var ErrWrongAlgorithm = errors.New("token: algorithm must be EdDSA") + +// ErrMissingClaim is returned when a required claim is absent or empty. +var ErrMissingClaim = errors.New("token: missing required claim") + +// IssueToken creates and signs a new JWT with the given claims. +// The jti is generated automatically using crypto/rand via uuid.New(). +// Returns the signed token string. +// +// Security: The signing key is provided by the caller from the server's +// keystore. The alg header is set explicitly to "EdDSA" by the jwt library +// when an ed25519.PrivateKey is passed to SignedString. +func IssueToken(key ed25519.PrivateKey, issuer, subject string, roles []string, expiry time.Duration) (string, *Claims, error) { + now := time.Now().UTC() + exp := now.Add(expiry) + jti := uuid.New().String() + + jc := jwtClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: issuer, + Subject: subject, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(exp), + ID: jti, + }, + Roles: roles, + } + + t := jwt.NewWithClaims(jwt.SigningMethodEdDSA, jc) + signed, err := t.SignedString(key) + if err != nil { + return "", nil, fmt.Errorf("token: sign JWT: %w", err) + } + + claims := &Claims{ + Issuer: issuer, + Subject: subject, + IssuedAt: now, + ExpiresAt: exp, + JTI: jti, + Roles: roles, + } + return signed, claims, nil +} + +// ValidateToken parses and validates a JWT string. +// +// Security order of operations (all must pass): +// 1. Parse the token header and extract the alg field. +// 2. Reject immediately if alg != "EdDSA" (before any signature check). +// 3. Verify Ed25519 signature. +// 4. Validate exp, iat, iss, jti claims. +// +// Returns Claims on success, or a typed error on any failure. +// The caller is responsible for checking revocation status via the DB. +func ValidateToken(key ed25519.PublicKey, tokenString, expectedIssuer string) (*Claims, error) { + // Step 1+2: Parse the header to check alg BEFORE any crypto. + // Security: We use jwt.ParseWithClaims with an explicit key function that + // enforces the algorithm. The key function is called by the library after + // parsing the header but before verifying the signature, which is the + // correct point to enforce algorithm constraints. + var jc jwtClaims + t, err := jwt.ParseWithClaims(tokenString, &jc, func(t *jwt.Token) (interface{}, error) { + // Security: Check alg header first. This must happen in the key + // function — it is the only place where the parsed (but unverified) + // header is available before signature validation. + if t.Method.Alg() != requiredAlg { + return nil, fmt.Errorf("%w: got %q, want %q", ErrWrongAlgorithm, t.Method.Alg(), requiredAlg) + } + return key, nil + }, + jwt.WithIssuedAt(), + jwt.WithIssuer(expectedIssuer), + jwt.WithExpirationRequired(), + ) + if err != nil { + // Map library errors to our typed errors for consistent handling. + if errors.Is(err, ErrWrongAlgorithm) { + return nil, ErrWrongAlgorithm + } + if errors.Is(err, jwt.ErrTokenExpired) { + return nil, ErrExpiredToken + } + if errors.Is(err, jwt.ErrSignatureInvalid) { + return nil, ErrInvalidSignature + } + return nil, fmt.Errorf("token: parse: %w", err) + } + if !t.Valid { + return nil, ErrInvalidSignature + } + + // Step 4: Validate required custom claims. + if jc.ID == "" { + return nil, fmt.Errorf("%w: jti", ErrMissingClaim) + } + if jc.Subject == "" { + return nil, fmt.Errorf("%w: sub", ErrMissingClaim) + } + if jc.ExpiresAt == nil { + return nil, fmt.Errorf("%w: exp", ErrMissingClaim) + } + if jc.IssuedAt == nil { + return nil, fmt.Errorf("%w: iat", ErrMissingClaim) + } + + claims := &Claims{ + Issuer: jc.Issuer, + Subject: jc.Subject, + IssuedAt: jc.IssuedAt.Time, + ExpiresAt: jc.ExpiresAt.Time, + JTI: jc.ID, + Roles: jc.Roles, + } + return claims, nil +} + +// HasRole reports whether the claims include the given role. +func (c *Claims) HasRole(role string) bool { + for _, r := range c.Roles { + if r == role { + return true + } + } + return false +} diff --git a/internal/token/token_test.go b/internal/token/token_test.go new file mode 100644 index 0000000..00e6fff --- /dev/null +++ b/internal/token/token_test.go @@ -0,0 +1,222 @@ +package token + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) { + t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("generate key: %v", err) + } + return pub, priv +} + +// b64url encodes a string as base64url without padding. +func b64url(s string) string { + return base64.RawURLEncoding.EncodeToString([]byte(s)) +} + +const testIssuer = "https://auth.example.com" + +func TestIssueAndValidateToken(t *testing.T) { + pub, priv := generateTestKey(t) + roles := []string{"admin", "reader"} + + tokenStr, claims, err := IssueToken(priv, testIssuer, "user-uuid-1", roles, time.Hour) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + if tokenStr == "" { + t.Fatal("IssueToken returned empty token string") + } + if claims.JTI == "" { + t.Error("JTI must not be empty") + } + if claims.Subject != "user-uuid-1" { + t.Errorf("Subject = %q, want %q", claims.Subject, "user-uuid-1") + } + + // Validate the token. + got, err := ValidateToken(pub, tokenStr, testIssuer) + if err != nil { + t.Fatalf("ValidateToken: %v", err) + } + if got.Subject != "user-uuid-1" { + t.Errorf("validated Subject = %q, want %q", got.Subject, "user-uuid-1") + } + if got.JTI != claims.JTI { + t.Errorf("validated JTI = %q, want %q", got.JTI, claims.JTI) + } + if len(got.Roles) != 2 { + t.Errorf("validated Roles = %v, want 2 roles", got.Roles) + } +} + +// TestValidateTokenWrongAlgorithm verifies that tokens with non-EdDSA alg are +// rejected immediately, before any signature verification. +// Security: This tests the core defence against algorithm-confusion attacks. +func TestValidateTokenWrongAlgorithm(t *testing.T) { + _, priv := generateTestKey(t) + pub, _ := generateTestKey(t) // different key — but alg check should fail first + + // Forge a token signed with HMAC-SHA256 (alg: HS256). + hmacToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "iss": testIssuer, + "sub": "attacker", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + "jti": "fake-jti", + }) + // Use the Ed25519 public key bytes as the HMAC secret (classic alg confusion). + hs256Signed, err := hmacToken.SignedString([]byte(priv.Public().(ed25519.PublicKey))) + if err != nil { + t.Fatalf("sign HS256 token: %v", err) + } + + _, err = ValidateToken(pub, hs256Signed, testIssuer) + if err == nil { + t.Fatal("expected error for HS256 token, got nil") + } + if err != ErrWrongAlgorithm { + t.Errorf("expected ErrWrongAlgorithm, got: %v", err) + } +} + +// TestValidateTokenAlgNone verifies that "none" algorithm is rejected. +// Security: "none" algorithm tokens have no signature and must always be +// rejected regardless of payload content. +func TestValidateTokenAlgNone(t *testing.T) { + pub, _ := generateTestKey(t) + + // Construct a "none" algorithm token manually. + // golang-jwt/v5 disallows signing with "none" directly, so we craft it + // using raw base64url encoding. + header := `{"alg":"none","typ":"JWT"}` + payload := `{"iss":"https://auth.example.com","sub":"evil","iat":1000000,"exp":9999999999,"jti":"evil-jti"}` + noneToken := b64url(header) + "." + b64url(payload) + "." + + _, err := ValidateToken(pub, noneToken, testIssuer) + if err == nil { + t.Fatal("expected error for 'none' algorithm token, got nil") + } +} + +// TestValidateTokenExpired verifies that expired tokens are rejected. +func TestValidateTokenExpired(t *testing.T) { + pub, priv := generateTestKey(t) + + // Issue a token with a negative expiry (already expired). + tokenStr, _, err := IssueToken(priv, testIssuer, "user", nil, -time.Minute) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + + _, err = ValidateToken(pub, tokenStr, testIssuer) + if err == nil { + t.Fatal("expected error for expired token, got nil") + } + if err != ErrExpiredToken { + t.Errorf("expected ErrExpiredToken, got: %v", err) + } +} + +// TestValidateTokenTamperedSignature verifies that signature tampering is caught. +func TestValidateTokenTamperedSignature(t *testing.T) { + pub, priv := generateTestKey(t) + + tokenStr, _, err := IssueToken(priv, testIssuer, "user", nil, time.Hour) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + + // Tamper: flip a byte in the signature (last segment). + parts := strings.Split(tokenStr, ".") + if len(parts) != 3 { + t.Fatalf("unexpected token format: %d parts", len(parts)) + } + sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + t.Fatalf("decode signature: %v", err) + } + sigBytes[0] ^= 0x01 // flip one bit + parts[2] = base64.RawURLEncoding.EncodeToString(sigBytes) + tampered := strings.Join(parts, ".") + + _, err = ValidateToken(pub, tampered, testIssuer) + if err == nil { + t.Fatal("expected error for tampered signature, got nil") + } +} + +// TestValidateTokenWrongKey verifies that a token signed with a different key +// is rejected. +func TestValidateTokenWrongKey(t *testing.T) { + _, priv := generateTestKey(t) + wrongPub, _ := generateTestKey(t) + + tokenStr, _, err := IssueToken(priv, testIssuer, "user", nil, time.Hour) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + + _, err = ValidateToken(wrongPub, tokenStr, testIssuer) + if err == nil { + t.Fatal("expected error for wrong key, got nil") + } +} + +// TestValidateTokenWrongIssuer verifies that tokens from a different issuer +// are rejected. +func TestValidateTokenWrongIssuer(t *testing.T) { + pub, priv := generateTestKey(t) + + tokenStr, _, err := IssueToken(priv, "https://evil.example.com", "user", nil, time.Hour) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + + _, err = ValidateToken(pub, tokenStr, testIssuer) + if err == nil { + t.Fatal("expected error for wrong issuer, got nil") + } +} + +// TestJTIsAreUnique verifies that two issued tokens have different JTIs. +func TestJTIsAreUnique(t *testing.T) { + _, priv := generateTestKey(t) + + _, c1, err := IssueToken(priv, testIssuer, "user", nil, time.Hour) + if err != nil { + t.Fatalf("IssueToken (1): %v", err) + } + _, c2, err := IssueToken(priv, testIssuer, "user", nil, time.Hour) + if err != nil { + t.Fatalf("IssueToken (2): %v", err) + } + if c1.JTI == c2.JTI { + t.Error("two issued tokens have the same JTI") + } +} + +// TestClaimsHasRole verifies role checking. +func TestClaimsHasRole(t *testing.T) { + c := &Claims{Roles: []string{"admin", "reader"}} + if !c.HasRole("admin") { + t.Error("expected HasRole(admin) = true") + } + if !c.HasRole("reader") { + t.Error("expected HasRole(reader) = true") + } + if c.HasRole("writer") { + t.Error("expected HasRole(writer) = false") + } +}