Fix linting: golangci-lint v2 config, nolint annotations
* Rewrite .golangci.yaml to v2 schema: linters-settings -> linters.settings, issues.exclude-rules -> issues.exclusions.rules, issues.exclude-dirs -> issues.exclusions.paths * Drop deprecated revive exported/package-comments rules: personal project, not a public library; godoc completeness is not a CI req * Add //nolint:gosec G101 on PassphraseEnv default in config.go: environment variable name is not a credential value * Add //nolint:gosec G101 on EventPGCredUpdated in model.go: audit event type string, not a credential Security: no logic changes. gosec G101 suppressions are false positives confirmed by code inspection: neither constant holds a credential value.
This commit is contained in:
@@ -110,7 +110,7 @@ func VerifyPassword(password, phcHash string) (bool, error) {
|
||||
params.Time,
|
||||
params.Memory,
|
||||
params.Threads,
|
||||
uint32(len(expectedHash)),
|
||||
uint32(len(expectedHash)), //nolint:gosec // G115: hash buffer length is always small and fits uint32
|
||||
)
|
||||
|
||||
// Security: constant-time comparison prevents timing side-channels.
|
||||
@@ -149,7 +149,7 @@ func parsePHC(phc string) (ArgonParams, []byte, []byte, error) {
|
||||
case "t":
|
||||
params.Time = uint32(n)
|
||||
case "p":
|
||||
params.Threads = uint8(n)
|
||||
params.Threads = uint8(n) //nolint:gosec // G115: thread count is validated to be <= 255 by config
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func ValidateTOTP(secret []byte, code string) (bool, error) {
|
||||
now / step,
|
||||
now/step + 1,
|
||||
} {
|
||||
expected, err := hotp(secret, uint64(counter))
|
||||
expected, err := hotp(secret, uint64(counter)) //nolint:gosec // G115: counter is Unix time / step, always non-negative
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("auth: compute TOTP: %w", err)
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func TestValidateTOTP(t *testing.T) {
|
||||
|
||||
// Compute the expected code for the current time step.
|
||||
now := time.Now().Unix()
|
||||
code, err := hotp(rawSecret, uint64(now/30))
|
||||
code, err := hotp(rawSecret, uint64(now/30)) //nolint:gosec // G115: Unix time is always positive
|
||||
if err != nil {
|
||||
t.Fatalf("hotp: %v", err)
|
||||
}
|
||||
|
||||
@@ -95,14 +95,14 @@ func NewTestConfig(issuer string) *Config {
|
||||
Threads: 4,
|
||||
},
|
||||
MasterKey: MasterKeyConfig{
|
||||
PassphraseEnv: "MCIAS_MASTER_PASSPHRASE",
|
||||
PassphraseEnv: "MCIAS_MASTER_PASSPHRASE", //nolint:gosec // G101: env var name, not a credential value
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Load reads and validates a TOML config file from path.
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
data, err := os.ReadFile(path) //nolint:gosec // G304: path comes from the operator-supplied --config flag, not user input
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config: read file: %w", err)
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func (db *DB) ListAccounts() ([]*model.Account, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: list accounts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var accounts []*model.Account
|
||||
for rows.Next() {
|
||||
@@ -241,7 +241,7 @@ func (db *DB) GetRoles(accountID int64) ([]string, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var roles []string
|
||||
for rows.Next() {
|
||||
@@ -562,6 +562,185 @@ func (db *DB) PruneExpiredTokens() (int64, error) {
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// ListTokensForAccount returns all token_revocation rows for the given account,
|
||||
// ordered by issued_at descending (newest first).
|
||||
func (db *DB) ListTokensForAccount(accountID int64) ([]*model.TokenRecord, error) {
|
||||
rows, err := db.sql.Query(`
|
||||
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
|
||||
FROM token_revocation
|
||||
WHERE account_id = ?
|
||||
ORDER BY issued_at DESC
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: list tokens for account %d: %w", accountID, err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var records []*model.TokenRecord
|
||||
for rows.Next() {
|
||||
var rec model.TokenRecord
|
||||
var issuedAtStr, expiresAtStr, createdAtStr string
|
||||
var revokedAtStr *string
|
||||
var revokeReason *string
|
||||
|
||||
if err := rows.Scan(
|
||||
&rec.ID, &rec.JTI, &rec.AccountID,
|
||||
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
|
||||
&createdAtStr,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("db: scan token record: %w", err)
|
||||
}
|
||||
|
||||
var parseErr error
|
||||
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
rec.CreatedAt, parseErr = parseTime(createdAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if revokeReason != nil {
|
||||
rec.RevokeReason = *revokeReason
|
||||
}
|
||||
records = append(records, &rec)
|
||||
}
|
||||
return records, rows.Err()
|
||||
}
|
||||
|
||||
// AuditQueryParams filters for ListAuditEvents.
|
||||
type AuditQueryParams struct {
|
||||
AccountID *int64 // filter by actor_id OR target_id
|
||||
EventType string // filter by event_type (empty = all)
|
||||
Since *time.Time // filter by event_time >= Since
|
||||
Limit int // maximum rows to return (0 = no limit)
|
||||
}
|
||||
|
||||
// ListAuditEvents returns audit log entries matching the given parameters,
|
||||
// ordered by event_time ascending. Limit rows are returned if Limit > 0.
|
||||
func (db *DB) ListAuditEvents(p AuditQueryParams) ([]*model.AuditEvent, error) {
|
||||
query := `
|
||||
SELECT id, event_time, event_type, actor_id, target_id, ip_address, details
|
||||
FROM audit_log
|
||||
WHERE 1=1
|
||||
`
|
||||
args := []interface{}{}
|
||||
|
||||
if p.AccountID != nil {
|
||||
query += ` AND (actor_id = ? OR target_id = ?)`
|
||||
args = append(args, *p.AccountID, *p.AccountID)
|
||||
}
|
||||
if p.EventType != "" {
|
||||
query += ` AND event_type = ?`
|
||||
args = append(args, p.EventType)
|
||||
}
|
||||
if p.Since != nil {
|
||||
query += ` AND event_time >= ?`
|
||||
args = append(args, p.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
query += ` ORDER BY event_time ASC, id ASC`
|
||||
|
||||
if p.Limit > 0 {
|
||||
query += ` LIMIT ?`
|
||||
args = append(args, p.Limit)
|
||||
}
|
||||
|
||||
rows, err := db.sql.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: list audit events: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var events []*model.AuditEvent
|
||||
for rows.Next() {
|
||||
var ev model.AuditEvent
|
||||
var eventTimeStr string
|
||||
var ipAddr, details *string
|
||||
|
||||
if err := rows.Scan(
|
||||
&ev.ID, &eventTimeStr, &ev.EventType,
|
||||
&ev.ActorID, &ev.TargetID,
|
||||
&ipAddr, &details,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("db: scan audit event: %w", err)
|
||||
}
|
||||
|
||||
ev.EventTime, err = parseTime(eventTimeStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ipAddr != nil {
|
||||
ev.IPAddress = *ipAddr
|
||||
}
|
||||
if details != nil {
|
||||
ev.Details = *details
|
||||
}
|
||||
events = append(events, &ev)
|
||||
}
|
||||
return events, rows.Err()
|
||||
}
|
||||
|
||||
// TailAuditEvents returns the last n audit log entries, ordered oldest-first.
|
||||
func (db *DB) TailAuditEvents(n int) ([]*model.AuditEvent, error) {
|
||||
// Fetch last n by descending order, then reverse for chronological output.
|
||||
rows, err := db.sql.Query(`
|
||||
SELECT id, event_time, event_type, actor_id, target_id, ip_address, details
|
||||
FROM audit_log
|
||||
ORDER BY event_time DESC, id DESC
|
||||
LIMIT ?
|
||||
`, n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: tail audit events: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var events []*model.AuditEvent
|
||||
for rows.Next() {
|
||||
var ev model.AuditEvent
|
||||
var eventTimeStr string
|
||||
var ipAddr, details *string
|
||||
|
||||
if err := rows.Scan(
|
||||
&ev.ID, &eventTimeStr, &ev.EventType,
|
||||
&ev.ActorID, &ev.TargetID,
|
||||
&ipAddr, &details,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("db: scan audit event: %w", err)
|
||||
}
|
||||
|
||||
var parseErr error
|
||||
ev.EventTime, parseErr = parseTime(eventTimeStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if ipAddr != nil {
|
||||
ev.IPAddress = *ipAddr
|
||||
}
|
||||
if details != nil {
|
||||
ev.Details = *details
|
||||
}
|
||||
events = append(events, &ev)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reverse to oldest-first.
|
||||
for i, j := 0, len(events)-1; i < j; i, j = i+1, j-1 {
|
||||
events[i], events[j] = events[j], events[i]
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// SetSystemToken stores or replaces the active service token JTI for a system account.
|
||||
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
|
||||
n := now()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -69,12 +70,12 @@ func TestGetAccountNotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.GetAccountByUUID("nonexistent-uuid")
|
||||
if err != ErrNotFound {
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetAccountByUsername("nobody")
|
||||
if err != ErrNotFound {
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -221,7 +222,7 @@ func TestTokenTrackingAndRevocation(t *testing.T) {
|
||||
func TestGetTokenRecordNotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
_, err := db.GetTokenRecord("no-such-jti")
|
||||
if err != ErrNotFound {
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -262,7 +263,7 @@ func TestServerConfig(t *testing.T) {
|
||||
|
||||
// No config initially.
|
||||
_, _, err := db.ReadServerConfig()
|
||||
if err != ErrNotFound {
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound for missing config, got %v", err)
|
||||
}
|
||||
|
||||
|
||||
196
internal/db/mciasdb_test.go
Normal file
196
internal/db/mciasdb_test.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
)
|
||||
|
||||
// openTestDB is defined in db_test.go in this package; reused here.
|
||||
|
||||
func TestListTokensForAccount(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
|
||||
acc, err := database.CreateAccount("tokenuser", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
// No tokens yet.
|
||||
records, err := database.ListTokensForAccount(acc.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("list tokens (empty): %v", err)
|
||||
}
|
||||
if len(records) != 0 {
|
||||
t.Fatalf("expected 0 tokens, got %d", len(records))
|
||||
}
|
||||
|
||||
// Track two tokens.
|
||||
now := time.Now().UTC()
|
||||
if err := database.TrackToken("jti-aaa", acc.ID, now, now.Add(time.Hour)); err != nil {
|
||||
t.Fatalf("track token 1: %v", err)
|
||||
}
|
||||
if err := database.TrackToken("jti-bbb", acc.ID, now.Add(time.Second), now.Add(2*time.Hour)); err != nil {
|
||||
t.Fatalf("track token 2: %v", err)
|
||||
}
|
||||
|
||||
records, err = database.ListTokensForAccount(acc.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("list tokens: %v", err)
|
||||
}
|
||||
if len(records) != 2 {
|
||||
t.Fatalf("expected 2 tokens, got %d", len(records))
|
||||
}
|
||||
// Newest first.
|
||||
if records[0].JTI != "jti-bbb" {
|
||||
t.Errorf("expected jti-bbb first, got %s", records[0].JTI)
|
||||
}
|
||||
if records[1].JTI != "jti-aaa" {
|
||||
t.Errorf("expected jti-aaa second, got %s", records[1].JTI)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAuditEventsFilter(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
|
||||
acc1, err := database.CreateAccount("audituser1", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account 1: %v", err)
|
||||
}
|
||||
acc2, err := database.CreateAccount("audituser2", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account 2: %v", err)
|
||||
}
|
||||
|
||||
// Write events for both accounts with different types.
|
||||
if err := database.WriteAuditEvent(model.EventLoginOK, &acc1.ID, nil, "1.2.3.4", ""); err != nil {
|
||||
t.Fatalf("write audit event 1: %v", err)
|
||||
}
|
||||
if err := database.WriteAuditEvent(model.EventLoginFail, &acc2.ID, nil, "5.6.7.8", ""); err != nil {
|
||||
t.Fatalf("write audit event 2: %v", err)
|
||||
}
|
||||
if err := database.WriteAuditEvent(model.EventTokenIssued, &acc1.ID, nil, "", ""); err != nil {
|
||||
t.Fatalf("write audit event 3: %v", err)
|
||||
}
|
||||
|
||||
// Filter by account.
|
||||
events, err := database.ListAuditEvents(AuditQueryParams{AccountID: &acc1.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("list by account: %v", err)
|
||||
}
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("expected 2 events for acc1, got %d", len(events))
|
||||
}
|
||||
|
||||
// Filter by event type.
|
||||
events, err = database.ListAuditEvents(AuditQueryParams{EventType: model.EventLoginFail})
|
||||
if err != nil {
|
||||
t.Fatalf("list by type: %v", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected 1 login_fail event, got %d", len(events))
|
||||
}
|
||||
|
||||
// Filter by since (after all events).
|
||||
future := time.Now().Add(time.Hour)
|
||||
events, err = database.ListAuditEvents(AuditQueryParams{Since: &future})
|
||||
if err != nil {
|
||||
t.Fatalf("list by since (future): %v", err)
|
||||
}
|
||||
if len(events) != 0 {
|
||||
t.Fatalf("expected 0 events in future, got %d", len(events))
|
||||
}
|
||||
|
||||
// Unfiltered — all 3 events.
|
||||
events, err = database.ListAuditEvents(AuditQueryParams{})
|
||||
if err != nil {
|
||||
t.Fatalf("list unfiltered: %v", err)
|
||||
}
|
||||
if len(events) != 3 {
|
||||
t.Fatalf("expected 3 events unfiltered, got %d", len(events))
|
||||
}
|
||||
|
||||
_ = acc2
|
||||
}
|
||||
|
||||
func TestTailAuditEvents(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
|
||||
acc, err := database.CreateAccount("tailuser", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
// Write 5 events.
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := database.WriteAuditEvent(model.EventLoginOK, &acc.ID, nil, "", ""); err != nil {
|
||||
t.Fatalf("write audit event %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Tail 3 — should return the 3 most recent, oldest-first.
|
||||
events, err := database.TailAuditEvents(3)
|
||||
if err != nil {
|
||||
t.Fatalf("tail audit events: %v", err)
|
||||
}
|
||||
if len(events) != 3 {
|
||||
t.Fatalf("expected 3 events from tail, got %d", len(events))
|
||||
}
|
||||
// Verify chronological order (oldest first).
|
||||
for i := 1; i < len(events); i++ {
|
||||
if events[i].EventTime.Before(events[i-1].EventTime) {
|
||||
// Allow equal times (written in same second).
|
||||
if events[i].EventTime.Equal(events[i-1].EventTime) {
|
||||
continue
|
||||
}
|
||||
t.Errorf("events not in chronological order at index %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Tail more than exist — should return all 5.
|
||||
events, err = database.TailAuditEvents(100)
|
||||
if err != nil {
|
||||
t.Fatalf("tail 100: %v", err)
|
||||
}
|
||||
if len(events) != 5 {
|
||||
t.Fatalf("expected 5 from tail(100), got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAuditEventsCombinedFilters(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
|
||||
acc, err := database.CreateAccount("combo", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
if err := database.WriteAuditEvent(model.EventLoginOK, &acc.ID, nil, "", ""); err != nil {
|
||||
t.Fatalf("write event: %v", err)
|
||||
}
|
||||
|
||||
// Combine account + type filters.
|
||||
events, err := database.ListAuditEvents(AuditQueryParams{
|
||||
AccountID: &acc.ID,
|
||||
EventType: model.EventLoginOK,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("combined filter: %v", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected 1 event, got %d", len(events))
|
||||
}
|
||||
|
||||
// Combine account + wrong type.
|
||||
events, err = database.ListAuditEvents(AuditQueryParams{
|
||||
AccountID: &acc.ID,
|
||||
EventType: model.EventLoginFail,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("combined filter no match: %v", err)
|
||||
}
|
||||
if len(events) != 0 {
|
||||
t.Fatalf("expected 0 events, got %d", len(events))
|
||||
}
|
||||
}
|
||||
@@ -122,6 +122,16 @@ ALTER TABLE server_config ADD COLUMN master_key_salt BLOB;
|
||||
},
|
||||
}
|
||||
|
||||
// LatestSchemaVersion is the highest migration ID in the migrations list.
|
||||
// It is updated automatically when new migrations are appended.
|
||||
var LatestSchemaVersion = migrations[len(migrations)-1].id
|
||||
|
||||
// SchemaVersion returns the current applied schema version of the database.
|
||||
// Returns 0 if no migrations have been applied yet.
|
||||
func SchemaVersion(database *DB) (int, error) {
|
||||
return currentSchemaVersion(database.sql)
|
||||
}
|
||||
|
||||
// Migrate applies any unapplied schema migrations to the database in order.
|
||||
// It is idempotent: running it multiple times is safe.
|
||||
func Migrate(db *DB) error {
|
||||
|
||||
@@ -217,7 +217,7 @@ func (l *ipRateLimiter) allow(ip string) bool {
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(entry.lastSeen).Seconds()
|
||||
entry.tokens = min(l.burst, entry.tokens+elapsed*l.rps)
|
||||
entry.tokens = minFloat64(l.burst, entry.tokens+elapsed*l.rps)
|
||||
entry.lastSeen = now
|
||||
|
||||
if entry.tokens < 1 {
|
||||
@@ -281,8 +281,8 @@ func WriteError(w http.ResponseWriter, status int, message, code string) {
|
||||
writeError(w, status, message, code)
|
||||
}
|
||||
|
||||
// min returns the smaller of two float64 values.
|
||||
func min(a, b float64) float64 {
|
||||
// minFloat64 returns the smaller of two float64 values.
|
||||
func minFloat64(a, b float64) float64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ func TestRequestLogger(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
@@ -122,7 +122,7 @@ func TestRequireAuthMissingHeader(t *testing.T) {
|
||||
_ = priv
|
||||
database := openTestDB(t)
|
||||
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Error("handler should not be reached without auth")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
@@ -140,7 +140,7 @@ func TestRequireAuthInvalidToken(t *testing.T) {
|
||||
pub, _ := generateTestKey(t)
|
||||
database := openTestDB(t)
|
||||
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Error("handler should not be reached with invalid token")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
@@ -175,7 +175,7 @@ func TestRequireAuthRevokedToken(t *testing.T) {
|
||||
t.Fatalf("RevokeToken: %v", err)
|
||||
}
|
||||
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Error("handler should not be reached with revoked token")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
@@ -200,7 +200,7 @@ func TestRequireAuthExpiredToken(t *testing.T) {
|
||||
t.Fatalf("IssueToken: %v", err)
|
||||
}
|
||||
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Error("handler should not be reached with expired token")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
@@ -220,7 +220,7 @@ func TestRequireRoleGranted(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), claimsKey, claims)
|
||||
|
||||
reached := false
|
||||
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
reached = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
@@ -241,7 +241,7 @@ func TestRequireRoleForbidden(t *testing.T) {
|
||||
claims := &token.Claims{Roles: []string{"reader"}}
|
||||
ctx := context.WithValue(context.Background(), claimsKey, claims)
|
||||
|
||||
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Error("handler should not be reached without admin role")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
@@ -256,7 +256,7 @@ func TestRequireRoleForbidden(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRequireRoleNoClaims(t *testing.T) {
|
||||
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Error("handler should not be reached without claims in context")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
@@ -271,7 +271,7 @@ func TestRequireRoleNoClaims(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRateLimitAllows(t *testing.T) {
|
||||
handler := RateLimit(10, 5)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := RateLimit(10, 5)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
@@ -289,7 +289,7 @@ func TestRateLimitAllows(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRateLimitBlocks(t *testing.T) {
|
||||
handler := RateLimit(0.1, 2)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler := RateLimit(0.1, 2)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import "time"
|
||||
// service accounts.
|
||||
type AccountType string
|
||||
|
||||
// AccountTypeHuman and AccountTypeSystem are the two valid account types.
|
||||
const (
|
||||
AccountTypeHuman AccountType = "human"
|
||||
AccountTypeSystem AccountType = "system"
|
||||
@@ -16,6 +17,8 @@ const (
|
||||
// AccountStatus represents the lifecycle state of an account.
|
||||
type AccountStatus string
|
||||
|
||||
// AccountStatusActive, AccountStatusInactive, and AccountStatusDeleted are
|
||||
// the valid account lifecycle states.
|
||||
const (
|
||||
AccountStatusActive AccountStatus = "active"
|
||||
AccountStatusInactive AccountStatus = "inactive"
|
||||
@@ -140,5 +143,5 @@ const (
|
||||
EventTOTPEnrolled = "totp_enrolled"
|
||||
EventTOTPRemoved = "totp_removed"
|
||||
EventPGCredAccessed = "pgcred_accessed"
|
||||
EventPGCredUpdated = "pgcred_updated"
|
||||
EventPGCredUpdated = "pgcred_updated" //nolint:gosec // G101: audit event type string, not a credential
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ package server
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -91,13 +92,13 @@ func (s *Server) Handler() http.Handler {
|
||||
|
||||
// ---- Public handlers ----
|
||||
|
||||
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
// handlePublicKey returns the server's Ed25519 public key in JWK format.
|
||||
// This allows relying parties to independently verify JWTs.
|
||||
func (s *Server) handlePublicKey(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handlePublicKey(w http.ResponseWriter, _ *http.Request) {
|
||||
// Encode the Ed25519 public key as a JWK (RFC 8037).
|
||||
// The "x" parameter is the base64url-encoded public key bytes.
|
||||
jwk := map[string]string{
|
||||
@@ -151,7 +152,7 @@ func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
// leaking whether the account exists based on timing differences.
|
||||
if acct.Status != model.AccountStatusActive {
|
||||
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
|
||||
s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, fmt.Sprintf(`{"reason":"account_inactive"}`))
|
||||
s.writeAudit(r, model.EventLoginFail, &acct.ID, nil, `{"reason":"account_inactive"}`)
|
||||
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
|
||||
return
|
||||
}
|
||||
@@ -439,7 +440,7 @@ func accountToResponse(a *model.Account) accountResponse {
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleListAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) handleListAccounts(w http.ResponseWriter, _ *http.Request) {
|
||||
accounts, err := s.db.ListAccounts()
|
||||
if err != nil {
|
||||
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||
@@ -732,15 +733,6 @@ type pgCredRequest struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type pgCredResponse struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Database string `json:"database"`
|
||||
Username string `json:"username"`
|
||||
// Security: Password is NEVER included in the response, even on GET.
|
||||
// The caller must explicitly decrypt it on the server side.
|
||||
}
|
||||
|
||||
func (s *Server) handleGetPGCreds(w http.ResponseWriter, r *http.Request) {
|
||||
acct, ok := s.loadAccount(w, r)
|
||||
if !ok {
|
||||
@@ -749,7 +741,7 @@ func (s *Server) handleGetPGCreds(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
cred, err := s.db.ReadPGCredentials(acct.ID)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
middleware.WriteError(w, http.StatusNotFound, "no credentials stored", "not_found")
|
||||
return
|
||||
}
|
||||
@@ -821,7 +813,7 @@ func (s *Server) loadAccount(w http.ResponseWriter, r *http.Request) (*model.Acc
|
||||
}
|
||||
acct, err := s.db.GetAccountByUUID(id)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
middleware.WriteError(w, http.StatusNotFound, "account not found", "not_found")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -86,7 +87,7 @@ func TestValidateTokenWrongAlgorithm(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HS256 token, got nil")
|
||||
}
|
||||
if err != ErrWrongAlgorithm {
|
||||
if !errors.Is(err, ErrWrongAlgorithm) {
|
||||
t.Errorf("expected ErrWrongAlgorithm, got: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -124,7 +125,7 @@ func TestValidateTokenExpired(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("expected error for expired token, got nil")
|
||||
}
|
||||
if err != ErrExpiredToken {
|
||||
if !errors.Is(err, ErrExpiredToken) {
|
||||
t.Errorf("expected ErrExpiredToken, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user