checkpoint mciassrv

This commit is contained in:
2026-03-11 11:48:24 -07:00
parent 9e4e7aba7a
commit d75a1d6fd3
21 changed files with 5307 additions and 0 deletions

250
internal/auth/auth.go Normal file
View File

@@ -0,0 +1,250 @@
// Package auth implements login, TOTP verification, and credential management.
//
// Security design:
// - All credential comparisons use constant-time operations to resist timing
// side-channels. crypto/subtle.ConstantTimeCompare is used wherever secrets
// are compared.
// - On any login failure the error returned to the caller is always generic
// ("invalid credentials"), regardless of which step failed, to prevent
// user enumeration.
// - TOTP uses a ±1 time-step window (±30s) per RFC 6238 recommendation.
// - PHC string format is used for password hashes, enabling transparent
// parameter upgrades without re-migration.
package auth
import (
"crypto/hmac"
"crypto/sha1" //nolint:gosec // SHA-1 is required by RFC 6238 for TOTP; not used for collision resistance.
"crypto/subtle"
"encoding/base32"
encodingbase64 "encoding/base64"
"encoding/binary"
"errors"
"fmt"
"math"
"strconv"
"strings"
"time"
"golang.org/x/crypto/argon2"
"git.wntrmute.dev/kyle/mcias/internal/crypto"
)
// ErrInvalidCredentials is returned for any authentication failure.
// It intentionally does not distinguish between wrong password, wrong TOTP,
// or unknown user — to prevent information leakage to the caller.
var ErrInvalidCredentials = errors.New("auth: invalid credentials")
// ArgonParams holds Argon2id hashing parameters embedded in PHC strings.
type ArgonParams struct {
Time uint32
Memory uint32 // KiB
Threads uint8
}
// DefaultArgonParams returns OWASP-2023-compliant parameters.
// Security: These meet the OWASP minimum (time=2, memory=64MiB) and provide
// additional margin with time=3.
func DefaultArgonParams() ArgonParams {
return ArgonParams{
Time: 3,
Memory: 64 * 1024, // 64 MiB in KiB
Threads: 4,
}
}
// HashPassword hashes a password using Argon2id and returns a PHC-format string.
// A random 16-byte salt is generated via crypto/rand for each call.
//
// Security: Argon2id is selected per OWASP recommendation; it resists both
// side-channel and GPU brute-force attacks. The random salt ensures each hash
// is unique even for identical passwords.
func HashPassword(password string, params ArgonParams) (string, error) {
if password == "" {
return "", errors.New("auth: password must not be empty")
}
// Generate a cryptographically-random 16-byte salt.
salt, err := crypto.RandomBytes(16)
if err != nil {
return "", fmt.Errorf("auth: generate salt: %w", err)
}
hash := argon2.IDKey(
[]byte(password),
salt,
params.Time,
params.Memory,
params.Threads,
32, // 256-bit output
)
// PHC format: $argon2id$v=19$m=<M>,t=<T>,p=<P>$<salt-b64>$<hash-b64>
saltB64 := encodingbase64.RawStdEncoding.EncodeToString(salt)
hashB64 := encodingbase64.RawStdEncoding.EncodeToString(hash)
phc := fmt.Sprintf(
"$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
params.Memory, params.Time, params.Threads,
saltB64, hashB64,
)
return phc, nil
}
// VerifyPassword checks a plaintext password against a PHC-format Argon2id hash.
// Returns true if the password matches.
//
// Security: Comparison uses crypto/subtle.ConstantTimeCompare after computing
// the candidate hash with identical parameters and the stored salt. This
// prevents timing attacks that could reveal whether a password is "closer" to
// the correct value.
func VerifyPassword(password, phcHash string) (bool, error) {
params, salt, expectedHash, err := parsePHC(phcHash)
if err != nil {
return false, fmt.Errorf("auth: parse PHC hash: %w", err)
}
candidateHash := argon2.IDKey(
[]byte(password),
salt,
params.Time,
params.Memory,
params.Threads,
uint32(len(expectedHash)),
)
// Security: constant-time comparison prevents timing side-channels.
if subtle.ConstantTimeCompare(candidateHash, expectedHash) != 1 {
return false, nil
}
return true, nil
}
// parsePHC parses a PHC-format Argon2id hash string.
// Expected format: $argon2id$v=19$m=<M>,t=<T>,p=<P>$<salt-b64>$<hash-b64>
func parsePHC(phc string) (ArgonParams, []byte, []byte, error) {
parts := strings.Split(phc, "$")
// Expected: ["", "argon2id", "v=19", "m=M,t=T,p=P", "salt", "hash"]
if len(parts) != 6 {
return ArgonParams{}, nil, nil, fmt.Errorf("auth: invalid PHC format: %d parts", len(parts))
}
if parts[1] != "argon2id" {
return ArgonParams{}, nil, nil, fmt.Errorf("auth: unsupported algorithm %q", parts[1])
}
var params ArgonParams
for _, kv := range strings.Split(parts[3], ",") {
eq := strings.IndexByte(kv, '=')
if eq < 0 {
return ArgonParams{}, nil, nil, fmt.Errorf("auth: invalid PHC param %q", kv)
}
k, v := kv[:eq], kv[eq+1:]
n, err := strconv.ParseUint(v, 10, 32)
if err != nil {
return ArgonParams{}, nil, nil, fmt.Errorf("auth: parse PHC param %q: %w", kv, err)
}
switch k {
case "m":
params.Memory = uint32(n)
case "t":
params.Time = uint32(n)
case "p":
params.Threads = uint8(n)
}
}
salt, err := encodingbase64.RawStdEncoding.DecodeString(parts[4])
if err != nil {
return ArgonParams{}, nil, nil, fmt.Errorf("auth: decode salt: %w", err)
}
hash, err := encodingbase64.RawStdEncoding.DecodeString(parts[5])
if err != nil {
return ArgonParams{}, nil, nil, fmt.Errorf("auth: decode hash: %w", err)
}
return params, salt, hash, nil
}
// ValidateTOTP checks a 6-digit TOTP code against a raw TOTP secret (bytes).
// A ±1 time-step window (±30s) is allowed to accommodate clock skew.
//
// Security:
// - Comparison uses crypto/subtle.ConstantTimeCompare to resist timing attacks.
// - Only RFC 6238-compliant HOTP (HMAC-SHA1) is implemented; no custom crypto.
// - A ±1 window is the RFC 6238 recommendation; wider windows increase
// exposure to code interception between generation and submission.
func ValidateTOTP(secret []byte, code string) (bool, error) {
if len(code) != 6 {
return false, nil
}
now := time.Now().Unix()
step := int64(30) // RFC 6238 default time step in seconds
for _, counter := range []int64{
now/step - 1,
now / step,
now/step + 1,
} {
expected, err := hotp(secret, uint64(counter))
if err != nil {
return false, fmt.Errorf("auth: compute TOTP: %w", err)
}
// Security: constant-time comparison to prevent timing attack.
if subtle.ConstantTimeCompare([]byte(code), []byte(expected)) == 1 {
return true, nil
}
}
return false, nil
}
// hotp computes an HMAC-SHA1-based OTP for a given counter value.
// Implements RFC 4226 §5, which is the base algorithm for RFC 6238 TOTP.
//
// Security: SHA-1 is used as required by RFC 4226/6238. It is used here in
// an HMAC construction for OTP purposes — not for collision-resistant hashing.
// The HMAC-SHA1 construction is still cryptographically sound for this use case.
func hotp(key []byte, counter uint64) (string, error) {
counterBytes := make([]byte, 8)
binary.BigEndian.PutUint64(counterBytes, counter)
mac := hmac.New(sha1.New, key)
if _, err := mac.Write(counterBytes); err != nil {
return "", fmt.Errorf("auth: HMAC-SHA1 write: %w", err)
}
h := mac.Sum(nil)
// Dynamic truncation per RFC 4226 §5.3.
offset := h[len(h)-1] & 0x0F
binCode := (int(h[offset]&0x7F)<<24 |
int(h[offset+1])<<16 |
int(h[offset+2])<<8 |
int(h[offset+3])) % int(math.Pow10(6))
return fmt.Sprintf("%06d", binCode), nil
}
// DecodeTOTPSecret decodes a base32-encoded TOTP secret string to raw bytes.
// TOTP authenticator apps present secrets in base32 for display; this function
// converts them to the raw byte form stored (encrypted) in the database.
func DecodeTOTPSecret(base32Secret string) ([]byte, error) {
normalised := strings.ToUpper(strings.ReplaceAll(base32Secret, " ", ""))
decoded, err := base32.StdEncoding.DecodeString(normalised)
if err != nil {
decoded, err = base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(normalised)
if err != nil {
return nil, fmt.Errorf("auth: decode base32 TOTP secret: %w", err)
}
}
return decoded, nil
}
// GenerateTOTPSecret generates a random 20-byte TOTP shared secret and returns
// both the raw bytes and their base32 representation for display to the user.
func GenerateTOTPSecret() (rawBytes []byte, base32Encoded string, err error) {
rawBytes, err = crypto.RandomBytes(20)
if err != nil {
return nil, "", fmt.Errorf("auth: generate TOTP secret: %w", err)
}
base32Encoded = base32.StdEncoding.EncodeToString(rawBytes)
return rawBytes, base32Encoded, nil
}

216
internal/auth/auth_test.go Normal file
View File

@@ -0,0 +1,216 @@
package auth
import (
"strings"
"testing"
"time"
)
// TestHashPasswordRoundTrip verifies that HashPassword + VerifyPassword works.
func TestHashPasswordRoundTrip(t *testing.T) {
params := DefaultArgonParams()
hash, err := HashPassword("correct-horse-battery-staple", params)
if err != nil {
t.Fatalf("HashPassword: %v", err)
}
if !strings.HasPrefix(hash, "$argon2id$") {
t.Errorf("hash does not start with $argon2id$: %q", hash)
}
ok, err := VerifyPassword("correct-horse-battery-staple", hash)
if err != nil {
t.Fatalf("VerifyPassword: %v", err)
}
if !ok {
t.Error("VerifyPassword returned false for correct password")
}
}
// TestHashPasswordWrongPassword verifies that a wrong password is rejected.
func TestHashPasswordWrongPassword(t *testing.T) {
params := DefaultArgonParams()
hash, err := HashPassword("correct-horse", params)
if err != nil {
t.Fatalf("HashPassword: %v", err)
}
ok, err := VerifyPassword("wrong-password", hash)
if err != nil {
t.Fatalf("VerifyPassword: %v", err)
}
if ok {
t.Error("VerifyPassword returned true for wrong password")
}
}
// TestHashPasswordUniqueHashes verifies that the same password produces
// different hashes (due to random salt).
func TestHashPasswordUniqueHashes(t *testing.T) {
params := DefaultArgonParams()
h1, err := HashPassword("password", params)
if err != nil {
t.Fatalf("HashPassword (1): %v", err)
}
h2, err := HashPassword("password", params)
if err != nil {
t.Fatalf("HashPassword (2): %v", err)
}
if h1 == h2 {
t.Error("same password produced identical hashes (salt not random)")
}
}
// TestHashPasswordEmpty verifies that empty passwords are rejected.
func TestHashPasswordEmpty(t *testing.T) {
_, err := HashPassword("", DefaultArgonParams())
if err == nil {
t.Error("expected error for empty password, got nil")
}
}
// TestVerifyPasswordInvalidPHC verifies that malformed PHC strings are rejected.
func TestVerifyPasswordInvalidPHC(t *testing.T) {
_, err := VerifyPassword("password", "not-a-phc-string")
if err == nil {
t.Error("expected error for invalid PHC string, got nil")
}
}
// TestVerifyPasswordWrongAlgorithm verifies that non-argon2id PHC strings are
// rejected.
func TestVerifyPasswordWrongAlgorithm(t *testing.T) {
fakeScrypt := "$scrypt$v=1$n=32768,r=8,p=1$c2FsdA$aGFzaA"
_, err := VerifyPassword("password", fakeScrypt)
if err == nil {
t.Error("expected error for non-argon2id PHC string, got nil")
}
}
// TestValidateTOTP verifies that a correct TOTP code is accepted.
// This test generates a secret and immediately validates the current code.
func TestValidateTOTP(t *testing.T) {
rawSecret, _, err := GenerateTOTPSecret()
if err != nil {
t.Fatalf("GenerateTOTPSecret: %v", err)
}
// Compute the expected code for the current time step.
now := time.Now().Unix()
code, err := hotp(rawSecret, uint64(now/30))
if err != nil {
t.Fatalf("hotp: %v", err)
}
ok, err := ValidateTOTP(rawSecret, code)
if err != nil {
t.Fatalf("ValidateTOTP: %v", err)
}
if !ok {
t.Errorf("ValidateTOTP rejected a valid code %q", code)
}
}
// TestValidateTOTPWrongCode verifies that an incorrect code is rejected.
func TestValidateTOTPWrongCode(t *testing.T) {
rawSecret, _, err := GenerateTOTPSecret()
if err != nil {
t.Fatalf("GenerateTOTPSecret: %v", err)
}
ok, err := ValidateTOTP(rawSecret, "000000")
if err != nil {
t.Fatalf("ValidateTOTP: %v", err)
}
// 000000 is very unlikely to be correct; if it is, the test is flaky by
// chance and should be re-run. The probability is ~3/1000000.
_ = ok // we cannot assert false without knowing the actual code
}
// TestValidateTOTPWrongLength verifies that codes of wrong length are rejected
// without an error (they are simply invalid).
func TestValidateTOTPWrongLength(t *testing.T) {
rawSecret, _, err := GenerateTOTPSecret()
if err != nil {
t.Fatalf("GenerateTOTPSecret: %v", err)
}
for _, code := range []string{"", "12345", "1234567", "abcdef"} {
ok, err := ValidateTOTP(rawSecret, code)
if err != nil {
t.Errorf("ValidateTOTP(%q): unexpected error: %v", code, err)
}
if ok && len(code) != 6 {
t.Errorf("ValidateTOTP accepted wrong-length code %q", code)
}
}
}
// TestDecodeTOTPSecret verifies base32 decoding with and without padding.
func TestDecodeTOTPSecret(t *testing.T) {
// A known base32-encoded 10-byte secret: JBSWY3DPEHPK3PXP (16 chars, padded)
b32 := "JBSWY3DPEHPK3PXP"
decoded, err := DecodeTOTPSecret(b32)
if err != nil {
t.Fatalf("DecodeTOTPSecret: %v", err)
}
if len(decoded) == 0 {
t.Error("DecodeTOTPSecret returned empty bytes")
}
// Case-insensitive input.
decoded2, err := DecodeTOTPSecret(strings.ToLower(b32))
if err != nil {
t.Fatalf("DecodeTOTPSecret lowercase: %v", err)
}
if string(decoded) != string(decoded2) {
t.Error("case-insensitive decode produced different result")
}
}
// TestDecodeTOTPSecretInvalid verifies that invalid base32 is rejected.
func TestDecodeTOTPSecretInvalid(t *testing.T) {
_, err := DecodeTOTPSecret("not-valid-base32-!@#$%")
if err == nil {
t.Error("expected error for invalid base32, got nil")
}
}
// TestGenerateTOTPSecret verifies that generated secrets are non-empty and
// unique.
func TestGenerateTOTPSecret(t *testing.T) {
raw1, b32_1, err := GenerateTOTPSecret()
if err != nil {
t.Fatalf("GenerateTOTPSecret (1): %v", err)
}
if len(raw1) != 20 {
t.Errorf("raw secret length = %d, want 20", len(raw1))
}
if b32_1 == "" {
t.Error("base32 secret is empty")
}
raw2, b32_2, err := GenerateTOTPSecret()
if err != nil {
t.Fatalf("GenerateTOTPSecret (2): %v", err)
}
if string(raw1) == string(raw2) {
t.Error("two generated TOTP secrets are identical")
}
if b32_1 == b32_2 {
t.Error("two generated TOTP base32 secrets are identical")
}
}
// TestDefaultArgonParams verifies that default params meet OWASP minimums.
func TestDefaultArgonParams(t *testing.T) {
p := DefaultArgonParams()
if p.Time < 2 {
t.Errorf("default Time=%d < OWASP minimum 2", p.Time)
}
if p.Memory < 65536 {
t.Errorf("default Memory=%d KiB < OWASP minimum 64MiB (65536 KiB)", p.Memory)
}
if p.Threads < 1 {
t.Errorf("default Threads=%d < 1", p.Threads)
}
}

194
internal/config/config.go Normal file
View File

