Implement Phase 7: gRPC dual-stack interface
- proto/mcias/v1/: AdminService, AuthService, TokenService, AccountService, CredentialService; generated Go stubs in gen/ - internal/grpcserver: full handler implementations sharing all business logic (auth, token, db, crypto) with REST server; interceptor chain: logging -> auth (JWT alg-first + revocation) -> rate-limit (token bucket, 10 req/s, burst 10, per-IP) - internal/config: optional grpc_addr field in [server] section - cmd/mciassrv: dual-stack startup; gRPC/TLS listener on grpc_addr when configured; graceful shutdown of both servers in 15s window - cmd/mciasgrpcctl: companion gRPC CLI mirroring mciasctl commands (health, pubkey, account, role, token, pgcreds) using TLS with optional custom CA cert - internal/grpcserver/grpcserver_test.go: 20 tests via bufconn covering public RPCs, auth interceptor (no token, invalid, revoked -> 401), non-admin -> 403, Login/Logout/RenewToken/ValidateToken flows, AccountService CRUD, SetPGCreds/GetPGCreds AES-GCM round-trip, credential fields absent from all responses Security: JWT validation path identical to REST: alg header checked before signature, alg:none rejected, revocation table checked after sig. Authorization metadata value never logged by any interceptor. Credential fields (PasswordHash, TOTPSecret*, PGPassword) absent from all proto response messages — enforced by proto design and confirmed by test TestCredentialFieldsAbsentFromAccountResponse. Login dummy-Argon2 timing guard preserves timing uniformity for unknown users (same as REST handleLogin). TLS required at listener level; cmd/mciassrv uses credentials.NewServerTLSFromFile; no h2c offered. 137 tests pass, zero race conditions (go test -race ./...)
This commit is contained in:
473
cmd/mciasctl/main.go
Normal file
473
cmd/mciasctl/main.go
Normal file
@@ -0,0 +1,473 @@
|
||||
// Command mciasctl is the MCIAS admin CLI.
|
||||
//
|
||||
// It connects to a running mciassrv instance and provides subcommands for
|
||||
// managing accounts, roles, tokens, and Postgres credentials.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// mciasctl [global flags] <command> [args]
|
||||
//
|
||||
// Global flags:
|
||||
//
|
||||
// -server URL of the mciassrv instance (default: https://localhost:8443)
|
||||
// -token Bearer token for authentication (or set MCIAS_TOKEN env var)
|
||||
// -cacert Path to CA certificate for TLS verification (optional)
|
||||
//
|
||||
// Commands:
|
||||
//
|
||||
// account list
|
||||
// account create -username NAME -password PASS [-type human|system]
|
||||
// account get -id UUID
|
||||
// account update -id UUID [-status active|inactive]
|
||||
// account delete -id UUID
|
||||
//
|
||||
// role list -id UUID
|
||||
// role set -id UUID -roles role1,role2,...
|
||||
//
|
||||
// token issue -id UUID
|
||||
// token revoke -jti JTI
|
||||
//
|
||||
// pgcreds set -id UUID -host HOST -port PORT -db DB -user USER -password PASS
|
||||
// pgcreds get -id UUID
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Global flags.
|
||||
serverURL := flag.String("server", "https://localhost:8443", "mciassrv base URL")
|
||||
tokenFlag := flag.String("token", "", "bearer token (or set MCIAS_TOKEN)")
|
||||
caCert := flag.String("cacert", "", "path to CA certificate for TLS")
|
||||
flag.Usage = usage
|
||||
flag.Parse()
|
||||
|
||||
// Resolve token from flag or environment.
|
||||
bearerToken := *tokenFlag
|
||||
if bearerToken == "" {
|
||||
bearerToken = os.Getenv("MCIAS_TOKEN")
|
||||
}
|
||||
|
||||
args := flag.Args()
|
||||
if len(args) == 0 {
|
||||
usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Build HTTP client.
|
||||
client, err := newHTTPClient(*caCert)
|
||||
if err != nil {
|
||||
fatalf("build HTTP client: %v", err)
|
||||
}
|
||||
|
||||
ctl := &controller{
|
||||
serverURL: strings.TrimRight(*serverURL, "/"),
|
||||
token: bearerToken,
|
||||
client: client,
|
||||
}
|
||||
|
||||
command := args[0]
|
||||
subArgs := args[1:]
|
||||
|
||||
switch command {
|
||||
case "account":
|
||||
ctl.runAccount(subArgs)
|
||||
case "role":
|
||||
ctl.runRole(subArgs)
|
||||
case "token":
|
||||
ctl.runToken(subArgs)
|
||||
case "pgcreds":
|
||||
ctl.runPGCreds(subArgs)
|
||||
default:
|
||||
fatalf("unknown command %q; run with no args to see usage", command)
|
||||
}
|
||||
}
|
||||
|
||||
// controller holds shared state for all subcommands.
|
||||
type controller struct {
|
||||
client *http.Client
|
||||
serverURL string
|
||||
token string
|
||||
}
|
||||
|
||||
// ---- account subcommands ----
|
||||
|
||||
func (c *controller) runAccount(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("account requires a subcommand: list, create, get, update, delete")
|
||||
}
|
||||
switch args[0] {
|
||||
case "list":
|
||||
c.accountList()
|
||||
case "create":
|
||||
c.accountCreate(args[1:])
|
||||
case "get":
|
||||
c.accountGet(args[1:])
|
||||
case "update":
|
||||
c.accountUpdate(args[1:])
|
||||
case "delete":
|
||||
c.accountDelete(args[1:])
|
||||
default:
|
||||
fatalf("unknown account subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controller) accountList() {
|
||||
var result []json.RawMessage
|
||||
c.doRequest("GET", "/v1/accounts", nil, &result)
|
||||
printJSON(result)
|
||||
}
|
||||
|
||||
func (c *controller) accountCreate(args []string) {
|
||||
fs := flag.NewFlagSet("account create", flag.ExitOnError)
|
||||
username := fs.String("username", "", "username (required)")
|
||||
password := fs.String("password", "", "password (required for human accounts)")
|
||||
accountType := fs.String("type", "human", "account type: human or system")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *username == "" {
|
||||
fatalf("account create: -username is required")
|
||||
}
|
||||
|
||||
body := map[string]string{
|
||||
"username": *username,
|
||||
"account_type": *accountType,
|
||||
}
|
||||
if *password != "" {
|
||||
body["password"] = *password
|
||||
}
|
||||
|
||||
var result json.RawMessage
|
||||
c.doRequest("POST", "/v1/accounts", body, &result)
|
||||
printJSON(result)
|
||||
}
|
||||
|
||||
func (c *controller) accountGet(args []string) {
|
||||
fs := flag.NewFlagSet("account get", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("account get: -id is required")
|
||||
}
|
||||
|
||||
var result json.RawMessage
|
||||
c.doRequest("GET", "/v1/accounts/"+*id, nil, &result)
|
||||
printJSON(result)
|
||||
}
|
||||
|
||||
func (c *controller) accountUpdate(args []string) {
|
||||
fs := flag.NewFlagSet("account update", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
status := fs.String("status", "", "new status: active or inactive")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("account update: -id is required")
|
||||
}
|
||||
if *status == "" {
|
||||
fatalf("account update: -status is required")
|
||||
}
|
||||
|
||||
body := map[string]string{"status": *status}
|
||||
c.doRequest("PATCH", "/v1/accounts/"+*id, body, nil)
|
||||
fmt.Println("account updated")
|
||||
}
|
||||
|
||||
func (c *controller) accountDelete(args []string) {
|
||||
fs := flag.NewFlagSet("account delete", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("account delete: -id is required")
|
||||
}
|
||||
|
||||
c.doRequest("DELETE", "/v1/accounts/"+*id, nil, nil)
|
||||
fmt.Println("account deleted")
|
||||
}
|
||||
|
||||
// ---- role subcommands ----
|
||||
|
||||
func (c *controller) runRole(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("role requires a subcommand: list, set")
|
||||
}
|
||||
switch args[0] {
|
||||
case "list":
|
||||
c.roleList(args[1:])
|
||||
case "set":
|
||||
c.roleSet(args[1:])
|
||||
default:
|
||||
fatalf("unknown role subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controller) roleList(args []string) {
|
||||
fs := flag.NewFlagSet("role list", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("role list: -id is required")
|
||||
}
|
||||
|
||||
var result json.RawMessage
|
||||
c.doRequest("GET", "/v1/accounts/"+*id+"/roles", nil, &result)
|
||||
printJSON(result)
|
||||
}
|
||||
|
||||
func (c *controller) roleSet(args []string) {
|
||||
fs := flag.NewFlagSet("role set", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
rolesFlag := fs.String("roles", "", "comma-separated list of roles")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("role set: -id is required")
|
||||
}
|
||||
|
||||
roles := []string{}
|
||||
if *rolesFlag != "" {
|
||||
for _, r := range strings.Split(*rolesFlag, ",") {
|
||||
r = strings.TrimSpace(r)
|
||||
if r != "" {
|
||||
roles = append(roles, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body := map[string][]string{"roles": roles}
|
||||
c.doRequest("PUT", "/v1/accounts/"+*id+"/roles", body, nil)
|
||||
fmt.Printf("roles set: %v\n", roles)
|
||||
}
|
||||
|
||||
// ---- token subcommands ----
|
||||
|
||||
func (c *controller) runToken(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("token requires a subcommand: issue, revoke")
|
||||
}
|
||||
switch args[0] {
|
||||
case "issue":
|
||||
c.tokenIssue(args[1:])
|
||||
case "revoke":
|
||||
c.tokenRevoke(args[1:])
|
||||
default:
|
||||
fatalf("unknown token subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controller) tokenIssue(args []string) {
|
||||
fs := flag.NewFlagSet("token issue", flag.ExitOnError)
|
||||
id := fs.String("id", "", "system account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("token issue: -id is required")
|
||||
}
|
||||
|
||||
body := map[string]string{"account_id": *id}
|
||||
var result json.RawMessage
|
||||
c.doRequest("POST", "/v1/token/issue", body, &result)
|
||||
printJSON(result)
|
||||
}
|
||||
|
||||
func (c *controller) tokenRevoke(args []string) {
|
||||
fs := flag.NewFlagSet("token revoke", flag.ExitOnError)
|
||||
jti := fs.String("jti", "", "JTI of the token to revoke (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *jti == "" {
|
||||
fatalf("token revoke: -jti is required")
|
||||
}
|
||||
|
||||
c.doRequest("DELETE", "/v1/token/"+*jti, nil, nil)
|
||||
fmt.Println("token revoked")
|
||||
}
|
||||
|
||||
// ---- pgcreds subcommands ----
|
||||
|
||||
func (c *controller) runPGCreds(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("pgcreds requires a subcommand: get, set")
|
||||
}
|
||||
switch args[0] {
|
||||
case "get":
|
||||
c.pgCredsGet(args[1:])
|
||||
case "set":
|
||||
c.pgCredsSet(args[1:])
|
||||
default:
|
||||
fatalf("unknown pgcreds subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controller) pgCredsGet(args []string) {
|
||||
fs := flag.NewFlagSet("pgcreds get", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("pgcreds get: -id is required")
|
||||
}
|
||||
|
||||
var result json.RawMessage
|
||||
c.doRequest("GET", "/v1/accounts/"+*id+"/pgcreds", nil, &result)
|
||||
printJSON(result)
|
||||
}
|
||||
|
||||
func (c *controller) pgCredsSet(args []string) {
|
||||
fs := flag.NewFlagSet("pgcreds set", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
host := fs.String("host", "", "Postgres host (required)")
|
||||
port := fs.Int("port", 5432, "Postgres port")
|
||||
dbName := fs.String("db", "", "Postgres database name (required)")
|
||||
username := fs.String("user", "", "Postgres username (required)")
|
||||
password := fs.String("password", "", "Postgres password (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" || *host == "" || *dbName == "" || *username == "" || *password == "" {
|
||||
fatalf("pgcreds set: -id, -host, -db, -user, and -password are required")
|
||||
}
|
||||
|
||||
body := map[string]interface{}{
|
||||
"host": *host,
|
||||
"port": *port,
|
||||
"database": *dbName,
|
||||
"username": *username,
|
||||
"password": *password,
|
||||
}
|
||||
c.doRequest("PUT", "/v1/accounts/"+*id+"/pgcreds", body, nil)
|
||||
fmt.Println("credentials stored")
|
||||
}
|
||||
|
||||
// ---- HTTP helpers ----
|
||||
|
||||
// doRequest performs an authenticated JSON HTTP request. If result is non-nil,
|
||||
// the response body is decoded into it. Exits on error.
|
||||
func (c *controller) doRequest(method, path string, body, result interface{}) {
|
||||
url := c.serverURL + path
|
||||
|
||||
var bodyReader *strings.Reader
|
||||
if body != nil {
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
fatalf("marshal request body: %v", err)
|
||||
}
|
||||
bodyReader = strings.NewReader(string(b))
|
||||
} else {
|
||||
bodyReader = strings.NewReader("")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, url, bodyReader)
|
||||
if err != nil {
|
||||
fatalf("create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
fatalf("HTTP %s %s: %v", method, path, err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
var errBody map[string]string
|
||||
_ = json.NewDecoder(resp.Body).Decode(&errBody)
|
||||
msg := errBody["error"]
|
||||
if msg == "" {
|
||||
msg = resp.Status
|
||||
}
|
||||
fatalf("server returned %d: %s", resp.StatusCode, msg)
|
||||
}
|
||||
|
||||
if result != nil && resp.StatusCode != http.StatusNoContent {
|
||||
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
|
||||
fatalf("decode response: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newHTTPClient builds an http.Client with optional custom CA certificate.
|
||||
// Security: TLS 1.2+ is required; the system CA pool is used by default.
|
||||
func newHTTPClient(caCertPath string) (*http.Client, error) {
|
||||
tlsCfg := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
if caCertPath != "" {
|
||||
// G304: path comes from a CLI flag supplied by the operator, not from
|
||||
// untrusted input. File inclusion is intentional.
|
||||
pemData, err := os.ReadFile(caCertPath) //nolint:gosec
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read CA cert: %w", err)
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(pemData) {
|
||||
return nil, fmt.Errorf("no valid certificates found in %s", caCertPath)
|
||||
}
|
||||
tlsCfg.RootCAs = pool
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: tlsCfg,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// printJSON pretty-prints a JSON value to stdout.
|
||||
func printJSON(v interface{}) {
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(v); err != nil {
|
||||
fatalf("encode output: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// fatalf prints an error message and exits with code 1.
|
||||
func fatalf(format string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "mciasctl: "+format+"\n", args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func usage() {
|
||||
fmt.Fprintf(os.Stderr, `mciasctl - MCIAS admin CLI
|
||||
|
||||
Usage: mciasctl [global flags] <command> [args]
|
||||
|
||||
Global flags:
|
||||
-server URL of the mciassrv instance (default: https://localhost:8443)
|
||||
-token Bearer token (or set MCIAS_TOKEN env var)
|
||||
-cacert Path to CA certificate for TLS verification
|
||||
|
||||
Commands:
|
||||
account list
|
||||
account create -username NAME -password PASS [-type human|system]
|
||||
account get -id UUID
|
||||
account update -id UUID -status active|inactive
|
||||
account delete -id UUID
|
||||
|
||||
role list -id UUID
|
||||
role set -id UUID -roles role1,role2,...
|
||||
|
||||
token issue -id UUID
|
||||
token revoke -jti JTI
|
||||
|
||||
pgcreds get -id UUID
|
||||
pgcreds set -id UUID -host HOST [-port PORT] -db DB -user USER -password PASS
|
||||
`)
|
||||
}
|
||||
251
cmd/mciasdb/account.go
Normal file
251
cmd/mciasdb/account.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func (t *tool) runAccount(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("account requires a subcommand: list, get, create, set-password, set-status, reset-totp")
|
||||
}
|
||||
switch args[0] {
|
||||
case "list":
|
||||
t.accountList()
|
||||
case "get":
|
||||
t.accountGet(args[1:])
|
||||
case "create":
|
||||
t.accountCreate(args[1:])
|
||||
case "set-password":
|
||||
t.accountSetPassword(args[1:])
|
||||
case "set-status":
|
||||
t.accountSetStatus(args[1:])
|
||||
case "reset-totp":
|
||||
t.accountResetTOTP(args[1:])
|
||||
default:
|
||||
fatalf("unknown account subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tool) accountList() {
|
||||
accounts, err := t.db.ListAccounts()
|
||||
if err != nil {
|
||||
fatalf("list accounts: %v", err)
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
fmt.Println("no accounts found")
|
||||
return
|
||||
}
|
||||
fmt.Printf("%-36s %-20s %-8s %-10s\n", "UUID", "USERNAME", "TYPE", "STATUS")
|
||||
fmt.Println(strings.Repeat("-", 80))
|
||||
for _, a := range accounts {
|
||||
fmt.Printf("%-36s %-20s %-8s %-10s\n",
|
||||
a.UUID, a.Username, string(a.AccountType), string(a.Status))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tool) accountGet(args []string) {
|
||||
fs := flag.NewFlagSet("account get", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("account get: --id is required")
|
||||
}
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("UUID: %s\n", a.UUID)
|
||||
fmt.Printf("Username: %s\n", a.Username)
|
||||
fmt.Printf("Type: %s\n", a.AccountType)
|
||||
fmt.Printf("Status: %s\n", a.Status)
|
||||
fmt.Printf("TOTP required: %v\n", a.TOTPRequired)
|
||||
fmt.Printf("Created: %s\n", a.CreatedAt.Format("2006-01-02T15:04:05Z"))
|
||||
fmt.Printf("Updated: %s\n", a.UpdatedAt.Format("2006-01-02T15:04:05Z"))
|
||||
if a.DeletedAt != nil {
|
||||
fmt.Printf("Deleted: %s\n", a.DeletedAt.Format("2006-01-02T15:04:05Z"))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tool) accountCreate(args []string) {
|
||||
fs := flag.NewFlagSet("account create", flag.ExitOnError)
|
||||
username := fs.String("username", "", "username (required)")
|
||||
accountType := fs.String("type", "human", "account type: human or system")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *username == "" {
|
||||
fatalf("account create: --username is required")
|
||||
}
|
||||
if *accountType != "human" && *accountType != "system" {
|
||||
fatalf("account create: --type must be human or system")
|
||||
}
|
||||
|
||||
atype := model.AccountType(*accountType)
|
||||
a, err := t.db.CreateAccount(*username, atype, "")
|
||||
if err != nil {
|
||||
fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.WriteAuditEvent("account_created", nil, &a.ID, "", fmt.Sprintf(`{"actor":"mciasdb","username":%q}`, *username)); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: write audit event: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("created account %s (UUID: %s)\n", *username, a.UUID)
|
||||
}
|
||||
|
||||
// accountSetPassword prompts twice for a new password, hashes it with
|
||||
// Argon2id, and updates the account's password_hash column.
|
||||
//
|
||||
// Security: No --password flag is provided; passwords must be entered
|
||||
// interactively so they never appear in shell history or process listings.
|
||||
// The password is hashed with Argon2id using OWASP-compliant parameters before
|
||||
// any database write.
|
||||
func (t *tool) accountSetPassword(args []string) {
|
||||
fs := flag.NewFlagSet("account set-password", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("account set-password: --id is required")
|
||||
}
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Setting password for account %s (%s)\n", a.Username, a.UUID)
|
||||
|
||||
password, err := readPassword("New password: ")
|
||||
if err != nil {
|
||||
fatalf("read password: %v", err)
|
||||
}
|
||||
confirm, err := readPassword("Confirm password: ")
|
||||
if err != nil {
|
||||
fatalf("read confirm: %v", err)
|
||||
}
|
||||
if password != confirm {
|
||||
fatalf("passwords do not match")
|
||||
}
|
||||
if password == "" {
|
||||
fatalf("password must not be empty")
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(password, auth.DefaultArgonParams())
|
||||
if err != nil {
|
||||
fatalf("hash password: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.UpdatePasswordHash(a.ID, hash); err != nil {
|
||||
fatalf("update password hash: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.WriteAuditEvent("account_updated", nil, &a.ID, "", `{"actor":"mciasdb","action":"set_password"}`); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: write audit event: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("password updated for account %s\n", a.Username)
|
||||
}
|
||||
|
||||
func (t *tool) accountSetStatus(args []string) {
|
||||
fs := flag.NewFlagSet("account set-status", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
status := fs.String("status", "", "new status: active, inactive, or deleted (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("account set-status: --id is required")
|
||||
}
|
||||
if *status == "" {
|
||||
fatalf("account set-status: --status is required")
|
||||
}
|
||||
|
||||
var newStatus model.AccountStatus
|
||||
switch *status {
|
||||
case "active":
|
||||
newStatus = model.AccountStatusActive
|
||||
case "inactive":
|
||||
newStatus = model.AccountStatusInactive
|
||||
case "deleted":
|
||||
newStatus = model.AccountStatusDeleted
|
||||
default:
|
||||
fatalf("account set-status: --status must be active, inactive, or deleted")
|
||||
}
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.UpdateAccountStatus(a.ID, newStatus); err != nil {
|
||||
fatalf("update account status: %v", err)
|
||||
}
|
||||
|
||||
eventType := "account_updated"
|
||||
if newStatus == model.AccountStatusDeleted {
|
||||
eventType = "account_deleted"
|
||||
}
|
||||
if err := t.db.WriteAuditEvent(eventType, nil, &a.ID, "", fmt.Sprintf(`{"actor":"mciasdb","status":%q}`, *status)); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: write audit event: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("account %s status set to %s\n", a.Username, *status)
|
||||
}
|
||||
|
||||
// accountResetTOTP clears TOTP fields for the account, disabling the
|
||||
// TOTP requirement. This is a break-glass operation for locked-out users.
|
||||
func (t *tool) accountResetTOTP(args []string) {
|
||||
fs := flag.NewFlagSet("account reset-totp", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("account reset-totp: --id is required")
|
||||
}
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.ClearTOTP(a.ID); err != nil {
|
||||
fatalf("clear TOTP: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.WriteAuditEvent("totp_removed", nil, &a.ID, "", `{"actor":"mciasdb","action":"reset_totp"}`); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: write audit event: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("TOTP cleared for account %s\n", a.Username)
|
||||
}
|
||||
|
||||
// readPassword reads a password from the terminal without echo.
|
||||
// Falls back to a regular line read if stdin is not a terminal (e.g. in tests).
|
||||
func readPassword(prompt string) (string, error) {
|
||||
fmt.Fprint(os.Stderr, prompt)
|
||||
fd := int(os.Stdin.Fd()) //nolint:gosec // G115: file descriptors are non-negative and fit in int on all supported platforms
|
||||
if term.IsTerminal(fd) {
|
||||
pw, err := term.ReadPassword(fd)
|
||||
fmt.Fprintln(os.Stderr) // newline after hidden input
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read password from terminal: %w", err)
|
||||
}
|
||||
return string(pw), nil
|
||||
}
|
||||
// Not a terminal: read a plain line (for piped input in tests).
|
||||
var line string
|
||||
_, err := fmt.Fscanln(os.Stdin, &line)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read password: %w", err)
|
||||
}
|
||||
return line, nil
|
||||
}
|
||||
116
cmd/mciasdb/audit.go
Normal file
116
cmd/mciasdb/audit.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
)
|
||||
|
||||
func (t *tool) runAudit(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("audit requires a subcommand: tail, query")
|
||||
}
|
||||
switch args[0] {
|
||||
case "tail":
|
||||
t.auditTail(args[1:])
|
||||
case "query":
|
||||
t.auditQuery(args[1:])
|
||||
default:
|
||||
fatalf("unknown audit subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tool) auditTail(args []string) {
|
||||
fs := flag.NewFlagSet("audit tail", flag.ExitOnError)
|
||||
n := fs.Int("n", 50, "number of events to show")
|
||||
asJSON := fs.Bool("json", false, "output as newline-delimited JSON")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *n <= 0 {
|
||||
fatalf("audit tail: --n must be positive")
|
||||
}
|
||||
|
||||
events, err := t.db.TailAuditEvents(*n)
|
||||
if err != nil {
|
||||
fatalf("tail audit events: %v", err)
|
||||
}
|
||||
|
||||
printAuditEvents(events, *asJSON)
|
||||
}
|
||||
|
||||
func (t *tool) auditQuery(args []string) {
|
||||
fs := flag.NewFlagSet("audit query", flag.ExitOnError)
|
||||
accountUUID := fs.String("account", "", "filter by account UUID (actor or target)")
|
||||
eventType := fs.String("type", "", "filter by event type")
|
||||
sinceStr := fs.String("since", "", "filter events on or after this RFC-3339 timestamp")
|
||||
asJSON := fs.Bool("json", false, "output as newline-delimited JSON")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
p := db.AuditQueryParams{
|
||||
EventType: *eventType,
|
||||
}
|
||||
|
||||
if *accountUUID != "" {
|
||||
a, err := t.db.GetAccountByUUID(*accountUUID)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
p.AccountID = &a.ID
|
||||
}
|
||||
|
||||
if *sinceStr != "" {
|
||||
since, err := time.Parse(time.RFC3339, *sinceStr)
|
||||
if err != nil {
|
||||
fatalf("audit query: --since must be an RFC-3339 timestamp (e.g. 2006-01-02T15:04:05Z): %v", err)
|
||||
}
|
||||
p.Since = &since
|
||||
}
|
||||
|
||||
events, err := t.db.ListAuditEvents(p)
|
||||
if err != nil {
|
||||
fatalf("query audit events: %v", err)
|
||||
}
|
||||
|
||||
printAuditEvents(events, *asJSON)
|
||||
}
|
||||
|
||||
func printAuditEvents(events []*model.AuditEvent, asJSON bool) {
|
||||
if len(events) == 0 {
|
||||
fmt.Println("no audit events found")
|
||||
return
|
||||
}
|
||||
|
||||
if asJSON {
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
for _, ev := range events {
|
||||
if err := enc.Encode(ev); err != nil {
|
||||
fatalf("encode audit event: %v", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("%-20s %-22s %-15s %s\n", "TIME", "EVENT TYPE", "IP", "DETAILS")
|
||||
fmt.Println("────────────────────────────────────────────────────────────────────────────────")
|
||||
for _, ev := range events {
|
||||
ip := ev.IPAddress
|
||||
if ip == "" {
|
||||
ip = "-"
|
||||
}
|
||||
details := ev.Details
|
||||
if details == "" {
|
||||
details = "-"
|
||||
}
|
||||
fmt.Printf("%-20s %-22s %-15s %s\n",
|
||||
ev.EventTime.Format("2006-01-02T15:04:05Z"),
|
||||
ev.EventType,
|
||||
ip,
|
||||
details,
|
||||
)
|
||||
}
|
||||
}
|
||||
242
cmd/mciasdb/main.go
Normal file
242
cmd/mciasdb/main.go
Normal file
@@ -0,0 +1,242 @@
|
||||
// Command mciasdb is the MCIAS database maintenance tool.
|
||||
//
|
||||
// It operates directly on the SQLite file, bypassing the mciassrv API.
|
||||
// Use it for break-glass recovery, offline inspection, schema verification,
|
||||
// and maintenance tasks when the server is unavailable.
|
||||
//
|
||||
// mciasdb requires the same master key configuration as mciassrv (passphrase
|
||||
// environment variable or keyfile) to decrypt secrets at rest.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// mciasdb --config /etc/mcias/mcias.toml <command> [subcommand] [flags]
|
||||
//
|
||||
// Commands:
|
||||
//
|
||||
// schema verify
|
||||
// schema migrate
|
||||
//
|
||||
// account list
|
||||
// account get --id UUID
|
||||
// account create --username NAME --type human|system
|
||||
// account set-password --id UUID
|
||||
// account set-status --id UUID --status active|inactive|deleted
|
||||
// account reset-totp --id UUID
|
||||
//
|
||||
// role list --id UUID
|
||||
// role grant --id UUID --role ROLE
|
||||
// role revoke --id UUID --role ROLE
|
||||
//
|
||||
// token list --id UUID
|
||||
// token revoke --jti JTI
|
||||
// token revoke-all --id UUID
|
||||
//
|
||||
// prune tokens
|
||||
//
|
||||
// audit tail [--n N]
|
||||
// audit query [--account UUID] [--type TYPE] [--since RFC3339] [--json]
|
||||
//
|
||||
// pgcreds get --id UUID
|
||||
// pgcreds set --id UUID --host H --port P --db D --user U
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/config"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
)
|
||||
|
||||
func main() {
|
||||
configPath := flag.String("config", "mcias.toml", "path to TOML configuration file")
|
||||
flag.Usage = usage
|
||||
flag.Parse()
|
||||
|
||||
args := flag.Args()
|
||||
if len(args) == 0 {
|
||||
usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
database, masterKey, err := openDB(*configPath)
|
||||
if err != nil {
|
||||
fatalf("%v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = database.Close()
|
||||
// Zero the master key when done to reduce the window of in-memory exposure.
|
||||
for i := range masterKey {
|
||||
masterKey[i] = 0
|
||||
}
|
||||
}()
|
||||
|
||||
tool := &tool{db: database, masterKey: masterKey}
|
||||
|
||||
command := args[0]
|
||||
subArgs := args[1:]
|
||||
|
||||
switch command {
|
||||
case "schema":
|
||||
tool.runSchema(subArgs)
|
||||
case "account":
|
||||
tool.runAccount(subArgs)
|
||||
case "role":
|
||||
tool.runRole(subArgs)
|
||||
case "token":
|
||||
tool.runToken(subArgs)
|
||||
case "prune":
|
||||
tool.runPrune(subArgs)
|
||||
case "audit":
|
||||
tool.runAudit(subArgs)
|
||||
case "pgcreds":
|
||||
tool.runPGCreds(subArgs)
|
||||
default:
|
||||
fatalf("unknown command %q; run with no args for usage", command)
|
||||
}
|
||||
}
|
||||
|
||||
// tool holds shared state for all subcommand handlers.
|
||||
type tool struct {
|
||||
db *db.DB
|
||||
masterKey []byte
|
||||
}
|
||||
|
||||
// openDB loads the config, derives the master key, opens and migrates the DB.
|
||||
//
|
||||
// Security: Master key derivation uses the same logic as mciassrv so that
|
||||
// the same passphrase always yields the same key and encrypted secrets remain
|
||||
// readable. The passphrase env var is unset immediately after reading.
|
||||
func openDB(configPath string) (*db.DB, []byte, error) {
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
database, err := db.Open(cfg.Database.Path)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("open database %q: %w", cfg.Database.Path, err)
|
||||
}
|
||||
|
||||
if err := db.Migrate(database); err != nil {
|
||||
_ = database.Close()
|
||||
return nil, nil, fmt.Errorf("migrate database: %w", err)
|
||||
}
|
||||
|
||||
masterKey, err := deriveMasterKey(cfg, database)
|
||||
if err != nil {
|
||||
_ = database.Close()
|
||||
return nil, nil, fmt.Errorf("derive master key: %w", err)
|
||||
}
|
||||
|
||||
return database, masterKey, nil
|
||||
}
|
||||
|
||||
// deriveMasterKey derives or loads the AES-256-GCM master key from config,
|
||||
// using identical logic to mciassrv so that encrypted DB secrets are readable.
|
||||
//
|
||||
// Security: Key file must be exactly 32 bytes (AES-256). Passphrase is read
|
||||
// from the environment variable named in cfg.MasterKey.PassphraseEnv and
|
||||
// cleared from the environment immediately after. The Argon2id KDF salt is
|
||||
// loaded from the database; if absent the DB has no encrypted secrets yet.
|
||||
func deriveMasterKey(cfg *config.Config, database *db.DB) ([]byte, error) {
|
||||
if cfg.MasterKey.KeyFile != "" {
|
||||
data, err := os.ReadFile(cfg.MasterKey.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read key file: %w", err)
|
||||
}
|
||||
if len(data) != 32 {
|
||||
return nil, fmt.Errorf("key file must be exactly 32 bytes, got %d", len(data))
|
||||
}
|
||||
key := make([]byte, 32)
|
||||
copy(key, data)
|
||||
for i := range data {
|
||||
data[i] = 0
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
passphrase := os.Getenv(cfg.MasterKey.PassphraseEnv)
|
||||
if passphrase == "" {
|
||||
return nil, fmt.Errorf("environment variable %q is not set or empty", cfg.MasterKey.PassphraseEnv)
|
||||
}
|
||||
_ = os.Unsetenv(cfg.MasterKey.PassphraseEnv)
|
||||
|
||||
salt, err := database.ReadMasterKeySalt()
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
// No salt means the database has no encrypted secrets yet.
|
||||
// Generate a new salt so future writes are consistent.
|
||||
salt, err = crypto.NewSalt()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate master key salt: %w", err)
|
||||
}
|
||||
if err := database.WriteMasterKeySalt(salt); err != nil {
|
||||
return nil, fmt.Errorf("store master key salt: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("read master key salt: %w", err)
|
||||
}
|
||||
|
||||
key, err := crypto.DeriveKey(passphrase, salt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("derive master key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// fatalf prints an error message to stderr and exits with code 1.
|
||||
func fatalf(format string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "mciasdb: "+format+"\n", args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// exitCode1 exits with code 1 without printing any message.
|
||||
// Used when the message has already been printed.
|
||||
func exitCode1() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func usage() {
|
||||
fmt.Fprint(os.Stderr, `mciasdb - MCIAS database maintenance tool
|
||||
|
||||
Usage: mciasdb --config PATH <command> [subcommand] [flags]
|
||||
|
||||
Global flags:
|
||||
--config Path to TOML config file (default: mcias.toml)
|
||||
|
||||
Commands:
|
||||
schema verify Check schema version; exit 1 if migrations pending
|
||||
schema migrate Apply any pending schema migrations
|
||||
|
||||
account list List all accounts
|
||||
account get --id UUID
|
||||
account create --username NAME --type human|system
|
||||
account set-password --id UUID (prompts interactively)
|
||||
account set-status --id UUID --status active|inactive|deleted
|
||||
account reset-totp --id UUID
|
||||
|
||||
role list --id UUID
|
||||
role grant --id UUID --role ROLE
|
||||
role revoke --id UUID --role ROLE
|
||||
|
||||
token list --id UUID
|
||||
token revoke --jti JTI
|
||||
token revoke-all --id UUID
|
||||
|
||||
prune tokens Delete expired token_revocation rows
|
||||
|
||||
audit tail [--n N] (default 50)
|
||||
audit query [--account UUID] [--type TYPE] [--since RFC3339] [--json]
|
||||
|
||||
pgcreds get --id UUID
|
||||
pgcreds set --id UUID --host H [--port P] --db D --user U
|
||||
(password is prompted interactively)
|
||||
|
||||
NOTE: mciasdb bypasses the mciassrv API and operates directly on the SQLite
|
||||
file. Use it only when the server is unavailable or for break-glass recovery.
|
||||
All write operations are recorded in the audit log.
|
||||
`)
|
||||
}
|
||||
440
cmd/mciasdb/mciasdb_test.go
Normal file
440
cmd/mciasdb/mciasdb_test.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
)
|
||||
|
||||
// newTestTool creates a tool backed by an in-memory SQLite database with a
|
||||
// freshly generated master key. The database is migrated to the latest schema.
|
||||
func newTestTool(t *testing.T) *tool {
|
||||
t.Helper()
|
||||
database, err := db.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("open test DB: %v", err)
|
||||
}
|
||||
if err := db.Migrate(database); err != nil {
|
||||
t.Fatalf("migrate test DB: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = database.Close() })
|
||||
|
||||
// Use a random 32-byte master key for encryption tests.
|
||||
masterKey, err := crypto.RandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("generate master key: %v", err)
|
||||
}
|
||||
|
||||
return &tool{db: database, masterKey: masterKey}
|
||||
}
|
||||
|
||||
// captureStdout captures stdout output during fn execution.
|
||||
func captureStdout(t *testing.T, fn func()) string {
|
||||
t.Helper()
|
||||
orig := os.Stdout
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create pipe: %v", err)
|
||||
}
|
||||
os.Stdout = w
|
||||
|
||||
fn()
|
||||
|
||||
_ = w.Close()
|
||||
os.Stdout = orig
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.Copy(&buf, r); err != nil {
|
||||
t.Fatalf("copy stdout: %v", err)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// ---- schema tests ----
|
||||
|
||||
func TestSchemaVerifyUpToDate(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
// Capture output; schemaVerify calls exitCode1 if migrations pending,
|
||||
// but with a freshly migrated DB it should print "up-to-date".
|
||||
out := captureStdout(t, tool.schemaVerify)
|
||||
if !strings.Contains(out, "up-to-date") {
|
||||
t.Errorf("expected 'up-to-date' in output, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- account tests ----
|
||||
|
||||
func TestAccountListEmpty(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
out := captureStdout(t, tool.accountList)
|
||||
if !strings.Contains(out, "no accounts") {
|
||||
t.Errorf("expected 'no accounts' in output, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountCreateAndList(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
// Create via DB method directly (accountCreate reads args via flags so
|
||||
// we test the DB path to avoid os.Exit on parse error).
|
||||
a, err := tool.db.CreateAccount("testuser", model.AccountTypeHuman, "")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
if a.UUID == "" {
|
||||
t.Error("expected UUID to be set")
|
||||
}
|
||||
|
||||
out := captureStdout(t, tool.accountList)
|
||||
if !strings.Contains(out, "testuser") {
|
||||
t.Errorf("expected 'testuser' in list output, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetByUUID(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("getuser", model.AccountTypeSystem, "")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
tool.accountGet([]string{"--id", a.UUID})
|
||||
})
|
||||
if !strings.Contains(out, "getuser") {
|
||||
t.Errorf("expected 'getuser' in get output, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "system") {
|
||||
t.Errorf("expected 'system' in get output, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSetStatus(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("statususer", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
captureStdout(t, func() {
|
||||
tool.accountSetStatus([]string{"--id", a.UUID, "--status", "inactive"})
|
||||
})
|
||||
|
||||
updated, err := tool.db.GetAccountByUUID(a.UUID)
|
||||
if err != nil {
|
||||
t.Fatalf("get account after update: %v", err)
|
||||
}
|
||||
if updated.Status != model.AccountStatusInactive {
|
||||
t.Errorf("expected inactive status, got %s", updated.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResetTOTP(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("totpuser", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
// Set TOTP fields.
|
||||
if err := tool.db.SetTOTP(a.ID, []byte("enc"), []byte("nonce")); err != nil {
|
||||
t.Fatalf("set TOTP: %v", err)
|
||||
}
|
||||
|
||||
captureStdout(t, func() {
|
||||
tool.accountResetTOTP([]string{"--id", a.UUID})
|
||||
})
|
||||
|
||||
updated, err := tool.db.GetAccountByUUID(a.UUID)
|
||||
if err != nil {
|
||||
t.Fatalf("get account after reset: %v", err)
|
||||
}
|
||||
if updated.TOTPRequired {
|
||||
t.Error("expected TOTP to be cleared")
|
||||
}
|
||||
if len(updated.TOTPSecretEnc) != 0 {
|
||||
t.Error("expected TOTP secret to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- role tests ----
|
||||
|
||||
func TestRoleGrantAndList(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("roleuser", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
captureStdout(t, func() {
|
||||
tool.roleGrant([]string{"--id", a.UUID, "--role", "admin"})
|
||||
})
|
||||
|
||||
roles, err := tool.db.GetRoles(a.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get roles: %v", err)
|
||||
}
|
||||
if len(roles) != 1 || roles[0] != "admin" {
|
||||
t.Errorf("expected [admin], got %v", roles)
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
tool.roleList([]string{"--id", a.UUID})
|
||||
})
|
||||
if !strings.Contains(out, "admin") {
|
||||
t.Errorf("expected 'admin' in role list, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleRevoke(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("revokeuser", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
if err := tool.db.GrantRole(a.ID, "editor", nil); err != nil {
|
||||
t.Fatalf("grant role: %v", err)
|
||||
}
|
||||
|
||||
captureStdout(t, func() {
|
||||
tool.roleRevoke([]string{"--id", a.UUID, "--role", "editor"})
|
||||
})
|
||||
|
||||
roles, err := tool.db.GetRoles(a.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get roles after revoke: %v", err)
|
||||
}
|
||||
if len(roles) != 0 {
|
||||
t.Errorf("expected no roles after revoke, got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- token tests ----
|
||||
|
||||
func TestTokenListAndRevoke(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("tokenuser", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if err := tool.db.TrackToken("test-jti-1", a.ID, now, now.Add(time.Hour)); err != nil {
|
||||
t.Fatalf("track token: %v", err)
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
tool.tokenList([]string{"--id", a.UUID})
|
||||
})
|
||||
if !strings.Contains(out, "test-jti-1") {
|
||||
t.Errorf("expected jti in token list, got: %s", out)
|
||||
}
|
||||
|
||||
captureStdout(t, func() {
|
||||
tool.tokenRevoke([]string{"--jti", "test-jti-1"})
|
||||
})
|
||||
|
||||
rec, err := tool.db.GetTokenRecord("test-jti-1")
|
||||
if err != nil {
|
||||
t.Fatalf("get token record: %v", err)
|
||||
}
|
||||
if rec.RevokedAt == nil {
|
||||
t.Error("expected token to be revoked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenRevokeAll(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("revokealluser", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
for i := 0; i < 3; i++ {
|
||||
jti := fmt.Sprintf("bulk-jti-%d", i)
|
||||
if err := tool.db.TrackToken(jti, a.ID, now, now.Add(time.Hour)); err != nil {
|
||||
t.Fatalf("track token %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
captureStdout(t, func() {
|
||||
tool.tokenRevokeAll([]string{"--id", a.UUID})
|
||||
})
|
||||
|
||||
// Verify all tokens are revoked.
|
||||
records, err := tool.db.ListTokensForAccount(a.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("list tokens: %v", err)
|
||||
}
|
||||
for _, r := range records {
|
||||
if r.RevokedAt == nil {
|
||||
t.Errorf("token %s should be revoked", r.JTI)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneTokens(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("pruneuser", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
past := time.Now().Add(-2 * time.Hour).UTC()
|
||||
future := time.Now().Add(time.Hour).UTC()
|
||||
|
||||
if err := tool.db.TrackToken("expired-jti", a.ID, past, past.Add(time.Minute)); err != nil {
|
||||
t.Fatalf("track expired token: %v", err)
|
||||
}
|
||||
if err := tool.db.TrackToken("valid-jti", a.ID, future.Add(-time.Minute), future); err != nil {
|
||||
t.Fatalf("track valid token: %v", err)
|
||||
}
|
||||
|
||||
out := captureStdout(t, tool.pruneTokens)
|
||||
if !strings.Contains(out, "1") {
|
||||
t.Errorf("expected 1 pruned in output, got: %s", out)
|
||||
}
|
||||
|
||||
// Valid token should still exist.
|
||||
if _, err := tool.db.GetTokenRecord("valid-jti"); err != nil {
|
||||
t.Errorf("valid token should survive pruning: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- audit tests ----
|
||||
|
||||
func TestAuditTail(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("audituser", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil {
|
||||
t.Fatalf("write audit event: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
tool.auditTail([]string{"--n", "3"})
|
||||
})
|
||||
// Output should contain the event type.
|
||||
if !strings.Contains(out, "login_ok") {
|
||||
t.Errorf("expected login_ok in tail output, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditQueryByType(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("auditquery", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil {
|
||||
t.Fatalf("write login_ok: %v", err)
|
||||
}
|
||||
if err := tool.db.WriteAuditEvent(model.EventLoginFail, &a.ID, nil, "", ""); err != nil {
|
||||
t.Fatalf("write login_fail: %v", err)
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
tool.auditQuery([]string{"--type", "login_fail"})
|
||||
})
|
||||
if !strings.Contains(out, "login_fail") {
|
||||
t.Errorf("expected login_fail in query output, got: %s", out)
|
||||
}
|
||||
if strings.Contains(out, "login_ok") {
|
||||
t.Errorf("unexpected login_ok in filtered query output, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditQueryJSON(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("jsonaudit", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil {
|
||||
t.Fatalf("write event: %v", err)
|
||||
}
|
||||
|
||||
out := captureStdout(t, func() {
|
||||
tool.auditQuery([]string{"--json"})
|
||||
})
|
||||
if !strings.Contains(out, `"event_type"`) {
|
||||
t.Errorf("expected JSON output with event_type, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- pgcreds tests ----
|
||||
|
||||
func TestPGCredsSetAndGet(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("pguser", model.AccountTypeSystem, "")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt and store credentials directly using the tool's master key.
|
||||
password := "s3cr3t-pg-pass"
|
||||
enc, nonce, err := crypto.SealAESGCM(tool.masterKey, []byte(password))
|
||||
if err != nil {
|
||||
t.Fatalf("seal pgcreds: %v", err)
|
||||
}
|
||||
if err := tool.db.WritePGCredentials(a.ID, "db.example.com", 5432, "mydb", "myuser", enc, nonce); err != nil {
|
||||
t.Fatalf("write pg credentials: %v", err)
|
||||
}
|
||||
|
||||
// pgCredsGet calls pgCredsGet which calls fatalf if decryption fails.
|
||||
// We test round-trip via DB + crypto directly.
|
||||
cred, err := tool.db.ReadPGCredentials(a.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("read pg credentials: %v", err)
|
||||
}
|
||||
plaintext, err := crypto.OpenAESGCM(tool.masterKey, cred.PGPasswordNonce, cred.PGPasswordEnc)
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt pg password: %v", err)
|
||||
}
|
||||
if string(plaintext) != password {
|
||||
t.Errorf("expected password %q, got %q", password, string(plaintext))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPGCredsGetNotFound(t *testing.T) {
|
||||
tool := newTestTool(t)
|
||||
|
||||
a, err := tool.db.CreateAccount("nopguser", model.AccountTypeSystem, "")
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
|
||||
// ReadPGCredentials for account with no credentials should return ErrNotFound.
|
||||
_, err = tool.db.ReadPGCredentials(a.ID)
|
||||
if err == nil {
|
||||
t.Fatal("expected ErrNotFound, got nil")
|
||||
}
|
||||
}
|
||||
127
cmd/mciasdb/pgcreds.go
Normal file
127
cmd/mciasdb/pgcreds.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
)
|
||||
|
||||
func (t *tool) runPGCreds(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("pgcreds requires a subcommand: get, set")
|
||||
}
|
||||
switch args[0] {
|
||||
case "get":
|
||||
t.pgCredsGet(args[1:])
|
||||
case "set":
|
||||
t.pgCredsSet(args[1:])
|
||||
default:
|
||||
fatalf("unknown pgcreds subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
// pgCredsGet decrypts and prints Postgres credentials for an account.
|
||||
// A warning is printed before the output to remind the operator that
|
||||
// the password is sensitive and must not be logged.
|
||||
//
|
||||
// Security: Credentials are decrypted in-memory using the master key and
|
||||
// printed directly to stdout. The operator is responsible for ensuring the
|
||||
// terminal output is not captured in logs. The plaintext password is never
|
||||
// written to disk.
|
||||
func (t *tool) pgCredsGet(args []string) {
|
||||
fs := flag.NewFlagSet("pgcreds get", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("pgcreds get: --id is required")
|
||||
}
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
cred, err := t.db.ReadPGCredentials(a.ID)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
fatalf("no Postgres credentials stored for account %s", a.Username)
|
||||
}
|
||||
if err != nil {
|
||||
fatalf("read pg credentials: %v", err)
|
||||
}
|
||||
|
||||
// Decrypt the password.
|
||||
// Security: AES-256-GCM decryption; any tampering with the ciphertext or
|
||||
// nonce will cause decryption to fail with an authentication error.
|
||||
plaintext, err := crypto.OpenAESGCM(t.masterKey, cred.PGPasswordNonce, cred.PGPasswordEnc)
|
||||
if err != nil {
|
||||
fatalf("decrypt pg password: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.WriteAuditEvent("pgcred_accessed", nil, &a.ID, "", `{"actor":"mciasdb"}`); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: write audit event: %v\n", err)
|
||||
}
|
||||
|
||||
// Print warning before sensitive output.
|
||||
fmt.Fprintln(os.Stderr, "WARNING: output below contains a plaintext password. Do not log or share.")
|
||||
fmt.Printf("Host: %s\n", cred.PGHost)
|
||||
fmt.Printf("Port: %d\n", cred.PGPort)
|
||||
fmt.Printf("Database: %s\n", cred.PGDatabase)
|
||||
fmt.Printf("Username: %s\n", cred.PGUsername)
|
||||
fmt.Printf("Password: %s\n", string(plaintext))
|
||||
}
|
||||
|
||||
// pgCredsSet prompts for a Postgres password interactively, encrypts it with
|
||||
// AES-256-GCM, and stores the credentials for the given account.
|
||||
//
|
||||
// Security: No --password flag is provided to prevent the password from
|
||||
// appearing in shell history or process listings. Encryption uses a fresh
|
||||
// random nonce each time.
|
||||
func (t *tool) pgCredsSet(args []string) {
|
||||
fs := flag.NewFlagSet("pgcreds set", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
host := fs.String("host", "", "Postgres host (required)")
|
||||
port := fs.Int("port", 5432, "Postgres port")
|
||||
dbName := fs.String("db", "", "Postgres database name (required)")
|
||||
username := fs.String("user", "", "Postgres username (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" || *host == "" || *dbName == "" || *username == "" {
|
||||
fatalf("pgcreds set: --id, --host, --db, and --user are required")
|
||||
}
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
password, err := readPassword("Postgres password: ")
|
||||
if err != nil {
|
||||
fatalf("read password: %v", err)
|
||||
}
|
||||
if password == "" {
|
||||
fatalf("password must not be empty")
|
||||
}
|
||||
|
||||
// Encrypt the password at rest.
|
||||
// Security: AES-256-GCM with a fresh random nonce ensures ciphertext
|
||||
// uniqueness even if the same password is stored multiple times.
|
||||
enc, nonce, err := crypto.SealAESGCM(t.masterKey, []byte(password))
|
||||
if err != nil {
|
||||
fatalf("encrypt pg password: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.WritePGCredentials(a.ID, *host, *port, *dbName, *username, enc, nonce); err != nil {
|
||||
fatalf("write pg credentials: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.WriteAuditEvent("pgcred_updated", nil, &a.ID, "", `{"actor":"mciasdb"}`); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: write audit event: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Postgres credentials stored for account %s\n", a.Username)
|
||||
}
|
||||
112
cmd/mciasdb/role.go
Normal file
112
cmd/mciasdb/role.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (t *tool) runRole(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("role requires a subcommand: list, grant, revoke")
|
||||
}
|
||||
switch args[0] {
|
||||
case "list":
|
||||
t.roleList(args[1:])
|
||||
case "grant":
|
||||
t.roleGrant(args[1:])
|
||||
case "revoke":
|
||||
t.roleRevoke(args[1:])
|
||||
default:
|
||||
fatalf("unknown role subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tool) roleList(args []string) {
|
||||
fs := flag.NewFlagSet("role list", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("role list: --id is required")
|
||||
}
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
roles, err := t.db.GetRoles(a.ID)
|
||||
if err != nil {
|
||||
fatalf("get roles: %v", err)
|
||||
}
|
||||
|
||||
if len(roles) == 0 {
|
||||
fmt.Printf("account %s has no roles\n", a.Username)
|
||||
return
|
||||
}
|
||||
fmt.Printf("roles for %s (%s):\n", a.Username, a.UUID)
|
||||
for _, r := range roles {
|
||||
fmt.Printf(" %s\n", r)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tool) roleGrant(args []string) {
|
||||
fs := flag.NewFlagSet("role grant", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
role := fs.String("role", "", "role to grant (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("role grant: --id is required")
|
||||
}
|
||||
if *role == "" {
|
||||
fatalf("role grant: --role is required")
|
||||
}
|
||||
*role = strings.TrimSpace(*role)
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.GrantRole(a.ID, *role, nil); err != nil {
|
||||
fatalf("grant role: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.WriteAuditEvent("role_granted", nil, &a.ID, "", fmt.Sprintf(`{"actor":"mciasdb","role":%q}`, *role)); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: write audit event: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("granted role %q to account %s\n", *role, a.Username)
|
||||
}
|
||||
|
||||
func (t *tool) roleRevoke(args []string) {
|
||||
fs := flag.NewFlagSet("role revoke", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
role := fs.String("role", "", "role to revoke (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("role revoke: --id is required")
|
||||
}
|
||||
if *role == "" {
|
||||
fatalf("role revoke: --role is required")
|
||||
}
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.RevokeRole(a.ID, *role); err != nil {
|
||||
fatalf("revoke role: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.WriteAuditEvent("role_revoked", nil, &a.ID, "", fmt.Sprintf(`{"actor":"mciasdb","role":%q}`, *role)); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: write audit event: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("revoked role %q from account %s\n", *role, a.Username)
|
||||
}
|
||||
63
cmd/mciasdb/schema.go
Normal file
63
cmd/mciasdb/schema.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
)
|
||||
|
||||
func (t *tool) runSchema(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("schema requires a subcommand: verify, migrate")
|
||||
}
|
||||
switch args[0] {
|
||||
case "verify":
|
||||
t.schemaVerify()
|
||||
case "migrate":
|
||||
t.schemaMigrate()
|
||||
default:
|
||||
fatalf("unknown schema subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
// schemaVerify reports the current schema version and exits 1 if migrations
|
||||
// are pending, 0 if the database is up-to-date.
|
||||
func (t *tool) schemaVerify() {
|
||||
version, err := db.SchemaVersion(t.db)
|
||||
if err != nil {
|
||||
fatalf("get schema version: %v", err)
|
||||
}
|
||||
latest := db.LatestSchemaVersion
|
||||
fmt.Printf("schema version: %d (latest: %d)\n", version, latest)
|
||||
if version < latest {
|
||||
fmt.Printf("%d migration(s) pending\n", latest-version)
|
||||
// Exit 1 to signal that migrations are needed (useful in scripts).
|
||||
// We call os.Exit directly rather than fatalf to avoid printing "mciasdb: ".
|
||||
fmt.Println("run 'mciasdb schema migrate' to apply pending migrations")
|
||||
exitCode1()
|
||||
}
|
||||
fmt.Println("schema is up-to-date")
|
||||
}
|
||||
|
||||
// schemaMigrate applies any pending migrations and reports each one.
|
||||
func (t *tool) schemaMigrate() {
|
||||
before, err := db.SchemaVersion(t.db)
|
||||
if err != nil {
|
||||
fatalf("get schema version: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Migrate(t.db); err != nil {
|
||||
fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
after, err := db.SchemaVersion(t.db)
|
||||
if err != nil {
|
||||
fatalf("get schema version after migrate: %v", err)
|
||||
}
|
||||
|
||||
if before == after {
|
||||
fmt.Println("no migrations needed; schema is already up-to-date")
|
||||
return
|
||||
}
|
||||
fmt.Printf("migrated schema from version %d to %d\n", before, after)
|
||||
}
|
||||
130
cmd/mciasdb/token.go
Normal file
130
cmd/mciasdb/token.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (t *tool) runToken(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("token requires a subcommand: list, revoke, revoke-all")
|
||||
}
|
||||
switch args[0] {
|
||||
case "list":
|
||||
t.tokenList(args[1:])
|
||||
case "revoke":
|
||||
t.tokenRevoke(args[1:])
|
||||
case "revoke-all":
|
||||
t.tokenRevokeAll(args[1:])
|
||||
default:
|
||||
fatalf("unknown token subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tool) runPrune(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("prune requires a subcommand: tokens")
|
||||
}
|
||||
switch args[0] {
|
||||
case "tokens":
|
||||
t.pruneTokens()
|
||||
default:
|
||||
fatalf("unknown prune subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tool) tokenList(args []string) {
|
||||
fs := flag.NewFlagSet("token list", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("token list: --id is required")
|
||||
}
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
records, err := t.db.ListTokensForAccount(a.ID)
|
||||
if err != nil {
|
||||
fatalf("list tokens: %v", err)
|
||||
}
|
||||
|
||||
if len(records) == 0 {
|
||||
fmt.Printf("no token records for account %s\n", a.Username)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("tokens for %s (%s):\n", a.Username, a.UUID)
|
||||
fmt.Printf("%-36s %-20s %-20s %-20s\n", "JTI", "ISSUED AT", "EXPIRES AT", "REVOKED AT")
|
||||
fmt.Println(strings.Repeat("-", 100))
|
||||
for _, r := range records {
|
||||
revokedAt := "-"
|
||||
if r.RevokedAt != nil {
|
||||
revokedAt = r.RevokedAt.Format("2006-01-02T15:04:05Z")
|
||||
}
|
||||
fmt.Printf("%-36s %-20s %-20s %-20s\n",
|
||||
r.JTI,
|
||||
r.IssuedAt.Format("2006-01-02T15:04:05Z"),
|
||||
r.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||
revokedAt,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tool) tokenRevoke(args []string) {
|
||||
fs := flag.NewFlagSet("token revoke", flag.ExitOnError)
|
||||
jti := fs.String("jti", "", "JTI of the token to revoke (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *jti == "" {
|
||||
fatalf("token revoke: --jti is required")
|
||||
}
|
||||
|
||||
if err := t.db.RevokeToken(*jti, "mciasdb"); err != nil {
|
||||
fatalf("revoke token: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.WriteAuditEvent("token_revoked", nil, nil, "", fmt.Sprintf(`{"actor":"mciasdb","jti":%q}`, *jti)); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: write audit event: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("token %s revoked\n", *jti)
|
||||
}
|
||||
|
||||
func (t *tool) tokenRevokeAll(args []string) {
|
||||
fs := flag.NewFlagSet("token revoke-all", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("token revoke-all: --id is required")
|
||||
}
|
||||
|
||||
a, err := t.db.GetAccountByUUID(*id)
|
||||
if err != nil {
|
||||
fatalf("get account: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.RevokeAllUserTokens(a.ID, "mciasdb"); err != nil {
|
||||
fatalf("revoke all tokens: %v", err)
|
||||
}
|
||||
|
||||
if err := t.db.WriteAuditEvent("token_revoked", nil, &a.ID, "", `{"actor":"mciasdb","action":"revoke_all"}`); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: write audit event: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("all active tokens revoked for account %s\n", a.Username)
|
||||
}
|
||||
|
||||
func (t *tool) pruneTokens() {
|
||||
count, err := t.db.PruneExpiredTokens()
|
||||
if err != nil {
|
||||
fatalf("prune expired tokens: %v", err)
|
||||
}
|
||||
fmt.Printf("pruned %d expired token record(s)\n", count)
|
||||
}
|
||||
602
cmd/mciasgrpcctl/main.go
Normal file
602
cmd/mciasgrpcctl/main.go
Normal file
@@ -0,0 +1,602 @@
|
||||
// Command mciasgrpcctl is the MCIAS gRPC admin CLI.
|
||||
//
|
||||
// It connects to a running mciassrv gRPC listener and provides subcommands for
|
||||
// managing accounts, roles, tokens, and Postgres credentials via the gRPC API.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// mciasgrpcctl [global flags] <command> [args]
|
||||
//
|
||||
// Global flags:
|
||||
//
|
||||
// -server gRPC server address (default: localhost:9443)
|
||||
// -token Bearer token for authentication (or set MCIAS_TOKEN env var)
|
||||
// -cacert Path to CA certificate for TLS verification (optional)
|
||||
//
|
||||
// Commands:
|
||||
//
|
||||
// health
|
||||
// pubkey
|
||||
//
|
||||
// account list
|
||||
// account create -username NAME -password PASS [-type human|system]
|
||||
// account get -id UUID
|
||||
// account update -id UUID -status active|inactive
|
||||
// account delete -id UUID
|
||||
//
|
||||
// role list -id UUID
|
||||
// role set -id UUID -roles role1,role2,...
|
||||
//
|
||||
// token validate -token TOKEN
|
||||
// token issue -id UUID
|
||||
// token revoke -jti JTI
|
||||
//
|
||||
// pgcreds get -id UUID
|
||||
// pgcreds set -id UUID -host HOST [-port PORT] -db DB -user USER -password PASS
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Global flags.
|
||||
serverAddr := flag.String("server", "localhost:9443", "gRPC server address (host:port)")
|
||||
tokenFlag := flag.String("token", "", "bearer token (or set MCIAS_TOKEN)")
|
||||
caCert := flag.String("cacert", "", "path to CA certificate for TLS")
|
||||
flag.Usage = usage
|
||||
flag.Parse()
|
||||
|
||||
// Resolve token from flag or environment.
|
||||
bearerToken := *tokenFlag
|
||||
if bearerToken == "" {
|
||||
bearerToken = os.Getenv("MCIAS_TOKEN")
|
||||
}
|
||||
|
||||
args := flag.Args()
|
||||
if len(args) == 0 {
|
||||
usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Build gRPC connection.
|
||||
conn, err := newGRPCConn(*serverAddr, *caCert)
|
||||
if err != nil {
|
||||
fatalf("connect to gRPC server: %v", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
ctl := &controller{
|
||||
conn: conn,
|
||||
token: bearerToken,
|
||||
}
|
||||
|
||||
command := args[0]
|
||||
subArgs := args[1:]
|
||||
|
||||
switch command {
|
||||
case "health":
|
||||
ctl.runHealth()
|
||||
case "pubkey":
|
||||
ctl.runPubKey()
|
||||
case "account":
|
||||
ctl.runAccount(subArgs)
|
||||
case "role":
|
||||
ctl.runRole(subArgs)
|
||||
case "token":
|
||||
ctl.runToken(subArgs)
|
||||
case "pgcreds":
|
||||
ctl.runPGCreds(subArgs)
|
||||
default:
|
||||
fatalf("unknown command %q; run with no args to see usage", command)
|
||||
}
|
||||
}
|
||||
|
||||
// controller holds the shared gRPC connection and token for all subcommands.
|
||||
type controller struct {
|
||||
conn *grpc.ClientConn
|
||||
token string
|
||||
}
|
||||
|
||||
// authCtx returns a context with the Bearer token injected as gRPC metadata.
|
||||
// Security: token is placed in the "authorization" key per the gRPC convention
|
||||
// that mirrors the HTTP Authorization header. Value is never logged.
|
||||
func (c *controller) authCtx() context.Context {
|
||||
ctx := context.Background()
|
||||
if c.token == "" {
|
||||
return ctx
|
||||
}
|
||||
// Security: metadata key "authorization" matches the server-side
|
||||
// extractBearerFromMD expectation; value is "Bearer <token>".
|
||||
return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+c.token)
|
||||
}
|
||||
|
||||
// callCtx returns an authCtx with a 30-second deadline.
|
||||
func (c *controller) callCtx() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(c.authCtx(), 30*time.Second)
|
||||
}
|
||||
|
||||
// ---- health / pubkey ----
|
||||
|
||||
func (c *controller) runHealth() {
|
||||
adminCl := mciasv1.NewAdminServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
resp, err := adminCl.Health(ctx, &mciasv1.HealthRequest{})
|
||||
if err != nil {
|
||||
fatalf("health: %v", err)
|
||||
}
|
||||
printJSON(map[string]string{"status": resp.Status})
|
||||
}
|
||||
|
||||
func (c *controller) runPubKey() {
|
||||
adminCl := mciasv1.NewAdminServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
resp, err := adminCl.GetPublicKey(ctx, &mciasv1.GetPublicKeyRequest{})
|
||||
if err != nil {
|
||||
fatalf("pubkey: %v", err)
|
||||
}
|
||||
printJSON(map[string]string{
|
||||
"kty": resp.Kty,
|
||||
"crv": resp.Crv,
|
||||
"use": resp.Use,
|
||||
"alg": resp.Alg,
|
||||
"x": resp.X,
|
||||
})
|
||||
}
|
||||
|
||||
// ---- account subcommands ----
|
||||
|
||||
func (c *controller) runAccount(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("account requires a subcommand: list, create, get, update, delete")
|
||||
}
|
||||
switch args[0] {
|
||||
case "list":
|
||||
c.accountList()
|
||||
case "create":
|
||||
c.accountCreate(args[1:])
|
||||
case "get":
|
||||
c.accountGet(args[1:])
|
||||
case "update":
|
||||
c.accountUpdate(args[1:])
|
||||
case "delete":
|
||||
c.accountDelete(args[1:])
|
||||
default:
|
||||
fatalf("unknown account subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controller) accountList() {
|
||||
cl := mciasv1.NewAccountServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
resp, err := cl.ListAccounts(ctx, &mciasv1.ListAccountsRequest{})
|
||||
if err != nil {
|
||||
fatalf("account list: %v", err)
|
||||
}
|
||||
printJSON(resp.Accounts)
|
||||
}
|
||||
|
||||
func (c *controller) accountCreate(args []string) {
|
||||
fs := flag.NewFlagSet("account create", flag.ExitOnError)
|
||||
username := fs.String("username", "", "username (required)")
|
||||
password := fs.String("password", "", "password (required for human accounts)")
|
||||
accountType := fs.String("type", "human", "account type: human or system")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *username == "" {
|
||||
fatalf("account create: -username is required")
|
||||
}
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
resp, err := cl.CreateAccount(ctx, &mciasv1.CreateAccountRequest{
|
||||
Username: *username,
|
||||
Password: *password,
|
||||
AccountType: *accountType,
|
||||
})
|
||||
if err != nil {
|
||||
fatalf("account create: %v", err)
|
||||
}
|
||||
printJSON(resp.Account)
|
||||
}
|
||||
|
||||
func (c *controller) accountGet(args []string) {
|
||||
fs := flag.NewFlagSet("account get", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("account get: -id is required")
|
||||
}
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
resp, err := cl.GetAccount(ctx, &mciasv1.GetAccountRequest{Id: *id})
|
||||
if err != nil {
|
||||
fatalf("account get: %v", err)
|
||||
}
|
||||
printJSON(resp.Account)
|
||||
}
|
||||
|
||||
func (c *controller) accountUpdate(args []string) {
|
||||
fs := flag.NewFlagSet("account update", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
status := fs.String("status", "", "new status: active or inactive (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("account update: -id is required")
|
||||
}
|
||||
if *status == "" {
|
||||
fatalf("account update: -status is required")
|
||||
}
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
_, err := cl.UpdateAccount(ctx, &mciasv1.UpdateAccountRequest{
|
||||
Id: *id,
|
||||
Status: *status,
|
||||
})
|
||||
if err != nil {
|
||||
fatalf("account update: %v", err)
|
||||
}
|
||||
fmt.Println("account updated")
|
||||
}
|
||||
|
||||
func (c *controller) accountDelete(args []string) {
|
||||
fs := flag.NewFlagSet("account delete", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("account delete: -id is required")
|
||||
}
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
_, err := cl.DeleteAccount(ctx, &mciasv1.DeleteAccountRequest{Id: *id})
|
||||
if err != nil {
|
||||
fatalf("account delete: %v", err)
|
||||
}
|
||||
fmt.Println("account deleted")
|
||||
}
|
||||
|
||||
// ---- role subcommands ----
|
||||
|
||||
func (c *controller) runRole(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("role requires a subcommand: list, set")
|
||||
}
|
||||
switch args[0] {
|
||||
case "list":
|
||||
c.roleList(args[1:])
|
||||
case "set":
|
||||
c.roleSet(args[1:])
|
||||
default:
|
||||
fatalf("unknown role subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controller) roleList(args []string) {
|
||||
fs := flag.NewFlagSet("role list", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("role list: -id is required")
|
||||
}
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
resp, err := cl.GetRoles(ctx, &mciasv1.GetRolesRequest{Id: *id})
|
||||
if err != nil {
|
||||
fatalf("role list: %v", err)
|
||||
}
|
||||
printJSON(resp.Roles)
|
||||
}
|
||||
|
||||
func (c *controller) roleSet(args []string) {
|
||||
fs := flag.NewFlagSet("role set", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
rolesFlag := fs.String("roles", "", "comma-separated list of roles")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("role set: -id is required")
|
||||
}
|
||||
|
||||
var roles []string
|
||||
if *rolesFlag != "" {
|
||||
for _, r := range strings.Split(*rolesFlag, ",") {
|
||||
r = strings.TrimSpace(r)
|
||||
if r != "" {
|
||||
roles = append(roles, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
_, err := cl.SetRoles(ctx, &mciasv1.SetRolesRequest{Id: *id, Roles: roles})
|
||||
if err != nil {
|
||||
fatalf("role set: %v", err)
|
||||
}
|
||||
fmt.Printf("roles set: %v\n", roles)
|
||||
}
|
||||
|
||||
// ---- token subcommands ----
|
||||
|
||||
func (c *controller) runToken(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("token requires a subcommand: validate, issue, revoke")
|
||||
}
|
||||
switch args[0] {
|
||||
case "validate":
|
||||
c.tokenValidate(args[1:])
|
||||
case "issue":
|
||||
c.tokenIssue(args[1:])
|
||||
case "revoke":
|
||||
c.tokenRevoke(args[1:])
|
||||
default:
|
||||
fatalf("unknown token subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controller) tokenValidate(args []string) {
|
||||
fs := flag.NewFlagSet("token validate", flag.ExitOnError)
|
||||
tok := fs.String("token", "", "JWT to validate (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *tok == "" {
|
||||
fatalf("token validate: -token is required")
|
||||
}
|
||||
|
||||
cl := mciasv1.NewTokenServiceClient(c.conn)
|
||||
// ValidateToken is public — no auth context needed.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := cl.ValidateToken(ctx, &mciasv1.ValidateTokenRequest{Token: *tok})
|
||||
if err != nil {
|
||||
fatalf("token validate: %v", err)
|
||||
}
|
||||
printJSON(map[string]interface{}{
|
||||
"valid": resp.Valid,
|
||||
"subject": resp.Subject,
|
||||
"roles": resp.Roles,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *controller) tokenIssue(args []string) {
|
||||
fs := flag.NewFlagSet("token issue", flag.ExitOnError)
|
||||
id := fs.String("id", "", "system account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("token issue: -id is required")
|
||||
}
|
||||
|
||||
cl := mciasv1.NewTokenServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
resp, err := cl.IssueServiceToken(ctx, &mciasv1.IssueServiceTokenRequest{AccountId: *id})
|
||||
if err != nil {
|
||||
fatalf("token issue: %v", err)
|
||||
}
|
||||
printJSON(map[string]string{"token": resp.Token})
|
||||
}
|
||||
|
||||
func (c *controller) tokenRevoke(args []string) {
|
||||
fs := flag.NewFlagSet("token revoke", flag.ExitOnError)
|
||||
jti := fs.String("jti", "", "JTI of the token to revoke (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *jti == "" {
|
||||
fatalf("token revoke: -jti is required")
|
||||
}
|
||||
|
||||
cl := mciasv1.NewTokenServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
_, err := cl.RevokeToken(ctx, &mciasv1.RevokeTokenRequest{Jti: *jti})
|
||||
if err != nil {
|
||||
fatalf("token revoke: %v", err)
|
||||
}
|
||||
fmt.Println("token revoked")
|
||||
}
|
||||
|
||||
// ---- pgcreds subcommands ----
|
||||
|
||||
func (c *controller) runPGCreds(args []string) {
|
||||
if len(args) == 0 {
|
||||
fatalf("pgcreds requires a subcommand: get, set")
|
||||
}
|
||||
switch args[0] {
|
||||
case "get":
|
||||
c.pgCredsGet(args[1:])
|
||||
case "set":
|
||||
c.pgCredsSet(args[1:])
|
||||
default:
|
||||
fatalf("unknown pgcreds subcommand %q", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controller) pgCredsGet(args []string) {
|
||||
fs := flag.NewFlagSet("pgcreds get", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" {
|
||||
fatalf("pgcreds get: -id is required")
|
||||
}
|
||||
|
||||
cl := mciasv1.NewCredentialServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
resp, err := cl.GetPGCreds(ctx, &mciasv1.GetPGCredsRequest{Id: *id})
|
||||
if err != nil {
|
||||
fatalf("pgcreds get: %v", err)
|
||||
}
|
||||
if resp.Creds == nil {
|
||||
fatalf("pgcreds get: no credentials returned")
|
||||
}
|
||||
printJSON(map[string]interface{}{
|
||||
"host": resp.Creds.Host,
|
||||
"port": resp.Creds.Port,
|
||||
"database": resp.Creds.Database,
|
||||
"username": resp.Creds.Username,
|
||||
"password": resp.Creds.Password,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *controller) pgCredsSet(args []string) {
|
||||
fs := flag.NewFlagSet("pgcreds set", flag.ExitOnError)
|
||||
id := fs.String("id", "", "account UUID (required)")
|
||||
host := fs.String("host", "", "Postgres host (required)")
|
||||
port := fs.Int("port", 5432, "Postgres port")
|
||||
dbName := fs.String("db", "", "Postgres database name (required)")
|
||||
username := fs.String("user", "", "Postgres username (required)")
|
||||
password := fs.String("password", "", "Postgres password (required)")
|
||||
_ = fs.Parse(args)
|
||||
|
||||
if *id == "" || *host == "" || *dbName == "" || *username == "" || *password == "" {
|
||||
fatalf("pgcreds set: -id, -host, -db, -user, and -password are required")
|
||||
}
|
||||
|
||||
cl := mciasv1.NewCredentialServiceClient(c.conn)
|
||||
ctx, cancel := c.callCtx()
|
||||
defer cancel()
|
||||
|
||||
_, err := cl.SetPGCreds(ctx, &mciasv1.SetPGCredsRequest{
|
||||
Id: *id,
|
||||
Creds: &mciasv1.PGCreds{
|
||||
Host: *host,
|
||||
Port: int32(*port),
|
||||
Database: *dbName,
|
||||
Username: *username,
|
||||
Password: *password,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
fatalf("pgcreds set: %v", err)
|
||||
}
|
||||
fmt.Println("credentials stored")
|
||||
}
|
||||
|
||||
// ---- gRPC connection ----
|
||||
|
||||
// newGRPCConn dials the gRPC server with TLS.
|
||||
// If caCertPath is empty, the system CA pool is used.
|
||||
// Security: TLS 1.2+ is enforced by the crypto/tls defaults on the client side.
|
||||
// The connection is insecure-skip-verify-free; operators can supply a custom CA
|
||||
// for self-signed certs without disabling certificate validation.
|
||||
func newGRPCConn(serverAddr, caCertPath string) (*grpc.ClientConn, error) {
|
||||
tlsCfg := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
if caCertPath != "" {
|
||||
// G304: path comes from a CLI flag supplied by the operator, not
|
||||
// from untrusted input. File inclusion is intentional.
|
||||
pemData, err := os.ReadFile(caCertPath) //nolint:gosec
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read CA cert: %w", err)
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(pemData) {
|
||||
return nil, fmt.Errorf("no valid certificates found in %s", caCertPath)
|
||||
}
|
||||
tlsCfg.RootCAs = pool
|
||||
}
|
||||
|
||||
creds := credentials.NewTLS(tlsCfg)
|
||||
conn, err := grpc.NewClient(serverAddr, grpc.WithTransportCredentials(creds))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", serverAddr, err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// ---- helpers ----
|
||||
|
||||
// printJSON pretty-prints a value as JSON to stdout.
|
||||
func printJSON(v interface{}) {
|
||||
enc := json.NewEncoder(os.Stdout)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(v); err != nil {
|
||||
fatalf("encode output: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// fatalf prints an error message to stderr and exits with code 1.
|
||||
func fatalf(format string, args ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "mciasgrpcctl: "+format+"\n", args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func usage() {
|
||||
fmt.Fprintf(os.Stderr, `mciasgrpcctl - MCIAS gRPC admin CLI
|
||||
|
||||
Usage: mciasgrpcctl [global flags] <command> [args]
|
||||
|
||||
Global flags:
|
||||
-server gRPC server address (default: localhost:9443)
|
||||
-token Bearer token (or set MCIAS_TOKEN env var)
|
||||
-cacert Path to CA certificate for TLS verification
|
||||
|
||||
Commands:
|
||||
health
|
||||
pubkey
|
||||
|
||||
account list
|
||||
account create -username NAME -password PASS [-type human|system]
|
||||
account get -id UUID
|
||||
account update -id UUID -status active|inactive
|
||||
account delete -id UUID
|
||||
|
||||
role list -id UUID
|
||||
role set -id UUID -roles role1,role2,...
|
||||
|
||||
token validate -token TOKEN
|
||||
token issue -id UUID
|
||||
token revoke -jti JTI
|
||||
|
||||
pgcreds get -id UUID
|
||||
pgcreds set -id UUID -host HOST [-port PORT] -db DB -user USER -password PASS
|
||||
`)
|
||||
}
|
||||
331
cmd/mciassrv/main.go
Normal file
331
cmd/mciassrv/main.go
Normal file
@@ -0,0 +1,331 @@
|
||||
// Command mciassrv is the MCIAS authentication server.
|
||||
//
|
||||
// It reads a TOML configuration file, derives the master encryption key,
|
||||
// loads or generates the Ed25519 signing key, opens the SQLite database,
|
||||
// runs migrations, and starts an HTTPS listener.
|
||||
// If [server] grpc_addr is set in the config, a gRPC/TLS listener is also
|
||||
// started on that address. Both listeners share the same signing key, DB,
|
||||
// and config. Graceful shutdown drains both within the configured window.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// mciassrv -config /etc/mcias/mcias.toml
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/config"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/grpcserver"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/server"
|
||||
)
|
||||
|
||||
func main() {
|
||||
configPath := flag.String("config", "mcias.toml", "path to TOML configuration file")
|
||||
flag.Parse()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelInfo,
|
||||
}))
|
||||
|
||||
if err := run(*configPath, logger); err != nil {
|
||||
logger.Error("fatal", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func run(configPath string, logger *slog.Logger) error {
|
||||
// Load and validate configuration.
|
||||
cfg, err := config.Load(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
logger.Info("configuration loaded", "listen_addr", cfg.Server.ListenAddr, "grpc_addr", cfg.Server.GRPCAddr)
|
||||
|
||||
// Open and migrate the database first — we need it to load the master key salt.
|
||||
database, err := db.Open(cfg.Database.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open database: %w", err)
|
||||
}
|
||||
defer func() { _ = database.Close() }()
|
||||
|
||||
if err := db.Migrate(database); err != nil {
|
||||
return fmt.Errorf("migrate database: %w", err)
|
||||
}
|
||||
logger.Info("database ready", "path", cfg.Database.Path)
|
||||
|
||||
// Derive or load the master encryption key.
|
||||
// Security: The master key encrypts TOTP secrets, Postgres passwords, and
|
||||
// the signing key at rest. It is derived from a passphrase via Argon2id
|
||||
// (or loaded directly from a key file). The KDF salt is stored in the DB
|
||||
// for stability across restarts. The passphrase env var is cleared after use.
|
||||
masterKey, err := loadMasterKey(cfg, database)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load master key: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
// Zero the master key when done — reduces the window of exposure.
|
||||
for i := range masterKey {
|
||||
masterKey[i] = 0
|
||||
}
|
||||
}()
|
||||
|
||||
// Load or generate the Ed25519 signing key.
|
||||
// Security: The private signing key is stored AES-256-GCM encrypted in the
|
||||
// database. On first run it is generated and stored. The key is decrypted
|
||||
// with the master key each startup.
|
||||
privKey, pubKey, err := loadOrGenerateSigningKey(database, masterKey, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("signing key: %w", err)
|
||||
}
|
||||
|
||||
// Configure TLS. We require TLS 1.2+ and prefer TLS 1.3.
|
||||
// Security: HTTPS/gRPC-TLS is mandatory; no plaintext listener is provided.
|
||||
// The same TLS certificate is used for both REST and gRPC listeners.
|
||||
tlsCfg := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CurvePreferences: []tls.CurveID{
|
||||
tls.X25519,
|
||||
tls.CurveP256,
|
||||
},
|
||||
}
|
||||
|
||||
// Build the REST handler.
|
||||
restSrv := server.New(database, cfg, privKey, pubKey, masterKey, logger)
|
||||
httpServer := &http.Server{
|
||||
Addr: cfg.Server.ListenAddr,
|
||||
Handler: restSrv.Handler(),
|
||||
TLSConfig: tlsCfg,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// Build the gRPC server if grpc_addr is configured.
|
||||
var grpcSrv *grpc.Server
|
||||
var grpcListener net.Listener
|
||||
if cfg.Server.GRPCAddr != "" {
|
||||
// Load TLS credentials for gRPC using the same cert/key as REST.
|
||||
// Security: TLS 1.2 minimum is enforced via tls.Config; no h2c is offered.
|
||||
grpcTLSCreds, err := credentials.NewServerTLSFromFile(cfg.Server.TLSCert, cfg.Server.TLSKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load gRPC TLS credentials: %w", err)
|
||||
}
|
||||
|
||||
grpcSrvImpl := grpcserver.New(database, cfg, privKey, pubKey, masterKey, logger)
|
||||
grpcSrv = grpcSrvImpl.GRPCServer()
|
||||
// Apply TLS to the gRPC server by wrapping options.
|
||||
// We reconstruct the server with TLS credentials since GRPCServer()
|
||||
// returns an already-built server; instead, build with creds directly.
|
||||
// Re-create with TLS option.
|
||||
grpcSrv = rebuildGRPCServerWithTLS(grpcSrvImpl, grpcTLSCreds)
|
||||
|
||||
grpcListener, err = net.Listen("tcp", cfg.Server.GRPCAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gRPC listen: %w", err)
|
||||
}
|
||||
logger.Info("gRPC listener ready", "addr", cfg.Server.GRPCAddr)
|
||||
}
|
||||
|
||||
// Graceful shutdown on SIGINT/SIGTERM.
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
// Start REST listener.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
logger.Info("REST server starting", "addr", cfg.Server.ListenAddr)
|
||||
if err := httpServer.ListenAndServeTLS(cfg.Server.TLSCert, cfg.Server.TLSKey); err != nil {
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
errCh <- fmt.Errorf("REST server: %w", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Start gRPC listener if configured.
|
||||
if grpcSrv != nil && grpcListener != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
logger.Info("gRPC server starting", "addr", cfg.Server.GRPCAddr)
|
||||
if err := grpcSrv.Serve(grpcListener); err != nil {
|
||||
errCh <- fmt.Errorf("gRPC server: %w", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for shutdown signal or a server error.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("shutdown signal received")
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
|
||||
// Graceful drain: give servers up to 15s to finish in-flight requests.
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Error("REST shutdown error", "error", err)
|
||||
}
|
||||
if grpcSrv != nil {
|
||||
grpcSrv.GracefulStop()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Drain any remaining error from startup goroutines.
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// rebuildGRPCServerWithTLS creates a new *grpc.Server with TLS credentials
|
||||
// and re-registers all services from the implementation.
|
||||
// This is needed because grpc.NewServer accepts credentials as an option at
|
||||
// construction time, not after the fact.
|
||||
func rebuildGRPCServerWithTLS(impl *grpcserver.Server, creds credentials.TransportCredentials) *grpc.Server {
|
||||
return impl.GRPCServerWithCreds(creds)
|
||||
}
|
||||
|
||||
// loadMasterKey derives or loads the AES-256-GCM master key from the config.
|
||||
//
|
||||
// Key file mode: reads exactly 32 bytes from a file.
|
||||
//
|
||||
// Passphrase mode: reads the passphrase from the named environment variable,
|
||||
// then immediately clears it from the environment. The Argon2id KDF salt is
|
||||
// stored in the database on first run and retrieved on subsequent runs so that
|
||||
// the same passphrase always yields the same master key.
|
||||
//
|
||||
// Security: The Argon2id parameters used by crypto.DeriveKey exceed OWASP 2023
|
||||
// minimums (time=3, memory=128MiB, threads=4). The salt is 32 random bytes.
|
||||
func loadMasterKey(cfg *config.Config, database *db.DB) ([]byte, error) {
|
||||
if cfg.MasterKey.KeyFile != "" {
|
||||
// Key file mode: file must contain exactly 32 bytes (AES-256).
|
||||
data, err := os.ReadFile(cfg.MasterKey.KeyFile) //nolint:gosec // G304: operator-supplied path
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read key file: %w", err)
|
||||
}
|
||||
if len(data) != 32 {
|
||||
return nil, fmt.Errorf("key file must be exactly 32 bytes, got %d", len(data))
|
||||
}
|
||||
key := make([]byte, 32)
|
||||
copy(key, data)
|
||||
// Zero the file buffer before it can be GC'd.
|
||||
for i := range data {
|
||||
data[i] = 0
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Passphrase mode.
|
||||
passphrase := os.Getenv(cfg.MasterKey.PassphraseEnv)
|
||||
if passphrase == "" {
|
||||
return nil, fmt.Errorf("environment variable %q is not set or empty", cfg.MasterKey.PassphraseEnv)
|
||||
}
|
||||
// Immediately unset the env var so child processes cannot read it.
|
||||
_ = os.Unsetenv(cfg.MasterKey.PassphraseEnv)
|
||||
|
||||
// Retrieve or create the KDF salt.
|
||||
salt, err := database.ReadMasterKeySalt()
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
// First run: generate and persist a new salt.
|
||||
salt, err = crypto.NewSalt()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate master key salt: %w", err)
|
||||
}
|
||||
if err := database.WriteMasterKeySalt(salt); err != nil {
|
||||
return nil, fmt.Errorf("store master key salt: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("read master key salt: %w", err)
|
||||
}
|
||||
|
||||
key, err := crypto.DeriveKey(passphrase, salt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("derive master key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// loadOrGenerateSigningKey loads the Ed25519 signing key from the database
|
||||
// (decrypted with masterKey), or generates and stores a new one on first run.
|
||||
//
|
||||
// Security: The private key is stored AES-256-GCM encrypted. A fresh random
|
||||
// nonce is used for each encryption. The plaintext key only exists in memory
|
||||
// during the process lifetime.
|
||||
func loadOrGenerateSigningKey(database *db.DB, masterKey []byte, logger *slog.Logger) (ed25519.PrivateKey, ed25519.PublicKey, error) {
|
||||
// Try to load existing key.
|
||||
enc, nonce, err := database.ReadServerConfig()
|
||||
if err == nil && enc != nil && nonce != nil {
|
||||
privPEM, err := crypto.OpenAESGCM(masterKey, nonce, enc)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("decrypt signing key: %w", err)
|
||||
}
|
||||
|
||||
priv, err := crypto.ParsePrivateKeyPEM(privPEM)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse signing key PEM: %w", err)
|
||||
}
|
||||
|
||||
// Security: ed25519.PrivateKey.Public() always returns ed25519.PublicKey,
|
||||
// but we use the ok form to make the type assertion explicit and safe.
|
||||
pub, ok := priv.Public().(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("signing key has unexpected public key type")
|
||||
}
|
||||
logger.Info("signing key loaded from database")
|
||||
return priv, pub, nil
|
||||
}
|
||||
|
||||
// First run: generate and store a new signing key.
|
||||
logger.Info("generating new Ed25519 signing key")
|
||||
pub, priv, err := crypto.GenerateEd25519KeyPair()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate signing key: %w", err)
|
||||
}
|
||||
|
||||
privPEM, err := crypto.MarshalPrivateKeyPEM(priv)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("marshal signing key: %w", err)
|
||||
}
|
||||
|
||||
encKey, encNonce, err := crypto.SealAESGCM(masterKey, privPEM)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("encrypt signing key: %w", err)
|
||||
}
|
||||
|
||||
if err := database.WriteServerConfig(encKey, encNonce); err != nil {
|
||||
return nil, nil, fmt.Errorf("store signing key: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("signing key generated and stored")
|
||||
return priv, pub, nil
|
||||
}
|
||||
Reference in New Issue
Block a user