Add Nix flake for mcproxyctl

Vendor dependencies and expose mcproxyctl binary via nix build.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 21:01:29 -07:00
parent 357ad60e42
commit c13c868e77
2463 changed files with 6834069 additions and 0 deletions

View File

@@ -0,0 +1,303 @@
// Package auth provides MCIAS token validation with caching for
// Metacircular services.
//
// Every Metacircular service delegates authentication to MCIAS. This
// package handles the login flow, token validation (with a 30-second
// SHA-256-keyed cache), and logout. It communicates directly with the
// MCIAS REST API.
//
// Security: bearer tokens are never logged or included in error messages.
package auth
import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"time"
)
const cacheTTL = 30 * time.Second
// Errors returned by the Authenticator.
var (
// ErrInvalidToken indicates the token is expired, revoked, or otherwise
// invalid.
ErrInvalidToken = errors.New("auth: invalid token")
// ErrInvalidCredentials indicates that the username/password combination
// was rejected by MCIAS.
ErrInvalidCredentials = errors.New("auth: invalid credentials")
// ErrForbidden indicates that MCIAS login policy denied access to this
// service (HTTP 403).
ErrForbidden = errors.New("auth: forbidden by policy")
// ErrUnavailable indicates that MCIAS could not be reached.
ErrUnavailable = errors.New("auth: MCIAS unavailable")
)
// Config holds MCIAS connection settings. This matches the standard [mcias]
// TOML section used by all Metacircular services.
type Config struct {
// ServerURL is the base URL of the MCIAS server
// (e.g., "https://mcias.metacircular.net:8443").
ServerURL string `toml:"server_url"`
// CACert is an optional path to a PEM-encoded CA certificate for
// verifying the MCIAS server's TLS certificate.
CACert string `toml:"ca_cert"`
// ServiceName is this service's identity as registered in MCIAS. It is
// sent with every login request so MCIAS can evaluate service-context
// login policy rules.
ServiceName string `toml:"service_name"`
// Tags are sent with every login request. MCIAS evaluates auth:login
// policy against these tags (e.g., ["env:restricted"]).
Tags []string `toml:"tags"`
}
// TokenInfo holds the validated identity of an authenticated caller.
type TokenInfo struct {
// Username is the MCIAS username (the "sub" claim).
Username string
// AccountType is the MCIAS account type: "human" or "system".
// Used by policy engines that need to distinguish interactive users
// from service accounts.
AccountType string
// Roles is the set of MCIAS roles assigned to the account.
Roles []string
// IsAdmin is true if the account has the "admin" role.
IsAdmin bool
}
// Authenticator validates MCIAS bearer tokens with a short-lived cache.
type Authenticator struct {
httpClient *http.Client
baseURL string
serviceName string
tags []string
logger *slog.Logger
cache *validationCache
}
// New creates an Authenticator that talks to the MCIAS server described
// by cfg. TLS 1.3 is required for all HTTPS connections. If cfg.CACert
// is set, that CA certificate is added to the trust pool.
//
// For plain HTTP URLs (used in tests), TLS configuration is skipped.
func New(cfg Config, logger *slog.Logger) (*Authenticator, error) {
if cfg.ServerURL == "" {
return nil, fmt.Errorf("auth: server_url is required")
}
transport := &http.Transport{}
if !strings.HasPrefix(cfg.ServerURL, "http://") {
tlsCfg := &tls.Config{
MinVersion: tls.VersionTLS13,
}
if cfg.CACert != "" {
pem, err := os.ReadFile(cfg.CACert) //nolint:gosec // CA cert path from operator config
if err != nil {
return nil, fmt.Errorf("auth: read CA cert %s: %w", cfg.CACert, err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(pem) {
return nil, fmt.Errorf("auth: no valid certificates in %s", cfg.CACert)
}
tlsCfg.RootCAs = pool
}
transport.TLSClientConfig = tlsCfg
}
return &Authenticator{
httpClient: &http.Client{
Transport: transport,
Timeout: 10 * time.Second,
},
baseURL: strings.TrimRight(cfg.ServerURL, "/"),
serviceName: cfg.ServiceName,
tags: cfg.Tags,
logger: logger,
cache: newCache(cacheTTL),
}, nil
}
// Login authenticates a user against MCIAS and returns a bearer token.
// totpCode may be empty for accounts without TOTP configured.
//
// The service name and tags from Config are included in the login request
// so MCIAS can evaluate service-context login policy.
func (a *Authenticator) Login(username, password, totpCode string) (token string, expiresAt time.Time, err error) {
reqBody := map[string]interface{}{
"username": username,
"password": password,
}
if totpCode != "" {
reqBody["totp_code"] = totpCode
}
if a.serviceName != "" {
reqBody["service_name"] = a.serviceName
}
if len(a.tags) > 0 {
reqBody["tags"] = a.tags
}
var resp struct {
Token string `json:"token"`
ExpiresAt string `json:"expires_at"`
}
status, err := a.doJSON(http.MethodPost, "/v1/auth/login", reqBody, &resp)
if err != nil {
return "", time.Time{}, fmt.Errorf("auth: MCIAS login: %w", ErrUnavailable)
}
switch status {
case http.StatusOK:
// Parse the expiry time.
exp, parseErr := time.Parse(time.RFC3339, resp.ExpiresAt)
if parseErr != nil {
exp = time.Now().Add(1 * time.Hour) // fallback
}
return resp.Token, exp, nil
case http.StatusForbidden:
return "", time.Time{}, ErrForbidden
default:
return "", time.Time{}, ErrInvalidCredentials
}
}
// ValidateToken checks a bearer token against MCIAS. Results are cached
// by the SHA-256 hash of the token for 30 seconds.
//
// Returns ErrInvalidToken if the token is expired, revoked, or otherwise
// not valid.
func (a *Authenticator) ValidateToken(token string) (*TokenInfo, error) {
h := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(h[:])
if info, ok := a.cache.get(tokenHash); ok {
return info, nil
}
var resp struct {
Valid bool `json:"valid"`
Sub string `json:"sub"`
Username string `json:"username"`
AccountType string `json:"account_type"`
Roles []string `json:"roles"`
}
status, err := a.doJSON(http.MethodPost, "/v1/token/validate",
map[string]string{"token": token}, &resp)
if err != nil {
return nil, fmt.Errorf("auth: MCIAS validate: %w", ErrUnavailable)
}
if status != http.StatusOK || !resp.Valid {
return nil, ErrInvalidToken
}
info := &TokenInfo{
Username: resp.Username,
AccountType: resp.AccountType,
Roles: resp.Roles,
IsAdmin: hasRole(resp.Roles, "admin"),
}
if info.Username == "" {
info.Username = resp.Sub
}
a.cache.put(tokenHash, info)
return info, nil
}
// ClearCache removes all cached token validation results. This should be
// called when the service transitions to a state where cached tokens may
// no longer be valid (e.g., Metacrypt sealing).
func (a *Authenticator) ClearCache() {
a.cache.clear()
}
// Logout revokes a token on the MCIAS server.
func (a *Authenticator) Logout(token string) error {
req, err := http.NewRequestWithContext(context.Background(),
http.MethodPost, a.baseURL+"/v1/auth/logout", nil)
if err != nil {
return fmt.Errorf("auth: build logout request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := a.httpClient.Do(req)
if err != nil {
return fmt.Errorf("auth: MCIAS logout: %w", ErrUnavailable)
}
_ = resp.Body.Close()
return nil
}
// doJSON makes a JSON request to the MCIAS server and decodes the response.
// It returns the HTTP status code and any transport error.
func (a *Authenticator) doJSON(method, path string, body, out interface{}) (int, error) {
var reqBody io.Reader
if body != nil {
b, err := json.Marshal(body)
if err != nil {
return 0, fmt.Errorf("marshal request: %w", err)
}
reqBody = bytes.NewReader(b)
}
req, err := http.NewRequestWithContext(context.Background(),
method, a.baseURL+path, reqBody)
if err != nil {
return 0, fmt.Errorf("build request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := a.httpClient.Do(req)
if err != nil {
return 0, err
}
defer func() { _ = resp.Body.Close() }()
if out != nil && resp.StatusCode == http.StatusOK {
respBytes, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return resp.StatusCode, fmt.Errorf("read response: %w", readErr)
}
if len(respBytes) > 0 {
if decErr := json.Unmarshal(respBytes, out); decErr != nil {
return resp.StatusCode, fmt.Errorf("decode response: %w", decErr)
}
}
}
return resp.StatusCode, nil
}
func hasRole(roles []string, target string) bool {
for _, r := range roles {
if r == target {
return true
}
}
return false
}

View File

@@ -0,0 +1,71 @@
package auth
import (
"sync"
"time"
)
// cacheEntry holds a cached TokenInfo and its expiration time.
type cacheEntry struct {
info *TokenInfo
expiresAt time.Time
}
// validationCache provides a concurrency-safe, TTL-based cache for token
// validation results. Tokens are keyed by their SHA-256 hex digest.
type validationCache struct {
mu sync.RWMutex
entries map[string]cacheEntry
ttl time.Duration
now func() time.Time // injectable clock for testing
}
// newCache creates a validationCache with the given TTL.
func newCache(ttl time.Duration) *validationCache {
return &validationCache{
entries: make(map[string]cacheEntry),
ttl: ttl,
now: time.Now,
}
}
// get returns cached TokenInfo for the given token hash, or false if
// the entry is missing or expired. Expired entries are lazily evicted.
func (c *validationCache) get(tokenHash string) (*TokenInfo, bool) {
c.mu.RLock()
entry, ok := c.entries[tokenHash]
c.mu.RUnlock()
if !ok {
return nil, false
}
if c.now().After(entry.expiresAt) {
// Lazy evict the expired entry.
c.mu.Lock()
if e, exists := c.entries[tokenHash]; exists && c.now().After(e.expiresAt) {
delete(c.entries, tokenHash)
}
c.mu.Unlock()
return nil, false
}
return entry.info, true
}
// clear removes all entries from the cache.
func (c *validationCache) clear() {
c.mu.Lock()
c.entries = make(map[string]cacheEntry)
c.mu.Unlock()
}
// put stores TokenInfo in the cache with an expiration of now + TTL.
func (c *validationCache) put(tokenHash string, info *TokenInfo) {
c.mu.Lock()
c.entries[tokenHash] = cacheEntry{
info: info,
expiresAt: c.now().Add(c.ttl),
}
c.mu.Unlock()
}

View File

@@ -0,0 +1,19 @@
package auth
import "context"
// contextKey is an unexported type used as the context key for TokenInfo,
// preventing collisions with keys from other packages.
type contextKey struct{}
// ContextWithTokenInfo returns a new context carrying the given TokenInfo.
func ContextWithTokenInfo(ctx context.Context, info *TokenInfo) context.Context {
return context.WithValue(ctx, contextKey{}, info)
}
// TokenInfoFromContext extracts TokenInfo from the context. It returns nil
// if no TokenInfo is present.
func TokenInfoFromContext(ctx context.Context) *TokenInfo {
info, _ := ctx.Value(contextKey{}).(*TokenInfo)
return info
}

View File

@@ -0,0 +1,307 @@
// Package config provides TOML configuration loading with environment
// variable overrides for Metacircular services.
//
// Services define their own config struct embedding [Base], which provides
// the standard sections (Server, Database, MCIAS, Log). Use [Load] to
// parse a TOML file, apply environment overrides, set defaults, and
// validate required fields.
//
// # Duration fields
//
// Timeout fields in [ServerConfig] use the [Duration] type rather than
// [time.Duration] because go-toml v2 does not natively decode strings
// (e.g., "30s") into time.Duration. Access the underlying value via
// the embedded field:
//
// cfg.Server.ReadTimeout.Duration // time.Duration
//
// In TOML files, durations are written as Go duration strings:
//
// read_timeout = "30s"
// idle_timeout = "2m"
//
// Environment variable overrides also use this format:
//
// MCR_SERVER_READ_TIMEOUT=30s
package config
import (
"fmt"
"os"
"reflect"
"strings"
"time"
"github.com/pelletier/go-toml/v2"
"git.wntrmute.dev/kyle/mcdsl/auth"
)
// Base contains the configuration sections common to all Metacircular
// services. Services embed this in their own config struct and add
// service-specific sections.
//
// Example:
//
// type MyConfig struct {
// config.Base
// MyService MyServiceSection `toml:"my_service"`
// }
type Base struct {
Server ServerConfig `toml:"server"`
Database DatabaseConfig `toml:"database"`
MCIAS auth.Config `toml:"mcias"`
Log LogConfig `toml:"log"`
}
// ServerConfig holds TLS server settings.
type ServerConfig struct {
// ListenAddr is the HTTPS listen address (e.g., ":8443"). Required.
ListenAddr string `toml:"listen_addr"`
// GRPCAddr is the gRPC listen address (e.g., ":9443"). Optional;
// gRPC is disabled if empty.
GRPCAddr string `toml:"grpc_addr"`
// TLSCert is the path to the TLS certificate file (PEM). Required.
TLSCert string `toml:"tls_cert"`
// TLSKey is the path to the TLS private key file (PEM). Required.
TLSKey string `toml:"tls_key"`
// ReadTimeout is the maximum duration for reading the entire request.
// Defaults to 30s.
ReadTimeout Duration `toml:"read_timeout"`
// WriteTimeout is the maximum duration before timing out writes.
// Defaults to 30s.
WriteTimeout Duration `toml:"write_timeout"`
// IdleTimeout is the maximum time to wait for the next request on
// a keep-alive connection. Defaults to 120s.
IdleTimeout Duration `toml:"idle_timeout"`
// ShutdownTimeout is the maximum time to wait for in-flight requests
// to drain during graceful shutdown. Defaults to 60s.
ShutdownTimeout Duration `toml:"shutdown_timeout"`
}
// DatabaseConfig holds SQLite database settings.
type DatabaseConfig struct {
// Path is the path to the SQLite database file. Required.
Path string `toml:"path"`
}
// LogConfig holds logging settings.
type LogConfig struct {
// Level is the log level (debug, info, warn, error). Defaults to "info".
Level string `toml:"level"`
}
// WebConfig holds web UI server settings. This is not part of Base because
// not all services have a web UI — services that do can add it to their
// own config struct.
type WebConfig struct {
// ListenAddr is the web UI listen address (e.g., "127.0.0.1:8080").
ListenAddr string `toml:"listen_addr"`
// GRPCAddr is the gRPC address of the API server that the web UI
// connects to.
GRPCAddr string `toml:"grpc_addr"`
// CACert is an optional CA certificate for verifying the API server's
// TLS certificate.
CACert string `toml:"ca_cert"`
}
// Validator is an optional interface that config structs can implement
// to add service-specific validation. If the config type implements
// Validator, its Validate method is called after defaults and env
// overrides are applied.
type Validator interface {
Validate() error
}
// Load reads a TOML config file at path, applies environment variable
// overrides using envPrefix (e.g., "MCR" maps MCR_SERVER_LISTEN_ADDR to
// Server.ListenAddr), sets defaults for unset optional fields, and
// validates required fields.
//
// If T implements [Validator], its Validate method is called after all
// other processing.
func Load[T any](path string, envPrefix string) (*T, error) {
data, err := os.ReadFile(path) //nolint:gosec // config path is operator-supplied
if err != nil {
return nil, fmt.Errorf("config: read %s: %w", path, err)
}
var cfg T
if err := toml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("config: parse %s: %w", path, err)
}
if envPrefix != "" {
applyEnvToStruct(reflect.ValueOf(&cfg).Elem(), envPrefix)
}
applyBaseDefaults(&cfg)
if err := validateBase(&cfg); err != nil {
return nil, err
}
if v, ok := any(&cfg).(Validator); ok {
if err := v.Validate(); err != nil {
return nil, fmt.Errorf("config: %w", err)
}
}
return &cfg, nil
}
// applyBaseDefaults sets defaults on the embedded Base struct if present.
func applyBaseDefaults(cfg any) {
base := findBase(cfg)
if base == nil {
return
}
if base.Server.ReadTimeout.Duration == 0 {
base.Server.ReadTimeout.Duration = 30 * time.Second
}
if base.Server.WriteTimeout.Duration == 0 {
base.Server.WriteTimeout.Duration = 30 * time.Second
}
if base.Server.IdleTimeout.Duration == 0 {
base.Server.IdleTimeout.Duration = 120 * time.Second
}
if base.Server.ShutdownTimeout.Duration == 0 {
base.Server.ShutdownTimeout.Duration = 60 * time.Second
}
if base.Log.Level == "" {
base.Log.Level = "info"
}
}
// validateBase checks required fields on the embedded Base struct if present.
func validateBase(cfg any) error {
base := findBase(cfg)
if base == nil {
return nil
}
required := []struct {
name string
value string
}{
{"server.listen_addr", base.Server.ListenAddr},
{"server.tls_cert", base.Server.TLSCert},
{"server.tls_key", base.Server.TLSKey},
}
for _, r := range required {
if r.value == "" {
return fmt.Errorf("config: required field %q is missing", r.name)
}
}
return nil
}
// findBase returns a pointer to the embedded Base struct, or nil if the
// config type does not embed Base.
func findBase(cfg any) *Base {
v := reflect.ValueOf(cfg)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil
}
// Check if cfg *is* a Base.
if b, ok := v.Addr().Interface().(*Base); ok {
return b
}
// Check embedded fields.
t := v.Type()
for i := range t.NumField() {
field := t.Field(i)
if field.Anonymous && field.Type == reflect.TypeOf(Base{}) {
b, ok := v.Field(i).Addr().Interface().(*Base)
if ok {
return b
}
}
}
return nil
}
// applyEnvToStruct recursively walks a struct and overrides field values
// from environment variables. The env variable name is built from the
// prefix and the toml tag: PREFIX_SECTION_FIELD (uppercased).
//
// Supported field types: string, time.Duration (as int64), []string
// (comma-separated), bool, int.
func applyEnvToStruct(v reflect.Value, prefix string) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
t := v.Type()
for i := range t.NumField() {
field := t.Field(i)
fv := v.Field(i)
// For anonymous (embedded) fields, recurse with the same prefix.
if field.Anonymous {
applyEnvToStruct(fv, prefix)
continue
}
tag := field.Tag.Get("toml")
if tag == "" || tag == "-" {
continue
}
envKey := prefix + "_" + strings.ToUpper(tag)
// Handle Duration wrapper before generic struct recursion.
if field.Type == reflect.TypeOf(Duration{}) {
envVal, ok := os.LookupEnv(envKey)
if ok {
d, parseErr := time.ParseDuration(envVal)
if parseErr == nil {
fv.Set(reflect.ValueOf(Duration{d}))
}
}
continue
}
if field.Type.Kind() == reflect.Struct {
applyEnvToStruct(fv, envKey)
continue
}
envVal, ok := os.LookupEnv(envKey)
if !ok {
continue
}
switch fv.Kind() {
case reflect.String:
fv.SetString(envVal)
case reflect.Bool:
fv.SetBool(envVal == "true" || envVal == "1")
case reflect.Slice:
if field.Type.Elem().Kind() == reflect.String {
parts := strings.Split(envVal, ",")
for j := range parts {
parts[j] = strings.TrimSpace(parts[j])
}
fv.Set(reflect.ValueOf(parts))
}
}
}
}

View File

@@ -0,0 +1,37 @@
package config
import (
"fmt"
"time"
)
// Duration is a [time.Duration] that can be unmarshalled from a TOML string
// (e.g., "30s", "5m"). go-toml v2 does not natively decode strings into
// time.Duration, so this wrapper implements [encoding.TextUnmarshaler].
//
// Access the underlying time.Duration via the embedded field:
//
// cfg.Server.ReadTimeout.Duration // time.Duration value
//
// Duration values work directly with time functions that accept
// time.Duration because of the embedding:
//
// time.After(cfg.Server.ReadTimeout.Duration)
type Duration struct {
time.Duration
}
// UnmarshalText implements encoding.TextUnmarshaler for TOML string decoding.
func (d *Duration) UnmarshalText(text []byte) error {
parsed, err := time.ParseDuration(string(text))
if err != nil {
return fmt.Errorf("invalid duration %q: %w", string(text), err)
}
d.Duration = parsed
return nil
}
// MarshalText implements encoding.TextMarshaler for TOML string encoding.
func (d Duration) MarshalText() ([]byte, error) {
return []byte(d.String()), nil
}

View File

@@ -0,0 +1,181 @@
// Package db provides SQLite database setup, migrations, and snapshots
// for Metacircular services.
//
// All databases are opened with the standard Metacircular pragmas (WAL mode,
// foreign keys, busy timeout) and restrictive file permissions (0600).
package db
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"time"
_ "modernc.org/sqlite" // SQLite driver (pure Go, no CGo).
)
// Open opens or creates a SQLite database at path with the standard
// Metacircular pragmas:
//
// PRAGMA journal_mode = WAL;
// PRAGMA foreign_keys = ON;
// PRAGMA busy_timeout = 5000;
//
// The file is created with 0600 permissions (owner read/write only).
// The parent directory is created if it does not exist.
//
// Open returns a standard [*sql.DB] — no wrapper types. Services use it
// directly with database/sql.
func Open(path string) (*sql.DB, error) {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, fmt.Errorf("db: create directory %s: %w", dir, err)
}
// Pre-create the file with restrictive permissions if it does not exist.
if _, err := os.Stat(path); os.IsNotExist(err) {
f, createErr := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600) //nolint:gosec // path is caller-provided config, not user input
if createErr != nil {
return nil, fmt.Errorf("db: create file %s: %w", path, createErr)
}
_ = f.Close()
}
database, err := sql.Open("sqlite", path)
if err != nil {
return nil, fmt.Errorf("db: open %s: %w", path, err)
}
pragmas := []string{
"PRAGMA journal_mode = WAL",
"PRAGMA foreign_keys = ON",
"PRAGMA busy_timeout = 5000",
}
for _, p := range pragmas {
if _, execErr := database.Exec(p); execErr != nil {
_ = database.Close()
return nil, fmt.Errorf("db: %s: %w", p, execErr)
}
}
// Ensure permissions are correct even if the file already existed.
if err := os.Chmod(path, 0600); err != nil {
_ = database.Close()
return nil, fmt.Errorf("db: chmod %s: %w", path, err)
}
return database, nil
}
// Migration is a numbered, named schema change. Services define their
// migrations as a []Migration slice — the slice is the schema history.
type Migration struct {
// Version is the migration number. Must be unique and should be
// sequential starting from 1.
Version int
// Name is a short human-readable description (e.g., "initial schema").
Name string
// SQL is the DDL/DML to execute. Multiple statements are allowed
// (separated by semicolons). Each migration runs in a transaction.
SQL string
}
// Migrate applies all pending migrations from the given slice. It creates
// the schema_migrations tracking table if it does not exist.
//
// Each migration runs in its own transaction. Already-applied migrations
// (identified by version number) are skipped. Timestamps are stored as
// RFC 3339 UTC.
func Migrate(database *sql.DB, migrations []Migration) error {
_, err := database.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
name TEXT NOT NULL DEFAULT '',
applied_at TEXT NOT NULL DEFAULT ''
)`)
if err != nil {
return fmt.Errorf("db: create schema_migrations: %w", err)
}
for _, m := range migrations {
applied, checkErr := migrationApplied(database, m.Version)
if checkErr != nil {
return checkErr
}
if applied {
continue
}
tx, txErr := database.Begin()
if txErr != nil {
return fmt.Errorf("db: begin migration %d (%s): %w", m.Version, m.Name, txErr)
}
if _, execErr := tx.Exec(m.SQL); execErr != nil {
_ = tx.Rollback()
return fmt.Errorf("db: migration %d (%s): %w", m.Version, m.Name, execErr)
}
now := time.Now().UTC().Format(time.RFC3339)
if _, execErr := tx.Exec(
`INSERT INTO schema_migrations (version, name, applied_at) VALUES (?, ?, ?)`,
m.Version, m.Name, now,
); execErr != nil {
_ = tx.Rollback()
return fmt.Errorf("db: record migration %d: %w", m.Version, execErr)
}
if commitErr := tx.Commit(); commitErr != nil {
return fmt.Errorf("db: commit migration %d: %w", m.Version, commitErr)
}
}
return nil
}
// SchemaVersion returns the highest applied migration version, or 0 if
// no migrations have been applied.
func SchemaVersion(database *sql.DB) (int, error) {
var version sql.NullInt64
err := database.QueryRow(`SELECT MAX(version) FROM schema_migrations`).Scan(&version)
if err != nil {
return 0, fmt.Errorf("db: schema version: %w", err)
}
if !version.Valid {
return 0, nil
}
return int(version.Int64), nil
}
// Snapshot creates a consistent backup of the database at destPath using
// SQLite's VACUUM INTO. The destination file is created with 0600
// permissions.
func Snapshot(database *sql.DB, destPath string) error {
dir := filepath.Dir(destPath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("db: create snapshot directory %s: %w", dir, err)
}
if _, err := database.Exec("VACUUM INTO ?", destPath); err != nil {
return fmt.Errorf("db: snapshot: %w", err)
}
if err := os.Chmod(destPath, 0600); err != nil {
return fmt.Errorf("db: chmod snapshot %s: %w", destPath, err)
}
return nil
}
func migrationApplied(database *sql.DB, version int) (bool, error) {
var count int
err := database.QueryRow(
`SELECT COUNT(*) FROM schema_migrations WHERE version = ?`, version,
).Scan(&count)
if err != nil {
return false, fmt.Errorf("db: check migration %d: %w", version, err)
}
return count > 0, nil
}