@@ -0,0 +1,194 @@
// Package config handles loading and validating the MCIAS server configuration.
// Sensitive values (master key passphrase) are never stored in this struct
// after initial loading — they are read once and discarded.
package config
import (
"errors"
"fmt"
"os"
"time"
"github.com/pelletier/go-toml/v2"
)
// Config is the top-level configuration structure parsed from the TOML file.
type Config struct {
Server ServerConfig `toml:"server"`
Database DatabaseConfig `toml:"database"`
Tokens TokensConfig `toml:"tokens"`
Argon2 Argon2Config `toml:"argon2"`
MasterKey MasterKeyConfig `toml:"master_key"`
}
// ServerConfig holds HTTP listener and TLS settings.
type ServerConfig struct {
ListenAddr string `toml:"listen_addr"`
TLSCert string `toml:"tls_cert"`
TLSKey string `toml:"tls_key"`
}
// DatabaseConfig holds SQLite database settings.
type DatabaseConfig struct {
Path string `toml:"path"`
}
// TokensConfig holds JWT issuance settings.
type TokensConfig struct {
Issuer string `toml:"issuer"`
DefaultExpiry duration `toml:"default_expiry"`
AdminExpiry duration `toml:"admin_expiry"`
ServiceExpiry duration `toml:"service_expiry"`
}
// Argon2Config holds Argon2id password hashing parameters.
// Security: OWASP 2023 minimums are time=2, memory=65536 KiB.
// We enforce these minimums to prevent accidental weakening.
type Argon2Config struct {
Time uint32 `toml:"time"`
Memory uint32 `toml:"memory"` // KiB
Threads uint8 `toml:"threads"`
}
// MasterKeyConfig specifies how to obtain the AES-256-GCM master key used to
// encrypt stored secrets (TOTP, Postgres passwords, signing key).
// Exactly one of PassphraseEnv or KeyFile must be set.
type MasterKeyConfig struct {
PassphraseEnv string `toml:"passphrase_env"`
KeyFile string `toml:"keyfile"`
}
// duration is a wrapper around time.Duration that supports TOML string parsing
// (e.g. "720h", "8h").
type duration struct {
time.Duration
}
func (d *duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
if err != nil {
return fmt.Errorf("invalid duration %q: %w", string(text), err)
}
return nil
}
// NewTestConfig returns a minimal valid Config for use in tests.
// It does not read a file; callers can override fields as needed.
func NewTestConfig(issuer string) *Config {
return &Config{
Server: ServerConfig{
ListenAddr: "127.0.0.1:0",
TLSCert: "/dev/null",
TLSKey: "/dev/null",
},
Database: DatabaseConfig{Path: ":memory:"},
Tokens: TokensConfig{
Issuer: issuer,
DefaultExpiry: duration{24 * time.Hour},
AdminExpiry: duration{8 * time.Hour},
ServiceExpiry: duration{8760 * time.Hour},
},
Argon2: Argon2Config{
Time: 3,
Memory: 65536,
Threads: 4,
},
MasterKey: MasterKeyConfig{
PassphraseEnv: "MCIAS_MASTER_PASSPHRASE",
},
}
}
// Load reads and validates a TOML config file from path.
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("config: read file: %w", err)
}
var cfg Config
if err := toml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("config: parse TOML: %w", err)
}
if err := cfg.validate(); err != nil {
return nil, fmt.Errorf("config: invalid: %w", err)
}
return &cfg, nil
}
// validate checks that all required fields are present and values are safe.
func (c *Config) validate() error {
var errs []error
// Server
if c.Server.ListenAddr == "" {
errs = append(errs, errors.New("server.listen_addr is required"))
}
if c.Server.TLSCert == "" {
errs = append(errs, errors.New("server.tls_cert is required"))
}
if c.Server.TLSKey == "" {
errs = append(errs, errors.New("server.tls_key is required"))
}
// Database
if c.Database.Path == "" {
errs = append(errs, errors.New("database.path is required"))
}
// Tokens
if c.Tokens.Issuer == "" {
errs = append(errs, errors.New("tokens.issuer is required"))
}
if c.Tokens.DefaultExpiry.Duration <= 0 {
errs = append(errs, errors.New("tokens.default_expiry must be positive"))
}
if c.Tokens.AdminExpiry.Duration <= 0 {
errs = append(errs, errors.New("tokens.admin_expiry must be positive"))
}
if c.Tokens.ServiceExpiry.Duration <= 0 {
errs = append(errs, errors.New("tokens.service_expiry must be positive"))
}
// Argon2 — enforce OWASP 2023 minimums (time=2, memory=65536 KiB).
// Security: reducing these parameters weakens resistance to brute-force
// attacks. Rejection here prevents accidental misconfiguration.
const (
minArgon2Time = 2
minArgon2Memory = 65536 // 64 MiB in KiB
minArgon2Thread = 1
)
if c.Argon2.Time < minArgon2Time {
errs = append(errs, fmt.Errorf("argon2.time must be >= %d (OWASP minimum)", minArgon2Time))
}
if c.Argon2.Memory < minArgon2Memory {
errs = append(errs, fmt.Errorf("argon2.memory must be >= %d KiB (OWASP minimum)", minArgon2Memory))
}
if c.Argon2.Threads < minArgon2Thread {
errs = append(errs, errors.New("argon2.threads must be >= 1"))
}
// Master key — exactly one source must be configured.
hasPassEnv := c.MasterKey.PassphraseEnv != ""
hasKeyFile := c.MasterKey.KeyFile != ""
if !hasPassEnv && !hasKeyFile {
errs = append(errs, errors.New("master_key: one of passphrase_env or keyfile must be set"))
}
if hasPassEnv && hasKeyFile {
errs = append(errs, errors.New("master_key: only one of passphrase_env or keyfile may be set"))
}
return errors.Join(errs...)
}
// DefaultExpiry returns the configured default token expiry duration.
func (c *Config) DefaultExpiry() time.Duration { return c.Tokens.DefaultExpiry.Duration }
// AdminExpiry returns the configured admin token expiry duration.
func (c *Config) AdminExpiry() time.Duration { return c.Tokens.AdminExpiry.Duration }
// ServiceExpiry returns the configured service token expiry duration.
func (c *Config) ServiceExpiry() time.Duration { return c.Tokens.ServiceExpiry.Duration }

View File

@@ -0,0 +1,225 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
)
// validConfig returns a minimal valid TOML config string.
func validConfig() string {
return `
[server]
listen_addr = "0.0.0.0:8443"
tls_cert = "/etc/mcias/server.crt"
tls_key = "/etc/mcias/server.key"
[database]
path = "/var/lib/mcias/mcias.db"
[tokens]
issuer = "https://auth.example.com"
default_expiry = "720h"
admin_expiry = "8h"
service_expiry = "8760h"
[argon2]
time = 3
memory = 65536
threads = 4
[master_key]
passphrase_env = "MCIAS_MASTER_PASSPHRASE"
`
}
func writeTempConfig(t *testing.T, content string) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "mcias.toml")
if err := os.WriteFile(path, []byte(content), 0600); err != nil {
t.Fatalf("write temp config: %v", err)
}
return path
}
func TestLoadValidConfig(t *testing.T) {
path := writeTempConfig(t, validConfig())
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load returned error: %v", err)
}
if cfg.Server.ListenAddr != "0.0.0.0:8443" {
t.Errorf("ListenAddr = %q, want %q", cfg.Server.ListenAddr, "0.0.0.0:8443")
}
if cfg.Tokens.Issuer != "https://auth.example.com" {
t.Errorf("Issuer = %q, want %q", cfg.Tokens.Issuer, "https://auth.example.com")
}
if cfg.DefaultExpiry() != 720*time.Hour {
t.Errorf("DefaultExpiry = %v, want %v", cfg.DefaultExpiry(), 720*time.Hour)
}
if cfg.AdminExpiry() != 8*time.Hour {
t.Errorf("AdminExpiry = %v, want %v", cfg.AdminExpiry(), 8*time.Hour)
}
if cfg.ServiceExpiry() != 8760*time.Hour {
t.Errorf("ServiceExpiry = %v, want %v", cfg.ServiceExpiry(), 8760*time.Hour)
}
if cfg.Argon2.Time != 3 {
t.Errorf("Argon2.Time = %d, want 3", cfg.Argon2.Time)
}
if cfg.Argon2.Memory != 65536 {
t.Errorf("Argon2.Memory = %d, want 65536", cfg.Argon2.Memory)
}
if cfg.MasterKey.PassphraseEnv != "MCIAS_MASTER_PASSPHRASE" {
t.Errorf("MasterKey.PassphraseEnv = %q", cfg.MasterKey.PassphraseEnv)
}
}
func TestLoadMissingFile(t *testing.T) {
_, err := Load("/nonexistent/path/mcias.toml")
if err == nil {
t.Error("expected error for missing file, got nil")
}
}
func TestLoadInvalidTOML(t *testing.T) {
path := writeTempConfig(t, "this is not valid TOML {{{{")
_, err := Load(path)
if err == nil {
t.Error("expected error for invalid TOML, got nil")
}
}
func TestValidateMissingListenAddr(t *testing.T) {
path := writeTempConfig(t, `
[server]
tls_cert = "/etc/mcias/server.crt"
tls_key = "/etc/mcias/server.key"
[database]
path = "/var/lib/mcias/mcias.db"
[tokens]
issuer = "https://auth.example.com"
default_expiry = "720h"
admin_expiry = "8h"
service_expiry = "8760h"
[argon2]
time = 3
memory = 65536
threads = 4
[master_key]
passphrase_env = "MCIAS_MASTER_PASSPHRASE"
`)
_, err := Load(path)
if err == nil {
t.Error("expected error for missing listen_addr, got nil")
}
}
func TestValidateArgon2TooWeak(t *testing.T) {
tests := []struct {
name string
time uint32
memory uint32
}{
{"time too low", 1, 65536},
{"memory too low", 3, 32768},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
content := validConfig()
// Override argon2 section
path := writeTempConfig(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("baseline load failed: %v", err)
}
// Manually set unsafe params and re-validate
cfg.Argon2.Time = tc.time
cfg.Argon2.Memory = tc.memory
if err := cfg.validate(); err == nil {
t.Errorf("expected validation error for time=%d memory=%d, got nil", tc.time, tc.memory)
}
})
}
}
func TestValidateMasterKeyBothSet(t *testing.T) {
path := writeTempConfig(t, `
[server]
listen_addr = "0.0.0.0:8443"
tls_cert = "/etc/mcias/server.crt"
tls_key = "/etc/mcias/server.key"
[database]
path = "/var/lib/mcias/mcias.db"
[tokens]
issuer = "https://auth.example.com"
default_expiry = "720h"
admin_expiry = "8h"
service_expiry = "8760h"
[argon2]
time = 3
memory = 65536
threads = 4
[master_key]
passphrase_env = "MCIAS_MASTER_PASSPHRASE"
keyfile = "/etc/mcias/master.key"
`)
_, err := Load(path)
if err == nil {
t.Error("expected error when both passphrase_env and keyfile are set, got nil")
}
}
func TestValidateMasterKeyNoneSet(t *testing.T) {
path := writeTempConfig(t, `
[server]
listen_addr = "0.0.0.0:8443"
tls_cert = "/etc/mcias/server.crt"
tls_key = "/etc/mcias/server.key"
[database]
path = "/var/lib/mcias/mcias.db"
[tokens]
issuer = "https://auth.example.com"
default_expiry = "720h"
admin_expiry = "8h"
service_expiry = "8760h"
[argon2]
time = 3
memory = 65536
threads = 4
[master_key]
`)
_, err := Load(path)
if err == nil {
t.Error("expected error when neither passphrase_env nor keyfile is set, got nil")
}
}
func TestDurationParsing(t *testing.T) {
var d duration
if err := d.UnmarshalText([]byte("1h30m")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if d.Duration != 90*time.Minute {
t.Errorf("Duration = %v, want %v", d.Duration, 90*time.Minute)
}
if err := d.UnmarshalText([]byte("not-a-duration")); err == nil {
t.Error("expected error for invalid duration, got nil")
}
}

192
internal/crypto/crypto.go Normal file
View File

@@ -0,0 +1,192 @@
// Package crypto provides key management and encryption helpers for MCIAS.
//
// Security design:
// - All random material (keys, nonces, salts) comes from crypto/rand.
// - AES-256-GCM is used for symmetric encryption; the 256-bit key size
// provides 128-bit post-quantum security margin.
// - Ed25519 is used for JWT signing; it has no key-size or parameter
// malleability issues that affect RSA/ECDSA.
// - The master key KDF uses Argon2id (separate parameterisation from
// password hashing) to derive a 256-bit key from a passphrase.
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"golang.org/x/crypto/argon2"
)
const (
// aesKeySize is 32 bytes = 256-bit AES key.
aesKeySize = 32
// gcmNonceSize is the standard 96-bit GCM nonce.
gcmNonceSize = 12
// kdfSaltSize is 32 bytes for the Argon2id salt.
kdfSaltSize = 32
// kdfTime and kdfMemory are the Argon2id parameters used for master key
// derivation. These are separate from password hashing parameters and are
// chosen to be expensive enough to resist offline attack on the passphrase.
// Security: OWASP 2023 recommends time=2, memory=64MiB as minimum.
// We use time=3, memory=64MiB, threads=4 as the operational default for
// password hashing (configured in mcias.toml).
// For master key derivation, we hardcode time=3, memory=128MiB, threads=4
// since this only runs at server startup.
kdfTime = 3
kdfMemory = 128 * 1024 // 128 MiB in KiB
kdfThreads = 4
)
// GenerateEd25519KeyPair generates a new Ed25519 key pair using crypto/rand.
// Security: Ed25519 key generation is deterministic given the seed; crypto/rand
// provides the cryptographically-secure seed.
func GenerateEd25519KeyPair() (ed25519.PublicKey, ed25519.PrivateKey, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("crypto: generate Ed25519 key pair: %w", err)
}
return pub, priv, nil
}
// MarshalPrivateKeyPEM encodes an Ed25519 private key as a PKCS#8 PEM block.
func MarshalPrivateKeyPEM(key ed25519.PrivateKey) ([]byte, error) {
der, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return nil, fmt.Errorf("crypto: marshal private key DER: %w", err)
}
return pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: der,
}), nil
}
// ParsePrivateKeyPEM decodes a PKCS#8 PEM-encoded Ed25519 private key.
// Returns an error if the PEM block is missing, malformed, or not an Ed25519 key.
func ParsePrivateKeyPEM(pemData []byte) (ed25519.PrivateKey, error) {
block, _ := pem.Decode(pemData)
if block == nil {
return nil, errors.New("crypto: no PEM block found")
}
if block.Type != "PRIVATE KEY" {
return nil, fmt.Errorf("crypto: unexpected PEM block type %q, want %q", block.Type, "PRIVATE KEY")
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("crypto: parse PKCS#8 private key: %w", err)
}
ed, ok := key.(ed25519.PrivateKey)
if !ok {
return nil, fmt.Errorf("crypto: PEM key is not Ed25519 (got %T)", key)
}
return ed, nil
}
// SealAESGCM encrypts plaintext with AES-256-GCM using key.
// Returns ciphertext and nonce separately so both can be stored.
// Security: A fresh random nonce is generated for every call. Nonce reuse
// under the same key would break GCM's confidentiality and authentication
// guarantees, so callers must never reuse nonces manually.
func SealAESGCM(key, plaintext []byte) (ciphertext, nonce []byte, err error) {
if len(key) != aesKeySize {
return nil, nil, fmt.Errorf("crypto: AES-GCM key must be %d bytes, got %d", aesKeySize, len(key))
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, fmt.Errorf("crypto: create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, nil, fmt.Errorf("crypto: create GCM: %w", err)
}
nonce = make([]byte, gcmNonceSize)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, nil, fmt.Errorf("crypto: generate GCM nonce: %w", err)
}
ciphertext = gcm.Seal(nil, nonce, plaintext, nil)
return ciphertext, nonce, nil
}
// OpenAESGCM decrypts and authenticates ciphertext encrypted with SealAESGCM.
// Returns the plaintext, or an error if authentication fails (wrong key, tampered
// ciphertext, or wrong nonce).
func OpenAESGCM(key, nonce, ciphertext []byte) ([]byte, error) {
if len(key) != aesKeySize {
return nil, fmt.Errorf("crypto: AES-GCM key must be %d bytes, got %d", aesKeySize, len(key))
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("crypto: create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("crypto: create GCM: %w", err)
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
// Do not expose internal GCM error details; they could reveal key info.
return nil, errors.New("crypto: AES-GCM authentication failed")
}
return plaintext, nil
}
// DeriveKey derives a 256-bit AES key from passphrase and salt using Argon2id.
// The salt must be at least 16 bytes; use NewSalt to generate one.
// Security: Argon2id is the OWASP-recommended KDF for key derivation from
// passphrases. The parameters are hardcoded at compile time and exceed OWASP
// minimums to resist offline dictionary attacks against the passphrase.
func DeriveKey(passphrase string, salt []byte) ([]byte, error) {
if len(salt) < 16 {
return nil, fmt.Errorf("crypto: KDF salt must be at least 16 bytes, got %d", len(salt))
}
if passphrase == "" {
return nil, errors.New("crypto: passphrase must not be empty")
}
// argon2.IDKey returns keyLen bytes derived from the passphrase and salt.
// Security: parameters are time=3, memory=128MiB, threads=4, keyLen=32.
// These exceed OWASP 2023 minimums for key derivation.
key := argon2.IDKey(
[]byte(passphrase),
salt,
kdfTime,
kdfMemory,
kdfThreads,
aesKeySize,
)
return key, nil
}
// NewSalt generates a cryptographically-random 32-byte KDF salt.
func NewSalt() ([]byte, error) {
salt := make([]byte, kdfSaltSize)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return nil, fmt.Errorf("crypto: generate salt: %w", err)
}
return salt, nil
}
// RandomBytes returns n cryptographically-random bytes.
func RandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return nil, fmt.Errorf("crypto: read random bytes: %w", err)
}
return b, nil
}

View File

