diff --git a/internal/db/accounts.go b/internal/db/accounts.go index 6c58a66..289d976 100644 --- a/internal/db/accounts.go +++ b/internal/db/accounts.go @@ -692,6 +692,70 @@ func (db *DB) RenewToken(oldJTI, reason, newJTI string, accountID int64, issuedA return nil } +// IssueSystemToken atomically revokes an existing system token (if oldJTI is +// non-empty), tracks the new token in token_revocation, and upserts the +// system_tokens table — all within a single SQLite transaction. +// +// Security: these three operations must be atomic so that a crash between them +// cannot leave the database in an inconsistent state (e.g., old token revoked +// but new token not tracked, or token tracked but system_tokens not updated). +// With MaxOpenConns(1) and SQLite's serialised write path, BEGIN IMMEDIATE +// acquires the write lock immediately and prevents any other writer from +// interleaving. +func (db *DB) IssueSystemToken(oldJTI, newJTI string, accountID int64, issuedAt, expiresAt time.Time) error { + tx, err := db.sql.Begin() + if err != nil { + return fmt.Errorf("db: issue system token begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + n := now() + + // If there is an existing token, revoke it. + if oldJTI != "" { + _, err := tx.Exec(` + UPDATE token_revocation + SET revoked_at = ?, revoke_reason = ? + WHERE jti = ? AND revoked_at IS NULL + `, n, nullString("rotated"), oldJTI) + if err != nil { + return fmt.Errorf("db: issue system token revoke old %q: %w", oldJTI, err) + } + // We do not require rows affected > 0 because the old token may + // already be revoked or expired; the important thing is that we + // proceed to track the new token regardless. + } + + // Track the new token in token_revocation. + _, err = tx.Exec(` + INSERT INTO token_revocation (jti, account_id, issued_at, expires_at) + VALUES (?, ?, ?, ?) + `, newJTI, accountID, + issuedAt.UTC().Format(time.RFC3339), + expiresAt.UTC().Format(time.RFC3339)) + if err != nil { + return fmt.Errorf("db: issue system token track new %q: %w", newJTI, err) + } + + // Upsert the system_tokens table so GetSystemToken returns the new JTI. + _, err = tx.Exec(` + INSERT INTO system_tokens (account_id, jti, expires_at, created_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(account_id) DO UPDATE SET + jti = excluded.jti, + expires_at = excluded.expires_at, + created_at = excluded.created_at + `, accountID, newJTI, expiresAt.UTC().Format(time.RFC3339), n) + if err != nil { + return fmt.Errorf("db: issue system token set system token for account %d: %w", accountID, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("db: issue system token commit: %w", err) + } + return nil +} + // RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account. func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error { n := now() diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 822e0fc..a3d43a3 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -445,6 +445,79 @@ func TestSystemTokenRotationRevokesOld(t *testing.T) { } } +// TestIssueSystemTokenAtomic verifies that IssueSystemToken atomically +// revokes an old token, tracks the new token, and upserts system_tokens. +func TestIssueSystemTokenAtomic(t *testing.T) { + db := openTestDB(t) + acct, err := db.CreateAccount("svc-atomic", model.AccountTypeSystem, "hash") + if err != nil { + t.Fatalf("CreateAccount: %v", err) + } + + now := time.Now().UTC() + exp := now.Add(time.Hour) + + // Issue first system token with no old JTI. + jti1 := "atomic-sys-tok-1" + if err := db.IssueSystemToken("", jti1, acct.ID, now, exp); err != nil { + t.Fatalf("IssueSystemToken first: %v", err) + } + + // Verify the first token is tracked and not revoked. + rec1, err := db.GetTokenRecord(jti1) + if err != nil { + t.Fatalf("GetTokenRecord jti1: %v", err) + } + if rec1.IsRevoked() { + t.Error("first token should not be revoked") + } + + // Verify system_tokens points to the first token. + st1, err := db.GetSystemToken(acct.ID) + if err != nil { + t.Fatalf("GetSystemToken after first issue: %v", err) + } + if st1.JTI != jti1 { + t.Errorf("system token JTI = %q, want %q", st1.JTI, jti1) + } + + // Issue second token, which should atomically revoke the first. + jti2 := "atomic-sys-tok-2" + if err := db.IssueSystemToken(jti1, jti2, acct.ID, now, exp); err != nil { + t.Fatalf("IssueSystemToken second: %v", err) + } + + // First token must be revoked. + rec1After, err := db.GetTokenRecord(jti1) + if err != nil { + t.Fatalf("GetTokenRecord jti1 after rotation: %v", err) + } + if !rec1After.IsRevoked() { + t.Error("first token should be revoked after second issue") + } + if rec1After.RevokeReason != "rotated" { + t.Errorf("revoke reason = %q, want %q", rec1After.RevokeReason, "rotated") + } + + // Second token must be tracked and not revoked. + rec2, err := db.GetTokenRecord(jti2) + if err != nil { + t.Fatalf("GetTokenRecord jti2: %v", err) + } + if rec2.IsRevoked() { + t.Error("second token should not be revoked") + } + + // system_tokens must point to the second token. + st2, err := db.GetSystemToken(acct.ID) + if err != nil { + t.Fatalf("GetSystemToken after second issue: %v", err) + } + if st2.JTI != jti2 { + t.Errorf("system token JTI = %q, want %q", st2.JTI, jti2) + } +} + func TestRevokeAllUserTokens(t *testing.T) { db := openTestDB(t) acct, err := db.CreateAccount("ivan", model.AccountTypeHuman, "hash") diff --git a/internal/grpcserver/tokenservice.go b/internal/grpcserver/tokenservice.go index 421d07d..fcd4501 100644 --- a/internal/grpcserver/tokenservice.go +++ b/internal/grpcserver/tokenservice.go @@ -72,16 +72,15 @@ func (ts *tokenServiceServer) IssueServiceToken(ctx context.Context, req *mciasv return nil, status.Error(codes.Internal, "internal error") } - // Revoke existing system token if any. + // Atomically revoke existing system token (if any), track the new token, + // and update system_tokens — all in a single transaction. + // Security: prevents inconsistent state if a crash occurs mid-operation. + var oldJTI string existing, err := ts.s.db.GetSystemToken(acct.ID) if err == nil && existing != nil { - _ = ts.s.db.RevokeToken(existing.JTI, "rotated") + oldJTI = existing.JTI } - - if err := ts.s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { - return nil, status.Error(codes.Internal, "internal error") - } - if err := ts.s.db.SetSystemToken(acct.ID, claims.JTI, claims.ExpiresAt); err != nil { + if err := ts.s.db.IssueSystemToken(oldJTI, claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { return nil, status.Error(codes.Internal, "internal error") } diff --git a/internal/server/server.go b/internal/server/server.go index f791276..a45bb87 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -470,17 +470,15 @@ func (s *Server) handleTokenIssue(w http.ResponseWriter, r *http.Request) { return } - // Revoke existing system token if any. + // Atomically revoke existing system token (if any), track the new token, + // and update system_tokens — all in a single transaction. + // Security: prevents inconsistent state if a crash occurs mid-operation. + var oldJTI string existing, err := s.db.GetSystemToken(acct.ID) if err == nil && existing != nil { - _ = s.db.RevokeToken(existing.JTI, "rotated") + oldJTI = existing.JTI } - - if err := s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { - middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") - return - } - if err := s.db.SetSystemToken(acct.ID, claims.JTI, claims.ExpiresAt); err != nil { + if err := s.db.IssueSystemToken(oldJTI, claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error") return }