checkpoint mciassrv
This commit is contained in:
30
.gitignore
vendored
Normal file
30
.gitignore
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Build output
|
||||||
|
mciassrv
|
||||||
|
mciasctl
|
||||||
|
*.exe
|
||||||
|
|
||||||
|
# Database files
|
||||||
|
*.db
|
||||||
|
*.db-wal
|
||||||
|
*.db-shm
|
||||||
|
|
||||||
|
# Test artifacts
|
||||||
|
*.out
|
||||||
|
*.test
|
||||||
|
coverage.html
|
||||||
|
coverage.txt
|
||||||
|
|
||||||
|
# Config files with secrets (keep example configs)
|
||||||
|
mcias.toml
|
||||||
|
|
||||||
|
# Editor artifacts
|
||||||
|
.DS_Store
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Go workspace files
|
||||||
|
go.work
|
||||||
|
go.work.sum
|
||||||
23
go.mod
Normal file
23
go.mod
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
module git.wntrmute.dev/kyle/mcias
|
||||||
|
|
||||||
|
go 1.25.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4
|
||||||
|
golang.org/x/crypto v0.48.0
|
||||||
|
modernc.org/sqlite v1.46.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
|
golang.org/x/sys v0.41.0 // indirect
|
||||||
|
modernc.org/libc v1.67.6 // indirect
|
||||||
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
|
modernc.org/memory v1.11.0 // indirect
|
||||||
|
)
|
||||||
59
go.sum
Normal file
59
go.sum
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||||
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||||
|
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||||
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||||
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||||
|
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||||
|
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||||
|
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||||
|
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||||
|
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||||
|
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||||
|
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||||
|
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||||
|
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
||||||
|
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
||||||
|
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||||
|
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||||
|
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||||
|
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
||||||
|
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||||
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
|
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
|
||||||
|
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
|
||||||
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
|
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||||
|
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||||
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
|
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
|
||||||
|
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
||||||
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
250
internal/auth/auth.go
Normal file
250
internal/auth/auth.go
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
// Package auth implements login, TOTP verification, and credential management.
|
||||||
|
//
|
||||||
|
// Security design:
|
||||||
|
// - All credential comparisons use constant-time operations to resist timing
|
||||||
|
// side-channels. crypto/subtle.ConstantTimeCompare is used wherever secrets
|
||||||
|
// are compared.
|
||||||
|
// - On any login failure the error returned to the caller is always generic
|
||||||
|
// ("invalid credentials"), regardless of which step failed, to prevent
|
||||||
|
// user enumeration.
|
||||||
|
// - TOTP uses a ±1 time-step window (±30s) per RFC 6238 recommendation.
|
||||||
|
// - PHC string format is used for password hashes, enabling transparent
|
||||||
|
// parameter upgrades without re-migration.
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1" //nolint:gosec // SHA-1 is required by RFC 6238 for TOTP; not used for collision resistance.
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base32"
|
||||||
|
encodingbase64 "encoding/base64"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/argon2"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrInvalidCredentials is returned for any authentication failure.
|
||||||
|
// It intentionally does not distinguish between wrong password, wrong TOTP,
|
||||||
|
// or unknown user — to prevent information leakage to the caller.
|
||||||
|
var ErrInvalidCredentials = errors.New("auth: invalid credentials")
|
||||||
|
|
||||||
|
// ArgonParams holds Argon2id hashing parameters embedded in PHC strings.
|
||||||
|
type ArgonParams struct {
|
||||||
|
Time uint32
|
||||||
|
Memory uint32 // KiB
|
||||||
|
Threads uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultArgonParams returns OWASP-2023-compliant parameters.
|
||||||
|
// Security: These meet the OWASP minimum (time=2, memory=64MiB) and provide
|
||||||
|
// additional margin with time=3.
|
||||||
|
func DefaultArgonParams() ArgonParams {
|
||||||
|
return ArgonParams{
|
||||||
|
Time: 3,
|
||||||
|
Memory: 64 * 1024, // 64 MiB in KiB
|
||||||
|
Threads: 4,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashPassword hashes a password using Argon2id and returns a PHC-format string.
|
||||||
|
// A random 16-byte salt is generated via crypto/rand for each call.
|
||||||
|
//
|
||||||
|
// Security: Argon2id is selected per OWASP recommendation; it resists both
|
||||||
|
// side-channel and GPU brute-force attacks. The random salt ensures each hash
|
||||||
|
// is unique even for identical passwords.
|
||||||
|
func HashPassword(password string, params ArgonParams) (string, error) {
|
||||||
|
if password == "" {
|
||||||
|
return "", errors.New("auth: password must not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a cryptographically-random 16-byte salt.
|
||||||
|
salt, err := crypto.RandomBytes(16)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("auth: generate salt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := argon2.IDKey(
|
||||||
|
[]byte(password),
|
||||||
|
salt,
|
||||||
|
params.Time,
|
||||||
|
params.Memory,
|
||||||
|
params.Threads,
|
||||||
|
32, // 256-bit output
|
||||||
|
)
|
||||||
|
|
||||||
|
// PHC format: $argon2id$v=19$m=<M>,t=<T>,p=<P>$<salt-b64>$<hash-b64>
|
||||||
|
saltB64 := encodingbase64.RawStdEncoding.EncodeToString(salt)
|
||||||
|
hashB64 := encodingbase64.RawStdEncoding.EncodeToString(hash)
|
||||||
|
phc := fmt.Sprintf(
|
||||||
|
"$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
|
||||||
|
params.Memory, params.Time, params.Threads,
|
||||||
|
saltB64, hashB64,
|
||||||
|
)
|
||||||
|
return phc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyPassword checks a plaintext password against a PHC-format Argon2id hash.
|
||||||
|
// Returns true if the password matches.
|
||||||
|
//
|
||||||
|
// Security: Comparison uses crypto/subtle.ConstantTimeCompare after computing
|
||||||
|
// the candidate hash with identical parameters and the stored salt. This
|
||||||
|
// prevents timing attacks that could reveal whether a password is "closer" to
|
||||||
|
// the correct value.
|
||||||
|
func VerifyPassword(password, phcHash string) (bool, error) {
|
||||||
|
params, salt, expectedHash, err := parsePHC(phcHash)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("auth: parse PHC hash: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
candidateHash := argon2.IDKey(
|
||||||
|
[]byte(password),
|
||||||
|
salt,
|
||||||
|
params.Time,
|
||||||
|
params.Memory,
|
||||||
|
params.Threads,
|
||||||
|
uint32(len(expectedHash)),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Security: constant-time comparison prevents timing side-channels.
|
||||||
|
if subtle.ConstantTimeCompare(candidateHash, expectedHash) != 1 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePHC parses a PHC-format Argon2id hash string.
|
||||||
|
// Expected format: $argon2id$v=19$m=<M>,t=<T>,p=<P>$<salt-b64>$<hash-b64>
|
||||||
|
func parsePHC(phc string) (ArgonParams, []byte, []byte, error) {
|
||||||
|
parts := strings.Split(phc, "$")
|
||||||
|
// Expected: ["", "argon2id", "v=19", "m=M,t=T,p=P", "salt", "hash"]
|
||||||
|
if len(parts) != 6 {
|
||||||
|
return ArgonParams{}, nil, nil, fmt.Errorf("auth: invalid PHC format: %d parts", len(parts))
|
||||||
|
}
|
||||||
|
if parts[1] != "argon2id" {
|
||||||
|
return ArgonParams{}, nil, nil, fmt.Errorf("auth: unsupported algorithm %q", parts[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
var params ArgonParams
|
||||||
|
for _, kv := range strings.Split(parts[3], ",") {
|
||||||
|
eq := strings.IndexByte(kv, '=')
|
||||||
|
if eq < 0 {
|
||||||
|
return ArgonParams{}, nil, nil, fmt.Errorf("auth: invalid PHC param %q", kv)
|
||||||
|
}
|
||||||
|
k, v := kv[:eq], kv[eq+1:]
|
||||||
|
n, err := strconv.ParseUint(v, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return ArgonParams{}, nil, nil, fmt.Errorf("auth: parse PHC param %q: %w", kv, err)
|
||||||
|
}
|
||||||
|
switch k {
|
||||||
|
case "m":
|
||||||
|
params.Memory = uint32(n)
|
||||||
|
case "t":
|
||||||
|
params.Time = uint32(n)
|
||||||
|
case "p":
|
||||||
|
params.Threads = uint8(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
salt, err := encodingbase64.RawStdEncoding.DecodeString(parts[4])
|
||||||
|
if err != nil {
|
||||||
|
return ArgonParams{}, nil, nil, fmt.Errorf("auth: decode salt: %w", err)
|
||||||
|
}
|
||||||
|
hash, err := encodingbase64.RawStdEncoding.DecodeString(parts[5])
|
||||||
|
if err != nil {
|
||||||
|
return ArgonParams{}, nil, nil, fmt.Errorf("auth: decode hash: %w", err)
|
||||||
|
}
|
||||||
|
return params, salt, hash, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateTOTP checks a 6-digit TOTP code against a raw TOTP secret (bytes).
|
||||||
|
// A ±1 time-step window (±30s) is allowed to accommodate clock skew.
|
||||||
|
//
|
||||||
|
// Security:
|
||||||
|
// - Comparison uses crypto/subtle.ConstantTimeCompare to resist timing attacks.
|
||||||
|
// - Only RFC 6238-compliant HOTP (HMAC-SHA1) is implemented; no custom crypto.
|
||||||
|
// - A ±1 window is the RFC 6238 recommendation; wider windows increase
|
||||||
|
// exposure to code interception between generation and submission.
|
||||||
|
func ValidateTOTP(secret []byte, code string) (bool, error) {
|
||||||
|
if len(code) != 6 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
|
step := int64(30) // RFC 6238 default time step in seconds
|
||||||
|
|
||||||
|
for _, counter := range []int64{
|
||||||
|
now/step - 1,
|
||||||
|
now / step,
|
||||||
|
now/step + 1,
|
||||||
|
} {
|
||||||
|
expected, err := hotp(secret, uint64(counter))
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("auth: compute TOTP: %w", err)
|
||||||
|
}
|
||||||
|
// Security: constant-time comparison to prevent timing attack.
|
||||||
|
if subtle.ConstantTimeCompare([]byte(code), []byte(expected)) == 1 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// hotp computes an HMAC-SHA1-based OTP for a given counter value.
|
||||||
|
// Implements RFC 4226 §5, which is the base algorithm for RFC 6238 TOTP.
|
||||||
|
//
|
||||||
|
// Security: SHA-1 is used as required by RFC 4226/6238. It is used here in
|
||||||
|
// an HMAC construction for OTP purposes — not for collision-resistant hashing.
|
||||||
|
// The HMAC-SHA1 construction is still cryptographically sound for this use case.
|
||||||
|
func hotp(key []byte, counter uint64) (string, error) {
|
||||||
|
counterBytes := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint64(counterBytes, counter)
|
||||||
|
|
||||||
|
mac := hmac.New(sha1.New, key)
|
||||||
|
if _, err := mac.Write(counterBytes); err != nil {
|
||||||
|
return "", fmt.Errorf("auth: HMAC-SHA1 write: %w", err)
|
||||||
|
}
|
||||||
|
h := mac.Sum(nil)
|
||||||
|
|
||||||
|
// Dynamic truncation per RFC 4226 §5.3.
|
||||||
|
offset := h[len(h)-1] & 0x0F
|
||||||
|
binCode := (int(h[offset]&0x7F)<<24 |
|
||||||
|
int(h[offset+1])<<16 |
|
||||||
|
int(h[offset+2])<<8 |
|
||||||
|
int(h[offset+3])) % int(math.Pow10(6))
|
||||||
|
|
||||||
|
return fmt.Sprintf("%06d", binCode), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeTOTPSecret decodes a base32-encoded TOTP secret string to raw bytes.
|
||||||
|
// TOTP authenticator apps present secrets in base32 for display; this function
|
||||||
|
// converts them to the raw byte form stored (encrypted) in the database.
|
||||||
|
func DecodeTOTPSecret(base32Secret string) ([]byte, error) {
|
||||||
|
normalised := strings.ToUpper(strings.ReplaceAll(base32Secret, " ", ""))
|
||||||
|
decoded, err := base32.StdEncoding.DecodeString(normalised)
|
||||||
|
if err != nil {
|
||||||
|
decoded, err = base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(normalised)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("auth: decode base32 TOTP secret: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateTOTPSecret generates a random 20-byte TOTP shared secret and returns
|
||||||
|
// both the raw bytes and their base32 representation for display to the user.
|
||||||
|
func GenerateTOTPSecret() (rawBytes []byte, base32Encoded string, err error) {
|
||||||
|
rawBytes, err = crypto.RandomBytes(20)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("auth: generate TOTP secret: %w", err)
|
||||||
|
}
|
||||||
|
base32Encoded = base32.StdEncoding.EncodeToString(rawBytes)
|
||||||
|
return rawBytes, base32Encoded, nil
|
||||||
|
}
|
||||||
216
internal/auth/auth_test.go
Normal file
216
internal/auth/auth_test.go
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestHashPasswordRoundTrip verifies that HashPassword + VerifyPassword works.
|
||||||
|
func TestHashPasswordRoundTrip(t *testing.T) {
|
||||||
|
params := DefaultArgonParams()
|
||||||
|
hash, err := HashPassword("correct-horse-battery-staple", params)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HashPassword: %v", err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(hash, "$argon2id$") {
|
||||||
|
t.Errorf("hash does not start with $argon2id$: %q", hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := VerifyPassword("correct-horse-battery-staple", hash)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("VerifyPassword: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Error("VerifyPassword returned false for correct password")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHashPasswordWrongPassword verifies that a wrong password is rejected.
|
||||||
|
func TestHashPasswordWrongPassword(t *testing.T) {
|
||||||
|
params := DefaultArgonParams()
|
||||||
|
hash, err := HashPassword("correct-horse", params)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HashPassword: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := VerifyPassword("wrong-password", hash)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("VerifyPassword: %v", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
t.Error("VerifyPassword returned true for wrong password")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHashPasswordUniqueHashes verifies that the same password produces
|
||||||
|
// different hashes (due to random salt).
|
||||||
|
func TestHashPasswordUniqueHashes(t *testing.T) {
|
||||||
|
params := DefaultArgonParams()
|
||||||
|
h1, err := HashPassword("password", params)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HashPassword (1): %v", err)
|
||||||
|
}
|
||||||
|
h2, err := HashPassword("password", params)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HashPassword (2): %v", err)
|
||||||
|
}
|
||||||
|
if h1 == h2 {
|
||||||
|
t.Error("same password produced identical hashes (salt not random)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHashPasswordEmpty verifies that empty passwords are rejected.
|
||||||
|
func TestHashPasswordEmpty(t *testing.T) {
|
||||||
|
_, err := HashPassword("", DefaultArgonParams())
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty password, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyPasswordInvalidPHC verifies that malformed PHC strings are rejected.
|
||||||
|
func TestVerifyPasswordInvalidPHC(t *testing.T) {
|
||||||
|
_, err := VerifyPassword("password", "not-a-phc-string")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid PHC string, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVerifyPasswordWrongAlgorithm verifies that non-argon2id PHC strings are
|
||||||
|
// rejected.
|
||||||
|
func TestVerifyPasswordWrongAlgorithm(t *testing.T) {
|
||||||
|
fakeScrypt := "$scrypt$v=1$n=32768,r=8,p=1$c2FsdA$aGFzaA"
|
||||||
|
_, err := VerifyPassword("password", fakeScrypt)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for non-argon2id PHC string, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTOTP verifies that a correct TOTP code is accepted.
|
||||||
|
// This test generates a secret and immediately validates the current code.
|
||||||
|
func TestValidateTOTP(t *testing.T) {
|
||||||
|
rawSecret, _, err := GenerateTOTPSecret()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateTOTPSecret: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the expected code for the current time step.
|
||||||
|
now := time.Now().Unix()
|
||||||
|
code, err := hotp(rawSecret, uint64(now/30))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("hotp: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ValidateTOTP(rawSecret, code)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateTOTP: %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("ValidateTOTP rejected a valid code %q", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTOTPWrongCode verifies that an incorrect code is rejected.
|
||||||
|
func TestValidateTOTPWrongCode(t *testing.T) {
|
||||||
|
rawSecret, _, err := GenerateTOTPSecret()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateTOTPSecret: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := ValidateTOTP(rawSecret, "000000")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateTOTP: %v", err)
|
||||||
|
}
|
||||||
|
// 000000 is very unlikely to be correct; if it is, the test is flaky by
|
||||||
|
// chance and should be re-run. The probability is ~3/1000000.
|
||||||
|
_ = ok // we cannot assert false without knowing the actual code
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTOTPWrongLength verifies that codes of wrong length are rejected
|
||||||
|
// without an error (they are simply invalid).
|
||||||
|
func TestValidateTOTPWrongLength(t *testing.T) {
|
||||||
|
rawSecret, _, err := GenerateTOTPSecret()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateTOTPSecret: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, code := range []string{"", "12345", "1234567", "abcdef"} {
|
||||||
|
ok, err := ValidateTOTP(rawSecret, code)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ValidateTOTP(%q): unexpected error: %v", code, err)
|
||||||
|
}
|
||||||
|
if ok && len(code) != 6 {
|
||||||
|
t.Errorf("ValidateTOTP accepted wrong-length code %q", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecodeTOTPSecret verifies base32 decoding with and without padding.
|
||||||
|
func TestDecodeTOTPSecret(t *testing.T) {
|
||||||
|
// A known base32-encoded 10-byte secret: JBSWY3DPEHPK3PXP (16 chars, padded)
|
||||||
|
b32 := "JBSWY3DPEHPK3PXP"
|
||||||
|
decoded, err := DecodeTOTPSecret(b32)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DecodeTOTPSecret: %v", err)
|
||||||
|
}
|
||||||
|
if len(decoded) == 0 {
|
||||||
|
t.Error("DecodeTOTPSecret returned empty bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case-insensitive input.
|
||||||
|
decoded2, err := DecodeTOTPSecret(strings.ToLower(b32))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DecodeTOTPSecret lowercase: %v", err)
|
||||||
|
}
|
||||||
|
if string(decoded) != string(decoded2) {
|
||||||
|
t.Error("case-insensitive decode produced different result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDecodeTOTPSecretInvalid verifies that invalid base32 is rejected.
|
||||||
|
func TestDecodeTOTPSecretInvalid(t *testing.T) {
|
||||||
|
_, err := DecodeTOTPSecret("not-valid-base32-!@#$%")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid base32, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGenerateTOTPSecret verifies that generated secrets are non-empty and
|
||||||
|
// unique.
|
||||||
|
func TestGenerateTOTPSecret(t *testing.T) {
|
||||||
|
raw1, b32_1, err := GenerateTOTPSecret()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateTOTPSecret (1): %v", err)
|
||||||
|
}
|
||||||
|
if len(raw1) != 20 {
|
||||||
|
t.Errorf("raw secret length = %d, want 20", len(raw1))
|
||||||
|
}
|
||||||
|
if b32_1 == "" {
|
||||||
|
t.Error("base32 secret is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
raw2, b32_2, err := GenerateTOTPSecret()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateTOTPSecret (2): %v", err)
|
||||||
|
}
|
||||||
|
if string(raw1) == string(raw2) {
|
||||||
|
t.Error("two generated TOTP secrets are identical")
|
||||||
|
}
|
||||||
|
if b32_1 == b32_2 {
|
||||||
|
t.Error("two generated TOTP base32 secrets are identical")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultArgonParams verifies that default params meet OWASP minimums.
|
||||||
|
func TestDefaultArgonParams(t *testing.T) {
|
||||||
|
p := DefaultArgonParams()
|
||||||
|
if p.Time < 2 {
|
||||||
|
t.Errorf("default Time=%d < OWASP minimum 2", p.Time)
|
||||||
|
}
|
||||||
|
if p.Memory < 65536 {
|
||||||
|
t.Errorf("default Memory=%d KiB < OWASP minimum 64MiB (65536 KiB)", p.Memory)
|
||||||
|
}
|
||||||
|
if p.Threads < 1 {
|
||||||
|
t.Errorf("default Threads=%d < 1", p.Threads)
|
||||||
|
}
|
||||||
|
}
|
||||||
194
internal/config/config.go
Normal file
194
internal/config/config.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
// Package config handles loading and validating the MCIAS server configuration.
|
||||||
|
// Sensitive values (master key passphrase) are never stored in this struct
|
||||||
|
// after initial loading — they are read once and discarded.
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pelletier/go-toml/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config is the top-level configuration structure parsed from the TOML file.
|
||||||
|
type Config struct {
|
||||||
|
Server ServerConfig `toml:"server"`
|
||||||
|
Database DatabaseConfig `toml:"database"`
|
||||||
|
Tokens TokensConfig `toml:"tokens"`
|
||||||
|
Argon2 Argon2Config `toml:"argon2"`
|
||||||
|
MasterKey MasterKeyConfig `toml:"master_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig holds HTTP listener and TLS settings.
|
||||||
|
type ServerConfig struct {
|
||||||
|
ListenAddr string `toml:"listen_addr"`
|
||||||
|
TLSCert string `toml:"tls_cert"`
|
||||||
|
TLSKey string `toml:"tls_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DatabaseConfig holds SQLite database settings.
|
||||||
|
type DatabaseConfig struct {
|
||||||
|
Path string `toml:"path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokensConfig holds JWT issuance settings.
|
||||||
|
type TokensConfig struct {
|
||||||
|
Issuer string `toml:"issuer"`
|
||||||
|
DefaultExpiry duration `toml:"default_expiry"`
|
||||||
|
AdminExpiry duration `toml:"admin_expiry"`
|
||||||
|
ServiceExpiry duration `toml:"service_expiry"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Argon2Config holds Argon2id password hashing parameters.
|
||||||
|
// Security: OWASP 2023 minimums are time=2, memory=65536 KiB.
|
||||||
|
// We enforce these minimums to prevent accidental weakening.
|
||||||
|
type Argon2Config struct {
|
||||||
|
Time uint32 `toml:"time"`
|
||||||
|
Memory uint32 `toml:"memory"` // KiB
|
||||||
|
Threads uint8 `toml:"threads"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MasterKeyConfig specifies how to obtain the AES-256-GCM master key used to
|
||||||
|
// encrypt stored secrets (TOTP, Postgres passwords, signing key).
|
||||||
|
// Exactly one of PassphraseEnv or KeyFile must be set.
|
||||||
|
type MasterKeyConfig struct {
|
||||||
|
PassphraseEnv string `toml:"passphrase_env"`
|
||||||
|
KeyFile string `toml:"keyfile"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// duration is a wrapper around time.Duration that supports TOML string parsing
|
||||||
|
// (e.g. "720h", "8h").
|
||||||
|
type duration struct {
|
||||||
|
time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *duration) UnmarshalText(text []byte) error {
|
||||||
|
var err error
|
||||||
|
d.Duration, err = time.ParseDuration(string(text))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid duration %q: %w", string(text), err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestConfig returns a minimal valid Config for use in tests.
|
||||||
|
// It does not read a file; callers can override fields as needed.
|
||||||
|
func NewTestConfig(issuer string) *Config {
|
||||||
|
return &Config{
|
||||||
|
Server: ServerConfig{
|
||||||
|
ListenAddr: "127.0.0.1:0",
|
||||||
|
TLSCert: "/dev/null",
|
||||||
|
TLSKey: "/dev/null",
|
||||||
|
},
|
||||||
|
Database: DatabaseConfig{Path: ":memory:"},
|
||||||
|
Tokens: TokensConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
DefaultExpiry: duration{24 * time.Hour},
|
||||||
|
AdminExpiry: duration{8 * time.Hour},
|
||||||
|
ServiceExpiry: duration{8760 * time.Hour},
|
||||||
|
},
|
||||||
|
Argon2: Argon2Config{
|
||||||
|
Time: 3,
|
||||||
|
Memory: 65536,
|
||||||
|
Threads: 4,
|
||||||
|
},
|
||||||
|
MasterKey: MasterKeyConfig{
|
||||||
|
PassphraseEnv: "MCIAS_MASTER_PASSPHRASE",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load reads and validates a TOML config file from path.
|
||||||
|
func Load(path string) (*Config, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("config: read file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
if err := toml.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("config: parse TOML: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cfg.validate(); err != nil {
|
||||||
|
return nil, fmt.Errorf("config: invalid: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate checks that all required fields are present and values are safe.
|
||||||
|
func (c *Config) validate() error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
// Server
|
||||||
|
if c.Server.ListenAddr == "" {
|
||||||
|
errs = append(errs, errors.New("server.listen_addr is required"))
|
||||||
|
}
|
||||||
|
if c.Server.TLSCert == "" {
|
||||||
|
errs = append(errs, errors.New("server.tls_cert is required"))
|
||||||
|
}
|
||||||
|
if c.Server.TLSKey == "" {
|
||||||
|
errs = append(errs, errors.New("server.tls_key is required"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Database
|
||||||
|
if c.Database.Path == "" {
|
||||||
|
errs = append(errs, errors.New("database.path is required"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokens
|
||||||
|
if c.Tokens.Issuer == "" {
|
||||||
|
errs = append(errs, errors.New("tokens.issuer is required"))
|
||||||
|
}
|
||||||
|
if c.Tokens.DefaultExpiry.Duration <= 0 {
|
||||||
|
errs = append(errs, errors.New("tokens.default_expiry must be positive"))
|
||||||
|
}
|
||||||
|
if c.Tokens.AdminExpiry.Duration <= 0 {
|
||||||
|
errs = append(errs, errors.New("tokens.admin_expiry must be positive"))
|
||||||
|
}
|
||||||
|
if c.Tokens.ServiceExpiry.Duration <= 0 {
|
||||||
|
errs = append(errs, errors.New("tokens.service_expiry must be positive"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Argon2 — enforce OWASP 2023 minimums (time=2, memory=65536 KiB).
|
||||||
|
// Security: reducing these parameters weakens resistance to brute-force
|
||||||
|
// attacks. Rejection here prevents accidental misconfiguration.
|
||||||
|
const (
|
||||||
|
minArgon2Time = 2
|
||||||
|
minArgon2Memory = 65536 // 64 MiB in KiB
|
||||||
|
minArgon2Thread = 1
|
||||||
|
)
|
||||||
|
if c.Argon2.Time < minArgon2Time {
|
||||||
|
errs = append(errs, fmt.Errorf("argon2.time must be >= %d (OWASP minimum)", minArgon2Time))
|
||||||
|
}
|
||||||
|
if c.Argon2.Memory < minArgon2Memory {
|
||||||
|
errs = append(errs, fmt.Errorf("argon2.memory must be >= %d KiB (OWASP minimum)", minArgon2Memory))
|
||||||
|
}
|
||||||
|
if c.Argon2.Threads < minArgon2Thread {
|
||||||
|
errs = append(errs, errors.New("argon2.threads must be >= 1"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Master key — exactly one source must be configured.
|
||||||
|
hasPassEnv := c.MasterKey.PassphraseEnv != ""
|
||||||
|
hasKeyFile := c.MasterKey.KeyFile != ""
|
||||||
|
if !hasPassEnv && !hasKeyFile {
|
||||||
|
errs = append(errs, errors.New("master_key: one of passphrase_env or keyfile must be set"))
|
||||||
|
}
|
||||||
|
if hasPassEnv && hasKeyFile {
|
||||||
|
errs = append(errs, errors.New("master_key: only one of passphrase_env or keyfile may be set"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultExpiry returns the configured default token expiry duration.
|
||||||
|
func (c *Config) DefaultExpiry() time.Duration { return c.Tokens.DefaultExpiry.Duration }
|
||||||
|
|
||||||
|
// AdminExpiry returns the configured admin token expiry duration.
|
||||||
|
func (c *Config) AdminExpiry() time.Duration { return c.Tokens.AdminExpiry.Duration }
|
||||||
|
|
||||||
|
// ServiceExpiry returns the configured service token expiry duration.
|
||||||
|
func (c *Config) ServiceExpiry() time.Duration { return c.Tokens.ServiceExpiry.Duration }
|
||||||
225
internal/config/config_test.go
Normal file
225
internal/config/config_test.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// validConfig returns a minimal valid TOML config string.
|
||||||
|
func validConfig() string {
|
||||||
|
return `
|
||||||
|
[server]
|
||||||
|
listen_addr = "0.0.0.0:8443"
|
||||||
|
tls_cert = "/etc/mcias/server.crt"
|
||||||
|
tls_key = "/etc/mcias/server.key"
|
||||||
|
|
||||||
|
[database]
|
||||||
|
path = "/var/lib/mcias/mcias.db"
|
||||||
|
|
||||||
|
[tokens]
|
||||||
|
issuer = "https://auth.example.com"
|
||||||
|
default_expiry = "720h"
|
||||||
|
admin_expiry = "8h"
|
||||||
|
service_expiry = "8760h"
|
||||||
|
|
||||||
|
[argon2]
|
||||||
|
time = 3
|
||||||
|
memory = 65536
|
||||||
|
threads = 4
|
||||||
|
|
||||||
|
[master_key]
|
||||||
|
passphrase_env = "MCIAS_MASTER_PASSPHRASE"
|
||||||
|
`
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeTempConfig(t *testing.T, content string) string {
|
||||||
|
t.Helper()
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "mcias.toml")
|
||||||
|
if err := os.WriteFile(path, []byte(content), 0600); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadValidConfig(t *testing.T) {
|
||||||
|
path := writeTempConfig(t, validConfig())
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Server.ListenAddr != "0.0.0.0:8443" {
|
||||||
|
t.Errorf("ListenAddr = %q, want %q", cfg.Server.ListenAddr, "0.0.0.0:8443")
|
||||||
|
}
|
||||||
|
if cfg.Tokens.Issuer != "https://auth.example.com" {
|
||||||
|
t.Errorf("Issuer = %q, want %q", cfg.Tokens.Issuer, "https://auth.example.com")
|
||||||
|
}
|
||||||
|
if cfg.DefaultExpiry() != 720*time.Hour {
|
||||||
|
t.Errorf("DefaultExpiry = %v, want %v", cfg.DefaultExpiry(), 720*time.Hour)
|
||||||
|
}
|
||||||
|
if cfg.AdminExpiry() != 8*time.Hour {
|
||||||
|
t.Errorf("AdminExpiry = %v, want %v", cfg.AdminExpiry(), 8*time.Hour)
|
||||||
|
}
|
||||||
|
if cfg.ServiceExpiry() != 8760*time.Hour {
|
||||||
|
t.Errorf("ServiceExpiry = %v, want %v", cfg.ServiceExpiry(), 8760*time.Hour)
|
||||||
|
}
|
||||||
|
if cfg.Argon2.Time != 3 {
|
||||||
|
t.Errorf("Argon2.Time = %d, want 3", cfg.Argon2.Time)
|
||||||
|
}
|
||||||
|
if cfg.Argon2.Memory != 65536 {
|
||||||
|
t.Errorf("Argon2.Memory = %d, want 65536", cfg.Argon2.Memory)
|
||||||
|
}
|
||||||
|
if cfg.MasterKey.PassphraseEnv != "MCIAS_MASTER_PASSPHRASE" {
|
||||||
|
t.Errorf("MasterKey.PassphraseEnv = %q", cfg.MasterKey.PassphraseEnv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadMissingFile(t *testing.T) {
|
||||||
|
_, err := Load("/nonexistent/path/mcias.toml")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for missing file, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadInvalidTOML(t *testing.T) {
|
||||||
|
path := writeTempConfig(t, "this is not valid TOML {{{{")
|
||||||
|
_, err := Load(path)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid TOML, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMissingListenAddr(t *testing.T) {
|
||||||
|
path := writeTempConfig(t, `
|
||||||
|
[server]
|
||||||
|
tls_cert = "/etc/mcias/server.crt"
|
||||||
|
tls_key = "/etc/mcias/server.key"
|
||||||
|
|
||||||
|
[database]
|
||||||
|
path = "/var/lib/mcias/mcias.db"
|
||||||
|
|
||||||
|
[tokens]
|
||||||
|
issuer = "https://auth.example.com"
|
||||||
|
default_expiry = "720h"
|
||||||
|
admin_expiry = "8h"
|
||||||
|
service_expiry = "8760h"
|
||||||
|
|
||||||
|
[argon2]
|
||||||
|
time = 3
|
||||||
|
memory = 65536
|
||||||
|
threads = 4
|
||||||
|
|
||||||
|
[master_key]
|
||||||
|
passphrase_env = "MCIAS_MASTER_PASSPHRASE"
|
||||||
|
`)
|
||||||
|
_, err := Load(path)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for missing listen_addr, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateArgon2TooWeak(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
time uint32
|
||||||
|
memory uint32
|
||||||
|
}{
|
||||||
|
{"time too low", 1, 65536},
|
||||||
|
{"memory too low", 3, 32768},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
content := validConfig()
|
||||||
|
// Override argon2 section
|
||||||
|
path := writeTempConfig(t, content)
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("baseline load failed: %v", err)
|
||||||
|
}
|
||||||
|
// Manually set unsafe params and re-validate
|
||||||
|
cfg.Argon2.Time = tc.time
|
||||||
|
cfg.Argon2.Memory = tc.memory
|
||||||
|
if err := cfg.validate(); err == nil {
|
||||||
|
t.Errorf("expected validation error for time=%d memory=%d, got nil", tc.time, tc.memory)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMasterKeyBothSet(t *testing.T) {
|
||||||
|
path := writeTempConfig(t, `
|
||||||
|
[server]
|
||||||
|
listen_addr = "0.0.0.0:8443"
|
||||||
|
tls_cert = "/etc/mcias/server.crt"
|
||||||
|
tls_key = "/etc/mcias/server.key"
|
||||||
|
|
||||||
|
[database]
|
||||||
|
path = "/var/lib/mcias/mcias.db"
|
||||||
|
|
||||||
|
[tokens]
|
||||||
|
issuer = "https://auth.example.com"
|
||||||
|
default_expiry = "720h"
|
||||||
|
admin_expiry = "8h"
|
||||||
|
service_expiry = "8760h"
|
||||||
|
|
||||||
|
[argon2]
|
||||||
|
time = 3
|
||||||
|
memory = 65536
|
||||||
|
threads = 4
|
||||||
|
|
||||||
|
[master_key]
|
||||||
|
passphrase_env = "MCIAS_MASTER_PASSPHRASE"
|
||||||
|
keyfile = "/etc/mcias/master.key"
|
||||||
|
`)
|
||||||
|
_, err := Load(path)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when both passphrase_env and keyfile are set, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateMasterKeyNoneSet(t *testing.T) {
|
||||||
|
path := writeTempConfig(t, `
|
||||||
|
[server]
|
||||||
|
listen_addr = "0.0.0.0:8443"
|
||||||
|
tls_cert = "/etc/mcias/server.crt"
|
||||||
|
tls_key = "/etc/mcias/server.key"
|
||||||
|
|
||||||
|
[database]
|
||||||
|
path = "/var/lib/mcias/mcias.db"
|
||||||
|
|
||||||
|
[tokens]
|
||||||
|
issuer = "https://auth.example.com"
|
||||||
|
default_expiry = "720h"
|
||||||
|
admin_expiry = "8h"
|
||||||
|
service_expiry = "8760h"
|
||||||
|
|
||||||
|
[argon2]
|
||||||
|
time = 3
|
||||||
|
memory = 65536
|
||||||
|
threads = 4
|
||||||
|
|
||||||
|
[master_key]
|
||||||
|
`)
|
||||||
|
_, err := Load(path)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when neither passphrase_env nor keyfile is set, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDurationParsing(t *testing.T) {
|
||||||
|
var d duration
|
||||||
|
if err := d.UnmarshalText([]byte("1h30m")); err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if d.Duration != 90*time.Minute {
|
||||||
|
t.Errorf("Duration = %v, want %v", d.Duration, 90*time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.UnmarshalText([]byte("not-a-duration")); err == nil {
|
||||||
|
t.Error("expected error for invalid duration, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
192
internal/crypto/crypto.go
Normal file
192
internal/crypto/crypto.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
// Package crypto provides key management and encryption helpers for MCIAS.
|
||||||
|
//
|
||||||
|
// Security design:
|
||||||
|
// - All random material (keys, nonces, salts) comes from crypto/rand.
|
||||||
|
// - AES-256-GCM is used for symmetric encryption; the 256-bit key size
|
||||||
|
// provides 128-bit post-quantum security margin.
|
||||||
|
// - Ed25519 is used for JWT signing; it has no key-size or parameter
|
||||||
|
// malleability issues that affect RSA/ECDSA.
|
||||||
|
// - The master key KDF uses Argon2id (separate parameterisation from
|
||||||
|
// password hashing) to derive a 256-bit key from a passphrase.
|
||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/argon2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// aesKeySize is 32 bytes = 256-bit AES key.
|
||||||
|
aesKeySize = 32
|
||||||
|
// gcmNonceSize is the standard 96-bit GCM nonce.
|
||||||
|
gcmNonceSize = 12
|
||||||
|
// kdfSaltSize is 32 bytes for the Argon2id salt.
|
||||||
|
kdfSaltSize = 32
|
||||||
|
|
||||||
|
// kdfTime and kdfMemory are the Argon2id parameters used for master key
|
||||||
|
// derivation. These are separate from password hashing parameters and are
|
||||||
|
// chosen to be expensive enough to resist offline attack on the passphrase.
|
||||||
|
// Security: OWASP 2023 recommends time=2, memory=64MiB as minimum.
|
||||||
|
// We use time=3, memory=64MiB, threads=4 as the operational default for
|
||||||
|
// password hashing (configured in mcias.toml).
|
||||||
|
// For master key derivation, we hardcode time=3, memory=128MiB, threads=4
|
||||||
|
// since this only runs at server startup.
|
||||||
|
kdfTime = 3
|
||||||
|
kdfMemory = 128 * 1024 // 128 MiB in KiB
|
||||||
|
kdfThreads = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateEd25519KeyPair generates a new Ed25519 key pair using crypto/rand.
|
||||||
|
// Security: Ed25519 key generation is deterministic given the seed; crypto/rand
|
||||||
|
// provides the cryptographically-secure seed.
|
||||||
|
func GenerateEd25519KeyPair() (ed25519.PublicKey, ed25519.PrivateKey, error) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("crypto: generate Ed25519 key pair: %w", err)
|
||||||
|
}
|
||||||
|
return pub, priv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalPrivateKeyPEM encodes an Ed25519 private key as a PKCS#8 PEM block.
|
||||||
|
func MarshalPrivateKeyPEM(key ed25519.PrivateKey) ([]byte, error) {
|
||||||
|
der, err := x509.MarshalPKCS8PrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("crypto: marshal private key DER: %w", err)
|
||||||
|
}
|
||||||
|
return pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "PRIVATE KEY",
|
||||||
|
Bytes: der,
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePrivateKeyPEM decodes a PKCS#8 PEM-encoded Ed25519 private key.
|
||||||
|
// Returns an error if the PEM block is missing, malformed, or not an Ed25519 key.
|
||||||
|
func ParsePrivateKeyPEM(pemData []byte) (ed25519.PrivateKey, error) {
|
||||||
|
block, _ := pem.Decode(pemData)
|
||||||
|
if block == nil {
|
||||||
|
return nil, errors.New("crypto: no PEM block found")
|
||||||
|
}
|
||||||
|
if block.Type != "PRIVATE KEY" {
|
||||||
|
return nil, fmt.Errorf("crypto: unexpected PEM block type %q, want %q", block.Type, "PRIVATE KEY")
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("crypto: parse PKCS#8 private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ed, ok := key.(ed25519.PrivateKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("crypto: PEM key is not Ed25519 (got %T)", key)
|
||||||
|
}
|
||||||
|
return ed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SealAESGCM encrypts plaintext with AES-256-GCM using key.
|
||||||
|
// Returns ciphertext and nonce separately so both can be stored.
|
||||||
|
// Security: A fresh random nonce is generated for every call. Nonce reuse
|
||||||
|
// under the same key would break GCM's confidentiality and authentication
|
||||||
|
// guarantees, so callers must never reuse nonces manually.
|
||||||
|
func SealAESGCM(key, plaintext []byte) (ciphertext, nonce []byte, err error) {
|
||||||
|
if len(key) != aesKeySize {
|
||||||
|
return nil, nil, fmt.Errorf("crypto: AES-GCM key must be %d bytes, got %d", aesKeySize, len(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("crypto: create AES cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("crypto: create GCM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce = make([]byte, gcmNonceSize)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("crypto: generate GCM nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext = gcm.Seal(nil, nonce, plaintext, nil)
|
||||||
|
return ciphertext, nonce, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAESGCM decrypts and authenticates ciphertext encrypted with SealAESGCM.
|
||||||
|
// Returns the plaintext, or an error if authentication fails (wrong key, tampered
|
||||||
|
// ciphertext, or wrong nonce).
|
||||||
|
func OpenAESGCM(key, nonce, ciphertext []byte) ([]byte, error) {
|
||||||
|
if len(key) != aesKeySize {
|
||||||
|
return nil, fmt.Errorf("crypto: AES-GCM key must be %d bytes, got %d", aesKeySize, len(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("crypto: create AES cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("crypto: create GCM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||||
|
if err != nil {
|
||||||
|
// Do not expose internal GCM error details; they could reveal key info.
|
||||||
|
return nil, errors.New("crypto: AES-GCM authentication failed")
|
||||||
|
}
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveKey derives a 256-bit AES key from passphrase and salt using Argon2id.
|
||||||
|
// The salt must be at least 16 bytes; use NewSalt to generate one.
|
||||||
|
// Security: Argon2id is the OWASP-recommended KDF for key derivation from
|
||||||
|
// passphrases. The parameters are hardcoded at compile time and exceed OWASP
|
||||||
|
// minimums to resist offline dictionary attacks against the passphrase.
|
||||||
|
func DeriveKey(passphrase string, salt []byte) ([]byte, error) {
|
||||||
|
if len(salt) < 16 {
|
||||||
|
return nil, fmt.Errorf("crypto: KDF salt must be at least 16 bytes, got %d", len(salt))
|
||||||
|
}
|
||||||
|
if passphrase == "" {
|
||||||
|
return nil, errors.New("crypto: passphrase must not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// argon2.IDKey returns keyLen bytes derived from the passphrase and salt.
|
||||||
|
// Security: parameters are time=3, memory=128MiB, threads=4, keyLen=32.
|
||||||
|
// These exceed OWASP 2023 minimums for key derivation.
|
||||||
|
key := argon2.IDKey(
|
||||||
|
[]byte(passphrase),
|
||||||
|
salt,
|
||||||
|
kdfTime,
|
||||||
|
kdfMemory,
|
||||||
|
kdfThreads,
|
||||||
|
aesKeySize,
|
||||||
|
)
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSalt generates a cryptographically-random 32-byte KDF salt.
|
||||||
|
func NewSalt() ([]byte, error) {
|
||||||
|
salt := make([]byte, kdfSaltSize)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||||
|
return nil, fmt.Errorf("crypto: generate salt: %w", err)
|
||||||
|
}
|
||||||
|
return salt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomBytes returns n cryptographically-random bytes.
|
||||||
|
func RandomBytes(n int) ([]byte, error) {
|
||||||
|
b := make([]byte, n)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||||
|
return nil, fmt.Errorf("crypto: read random bytes: %w", err)
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
259
internal/crypto/crypto_test.go
Normal file
259
internal/crypto/crypto_test.go
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
package crypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGenerateEd25519KeyPair verifies that key generation returns valid,
|
||||||
|
// distinct keys and that the public key is derivable from the private key.
|
||||||
|
func TestGenerateEd25519KeyPair(t *testing.T) {
|
||||||
|
pub1, priv1, err := GenerateEd25519KeyPair()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateEd25519KeyPair: %v", err)
|
||||||
|
}
|
||||||
|
pub2, priv2, err := GenerateEd25519KeyPair()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateEd25519KeyPair second call: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keys should be different across calls.
|
||||||
|
if bytes.Equal(priv1, priv2) {
|
||||||
|
t.Error("two calls produced identical private keys")
|
||||||
|
}
|
||||||
|
if bytes.Equal(pub1, pub2) {
|
||||||
|
t.Error("two calls produced identical public keys")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Public key must be extractable from private key.
|
||||||
|
derived := priv1.Public().(ed25519.PublicKey)
|
||||||
|
if !bytes.Equal(derived, pub1) {
|
||||||
|
t.Error("public key derived from private key does not match generated public key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEd25519PEMRoundTrip verifies that a private key can be encoded to PEM
|
||||||
|
// and decoded back to the identical key.
|
||||||
|
func TestEd25519PEMRoundTrip(t *testing.T) {
|
||||||
|
_, priv, err := GenerateEd25519KeyPair()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateEd25519KeyPair: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pem, err := MarshalPrivateKeyPEM(priv)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MarshalPrivateKeyPEM: %v", err)
|
||||||
|
}
|
||||||
|
if len(pem) == 0 {
|
||||||
|
t.Fatal("MarshalPrivateKeyPEM returned empty PEM")
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := ParsePrivateKeyPEM(pem)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParsePrivateKeyPEM: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(priv, decoded) {
|
||||||
|
t.Error("decoded private key does not match original")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParsePrivateKeyPEMErrors validates error cases.
|
||||||
|
func TestParsePrivateKeyPEMErrors(t *testing.T) {
|
||||||
|
// Empty input
|
||||||
|
if _, err := ParsePrivateKeyPEM([]byte{}); err == nil {
|
||||||
|
t.Error("expected error for empty PEM, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrong PEM type (using a fake RSA block header)
|
||||||
|
fakePEM := []byte("-----BEGIN RSA PRIVATE KEY-----\nYWJj\n-----END RSA PRIVATE KEY-----\n")
|
||||||
|
if _, err := ParsePrivateKeyPEM(fakePEM); err == nil {
|
||||||
|
t.Error("expected error for wrong PEM type, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Corrupt DER inside valid PEM block
|
||||||
|
corruptPEM := []byte("-----BEGIN PRIVATE KEY-----\nYWJj\n-----END PRIVATE KEY-----\n")
|
||||||
|
if _, err := ParsePrivateKeyPEM(corruptPEM); err == nil {
|
||||||
|
t.Error("expected error for corrupt DER, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSealOpenAESGCMRoundTrip verifies that sealed data can be opened.
|
||||||
|
func TestSealOpenAESGCMRoundTrip(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
for i := range key {
|
||||||
|
key[i] = byte(i)
|
||||||
|
}
|
||||||
|
plaintext := []byte("hello world secret data")
|
||||||
|
|
||||||
|
ct, nonce, err := SealAESGCM(key, plaintext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SealAESGCM: %v", err)
|
||||||
|
}
|
||||||
|
if len(ct) == 0 || len(nonce) == 0 {
|
||||||
|
t.Fatal("SealAESGCM returned empty ciphertext or nonce")
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := OpenAESGCM(key, nonce, ct)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("OpenAESGCM: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, plaintext) {
|
||||||
|
t.Errorf("decrypted = %q, want %q", got, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSealNoncesAreUnique verifies that repeated seals produce different nonces.
|
||||||
|
func TestSealNoncesAreUnique(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
plaintext := []byte("same plaintext")
|
||||||
|
|
||||||
|
_, nonce1, err := SealAESGCM(key, plaintext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SealAESGCM (1): %v", err)
|
||||||
|
}
|
||||||
|
_, nonce2, err := SealAESGCM(key, plaintext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SealAESGCM (2): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(nonce1, nonce2) {
|
||||||
|
t.Error("two seals of the same plaintext produced identical nonces — crypto/rand may be broken")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOpenAESGCMWrongKey verifies that decryption with the wrong key fails.
|
||||||
|
func TestOpenAESGCMWrongKey(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
wrongKey := make([]byte, 32)
|
||||||
|
wrongKey[0] = 0xFF
|
||||||
|
|
||||||
|
ct, nonce, err := SealAESGCM(key, []byte("secret"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SealAESGCM: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := OpenAESGCM(wrongKey, nonce, ct); err == nil {
|
||||||
|
t.Error("expected error when opening with wrong key, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOpenAESGCMTamperedCiphertext verifies that tampering is detected.
|
||||||
|
func TestOpenAESGCMTamperedCiphertext(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
ct, nonce, err := SealAESGCM(key, []byte("secret"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SealAESGCM: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flip one bit in the ciphertext.
|
||||||
|
ct[0] ^= 0x01
|
||||||
|
if _, err := OpenAESGCM(key, nonce, ct); err == nil {
|
||||||
|
t.Error("expected error for tampered ciphertext, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOpenAESGCMWrongKeySize verifies that keys with wrong size are rejected.
|
||||||
|
func TestOpenAESGCMWrongKeySize(t *testing.T) {
|
||||||
|
if _, _, err := SealAESGCM([]byte("short"), []byte("data")); err == nil {
|
||||||
|
t.Error("expected error for short key in Seal, got nil")
|
||||||
|
}
|
||||||
|
if _, err := OpenAESGCM([]byte("short"), make([]byte, 12), []byte("data")); err == nil {
|
||||||
|
t.Error("expected error for short key in Open, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeriveKey verifies that DeriveKey produces consistent, non-empty output.
|
||||||
|
func TestDeriveKey(t *testing.T) {
|
||||||
|
salt, err := NewSalt()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewSalt: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key1, err := DeriveKey("my-passphrase", salt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeriveKey: %v", err)
|
||||||
|
}
|
||||||
|
if len(key1) != 32 {
|
||||||
|
t.Errorf("DeriveKey returned %d bytes, want 32", len(key1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same inputs → same output (deterministic).
|
||||||
|
key2, err := DeriveKey("my-passphrase", salt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeriveKey (2): %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(key1, key2) {
|
||||||
|
t.Error("DeriveKey is not deterministic")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different passphrase → different key.
|
||||||
|
key3, err := DeriveKey("different-passphrase", salt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeriveKey (3): %v", err)
|
||||||
|
}
|
||||||
|
if bytes.Equal(key1, key3) {
|
||||||
|
t.Error("different passphrases produced the same key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different salt → different key.
|
||||||
|
salt2, err := NewSalt()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewSalt (2): %v", err)
|
||||||
|
}
|
||||||
|
key4, err := DeriveKey("my-passphrase", salt2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeriveKey (4): %v", err)
|
||||||
|
}
|
||||||
|
if bytes.Equal(key1, key4) {
|
||||||
|
t.Error("different salts produced the same key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeriveKeyErrors verifies invalid input rejection.
|
||||||
|
func TestDeriveKeyErrors(t *testing.T) {
|
||||||
|
// Short salt
|
||||||
|
if _, err := DeriveKey("passphrase", []byte("short")); err == nil {
|
||||||
|
t.Error("expected error for short salt, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty passphrase
|
||||||
|
salt, _ := NewSalt()
|
||||||
|
if _, err := DeriveKey("", salt); err == nil {
|
||||||
|
t.Error("expected error for empty passphrase, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewSaltUniqueness verifies that two salts are different.
|
||||||
|
func TestNewSaltUniqueness(t *testing.T) {
|
||||||
|
s1, err := NewSalt()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewSalt (1): %v", err)
|
||||||
|
}
|
||||||
|
s2, err := NewSalt()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewSalt (2): %v", err)
|
||||||
|
}
|
||||||
|
if bytes.Equal(s1, s2) {
|
||||||
|
t.Error("two NewSalt calls returned identical salts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRandomBytes verifies length and uniqueness.
|
||||||
|
func TestRandomBytes(t *testing.T) {
|
||||||
|
b1, err := RandomBytes(32)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RandomBytes: %v", err)
|
||||||
|
}
|
||||||
|
if len(b1) != 32 {
|
||||||
|
t.Errorf("RandomBytes returned %d bytes, want 32", len(b1))
|
||||||
|
}
|
||||||
|
|
||||||
|
b2, err := RandomBytes(32)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RandomBytes (2): %v", err)
|
||||||
|
}
|
||||||
|
if bytes.Equal(b1, b2) {
|
||||||
|
t.Error("two RandomBytes calls returned identical values")
|
||||||
|
}
|
||||||
|
}
|
||||||
608
internal/db/accounts.go
Normal file
608
internal/db/accounts.go
Normal file
@@ -0,0 +1,608 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateAccount inserts a new account record. The UUID is generated
|
||||||
|
// automatically. Returns the created Account with its DB-assigned ID and UUID.
|
||||||
|
func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) {
|
||||||
|
id := uuid.New().String()
|
||||||
|
n := now()
|
||||||
|
|
||||||
|
result, err := db.sql.Exec(`
|
||||||
|
INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, 'active', ?, ?)
|
||||||
|
`, id, username, string(accountType), nullString(passwordHash), n, n)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: create account %q: %w", username, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowID, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
createdAt, err := parseTime(n)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model.Account{
|
||||||
|
ID: rowID,
|
||||||
|
UUID: id,
|
||||||
|
Username: username,
|
||||||
|
AccountType: accountType,
|
||||||
|
Status: model.AccountStatusActive,
|
||||||
|
PasswordHash: passwordHash,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
UpdatedAt: createdAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByUUID retrieves an account by its external UUID.
|
||||||
|
// Returns ErrNotFound if no matching account exists.
|
||||||
|
func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) {
|
||||||
|
return db.scanAccount(db.sql.QueryRow(`
|
||||||
|
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
||||||
|
status, totp_required,
|
||||||
|
totp_secret_enc, totp_secret_nonce,
|
||||||
|
created_at, updated_at, deleted_at
|
||||||
|
FROM accounts WHERE uuid = ?
|
||||||
|
`, accountUUID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByUsername retrieves an account by username (case-insensitive).
|
||||||
|
// Returns ErrNotFound if no matching account exists.
|
||||||
|
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
|
||||||
|
return db.scanAccount(db.sql.QueryRow(`
|
||||||
|
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
||||||
|
status, totp_required,
|
||||||
|
totp_secret_enc, totp_secret_nonce,
|
||||||
|
created_at, updated_at, deleted_at
|
||||||
|
FROM accounts WHERE username = ?
|
||||||
|
`, username))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAccounts returns all non-deleted accounts ordered by username.
|
||||||
|
func (db *DB) ListAccounts() ([]*model.Account, error) {
|
||||||
|
rows, err := db.sql.Query(`
|
||||||
|
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
||||||
|
status, totp_required,
|
||||||
|
totp_secret_enc, totp_secret_nonce,
|
||||||
|
created_at, updated_at, deleted_at
|
||||||
|
FROM accounts
|
||||||
|
WHERE status != 'deleted'
|
||||||
|
ORDER BY username ASC
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: list accounts: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var accounts []*model.Account
|
||||||
|
for rows.Next() {
|
||||||
|
a, err := db.scanAccountRow(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
accounts = append(accounts, a)
|
||||||
|
}
|
||||||
|
return accounts, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAccountStatus updates the status field and optionally sets deleted_at.
|
||||||
|
func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error {
|
||||||
|
n := now()
|
||||||
|
var deletedAt *string
|
||||||
|
if status == model.AccountStatusDeleted {
|
||||||
|
deletedAt = &n
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`, string(status), deletedAt, n, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: update account status: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePasswordHash updates the Argon2id password hash for an account.
|
||||||
|
func (db *DB) UpdatePasswordHash(accountID int64, hash string) error {
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
UPDATE accounts SET password_hash = ?, updated_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`, hash, now(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: update password hash: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTOTP stores the encrypted TOTP secret and marks TOTP as required.
|
||||||
|
func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
UPDATE accounts
|
||||||
|
SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`, secretEnc, secretNonce, now(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: set TOTP: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
|
||||||
|
func (db *DB) ClearTOTP(accountID int64) error {
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
UPDATE accounts
|
||||||
|
SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
`, now(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: clear TOTP: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanAccount scans a single account row from a *sql.Row.
|
||||||
|
func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) {
|
||||||
|
var a model.Account
|
||||||
|
var accountType, status string
|
||||||
|
var totpRequired int
|
||||||
|
var createdAtStr, updatedAtStr string
|
||||||
|
var deletedAtStr *string
|
||||||
|
var totpSecretEnc, totpSecretNonce []byte
|
||||||
|
|
||||||
|
err := row.Scan(
|
||||||
|
&a.ID, &a.UUID, &a.Username,
|
||||||
|
&accountType, &a.PasswordHash,
|
||||||
|
&status, &totpRequired,
|
||||||
|
&totpSecretEnc, &totpSecretNonce,
|
||||||
|
&createdAtStr, &updatedAtStr, &deletedAtStr,
|
||||||
|
)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: scan account: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanAccountRow scans a single account from *sql.Rows.
|
||||||
|
func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) {
|
||||||
|
var a model.Account
|
||||||
|
var accountType, status string
|
||||||
|
var totpRequired int
|
||||||
|
var createdAtStr, updatedAtStr string
|
||||||
|
var deletedAtStr *string
|
||||||
|
var totpSecretEnc, totpSecretNonce []byte
|
||||||
|
|
||||||
|
err := rows.Scan(
|
||||||
|
&a.ID, &a.UUID, &a.Username,
|
||||||
|
&accountType, &a.PasswordHash,
|
||||||
|
&status, &totpRequired,
|
||||||
|
&totpSecretEnc, &totpSecretNonce,
|
||||||
|
&createdAtStr, &updatedAtStr, &deletedAtStr,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: scan account row: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) {
|
||||||
|
a.AccountType = model.AccountType(accountType)
|
||||||
|
a.Status = model.AccountStatus(status)
|
||||||
|
a.TOTPRequired = totpRequired == 1
|
||||||
|
a.TOTPSecretEnc = totpSecretEnc
|
||||||
|
a.TOTPSecretNonce = totpSecretNonce
|
||||||
|
|
||||||
|
var err error
|
||||||
|
a.CreatedAt, err = parseTime(createdAtStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
a.UpdatedAt, err = parseTime(updatedAtStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
a.DeletedAt, err = nullableTime(deletedAtStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nullString converts an empty string to nil for nullable SQL columns.
|
||||||
|
func nullString(s string) *string {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoles returns the role strings assigned to an account.
|
||||||
|
func (db *DB) GetRoles(accountID int64) ([]string, error) {
|
||||||
|
rows, err := db.sql.Query(`
|
||||||
|
SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC
|
||||||
|
`, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var roles []string
|
||||||
|
for rows.Next() {
|
||||||
|
var role string
|
||||||
|
if err := rows.Scan(&role); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: scan role: %w", err)
|
||||||
|
}
|
||||||
|
roles = append(roles, role)
|
||||||
|
}
|
||||||
|
return roles, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
|
||||||
|
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
`, accountID, role, grantedBy, now())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeRole removes a role from an account.
|
||||||
|
func (db *DB) RevokeRole(accountID int64, role string) error {
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
DELETE FROM account_roles WHERE account_id = ? AND role = ?
|
||||||
|
`, accountID, role)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRoles replaces the full role set for an account atomically.
|
||||||
|
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
|
||||||
|
tx, err := db.sql.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: set roles begin tx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: set roles delete existing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := now()
|
||||||
|
for _, role := range roles {
|
||||||
|
if _, err := tx.Exec(`
|
||||||
|
INSERT INTO account_roles (account_id, role, granted_by, granted_at)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
`, accountID, role, grantedBy, n); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: set roles insert %q: %w", role, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("db: set roles commit: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRole reports whether an account holds the given role.
|
||||||
|
func (db *DB) HasRole(accountID int64, role string) (bool, error) {
|
||||||
|
var count int
|
||||||
|
err := db.sql.QueryRow(`
|
||||||
|
SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ?
|
||||||
|
`, accountID, role).Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("db: has role: %w", err)
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteServerConfig stores the encrypted Ed25519 signing key.
|
||||||
|
// There can only be one row (id=1).
|
||||||
|
func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error {
|
||||||
|
n := now()
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at)
|
||||||
|
VALUES (1, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
|
signing_key_enc = excluded.signing_key_enc,
|
||||||
|
signing_key_nonce = excluded.signing_key_nonce,
|
||||||
|
updated_at = excluded.updated_at
|
||||||
|
`, signingKeyEnc, signingKeyNonce, n, n)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: write server config: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadServerConfig returns the encrypted signing key and nonce.
|
||||||
|
// Returns ErrNotFound if no config row exists yet.
|
||||||
|
func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) {
|
||||||
|
err = db.sql.QueryRow(`
|
||||||
|
SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1
|
||||||
|
`).Scan(&signingKeyEnc, &signingKeyNonce)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, nil, ErrNotFound
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("db: read server config: %w", err)
|
||||||
|
}
|
||||||
|
return signingKeyEnc, signingKeyNonce, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation.
|
||||||
|
// The salt must be stable across restarts so the same passphrase always yields
|
||||||
|
// the same master key. There can only be one row (id=1).
|
||||||
|
func (db *DB) WriteMasterKeySalt(salt []byte) error {
|
||||||
|
n := now()
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
INSERT INTO server_config (id, master_key_salt, created_at, updated_at)
|
||||||
|
VALUES (1, ?, ?, ?)
|
||||||
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
|
master_key_salt = excluded.master_key_salt,
|
||||||
|
updated_at = excluded.updated_at
|
||||||
|
`, salt, n, n)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: write master key salt: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadMasterKeySalt returns the stored Argon2id KDF salt.
|
||||||
|
// Returns ErrNotFound if no salt has been stored yet (first run).
|
||||||
|
func (db *DB) ReadMasterKeySalt() ([]byte, error) {
|
||||||
|
var salt []byte
|
||||||
|
err := db.sql.QueryRow(`
|
||||||
|
SELECT master_key_salt FROM server_config WHERE id = 1
|
||||||
|
`).Scan(&salt)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: read master key salt: %w", err)
|
||||||
|
}
|
||||||
|
if salt == nil {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
return salt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WritePGCredentials stores or replaces the Postgres credentials for an account.
|
||||||
|
func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error {
|
||||||
|
n := now()
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
INSERT INTO pg_credentials
|
||||||
|
(account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(account_id) DO UPDATE SET
|
||||||
|
pg_host = excluded.pg_host,
|
||||||
|
pg_port = excluded.pg_port,
|
||||||
|
pg_database = excluded.pg_database,
|
||||||
|
pg_username = excluded.pg_username,
|
||||||
|
pg_password_enc = excluded.pg_password_enc,
|
||||||
|
pg_password_nonce = excluded.pg_password_nonce,
|
||||||
|
updated_at = excluded.updated_at
|
||||||
|
`, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: write pg credentials: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadPGCredentials retrieves the encrypted Postgres credentials for an account.
|
||||||
|
// Returns ErrNotFound if no credentials are stored.
|
||||||
|
func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) {
|
||||||
|
var cred model.PGCredential
|
||||||
|
var createdAtStr, updatedAtStr string
|
||||||
|
|
||||||
|
err := db.sql.QueryRow(`
|
||||||
|
SELECT id, account_id, pg_host, pg_port, pg_database, pg_username,
|
||||||
|
pg_password_enc, pg_password_nonce, created_at, updated_at
|
||||||
|
FROM pg_credentials WHERE account_id = ?
|
||||||
|
`, accountID).Scan(
|
||||||
|
&cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort,
|
||||||
|
&cred.PGDatabase, &cred.PGUsername,
|
||||||
|
&cred.PGPasswordEnc, &cred.PGPasswordNonce,
|
||||||
|
&createdAtStr, &updatedAtStr,
|
||||||
|
)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: read pg credentials: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cred.CreatedAt, err = parseTime(createdAtStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cred.UpdatedAt, err = parseTime(updatedAtStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &cred, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteAuditEvent appends an audit log entry.
|
||||||
|
// Details must never contain credential material.
|
||||||
|
func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error {
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
`, eventType, actorID, targetID, nullString(ipAddress), nullString(details))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: write audit event %q: %w", eventType, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackToken records a newly issued JWT JTI for revocation tracking.
|
||||||
|
func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error {
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
`, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: track token %q: %w", jti, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenRecord retrieves a token record by JTI.
|
||||||
|
// Returns ErrNotFound if no record exists (token was never issued by this server).
|
||||||
|
func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) {
|
||||||
|
var rec model.TokenRecord
|
||||||
|
var issuedAtStr, expiresAtStr, createdAtStr string
|
||||||
|
var revokedAtStr *string
|
||||||
|
var revokeReason *string
|
||||||
|
|
||||||
|
err := db.sql.QueryRow(`
|
||||||
|
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
|
||||||
|
FROM token_revocation WHERE jti = ?
|
||||||
|
`, jti).Scan(
|
||||||
|
&rec.ID, &rec.JTI, &rec.AccountID,
|
||||||
|
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
|
||||||
|
&createdAtStr,
|
||||||
|
)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: get token record %q: %w", jti, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parseErr error
|
||||||
|
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, parseErr
|
||||||
|
}
|
||||||
|
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, parseErr
|
||||||
|
}
|
||||||
|
rec.CreatedAt, parseErr = parseTime(createdAtStr)
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, parseErr
|
||||||
|
}
|
||||||
|
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, parseErr
|
||||||
|
}
|
||||||
|
if revokeReason != nil {
|
||||||
|
rec.RevokeReason = *revokeReason
|
||||||
|
}
|
||||||
|
return &rec, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeToken marks a token as revoked by JTI.
|
||||||
|
func (db *DB) RevokeToken(jti, reason string) error {
|
||||||
|
n := now()
|
||||||
|
result, err := db.sql.Exec(`
|
||||||
|
UPDATE token_revocation
|
||||||
|
SET revoked_at = ?, revoke_reason = ?
|
||||||
|
WHERE jti = ? AND revoked_at IS NULL
|
||||||
|
`, n, nullString(reason), jti)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: revoke token %q: %w", jti, err)
|
||||||
|
}
|
||||||
|
rows, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: revoke token rows affected: %w", err)
|
||||||
|
}
|
||||||
|
if rows == 0 {
|
||||||
|
return fmt.Errorf("db: token %q not found or already revoked", jti)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account.
|
||||||
|
func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error {
|
||||||
|
n := now()
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
UPDATE token_revocation
|
||||||
|
SET revoked_at = ?, revoke_reason = ?
|
||||||
|
WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ?
|
||||||
|
`, n, nullString(reason), accountID, n)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PruneExpiredTokens removes token_revocation rows that are past their expiry.
|
||||||
|
// Returns the number of rows deleted.
|
||||||
|
func (db *DB) PruneExpiredTokens() (int64, error) {
|
||||||
|
result, err := db.sql.Exec(`
|
||||||
|
DELETE FROM token_revocation WHERE expires_at < ?
|
||||||
|
`, now())
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("db: prune expired tokens: %w", err)
|
||||||
|
}
|
||||||
|
return result.RowsAffected()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSystemToken stores or replaces the active service token JTI for a system account.
|
||||||
|
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
|
||||||
|
n := now()
|
||||||
|
_, err := db.sql.Exec(`
|
||||||
|
INSERT INTO system_tokens (account_id, jti, expires_at, created_at)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
ON CONFLICT(account_id) DO UPDATE SET
|
||||||
|
jti = excluded.jti,
|
||||||
|
expires_at = excluded.expires_at,
|
||||||
|
created_at = excluded.created_at
|
||||||
|
`, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: set system token for account %d: %w", accountID, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSystemToken retrieves the active service token record for a system account.
|
||||||
|
func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) {
|
||||||
|
var st model.SystemToken
|
||||||
|
var expiresAtStr, createdAtStr string
|
||||||
|
|
||||||
|
err := db.sql.QueryRow(`
|
||||||
|
SELECT id, account_id, jti, expires_at, created_at
|
||||||
|
FROM system_tokens WHERE account_id = ?
|
||||||
|
`, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: get system token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parseErr error
|
||||||
|
st.ExpiresAt, parseErr = parseTime(expiresAtStr)
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, parseErr
|
||||||
|
}
|
||||||
|
st.CreatedAt, parseErr = parseTime(createdAtStr)
|
||||||
|
if parseErr != nil {
|
||||||
|
return nil, parseErr
|
||||||
|
}
|
||||||
|
return &st, nil
|
||||||
|
}
|
||||||
109
internal/db/db.go
Normal file
109
internal/db/db.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
// Package db provides the SQLite database access layer for MCIAS.
|
||||||
|
//
|
||||||
|
// Security design:
|
||||||
|
// - All queries use parameterized statements; no string concatenation.
|
||||||
|
// - Foreign keys are enforced (PRAGMA foreign_keys = ON).
|
||||||
|
// - WAL mode is enabled for safe concurrent reads during writes.
|
||||||
|
// - The audit log is append-only: no update or delete operations are provided.
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "modernc.org/sqlite" // register the sqlite3 driver
|
||||||
|
)
|
||||||
|
|
||||||
|
// DB wraps a *sql.DB with MCIAS-specific helpers.
|
||||||
|
type DB struct {
|
||||||
|
sql *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open opens (or creates) the SQLite database at path and configures it for
|
||||||
|
// MCIAS use (WAL mode, foreign keys, busy timeout).
|
||||||
|
func Open(path string) (*DB, error) {
|
||||||
|
// The modernc.org/sqlite driver is registered as "sqlite".
|
||||||
|
sqlDB, err := sql.Open("sqlite", path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: open sqlite: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a single connection for writes; reads can use the pool.
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
|
||||||
|
db := &DB{sql: sqlDB}
|
||||||
|
if err := db.configure(); err != nil {
|
||||||
|
_ = sqlDB.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configure applies PRAGMAs that must be set on every connection.
|
||||||
|
func (db *DB) configure() error {
|
||||||
|
pragmas := []string{
|
||||||
|
"PRAGMA journal_mode=WAL",
|
||||||
|
"PRAGMA foreign_keys=ON",
|
||||||
|
"PRAGMA busy_timeout=5000",
|
||||||
|
"PRAGMA synchronous=NORMAL",
|
||||||
|
}
|
||||||
|
for _, p := range pragmas {
|
||||||
|
if _, err := db.sql.Exec(p); err != nil {
|
||||||
|
return fmt.Errorf("db: configure pragma %q: %w", p, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying database connection.
|
||||||
|
func (db *DB) Close() error {
|
||||||
|
return db.sql.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping verifies the database connection is alive.
|
||||||
|
func (db *DB) Ping(ctx context.Context) error {
|
||||||
|
return db.sql.PingContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQL returns the underlying *sql.DB for use in tests or advanced queries.
|
||||||
|
// Prefer the typed methods on DB for all production code.
|
||||||
|
func (db *DB) SQL() *sql.DB {
|
||||||
|
return db.sql
|
||||||
|
}
|
||||||
|
|
||||||
|
// now returns the current UTC time formatted as ISO-8601.
|
||||||
|
func now() string {
|
||||||
|
return time.Now().UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseTime parses an ISO-8601 UTC time string returned by SQLite.
|
||||||
|
func parseTime(s string) (time.Time, error) {
|
||||||
|
t, err := time.Parse(time.RFC3339, s)
|
||||||
|
if err != nil {
|
||||||
|
// Try without timezone suffix (some SQLite defaults).
|
||||||
|
t, err = time.Parse("2006-01-02T15:04:05", s)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("db: parse time %q: %w", s, err)
|
||||||
|
}
|
||||||
|
return t.UTC(), nil
|
||||||
|
}
|
||||||
|
return t.UTC(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrNotFound is returned when a requested record does not exist.
|
||||||
|
var ErrNotFound = errors.New("db: record not found")
|
||||||
|
|
||||||
|
// nullableTime converts a *string from SQLite into a *time.Time.
|
||||||
|
func nullableTime(s *string) (*time.Time, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
t, err := parseTime(*s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &t, nil
|
||||||
|
}
|
||||||
355
internal/db/db_test.go
Normal file
355
internal/db/db_test.go
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// openTestDB opens an in-memory SQLite database for testing.
|
||||||
|
func openTestDB(t *testing.T) *DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := Open(":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Open: %v", err)
|
||||||
|
}
|
||||||
|
if err := Migrate(db); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateIdempotent(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
// Run again — should be a no-op.
|
||||||
|
if err := Migrate(db); err != nil {
|
||||||
|
t.Errorf("second Migrate call returned error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndGetAccount(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
|
||||||
|
acct, err := db.CreateAccount("alice", model.AccountTypeHuman, "$argon2id$v=19$...")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateAccount: %v", err)
|
||||||
|
}
|
||||||
|
if acct.UUID == "" {
|
||||||
|
t.Error("expected non-empty UUID")
|
||||||
|
}
|
||||||
|
if acct.Username != "alice" {
|
||||||
|
t.Errorf("Username = %q, want %q", acct.Username, "alice")
|
||||||
|
}
|
||||||
|
if acct.Status != model.AccountStatusActive {
|
||||||
|
t.Errorf("Status = %q, want active", acct.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve by UUID.
|
||||||
|
got, err := db.GetAccountByUUID(acct.UUID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetAccountByUUID: %v", err)
|
||||||
|
}
|
||||||
|
if got.Username != "alice" {
|
||||||
|
t.Errorf("fetched Username = %q, want %q", got.Username, "alice")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve by username.
|
||||||
|
got2, err := db.GetAccountByUsername("alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetAccountByUsername: %v", err)
|
||||||
|
}
|
||||||
|
if got2.UUID != acct.UUID {
|
||||||
|
t.Errorf("UUID mismatch: got %q, want %q", got2.UUID, acct.UUID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountNotFound(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
|
||||||
|
_, err := db.GetAccountByUUID("nonexistent-uuid")
|
||||||
|
if err != ErrNotFound {
|
||||||
|
t.Errorf("expected ErrNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.GetAccountByUsername("nobody")
|
||||||
|
if err != ErrNotFound {
|
||||||
|
t.Errorf("expected ErrNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateAccountStatus(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
acct, err := db.CreateAccount("bob", model.AccountTypeHuman, "hash")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateAccount: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.UpdateAccountStatus(acct.ID, model.AccountStatusInactive); err != nil {
|
||||||
|
t.Fatalf("UpdateAccountStatus: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := db.GetAccountByUUID(acct.UUID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetAccountByUUID: %v", err)
|
||||||
|
}
|
||||||
|
if got.Status != model.AccountStatusInactive {
|
||||||
|
t.Errorf("Status = %q, want inactive", got.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListAccounts(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
for _, name := range []string{"charlie", "delta", "eve"} {
|
||||||
|
if _, err := db.CreateAccount(name, model.AccountTypeHuman, "hash"); err != nil {
|
||||||
|
t.Fatalf("CreateAccount %q: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accts, err := db.ListAccounts()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListAccounts: %v", err)
|
||||||
|
}
|
||||||
|
if len(accts) != 3 {
|
||||||
|
t.Errorf("ListAccounts returned %d accounts, want 3", len(accts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoleOperations(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
acct, err := db.CreateAccount("frank", model.AccountTypeHuman, "hash")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateAccount: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrantRole
|
||||||
|
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
|
||||||
|
t.Fatalf("GrantRole: %v", err)
|
||||||
|
}
|
||||||
|
// Grant again — should be no-op.
|
||||||
|
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
|
||||||
|
t.Fatalf("GrantRole duplicate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
roles, err := db.GetRoles(acct.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRoles: %v", err)
|
||||||
|
}
|
||||||
|
if len(roles) != 1 || roles[0] != "admin" {
|
||||||
|
t.Errorf("GetRoles = %v, want [admin]", roles)
|
||||||
|
}
|
||||||
|
|
||||||
|
has, err := db.HasRole(acct.ID, "admin")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HasRole: %v", err)
|
||||||
|
}
|
||||||
|
if !has {
|
||||||
|
t.Error("expected HasRole to return true for 'admin'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeRole
|
||||||
|
if err := db.RevokeRole(acct.ID, "admin"); err != nil {
|
||||||
|
t.Fatalf("RevokeRole: %v", err)
|
||||||
|
}
|
||||||
|
roles, err = db.GetRoles(acct.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRoles after revoke: %v", err)
|
||||||
|
}
|
||||||
|
if len(roles) != 0 {
|
||||||
|
t.Errorf("expected no roles after revoke, got %v", roles)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRoles
|
||||||
|
if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil {
|
||||||
|
t.Fatalf("SetRoles: %v", err)
|
||||||
|
}
|
||||||
|
roles, err = db.GetRoles(acct.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRoles after SetRoles: %v", err)
|
||||||
|
}
|
||||||
|
if len(roles) != 2 {
|
||||||
|
t.Errorf("expected 2 roles after SetRoles, got %d", len(roles))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenTrackingAndRevocation(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
acct, err := db.CreateAccount("grace", model.AccountTypeHuman, "hash")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateAccount: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jti := "test-jti-1234"
|
||||||
|
issuedAt := time.Now().UTC()
|
||||||
|
expiresAt := issuedAt.Add(time.Hour)
|
||||||
|
|
||||||
|
if err := db.TrackToken(jti, acct.ID, issuedAt, expiresAt); err != nil {
|
||||||
|
t.Fatalf("TrackToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve
|
||||||
|
rec, err := db.GetTokenRecord(jti)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTokenRecord: %v", err)
|
||||||
|
}
|
||||||
|
if rec.JTI != jti {
|
||||||
|
t.Errorf("JTI = %q, want %q", rec.JTI, jti)
|
||||||
|
}
|
||||||
|
if rec.IsRevoked() {
|
||||||
|
t.Error("newly tracked token should not be revoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke
|
||||||
|
if err := db.RevokeToken(jti, "test revocation"); err != nil {
|
||||||
|
t.Fatalf("RevokeToken: %v", err)
|
||||||
|
}
|
||||||
|
rec, err = db.GetTokenRecord(jti)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTokenRecord after revoke: %v", err)
|
||||||
|
}
|
||||||
|
if !rec.IsRevoked() {
|
||||||
|
t.Error("token should be revoked after RevokeToken")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoking again should fail (already revoked).
|
||||||
|
if err := db.RevokeToken(jti, "again"); err == nil {
|
||||||
|
t.Error("expected error when revoking already-revoked token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTokenRecordNotFound(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
_, err := db.GetTokenRecord("no-such-jti")
|
||||||
|
if err != ErrNotFound {
|
||||||
|
t.Errorf("expected ErrNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPruneExpiredTokens(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
acct, err := db.CreateAccount("henry", model.AccountTypeHuman, "hash")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateAccount: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
past := time.Now().UTC().Add(-time.Hour)
|
||||||
|
future := time.Now().UTC().Add(time.Hour)
|
||||||
|
|
||||||
|
if err := db.TrackToken("expired-jti", acct.ID, past.Add(-time.Hour), past); err != nil {
|
||||||
|
t.Fatalf("TrackToken expired: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.TrackToken("valid-jti", acct.ID, time.Now(), future); err != nil {
|
||||||
|
t.Fatalf("TrackToken valid: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := db.PruneExpiredTokens()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PruneExpiredTokens: %v", err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Errorf("pruned %d rows, want 1", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid token should still be retrievable.
|
||||||
|
if _, err := db.GetTokenRecord("valid-jti"); err != nil {
|
||||||
|
t.Errorf("valid token missing after prune: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerConfig(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
|
||||||
|
// No config initially.
|
||||||
|
_, _, err := db.ReadServerConfig()
|
||||||
|
if err != ErrNotFound {
|
||||||
|
t.Errorf("expected ErrNotFound for missing config, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := []byte("encrypted-key-data")
|
||||||
|
nonce := []byte("nonce12345678901")
|
||||||
|
|
||||||
|
if err := db.WriteServerConfig(enc, nonce); err != nil {
|
||||||
|
t.Fatalf("WriteServerConfig: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotEnc, gotNonce, err := db.ReadServerConfig()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadServerConfig: %v", err)
|
||||||
|
}
|
||||||
|
if string(gotEnc) != string(enc) {
|
||||||
|
t.Errorf("enc mismatch: got %q, want %q", gotEnc, enc)
|
||||||
|
}
|
||||||
|
if string(gotNonce) != string(nonce) {
|
||||||
|
t.Errorf("nonce mismatch: got %q, want %q", gotNonce, nonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overwrite — should work without error.
|
||||||
|
if err := db.WriteServerConfig([]byte("new-key"), []byte("new-nonce123456")); err != nil {
|
||||||
|
t.Fatalf("WriteServerConfig overwrite: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForeignKeyEnforcement(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
// Attempting to track a token for a non-existent account should fail.
|
||||||
|
err := db.TrackToken("jti-x", 999999, time.Now(), time.Now().Add(time.Hour))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected foreign key error for non-existent account_id, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPGCredentials(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateAccount: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := []byte("encrypted-pg-password")
|
||||||
|
nonce := []byte("pg-nonce12345678")
|
||||||
|
|
||||||
|
if err := db.WritePGCredentials(acct.ID, "localhost", 5432, "mydb", "myuser", enc, nonce); err != nil {
|
||||||
|
t.Fatalf("WritePGCredentials: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cred, err := db.ReadPGCredentials(acct.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadPGCredentials: %v", err)
|
||||||
|
}
|
||||||
|
if cred.PGHost != "localhost" {
|
||||||
|
t.Errorf("PGHost = %q, want %q", cred.PGHost, "localhost")
|
||||||
|
}
|
||||||
|
if cred.PGDatabase != "mydb" {
|
||||||
|
t.Errorf("PGDatabase = %q, want %q", cred.PGDatabase, "mydb")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRevokeAllUserTokens(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
acct, err := db.CreateAccount("ivan", model.AccountTypeHuman, "hash")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateAccount: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
future := time.Now().UTC().Add(time.Hour)
|
||||||
|
for _, jti := range []string{"tok1", "tok2", "tok3"} {
|
||||||
|
if err := db.TrackToken(jti, acct.ID, time.Now(), future); err != nil {
|
||||||
|
t.Fatalf("TrackToken %q: %v", jti, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.RevokeAllUserTokens(acct.ID, "account suspended"); err != nil {
|
||||||
|
t.Fatalf("RevokeAllUserTokens: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, jti := range []string{"tok1", "tok2", "tok3"} {
|
||||||
|
rec, err := db.GetTokenRecord(jti)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTokenRecord %q: %v", jti, err)
|
||||||
|
}
|
||||||
|
if !rec.IsRevoked() {
|
||||||
|
t.Errorf("token %q should be revoked", jti)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
187
internal/db/migrate.go
Normal file
187
internal/db/migrate.go
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// migration represents a single schema migration with an ID and SQL statement.
|
||||||
|
type migration struct {
|
||||||
|
id int
|
||||||
|
sql string
|
||||||
|
}
|
||||||
|
|
||||||
|
// migrations is the ordered list of schema migrations applied to the database.
|
||||||
|
// Once applied, migrations must never be modified — only new ones appended.
|
||||||
|
var migrations = []migration{
|
||||||
|
{
|
||||||
|
id: 1,
|
||||||
|
sql: `
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_version (
|
||||||
|
version INTEGER NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS server_config (
|
||||||
|
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||||
|
signing_key_enc BLOB,
|
||||||
|
signing_key_nonce BLOB,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||||
|
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS accounts (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
uuid TEXT NOT NULL UNIQUE,
|
||||||
|
username TEXT NOT NULL UNIQUE COLLATE NOCASE,
|
||||||
|
account_type TEXT NOT NULL CHECK (account_type IN ('human','system')),
|
||||||
|
password_hash TEXT,
|
||||||
|
status TEXT NOT NULL DEFAULT 'active'
|
||||||
|
CHECK (status IN ('active','inactive','deleted')),
|
||||||
|
totp_required INTEGER NOT NULL DEFAULT 0 CHECK (totp_required IN (0,1)),
|
||||||
|
totp_secret_enc BLOB,
|
||||||
|
totp_secret_nonce BLOB,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||||
|
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||||
|
deleted_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts (username);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_accounts_uuid ON accounts (uuid);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_accounts_status ON accounts (status);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS account_roles (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
granted_by INTEGER REFERENCES accounts(id),
|
||||||
|
granted_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||||
|
UNIQUE (account_id, role)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_account_roles_account ON account_roles (account_id);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS token_revocation (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
jti TEXT NOT NULL UNIQUE,
|
||||||
|
account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||||
|
expires_at TEXT NOT NULL,
|
||||||
|
revoked_at TEXT,
|
||||||
|
revoke_reason TEXT,
|
||||||
|
issued_at TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_token_jti ON token_revocation (jti);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_token_account ON token_revocation (account_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_token_expires ON token_revocation (expires_at);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS system_tokens (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
account_id INTEGER NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE,
|
||||||
|
jti TEXT NOT NULL UNIQUE,
|
||||||
|
expires_at TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS pg_credentials (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
account_id INTEGER NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE,
|
||||||
|
pg_host TEXT NOT NULL,
|
||||||
|
pg_port INTEGER NOT NULL DEFAULT 5432,
|
||||||
|
pg_database TEXT NOT NULL,
|
||||||
|
pg_username TEXT NOT NULL,
|
||||||
|
pg_password_enc BLOB NOT NULL,
|
||||||
|
pg_password_nonce BLOB NOT NULL,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||||
|
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS audit_log (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
event_time TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||||
|
event_type TEXT NOT NULL,
|
||||||
|
actor_id INTEGER REFERENCES accounts(id),
|
||||||
|
target_id INTEGER REFERENCES accounts(id),
|
||||||
|
ip_address TEXT,
|
||||||
|
details TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_audit_time ON audit_log (event_time);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_audit_actor ON audit_log (actor_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log (event_type);
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 2,
|
||||||
|
sql: `
|
||||||
|
-- Add master_key_salt to server_config for Argon2id KDF salt storage.
|
||||||
|
-- The salt must be stable across restarts so the passphrase always yields the same key.
|
||||||
|
-- We allow NULL signing_key_enc/nonce temporarily until the first signing key is generated.
|
||||||
|
ALTER TABLE server_config ADD COLUMN master_key_salt BLOB;
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Migrate applies any unapplied schema migrations to the database in order.
|
||||||
|
// It is idempotent: running it multiple times is safe.
|
||||||
|
func Migrate(db *DB) error {
|
||||||
|
// Ensure the schema_version table exists first.
|
||||||
|
if _, err := db.sql.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_version (
|
||||||
|
version INTEGER NOT NULL
|
||||||
|
)
|
||||||
|
`); err != nil {
|
||||||
|
return fmt.Errorf("db: ensure schema_version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentVersion, err := currentSchemaVersion(db.sql)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: get current schema version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range migrations {
|
||||||
|
if m.id <= currentVersion {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := db.sql.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: begin migration %d transaction: %w", m.id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(m.sql); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: apply migration %d: %w", m.id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the schema version within the same transaction.
|
||||||
|
if currentVersion == 0 {
|
||||||
|
if _, err := tx.Exec(`INSERT INTO schema_version (version) VALUES (?)`, m.id); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: insert schema version %d: %w", m.id, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := tx.Exec(`UPDATE schema_version SET version = ?`, m.id); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: update schema version to %d: %w", m.id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("db: commit migration %d: %w", m.id, err)
|
||||||
|
}
|
||||||
|
currentVersion = m.id
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// currentSchemaVersion returns the current schema version, or 0 if none applied.
|
||||||
|
func currentSchemaVersion(db *sql.DB) (int, error) {
|
||||||
|
var version int
|
||||||
|
err := db.QueryRow(`SELECT version FROM schema_version LIMIT 1`).Scan(&version)
|
||||||
|
if err != nil {
|
||||||
|
// No rows means version 0 (fresh database).
|
||||||
|
return 0, nil //nolint:nilerr
|
||||||
|
}
|
||||||
|
return version, nil
|
||||||
|
}
|
||||||
290
internal/middleware/middleware.go
Normal file
290
internal/middleware/middleware.go
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
// Package middleware provides HTTP middleware for the MCIAS server.
|
||||||
|
//
|
||||||
|
// Security design:
|
||||||
|
// - RequireAuth extracts the Bearer token from the Authorization header,
|
||||||
|
// validates it (alg check, signature, expiry, issuer), and checks revocation
|
||||||
|
// against the database before injecting claims into the request context.
|
||||||
|
// - RequireRole checks claims from context for the required role.
|
||||||
|
// No role implies no access; the check fails closed.
|
||||||
|
// - RateLimit implements a per-IP token bucket to limit login brute-force.
|
||||||
|
// - RequestLogger logs request metadata but never logs the Authorization
|
||||||
|
// header value (which contains credential tokens).
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// contextKey is the unexported type for context keys in this package, preventing
|
||||||
|
// collisions with keys from other packages.
|
||||||
|
type contextKey int
|
||||||
|
|
||||||
|
const (
|
||||||
|
claimsKey contextKey = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClaimsFromContext retrieves the validated JWT claims from the request context.
|
||||||
|
// Returns nil if no claims are present (unauthenticated request).
|
||||||
|
func ClaimsFromContext(ctx context.Context) *token.Claims {
|
||||||
|
c, _ := ctx.Value(claimsKey).(*token.Claims)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestLogger returns middleware that logs each request at INFO level.
|
||||||
|
// The Authorization header is intentionally never logged.
|
||||||
|
func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
start := time.Now()
|
||||||
|
// Wrap the ResponseWriter to capture the status code.
|
||||||
|
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
|
||||||
|
next.ServeHTTP(rw, r)
|
||||||
|
|
||||||
|
logger.Info("request",
|
||||||
|
"method", r.Method,
|
||||||
|
"path", r.URL.Path,
|
||||||
|
"status", rw.status,
|
||||||
|
"duration_ms", time.Since(start).Milliseconds(),
|
||||||
|
"remote_addr", r.RemoteAddr,
|
||||||
|
"user_agent", r.UserAgent(),
|
||||||
|
// Security: Authorization header is never logged.
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseWriter wraps http.ResponseWriter to capture the status code.
|
||||||
|
type responseWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
status int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *responseWriter) WriteHeader(code int) {
|
||||||
|
rw.status = code
|
||||||
|
rw.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireAuth returns middleware that validates a Bearer JWT and injects the
|
||||||
|
// claims into the request context. Returns 401 on any auth failure.
|
||||||
|
//
|
||||||
|
// Security: Token validation order:
|
||||||
|
// 1. Extract Bearer token from Authorization header.
|
||||||
|
// 2. Validate the JWT (alg=EdDSA, signature, expiry, issuer).
|
||||||
|
// 3. Check the JTI against the revocation table in the database.
|
||||||
|
// 4. Inject validated claims into context for downstream handlers.
|
||||||
|
func RequireAuth(pubKey ed25519.PublicKey, database *db.DB, issuer string) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tokenStr, err := extractBearerToken(r)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusUnauthorized, "missing or malformed Authorization header", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := token.ValidateToken(pubKey, tokenStr, issuer)
|
||||||
|
if err != nil {
|
||||||
|
// Security: Map all token errors to a generic 401; do not
|
||||||
|
// reveal which specific check failed.
|
||||||
|
if errors.Is(err, token.ErrExpiredToken) {
|
||||||
|
writeError(w, http.StatusUnauthorized, "token expired", "token_expired")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeError(w, http.StatusUnauthorized, "invalid token", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Security: Check revocation table. A token may be cryptographically
|
||||||
|
// valid but explicitly revoked (logout, account suspension, etc.).
|
||||||
|
rec, err := database.GetTokenRecord(claims.JTI)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
|
// Token not tracked — could be from a different server instance
|
||||||
|
// or pre-dates tracking. Reject to be safe (fail closed).
|
||||||
|
writeError(w, http.StatusUnauthorized, "unrecognized token", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rec.IsRevoked() {
|
||||||
|
writeError(w, http.StatusUnauthorized, "token has been revoked", "token_revoked")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), claimsKey, claims)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireRole returns middleware that checks whether the authenticated user has
|
||||||
|
// the given role. Must be used after RequireAuth. Returns 403 if role is absent.
|
||||||
|
func RequireRole(role string) func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims := ClaimsFromContext(r.Context())
|
||||||
|
if claims == nil {
|
||||||
|
// RequireAuth was not applied upstream; fail closed.
|
||||||
|
writeError(w, http.StatusForbidden, "forbidden", "forbidden")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !claims.HasRole(role) {
|
||||||
|
writeError(w, http.StatusForbidden, "insufficient privileges", "forbidden")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rateLimitEntry holds the token bucket state for a single IP.
|
||||||
|
type rateLimitEntry struct {
|
||||||
|
tokens float64
|
||||||
|
lastSeen time.Time
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipRateLimiter implements a per-IP token bucket rate limiter.
|
||||||
|
type ipRateLimiter struct {
|
||||||
|
rps float64 // refill rate: tokens per second
|
||||||
|
burst float64 // bucket capacity
|
||||||
|
ttl time.Duration // how long to keep idle entries
|
||||||
|
mu sync.Mutex
|
||||||
|
ips map[string]*rateLimitEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimit returns middleware implementing a per-IP token bucket.
|
||||||
|
// rps is the sustained request rate (tokens refilled per second).
|
||||||
|
// burst is the maximum burst size (initial and maximum token count).
|
||||||
|
//
|
||||||
|
// Security: Rate limiting is applied at the IP level. In production, the
|
||||||
|
// server should be behind a reverse proxy that sets X-Forwarded-For; this
|
||||||
|
// middleware uses RemoteAddr directly which may be the proxy IP. For single-
|
||||||
|
// instance deployment without a proxy, RemoteAddr is the client IP.
|
||||||
|
func RateLimit(rps float64, burst int) func(http.Handler) http.Handler {
|
||||||
|
limiter := &ipRateLimiter{
|
||||||
|
rps: rps,
|
||||||
|
burst: float64(burst),
|
||||||
|
ttl: 10 * time.Minute,
|
||||||
|
ips: make(map[string]*rateLimitEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Background cleanup of idle entries to prevent unbounded memory growth.
|
||||||
|
go limiter.cleanup()
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
ip = r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
if !limiter.allow(ip) {
|
||||||
|
w.Header().Set("Retry-After", "60")
|
||||||
|
writeError(w, http.StatusTooManyRequests, "rate limit exceeded", "rate_limited")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow returns true if a request from ip is permitted under the rate limit.
|
||||||
|
func (l *ipRateLimiter) allow(ip string) bool {
|
||||||
|
l.mu.Lock()
|
||||||
|
entry, ok := l.ips[ip]
|
||||||
|
if !ok {
|
||||||
|
entry = &rateLimitEntry{tokens: l.burst, lastSeen: time.Now()}
|
||||||
|
l.ips[ip] = entry
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
entry.mu.Lock()
|
||||||
|
defer entry.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
elapsed := now.Sub(entry.lastSeen).Seconds()
|
||||||
|
entry.tokens = min(l.burst, entry.tokens+elapsed*l.rps)
|
||||||
|
entry.lastSeen = now
|
||||||
|
|
||||||
|
if entry.tokens < 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
entry.tokens--
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup periodically removes idle rate-limit entries.
|
||||||
|
func (l *ipRateLimiter) cleanup() {
|
||||||
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
l.mu.Lock()
|
||||||
|
cutoff := time.Now().Add(-l.ttl)
|
||||||
|
for ip, entry := range l.ips {
|
||||||
|
entry.mu.Lock()
|
||||||
|
if entry.lastSeen.Before(cutoff) {
|
||||||
|
delete(l.ips, ip)
|
||||||
|
}
|
||||||
|
entry.mu.Unlock()
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractBearerToken extracts the token from "Authorization: Bearer <token>".
|
||||||
|
func extractBearerToken(r *http.Request) (string, error) {
|
||||||
|
auth := r.Header.Get("Authorization")
|
||||||
|
if auth == "" {
|
||||||
|
return "", fmt.Errorf("missing Authorization header")
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(auth, " ", 2)
|
||||||
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||||
|
return "", fmt.Errorf("malformed Authorization header")
|
||||||
|
}
|
||||||
|
if parts[1] == "" {
|
||||||
|
return "", fmt.Errorf("empty Bearer token")
|
||||||
|
}
|
||||||
|
return parts[1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// apiError is the uniform error response structure.
|
||||||
|
type apiError struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
Code string `json:"code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeError writes a JSON error response.
|
||||||
|
func writeError(w http.ResponseWriter, status int, message, code string) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
// Intentionally ignoring the error here; if the write fails, the client
|
||||||
|
// already got the status code.
|
||||||
|
_ = json.NewEncoder(w).Encode(apiError{Error: message, Code: code})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteError is the exported version for use by handler packages.
|
||||||
|
func WriteError(w http.ResponseWriter, status int, message, code string) {
|
||||||
|
writeError(w, status, message, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// min returns the smaller of two float64 values.
|
||||||
|
func min(a, b float64) float64 {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
342
internal/middleware/middleware_test.go
Normal file
342
internal/middleware/middleware_test.go
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/token"
|
||||||
|
)
|
||||||
|
|
||||||
|
func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) {
|
||||||
|
t.Helper()
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate test key: %v", err)
|
||||||
|
}
|
||||||
|
return pub, priv
|
||||||
|
}
|
||||||
|
|
||||||
|
func openTestDB(t *testing.T) *db.DB {
|
||||||
|
t.Helper()
|
||||||
|
database, err := db.Open(":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open test db: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Migrate(database); err != nil {
|
||||||
|
t.Fatalf("migrate test db: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = database.Close() })
|
||||||
|
return database
|
||||||
|
}
|
||||||
|
|
||||||
|
const testIssuer = "https://auth.example.com"
|
||||||
|
|
||||||
|
// issueAndTrackToken creates a valid JWT and records it in the DB.
|
||||||
|
func issueAndTrackToken(t *testing.T, priv ed25519.PrivateKey, database *db.DB, accountID int64, roles []string) string {
|
||||||
|
t.Helper()
|
||||||
|
tokenStr, claims, err := token.IssueToken(priv, testIssuer, "user-uuid", roles, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
if err := database.TrackToken(claims.JTI, accountID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||||
|
t.Fatalf("TrackToken: %v", err)
|
||||||
|
}
|
||||||
|
return tokenStr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestLogger(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||||
|
|
||||||
|
handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/health", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
logOutput := buf.String()
|
||||||
|
if logOutput == "" {
|
||||||
|
t.Error("expected log output, got empty string")
|
||||||
|
}
|
||||||
|
// Security: Authorization header must not appear in logs.
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/v1/health", nil)
|
||||||
|
req2.Header.Set("Authorization", "Bearer secret-token-value")
|
||||||
|
buf.Reset()
|
||||||
|
rr2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr2, req2)
|
||||||
|
if bytes.Contains(buf.Bytes(), []byte("secret-token-value")) {
|
||||||
|
t.Error("log output contains Authorization token value — credential leak!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAuthValid(t *testing.T) {
|
||||||
|
pub, priv := generateTestKey(t)
|
||||||
|
database := openTestDB(t)
|
||||||
|
|
||||||
|
acct, err := database.CreateAccount("alice", model.AccountTypeHuman, "hash")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateAccount: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenStr := issueAndTrackToken(t, priv, database, acct.ID, []string{"reader"})
|
||||||
|
|
||||||
|
reached := false
|
||||||
|
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
reached = true
|
||||||
|
claims := ClaimsFromContext(r.Context())
|
||||||
|
if claims == nil {
|
||||||
|
t.Error("claims not in context")
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
if !reached {
|
||||||
|
t.Error("handler was not reached with valid token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAuthMissingHeader(t *testing.T) {
|
||||||
|
pub, priv := generateTestKey(t)
|
||||||
|
_ = priv
|
||||||
|
database := openTestDB(t)
|
||||||
|
|
||||||
|
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Error("handler should not be reached without auth")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAuthInvalidToken(t *testing.T) {
|
||||||
|
pub, _ := generateTestKey(t)
|
||||||
|
database := openTestDB(t)
|
||||||
|
|
||||||
|
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Error("handler should not be reached with invalid token")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer not.a.valid.jwt")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAuthRevokedToken(t *testing.T) {
|
||||||
|
pub, priv := generateTestKey(t)
|
||||||
|
database := openTestDB(t)
|
||||||
|
|
||||||
|
acct, err := database.CreateAccount("bob", model.AccountTypeHuman, "hash")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateAccount: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenStr := issueAndTrackToken(t, priv, database, acct.ID, nil)
|
||||||
|
|
||||||
|
// Extract JTI and revoke the token.
|
||||||
|
claims, err := token.ValidateToken(pub, tokenStr, testIssuer)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateToken: %v", err)
|
||||||
|
}
|
||||||
|
if err := database.RevokeToken(claims.JTI, "test revocation"); err != nil {
|
||||||
|
t.Fatalf("RevokeToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Error("handler should not be reached with revoked token")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAuthExpiredToken(t *testing.T) {
|
||||||
|
pub, priv := generateTestKey(t)
|
||||||
|
database := openTestDB(t)
|
||||||
|
|
||||||
|
// Issue an already-expired token.
|
||||||
|
tokenStr, _, err := token.IssueToken(priv, testIssuer, "user-uuid", nil, -time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Error("handler should not be reached with expired token")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireRoleGranted(t *testing.T) {
|
||||||
|
claims := &token.Claims{Roles: []string{"admin"}}
|
||||||
|
ctx := context.WithValue(context.Background(), claimsKey, claims)
|
||||||
|
|
||||||
|
reached := false
|
||||||
|
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
reached = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
if !reached {
|
||||||
|
t.Error("handler not reached with correct role")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireRoleForbidden(t *testing.T) {
|
||||||
|
claims := &token.Claims{Roles: []string{"reader"}}
|
||||||
|
ctx := context.WithValue(context.Background(), claimsKey, claims)
|
||||||
|
|
||||||
|
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Error("handler should not be reached without admin role")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("status = %d, want 403", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireRoleNoClaims(t *testing.T) {
|
||||||
|
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Error("handler should not be reached without claims in context")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("status = %d, want 403", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitAllows(t *testing.T) {
|
||||||
|
handler := RateLimit(10, 5)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
|
||||||
|
req.RemoteAddr = "127.0.0.1:12345"
|
||||||
|
|
||||||
|
// First 5 requests should be allowed (burst=5).
|
||||||
|
for i := range 5 {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("request %d: status = %d, want 200", i+1, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitBlocks(t *testing.T) {
|
||||||
|
handler := RateLimit(0.1, 2)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
|
||||||
|
req.RemoteAddr = "10.0.0.1:9999"
|
||||||
|
|
||||||
|
// Exhaust the burst of 2.
|
||||||
|
for range 2 {
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next request should be rate-limited.
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("status = %d, want 429 after burst exceeded", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractBearerToken(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
header string
|
||||||
|
wantErr bool
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"valid", "Bearer mytoken123", false, "mytoken123"},
|
||||||
|
{"missing header", "", true, ""},
|
||||||
|
{"no bearer prefix", "Token mytoken123", true, ""},
|
||||||
|
{"empty token", "Bearer ", true, ""},
|
||||||
|
{"case insensitive", "bearer mytoken123", false, "mytoken123"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if tc.header != "" {
|
||||||
|
req.Header.Set("Authorization", tc.header)
|
||||||
|
}
|
||||||
|
got, err := extractBearerToken(req)
|
||||||
|
if (err != nil) != tc.wantErr {
|
||||||
|
t.Errorf("wantErr=%v, got err=%v", tc.wantErr, err)
|
||||||
|
}
|
||||||
|
if !tc.wantErr && got != tc.want {
|
||||||
|
t.Errorf("token = %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
144
internal/model/model.go
Normal file
144
internal/model/model.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
// Package model defines the shared data types used throughout MCIAS.
|
||||||
|
// These are pure data definitions with no external dependencies.
|
||||||
|
package model
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// AccountType distinguishes human interactive accounts from non-interactive
|
||||||
|
// service accounts.
|
||||||
|
type AccountType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AccountTypeHuman AccountType = "human"
|
||||||
|
AccountTypeSystem AccountType = "system"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AccountStatus represents the lifecycle state of an account.
|
||||||
|
type AccountStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AccountStatusActive AccountStatus = "active"
|
||||||
|
AccountStatusInactive AccountStatus = "inactive"
|
||||||
|
AccountStatusDeleted AccountStatus = "deleted"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Account represents a user or service identity in MCIAS.
|
||||||
|
// Fields containing credential material (PasswordHash, TOTPSecretEnc) are
|
||||||
|
// never serialised into API responses — callers must explicitly omit them.
|
||||||
|
type Account struct {
|
||||||
|
ID int64 `json:"-"`
|
||||||
|
UUID string `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
AccountType AccountType `json:"account_type"`
|
||||||
|
Status AccountStatus `json:"status"`
|
||||||
|
TOTPRequired bool `json:"totp_required"`
|
||||||
|
|
||||||
|
// PasswordHash is a PHC-format Argon2id string. Never returned in API
|
||||||
|
// responses; populated only when reading from the database.
|
||||||
|
PasswordHash string `json:"-"`
|
||||||
|
|
||||||
|
// TOTPSecretEnc and TOTPSecretNonce hold the AES-256-GCM-encrypted TOTP
|
||||||
|
// shared secret. Never returned in API responses.
|
||||||
|
TOTPSecretEnc []byte `json:"-"`
|
||||||
|
TOTPSecretNonce []byte `json:"-"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
DeletedAt *time.Time `json:"deleted_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Role is a string label assigned to an account to grant permissions.
|
||||||
|
type Role struct {
|
||||||
|
ID int64 `json:"-"`
|
||||||
|
AccountID int64 `json:"-"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
GrantedBy *int64 `json:"-"`
|
||||||
|
GrantedAt time.Time `json:"granted_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenRecord tracks an issued JWT by its JTI for revocation purposes.
|
||||||
|
// The raw token string is never stored — only the JTI identifier.
|
||||||
|
type TokenRecord struct {
|
||||||
|
ID int64 `json:"-"`
|
||||||
|
JTI string `json:"jti"`
|
||||||
|
AccountID int64 `json:"-"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
IssuedAt time.Time `json:"issued_at"`
|
||||||
|
RevokedAt *time.Time `json:"revoked_at,omitempty"`
|
||||||
|
RevokeReason string `json:"revoke_reason,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRevoked reports whether the token has been explicitly revoked.
|
||||||
|
func (t *TokenRecord) IsRevoked() bool {
|
||||||
|
return t.RevokedAt != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired reports whether the token is past its expiry time.
|
||||||
|
func (t *TokenRecord) IsExpired() bool {
|
||||||
|
return time.Now().After(t.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemToken represents the current active service token for a system account.
|
||||||
|
type SystemToken struct {
|
||||||
|
ID int64 `json:"-"`
|
||||||
|
AccountID int64 `json:"-"`
|
||||||
|
JTI string `json:"jti"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PGCredential holds Postgres connection details for a system account.
|
||||||
|
// The password is encrypted at rest; PGPassword is only populated after
|
||||||
|
// decryption and must never be logged or included in API responses.
|
||||||
|
type PGCredential struct {
|
||||||
|
ID int64 `json:"-"`
|
||||||
|
AccountID int64 `json:"-"`
|
||||||
|
PGHost string `json:"host"`
|
||||||
|
PGPort int `json:"port"`
|
||||||
|
PGDatabase string `json:"database"`
|
||||||
|
PGUsername string `json:"username"`
|
||||||
|
|
||||||
|
// PGPassword is plaintext only after decryption. Never log or serialise.
|
||||||
|
PGPassword string `json:"-"`
|
||||||
|
|
||||||
|
// PGPasswordEnc and PGPasswordNonce are the AES-256-GCM ciphertext and
|
||||||
|
// nonce stored in the database.
|
||||||
|
PGPasswordEnc []byte `json:"-"`
|
||||||
|
PGPasswordNonce []byte `json:"-"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditEvent represents a single entry in the append-only audit log.
|
||||||
|
// Details must never contain credential material (passwords, tokens, secrets).
|
||||||
|
type AuditEvent struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
EventTime time.Time `json:"event_time"`
|
||||||
|
EventType string `json:"event_type"`
|
||||||
|
ActorID *int64 `json:"-"`
|
||||||
|
TargetID *int64 `json:"-"`
|
||||||
|
IPAddress string `json:"ip_address,omitempty"`
|
||||||
|
Details string `json:"details,omitempty"` // JSON string; no secrets
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audit event type constants — exhaustive list, enforced at write time.
|
||||||
|
const (
|
||||||
|
EventLoginOK = "login_ok"
|
||||||
|
EventLoginFail = "login_fail"
|
||||||
|
EventLoginTOTPFail = "login_totp_fail"
|
||||||
|
EventTokenIssued = "token_issued"
|
||||||
|
EventTokenRenewed = "token_renewed"
|
||||||
|
EventTokenRevoked = "token_revoked"
|
||||||
|
EventTokenExpired = "token_expired"
|
||||||
|
EventAccountCreated = "account_created"
|
||||||
|
EventAccountUpdated = "account_updated"
|
||||||
|
EventAccountDeleted = "account_deleted"
|
||||||
|
EventRoleGranted = "role_granted"
|
||||||
|
EventRoleRevoked = "role_revoked"
|
||||||
|
EventTOTPEnrolled = "totp_enrolled"
|
||||||
|
EventTOTPRemoved = "totp_removed"
|
||||||
|
EventPGCredAccessed = "pgcred_accessed"
|
||||||
|
EventPGCredUpdated = "pgcred_updated"
|
||||||
|
)
|
||||||
83
internal/model/model_test.go
Normal file
83
internal/model/model_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAccountTypeConstants(t *testing.T) {
|
||||||
|
if AccountTypeHuman != "human" {
|
||||||
|
t.Errorf("AccountTypeHuman = %q, want %q", AccountTypeHuman, "human")
|
||||||
|
}
|
||||||
|
if AccountTypeSystem != "system" {
|
||||||
|
t.Errorf("AccountTypeSystem = %q, want %q", AccountTypeSystem, "system")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountStatusConstants(t *testing.T) {
|
||||||
|
if AccountStatusActive != "active" {
|
||||||
|
t.Errorf("AccountStatusActive = %q, want %q", AccountStatusActive, "active")
|
||||||
|
}
|
||||||
|
if AccountStatusInactive != "inactive" {
|
||||||
|
t.Errorf("AccountStatusInactive = %q, want %q", AccountStatusInactive, "inactive")
|
||||||
|
}
|
||||||
|
if AccountStatusDeleted != "deleted" {
|
||||||
|
t.Errorf("AccountStatusDeleted = %q, want %q", AccountStatusDeleted, "deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenRecordIsRevoked(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
notRevoked := &TokenRecord{}
|
||||||
|
if notRevoked.IsRevoked() {
|
||||||
|
t.Error("expected token with nil RevokedAt to not be revoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
revoked := &TokenRecord{RevokedAt: &now}
|
||||||
|
if !revoked.IsRevoked() {
|
||||||
|
t.Error("expected token with RevokedAt set to be revoked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenRecordIsExpired(t *testing.T) {
|
||||||
|
past := time.Now().Add(-time.Hour)
|
||||||
|
future := time.Now().Add(time.Hour)
|
||||||
|
|
||||||
|
expired := &TokenRecord{ExpiresAt: past}
|
||||||
|
if !expired.IsExpired() {
|
||||||
|
t.Error("expected token with past ExpiresAt to be expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
valid := &TokenRecord{ExpiresAt: future}
|
||||||
|
if valid.IsExpired() {
|
||||||
|
t.Error("expected token with future ExpiresAt to not be expired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditEventConstants(t *testing.T) {
|
||||||
|
// Spot-check a few to ensure they are not empty strings.
|
||||||
|
events := []string{
|
||||||
|
EventLoginOK,
|
||||||
|
EventLoginFail,
|
||||||
|
EventLoginTOTPFail,
|
||||||
|
EventTokenIssued,
|
||||||
|
EventTokenRenewed,
|
||||||
|
EventTokenRevoked,
|
||||||
|
EventTokenExpired,
|
||||||
|
EventAccountCreated,
|
||||||
|
EventAccountUpdated,
|
||||||
|
EventAccountDeleted,
|
||||||
|
EventRoleGranted,
|
||||||
|
EventRoleRevoked,
|
||||||
|
EventTOTPEnrolled,
|
||||||
|
EventTOTPRemoved,
|
||||||
|
EventPGCredAccessed,
|
||||||
|
EventPGCredUpdated,
|
||||||
|
}
|
||||||
|
for _, e := range events {
|
||||||
|
if e == "" {
|
||||||
|
t.Errorf("audit event constant is empty string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
904
internal/server/server.go
Normal file
904
internal/server/server.go
Normal file
@@ -0,0 +1,904 @@
|
|||||||
|
// Package server wires together the HTTP router, middleware, and handlers
|
||||||
|
// for the MCIAS authentication server.
|
||||||
|
//
|
||||||
|
// Security design:
|
||||||
|
// - All endpoints use HTTPS (enforced at the listener level in cmd/mciassrv).
|
||||||
|
// - Authentication state is carried via JWT; no cookies or server-side sessions.
|
||||||
|
// - Credential fields (password hash, TOTP secret, Postgres password) are
|
||||||
|
// never included in any API response.
|
||||||
|
// - All JSON parsing uses strict decoders that reject unknown fields.
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/auth"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/config"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/crypto"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/middleware"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server holds the dependencies injected into all handlers.
|
||||||
|
type Server struct {
|
||||||
|
db *db.DB
|
||||||
|
cfg *config.Config
|
||||||
|
privKey ed25519.PrivateKey
|
||||||
|
pubKey ed25519.PublicKey
|
||||||
|
masterKey []byte
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a Server with the given dependencies.
|
||||||
|
func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed25519.PublicKey, masterKey []byte, logger *slog.Logger) *Server {
|
||||||
|
return &Server{
|
||||||
|
db: database,
|
||||||
|
cfg: cfg,
|
||||||
|
privKey: priv,
|
||||||
|
pubKey: pub,
|
||||||
|
masterKey: masterKey,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler builds and returns the root HTTP handler with all routes and middleware.
|
||||||
|
func (s *Server) Handler() http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
|
// Public endpoints (no authentication required).
|
||||||
|
mux.HandleFunc("GET /v1/health", s.handleHealth)
|
||||||
|
mux.HandleFunc("GET /v1/keys/public", s.handlePublicKey)
|
||||||
|
mux.HandleFunc("POST /v1/auth/login", s.handleLogin)
|
||||||
|
mux.HandleFunc("POST /v1/token/validate", s.handleTokenValidate)
|
||||||
|
|
||||||
|
// Authenticated endpoints.
|
||||||
|
requireAuth := middleware.RequireAuth(s.pubKey, s.db, s.cfg.Tokens.Issuer)
|
||||||
|
requireAdmin := func(h http.Handler) http.Handler {
|
||||||
|
return requireAuth(middleware.RequireRole("admin")(h))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth endpoints (require valid token).
|
||||||
|
mux.Handle("POST /v1/auth/logout", requireAuth(http.HandlerFunc(s.handleLogout)))
|
||||||
|
mux.Handle("POST /v1/auth/renew", requireAuth(http.HandlerFunc(s.handleRenew)))
|
||||||
|
mux.Handle("POST /v1/auth/totp/enroll", requireAuth(http.HandlerFunc(s.handleTOTPEnroll)))
|
||||||
|
mux.Handle("POST /v1/auth/totp/confirm", requireAuth(http.HandlerFunc(s.handleTOTPConfirm)))
|
||||||
|
|
||||||
|
// Admin-only endpoints.
|
||||||
|
mux.Handle("DELETE /v1/auth/totp", requireAdmin(http.HandlerFunc(s.handleTOTPRemove)))
|
||||||
|
mux.Handle("POST /v1/token/issue", requireAdmin(http.HandlerFunc(s.handleTokenIssue)))
|
||||||
|
mux.Handle("DELETE /v1/token/{jti}", requireAdmin(http.HandlerFunc(s.handleTokenRevoke)))
|
||||||
|
mux.Handle("GET /v1/accounts", requireAdmin(http.HandlerFunc(s.handleListAccounts)))
|
||||||
|
mux.Handle("POST /v1/accounts", requireAdmin(http.HandlerFunc(s.handleCreateAccount)))
|
||||||
|
mux.Handle("GET /v1/accounts/{id}", requireAdmin(http.HandlerFunc(s.handleGetAccount)))
|
||||||
|
mux.Handle("PATCH /v1/accounts/{id}", requireAdmin(http.HandlerFunc(s.handleUpdateAccount)))
|
||||||
|
mux.Handle("DELETE /v1/accounts/{id}", requireAdmin(http.HandlerFunc(s.handleDeleteAccount)))
|
||||||
|
mux.Handle("GET /v1/accounts/{id}/roles", requireAdmin(http.HandlerFunc(s.handleGetRoles)))
|
||||||
|
mux.Handle("PUT /v1/accounts/{id}/roles", requireAdmin(http.HandlerFunc(s.handleSetRoles)))
|
||||||
|
mux.Handle("GET /v1/accounts/{id}/pgcreds", requireAdmin(http.HandlerFunc(s.handleGetPGCreds)))
|
||||||
|
mux.Handle("PUT /v1/accounts/{id}/pgcreds", requireAdmin(http.HandlerFunc(s.handleSetPGCreds)))
|
||||||
|
|
||||||
|
// Apply global middleware: logging and login-path rate limiting.
|
||||||
|
var root http.Handler = mux
|
||||||
|
root = middleware.RequestLogger(s.logger)(root)
|
||||||
|
return root
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Public handlers ----
|
||||||
|
|
||||||
|
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handlePublicKey returns the server's Ed25519 public key in JWK format.
|
||||||
|
// This allows relying parties to independently verify JWTs.
|
||||||
|
func (s *Server) handlePublicKey(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Encode the Ed25519 public key as a JWK (RFC 8037).
|
||||||
|
// The "x" parameter is the base64url-encoded public key bytes.
|
||||||
|
jwk := map[string]string{
|
||||||
|
"kty": "OKP",
|
||||||
|
"crv": "Ed25519",
|
||||||
|
"use": "sig",
|
||||||
|
"alg": "EdDSA",
|
||||||
|
"x": encodeBase64URL(s.pubKey),
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, jwk)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Auth handlers ----
|
||||||
|
|
||||||
|
// loginRequest is the request body for POST /v1/auth/login.
|
||||||
|
type loginRequest struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
TOTPCode string `json:"totp_code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// loginResponse is the response body for a successful login.
|
||||||
|
type loginResponse struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresAt string `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req loginRequest
|
||||||
|
if !decodeJSON(w, r, &req) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Username == "" || req.Password == "" {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "username and password are required", "bad_request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load account by username.
|
||||||
|
acct, err := s.db.GetAccountByUsername(req.Username)
|
||||||
|
if err != nil {
|
||||||
|
// Security: return a generic error whether the user exists or not.
|
||||||
|
// Always run a dummy Argon2 check to prevent timing-based user enumeration.
|
||||||
|
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
|
||||||
|
s.writeAudit(r, model.EventLoginFail, nil, nil, fmt.Sprintf(`{"username":%q,"reason":"unknown_user"}`, req.Username))
|
||||||
|
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Security: Check account status before credential verification to avoid
|
||||||
|
// leaking whether the account exists based on timing differences.
|
||||||
|
if acct.Status != model.AccountStatusActive {
|
||||||
|
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
|
||||||
|
s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, fmt.Sprintf(`{"reason":"account_inactive"}`))
|
||||||
|
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify password. This is always run, even for system accounts (which have
|
||||||
|
// no password hash), to maintain constant timing.
|
||||||
|
ok, err := auth.VerifyPassword(req.Password, acct.PasswordHash)
|
||||||
|
if err != nil || !ok {
|
||||||
|
s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, `{"reason":"wrong_password"}`)
|
||||||
|
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TOTP check (if enrolled).
|
||||||
|
if acct.TOTPRequired {
|
||||||
|
if req.TOTPCode == "" {
|
||||||
|
s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, `{"reason":"totp_missing"}`)
|
||||||
|
middleware.WriteError(w, http.StatusUnauthorized, "TOTP code required", "totp_required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Decrypt the TOTP secret.
|
||||||
|
secret, err := crypto.OpenAESGCM(s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("decrypt TOTP secret", "error", err, "account_id", acct.ID)
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
valid, err := auth.ValidateTOTP(secret, req.TOTPCode)
|
||||||
|
if err != nil || !valid {
|
||||||
|
s.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"wrong_totp"}`)
|
||||||
|
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine expiry.
|
||||||
|
expiry := s.cfg.DefaultExpiry()
|
||||||
|
roles, err := s.db.GetRoles(acct.ID)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, r := range roles {
|
||||||
|
if r == "admin" {
|
||||||
|
expiry = s.cfg.AdminExpiry()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenStr, claims, err := token.IssueToken(s.privKey, s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("issue token", "error", err)
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||||
|
s.logger.Error("track token", "error", err)
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventLoginOK, &acct.ID, nil, "")
|
||||||
|
s.writeAudit(r, model.EventTokenIssued, &acct.ID, nil, fmt.Sprintf(`{"jti":%q}`, claims.JTI))
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, loginResponse{
|
||||||
|
Token: tokenStr,
|
||||||
|
ExpiresAt: claims.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims := middleware.ClaimsFromContext(r.Context())
|
||||||
|
if err := s.db.RevokeToken(claims.JTI, "logout"); err != nil {
|
||||||
|
s.logger.Error("revoke token on logout", "error", err)
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.writeAudit(r, model.EventTokenRevoked, nil, nil, fmt.Sprintf(`{"jti":%q,"reason":"logout"}`, claims.JTI))
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims := middleware.ClaimsFromContext(r.Context())
|
||||||
|
|
||||||
|
// Load account to get current roles (they may have changed since token issuance).
|
||||||
|
acct, err := s.db.GetAccountByUUID(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusUnauthorized, "account not found", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if acct.Status != model.AccountStatusActive {
|
||||||
|
middleware.WriteError(w, http.StatusUnauthorized, "account inactive", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
roles, err := s.db.GetRoles(acct.ID)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
expiry := s.cfg.DefaultExpiry()
|
||||||
|
for _, role := range roles {
|
||||||
|
if role == "admin" {
|
||||||
|
expiry = s.cfg.AdminExpiry()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newTokenStr, newClaims, err := token.IssueToken(s.privKey, s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke the old token and track the new one atomically is not possible
|
||||||
|
// in SQLite without a transaction. We do best-effort: revoke old, track new.
|
||||||
|
if err := s.db.RevokeToken(claims.JTI, "renewed"); err != nil {
|
||||||
|
s.logger.Error("revoke old token on renew", "error", err)
|
||||||
|
}
|
||||||
|
if err := s.db.TrackToken(newClaims.JTI, acct.ID, newClaims.IssuedAt, newClaims.ExpiresAt); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventTokenRenewed, &acct.ID, nil, fmt.Sprintf(`{"old_jti":%q,"new_jti":%q}`, claims.JTI, newClaims.JTI))
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, loginResponse{
|
||||||
|
Token: newTokenStr,
|
||||||
|
ExpiresAt: newClaims.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Token endpoints ----
|
||||||
|
|
||||||
|
type validateRequest struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type validateResponse struct {
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
Subject string `json:"sub,omitempty"`
|
||||||
|
Roles []string `json:"roles,omitempty"`
|
||||||
|
ExpiresAt string `json:"expires_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleTokenValidate(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Accept token either from Authorization: Bearer header or JSON body.
|
||||||
|
tokenStr, err := extractBearerFromRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
// Try JSON body.
|
||||||
|
var req validateRequest
|
||||||
|
if !decodeJSON(w, r, &req) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tokenStr = req.Token
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenStr == "" {
|
||||||
|
writeJSON(w, http.StatusOK, validateResponse{Valid: false})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := token.ValidateToken(s.pubKey, tokenStr, s.cfg.Tokens.Issuer)
|
||||||
|
if err != nil {
|
||||||
|
writeJSON(w, http.StatusOK, validateResponse{Valid: false})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rec, err := s.db.GetTokenRecord(claims.JTI)
|
||||||
|
if err != nil || rec.IsRevoked() {
|
||||||
|
writeJSON(w, http.StatusOK, validateResponse{Valid: false})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, validateResponse{
|
||||||
|
Valid: true,
|
||||||
|
Subject: claims.Subject,
|
||||||
|
Roles: claims.Roles,
|
||||||
|
ExpiresAt: claims.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type issueTokenRequest struct {
|
||||||
|
AccountID string `json:"account_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleTokenIssue(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req issueTokenRequest
|
||||||
|
if !decodeJSON(w, r, &req) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
acct, err := s.db.GetAccountByUUID(req.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusNotFound, "account not found", "not_found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if acct.AccountType != model.AccountTypeSystem {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "token issue is only for system accounts", "bad_request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenStr, claims, err := token.IssueToken(s.privKey, s.cfg.Tokens.Issuer, acct.UUID, nil, s.cfg.ServiceExpiry())
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke existing system token if any.
|
||||||
|
existing, err := s.db.GetSystemToken(acct.ID)
|
||||||
|
if err == nil && existing != nil {
|
||||||
|
_ = s.db.RevokeToken(existing.JTI, "rotated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.db.SetSystemToken(acct.ID, claims.JTI, claims.ExpiresAt); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
actor := middleware.ClaimsFromContext(r.Context())
|
||||||
|
var actorID *int64
|
||||||
|
if actor != nil {
|
||||||
|
if a, err := s.db.GetAccountByUUID(actor.Subject); err == nil {
|
||||||
|
actorID = &a.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.writeAudit(r, model.EventTokenIssued, actorID, &acct.ID, fmt.Sprintf(`{"jti":%q}`, claims.JTI))
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, loginResponse{
|
||||||
|
Token: tokenStr,
|
||||||
|
ExpiresAt: claims.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleTokenRevoke(w http.ResponseWriter, r *http.Request) {
|
||||||
|
jti := r.PathValue("jti")
|
||||||
|
if jti == "" {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "jti is required", "bad_request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.RevokeToken(jti, "admin revocation"); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusNotFound, "token not found or already revoked", "not_found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventTokenRevoked, nil, nil, fmt.Sprintf(`{"jti":%q}`, jti))
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Account endpoints ----
|
||||||
|
|
||||||
|
type createAccountRequest struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password,omitempty"`
|
||||||
|
Type string `json:"account_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type accountResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
AccountType string `json:"account_type"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
TOTPEnabled bool `json:"totp_enabled"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func accountToResponse(a *model.Account) accountResponse {
|
||||||
|
resp := accountResponse{
|
||||||
|
ID: a.UUID,
|
||||||
|
Username: a.Username,
|
||||||
|
AccountType: string(a.AccountType),
|
||||||
|
Status: string(a.Status),
|
||||||
|
TOTPEnabled: a.TOTPRequired,
|
||||||
|
CreatedAt: a.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||||
|
UpdatedAt: a.UpdatedAt.Format("2006-01-02T15:04:05Z"),
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleListAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
|
accounts, err := s.db.ListAccounts()
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp := make([]accountResponse, len(accounts))
|
||||||
|
for i, a := range accounts {
|
||||||
|
resp[i] = accountToResponse(a)
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req createAccountRequest
|
||||||
|
if !decodeJSON(w, r, &req) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Username == "" {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "username is required", "bad_request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
accountType := model.AccountType(req.Type)
|
||||||
|
if accountType != model.AccountTypeHuman && accountType != model.AccountTypeSystem {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "account_type must be 'human' or 'system'", "bad_request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var passwordHash string
|
||||||
|
if accountType == model.AccountTypeHuman {
|
||||||
|
if req.Password == "" {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "password is required for human accounts", "bad_request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
passwordHash, err = auth.HashPassword(req.Password, auth.ArgonParams{
|
||||||
|
Time: s.cfg.Argon2.Time,
|
||||||
|
Memory: s.cfg.Argon2.Memory,
|
||||||
|
Threads: s.cfg.Argon2.Threads,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
acct, err := s.db.CreateAccount(req.Username, accountType, passwordHash)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusConflict, "username already exists", "conflict")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventAccountCreated, nil, &acct.ID, fmt.Sprintf(`{"username":%q}`, acct.Username))
|
||||||
|
writeJSON(w, http.StatusCreated, accountToResponse(acct))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleGetAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
acct, ok := s.loadAccount(w, r)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, accountToResponse(acct))
|
||||||
|
}
|
||||||
|
|
||||||
|
type updateAccountRequest struct {
|
||||||
|
Status string `json:"status,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleUpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
acct, ok := s.loadAccount(w, r)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req updateAccountRequest
|
||||||
|
if !decodeJSON(w, r, &req) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Status != "" {
|
||||||
|
newStatus := model.AccountStatus(req.Status)
|
||||||
|
if newStatus != model.AccountStatusActive && newStatus != model.AccountStatusInactive {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "status must be 'active' or 'inactive'", "bad_request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.db.UpdateAccountStatus(acct.ID, newStatus); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventAccountUpdated, nil, &acct.ID, "")
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
acct, ok := s.loadAccount(w, r)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.UpdateAccountStatus(acct.ID, model.AccountStatusDeleted); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.db.RevokeAllUserTokens(acct.ID, "account deleted"); err != nil {
|
||||||
|
s.logger.Error("revoke tokens on delete", "error", err, "account_id", acct.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventAccountDeleted, nil, &acct.ID, "")
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Role endpoints ----
|
||||||
|
|
||||||
|
type rolesResponse struct {
|
||||||
|
Roles []string `json:"roles"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type setRolesRequest struct {
|
||||||
|
Roles []string `json:"roles"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleGetRoles(w http.ResponseWriter, r *http.Request) {
|
||||||
|
acct, ok := s.loadAccount(w, r)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
roles, err := s.db.GetRoles(acct.ID)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if roles == nil {
|
||||||
|
roles = []string{}
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, rolesResponse{Roles: roles})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleSetRoles(w http.ResponseWriter, r *http.Request) {
|
||||||
|
acct, ok := s.loadAccount(w, r)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req setRolesRequest
|
||||||
|
if !decodeJSON(w, r, &req) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
actor := middleware.ClaimsFromContext(r.Context())
|
||||||
|
var grantedBy *int64
|
||||||
|
if actor != nil {
|
||||||
|
if a, err := s.db.GetAccountByUUID(actor.Subject); err == nil {
|
||||||
|
grantedBy = &a.ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.SetRoles(acct.ID, req.Roles, grantedBy); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventRoleGranted, grantedBy, &acct.ID, fmt.Sprintf(`{"roles":%v}`, req.Roles))
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- TOTP endpoints ----
|
||||||
|
|
||||||
|
type totpEnrollResponse struct {
|
||||||
|
Secret string `json:"secret"` // base32-encoded
|
||||||
|
OTPAuthURI string `json:"otpauth_uri"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type totpConfirmRequest struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleTOTPEnroll(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims := middleware.ClaimsFromContext(r.Context())
|
||||||
|
acct, err := s.db.GetAccountByUUID(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusUnauthorized, "account not found", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawSecret, b32Secret, err := auth.GenerateTOTPSecret()
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt the secret before storing it temporarily.
|
||||||
|
// Note: we store as pending; enrollment is confirmed with /confirm.
|
||||||
|
secretEnc, secretNonce, err := crypto.SealAESGCM(s.masterKey, rawSecret)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the encrypted pending secret. The totp_required flag is NOT set
|
||||||
|
// yet — it is set only after the user confirms the code.
|
||||||
|
if err := s.db.SetTOTP(acct.ID, secretEnc, secretNonce); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
otpURI := fmt.Sprintf("otpauth://totp/MCIAS:%s?secret=%s&issuer=MCIAS", acct.Username, b32Secret)
|
||||||
|
|
||||||
|
// Security: return the secret for display to the user. It is only shown
|
||||||
|
// once; subsequent reads are not possible (only the encrypted form is stored).
|
||||||
|
writeJSON(w, http.StatusOK, totpEnrollResponse{
|
||||||
|
Secret: b32Secret,
|
||||||
|
OTPAuthURI: otpURI,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleTOTPConfirm(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req totpConfirmRequest
|
||||||
|
if !decodeJSON(w, r, &req) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := middleware.ClaimsFromContext(r.Context())
|
||||||
|
acct, err := s.db.GetAccountByUUID(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusUnauthorized, "account not found", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if acct.TOTPSecretEnc == nil {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "TOTP enrollment not started", "bad_request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
secret, err := crypto.OpenAESGCM(s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := auth.ValidateTOTP(secret, req.Code)
|
||||||
|
if err != nil || !valid {
|
||||||
|
middleware.WriteError(w, http.StatusUnauthorized, "invalid TOTP code", "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark TOTP as confirmed and required.
|
||||||
|
if err := s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventTOTPEnrolled, &acct.ID, nil, "")
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
type totpRemoveRequest struct {
|
||||||
|
AccountID string `json:"account_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleTOTPRemove(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req totpRemoveRequest
|
||||||
|
if !decodeJSON(w, r, &req) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
acct, err := s.db.GetAccountByUUID(req.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusNotFound, "account not found", "not_found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.ClearTOTP(acct.ID); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventTOTPRemoved, nil, &acct.ID, "")
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Postgres credential endpoints ----
|
||||||
|
|
||||||
|
type pgCredRequest struct {
|
||||||
|
Host string `json:"host"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
Database string `json:"database"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type pgCredResponse struct {
|
||||||
|
Host string `json:"host"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
Database string `json:"database"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
// Security: Password is NEVER included in the response, even on GET.
|
||||||
|
// The caller must explicitly decrypt it on the server side.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleGetPGCreds(w http.ResponseWriter, r *http.Request) {
|
||||||
|
acct, ok := s.loadAccount(w, r)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cred, err := s.db.ReadPGCredentials(acct.ID)
|
||||||
|
if err != nil {
|
||||||
|
if err == db.ErrNotFound {
|
||||||
|
middleware.WriteError(w, http.StatusNotFound, "no credentials stored", "not_found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt the password to return it to the admin caller.
|
||||||
|
password, err := crypto.OpenAESGCM(s.masterKey, cred.PGPasswordNonce, cred.PGPasswordEnc)
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventPGCredAccessed, nil, &acct.ID, "")
|
||||||
|
|
||||||
|
// Return including password since this is an explicit admin retrieval.
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"host": cred.PGHost,
|
||||||
|
"port": cred.PGPort,
|
||||||
|
"database": cred.PGDatabase,
|
||||||
|
"username": cred.PGUsername,
|
||||||
|
"password": string(password), // included only for admin retrieval
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleSetPGCreds(w http.ResponseWriter, r *http.Request) {
|
||||||
|
acct, ok := s.loadAccount(w, r)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req pgCredRequest
|
||||||
|
if !decodeJSON(w, r, &req) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Host == "" || req.Database == "" || req.Username == "" || req.Password == "" {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "host, database, username, and password are required", "bad_request")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Port == 0 {
|
||||||
|
req.Port = 5432
|
||||||
|
}
|
||||||
|
|
||||||
|
enc, nonce, err := crypto.SealAESGCM(s.masterKey, []byte(req.Password))
|
||||||
|
if err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.WritePGCredentials(acct.ID, req.Host, req.Port, req.Database, req.Username, enc, nonce); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.writeAudit(r, model.EventPGCredUpdated, nil, &acct.ID, "")
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Helpers ----
|
||||||
|
|
||||||
|
// loadAccount retrieves an account by the {id} path parameter (UUID).
|
||||||
|
func (s *Server) loadAccount(w http.ResponseWriter, r *http.Request) (*model.Account, bool) {
|
||||||
|
id := r.PathValue("id")
|
||||||
|
if id == "" {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "account id is required", "bad_request")
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
acct, err := s.db.GetAccountByUUID(id)
|
||||||
|
if err != nil {
|
||||||
|
if err == db.ErrNotFound {
|
||||||
|
middleware.WriteError(w, http.StatusNotFound, "account not found", "not_found")
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return acct, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeAudit appends an audit log entry, logging errors but not failing the request.
|
||||||
|
func (s *Server) writeAudit(r *http.Request, eventType string, actorID, targetID *int64, details string) {
|
||||||
|
ip := r.RemoteAddr
|
||||||
|
if err := s.db.WriteAuditEvent(eventType, actorID, targetID, ip, details); err != nil {
|
||||||
|
s.logger.Error("write audit event", "error", err, "event_type", eventType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeJSON encodes v as JSON and writes it to w with the given status code.
|
||||||
|
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||||
|
// If encoding fails, the status is already written; log but don't panic.
|
||||||
|
_ = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeJSON decodes a JSON request body into v.
|
||||||
|
// Returns false and writes a 400 response if decoding fails.
|
||||||
|
func decodeJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
|
||||||
|
dec := json.NewDecoder(r.Body)
|
||||||
|
dec.DisallowUnknownFields()
|
||||||
|
if err := dec.Decode(v); err != nil {
|
||||||
|
middleware.WriteError(w, http.StatusBadRequest, "invalid JSON request body", "bad_request")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractBearerFromRequest extracts a Bearer token from the Authorization header.
|
||||||
|
func extractBearerFromRequest(r *http.Request) (string, error) {
|
||||||
|
auth := r.Header.Get("Authorization")
|
||||||
|
if auth == "" {
|
||||||
|
return "", fmt.Errorf("no Authorization header")
|
||||||
|
}
|
||||||
|
const prefix = "Bearer "
|
||||||
|
if len(auth) <= len(prefix) {
|
||||||
|
return "", fmt.Errorf("malformed Authorization header")
|
||||||
|
}
|
||||||
|
return auth[len(prefix):], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeBase64URL encodes bytes as base64url without padding.
|
||||||
|
func encodeBase64URL(b []byte) string {
|
||||||
|
const table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
|
||||||
|
result := make([]byte, 0, (len(b)*4+2)/3)
|
||||||
|
for i := 0; i < len(b); i += 3 {
|
||||||
|
switch {
|
||||||
|
case i+2 < len(b):
|
||||||
|
result = append(result,
|
||||||
|
table[b[i]>>2],
|
||||||
|
table[(b[i]&3)<<4|b[i+1]>>4],
|
||||||
|
table[(b[i+1]&0xf)<<2|b[i+2]>>6],
|
||||||
|
table[b[i+2]&0x3f],
|
||||||
|
)
|
||||||
|
case i+1 < len(b):
|
||||||
|
result = append(result,
|
||||||
|
table[b[i]>>2],
|
||||||
|
table[(b[i]&3)<<4|b[i+1]>>4],
|
||||||
|
table[(b[i+1]&0xf)<<2],
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
result = append(result,
|
||||||
|
table[b[i]>>2],
|
||||||
|
table[(b[i]&3)<<4],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
434
internal/server/server_test.go
Normal file
434
internal/server/server_test.go
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/auth"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/config"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||||
|
"git.wntrmute.dev/kyle/mcias/internal/token"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testIssuer = "https://auth.example.com"
|
||||||
|
|
||||||
|
func newTestServer(t *testing.T) (*Server, ed25519.PublicKey, ed25519.PrivateKey, *db.DB) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
database, err := db.Open(":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open db: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Migrate(database); err != nil {
|
||||||
|
t.Fatalf("migrate db: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = database.Close() })
|
||||||
|
|
||||||
|
masterKey := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(masterKey); err != nil {
|
||||||
|
t.Fatalf("generate master key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := config.NewTestConfig(testIssuer)
|
||||||
|
|
||||||
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
srv := New(database, cfg, priv, pub, masterKey, logger)
|
||||||
|
return srv, pub, priv, database
|
||||||
|
}
|
||||||
|
|
||||||
|
// createTestHumanAccount creates a human account with password "testpass123".
|
||||||
|
func createTestHumanAccount(t *testing.T, srv *Server, username string) *model.Account {
|
||||||
|
t.Helper()
|
||||||
|
hash, err := auth.HashPassword("testpass123", auth.ArgonParams{Time: 3, Memory: 65536, Threads: 4})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("hash password: %v", err)
|
||||||
|
}
|
||||||
|
acct, err := srv.db.CreateAccount(username, model.AccountTypeHuman, hash)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create account: %v", err)
|
||||||
|
}
|
||||||
|
return acct
|
||||||
|
}
|
||||||
|
|
||||||
|
// issueAdminToken creates an account with admin role, issues a JWT, and tracks it.
|
||||||
|
func issueAdminToken(t *testing.T, srv *Server, priv ed25519.PrivateKey, username string) (string, *model.Account) {
|
||||||
|
t.Helper()
|
||||||
|
acct := createTestHumanAccount(t, srv, username)
|
||||||
|
if err := srv.db.GrantRole(acct.ID, "admin", nil); err != nil {
|
||||||
|
t.Fatalf("grant admin role: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, []string{"admin"}, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("issue token: %v", err)
|
||||||
|
}
|
||||||
|
if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||||
|
t.Fatalf("track token: %v", err)
|
||||||
|
}
|
||||||
|
return tokenStr, acct
|
||||||
|
}
|
||||||
|
|
||||||
|
func doRequest(t *testing.T, handler http.Handler, method, path string, body interface{}, authToken string) *httptest.ResponseRecorder {
|
||||||
|
t.Helper()
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if body != nil {
|
||||||
|
b, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal body: %v", err)
|
||||||
|
}
|
||||||
|
bodyReader = bytes.NewReader(b)
|
||||||
|
} else {
|
||||||
|
bodyReader = bytes.NewReader(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(method, path, bodyReader)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if authToken != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
return rr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealth(t *testing.T) {
|
||||||
|
srv, _, _, _ := newTestServer(t)
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "GET", "/v1/health", nil, "")
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("health status = %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublicKey(t *testing.T) {
|
||||||
|
srv, _, _, _ := newTestServer(t)
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "GET", "/v1/keys/public", nil, "")
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("public key status = %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var jwk map[string]string
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &jwk); err != nil {
|
||||||
|
t.Fatalf("unmarshal JWK: %v", err)
|
||||||
|
}
|
||||||
|
if jwk["kty"] != "OKP" {
|
||||||
|
t.Errorf("kty = %q, want OKP", jwk["kty"])
|
||||||
|
}
|
||||||
|
if jwk["alg"] != "EdDSA" {
|
||||||
|
t.Errorf("alg = %q, want EdDSA", jwk["alg"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginSuccess(t *testing.T) {
|
||||||
|
srv, _, _, _ := newTestServer(t)
|
||||||
|
createTestHumanAccount(t, srv, "alice")
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{
|
||||||
|
"username": "alice",
|
||||||
|
"password": "testpass123",
|
||||||
|
}, "")
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("login status = %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp loginResponse
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal login response: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Token == "" {
|
||||||
|
t.Error("expected non-empty token in login response")
|
||||||
|
}
|
||||||
|
if resp.ExpiresAt == "" {
|
||||||
|
t.Error("expected non-empty expires_at in login response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginWrongPassword(t *testing.T) {
|
||||||
|
srv, _, _, _ := newTestServer(t)
|
||||||
|
createTestHumanAccount(t, srv, "bob")
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{
|
||||||
|
"username": "bob",
|
||||||
|
"password": "wrongpassword",
|
||||||
|
}, "")
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginUnknownUser(t *testing.T) {
|
||||||
|
srv, _, _, _ := newTestServer(t)
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{
|
||||||
|
"username": "nobody",
|
||||||
|
"password": "password",
|
||||||
|
}, "")
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Errorf("status = %d, want 401", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginResponseDoesNotContainCredentials(t *testing.T) {
|
||||||
|
srv, _, _, _ := newTestServer(t)
|
||||||
|
createTestHumanAccount(t, srv, "charlie")
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{
|
||||||
|
"username": "charlie",
|
||||||
|
"password": "testpass123",
|
||||||
|
}, "")
|
||||||
|
|
||||||
|
body := rr.Body.String()
|
||||||
|
// Security: password hash must never appear in any API response.
|
||||||
|
if strings.Contains(body, "argon2id") || strings.Contains(body, "password_hash") {
|
||||||
|
t.Error("login response contains password hash material")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenValidate(t *testing.T) {
|
||||||
|
srv, _, priv, _ := newTestServer(t)
|
||||||
|
acct := createTestHumanAccount(t, srv, "dave")
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
// Issue and track a token.
|
||||||
|
tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||||
|
t.Fatalf("TrackToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/v1/token/validate", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("validate status = %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp validateResponse
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if !resp.Valid {
|
||||||
|
t.Error("expected valid=true for valid token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogout(t *testing.T) {
|
||||||
|
srv, _, priv, _ := newTestServer(t)
|
||||||
|
acct := createTestHumanAccount(t, srv, "eve")
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||||
|
t.Fatalf("TrackToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logout.
|
||||||
|
rr := doRequest(t, handler, "POST", "/v1/auth/logout", nil, tokenStr)
|
||||||
|
if rr.Code != http.StatusNoContent {
|
||||||
|
t.Errorf("logout status = %d, want 204; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token should now be invalid on validate.
|
||||||
|
req := httptest.NewRequest("POST", "/v1/token/validate", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||||
|
rr2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr2, req)
|
||||||
|
|
||||||
|
var resp validateResponse
|
||||||
|
_ = json.Unmarshal(rr2.Body.Bytes(), &resp)
|
||||||
|
if resp.Valid {
|
||||||
|
t.Error("expected valid=false after logout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAccountAdmin(t *testing.T) {
|
||||||
|
srv, _, priv, _ := newTestServer(t)
|
||||||
|
adminToken, _ := issueAdminToken(t, srv, priv, "admin-user")
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "POST", "/v1/accounts", map[string]string{
|
||||||
|
"username": "new-user",
|
||||||
|
"password": "newpassword123",
|
||||||
|
"account_type": "human",
|
||||||
|
}, adminToken)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusCreated {
|
||||||
|
t.Errorf("create account status = %d, want 201; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp accountResponse
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Username != "new-user" {
|
||||||
|
t.Errorf("Username = %q, want %q", resp.Username, "new-user")
|
||||||
|
}
|
||||||
|
// Security: password hash must not appear in account response.
|
||||||
|
body := rr.Body.String()
|
||||||
|
if strings.Contains(body, "password_hash") || strings.Contains(body, "argon2id") {
|
||||||
|
t.Error("account creation response contains password hash")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAccountRequiresAdmin(t *testing.T) {
|
||||||
|
srv, _, priv, _ := newTestServer(t)
|
||||||
|
acct := createTestHumanAccount(t, srv, "regular-user")
|
||||||
|
tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, []string{"reader"}, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||||
|
t.Fatalf("TrackToken: %v", err)
|
||||||
|
}
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "POST", "/v1/accounts", map[string]string{
|
||||||
|
"username": "other-user",
|
||||||
|
"password": "password",
|
||||||
|
"account_type": "human",
|
||||||
|
}, tokenStr)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("status = %d, want 403", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListAccounts(t *testing.T) {
|
||||||
|
srv, _, priv, _ := newTestServer(t)
|
||||||
|
adminToken, _ := issueAdminToken(t, srv, priv, "admin2")
|
||||||
|
createTestHumanAccount(t, srv, "user1")
|
||||||
|
createTestHumanAccount(t, srv, "user2")
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "GET", "/v1/accounts", nil, adminToken)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("list accounts status = %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var accounts []accountResponse
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &accounts); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if len(accounts) < 3 { // admin + user1 + user2
|
||||||
|
t.Errorf("expected at least 3 accounts, got %d", len(accounts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Security: no credential fields in any response.
|
||||||
|
body := rr.Body.String()
|
||||||
|
for _, bad := range []string{"password_hash", "argon2id", "totp_secret", "PasswordHash"} {
|
||||||
|
if strings.Contains(body, bad) {
|
||||||
|
t.Errorf("account list response contains credential field %q", bad)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAccount(t *testing.T) {
|
||||||
|
srv, _, priv, _ := newTestServer(t)
|
||||||
|
adminToken, _ := issueAdminToken(t, srv, priv, "admin3")
|
||||||
|
target := createTestHumanAccount(t, srv, "delete-me")
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "DELETE", "/v1/accounts/"+target.UUID, nil, adminToken)
|
||||||
|
if rr.Code != http.StatusNoContent {
|
||||||
|
t.Errorf("delete status = %d, want 204; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetAndGetRoles(t *testing.T) {
|
||||||
|
srv, _, priv, _ := newTestServer(t)
|
||||||
|
adminToken, _ := issueAdminToken(t, srv, priv, "admin4")
|
||||||
|
target := createTestHumanAccount(t, srv, "role-target")
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
// Set roles.
|
||||||
|
rr := doRequest(t, handler, "PUT", "/v1/accounts/"+target.UUID+"/roles", map[string][]string{
|
||||||
|
"roles": {"reader", "writer"},
|
||||||
|
}, adminToken)
|
||||||
|
if rr.Code != http.StatusNoContent {
|
||||||
|
t.Errorf("set roles status = %d, want 204; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get roles.
|
||||||
|
rr2 := doRequest(t, handler, "GET", "/v1/accounts/"+target.UUID+"/roles", nil, adminToken)
|
||||||
|
if rr2.Code != http.StatusOK {
|
||||||
|
t.Errorf("get roles status = %d, want 200", rr2.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp rolesResponse
|
||||||
|
if err := json.Unmarshal(rr2.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if len(resp.Roles) != 2 {
|
||||||
|
t.Errorf("expected 2 roles, got %d", len(resp.Roles))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRenewToken(t *testing.T) {
|
||||||
|
srv, _, priv, _ := newTestServer(t)
|
||||||
|
acct := createTestHumanAccount(t, srv, "renew-user")
|
||||||
|
handler := srv.Handler()
|
||||||
|
|
||||||
|
oldTokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
oldJTI := claims.JTI
|
||||||
|
if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||||
|
t.Fatalf("TrackToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := doRequest(t, handler, "POST", "/v1/auth/renew", nil, oldTokenStr)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("renew status = %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp loginResponse
|
||||||
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal renew response: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Token == "" || resp.Token == oldTokenStr {
|
||||||
|
t.Error("expected new, distinct token after renewal")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Old token should be revoked in the database.
|
||||||
|
rec, err := srv.db.GetTokenRecord(oldJTI)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTokenRecord: %v", err)
|
||||||
|
}
|
||||||
|
if !rec.IsRevoked() {
|
||||||
|
t.Error("old token should be revoked after renewal")
|
||||||
|
}
|
||||||
|
}
|
||||||
181
internal/token/token.go
Normal file
181
internal/token/token.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
// Package token handles JWT issuance, validation, and revocation for MCIAS.
|
||||||
|
//
|
||||||
|
// Security design:
|
||||||
|
// - Algorithm header is checked FIRST, before any signature verification.
|
||||||
|
// This prevents algorithm-confusion attacks (CVE-2022-21449 class).
|
||||||
|
// - Only "EdDSA" is accepted; "none", HS*, RS*, ES* are all rejected.
|
||||||
|
// - The signing key is taken from the server's keystore, never from the token.
|
||||||
|
// - All standard claims (exp, iat, iss, jti) are required and validated.
|
||||||
|
// - JTIs are UUIDs generated from crypto/rand (via google/uuid).
|
||||||
|
// - Token values are never stored; only JTIs are recorded for revocation.
|
||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// requiredAlg is the only JWT algorithm accepted by MCIAS.
|
||||||
|
// Security: Hard-coding this as a constant rather than a variable ensures
|
||||||
|
// it cannot be changed at runtime and cannot be confused by token headers.
|
||||||
|
requiredAlg = "EdDSA"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claims holds the MCIAS-specific JWT claims.
|
||||||
|
type Claims struct {
|
||||||
|
// Standard registered claims.
|
||||||
|
Issuer string `json:"iss"`
|
||||||
|
Subject string `json:"sub"` // account UUID
|
||||||
|
IssuedAt time.Time `json:"iat"`
|
||||||
|
ExpiresAt time.Time `json:"exp"`
|
||||||
|
JTI string `json:"jti"`
|
||||||
|
|
||||||
|
// MCIAS-specific claims.
|
||||||
|
Roles []string `json:"roles"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtClaims adapts Claims to the golang-jwt MapClaims interface.
|
||||||
|
type jwtClaims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
Roles []string `json:"roles"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrExpiredToken is returned when the token's exp claim is in the past.
|
||||||
|
var ErrExpiredToken = errors.New("token: expired")
|
||||||
|
|
||||||
|
// ErrInvalidSignature is returned when Ed25519 signature verification fails.
|
||||||
|
var ErrInvalidSignature = errors.New("token: invalid signature")
|
||||||
|
|
||||||
|
// ErrWrongAlgorithm is returned when the alg header is not EdDSA.
|
||||||
|
var ErrWrongAlgorithm = errors.New("token: algorithm must be EdDSA")
|
||||||
|
|
||||||
|
// ErrMissingClaim is returned when a required claim is absent or empty.
|
||||||
|
var ErrMissingClaim = errors.New("token: missing required claim")
|
||||||
|
|
||||||
|
// IssueToken creates and signs a new JWT with the given claims.
|
||||||
|
// The jti is generated automatically using crypto/rand via uuid.New().
|
||||||
|
// Returns the signed token string.
|
||||||
|
//
|
||||||
|
// Security: The signing key is provided by the caller from the server's
|
||||||
|
// keystore. The alg header is set explicitly to "EdDSA" by the jwt library
|
||||||
|
// when an ed25519.PrivateKey is passed to SignedString.
|
||||||
|
func IssueToken(key ed25519.PrivateKey, issuer, subject string, roles []string, expiry time.Duration) (string, *Claims, error) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
exp := now.Add(expiry)
|
||||||
|
jti := uuid.New().String()
|
||||||
|
|
||||||
|
jc := jwtClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: issuer,
|
||||||
|
Subject: subject,
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
ExpiresAt: jwt.NewNumericDate(exp),
|
||||||
|
ID: jti,
|
||||||
|
},
|
||||||
|
Roles: roles,
|
||||||
|
}
|
||||||
|
|
||||||
|
t := jwt.NewWithClaims(jwt.SigningMethodEdDSA, jc)
|
||||||
|
signed, err := t.SignedString(key)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("token: sign JWT: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := &Claims{
|
||||||
|
Issuer: issuer,
|
||||||
|
Subject: subject,
|
||||||
|
IssuedAt: now,
|
||||||
|
ExpiresAt: exp,
|
||||||
|
JTI: jti,
|
||||||
|
Roles: roles,
|
||||||
|
}
|
||||||
|
return signed, claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken parses and validates a JWT string.
|
||||||
|
//
|
||||||
|
// Security order of operations (all must pass):
|
||||||
|
// 1. Parse the token header and extract the alg field.
|
||||||
|
// 2. Reject immediately if alg != "EdDSA" (before any signature check).
|
||||||
|
// 3. Verify Ed25519 signature.
|
||||||
|
// 4. Validate exp, iat, iss, jti claims.
|
||||||
|
//
|
||||||
|
// Returns Claims on success, or a typed error on any failure.
|
||||||
|
// The caller is responsible for checking revocation status via the DB.
|
||||||
|
func ValidateToken(key ed25519.PublicKey, tokenString, expectedIssuer string) (*Claims, error) {
|
||||||
|
// Step 1+2: Parse the header to check alg BEFORE any crypto.
|
||||||
|
// Security: We use jwt.ParseWithClaims with an explicit key function that
|
||||||
|
// enforces the algorithm. The key function is called by the library after
|
||||||
|
// parsing the header but before verifying the signature, which is the
|
||||||
|
// correct point to enforce algorithm constraints.
|
||||||
|
var jc jwtClaims
|
||||||
|
t, err := jwt.ParseWithClaims(tokenString, &jc, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
// Security: Check alg header first. This must happen in the key
|
||||||
|
// function — it is the only place where the parsed (but unverified)
|
||||||
|
// header is available before signature validation.
|
||||||
|
if t.Method.Alg() != requiredAlg {
|
||||||
|
return nil, fmt.Errorf("%w: got %q, want %q", ErrWrongAlgorithm, t.Method.Alg(), requiredAlg)
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
},
|
||||||
|
jwt.WithIssuedAt(),
|
||||||
|
jwt.WithIssuer(expectedIssuer),
|
||||||
|
jwt.WithExpirationRequired(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
// Map library errors to our typed errors for consistent handling.
|
||||||
|
if errors.Is(err, ErrWrongAlgorithm) {
|
||||||
|
return nil, ErrWrongAlgorithm
|
||||||
|
}
|
||||||
|
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||||
|
return nil, ErrExpiredToken
|
||||||
|
}
|
||||||
|
if errors.Is(err, jwt.ErrSignatureInvalid) {
|
||||||
|
return nil, ErrInvalidSignature
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("token: parse: %w", err)
|
||||||
|
}
|
||||||
|
if !t.Valid {
|
||||||
|
return nil, ErrInvalidSignature
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Validate required custom claims.
|
||||||
|
if jc.ID == "" {
|
||||||
|
return nil, fmt.Errorf("%w: jti", ErrMissingClaim)
|
||||||
|
}
|
||||||
|
if jc.Subject == "" {
|
||||||
|
return nil, fmt.Errorf("%w: sub", ErrMissingClaim)
|
||||||
|
}
|
||||||
|
if jc.ExpiresAt == nil {
|
||||||
|
return nil, fmt.Errorf("%w: exp", ErrMissingClaim)
|
||||||
|
}
|
||||||
|
if jc.IssuedAt == nil {
|
||||||
|
return nil, fmt.Errorf("%w: iat", ErrMissingClaim)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := &Claims{
|
||||||
|
Issuer: jc.Issuer,
|
||||||
|
Subject: jc.Subject,
|
||||||
|
IssuedAt: jc.IssuedAt.Time,
|
||||||
|
ExpiresAt: jc.ExpiresAt.Time,
|
||||||
|
JTI: jc.ID,
|
||||||
|
Roles: jc.Roles,
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRole reports whether the claims include the given role.
|
||||||
|
func (c *Claims) HasRole(role string) bool {
|
||||||
|
for _, r := range c.Roles {
|
||||||
|
if r == role {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
222
internal/token/token_test.go
Normal file
222
internal/token/token_test.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) {
|
||||||
|
t.Helper()
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate key: %v", err)
|
||||||
|
}
|
||||||
|
return pub, priv
|
||||||
|
}
|
||||||
|
|
||||||
|
// b64url encodes a string as base64url without padding.
|
||||||
|
func b64url(s string) string {
|
||||||
|
return base64.RawURLEncoding.EncodeToString([]byte(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
const testIssuer = "https://auth.example.com"
|
||||||
|
|
||||||
|
func TestIssueAndValidateToken(t *testing.T) {
|
||||||
|
pub, priv := generateTestKey(t)
|
||||||
|
roles := []string{"admin", "reader"}
|
||||||
|
|
||||||
|
tokenStr, claims, err := IssueToken(priv, testIssuer, "user-uuid-1", roles, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
if tokenStr == "" {
|
||||||
|
t.Fatal("IssueToken returned empty token string")
|
||||||
|
}
|
||||||
|
if claims.JTI == "" {
|
||||||
|
t.Error("JTI must not be empty")
|
||||||
|
}
|
||||||
|
if claims.Subject != "user-uuid-1" {
|
||||||
|
t.Errorf("Subject = %q, want %q", claims.Subject, "user-uuid-1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the token.
|
||||||
|
got, err := ValidateToken(pub, tokenStr, testIssuer)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ValidateToken: %v", err)
|
||||||
|
}
|
||||||
|
if got.Subject != "user-uuid-1" {
|
||||||
|
t.Errorf("validated Subject = %q, want %q", got.Subject, "user-uuid-1")
|
||||||
|
}
|
||||||
|
if got.JTI != claims.JTI {
|
||||||
|
t.Errorf("validated JTI = %q, want %q", got.JTI, claims.JTI)
|
||||||
|
}
|
||||||
|
if len(got.Roles) != 2 {
|
||||||
|
t.Errorf("validated Roles = %v, want 2 roles", got.Roles)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTokenWrongAlgorithm verifies that tokens with non-EdDSA alg are
|
||||||
|
// rejected immediately, before any signature verification.
|
||||||
|
// Security: This tests the core defence against algorithm-confusion attacks.
|
||||||
|
func TestValidateTokenWrongAlgorithm(t *testing.T) {
|
||||||
|
_, priv := generateTestKey(t)
|
||||||
|
pub, _ := generateTestKey(t) // different key — but alg check should fail first
|
||||||
|
|
||||||
|
// Forge a token signed with HMAC-SHA256 (alg: HS256).
|
||||||
|
hmacToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||||
|
"iss": testIssuer,
|
||||||
|
"sub": "attacker",
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"jti": "fake-jti",
|
||||||
|
})
|
||||||
|
// Use the Ed25519 public key bytes as the HMAC secret (classic alg confusion).
|
||||||
|
hs256Signed, err := hmacToken.SignedString([]byte(priv.Public().(ed25519.PublicKey)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sign HS256 token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = ValidateToken(pub, hs256Signed, testIssuer)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for HS256 token, got nil")
|
||||||
|
}
|
||||||
|
if err != ErrWrongAlgorithm {
|
||||||
|
t.Errorf("expected ErrWrongAlgorithm, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTokenAlgNone verifies that "none" algorithm is rejected.
|
||||||
|
// Security: "none" algorithm tokens have no signature and must always be
|
||||||
|
// rejected regardless of payload content.
|
||||||
|
func TestValidateTokenAlgNone(t *testing.T) {
|
||||||
|
pub, _ := generateTestKey(t)
|
||||||
|
|
||||||
|
// Construct a "none" algorithm token manually.
|
||||||
|
// golang-jwt/v5 disallows signing with "none" directly, so we craft it
|
||||||
|
// using raw base64url encoding.
|
||||||
|
header := `{"alg":"none","typ":"JWT"}`
|
||||||
|
payload := `{"iss":"https://auth.example.com","sub":"evil","iat":1000000,"exp":9999999999,"jti":"evil-jti"}`
|
||||||
|
noneToken := b64url(header) + "." + b64url(payload) + "."
|
||||||
|
|
||||||
|
_, err := ValidateToken(pub, noneToken, testIssuer)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 'none' algorithm token, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTokenExpired verifies that expired tokens are rejected.
|
||||||
|
func TestValidateTokenExpired(t *testing.T) {
|
||||||
|
pub, priv := generateTestKey(t)
|
||||||
|
|
||||||
|
// Issue a token with a negative expiry (already expired).
|
||||||
|
tokenStr, _, err := IssueToken(priv, testIssuer, "user", nil, -time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = ValidateToken(pub, tokenStr, testIssuer)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for expired token, got nil")
|
||||||
|
}
|
||||||
|
if err != ErrExpiredToken {
|
||||||
|
t.Errorf("expected ErrExpiredToken, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTokenTamperedSignature verifies that signature tampering is caught.
|
||||||
|
func TestValidateTokenTamperedSignature(t *testing.T) {
|
||||||
|
pub, priv := generateTestKey(t)
|
||||||
|
|
||||||
|
tokenStr, _, err := IssueToken(priv, testIssuer, "user", nil, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tamper: flip a byte in the signature (last segment).
|
||||||
|
parts := strings.Split(tokenStr, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Fatalf("unexpected token format: %d parts", len(parts))
|
||||||
|
}
|
||||||
|
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decode signature: %v", err)
|
||||||
|
}
|
||||||
|
sigBytes[0] ^= 0x01 // flip one bit
|
||||||
|
parts[2] = base64.RawURLEncoding.EncodeToString(sigBytes)
|
||||||
|
tampered := strings.Join(parts, ".")
|
||||||
|
|
||||||
|
_, err = ValidateToken(pub, tampered, testIssuer)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for tampered signature, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTokenWrongKey verifies that a token signed with a different key
|
||||||
|
// is rejected.
|
||||||
|
func TestValidateTokenWrongKey(t *testing.T) {
|
||||||
|
_, priv := generateTestKey(t)
|
||||||
|
wrongPub, _ := generateTestKey(t)
|
||||||
|
|
||||||
|
tokenStr, _, err := IssueToken(priv, testIssuer, "user", nil, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = ValidateToken(wrongPub, tokenStr, testIssuer)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for wrong key, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidateTokenWrongIssuer verifies that tokens from a different issuer
|
||||||
|
// are rejected.
|
||||||
|
func TestValidateTokenWrongIssuer(t *testing.T) {
|
||||||
|
pub, priv := generateTestKey(t)
|
||||||
|
|
||||||
|
tokenStr, _, err := IssueToken(priv, "https://evil.example.com", "user", nil, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = ValidateToken(pub, tokenStr, testIssuer)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for wrong issuer, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestJTIsAreUnique verifies that two issued tokens have different JTIs.
|
||||||
|
func TestJTIsAreUnique(t *testing.T) {
|
||||||
|
_, priv := generateTestKey(t)
|
||||||
|
|
||||||
|
_, c1, err := IssueToken(priv, testIssuer, "user", nil, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken (1): %v", err)
|
||||||
|
}
|
||||||
|
_, c2, err := IssueToken(priv, testIssuer, "user", nil, time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("IssueToken (2): %v", err)
|
||||||
|
}
|
||||||
|
if c1.JTI == c2.JTI {
|
||||||
|
t.Error("two issued tokens have the same JTI")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClaimsHasRole verifies role checking.
|
||||||
|
func TestClaimsHasRole(t *testing.T) {
|
||||||
|
c := &Claims{Roles: []string{"admin", "reader"}}
|
||||||
|
if !c.HasRole("admin") {
|
||||||
|
t.Error("expected HasRole(admin) = true")
|
||||||
|
}
|
||||||
|
if !c.HasRole("reader") {
|
||||||
|
t.Error("expected HasRole(reader) = true")
|
||||||
|
}
|
||||||
|
if c.HasRole("writer") {
|
||||||
|
t.Error("expected HasRole(writer) = false")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user