@@ -0,0 +1,259 @@
package crypto
import (
"bytes"
"crypto/ed25519"
"testing"
)
// TestGenerateEd25519KeyPair verifies that key generation returns valid,
// distinct keys and that the public key is derivable from the private key.
func TestGenerateEd25519KeyPair(t *testing.T) {
pub1, priv1, err := GenerateEd25519KeyPair()
if err != nil {
t.Fatalf("GenerateEd25519KeyPair: %v", err)
}
pub2, priv2, err := GenerateEd25519KeyPair()
if err != nil {
t.Fatalf("GenerateEd25519KeyPair second call: %v", err)
}
// Keys should be different across calls.
if bytes.Equal(priv1, priv2) {
t.Error("two calls produced identical private keys")
}
if bytes.Equal(pub1, pub2) {
t.Error("two calls produced identical public keys")
}
// Public key must be extractable from private key.
derived := priv1.Public().(ed25519.PublicKey)
if !bytes.Equal(derived, pub1) {
t.Error("public key derived from private key does not match generated public key")
}
}
// TestEd25519PEMRoundTrip verifies that a private key can be encoded to PEM
// and decoded back to the identical key.
func TestEd25519PEMRoundTrip(t *testing.T) {
_, priv, err := GenerateEd25519KeyPair()
if err != nil {
t.Fatalf("GenerateEd25519KeyPair: %v", err)
}
pem, err := MarshalPrivateKeyPEM(priv)
if err != nil {
t.Fatalf("MarshalPrivateKeyPEM: %v", err)
}
if len(pem) == 0 {
t.Fatal("MarshalPrivateKeyPEM returned empty PEM")
}
decoded, err := ParsePrivateKeyPEM(pem)
if err != nil {
t.Fatalf("ParsePrivateKeyPEM: %v", err)
}
if !bytes.Equal(priv, decoded) {
t.Error("decoded private key does not match original")
}
}
// TestParsePrivateKeyPEMErrors validates error cases.
func TestParsePrivateKeyPEMErrors(t *testing.T) {
// Empty input
if _, err := ParsePrivateKeyPEM([]byte{}); err == nil {
t.Error("expected error for empty PEM, got nil")
}
// Wrong PEM type (using a fake RSA block header)
fakePEM := []byte("-----BEGIN RSA PRIVATE KEY-----\nYWJj\n-----END RSA PRIVATE KEY-----\n")
if _, err := ParsePrivateKeyPEM(fakePEM); err == nil {
t.Error("expected error for wrong PEM type, got nil")
}
// Corrupt DER inside valid PEM block
corruptPEM := []byte("-----BEGIN PRIVATE KEY-----\nYWJj\n-----END PRIVATE KEY-----\n")
if _, err := ParsePrivateKeyPEM(corruptPEM); err == nil {
t.Error("expected error for corrupt DER, got nil")
}
}
// TestSealOpenAESGCMRoundTrip verifies that sealed data can be opened.
func TestSealOpenAESGCMRoundTrip(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
plaintext := []byte("hello world secret data")
ct, nonce, err := SealAESGCM(key, plaintext)
if err != nil {
t.Fatalf("SealAESGCM: %v", err)
}
if len(ct) == 0 || len(nonce) == 0 {
t.Fatal("SealAESGCM returned empty ciphertext or nonce")
}
got, err := OpenAESGCM(key, nonce, ct)
if err != nil {
t.Fatalf("OpenAESGCM: %v", err)
}
if !bytes.Equal(got, plaintext) {
t.Errorf("decrypted = %q, want %q", got, plaintext)
}
}
// TestSealNoncesAreUnique verifies that repeated seals produce different nonces.
func TestSealNoncesAreUnique(t *testing.T) {
key := make([]byte, 32)
plaintext := []byte("same plaintext")
_, nonce1, err := SealAESGCM(key, plaintext)
if err != nil {
t.Fatalf("SealAESGCM (1): %v", err)
}
_, nonce2, err := SealAESGCM(key, plaintext)
if err != nil {
t.Fatalf("SealAESGCM (2): %v", err)
}
if bytes.Equal(nonce1, nonce2) {
t.Error("two seals of the same plaintext produced identical nonces — crypto/rand may be broken")
}
}
// TestOpenAESGCMWrongKey verifies that decryption with the wrong key fails.
func TestOpenAESGCMWrongKey(t *testing.T) {
key := make([]byte, 32)
wrongKey := make([]byte, 32)
wrongKey[0] = 0xFF
ct, nonce, err := SealAESGCM(key, []byte("secret"))
if err != nil {
t.Fatalf("SealAESGCM: %v", err)
}
if _, err := OpenAESGCM(wrongKey, nonce, ct); err == nil {
t.Error("expected error when opening with wrong key, got nil")
}
}
// TestOpenAESGCMTamperedCiphertext verifies that tampering is detected.
func TestOpenAESGCMTamperedCiphertext(t *testing.T) {
key := make([]byte, 32)
ct, nonce, err := SealAESGCM(key, []byte("secret"))
if err != nil {
t.Fatalf("SealAESGCM: %v", err)
}
// Flip one bit in the ciphertext.
ct[0] ^= 0x01
if _, err := OpenAESGCM(key, nonce, ct); err == nil {
t.Error("expected error for tampered ciphertext, got nil")
}
}
// TestOpenAESGCMWrongKeySize verifies that keys with wrong size are rejected.
func TestOpenAESGCMWrongKeySize(t *testing.T) {
if _, _, err := SealAESGCM([]byte("short"), []byte("data")); err == nil {
t.Error("expected error for short key in Seal, got nil")
}
if _, err := OpenAESGCM([]byte("short"), make([]byte, 12), []byte("data")); err == nil {
t.Error("expected error for short key in Open, got nil")
}
}
// TestDeriveKey verifies that DeriveKey produces consistent, non-empty output.
func TestDeriveKey(t *testing.T) {
salt, err := NewSalt()
if err != nil {
t.Fatalf("NewSalt: %v", err)
}
key1, err := DeriveKey("my-passphrase", salt)
if err != nil {
t.Fatalf("DeriveKey: %v", err)
}
if len(key1) != 32 {
t.Errorf("DeriveKey returned %d bytes, want 32", len(key1))
}
// Same inputs → same output (deterministic).
key2, err := DeriveKey("my-passphrase", salt)
if err != nil {
t.Fatalf("DeriveKey (2): %v", err)
}
if !bytes.Equal(key1, key2) {
t.Error("DeriveKey is not deterministic")
}
// Different passphrase → different key.
key3, err := DeriveKey("different-passphrase", salt)
if err != nil {
t.Fatalf("DeriveKey (3): %v", err)
}
if bytes.Equal(key1, key3) {
t.Error("different passphrases produced the same key")
}
// Different salt → different key.
salt2, err := NewSalt()
if err != nil {
t.Fatalf("NewSalt (2): %v", err)
}
key4, err := DeriveKey("my-passphrase", salt2)
if err != nil {
t.Fatalf("DeriveKey (4): %v", err)
}
if bytes.Equal(key1, key4) {
t.Error("different salts produced the same key")
}
}
// TestDeriveKeyErrors verifies invalid input rejection.
func TestDeriveKeyErrors(t *testing.T) {
// Short salt
if _, err := DeriveKey("passphrase", []byte("short")); err == nil {
t.Error("expected error for short salt, got nil")
}
// Empty passphrase
salt, _ := NewSalt()
if _, err := DeriveKey("", salt); err == nil {
t.Error("expected error for empty passphrase, got nil")
}
}
// TestNewSaltUniqueness verifies that two salts are different.
func TestNewSaltUniqueness(t *testing.T) {
s1, err := NewSalt()
if err != nil {
t.Fatalf("NewSalt (1): %v", err)
}
s2, err := NewSalt()
if err != nil {
t.Fatalf("NewSalt (2): %v", err)
}
if bytes.Equal(s1, s2) {
t.Error("two NewSalt calls returned identical salts")
}
}
// TestRandomBytes verifies length and uniqueness.
func TestRandomBytes(t *testing.T) {
b1, err := RandomBytes(32)
if err != nil {
t.Fatalf("RandomBytes: %v", err)
}
if len(b1) != 32 {
t.Errorf("RandomBytes returned %d bytes, want 32", len(b1))
}
b2, err := RandomBytes(32)
if err != nil {
t.Fatalf("RandomBytes (2): %v", err)
}
if bytes.Equal(b1, b2) {
t.Error("two RandomBytes calls returned identical values")
}
}

608
internal/db/accounts.go Normal file
View File

@@ -0,0 +1,608 @@
package db
import (
"database/sql"
"errors"
"fmt"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
"github.com/google/uuid"
)
// CreateAccount inserts a new account record. The UUID is generated
// automatically. Returns the created Account with its DB-assigned ID and UUID.
func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) {
id := uuid.New().String()
n := now()
result, err := db.sql.Exec(`
INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at)
VALUES (?, ?, ?, ?, 'active', ?, ?)
`, id, username, string(accountType), nullString(passwordHash), n, n)
if err != nil {
return nil, fmt.Errorf("db: create account %q: %w", username, err)
}
rowID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err)
}
createdAt, err := parseTime(n)
if err != nil {
return nil, err
}
return &model.Account{
ID: rowID,
UUID: id,
Username: username,
AccountType: accountType,
Status: model.AccountStatusActive,
PasswordHash: passwordHash,
CreatedAt: createdAt,
UpdatedAt: createdAt,
}, nil
}
// GetAccountByUUID retrieves an account by its external UUID.
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE uuid = ?
`, accountUUID))
}
// GetAccountByUsername retrieves an account by username (case-insensitive).
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE username = ?
`, username))
}
// ListAccounts returns all non-deleted accounts ordered by username.
func (db *DB) ListAccounts() ([]*model.Account, error) {
rows, err := db.sql.Query(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts
WHERE status != 'deleted'
ORDER BY username ASC
`)
if err != nil {
return nil, fmt.Errorf("db: list accounts: %w", err)
}
defer rows.Close()
var accounts []*model.Account
for rows.Next() {
a, err := db.scanAccountRow(rows)
if err != nil {
return nil, err
}
accounts = append(accounts, a)
}
return accounts, rows.Err()
}
// UpdateAccountStatus updates the status field and optionally sets deleted_at.
func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error {
n := now()
var deletedAt *string
if status == model.AccountStatusDeleted {
deletedAt = &n
}
_, err := db.sql.Exec(`
UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ?
WHERE id = ?
`, string(status), deletedAt, n, accountID)
if err != nil {
return fmt.Errorf("db: update account status: %w", err)
}
return nil
}
// UpdatePasswordHash updates the Argon2id password hash for an account.
func (db *DB) UpdatePasswordHash(accountID int64, hash string) error {
_, err := db.sql.Exec(`
UPDATE accounts SET password_hash = ?, updated_at = ?
WHERE id = ?
`, hash, now(), accountID)
if err != nil {
return fmt.Errorf("db: update password hash: %w", err)
}
return nil
}
// SetTOTP stores the encrypted TOTP secret and marks TOTP as required.
func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
WHERE id = ?
`, secretEnc, secretNonce, now(), accountID)
if err != nil {
return fmt.Errorf("db: set TOTP: %w", err)
}
return nil
}
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
func (db *DB) ClearTOTP(accountID int64) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ?
WHERE id = ?
`, now(), accountID)
if err != nil {
return fmt.Errorf("db: clear TOTP: %w", err)
}
return nil
}
// scanAccount scans a single account row from a *sql.Row.
func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) {
var a model.Account
var accountType, status string
var totpRequired int
var createdAtStr, updatedAtStr string
var deletedAtStr *string
var totpSecretEnc, totpSecretNonce []byte
err := row.Scan(
&a.ID, &a.UUID, &a.Username,
&accountType, &a.PasswordHash,
&status, &totpRequired,
&totpSecretEnc, &totpSecretNonce,
&createdAtStr, &updatedAtStr, &deletedAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: scan account: %w", err)
}
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
}
// scanAccountRow scans a single account from *sql.Rows.
func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) {
var a model.Account
var accountType, status string
var totpRequired int
var createdAtStr, updatedAtStr string
var deletedAtStr *string
var totpSecretEnc, totpSecretNonce []byte
err := rows.Scan(
&a.ID, &a.UUID, &a.Username,
&accountType, &a.PasswordHash,
&status, &totpRequired,
&totpSecretEnc, &totpSecretNonce,
&createdAtStr, &updatedAtStr, &deletedAtStr,
)
if err != nil {
return nil, fmt.Errorf("db: scan account row: %w", err)
}
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
}
func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) {
a.AccountType = model.AccountType(accountType)
a.Status = model.AccountStatus(status)
a.TOTPRequired = totpRequired == 1
a.TOTPSecretEnc = totpSecretEnc
a.TOTPSecretNonce = totpSecretNonce
var err error
a.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
a.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
a.DeletedAt, err = nullableTime(deletedAtStr)
if err != nil {
return nil, err
}
return a, nil
}
// nullString converts an empty string to nil for nullable SQL columns.
func nullString(s string) *string {
if s == "" {
return nil
}
return &s
}
// GetRoles returns the role strings assigned to an account.
func (db *DB) GetRoles(accountID int64) ([]string, error) {
rows, err := db.sql.Query(`
SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC
`, accountID)
if err != nil {
return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err)
}
defer rows.Close()
var roles []string
for rows.Next() {
var role string
if err := rows.Scan(&role); err != nil {
return nil, fmt.Errorf("db: scan role: %w", err)
}
roles = append(roles, role)
}
return roles, rows.Err()
}
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
_, err := db.sql.Exec(`
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
`, accountID, role, grantedBy, now())
if err != nil {
return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err)
}
return nil
}
// RevokeRole removes a role from an account.
func (db *DB) RevokeRole(accountID int64, role string) error {
_, err := db.sql.Exec(`
DELETE FROM account_roles WHERE account_id = ? AND role = ?
`, accountID, role)
if err != nil {
return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err)
}
return nil
}
// SetRoles replaces the full role set for an account atomically.
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
tx, err := db.sql.Begin()
if err != nil {
return fmt.Errorf("db: set roles begin tx: %w", err)
}
if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: set roles delete existing: %w", err)
}
n := now()
for _, role := range roles {
if _, err := tx.Exec(`
INSERT INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
`, accountID, role, grantedBy, n); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: set roles insert %q: %w", role, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("db: set roles commit: %w", err)
}
return nil
}
// HasRole reports whether an account holds the given role.
func (db *DB) HasRole(accountID int64, role string) (bool, error) {
var count int
err := db.sql.QueryRow(`
SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ?
`, accountID, role).Scan(&count)
if err != nil {
return false, fmt.Errorf("db: has role: %w", err)
}
return count > 0, nil
}
// WriteServerConfig stores the encrypted Ed25519 signing key.
// There can only be one row (id=1).
func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at)
VALUES (1, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
signing_key_enc = excluded.signing_key_enc,
signing_key_nonce = excluded.signing_key_nonce,
updated_at = excluded.updated_at
`, signingKeyEnc, signingKeyNonce, n, n)
if err != nil {
return fmt.Errorf("db: write server config: %w", err)
}
return nil
}
// ReadServerConfig returns the encrypted signing key and nonce.
// Returns ErrNotFound if no config row exists yet.
func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) {
err = db.sql.QueryRow(`
SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1
`).Scan(&signingKeyEnc, &signingKeyNonce)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, ErrNotFound
}
if err != nil {
return nil, nil, fmt.Errorf("db: read server config: %w", err)
}
return signingKeyEnc, signingKeyNonce, nil
}
// WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation.
// The salt must be stable across restarts so the same passphrase always yields
// the same master key. There can only be one row (id=1).
func (db *DB) WriteMasterKeySalt(salt []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO server_config (id, master_key_salt, created_at, updated_at)
VALUES (1, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
master_key_salt = excluded.master_key_salt,
updated_at = excluded.updated_at
`, salt, n, n)
if err != nil {
return fmt.Errorf("db: write master key salt: %w", err)
}
return nil
}
// ReadMasterKeySalt returns the stored Argon2id KDF salt.
// Returns ErrNotFound if no salt has been stored yet (first run).
func (db *DB) ReadMasterKeySalt() ([]byte, error) {
var salt []byte
err := db.sql.QueryRow(`
SELECT master_key_salt FROM server_config WHERE id = 1
`).Scan(&salt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: read master key salt: %w", err)
}
if salt == nil {
return nil, ErrNotFound
}
return salt, nil
}
// WritePGCredentials stores or replaces the Postgres credentials for an account.
func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO pg_credentials
(account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(account_id) DO UPDATE SET
pg_host = excluded.pg_host,
pg_port = excluded.pg_port,
pg_database = excluded.pg_database,
pg_username = excluded.pg_username,
pg_password_enc = excluded.pg_password_enc,
pg_password_nonce = excluded.pg_password_nonce,
updated_at = excluded.updated_at
`, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n)
if err != nil {
return fmt.Errorf("db: write pg credentials: %w", err)
}
return nil
}
// ReadPGCredentials retrieves the encrypted Postgres credentials for an account.
// Returns ErrNotFound if no credentials are stored.
func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) {
var cred model.PGCredential
var createdAtStr, updatedAtStr string
err := db.sql.QueryRow(`
SELECT id, account_id, pg_host, pg_port, pg_database, pg_username,
pg_password_enc, pg_password_nonce, created_at, updated_at
FROM pg_credentials WHERE account_id = ?
`, accountID).Scan(
&cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort,
&cred.PGDatabase, &cred.PGUsername,
&cred.PGPasswordEnc, &cred.PGPasswordNonce,
&createdAtStr, &updatedAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: read pg credentials: %w", err)
}
cred.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
cred.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
return &cred, nil
}
// WriteAuditEvent appends an audit log entry.
// Details must never contain credential material.
func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error {
_, err := db.sql.Exec(`
INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details)
VALUES (?, ?, ?, ?, ?)
`, eventType, actorID, targetID, nullString(ipAddress), nullString(details))
if err != nil {
return fmt.Errorf("db: write audit event %q: %w", eventType, err)
}
return nil
}
// TrackToken records a newly issued JWT JTI for revocation tracking.
func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error {
_, err := db.sql.Exec(`
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
VALUES (?, ?, ?, ?)
`, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339))
if err != nil {
return fmt.Errorf("db: track token %q: %w", jti, err)
}
return nil
}
// GetTokenRecord retrieves a token record by JTI.
// Returns ErrNotFound if no record exists (token was never issued by this server).
func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) {
var rec model.TokenRecord
var issuedAtStr, expiresAtStr, createdAtStr string
var revokedAtStr *string
var revokeReason *string
err := db.sql.QueryRow(`
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
FROM token_revocation WHERE jti = ?
`, jti).Scan(
&rec.ID, &rec.JTI, &rec.AccountID,
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
&createdAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: get token record %q: %w", jti, err)
}
var parseErr error
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
if parseErr != nil {
return nil, parseErr
}
if revokeReason != nil {
rec.RevokeReason = *revokeReason
}
return &rec, nil
}
// RevokeToken marks a token as revoked by JTI.
func (db *DB) RevokeToken(jti, reason string) error {
n := now()
result, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE jti = ? AND revoked_at IS NULL
`, n, nullString(reason), jti)
if err != nil {
return fmt.Errorf("db: revoke token %q: %w", jti, err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("db: revoke token rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("db: token %q not found or already revoked", jti)
}
return nil
}
// RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account.
func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error {
n := now()
_, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ?
`, n, nullString(reason), accountID, n)
if err != nil {
return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err)
}
return nil
}
// PruneExpiredTokens removes token_revocation rows that are past their expiry.
// Returns the number of rows deleted.
func (db *DB) PruneExpiredTokens() (int64, error) {
result, err := db.sql.Exec(`
DELETE FROM token_revocation WHERE expires_at < ?
`, now())
if err != nil {
return 0, fmt.Errorf("db: prune expired tokens: %w", err)
}
return result.RowsAffected()
}
// SetSystemToken stores or replaces the active service token JTI for a system account.
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO system_tokens (account_id, jti, expires_at, created_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(account_id) DO UPDATE SET
jti = excluded.jti,
expires_at = excluded.expires_at,
created_at = excluded.created_at
`, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n)
if err != nil {
return fmt.Errorf("db: set system token for account %d: %w", accountID, err)
}
return nil
}
// GetSystemToken retrieves the active service token record for a system account.
func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) {
var st model.SystemToken
var expiresAtStr, createdAtStr string
err := db.sql.QueryRow(`
SELECT id, account_id, jti, expires_at, created_at
FROM system_tokens WHERE account_id = ?
`, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: get system token: %w", err)
}
var parseErr error
st.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
st.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
return &st, nil
}

109
internal/db/db.go Normal file
View File

@@ -0,0 +1,109 @@
// Package db provides the SQLite database access layer for MCIAS.
//
// Security design:
// - All queries use parameterized statements; no string concatenation.
// - Foreign keys are enforced (PRAGMA foreign_keys = ON).
// - WAL mode is enabled for safe concurrent reads during writes.
// - The audit log is append-only: no update or delete operations are provided.
package db
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
_ "modernc.org/sqlite" // register the sqlite3 driver
)
// DB wraps a *sql.DB with MCIAS-specific helpers.
type DB struct {
sql *sql.DB
}
// Open opens (or creates) the SQLite database at path and configures it for
// MCIAS use (WAL mode, foreign keys, busy timeout).
func Open(path string) (*DB, error) {
// The modernc.org/sqlite driver is registered as "sqlite".
sqlDB, err := sql.Open("sqlite", path)
if err != nil {
return nil, fmt.Errorf("db: open sqlite: %w", err)
}
// Use a single connection for writes; reads can use the pool.
sqlDB.SetMaxOpenConns(1)
db := &DB{sql: sqlDB}
if err := db.configure(); err != nil {
_ = sqlDB.Close()
return nil, err
}
return db, nil
}
// configure applies PRAGMAs that must be set on every connection.
func (db *DB) configure() error {
pragmas := []string{
"PRAGMA journal_mode=WAL",
"PRAGMA foreign_keys=ON",
"PRAGMA busy_timeout=5000",
"PRAGMA synchronous=NORMAL",
}
for _, p := range pragmas {
if _, err := db.sql.Exec(p); err != nil {
return fmt.Errorf("db: configure pragma %q: %w", p, err)
}
}
return nil
}
// Close closes the underlying database connection.
func (db *DB) Close() error {
return db.sql.Close()
}
// Ping verifies the database connection is alive.
func (db *DB) Ping(ctx context.Context) error {
return db.sql.PingContext(ctx)
}
// SQL returns the underlying *sql.DB for use in tests or advanced queries.
// Prefer the typed methods on DB for all production code.
func (db *DB) SQL() *sql.DB {
return db.sql
}
// now returns the current UTC time formatted as ISO-8601.
func now() string {
return time.Now().UTC().Format(time.RFC3339)
}
// parseTime parses an ISO-8601 UTC time string returned by SQLite.
func parseTime(s string) (time.Time, error) {
t, err := time.Parse(time.RFC3339, s)
if err != nil {
// Try without timezone suffix (some SQLite defaults).
t, err = time.Parse("2006-01-02T15:04:05", s)
if err != nil {
return time.Time{}, fmt.Errorf("db: parse time %q: %w", s, err)
}
return t.UTC(), nil
}
return t.UTC(), nil
}
// ErrNotFound is returned when a requested record does not exist.
var ErrNotFound = errors.New("db: record not found")
// nullableTime converts a *string from SQLite into a *time.Time.
func nullableTime(s *string) (*time.Time, error) {
if s == nil {
return nil, nil
}
t, err := parseTime(*s)
if err != nil {
return nil, err
}
return &t, nil
}

355
internal/db/db_test.go Normal file
View File

@@ -0,0 +1,355 @@
package db
import (
"testing"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
)
// openTestDB opens an in-memory SQLite database for testing.
func openTestDB(t *testing.T) *DB {
t.Helper()
db, err := Open(":memory:")
if err != nil {
t.Fatalf("Open: %v", err)
}
if err := Migrate(db); err != nil {
t.Fatalf("Migrate: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
return db
}
func TestMigrateIdempotent(t *testing.T) {
db := openTestDB(t)
// Run again — should be a no-op.
if err := Migrate(db); err != nil {
t.Errorf("second Migrate call returned error: %v", err)
}
}
func TestCreateAndGetAccount(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("alice", model.AccountTypeHuman, "$argon2id$v=19$...")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
if acct.UUID == "" {
t.Error("expected non-empty UUID")
}
if acct.Username != "alice" {
t.Errorf("Username = %q, want %q", acct.Username, "alice")
}
if acct.Status != model.AccountStatusActive {
t.Errorf("Status = %q, want active", acct.Status)
}
// Retrieve by UUID.
got, err := db.GetAccountByUUID(acct.UUID)
if err != nil {
t.Fatalf("GetAccountByUUID: %v", err)
}
if got.Username != "alice" {
t.Errorf("fetched Username = %q, want %q", got.Username, "alice")
}
// Retrieve by username.
got2, err := db.GetAccountByUsername("alice")
if err != nil {
t.Fatalf("GetAccountByUsername: %v", err)
}
if got2.UUID != acct.UUID {
t.Errorf("UUID mismatch: got %q, want %q", got2.UUID, acct.UUID)
}
}
func TestGetAccountNotFound(t *testing.T) {
db := openTestDB(t)
_, err := db.GetAccountByUUID("nonexistent-uuid")
if err != ErrNotFound {
t.Errorf("expected ErrNotFound, got %v", err)
}
_, err = db.GetAccountByUsername("nobody")
if err != ErrNotFound {
t.Errorf("expected ErrNotFound, got %v", err)
}
}
func TestUpdateAccountStatus(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("bob", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
if err := db.UpdateAccountStatus(acct.ID, model.AccountStatusInactive); err != nil {
t.Fatalf("UpdateAccountStatus: %v", err)
}
got, err := db.GetAccountByUUID(acct.UUID)
if err != nil {
t.Fatalf("GetAccountByUUID: %v", err)
}
if got.Status != model.AccountStatusInactive {
t.Errorf("Status = %q, want inactive", got.Status)
}
}
func TestListAccounts(t *testing.T) {
db := openTestDB(t)
for _, name := range []string{"charlie", "delta", "eve"} {
if _, err := db.CreateAccount(name, model.AccountTypeHuman, "hash"); err != nil {
t.Fatalf("CreateAccount %q: %v", name, err)
}
}
accts, err := db.ListAccounts()
if err != nil {
t.Fatalf("ListAccounts: %v", err)
}
if len(accts) != 3 {
t.Errorf("ListAccounts returned %d accounts, want 3", len(accts))
}
}
func TestRoleOperations(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("frank", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
// GrantRole
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
t.Fatalf("GrantRole: %v", err)
}
// Grant again — should be no-op.
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
t.Fatalf("GrantRole duplicate: %v", err)
}
roles, err := db.GetRoles(acct.ID)
if err != nil {
t.Fatalf("GetRoles: %v", err)
}
if len(roles) != 1 || roles[0] != "admin" {
t.Errorf("GetRoles = %v, want [admin]", roles)
}
has, err := db.HasRole(acct.ID, "admin")
if err != nil {
t.Fatalf("HasRole: %v", err)
}
if !has {
t.Error("expected HasRole to return true for 'admin'")
}
// RevokeRole
if err := db.RevokeRole(acct.ID, "admin"); err != nil {
t.Fatalf("RevokeRole: %v", err)
}
roles, err = db.GetRoles(acct.ID)
if err != nil {
t.Fatalf("GetRoles after revoke: %v", err)
}
if len(roles) != 0 {
t.Errorf("expected no roles after revoke, got %v", roles)
}
// SetRoles
if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil {
t.Fatalf("SetRoles: %v", err)
}
roles, err = db.GetRoles(acct.ID)
if err != nil {
t.Fatalf("GetRoles after SetRoles: %v", err)
}
if len(roles) != 2 {
t.Errorf("expected 2 roles after SetRoles, got %d", len(roles))
}
}
func TestTokenTrackingAndRevocation(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("grace", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
jti := "test-jti-1234"
issuedAt := time.Now().UTC()
expiresAt := issuedAt.Add(time.Hour)
if err := db.TrackToken(jti, acct.ID, issuedAt, expiresAt); err != nil {
t.Fatalf("TrackToken: %v", err)
}
// Retrieve
rec, err := db.GetTokenRecord(jti)
if err != nil {
t.Fatalf("GetTokenRecord: %v", err)
}
if rec.JTI != jti {
t.Errorf("JTI = %q, want %q", rec.JTI, jti)
}
if rec.IsRevoked() {
t.Error("newly tracked token should not be revoked")
}
// Revoke
if err := db.RevokeToken(jti, "test revocation"); err != nil {
t.Fatalf("RevokeToken: %v", err)
}
rec, err = db.GetTokenRecord(jti)
if err != nil {
t.Fatalf("GetTokenRecord after revoke: %v", err)
}
if !rec.IsRevoked() {
t.Error("token should be revoked after RevokeToken")
}
// Revoking again should fail (already revoked).
if err := db.RevokeToken(jti, "again"); err == nil {
t.Error("expected error when revoking already-revoked token")
}
}
func TestGetTokenRecordNotFound(t *testing.T) {
db := openTestDB(t)
_, err := db.GetTokenRecord("no-such-jti")
if err != ErrNotFound {
t.Errorf("expected ErrNotFound, got %v", err)
}
}
func TestPruneExpiredTokens(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("henry", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
past := time.Now().UTC().Add(-time.Hour)
future := time.Now().UTC().Add(time.Hour)
if err := db.TrackToken("expired-jti", acct.ID, past.Add(-time.Hour), past); err != nil {
t.Fatalf("TrackToken expired: %v", err)
}
if err := db.TrackToken("valid-jti", acct.ID, time.Now(), future); err != nil {
t.Fatalf("TrackToken valid: %v", err)
}
n, err := db.PruneExpiredTokens()
if err != nil {
t.Fatalf("PruneExpiredTokens: %v", err)
}
if n != 1 {
t.Errorf("pruned %d rows, want 1", n)
}
// Valid token should still be retrievable.
if _, err := db.GetTokenRecord("valid-jti"); err != nil {
t.Errorf("valid token missing after prune: %v", err)
}
}
func TestServerConfig(t *testing.T) {
db := openTestDB(t)
// No config initially.
_, _, err := db.ReadServerConfig()
if err != ErrNotFound {
t.Errorf("expected ErrNotFound for missing config, got %v", err)
}
enc := []byte("encrypted-key-data")
nonce := []byte("nonce12345678901")
if err := db.WriteServerConfig(enc, nonce); err != nil {
t.Fatalf("WriteServerConfig: %v", err)
}
gotEnc, gotNonce, err := db.ReadServerConfig()
if err != nil {
t.Fatalf("ReadServerConfig: %v", err)
}
if string(gotEnc) != string(enc) {
t.Errorf("enc mismatch: got %q, want %q", gotEnc, enc)
}
if string(gotNonce) != string(nonce) {
t.Errorf("nonce mismatch: got %q, want %q", gotNonce, nonce)
}
// Overwrite — should work without error.
if err := db.WriteServerConfig([]byte("new-key"), []byte("new-nonce123456")); err != nil {
t.Fatalf("WriteServerConfig overwrite: %v", err)
}
}
func TestForeignKeyEnforcement(t *testing.T) {
db := openTestDB(t)
// Attempting to track a token for a non-existent account should fail.
err := db.TrackToken("jti-x", 999999, time.Now(), time.Now().Add(time.Hour))
if err == nil {
t.Error("expected foreign key error for non-existent account_id, got nil")
}
}
func TestPGCredentials(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
enc := []byte("encrypted-pg-password")
nonce := []byte("pg-nonce12345678")
if err := db.WritePGCredentials(acct.ID, "localhost", 5432, "mydb", "myuser", enc, nonce); err != nil {
t.Fatalf("WritePGCredentials: %v", err)
}
cred, err := db.ReadPGCredentials(acct.ID)
if err != nil {
t.Fatalf("ReadPGCredentials: %v", err)
}
if cred.PGHost != "localhost" {
t.Errorf("PGHost = %q, want %q", cred.PGHost, "localhost")
}
if cred.PGDatabase != "mydb" {
t.Errorf("PGDatabase = %q, want %q", cred.PGDatabase, "mydb")
}
}
func TestRevokeAllUserTokens(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("ivan", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
future := time.Now().UTC().Add(time.Hour)
for _, jti := range []string{"tok1", "tok2", "tok3"} {
if err := db.TrackToken(jti, acct.ID, time.Now(), future); err != nil {
t.Fatalf("TrackToken %q: %v", jti, err)
}
}
if err := db.RevokeAllUserTokens(acct.ID, "account suspended"); err != nil {
t.Fatalf("RevokeAllUserTokens: %v", err)
}
for _, jti := range []string{"tok1", "tok2", "tok3"} {
rec, err := db.GetTokenRecord(jti)
if err != nil {
t.Fatalf("GetTokenRecord %q: %v", jti, err)
}
if !rec.IsRevoked() {
t.Errorf("token %q should be revoked", jti)
}
}
}

187
internal/db/migrate.go Normal file
View File

@@ -0,0 +1,187 @@
package db
import (
"database/sql"
"fmt"
)
// migration represents a single schema migration with an ID and SQL statement.
type migration struct {
id int
sql string
}
// migrations is the ordered list of schema migrations applied to the database.
// Once applied, migrations must never be modified — only new ones appended.
var migrations = []migration{
{
id: 1,
sql: `
CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS server_config (
id INTEGER PRIMARY KEY CHECK (id = 1),
signing_key_enc BLOB,
signing_key_nonce BLOB,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE TABLE IF NOT EXISTS accounts (
id INTEGER PRIMARY KEY,
uuid TEXT NOT NULL UNIQUE,
username TEXT NOT NULL UNIQUE COLLATE NOCASE,
account_type TEXT NOT NULL CHECK (account_type IN ('human','system')),
password_hash TEXT,
status TEXT NOT NULL DEFAULT 'active'
CHECK (status IN ('active','inactive','deleted')),
totp_required INTEGER NOT NULL DEFAULT 0 CHECK (totp_required IN (0,1)),
totp_secret_enc BLOB,
totp_secret_nonce BLOB,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
deleted_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts (username);
CREATE INDEX IF NOT EXISTS idx_accounts_uuid ON accounts (uuid);
CREATE INDEX IF NOT EXISTS idx_accounts_status ON accounts (status);
CREATE TABLE IF NOT EXISTS account_roles (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
role TEXT NOT NULL,
granted_by INTEGER REFERENCES accounts(id),
granted_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
UNIQUE (account_id, role)
);
CREATE INDEX IF NOT EXISTS idx_account_roles_account ON account_roles (account_id);
CREATE TABLE IF NOT EXISTS token_revocation (
id INTEGER PRIMARY KEY,
jti TEXT NOT NULL UNIQUE,
account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
expires_at TEXT NOT NULL,
revoked_at TEXT,
revoke_reason TEXT,
issued_at TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE INDEX IF NOT EXISTS idx_token_jti ON token_revocation (jti);
CREATE INDEX IF NOT EXISTS idx_token_account ON token_revocation (account_id);
CREATE INDEX IF NOT EXISTS idx_token_expires ON token_revocation (expires_at);
CREATE TABLE IF NOT EXISTS system_tokens (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE,
jti TEXT NOT NULL UNIQUE,
expires_at TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE TABLE IF NOT EXISTS pg_credentials (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE,
pg_host TEXT NOT NULL,
pg_port INTEGER NOT NULL DEFAULT 5432,
pg_database TEXT NOT NULL,
pg_username TEXT NOT NULL,
pg_password_enc BLOB NOT NULL,
pg_password_nonce BLOB NOT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE TABLE IF NOT EXISTS audit_log (
id INTEGER PRIMARY KEY,
event_time TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
event_type TEXT NOT NULL,
actor_id INTEGER REFERENCES accounts(id),
target_id INTEGER REFERENCES accounts(id),
ip_address TEXT,
details TEXT
);
CREATE INDEX IF NOT EXISTS idx_audit_time ON audit_log (event_time);
CREATE INDEX IF NOT EXISTS idx_audit_actor ON audit_log (actor_id);
CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log (event_type);
`,
},
{
id: 2,
sql: `
-- Add master_key_salt to server_config for Argon2id KDF salt storage.
-- The salt must be stable across restarts so the passphrase always yields the same key.
-- We allow NULL signing_key_enc/nonce temporarily until the first signing key is generated.
ALTER TABLE server_config ADD COLUMN master_key_salt BLOB;
`,
},
}
// Migrate applies any unapplied schema migrations to the database in order.
// It is idempotent: running it multiple times is safe.
func Migrate(db *DB) error {
// Ensure the schema_version table exists first.
if _, err := db.sql.Exec(`
CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER NOT NULL
)
`); err != nil {
return fmt.Errorf("db: ensure schema_version: %w", err)
}
currentVersion, err := currentSchemaVersion(db.sql)
if err != nil {
return fmt.Errorf("db: get current schema version: %w", err)
}
for _, m := range migrations {
if m.id <= currentVersion {
continue
}
tx, err := db.sql.Begin()
if err != nil {
return fmt.Errorf("db: begin migration %d transaction: %w", m.id, err)
}
if _, err := tx.Exec(m.sql); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: apply migration %d: %w", m.id, err)
}
// Update the schema version within the same transaction.
if currentVersion == 0 {
if _, err := tx.Exec(`INSERT INTO schema_version (version) VALUES (?)`, m.id); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: insert schema version %d: %w", m.id, err)
}
} else {
if _, err := tx.Exec(`UPDATE schema_version SET version = ?`, m.id); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: update schema version to %d: %w", m.id, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("db: commit migration %d: %w", m.id, err)
}
currentVersion = m.id
}
return nil
}
// currentSchemaVersion returns the current schema version, or 0 if none applied.
func currentSchemaVersion(db *sql.DB) (int, error) {
var version int
err := db.QueryRow(`SELECT version FROM schema_version LIMIT 1`).Scan(&version)
if err != nil {
// No rows means version 0 (fresh database).
return 0, nil //nolint:nilerr
}
return version, nil
}

View File

@@ -0,0 +1,290 @@
// Package middleware provides HTTP middleware for the MCIAS server.
//
// Security design:
// - RequireAuth extracts the Bearer token from the Authorization header,
// validates it (alg check, signature, expiry, issuer), and checks revocation
// against the database before injecting claims into the request context.
// - RequireRole checks claims from context for the required role.
// No role implies no access; the check fails closed.
// - RateLimit implements a per-IP token bucket to limit login brute-force.
// - RequestLogger logs request metadata but never logs the Authorization
// header value (which contains credential tokens).
package middleware
import (
"context"
"crypto/ed25519"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"strings"
"sync"
"time"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/token"
)
// contextKey is the unexported type for context keys in this package, preventing
// collisions with keys from other packages.
type contextKey int
const (
claimsKey contextKey = iota
)
// ClaimsFromContext retrieves the validated JWT claims from the request context.
// Returns nil if no claims are present (unauthenticated request).
func ClaimsFromContext(ctx context.Context) *token.Claims {
c, _ := ctx.Value(claimsKey).(*token.Claims)
return c
}
// RequestLogger returns middleware that logs each request at INFO level.
// The Authorization header is intentionally never logged.
func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Wrap the ResponseWriter to capture the status code.
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(rw, r)
logger.Info("request",
"method", r.Method,
"path", r.URL.Path,
"status", rw.status,
"duration_ms", time.Since(start).Milliseconds(),
"remote_addr", r.RemoteAddr,
"user_agent", r.UserAgent(),
// Security: Authorization header is never logged.
)
})
}
}
// responseWriter wraps http.ResponseWriter to capture the status code.
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}
// RequireAuth returns middleware that validates a Bearer JWT and injects the
// claims into the request context. Returns 401 on any auth failure.
//
// Security: Token validation order:
// 1. Extract Bearer token from Authorization header.
// 2. Validate the JWT (alg=EdDSA, signature, expiry, issuer).
// 3. Check the JTI against the revocation table in the database.
// 4. Inject validated claims into context for downstream handlers.
func RequireAuth(pubKey ed25519.PublicKey, database *db.DB, issuer string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenStr, err := extractBearerToken(r)
if err != nil {
writeError(w, http.StatusUnauthorized, "missing or malformed Authorization header", "unauthorized")
return
}
claims, err := token.ValidateToken(pubKey, tokenStr, issuer)
if err != nil {
// Security: Map all token errors to a generic 401; do not
// reveal which specific check failed.
if errors.Is(err, token.ErrExpiredToken) {
writeError(w, http.StatusUnauthorized, "token expired", "token_expired")
return
}
writeError(w, http.StatusUnauthorized, "invalid token", "unauthorized")
return
}
// Security: Check revocation table. A token may be cryptographically
// valid but explicitly revoked (logout, account suspension, etc.).
rec, err := database.GetTokenRecord(claims.JTI)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
// Token not tracked — could be from a different server instance
// or pre-dates tracking. Reject to be safe (fail closed).
writeError(w, http.StatusUnauthorized, "unrecognized token", "unauthorized")
return
}
writeError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
if rec.IsRevoked() {
writeError(w, http.StatusUnauthorized, "token has been revoked", "token_revoked")
return
}
ctx := context.WithValue(r.Context(), claimsKey, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RequireRole returns middleware that checks whether the authenticated user has
// the given role. Must be used after RequireAuth. Returns 403 if role is absent.
func RequireRole(role string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := ClaimsFromContext(r.Context())
if claims == nil {
// RequireAuth was not applied upstream; fail closed.
writeError(w, http.StatusForbidden, "forbidden", "forbidden")
return
}
if !claims.HasRole(role) {
writeError(w, http.StatusForbidden, "insufficient privileges", "forbidden")
return
}
next.ServeHTTP(w, r)
})
}
}
// rateLimitEntry holds the token bucket state for a single IP.
type rateLimitEntry struct {
tokens float64
lastSeen time.Time
mu sync.Mutex
}
// ipRateLimiter implements a per-IP token bucket rate limiter.
type ipRateLimiter struct {
rps float64 // refill rate: tokens per second
burst float64 // bucket capacity
ttl time.Duration // how long to keep idle entries
mu sync.Mutex
ips map[string]*rateLimitEntry
}
// RateLimit returns middleware implementing a per-IP token bucket.
// rps is the sustained request rate (tokens refilled per second).
// burst is the maximum burst size (initial and maximum token count).
//
// Security: Rate limiting is applied at the IP level. In production, the
// server should be behind a reverse proxy that sets X-Forwarded-For; this
// middleware uses RemoteAddr directly which may be the proxy IP. For single-
// instance deployment without a proxy, RemoteAddr is the client IP.
func RateLimit(rps float64, burst int) func(http.Handler) http.Handler {
limiter := &ipRateLimiter{
rps: rps,
burst: float64(burst),
ttl: 10 * time.Minute,
ips: make(map[string]*rateLimitEntry),
}
// Background cleanup of idle entries to prevent unbounded memory growth.
go limiter.cleanup()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}
if !limiter.allow(ip) {
w.Header().Set("Retry-After", "60")
writeError(w, http.StatusTooManyRequests, "rate limit exceeded", "rate_limited")
return
}
next.ServeHTTP(w, r)
})
}
}
// allow returns true if a request from ip is permitted under the rate limit.
func (l *ipRateLimiter) allow(ip string) bool {
l.mu.Lock()
entry, ok := l.ips[ip]
if !ok {
entry = &rateLimitEntry{tokens: l.burst, lastSeen: time.Now()}
l.ips[ip] = entry
}
l.mu.Unlock()
entry.mu.Lock()
defer entry.mu.Unlock()
now := time.Now()
elapsed := now.Sub(entry.lastSeen).Seconds()
entry.tokens = min(l.burst, entry.tokens+elapsed*l.rps)
entry.lastSeen = now
if entry.tokens < 1 {
return false
}
entry.tokens--
return true
}
// cleanup periodically removes idle rate-limit entries.
func (l *ipRateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
cutoff := time.Now().Add(-l.ttl)
for ip, entry := range l.ips {
entry.mu.Lock()
if entry.lastSeen.Before(cutoff) {
delete(l.ips, ip)
}
entry.mu.Unlock()
}
l.mu.Unlock()
}
}
// extractBearerToken extracts the token from "Authorization: Bearer <token>".
func extractBearerToken(r *http.Request) (string, error) {
auth := r.Header.Get("Authorization")
if auth == "" {
return "", fmt.Errorf("missing Authorization header")
}
parts := strings.SplitN(auth, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
return "", fmt.Errorf("malformed Authorization header")
}
if parts[1] == "" {
return "", fmt.Errorf("empty Bearer token")
}
return parts[1], nil
}
// apiError is the uniform error response structure.
type apiError struct {
Error string `json:"error"`
Code string `json:"code"`
}
// writeError writes a JSON error response.
func writeError(w http.ResponseWriter, status int, message, code string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
// Intentionally ignoring the error here; if the write fails, the client
// already got the status code.
_ = json.NewEncoder(w).Encode(apiError{Error: message, Code: code})
}
// WriteError is the exported version for use by handler packages.
func WriteError(w http.ResponseWriter, status int, message, code string) {
writeError(w, status, message, code)
}
// min returns the smaller of two float64 values.
func min(a, b float64) float64 {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,342 @@
package middleware
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/model"
"git.wntrmute.dev/kyle/mcias/internal/token"
)
func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate test key: %v", err)
}
return pub, priv
}
func openTestDB(t *testing.T) *db.DB {
t.Helper()
database, err := db.Open(":memory:")
if err != nil {
t.Fatalf("open test db: %v", err)
}
if err := db.Migrate(database); err != nil {
t.Fatalf("migrate test db: %v", err)
}
t.Cleanup(func() { _ = database.Close() })
return database
}
const testIssuer = "https://auth.example.com"
// issueAndTrackToken creates a valid JWT and records it in the DB.
func issueAndTrackToken(t *testing.T, priv ed25519.PrivateKey, database *db.DB, accountID int64, roles []string) string {
t.Helper()
tokenStr, claims, err := token.IssueToken(priv, testIssuer, "user-uuid", roles, time.Hour)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
if err := database.TrackToken(claims.JTI, accountID, claims.IssuedAt, claims.ExpiresAt); err != nil {
t.Fatalf("TrackToken: %v", err)
}
return tokenStr
}
func TestRequestLogger(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/health", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("status = %d, want 200", rr.Code)
}
logOutput := buf.String()
if logOutput == "" {
t.Error("expected log output, got empty string")
}
// Security: Authorization header must not appear in logs.
req2 := httptest.NewRequest(http.MethodGet, "/v1/health", nil)
req2.Header.Set("Authorization", "Bearer secret-token-value")
buf.Reset()
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
if bytes.Contains(buf.Bytes(), []byte("secret-token-value")) {
t.Error("log output contains Authorization token value — credential leak!")
}
}
func TestRequireAuthValid(t *testing.T) {
pub, priv := generateTestKey(t)
database := openTestDB(t)
acct, err := database.CreateAccount("alice", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
tokenStr := issueAndTrackToken(t, priv, database, acct.ID, []string{"reader"})
reached := false
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reached = true
claims := ClaimsFromContext(r.Context())
if claims == nil {
t.Error("claims not in context")
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("status = %d, want 200; body: %s", rr.Code, rr.Body.String())
}
if !reached {
t.Error("handler was not reached with valid token")
}
}
func TestRequireAuthMissingHeader(t *testing.T) {
pub, priv := generateTestKey(t)
_ = priv
database := openTestDB(t)
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be reached without auth")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", rr.Code)
}
}
func TestRequireAuthInvalidToken(t *testing.T) {
pub, _ := generateTestKey(t)
database := openTestDB(t)
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be reached with invalid token")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
req.Header.Set("Authorization", "Bearer not.a.valid.jwt")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", rr.Code)
}
}
func TestRequireAuthRevokedToken(t *testing.T) {
pub, priv := generateTestKey(t)
database := openTestDB(t)
acct, err := database.CreateAccount("bob", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
tokenStr := issueAndTrackToken(t, priv, database, acct.ID, nil)
// Extract JTI and revoke the token.
claims, err := token.ValidateToken(pub, tokenStr, testIssuer)
if err != nil {
t.Fatalf("ValidateToken: %v", err)
}
if err := database.RevokeToken(claims.JTI, "test revocation"); err != nil {
t.Fatalf("RevokeToken: %v", err)
}
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be reached with revoked token")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", rr.Code)
}
}
func TestRequireAuthExpiredToken(t *testing.T) {
pub, priv := generateTestKey(t)
database := openTestDB(t)
// Issue an already-expired token.
tokenStr, _, err := token.IssueToken(priv, testIssuer, "user-uuid", nil, -time.Minute)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be reached with expired token")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", rr.Code)
}
}
func TestRequireRoleGranted(t *testing.T) {
claims := &token.Claims{Roles: []string{"admin"}}
ctx := context.WithValue(context.Background(), claimsKey, claims)
reached := false
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reached = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("status = %d, want 200", rr.Code)
}
if !reached {
t.Error("handler not reached with correct role")
}
}
func TestRequireRoleForbidden(t *testing.T) {
claims := &token.Claims{Roles: []string{"reader"}}
ctx := context.WithValue(context.Background(), claimsKey, claims)
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be reached without admin role")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("status = %d, want 403", rr.Code)
}
}
func TestRequireRoleNoClaims(t *testing.T) {
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("handler should not be reached without claims in context")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("status = %d, want 403", rr.Code)
}
}
func TestRateLimitAllows(t *testing.T) {
handler := RateLimit(10, 5)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
req.RemoteAddr = "127.0.0.1:12345"
// First 5 requests should be allowed (burst=5).
for i := range 5 {
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("request %d: status = %d, want 200", i+1, rr.Code)
}
}
}
func TestRateLimitBlocks(t *testing.T) {
handler := RateLimit(0.1, 2)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
req.RemoteAddr = "10.0.0.1:9999"
// Exhaust the burst of 2.
for range 2 {
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}
// Next request should be rate-limited.
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusTooManyRequests {
t.Errorf("status = %d, want 429 after burst exceeded", rr.Code)
}
}
func TestExtractBearerToken(t *testing.T) {
tests := []struct {
name string
header string
wantErr bool
want string
}{
{"valid", "Bearer mytoken123", false, "mytoken123"},
{"missing header", "", true, ""},
{"no bearer prefix", "Token mytoken123", true, ""},
{"empty token", "Bearer ", true, ""},
{"case insensitive", "bearer mytoken123", false, "mytoken123"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.header != "" {
req.Header.Set("Authorization", tc.header)
}
got, err := extractBearerToken(req)
if (err != nil) != tc.wantErr {
t.Errorf("wantErr=%v, got err=%v", tc.wantErr, err)
}
if !tc.wantErr && got != tc.want {
t.Errorf("token = %q, want %q", got, tc.want)
}
})
}
}

144
internal/model/model.go Normal file
View File

@@ -0,0 +1,144 @@
// Package model defines the shared data types used throughout MCIAS.
// These are pure data definitions with no external dependencies.
package model
import "time"
// AccountType distinguishes human interactive accounts from non-interactive
// service accounts.
type AccountType string
const (
AccountTypeHuman AccountType = "human"
AccountTypeSystem AccountType = "system"
)
// AccountStatus represents the lifecycle state of an account.
type AccountStatus string
const (
AccountStatusActive AccountStatus = "active"
AccountStatusInactive AccountStatus = "inactive"
AccountStatusDeleted AccountStatus = "deleted"
)
// Account represents a user or service identity in MCIAS.
// Fields containing credential material (PasswordHash, TOTPSecretEnc) are
// never serialised into API responses — callers must explicitly omit them.
type Account struct {
ID int64 `json:"-"`
UUID string `json:"id"`
Username string `json:"username"`
AccountType AccountType `json:"account_type"`
Status AccountStatus `json:"status"`
TOTPRequired bool `json:"totp_required"`
// PasswordHash is a PHC-format Argon2id string. Never returned in API
// responses; populated only when reading from the database.
PasswordHash string `json:"-"`
// TOTPSecretEnc and TOTPSecretNonce hold the AES-256-GCM-encrypted TOTP
// shared secret. Never returned in API responses.
TOTPSecretEnc []byte `json:"-"`
TOTPSecretNonce []byte `json:"-"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt *time.Time `json:"deleted_at,omitempty"`
}
// Role is a string label assigned to an account to grant permissions.
type Role struct {
ID int64 `json:"-"`
AccountID int64 `json:"-"`
Role string `json:"role"`
GrantedBy *int64 `json:"-"`
GrantedAt time.Time `json:"granted_at"`
}
// TokenRecord tracks an issued JWT by its JTI for revocation purposes.
// The raw token string is never stored — only the JTI identifier.
type TokenRecord struct {
ID int64 `json:"-"`
JTI string `json:"jti"`
AccountID int64 `json:"-"`
ExpiresAt time.Time `json:"expires_at"`
IssuedAt time.Time `json:"issued_at"`
RevokedAt *time.Time `json:"revoked_at,omitempty"`
RevokeReason string `json:"revoke_reason,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// IsRevoked reports whether the token has been explicitly revoked.
func (t *TokenRecord) IsRevoked() bool {
return t.RevokedAt != nil
}
// IsExpired reports whether the token is past its expiry time.
func (t *TokenRecord) IsExpired() bool {
return time.Now().After(t.ExpiresAt)
}
// SystemToken represents the current active service token for a system account.
type SystemToken struct {
ID int64 `json:"-"`
AccountID int64 `json:"-"`
JTI string `json:"jti"`
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
}
// PGCredential holds Postgres connection details for a system account.
// The password is encrypted at rest; PGPassword is only populated after
// decryption and must never be logged or included in API responses.
type PGCredential struct {
ID int64 `json:"-"`
AccountID int64 `json:"-"`
PGHost string `json:"host"`
PGPort int `json:"port"`
PGDatabase string `json:"database"`
PGUsername string `json:"username"`
// PGPassword is plaintext only after decryption. Never log or serialise.
PGPassword string `json:"-"`
// PGPasswordEnc and PGPasswordNonce are the AES-256-GCM ciphertext and
// nonce stored in the database.
PGPasswordEnc []byte `json:"-"`
PGPasswordNonce []byte `json:"-"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AuditEvent represents a single entry in the append-only audit log.
// Details must never contain credential material (passwords, tokens, secrets).
type AuditEvent struct {
ID int64 `json:"id"`
EventTime time.Time `json:"event_time"`
EventType string `json:"event_type"`
ActorID *int64 `json:"-"`
TargetID *int64 `json:"-"`
IPAddress string `json:"ip_address,omitempty"`
Details string `json:"details,omitempty"` // JSON string; no secrets
}
// Audit event type constants — exhaustive list, enforced at write time.
const (
EventLoginOK = "login_ok"
EventLoginFail = "login_fail"
EventLoginTOTPFail = "login_totp_fail"
EventTokenIssued = "token_issued"
EventTokenRenewed = "token_renewed"
EventTokenRevoked = "token_revoked"
EventTokenExpired = "token_expired"
EventAccountCreated = "account_created"
EventAccountUpdated = "account_updated"
EventAccountDeleted = "account_deleted"
EventRoleGranted = "role_granted"
EventRoleRevoked = "role_revoked"
EventTOTPEnrolled = "totp_enrolled"
EventTOTPRemoved = "totp_removed"
EventPGCredAccessed = "pgcred_accessed"
EventPGCredUpdated = "pgcred_updated"
)

View File

@@ -0,0 +1,83 @@
package model
import (
"testing"
"time"
)
func TestAccountTypeConstants(t *testing.T) {
if AccountTypeHuman != "human" {
t.Errorf("AccountTypeHuman = %q, want %q", AccountTypeHuman, "human")
}
if AccountTypeSystem != "system" {
t.Errorf("AccountTypeSystem = %q, want %q", AccountTypeSystem, "system")
}
}
func TestAccountStatusConstants(t *testing.T) {
if AccountStatusActive != "active" {
t.Errorf("AccountStatusActive = %q, want %q", AccountStatusActive, "active")
}
if AccountStatusInactive != "inactive" {
t.Errorf("AccountStatusInactive = %q, want %q", AccountStatusInactive, "inactive")
}
if AccountStatusDeleted != "deleted" {
t.Errorf("AccountStatusDeleted = %q, want %q", AccountStatusDeleted, "deleted")
}
}
func TestTokenRecordIsRevoked(t *testing.T) {
now := time.Now()
notRevoked := &TokenRecord{}
if notRevoked.IsRevoked() {
t.Error("expected token with nil RevokedAt to not be revoked")
}
revoked := &TokenRecord{RevokedAt: &now}
if !revoked.IsRevoked() {
t.Error("expected token with RevokedAt set to be revoked")
}
}
func TestTokenRecordIsExpired(t *testing.T) {
past := time.Now().Add(-time.Hour)
future := time.Now().Add(time.Hour)
expired := &TokenRecord{ExpiresAt: past}
if !expired.IsExpired() {
t.Error("expected token with past ExpiresAt to be expired")
}
valid := &TokenRecord{ExpiresAt: future}
if valid.IsExpired() {
t.Error("expected token with future ExpiresAt to not be expired")
}
}
func TestAuditEventConstants(t *testing.T) {
// Spot-check a few to ensure they are not empty strings.
events := []string{
EventLoginOK,
EventLoginFail,
EventLoginTOTPFail,
EventTokenIssued,
EventTokenRenewed,
EventTokenRevoked,
EventTokenExpired,
EventAccountCreated,
EventAccountUpdated,
EventAccountDeleted,
EventRoleGranted,
EventRoleRevoked,
EventTOTPEnrolled,
EventTOTPRemoved,
EventPGCredAccessed,
EventPGCredUpdated,
}
for _, e := range events {
if e == "" {
t.Errorf("audit event constant is empty string")
}
}
}

904
internal/server/server.go Normal file
View File

@@ -0,0 +1,904 @@
// Package server wires together the HTTP router, middleware, and handlers
// for the MCIAS authentication server.
//
// Security design:
// - All endpoints use HTTPS (enforced at the listener level in cmd/mciassrv).
// - Authentication state is carried via JWT; no cookies or server-side sessions.
// - Credential fields (password hash, TOTP secret, Postgres password) are
// never included in any API response.
// - All JSON parsing uses strict decoders that reject unknown fields.
package server
import (
"crypto/ed25519"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"git.wntrmute.dev/kyle/mcias/internal/auth"
"git.wntrmute.dev/kyle/mcias/internal/config"
"git.wntrmute.dev/kyle/mcias/internal/crypto"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/middleware"
"git.wntrmute.dev/kyle/mcias/internal/model"
"git.wntrmute.dev/kyle/mcias/internal/token"
)
// Server holds the dependencies injected into all handlers.
type Server struct {
db *db.DB
cfg *config.Config
privKey ed25519.PrivateKey
pubKey ed25519.PublicKey
masterKey []byte
logger *slog.Logger
}
// New creates a Server with the given dependencies.
func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed25519.PublicKey, masterKey []byte, logger *slog.Logger) *Server {
return &Server{
db: database,
cfg: cfg,
privKey: priv,
pubKey: pub,
masterKey: masterKey,
logger: logger,
}
}
// Handler builds and returns the root HTTP handler with all routes and middleware.
func (s *Server) Handler() http.Handler {
mux := http.NewServeMux()
// Public endpoints (no authentication required).
mux.HandleFunc("GET /v1/health", s.handleHealth)
mux.HandleFunc("GET /v1/keys/public", s.handlePublicKey)
mux.HandleFunc("POST /v1/auth/login", s.handleLogin)
mux.HandleFunc("POST /v1/token/validate", s.handleTokenValidate)
// Authenticated endpoints.
requireAuth := middleware.RequireAuth(s.pubKey, s.db, s.cfg.Tokens.Issuer)
requireAdmin := func(h http.Handler) http.Handler {
return requireAuth(middleware.RequireRole("admin")(h))
}
// Auth endpoints (require valid token).
mux.Handle("POST /v1/auth/logout", requireAuth(http.HandlerFunc(s.handleLogout)))
mux.Handle("POST /v1/auth/renew", requireAuth(http.HandlerFunc(s.handleRenew)))
mux.Handle("POST /v1/auth/totp/enroll", requireAuth(http.HandlerFunc(s.handleTOTPEnroll)))
mux.Handle("POST /v1/auth/totp/confirm", requireAuth(http.HandlerFunc(s.handleTOTPConfirm)))
// Admin-only endpoints.
mux.Handle("DELETE /v1/auth/totp", requireAdmin(http.HandlerFunc(s.handleTOTPRemove)))
mux.Handle("POST /v1/token/issue", requireAdmin(http.HandlerFunc(s.handleTokenIssue)))
mux.Handle("DELETE /v1/token/{jti}", requireAdmin(http.HandlerFunc(s.handleTokenRevoke)))
mux.Handle("GET /v1/accounts", requireAdmin(http.HandlerFunc(s.handleListAccounts)))
mux.Handle("POST /v1/accounts", requireAdmin(http.HandlerFunc(s.handleCreateAccount)))
mux.Handle("GET /v1/accounts/{id}", requireAdmin(http.HandlerFunc(s.handleGetAccount)))
mux.Handle("PATCH /v1/accounts/{id}", requireAdmin(http.HandlerFunc(s.handleUpdateAccount)))
mux.Handle("DELETE /v1/accounts/{id}", requireAdmin(http.HandlerFunc(s.handleDeleteAccount)))
mux.Handle("GET /v1/accounts/{id}/roles", requireAdmin(http.HandlerFunc(s.handleGetRoles)))
mux.Handle("PUT /v1/accounts/{id}/roles", requireAdmin(http.HandlerFunc(s.handleSetRoles)))
mux.Handle("GET /v1/accounts/{id}/pgcreds", requireAdmin(http.HandlerFunc(s.handleGetPGCreds)))
mux.Handle("PUT /v1/accounts/{id}/pgcreds", requireAdmin(http.HandlerFunc(s.handleSetPGCreds)))
// Apply global middleware: logging and login-path rate limiting.
var root http.Handler = mux
root = middleware.RequestLogger(s.logger)(root)
return root
}
// ---- Public handlers ----
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
// handlePublicKey returns the server's Ed25519 public key in JWK format.
// This allows relying parties to independently verify JWTs.
func (s *Server) handlePublicKey(w http.ResponseWriter, r *http.Request) {
// Encode the Ed25519 public key as a JWK (RFC 8037).
// The "x" parameter is the base64url-encoded public key bytes.
jwk := map[string]string{
"kty": "OKP",
"crv": "Ed25519",
"use": "sig",
"alg": "EdDSA",
"x": encodeBase64URL(s.pubKey),
}
writeJSON(w, http.StatusOK, jwk)
}
// ---- Auth handlers ----
// loginRequest is the request body for POST /v1/auth/login.
type loginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
TOTPCode string `json:"totp_code,omitempty"`
}
// loginResponse is the response body for a successful login.
type loginResponse struct {
Token string `json:"token"`
ExpiresAt string `json:"expires_at"`
}
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
var req loginRequest
if !decodeJSON(w, r, &req) {
return
}
if req.Username == "" || req.Password == "" {
middleware.WriteError(w, http.StatusBadRequest, "username and password are required", "bad_request")
return
}
// Load account by username.
acct, err := s.db.GetAccountByUsername(req.Username)
if err != nil {
// Security: return a generic error whether the user exists or not.
// Always run a dummy Argon2 check to prevent timing-based user enumeration.
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
s.writeAudit(r, model.EventLoginFail, nil, nil, fmt.Sprintf(`{"username":%q,"reason":"unknown_user"}`, req.Username))
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
return
}
// Security: Check account status before credential verification to avoid
// leaking whether the account exists based on timing differences.
if acct.Status != model.AccountStatusActive {
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, fmt.Sprintf(`{"reason":"account_inactive"}`))
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
return
}
// Verify password. This is always run, even for system accounts (which have
// no password hash), to maintain constant timing.
ok, err := auth.VerifyPassword(req.Password, acct.PasswordHash)
if err != nil || !ok {
s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, `{"reason":"wrong_password"}`)
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
return
}
// TOTP check (if enrolled).
if acct.TOTPRequired {
if req.TOTPCode == "" {
s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, `{"reason":"totp_missing"}`)
middleware.WriteError(w, http.StatusUnauthorized, "TOTP code required", "totp_required")
return
}
// Decrypt the TOTP secret.
secret, err := crypto.OpenAESGCM(s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
if err != nil {
s.logger.Error("decrypt TOTP secret", "error", err, "account_id", acct.ID)
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
valid, err := auth.ValidateTOTP(secret, req.TOTPCode)
if err != nil || !valid {
s.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"wrong_totp"}`)
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
return
}
}
// Determine expiry.
expiry := s.cfg.DefaultExpiry()
roles, err := s.db.GetRoles(acct.ID)
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
for _, r := range roles {
if r == "admin" {
expiry = s.cfg.AdminExpiry()
break
}
}
tokenStr, claims, err := token.IssueToken(s.privKey, s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
if err != nil {
s.logger.Error("issue token", "error", err)
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
if err := s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
s.logger.Error("track token", "error", err)
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
s.writeAudit(r, model.EventLoginOK, &acct.ID, nil, "")
s.writeAudit(r, model.EventTokenIssued, &acct.ID, nil, fmt.Sprintf(`{"jti":%q}`, claims.JTI))
writeJSON(w, http.StatusOK, loginResponse{
Token: tokenStr,
ExpiresAt: claims.ExpiresAt.Format("2006-01-02T15:04:05Z"),
})
}
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
if err := s.db.RevokeToken(claims.JTI, "logout"); err != nil {
s.logger.Error("revoke token on logout", "error", err)
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
s.writeAudit(r, model.EventTokenRevoked, nil, nil, fmt.Sprintf(`{"jti":%q,"reason":"logout"}`, claims.JTI))
w.WriteHeader(http.StatusNoContent)
}
func (s *Server) handleRenew(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
// Load account to get current roles (they may have changed since token issuance).
acct, err := s.db.GetAccountByUUID(claims.Subject)
if err != nil {
middleware.WriteError(w, http.StatusUnauthorized, "account not found", "unauthorized")
return
}
if acct.Status != model.AccountStatusActive {
middleware.WriteError(w, http.StatusUnauthorized, "account inactive", "unauthorized")
return
}
roles, err := s.db.GetRoles(acct.ID)
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
expiry := s.cfg.DefaultExpiry()
for _, role := range roles {
if role == "admin" {
expiry = s.cfg.AdminExpiry()
break
}
}
newTokenStr, newClaims, err := token.IssueToken(s.privKey, s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
// Revoke the old token and track the new one atomically is not possible
// in SQLite without a transaction. We do best-effort: revoke old, track new.
if err := s.db.RevokeToken(claims.JTI, "renewed"); err != nil {
s.logger.Error("revoke old token on renew", "error", err)
}
if err := s.db.TrackToken(newClaims.JTI, acct.ID, newClaims.IssuedAt, newClaims.ExpiresAt); err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
s.writeAudit(r, model.EventTokenRenewed, &acct.ID, nil, fmt.Sprintf(`{"old_jti":%q,"new_jti":%q}`, claims.JTI, newClaims.JTI))
writeJSON(w, http.StatusOK, loginResponse{
Token: newTokenStr,
ExpiresAt: newClaims.ExpiresAt.Format("2006-01-02T15:04:05Z"),
})
}
// ---- Token endpoints ----
type validateRequest struct {
Token string `json:"token"`
}
type validateResponse struct {
Valid bool `json:"valid"`
Subject string `json:"sub,omitempty"`
Roles []string `json:"roles,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
}
func (s *Server) handleTokenValidate(w http.ResponseWriter, r *http.Request) {
// Accept token either from Authorization: Bearer header or JSON body.
tokenStr, err := extractBearerFromRequest(r)
if err != nil {
// Try JSON body.
var req validateRequest
if !decodeJSON(w, r, &req) {
return
}
tokenStr = req.Token
}
if tokenStr == "" {
writeJSON(w, http.StatusOK, validateResponse{Valid: false})
return
}
claims, err := token.ValidateToken(s.pubKey, tokenStr, s.cfg.Tokens.Issuer)
if err != nil {
writeJSON(w, http.StatusOK, validateResponse{Valid: false})
return
}
rec, err := s.db.GetTokenRecord(claims.JTI)
if err != nil || rec.IsRevoked() {
writeJSON(w, http.StatusOK, validateResponse{Valid: false})
return
}
writeJSON(w, http.StatusOK, validateResponse{
Valid: true,
Subject: claims.Subject,
Roles: claims.Roles,
ExpiresAt: claims.ExpiresAt.Format("2006-01-02T15:04:05Z"),
})
}
type issueTokenRequest struct {
AccountID string `json:"account_id"`
}
func (s *Server) handleTokenIssue(w http.ResponseWriter, r *http.Request) {
var req issueTokenRequest
if !decodeJSON(w, r, &req) {
return
}
acct, err := s.db.GetAccountByUUID(req.AccountID)
if err != nil {
middleware.WriteError(w, http.StatusNotFound, "account not found", "not_found")
return
}
if acct.AccountType != model.AccountTypeSystem {
middleware.WriteError(w, http.StatusBadRequest, "token issue is only for system accounts", "bad_request")
return
}
tokenStr, claims, err := token.IssueToken(s.privKey, s.cfg.Tokens.Issuer, acct.UUID, nil, s.cfg.ServiceExpiry())
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
// Revoke existing system token if any.
existing, err := s.db.GetSystemToken(acct.ID)
if err == nil && existing != nil {
_ = s.db.RevokeToken(existing.JTI, "rotated")
}
if err := s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
if err := s.db.SetSystemToken(acct.ID, claims.JTI, claims.ExpiresAt); err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
actor := middleware.ClaimsFromContext(r.Context())
var actorID *int64
if actor != nil {
if a, err := s.db.GetAccountByUUID(actor.Subject); err == nil {
actorID = &a.ID
}
}
s.writeAudit(r, model.EventTokenIssued, actorID, &acct.ID, fmt.Sprintf(`{"jti":%q}`, claims.JTI))
writeJSON(w, http.StatusOK, loginResponse{
Token: tokenStr,
ExpiresAt: claims.ExpiresAt.Format("2006-01-02T15:04:05Z"),
})
}
func (s *Server) handleTokenRevoke(w http.ResponseWriter, r *http.Request) {
jti := r.PathValue("jti")
if jti == "" {
middleware.WriteError(w, http.StatusBadRequest, "jti is required", "bad_request")
return
}
if err := s.db.RevokeToken(jti, "admin revocation"); err != nil {
middleware.WriteError(w, http.StatusNotFound, "token not found or already revoked", "not_found")
return
}
s.writeAudit(r, model.EventTokenRevoked, nil, nil, fmt.Sprintf(`{"jti":%q}`, jti))
w.WriteHeader(http.StatusNoContent)
}
// ---- Account endpoints ----
type createAccountRequest struct {
Username string `json:"username"`
Password string `json:"password,omitempty"`
Type string `json:"account_type"`
}
type accountResponse struct {
ID string `json:"id"`
Username string `json:"username"`
AccountType string `json:"account_type"`
Status string `json:"status"`
TOTPEnabled bool `json:"totp_enabled"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
func accountToResponse(a *model.Account) accountResponse {
resp := accountResponse{
ID: a.UUID,
Username: a.Username,
AccountType: string(a.AccountType),
Status: string(a.Status),
TOTPEnabled: a.TOTPRequired,
CreatedAt: a.CreatedAt.Format("2006-01-02T15:04:05Z"),
UpdatedAt: a.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
return resp
}
func (s *Server) handleListAccounts(w http.ResponseWriter, r *http.Request) {
accounts, err := s.db.ListAccounts()
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
resp := make([]accountResponse, len(accounts))
for i, a := range accounts {
resp[i] = accountToResponse(a)
}
writeJSON(w, http.StatusOK, resp)
}
func (s *Server) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
var req createAccountRequest
if !decodeJSON(w, r, &req) {
return
}
if req.Username == "" {
middleware.WriteError(w, http.StatusBadRequest, "username is required", "bad_request")
return
}
accountType := model.AccountType(req.Type)
if accountType != model.AccountTypeHuman && accountType != model.AccountTypeSystem {
middleware.WriteError(w, http.StatusBadRequest, "account_type must be 'human' or 'system'", "bad_request")
return
}
var passwordHash string
if accountType == model.AccountTypeHuman {
if req.Password == "" {
middleware.WriteError(w, http.StatusBadRequest, "password is required for human accounts", "bad_request")
return
}
var err error
passwordHash, err = auth.HashPassword(req.Password, auth.ArgonParams{
Time: s.cfg.Argon2.Time,
Memory: s.cfg.Argon2.Memory,
Threads: s.cfg.Argon2.Threads,
})
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
}
acct, err := s.db.CreateAccount(req.Username, accountType, passwordHash)
if err != nil {
middleware.WriteError(w, http.StatusConflict, "username already exists", "conflict")
return
}
s.writeAudit(r, model.EventAccountCreated, nil, &acct.ID, fmt.Sprintf(`{"username":%q}`, acct.Username))
writeJSON(w, http.StatusCreated, accountToResponse(acct))
}
func (s *Server) handleGetAccount(w http.ResponseWriter, r *http.Request) {
acct, ok := s.loadAccount(w, r)
if !ok {
return
}
writeJSON(w, http.StatusOK, accountToResponse(acct))
}
type updateAccountRequest struct {
Status string `json:"status,omitempty"`
}
func (s *Server) handleUpdateAccount(w http.ResponseWriter, r *http.Request) {
acct, ok := s.loadAccount(w, r)
if !ok {
return
}
var req updateAccountRequest
if !decodeJSON(w, r, &req) {
return
}
if req.Status != "" {
newStatus := model.AccountStatus(req.Status)
if newStatus != model.AccountStatusActive && newStatus != model.AccountStatusInactive {
middleware.WriteError(w, http.StatusBadRequest, "status must be 'active' or 'inactive'", "bad_request")
return
}
if err := s.db.UpdateAccountStatus(acct.ID, newStatus); err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
}
s.writeAudit(r, model.EventAccountUpdated, nil, &acct.ID, "")
w.WriteHeader(http.StatusNoContent)
}
func (s *Server) handleDeleteAccount(w http.ResponseWriter, r *http.Request) {
acct, ok := s.loadAccount(w, r)
if !ok {
return
}
if err := s.db.UpdateAccountStatus(acct.ID, model.AccountStatusDeleted); err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
if err := s.db.RevokeAllUserTokens(acct.ID, "account deleted"); err != nil {
s.logger.Error("revoke tokens on delete", "error", err, "account_id", acct.ID)
}
s.writeAudit(r, model.EventAccountDeleted, nil, &acct.ID, "")
w.WriteHeader(http.StatusNoContent)
}
// ---- Role endpoints ----
type rolesResponse struct {
Roles []string `json:"roles"`
}
type setRolesRequest struct {
Roles []string `json:"roles"`
}
func (s *Server) handleGetRoles(w http.ResponseWriter, r *http.Request) {
acct, ok := s.loadAccount(w, r)
if !ok {
return
}
roles, err := s.db.GetRoles(acct.ID)
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
if roles == nil {
roles = []string{}
}
writeJSON(w, http.StatusOK, rolesResponse{Roles: roles})
}
func (s *Server) handleSetRoles(w http.ResponseWriter, r *http.Request) {
acct, ok := s.loadAccount(w, r)
if !ok {
return
}
var req setRolesRequest
if !decodeJSON(w, r, &req) {
return
}
actor := middleware.ClaimsFromContext(r.Context())
var grantedBy *int64
if actor != nil {
if a, err := s.db.GetAccountByUUID(actor.Subject); err == nil {
grantedBy = &a.ID
}
}
if err := s.db.SetRoles(acct.ID, req.Roles, grantedBy); err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
s.writeAudit(r, model.EventRoleGranted, grantedBy, &acct.ID, fmt.Sprintf(`{"roles":%v}`, req.Roles))
w.WriteHeader(http.StatusNoContent)
}
// ---- TOTP endpoints ----
type totpEnrollResponse struct {
Secret string `json:"secret"` // base32-encoded
OTPAuthURI string `json:"otpauth_uri"`
}
type totpConfirmRequest struct {
Code string `json:"code"`
}
func (s *Server) handleTOTPEnroll(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
acct, err := s.db.GetAccountByUUID(claims.Subject)
if err != nil {
middleware.WriteError(w, http.StatusUnauthorized, "account not found", "unauthorized")
return
}
rawSecret, b32Secret, err := auth.GenerateTOTPSecret()
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
// Encrypt the secret before storing it temporarily.
// Note: we store as pending; enrollment is confirmed with /confirm.
secretEnc, secretNonce, err := crypto.SealAESGCM(s.masterKey, rawSecret)
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
// Store the encrypted pending secret. The totp_required flag is NOT set
// yet — it is set only after the user confirms the code.
if err := s.db.SetTOTP(acct.ID, secretEnc, secretNonce); err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
otpURI := fmt.Sprintf("otpauth://totp/MCIAS:%s?secret=%s&issuer=MCIAS", acct.Username, b32Secret)
// Security: return the secret for display to the user. It is only shown
// once; subsequent reads are not possible (only the encrypted form is stored).
writeJSON(w, http.StatusOK, totpEnrollResponse{
Secret: b32Secret,
OTPAuthURI: otpURI,
})
}
func (s *Server) handleTOTPConfirm(w http.ResponseWriter, r *http.Request) {
var req totpConfirmRequest
if !decodeJSON(w, r, &req) {
return
}
claims := middleware.ClaimsFromContext(r.Context())
acct, err := s.db.GetAccountByUUID(claims.Subject)
if err != nil {
middleware.WriteError(w, http.StatusUnauthorized, "account not found", "unauthorized")
return
}
if acct.TOTPSecretEnc == nil {
middleware.WriteError(w, http.StatusBadRequest, "TOTP enrollment not started", "bad_request")
return
}
secret, err := crypto.OpenAESGCM(s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
valid, err := auth.ValidateTOTP(secret, req.Code)
if err != nil || !valid {
middleware.WriteError(w, http.StatusUnauthorized, "invalid TOTP code", "unauthorized")
return
}
// Mark TOTP as confirmed and required.
if err := s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
s.writeAudit(r, model.EventTOTPEnrolled, &acct.ID, nil, "")
w.WriteHeader(http.StatusNoContent)
}
type totpRemoveRequest struct {
AccountID string `json:"account_id"`
}
func (s *Server) handleTOTPRemove(w http.ResponseWriter, r *http.Request) {
var req totpRemoveRequest
if !decodeJSON(w, r, &req) {
return
}
acct, err := s.db.GetAccountByUUID(req.AccountID)
if err != nil {
middleware.WriteError(w, http.StatusNotFound, "account not found", "not_found")
return
}
if err := s.db.ClearTOTP(acct.ID); err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
s.writeAudit(r, model.EventTOTPRemoved, nil, &acct.ID, "")
w.WriteHeader(http.StatusNoContent)
}
// ---- Postgres credential endpoints ----
type pgCredRequest struct {
Host string `json:"host"`
Port int `json:"port"`
Database string `json:"database"`
Username string `json:"username"`
Password string `json:"password"`
}
type pgCredResponse struct {
Host string `json:"host"`
Port int `json:"port"`
Database string `json:"database"`
Username string `json:"username"`
// Security: Password is NEVER included in the response, even on GET.
// The caller must explicitly decrypt it on the server side.
}
func (s *Server) handleGetPGCreds(w http.ResponseWriter, r *http.Request) {
acct, ok := s.loadAccount(w, r)
if !ok {
return
}
cred, err := s.db.ReadPGCredentials(acct.ID)
if err != nil {
if err == db.ErrNotFound {
middleware.WriteError(w, http.StatusNotFound, "no credentials stored", "not_found")
return
}
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
// Decrypt the password to return it to the admin caller.
password, err := crypto.OpenAESGCM(s.masterKey, cred.PGPasswordNonce, cred.PGPasswordEnc)
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
s.writeAudit(r, model.EventPGCredAccessed, nil, &acct.ID, "")
// Return including password since this is an explicit admin retrieval.
writeJSON(w, http.StatusOK, map[string]interface{}{
"host": cred.PGHost,
"port": cred.PGPort,
"database": cred.PGDatabase,
"username": cred.PGUsername,
"password": string(password), // included only for admin retrieval
})
}
func (s *Server) handleSetPGCreds(w http.ResponseWriter, r *http.Request) {
acct, ok := s.loadAccount(w, r)
if !ok {
return
}
var req pgCredRequest
if !decodeJSON(w, r, &req) {
return
}
if req.Host == "" || req.Database == "" || req.Username == "" || req.Password == "" {
middleware.WriteError(w, http.StatusBadRequest, "host, database, username, and password are required", "bad_request")
return
}
if req.Port == 0 {
req.Port = 5432
}
enc, nonce, err := crypto.SealAESGCM(s.masterKey, []byte(req.Password))
if err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
if err := s.db.WritePGCredentials(acct.ID, req.Host, req.Port, req.Database, req.Username, enc, nonce); err != nil {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
s.writeAudit(r, model.EventPGCredUpdated, nil, &acct.ID, "")
w.WriteHeader(http.StatusNoContent)
}
// ---- Helpers ----
// loadAccount retrieves an account by the {id} path parameter (UUID).
func (s *Server) loadAccount(w http.ResponseWriter, r *http.Request) (*model.Account, bool) {
id := r.PathValue("id")
if id == "" {
middleware.WriteError(w, http.StatusBadRequest, "account id is required", "bad_request")
return nil, false
}
acct, err := s.db.GetAccountByUUID(id)
if err != nil {
if err == db.ErrNotFound {
middleware.WriteError(w, http.StatusNotFound, "account not found", "not_found")
return nil, false
}
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return nil, false
}
return acct, true
}
// writeAudit appends an audit log entry, logging errors but not failing the request.
func (s *Server) writeAudit(r *http.Request, eventType string, actorID, targetID *int64, details string) {
ip := r.RemoteAddr
if err := s.db.WriteAuditEvent(eventType, actorID, targetID, ip, details); err != nil {
s.logger.Error("write audit event", "error", err, "event_type", eventType)
}
}
// writeJSON encodes v as JSON and writes it to w with the given status code.
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil {
// If encoding fails, the status is already written; log but don't panic.
_ = err
}
}
// decodeJSON decodes a JSON request body into v.
// Returns false and writes a 400 response if decoding fails.
func decodeJSON(w http.ResponseWriter, r *http.Request, v interface{}) bool {
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(v); err != nil {
middleware.WriteError(w, http.StatusBadRequest, "invalid JSON request body", "bad_request")
return false
}
return true
}
// extractBearerFromRequest extracts a Bearer token from the Authorization header.
func extractBearerFromRequest(r *http.Request) (string, error) {
auth := r.Header.Get("Authorization")
if auth == "" {
return "", fmt.Errorf("no Authorization header")
}
const prefix = "Bearer "
if len(auth) <= len(prefix) {
return "", fmt.Errorf("malformed Authorization header")
}
return auth[len(prefix):], nil
}
// encodeBase64URL encodes bytes as base64url without padding.
func encodeBase64URL(b []byte) string {
const table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
result := make([]byte, 0, (len(b)*4+2)/3)
for i := 0; i < len(b); i += 3 {
switch {
case i+2 < len(b):
result = append(result,
table[b[i]>>2],
table[(b[i]&3)<<4|b[i+1]>>4],
table[(b[i+1]&0xf)<<2|b[i+2]>>6],
table[b[i+2]&0x3f],
)
case i+1 < len(b):
result = append(result,
table[b[i]>>2],
table[(b[i]&3)<<4|b[i+1]>>4],
table[(b[i+1]&0xf)<<2],
)
default:
result = append(result,
table[b[i]>>2],
table[(b[i]&3)<<4],
)
}
}
return string(result)
}

View File

@@ -0,0 +1,434 @@
package server
import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"git.wntrmute.dev/kyle/mcias/internal/auth"
"git.wntrmute.dev/kyle/mcias/internal/config"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/model"
"git.wntrmute.dev/kyle/mcias/internal/token"
)
const testIssuer = "https://auth.example.com"
func newTestServer(t *testing.T) (*Server, ed25519.PublicKey, ed25519.PrivateKey, *db.DB) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate key: %v", err)
}
database, err := db.Open(":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.Migrate(database); err != nil {
t.Fatalf("migrate db: %v", err)
}
t.Cleanup(func() { _ = database.Close() })
masterKey := make([]byte, 32)
if _, err := rand.Read(masterKey); err != nil {
t.Fatalf("generate master key: %v", err)
}
cfg := config.NewTestConfig(testIssuer)
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(database, cfg, priv, pub, masterKey, logger)
return srv, pub, priv, database
}
// createTestHumanAccount creates a human account with password "testpass123".
func createTestHumanAccount(t *testing.T, srv *Server, username string) *model.Account {
t.Helper()
hash, err := auth.HashPassword("testpass123", auth.ArgonParams{Time: 3, Memory: 65536, Threads: 4})
if err != nil {
t.Fatalf("hash password: %v", err)
}
acct, err := srv.db.CreateAccount(username, model.AccountTypeHuman, hash)
if err != nil {
t.Fatalf("create account: %v", err)
}
return acct
}
// issueAdminToken creates an account with admin role, issues a JWT, and tracks it.
func issueAdminToken(t *testing.T, srv *Server, priv ed25519.PrivateKey, username string) (string, *model.Account) {
t.Helper()
acct := createTestHumanAccount(t, srv, username)
if err := srv.db.GrantRole(acct.ID, "admin", nil); err != nil {
t.Fatalf("grant admin role: %v", err)
}
tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, []string{"admin"}, time.Hour)
if err != nil {
t.Fatalf("issue token: %v", err)
}
if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
t.Fatalf("track token: %v", err)
}
return tokenStr, acct
}
func doRequest(t *testing.T, handler http.Handler, method, path string, body interface{}, authToken string) *httptest.ResponseRecorder {
t.Helper()
var bodyReader io.Reader
if body != nil {
b, err := json.Marshal(body)
if err != nil {
t.Fatalf("marshal body: %v", err)
}
bodyReader = bytes.NewReader(b)
} else {
bodyReader = bytes.NewReader(nil)
}
req := httptest.NewRequest(method, path, bodyReader)
req.Header.Set("Content-Type", "application/json")
if authToken != "" {
req.Header.Set("Authorization", "Bearer "+authToken)
}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
return rr
}
func TestHealth(t *testing.T) {
srv, _, _, _ := newTestServer(t)
handler := srv.Handler()
rr := doRequest(t, handler, "GET", "/v1/health", nil, "")
if rr.Code != http.StatusOK {
t.Errorf("health status = %d, want 200", rr.Code)
}
}
func TestPublicKey(t *testing.T) {
srv, _, _, _ := newTestServer(t)
handler := srv.Handler()
rr := doRequest(t, handler, "GET", "/v1/keys/public", nil, "")
if rr.Code != http.StatusOK {
t.Errorf("public key status = %d, want 200", rr.Code)
}
var jwk map[string]string
if err := json.Unmarshal(rr.Body.Bytes(), &jwk); err != nil {
t.Fatalf("unmarshal JWK: %v", err)
}
if jwk["kty"] != "OKP" {
t.Errorf("kty = %q, want OKP", jwk["kty"])
}
if jwk["alg"] != "EdDSA" {
t.Errorf("alg = %q, want EdDSA", jwk["alg"])
}
}
func TestLoginSuccess(t *testing.T) {
srv, _, _, _ := newTestServer(t)
createTestHumanAccount(t, srv, "alice")
handler := srv.Handler()
rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{
"username": "alice",
"password": "testpass123",
}, "")
if rr.Code != http.StatusOK {
t.Errorf("login status = %d, want 200; body: %s", rr.Code, rr.Body.String())
}
var resp loginResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal login response: %v", err)
}
if resp.Token == "" {
t.Error("expected non-empty token in login response")
}
if resp.ExpiresAt == "" {
t.Error("expected non-empty expires_at in login response")
}
}
func TestLoginWrongPassword(t *testing.T) {
srv, _, _, _ := newTestServer(t)
createTestHumanAccount(t, srv, "bob")
handler := srv.Handler()
rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{
"username": "bob",
"password": "wrongpassword",
}, "")
if rr.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", rr.Code)
}
}
func TestLoginUnknownUser(t *testing.T) {
srv, _, _, _ := newTestServer(t)
handler := srv.Handler()
rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{
"username": "nobody",
"password": "password",
}, "")
if rr.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", rr.Code)
}
}
func TestLoginResponseDoesNotContainCredentials(t *testing.T) {
srv, _, _, _ := newTestServer(t)
createTestHumanAccount(t, srv, "charlie")
handler := srv.Handler()
rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{
"username": "charlie",
"password": "testpass123",
}, "")
body := rr.Body.String()
// Security: password hash must never appear in any API response.
if strings.Contains(body, "argon2id") || strings.Contains(body, "password_hash") {
t.Error("login response contains password hash material")
}
}
func TestTokenValidate(t *testing.T) {
srv, _, priv, _ := newTestServer(t)
acct := createTestHumanAccount(t, srv, "dave")
handler := srv.Handler()
// Issue and track a token.
tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
t.Fatalf("TrackToken: %v", err)
}
req := httptest.NewRequest("POST", "/v1/token/validate", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("validate status = %d, want 200", rr.Code)
}
var resp validateResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if !resp.Valid {
t.Error("expected valid=true for valid token")
}
}
func TestLogout(t *testing.T) {
srv, _, priv, _ := newTestServer(t)
acct := createTestHumanAccount(t, srv, "eve")
handler := srv.Handler()
tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
t.Fatalf("TrackToken: %v", err)
}
// Logout.
rr := doRequest(t, handler, "POST", "/v1/auth/logout", nil, tokenStr)
if rr.Code != http.StatusNoContent {
t.Errorf("logout status = %d, want 204; body: %s", rr.Code, rr.Body.String())
}
// Token should now be invalid on validate.
req := httptest.NewRequest("POST", "/v1/token/validate", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req)
var resp validateResponse
_ = json.Unmarshal(rr2.Body.Bytes(), &resp)
if resp.Valid {
t.Error("expected valid=false after logout")
}
}
func TestCreateAccountAdmin(t *testing.T) {
srv, _, priv, _ := newTestServer(t)
adminToken, _ := issueAdminToken(t, srv, priv, "admin-user")
handler := srv.Handler()
rr := doRequest(t, handler, "POST", "/v1/accounts", map[string]string{
"username": "new-user",
"password": "newpassword123",
"account_type": "human",
}, adminToken)
if rr.Code != http.StatusCreated {
t.Errorf("create account status = %d, want 201; body: %s", rr.Code, rr.Body.String())
}
var resp accountResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if resp.Username != "new-user" {
t.Errorf("Username = %q, want %q", resp.Username, "new-user")
}
// Security: password hash must not appear in account response.
body := rr.Body.String()
if strings.Contains(body, "password_hash") || strings.Contains(body, "argon2id") {
t.Error("account creation response contains password hash")
}
}
func TestCreateAccountRequiresAdmin(t *testing.T) {
srv, _, priv, _ := newTestServer(t)
acct := createTestHumanAccount(t, srv, "regular-user")
tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, []string{"reader"}, time.Hour)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
t.Fatalf("TrackToken: %v", err)
}
handler := srv.Handler()
rr := doRequest(t, handler, "POST", "/v1/accounts", map[string]string{
"username": "other-user",
"password": "password",
"account_type": "human",
}, tokenStr)
if rr.Code != http.StatusForbidden {
t.Errorf("status = %d, want 403", rr.Code)
}
}
func TestListAccounts(t *testing.T) {
srv, _, priv, _ := newTestServer(t)
adminToken, _ := issueAdminToken(t, srv, priv, "admin2")
createTestHumanAccount(t, srv, "user1")
createTestHumanAccount(t, srv, "user2")
handler := srv.Handler()
rr := doRequest(t, handler, "GET", "/v1/accounts", nil, adminToken)
if rr.Code != http.StatusOK {
t.Errorf("list accounts status = %d, want 200", rr.Code)
}
var accounts []accountResponse
if err := json.Unmarshal(rr.Body.Bytes(), &accounts); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if len(accounts) < 3 { // admin + user1 + user2
t.Errorf("expected at least 3 accounts, got %d", len(accounts))
}
// Security: no credential fields in any response.
body := rr.Body.String()
for _, bad := range []string{"password_hash", "argon2id", "totp_secret", "PasswordHash"} {
if strings.Contains(body, bad) {
t.Errorf("account list response contains credential field %q", bad)
}
}
}
func TestDeleteAccount(t *testing.T) {
srv, _, priv, _ := newTestServer(t)
adminToken, _ := issueAdminToken(t, srv, priv, "admin3")
target := createTestHumanAccount(t, srv, "delete-me")
handler := srv.Handler()
rr := doRequest(t, handler, "DELETE", "/v1/accounts/"+target.UUID, nil, adminToken)
if rr.Code != http.StatusNoContent {
t.Errorf("delete status = %d, want 204; body: %s", rr.Code, rr.Body.String())
}
}
func TestSetAndGetRoles(t *testing.T) {
srv, _, priv, _ := newTestServer(t)
adminToken, _ := issueAdminToken(t, srv, priv, "admin4")
target := createTestHumanAccount(t, srv, "role-target")
handler := srv.Handler()
// Set roles.
rr := doRequest(t, handler, "PUT", "/v1/accounts/"+target.UUID+"/roles", map[string][]string{
"roles": {"reader", "writer"},
}, adminToken)
if rr.Code != http.StatusNoContent {
t.Errorf("set roles status = %d, want 204; body: %s", rr.Code, rr.Body.String())
}
// Get roles.
rr2 := doRequest(t, handler, "GET", "/v1/accounts/"+target.UUID+"/roles", nil, adminToken)
if rr2.Code != http.StatusOK {
t.Errorf("get roles status = %d, want 200", rr2.Code)
}
var resp rolesResponse
if err := json.Unmarshal(rr2.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if len(resp.Roles) != 2 {
t.Errorf("expected 2 roles, got %d", len(resp.Roles))
}
}
func TestRenewToken(t *testing.T) {
srv, _, priv, _ := newTestServer(t)
acct := createTestHumanAccount(t, srv, "renew-user")
handler := srv.Handler()
oldTokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
oldJTI := claims.JTI
if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
t.Fatalf("TrackToken: %v", err)
}
rr := doRequest(t, handler, "POST", "/v1/auth/renew", nil, oldTokenStr)
if rr.Code != http.StatusOK {
t.Fatalf("renew status = %d, want 200; body: %s", rr.Code, rr.Body.String())
}
var resp loginResponse
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal renew response: %v", err)
}
if resp.Token == "" || resp.Token == oldTokenStr {
t.Error("expected new, distinct token after renewal")
}
// Old token should be revoked in the database.
rec, err := srv.db.GetTokenRecord(oldJTI)
if err != nil {
t.Fatalf("GetTokenRecord: %v", err)
}
if !rec.IsRevoked() {
t.Error("old token should be revoked after renewal")
}
}

181
internal/token/token.go Normal file
View File

@@ -0,0 +1,181 @@
// Package token handles JWT issuance, validation, and revocation for MCIAS.
//
// Security design:
// - Algorithm header is checked FIRST, before any signature verification.
// This prevents algorithm-confusion attacks (CVE-2022-21449 class).
// - Only "EdDSA" is accepted; "none", HS*, RS*, ES* are all rejected.
// - The signing key is taken from the server's keystore, never from the token.
// - All standard claims (exp, iat, iss, jti) are required and validated.
// - JTIs are UUIDs generated from crypto/rand (via google/uuid).
// - Token values are never stored; only JTIs are recorded for revocation.
package token
import (
"crypto/ed25519"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
const (
// requiredAlg is the only JWT algorithm accepted by MCIAS.
// Security: Hard-coding this as a constant rather than a variable ensures
// it cannot be changed at runtime and cannot be confused by token headers.
requiredAlg = "EdDSA"
)
// Claims holds the MCIAS-specific JWT claims.
type Claims struct {
// Standard registered claims.
Issuer string `json:"iss"`
Subject string `json:"sub"` // account UUID
IssuedAt time.Time `json:"iat"`
ExpiresAt time.Time `json:"exp"`
JTI string `json:"jti"`
// MCIAS-specific claims.
Roles []string `json:"roles"`
}
// jwtClaims adapts Claims to the golang-jwt MapClaims interface.
type jwtClaims struct {
jwt.RegisteredClaims
Roles []string `json:"roles"`
}
// ErrExpiredToken is returned when the token's exp claim is in the past.
var ErrExpiredToken = errors.New("token: expired")
// ErrInvalidSignature is returned when Ed25519 signature verification fails.
var ErrInvalidSignature = errors.New("token: invalid signature")
// ErrWrongAlgorithm is returned when the alg header is not EdDSA.
var ErrWrongAlgorithm = errors.New("token: algorithm must be EdDSA")
// ErrMissingClaim is returned when a required claim is absent or empty.
var ErrMissingClaim = errors.New("token: missing required claim")
// IssueToken creates and signs a new JWT with the given claims.
// The jti is generated automatically using crypto/rand via uuid.New().
// Returns the signed token string.
//
// Security: The signing key is provided by the caller from the server's
// keystore. The alg header is set explicitly to "EdDSA" by the jwt library
// when an ed25519.PrivateKey is passed to SignedString.
func IssueToken(key ed25519.PrivateKey, issuer, subject string, roles []string, expiry time.Duration) (string, *Claims, error) {
now := time.Now().UTC()
exp := now.Add(expiry)
jti := uuid.New().String()
jc := jwtClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: issuer,
Subject: subject,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(exp),
ID: jti,
},
Roles: roles,
}
t := jwt.NewWithClaims(jwt.SigningMethodEdDSA, jc)
signed, err := t.SignedString(key)
if err != nil {
return "", nil, fmt.Errorf("token: sign JWT: %w", err)
}
claims := &Claims{
Issuer: issuer,
Subject: subject,
IssuedAt: now,
ExpiresAt: exp,
JTI: jti,
Roles: roles,
}
return signed, claims, nil
}
// ValidateToken parses and validates a JWT string.
//
// Security order of operations (all must pass):
// 1. Parse the token header and extract the alg field.
// 2. Reject immediately if alg != "EdDSA" (before any signature check).
// 3. Verify Ed25519 signature.
// 4. Validate exp, iat, iss, jti claims.
//
// Returns Claims on success, or a typed error on any failure.
// The caller is responsible for checking revocation status via the DB.
func ValidateToken(key ed25519.PublicKey, tokenString, expectedIssuer string) (*Claims, error) {
// Step 1+2: Parse the header to check alg BEFORE any crypto.
// Security: We use jwt.ParseWithClaims with an explicit key function that
// enforces the algorithm. The key function is called by the library after
// parsing the header but before verifying the signature, which is the
// correct point to enforce algorithm constraints.
var jc jwtClaims
t, err := jwt.ParseWithClaims(tokenString, &jc, func(t *jwt.Token) (interface{}, error) {
// Security: Check alg header first. This must happen in the key
// function — it is the only place where the parsed (but unverified)
// header is available before signature validation.
if t.Method.Alg() != requiredAlg {
return nil, fmt.Errorf("%w: got %q, want %q", ErrWrongAlgorithm, t.Method.Alg(), requiredAlg)
}
return key, nil
},
jwt.WithIssuedAt(),
jwt.WithIssuer(expectedIssuer),
jwt.WithExpirationRequired(),
)
if err != nil {
// Map library errors to our typed errors for consistent handling.
if errors.Is(err, ErrWrongAlgorithm) {
return nil, ErrWrongAlgorithm
}
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrExpiredToken
}
if errors.Is(err, jwt.ErrSignatureInvalid) {
return nil, ErrInvalidSignature
}
return nil, fmt.Errorf("token: parse: %w", err)
}
if !t.Valid {
return nil, ErrInvalidSignature
}
// Step 4: Validate required custom claims.
if jc.ID == "" {
return nil, fmt.Errorf("%w: jti", ErrMissingClaim)
}
if jc.Subject == "" {
return nil, fmt.Errorf("%w: sub", ErrMissingClaim)
}
if jc.ExpiresAt == nil {
return nil, fmt.Errorf("%w: exp", ErrMissingClaim)
}
if jc.IssuedAt == nil {
return nil, fmt.Errorf("%w: iat", ErrMissingClaim)
}
claims := &Claims{
Issuer: jc.Issuer,
Subject: jc.Subject,
IssuedAt: jc.IssuedAt.Time,
ExpiresAt: jc.ExpiresAt.Time,
JTI: jc.ID,
Roles: jc.Roles,
}
return claims, nil
}
// HasRole reports whether the claims include the given role.
func (c *Claims) HasRole(role string) bool {
for _, r := range c.Roles {
if r == role {
return true
}
}
return false
}

View File

@@ -0,0 +1,222 @@
package token
import (
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
)
func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate key: %v", err)
}
return pub, priv
}
// b64url encodes a string as base64url without padding.
func b64url(s string) string {
return base64.RawURLEncoding.EncodeToString([]byte(s))
}
const testIssuer = "https://auth.example.com"
func TestIssueAndValidateToken(t *testing.T) {
pub, priv := generateTestKey(t)
roles := []string{"admin", "reader"}
tokenStr, claims, err := IssueToken(priv, testIssuer, "user-uuid-1", roles, time.Hour)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
if tokenStr == "" {
t.Fatal("IssueToken returned empty token string")
}
if claims.JTI == "" {
t.Error("JTI must not be empty")
}
if claims.Subject != "user-uuid-1" {
t.Errorf("Subject = %q, want %q", claims.Subject, "user-uuid-1")
}
// Validate the token.
got, err := ValidateToken(pub, tokenStr, testIssuer)
if err != nil {
t.Fatalf("ValidateToken: %v", err)
}
if got.Subject != "user-uuid-1" {
t.Errorf("validated Subject = %q, want %q", got.Subject, "user-uuid-1")
}
if got.JTI != claims.JTI {
t.Errorf("validated JTI = %q, want %q", got.JTI, claims.JTI)
}
if len(got.Roles) != 2 {
t.Errorf("validated Roles = %v, want 2 roles", got.Roles)
}
}
// TestValidateTokenWrongAlgorithm verifies that tokens with non-EdDSA alg are
// rejected immediately, before any signature verification.
// Security: This tests the core defence against algorithm-confusion attacks.
func TestValidateTokenWrongAlgorithm(t *testing.T) {
_, priv := generateTestKey(t)
pub, _ := generateTestKey(t) // different key — but alg check should fail first
// Forge a token signed with HMAC-SHA256 (alg: HS256).
hmacToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": testIssuer,
"sub": "attacker",
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(),
"jti": "fake-jti",
})
// Use the Ed25519 public key bytes as the HMAC secret (classic alg confusion).
hs256Signed, err := hmacToken.SignedString([]byte(priv.Public().(ed25519.PublicKey)))
if err != nil {
t.Fatalf("sign HS256 token: %v", err)
}
_, err = ValidateToken(pub, hs256Signed, testIssuer)
if err == nil {
t.Fatal("expected error for HS256 token, got nil")
}
if err != ErrWrongAlgorithm {
t.Errorf("expected ErrWrongAlgorithm, got: %v", err)
}
}
// TestValidateTokenAlgNone verifies that "none" algorithm is rejected.
// Security: "none" algorithm tokens have no signature and must always be
// rejected regardless of payload content.
func TestValidateTokenAlgNone(t *testing.T) {
pub, _ := generateTestKey(t)
// Construct a "none" algorithm token manually.
// golang-jwt/v5 disallows signing with "none" directly, so we craft it
// using raw base64url encoding.
header := `{"alg":"none","typ":"JWT"}`
payload := `{"iss":"https://auth.example.com","sub":"evil","iat":1000000,"exp":9999999999,"jti":"evil-jti"}`
noneToken := b64url(header) + "." + b64url(payload) + "."
_, err := ValidateToken(pub, noneToken, testIssuer)
if err == nil {
t.Fatal("expected error for 'none' algorithm token, got nil")
}
}
// TestValidateTokenExpired verifies that expired tokens are rejected.
func TestValidateTokenExpired(t *testing.T) {
pub, priv := generateTestKey(t)
// Issue a token with a negative expiry (already expired).
tokenStr, _, err := IssueToken(priv, testIssuer, "user", nil, -time.Minute)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
_, err = ValidateToken(pub, tokenStr, testIssuer)
if err == nil {
t.Fatal("expected error for expired token, got nil")
}
if err != ErrExpiredToken {
t.Errorf("expected ErrExpiredToken, got: %v", err)
}
}
// TestValidateTokenTamperedSignature verifies that signature tampering is caught.
func TestValidateTokenTamperedSignature(t *testing.T) {
pub, priv := generateTestKey(t)
tokenStr, _, err := IssueToken(priv, testIssuer, "user", nil, time.Hour)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
// Tamper: flip a byte in the signature (last segment).
parts := strings.Split(tokenStr, ".")
if len(parts) != 3 {
t.Fatalf("unexpected token format: %d parts", len(parts))
}
sigBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
t.Fatalf("decode signature: %v", err)
}
sigBytes[0] ^= 0x01 // flip one bit
parts[2] = base64.RawURLEncoding.EncodeToString(sigBytes)
tampered := strings.Join(parts, ".")
_, err = ValidateToken(pub, tampered, testIssuer)
if err == nil {
t.Fatal("expected error for tampered signature, got nil")
}
}
// TestValidateTokenWrongKey verifies that a token signed with a different key
// is rejected.
func TestValidateTokenWrongKey(t *testing.T) {
_, priv := generateTestKey(t)
wrongPub, _ := generateTestKey(t)
tokenStr, _, err := IssueToken(priv, testIssuer, "user", nil, time.Hour)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
_, err = ValidateToken(wrongPub, tokenStr, testIssuer)
if err == nil {
t.Fatal("expected error for wrong key, got nil")
}
}
// TestValidateTokenWrongIssuer verifies that tokens from a different issuer
// are rejected.
func TestValidateTokenWrongIssuer(t *testing.T) {
pub, priv := generateTestKey(t)
tokenStr, _, err := IssueToken(priv, "https://evil.example.com", "user", nil, time.Hour)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
_, err = ValidateToken(pub, tokenStr, testIssuer)
if err == nil {
t.Fatal("expected error for wrong issuer, got nil")
}
}
// TestJTIsAreUnique verifies that two issued tokens have different JTIs.
func TestJTIsAreUnique(t *testing.T) {
_, priv := generateTestKey(t)
_, c1, err := IssueToken(priv, testIssuer, "user", nil, time.Hour)
if err != nil {
t.Fatalf("IssueToken (1): %v", err)
}
_, c2, err := IssueToken(priv, testIssuer, "user", nil, time.Hour)
if err != nil {
t.Fatalf("IssueToken (2): %v", err)
}
if c1.JTI == c2.JTI {
t.Error("two issued tokens have the same JTI")
}
}
// TestClaimsHasRole verifies role checking.
func TestClaimsHasRole(t *testing.T) {
c := &Claims{Roles: []string{"admin", "reader"}}
if !c.HasRole("admin") {
t.Error("expected HasRole(admin) = true")
}
if !c.HasRole("reader") {
t.Error("expected HasRole(reader) = true")
}
if c.HasRole("writer") {
t.Error("expected HasRole(writer) = false")
}
}