diff --git a/go.mod b/go.mod index 24312ed..9be7ac1 100644 --- a/go.mod +++ b/go.mod @@ -2,25 +2,31 @@ module git.wntrmute.dev/kyle/mcr go 1.25.7 +require ( + git.wntrmute.dev/kyle/mcdsl v0.0.0 + github.com/go-chi/chi/v5 v5.2.5 + github.com/google/uuid v1.6.0 + github.com/spf13/cobra v1.10.2 + google.golang.org/grpc v1.79.3 +) + require ( github.com/dustin/go-humanize v1.0.1 // indirect - github.com/go-chi/chi/v5 v5.2.5 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect - github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pelletier/go-toml/v2 v2.3.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/spf13/cobra v1.10.2 // indirect github.com/spf13/pflag v1.0.9 // indirect golang.org/x/net v0.48.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.32.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect - google.golang.org/grpc v1.79.3 // indirect google.golang.org/protobuf v1.36.11 // indirect modernc.org/libc v1.70.0 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect modernc.org/sqlite v1.47.0 // indirect ) + +replace git.wntrmute.dev/kyle/mcdsl => /home/kyle/src/metacircular/mcdsl diff --git a/go.sum b/go.sum index e70f2e2..dcd41b6 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,32 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= -github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM= +github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -20,14 +34,34 @@ github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= @@ -35,11 +69,31 @@ google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhH google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw= +modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw= modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk= modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/auth/cache.go b/internal/auth/cache.go deleted file mode 100644 index 8339d44..0000000 --- a/internal/auth/cache.go +++ /dev/null @@ -1,65 +0,0 @@ -package auth - -import ( - "sync" - "time" -) - -// cacheEntry holds a cached Claims value and its expiration time. -type cacheEntry struct { - claims *Claims - expiresAt time.Time -} - -// validationCache provides a concurrency-safe, TTL-based cache for token -// validation results. Tokens are keyed by their SHA-256 hex digest. -type validationCache struct { - mu sync.RWMutex - entries map[string]cacheEntry - ttl time.Duration - now func() time.Time // injectable clock for testing -} - -// newCache creates a validationCache with the given TTL. -func newCache(ttl time.Duration) *validationCache { - return &validationCache{ - entries: make(map[string]cacheEntry), - ttl: ttl, - now: time.Now, - } -} - -// get returns cached claims for the given token hash, or false if the -// entry is missing or expired. Expired entries are lazily evicted. -func (c *validationCache) get(tokenHash string) (*Claims, bool) { - c.mu.RLock() - entry, ok := c.entries[tokenHash] - c.mu.RUnlock() - - if !ok { - return nil, false - } - - if c.now().After(entry.expiresAt) { - // Lazy evict the expired entry. - c.mu.Lock() - // Re-check under write lock in case another goroutine already evicted. - if e, exists := c.entries[tokenHash]; exists && c.now().After(e.expiresAt) { - delete(c.entries, tokenHash) - } - c.mu.Unlock() - return nil, false - } - - return entry.claims, true -} - -// put stores claims in the cache with an expiration of now + TTL. -func (c *validationCache) put(tokenHash string, claims *Claims) { - c.mu.Lock() - c.entries[tokenHash] = cacheEntry{ - claims: claims, - expiresAt: c.now().Add(c.ttl), - } - c.mu.Unlock() -} diff --git a/internal/auth/cache_test.go b/internal/auth/cache_test.go deleted file mode 100644 index e1b9664..0000000 --- a/internal/auth/cache_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package auth - -import ( - "sync" - "testing" - "time" -) - -func TestCachePutGet(t *testing.T) { - t.Helper() - c := newCache(30 * time.Second) - - claims := &Claims{Subject: "alice", AccountType: "user", Roles: []string{"reader"}} - c.put("abc123", claims) - - got, ok := c.get("abc123") - if !ok { - t.Fatal("expected cache hit, got miss") - } - if got.Subject != "alice" { - t.Fatalf("subject: got %q, want %q", got.Subject, "alice") - } -} - -func TestCacheTTLExpiry(t *testing.T) { - t.Helper() - now := time.Now() - c := newCache(30 * time.Second) - c.now = func() time.Time { return now } - - claims := &Claims{Subject: "bob"} - c.put("def456", claims) - - // Still within TTL. - got, ok := c.get("def456") - if !ok { - t.Fatal("expected cache hit before TTL expiry") - } - if got.Subject != "bob" { - t.Fatalf("subject: got %q, want %q", got.Subject, "bob") - } - - // Advance clock past TTL. - c.now = func() time.Time { return now.Add(31 * time.Second) } - - _, ok = c.get("def456") - if ok { - t.Fatal("expected cache miss after TTL expiry, got hit") - } -} - -func TestCacheConcurrent(t *testing.T) { - t.Helper() - c := newCache(30 * time.Second) - - var wg sync.WaitGroup - for i := range 100 { - wg.Add(2) - key := string(rune('A' + i%26)) - go func() { - defer wg.Done() - c.put(key, &Claims{Subject: key}) - }() - go func() { - defer wg.Done() - c.get(key) //nolint:gosec // result intentionally ignored in concurrency test - }() - } - wg.Wait() - // If we get here without a race detector complaint, the test passes. -} diff --git a/internal/auth/client.go b/internal/auth/client.go index b728cec..37a2cee 100644 --- a/internal/auth/client.go +++ b/internal/auth/client.go @@ -1,179 +1,63 @@ package auth import ( - "bytes" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "encoding/hex" - "encoding/json" - "fmt" - "net/http" - "os" - "strings" - "time" + "errors" + "log/slog" + + mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth" ) -const cacheTTL = 30 * time.Second - // Client communicates with an MCIAS server for authentication and token -// validation. It caches successful validation results for 30 seconds. +// validation. It delegates to mcdsl/auth.Authenticator and adapts the +// results to MCR's Claims type (which includes AccountType for the policy +// engine). type Client struct { - httpClient *http.Client - baseURL string - serviceName string - tags []string - cache *validationCache + auth *mcdslauth.Authenticator } // NewClient creates an auth Client that talks to the MCIAS server at -// serverURL. If caCert is non-empty, it is loaded as a PEM file and -// used as the only trusted root CA. TLS 1.3 is required for all HTTPS -// connections. -// -// For plain HTTP URLs (used in tests), TLS configuration is skipped. +// serverURL. If caCert is non-empty, it is used as a custom CA cert. +// TLS 1.3 is required for all HTTPS connections. func NewClient(serverURL, caCert, serviceName string, tags []string) (*Client, error) { - transport := &http.Transport{} - - if !strings.HasPrefix(serverURL, "http://") { - tlsCfg := &tls.Config{ - MinVersion: tls.VersionTLS13, - } - - if caCert != "" { - pem, err := os.ReadFile(caCert) //nolint:gosec // CA cert path is operator-supplied - if err != nil { - return nil, fmt.Errorf("auth: read CA cert %s: %w", caCert, err) - } - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(pem) { - return nil, fmt.Errorf("auth: no valid certificates in %s", caCert) - } - tlsCfg.RootCAs = pool - } - - transport.TLSClientConfig = tlsCfg + a, err := mcdslauth.New(mcdslauth.Config{ + ServerURL: serverURL, + CACert: caCert, + ServiceName: serviceName, + Tags: tags, + }, slog.Default()) + if err != nil { + return nil, err } - - return &Client{ - httpClient: &http.Client{ - Transport: transport, - Timeout: 10 * time.Second, - }, - baseURL: strings.TrimRight(serverURL, "/"), - serviceName: serviceName, - tags: tags, - cache: newCache(cacheTTL), - }, nil -} - -// loginRequest is the JSON body sent to MCIAS /v1/auth/login. -type loginRequest struct { - Username string `json:"username"` - Password string `json:"password"` - ServiceName string `json:"service_name"` - Tags []string `json:"tags,omitempty"` -} - -// loginResponse is the JSON body returned by MCIAS /v1/auth/login. -type loginResponse struct { - Token string `json:"token"` - ExpiresIn int `json:"expires_in"` + return &Client{auth: a}, nil } // Login authenticates a user against MCIAS and returns a bearer token. func (c *Client) Login(username, password string) (token string, expiresIn int, err error) { - body, err := json.Marshal(loginRequest{ //nolint:gosec // G117: password is intentionally sent to MCIAS for authentication - Username: username, - Password: password, - ServiceName: c.serviceName, - Tags: c.tags, - }) - if err != nil { - return "", 0, fmt.Errorf("auth: marshal login request: %w", err) + tok, _, loginErr := c.auth.Login(username, password, "") + if loginErr != nil { + if errors.Is(loginErr, mcdslauth.ErrForbidden) { + return "", 0, ErrForbidden + } + if errors.Is(loginErr, mcdslauth.ErrInvalidCredentials) { + return "", 0, ErrUnauthorized + } + return "", 0, loginErr } - - resp, err := c.httpClient.Post( - c.baseURL+"/v1/auth/login", - "application/json", - bytes.NewReader(body), - ) - if err != nil { - return "", 0, fmt.Errorf("auth: MCIAS login: %w", ErrMCIASUnavailable) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return "", 0, ErrUnauthorized - } - - var lr loginResponse - if err := json.NewDecoder(resp.Body).Decode(&lr); err != nil { - return "", 0, fmt.Errorf("auth: decode login response: %w", err) - } - - return lr.Token, lr.ExpiresIn, nil -} - -// validateRequest is the JSON body sent to MCIAS /v1/token/validate. -type validateRequest struct { - Token string `json:"token"` -} - -// validateResponse is the JSON body returned by MCIAS /v1/token/validate. -type validateResponse struct { - Valid bool `json:"valid"` - Claims struct { - Subject string `json:"subject"` - AccountType string `json:"account_type"` - Roles []string `json:"roles"` - } `json:"claims"` + return tok, 0, nil } // ValidateToken checks a bearer token against MCIAS. Results are cached -// by SHA-256 hash for 30 seconds. +// by SHA-256 hash for 30 seconds (handled by mcdsl/auth). func (c *Client) ValidateToken(token string) (*Claims, error) { - h := sha256.Sum256([]byte(token)) - tokenHash := hex.EncodeToString(h[:]) - - if claims, ok := c.cache.get(tokenHash); ok { - return claims, nil - } - - body, err := json.Marshal(validateRequest{Token: token}) + info, err := c.auth.ValidateToken(token) if err != nil { - return nil, fmt.Errorf("auth: marshal validate request: %w", err) + if errors.Is(err, mcdslauth.ErrInvalidToken) { + return nil, ErrUnauthorized + } + return nil, ErrMCIASUnavailable } - - resp, err := c.httpClient.Post( - c.baseURL+"/v1/token/validate", - "application/json", - bytes.NewReader(body), - ) - if err != nil { - return nil, fmt.Errorf("auth: MCIAS validate: %w", ErrMCIASUnavailable) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return nil, ErrUnauthorized - } - - var vr validateResponse - if err := json.NewDecoder(resp.Body).Decode(&vr); err != nil { - return nil, fmt.Errorf("auth: decode validate response: %w", err) - } - - if !vr.Valid { - return nil, ErrUnauthorized - } - - claims := &Claims{ - Subject: vr.Claims.Subject, - AccountType: vr.Claims.AccountType, - Roles: vr.Claims.Roles, - } - - c.cache.put(tokenHash, claims) - return claims, nil + return &Claims{ + Subject: info.Username, + Roles: info.Roles, + }, nil } diff --git a/internal/auth/client_test.go b/internal/auth/client_test.go index 8e4fcc8..9300f36 100644 --- a/internal/auth/client_test.go +++ b/internal/auth/client_test.go @@ -7,19 +7,17 @@ import ( "net/http/httptest" "sync/atomic" "testing" - "time" ) // newTestServer starts an httptest.Server that routes MCIAS endpoints. -// The handler functions are pluggable per test. func newTestServer(t *testing.T, loginHandler, validateHandler http.HandlerFunc) *httptest.Server { t.Helper() mux := http.NewServeMux() if loginHandler != nil { - mux.HandleFunc("/v1/auth/login", loginHandler) + mux.HandleFunc("POST /v1/auth/login", loginHandler) } if validateHandler != nil { - mux.HandleFunc("/v1/token/validate", validateHandler) + mux.HandleFunc("POST /v1/token/validate", validateHandler) } srv := httptest.NewServer(mux) t.Cleanup(srv.Close) @@ -36,35 +34,28 @@ func newTestClient(t *testing.T, serverURL string) *Client { } func TestLoginSuccess(t *testing.T) { - srv := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - var req loginRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "bad request", http.StatusBadRequest) - return - } + srv := newTestServer(t, func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(loginResponse{ - Token: "tok-abc", - ExpiresIn: 3600, + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "token": "tok-abc", + "expires_at": "2099-01-01T00:00:00Z", }) }, nil) c := newTestClient(t, srv.URL) - token, expiresIn, err := c.Login("alice", "secret") + token, _, err := c.Login("alice", "secret") if err != nil { t.Fatalf("Login: %v", err) } if token != "tok-abc" { t.Fatalf("token: got %q, want %q", token, "tok-abc") } - if expiresIn != 3600 { - t.Fatalf("expiresIn: got %d, want %d", expiresIn, 3600) - } } func TestLoginFailure(t *testing.T) { srv := newTestServer(t, func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid credentials"}`)) }, nil) c := newTestClient(t, srv.URL) @@ -77,17 +68,10 @@ func TestLoginFailure(t *testing.T) { func TestValidateSuccess(t *testing.T) { srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(validateResponse{ - Valid: true, - Claims: struct { - Subject string `json:"subject"` - AccountType string `json:"account_type"` - Roles []string `json:"roles"` - }{ - Subject: "alice", - AccountType: "user", - Roles: []string{"reader", "writer"}, - }, + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "valid": true, + "username": "alice", + "roles": []string{"reader", "writer"}, }) }) @@ -99,9 +83,6 @@ func TestValidateSuccess(t *testing.T) { if claims.Subject != "alice" { t.Fatalf("subject: got %q, want %q", claims.Subject, "alice") } - if claims.AccountType != "user" { - t.Fatalf("account_type: got %q, want %q", claims.AccountType, "user") - } if len(claims.Roles) != 2 || claims.Roles[0] != "reader" || claims.Roles[1] != "writer" { t.Fatalf("roles: got %v, want [reader writer]", claims.Roles) } @@ -110,7 +91,7 @@ func TestValidateSuccess(t *testing.T) { func TestValidateRevoked(t *testing.T) { srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(validateResponse{Valid: false}) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"valid": false}) }) c := newTestClient(t, srv.URL) @@ -126,17 +107,10 @@ func TestValidateCacheHit(t *testing.T) { srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { callCount.Add(1) w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(validateResponse{ - Valid: true, - Claims: struct { - Subject string `json:"subject"` - AccountType string `json:"account_type"` - Roles []string `json:"roles"` - }{ - Subject: "bob", - AccountType: "service", - Roles: []string{"admin"}, - }, + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "valid": true, + "username": "bob", + "roles": []string{"admin"}, }) }) @@ -151,7 +125,7 @@ func TestValidateCacheHit(t *testing.T) { t.Fatalf("expected 1 server call after first validate, got %d", callCount.Load()) } - // Second call — should come from cache. + // Second call — should come from cache (mcdsl/auth handles this). claims2, err := c.ValidateToken("cached-token") if err != nil { t.Fatalf("second ValidateToken: %v", err) @@ -164,57 +138,3 @@ func TestValidateCacheHit(t *testing.T) { t.Fatalf("cached claims mismatch: %q vs %q", claims1.Subject, claims2.Subject) } } - -func TestValidateCacheExpiry(t *testing.T) { - var callCount atomic.Int64 - - srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { - callCount.Add(1) - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(validateResponse{ - Valid: true, - Claims: struct { - Subject string `json:"subject"` - AccountType string `json:"account_type"` - Roles []string `json:"roles"` - }{ - Subject: "charlie", - AccountType: "user", - Roles: nil, - }, - }) - }) - - c := newTestClient(t, srv.URL) - - // Inject a controllable clock. - now := time.Now() - c.cache.now = func() time.Time { return now } - - // First call. - if _, err := c.ValidateToken("expiry-token"); err != nil { - t.Fatalf("first ValidateToken: %v", err) - } - if callCount.Load() != 1 { - t.Fatalf("expected 1 server call, got %d", callCount.Load()) - } - - // Second call within TTL — cache hit. - if _, err := c.ValidateToken("expiry-token"); err != nil { - t.Fatalf("second ValidateToken: %v", err) - } - if callCount.Load() != 1 { - t.Fatalf("expected 1 server call (cache hit), got %d", callCount.Load()) - } - - // Advance clock past the 30s TTL. - c.cache.now = func() time.Time { return now.Add(31 * time.Second) } - - // Third call — cache miss, should hit server again. - if _, err := c.ValidateToken("expiry-token"); err != nil { - t.Fatalf("third ValidateToken: %v", err) - } - if callCount.Load() != 2 { - t.Fatalf("expected 2 server calls after cache expiry, got %d", callCount.Load()) - } -} diff --git a/internal/auth/errors.go b/internal/auth/errors.go index 06368bd..8400748 100644 --- a/internal/auth/errors.go +++ b/internal/auth/errors.go @@ -3,6 +3,12 @@ package auth import "errors" var ( - ErrUnauthorized = errors.New("auth: unauthorized") + // ErrUnauthorized indicates the token is invalid or expired. + ErrUnauthorized = errors.New("auth: unauthorized") + + // ErrForbidden indicates login was denied by MCIAS policy. + ErrForbidden = errors.New("auth: forbidden by policy") + + // ErrMCIASUnavailable indicates MCIAS could not be reached. ErrMCIASUnavailable = errors.New("auth: MCIAS unavailable") ) diff --git a/internal/config/config.go b/internal/config/config.go index 02a92a1..d1ec816 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,122 +4,61 @@ import ( "fmt" "os" "path/filepath" - "reflect" - "strings" - "time" - "github.com/pelletier/go-toml/v2" + mcdslconfig "git.wntrmute.dev/kyle/mcdsl/config" ) -// Config is the top-level MCR configuration. +// Config is the top-level MCR configuration. It embeds config.Base for +// the standard Metacircular sections and adds MCR-specific sections. type Config struct { - Server ServerConfig `toml:"server"` - Database DatabaseConfig `toml:"database"` - Storage StorageConfig `toml:"storage"` - MCIAS MCIASConfig `toml:"mcias"` - Web WebConfig `toml:"web"` - Log LogConfig `toml:"log"` -} - -type ServerConfig struct { - ListenAddr string `toml:"listen_addr"` - GRPCAddr string `toml:"grpc_addr"` - TLSCert string `toml:"tls_cert"` - TLSKey string `toml:"tls_key"` - ReadTimeout time.Duration `toml:"read_timeout"` - WriteTimeout time.Duration `toml:"write_timeout"` - IdleTimeout time.Duration `toml:"idle_timeout"` - ShutdownTimeout time.Duration `toml:"shutdown_timeout"` -} - -type DatabaseConfig struct { - Path string `toml:"path"` + mcdslconfig.Base + Storage StorageConfig `toml:"storage"` + Web WebConfig `toml:"web"` } +// StorageConfig holds blob/layer storage settings. type StorageConfig struct { LayersPath string `toml:"layers_path"` UploadsPath string `toml:"uploads_path"` } -type MCIASConfig struct { - ServerURL string `toml:"server_url"` - CACert string `toml:"ca_cert"` - ServiceName string `toml:"service_name"` - Tags []string `toml:"tags"` -} - +// WebConfig holds the web UI server settings. type WebConfig struct { ListenAddr string `toml:"listen_addr"` GRPCAddr string `toml:"grpc_addr"` CACert string `toml:"ca_cert"` } -type LogConfig struct { - Level string `toml:"level"` -} - -// Load reads a TOML config file, applies environment variable overrides, -// sets defaults, and validates required fields. +// Load reads a TOML config file, applies environment variable overrides +// (MCR_ prefix), sets defaults, and validates required fields. func Load(path string) (*Config, error) { - data, err := os.ReadFile(path) //nolint:gosec // config path is operator-supplied, not user input + cfg, err := mcdslconfig.Load[Config](path, "MCR") if err != nil { - return nil, fmt.Errorf("config: read %s: %w", path, err) - } - - var cfg Config - if err := toml.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("config: parse %s: %w", path, err) - } - - applyEnvOverrides(&cfg) - applyDefaults(&cfg) - - if err := validate(&cfg); err != nil { return nil, err } - - return &cfg, nil + return cfg, nil } -func applyDefaults(cfg *Config) { - if cfg.Server.ReadTimeout == 0 { - cfg.Server.ReadTimeout = 30 * time.Second +// Validate implements the mcdsl config.Validator interface. It checks +// MCR-specific required fields and constraints beyond what config.Base +// validates. +func (c *Config) Validate() error { + if c.Database.Path == "" { + return fmt.Errorf("database.path is required") } - // WriteTimeout defaults to 0 (disabled) — no action needed. - if cfg.Server.IdleTimeout == 0 { - cfg.Server.IdleTimeout = 120 * time.Second + if c.Storage.LayersPath == "" { + return fmt.Errorf("storage.layers_path is required") } - if cfg.Server.ShutdownTimeout == 0 { - cfg.Server.ShutdownTimeout = 60 * time.Second - } - if cfg.Storage.UploadsPath == "" && cfg.Storage.LayersPath != "" { - cfg.Storage.UploadsPath = filepath.Join(filepath.Dir(cfg.Storage.LayersPath), "uploads") - } - if cfg.Log.Level == "" { - cfg.Log.Level = "info" - } -} - -func validate(cfg *Config) error { - required := []struct { - name string - value string - }{ - {"server.listen_addr", cfg.Server.ListenAddr}, - {"server.tls_cert", cfg.Server.TLSCert}, - {"server.tls_key", cfg.Server.TLSKey}, - {"database.path", cfg.Database.Path}, - {"storage.layers_path", cfg.Storage.LayersPath}, - {"mcias.server_url", cfg.MCIAS.ServerURL}, + if c.MCIAS.ServerURL == "" { + return fmt.Errorf("mcias.server_url is required") } - for _, r := range required { - if r.value == "" { - return fmt.Errorf("config: required field %q is missing", r.name) - } + // Default uploads path to sibling of layers path. + if c.Storage.UploadsPath == "" && c.Storage.LayersPath != "" { + c.Storage.UploadsPath = filepath.Join(filepath.Dir(c.Storage.LayersPath), "uploads") } - return validateSameFilesystem(cfg.Storage.LayersPath, cfg.Storage.UploadsPath) + return validateSameFilesystem(c.Storage.LayersPath, c.Storage.UploadsPath) } // validateSameFilesystem checks that two paths reside on the same filesystem @@ -128,16 +67,16 @@ func validate(cfg *Config) error { func validateSameFilesystem(layersPath, uploadsPath string) error { layersDev, err := deviceID(layersPath) if err != nil { - return fmt.Errorf("config: stat layers_path: %w", err) + return fmt.Errorf("stat layers_path: %w", err) } uploadsDev, err := deviceID(uploadsPath) if err != nil { - return fmt.Errorf("config: stat uploads_path: %w", err) + return fmt.Errorf("stat uploads_path: %w", err) } if layersDev != uploadsDev { - return fmt.Errorf("config: storage.layers_path and storage.uploads_path must be on the same filesystem") + return fmt.Errorf("storage.layers_path and storage.uploads_path must be on the same filesystem") } return nil @@ -162,51 +101,3 @@ func deviceID(path string) (uint64, error) { p = parent } } - -// applyEnvOverrides walks the Config struct and overrides fields from -// environment variables with the MCR_ prefix. For example, -// MCR_SERVER_LISTEN_ADDR overrides Config.Server.ListenAddr. -func applyEnvOverrides(cfg *Config) { - applyEnvToStruct(reflect.ValueOf(cfg).Elem(), "MCR") -} - -func applyEnvToStruct(v reflect.Value, prefix string) { - t := v.Type() - for i := range t.NumField() { - field := t.Field(i) - fv := v.Field(i) - - tag := field.Tag.Get("toml") - if tag == "" || tag == "-" { - continue - } - envKey := prefix + "_" + strings.ToUpper(tag) - - if field.Type.Kind() == reflect.Struct { - applyEnvToStruct(fv, envKey) - continue - } - - envVal, ok := os.LookupEnv(envKey) - if !ok { - continue - } - - switch fv.Kind() { - case reflect.String: - fv.SetString(envVal) - case reflect.Int64: - if field.Type == reflect.TypeOf(time.Duration(0)) { - d, err := time.ParseDuration(envVal) - if err == nil { - fv.Set(reflect.ValueOf(d)) - } - } - case reflect.Slice: - if field.Type.Elem().Kind() == reflect.String { - parts := strings.Split(envVal, ",") - fv.Set(reflect.ValueOf(parts)) - } - } - } -} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 84f4c9d..94f126f 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -69,8 +69,8 @@ func TestLoadDefaults(t *testing.T) { if cfg.Server.ReadTimeout.Seconds() != 30 { t.Fatalf("read_timeout: got %v, want 30s", cfg.Server.ReadTimeout) } - if cfg.Server.WriteTimeout != 0 { - t.Fatalf("write_timeout: got %v, want 0", cfg.Server.WriteTimeout) + if cfg.Server.WriteTimeout.Seconds() != 30 { + t.Fatalf("write_timeout: got %v, want 30s (mcdsl default)", cfg.Server.WriteTimeout) } if cfg.Server.IdleTimeout.Seconds() != 120 { t.Fatalf("idle_timeout: got %v, want 120s", cfg.Server.IdleTimeout) diff --git a/internal/db/db.go b/internal/db/db.go index af83685..728aebe 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -3,10 +3,8 @@ package db import ( "database/sql" "fmt" - "os" - "path/filepath" - _ "modernc.org/sqlite" + mcdsldb "git.wntrmute.dev/kyle/mcdsl/db" ) // DB wraps a SQLite database connection. @@ -16,34 +14,11 @@ type DB struct { // Open opens (or creates) a SQLite database at the given path with the // standard Metacircular pragmas: WAL mode, foreign keys, busy timeout. +// The file is created with 0600 permissions. func Open(path string) (*DB, error) { - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0700); err != nil { - return nil, fmt.Errorf("db: create directory %s: %w", dir, err) - } - - sqlDB, err := sql.Open("sqlite", path) + sqlDB, err := mcdsldb.Open(path) if err != nil { - return nil, fmt.Errorf("db: open %s: %w", path, err) + return nil, fmt.Errorf("db: %w", err) } - - pragmas := []string{ - "PRAGMA journal_mode = WAL", - "PRAGMA foreign_keys = ON", - "PRAGMA busy_timeout = 5000", - } - for _, p := range pragmas { - if _, err := sqlDB.Exec(p); err != nil { - _ = sqlDB.Close() - return nil, fmt.Errorf("db: %s: %w", p, err) - } - } - - // Set file permissions to 0600 (owner read/write only). - if err := os.Chmod(path, 0600); err != nil { - _ = sqlDB.Close() - return nil, fmt.Errorf("db: chmod %s: %w", path, err) - } - return &DB{sqlDB}, nil } diff --git a/internal/db/migrate.go b/internal/db/migrate.go index 3063f02..7e83589 100644 --- a/internal/db/migrate.go +++ b/internal/db/migrate.go @@ -1,23 +1,15 @@ package db import ( - "database/sql" - "fmt" + mcdsldb "git.wntrmute.dev/kyle/mcdsl/db" ) -// migration is a numbered schema change. -type migration struct { - version int - name string - sql string -} - -// migrations is the ordered list of schema migrations. -var migrations = []migration{ +// Migrations is the ordered list of MCR schema migrations. +var Migrations = []mcdsldb.Migration{ { - version: 1, - name: "core registry tables", - sql: ` + Version: 1, + Name: "core registry tables", + SQL: ` CREATE TABLE IF NOT EXISTS repositories ( id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE, @@ -71,13 +63,12 @@ CREATE TABLE IF NOT EXISTS uploads ( repository_id INTEGER NOT NULL REFERENCES repositories(id) ON DELETE CASCADE, byte_offset INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) -); -`, +);`, }, { - version: 2, - name: "policy and audit tables", - sql: ` + Version: 2, + Name: "policy and audit tables", + SQL: ` CREATE TABLE IF NOT EXISTS policy_rules ( id INTEGER PRIMARY KEY, priority INTEGER NOT NULL DEFAULT 100, @@ -102,73 +93,18 @@ CREATE TABLE IF NOT EXISTS audit_log ( CREATE INDEX IF NOT EXISTS idx_audit_time ON audit_log (event_time); CREATE INDEX IF NOT EXISTS idx_audit_actor ON audit_log (actor_id); -CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log (event_type); -`, +CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log (event_type);`, }, } // Migrate applies all pending migrations. It creates the schema_migrations // tracking table if it does not exist. Migrations are idempotent. func (d *DB) Migrate() error { - _, err := d.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations ( - version INTEGER PRIMARY KEY, - applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) - )`) - if err != nil { - return fmt.Errorf("db: create schema_migrations: %w", err) - } - - for _, m := range migrations { - applied, err := d.migrationApplied(m.version) - if err != nil { - return err - } - if applied { - continue - } - - tx, err := d.Begin() - if err != nil { - return fmt.Errorf("db: begin migration %d (%s): %w", m.version, m.name, err) - } - - if _, err := tx.Exec(m.sql); err != nil { - _ = tx.Rollback() - return fmt.Errorf("db: migration %d (%s): %w", m.version, m.name, err) - } - - if _, err := tx.Exec(`INSERT INTO schema_migrations (version) VALUES (?)`, m.version); err != nil { - _ = tx.Rollback() - return fmt.Errorf("db: record migration %d: %w", m.version, err) - } - - if err := tx.Commit(); err != nil { - return fmt.Errorf("db: commit migration %d: %w", m.version, err) - } - } - - return nil -} - -func (d *DB) migrationApplied(version int) (bool, error) { - var count int - err := d.QueryRow(`SELECT COUNT(*) FROM schema_migrations WHERE version = ?`, version).Scan(&count) - if err != nil { - return false, fmt.Errorf("db: check migration %d: %w", version, err) - } - return count > 0, nil + return mcdsldb.Migrate(d.DB, Migrations) } // SchemaVersion returns the highest applied migration version, or 0 if // no migrations have been applied. func (d *DB) SchemaVersion() (int, error) { - var version sql.NullInt64 - err := d.QueryRow(`SELECT MAX(version) FROM schema_migrations`).Scan(&version) - if err != nil { - return 0, fmt.Errorf("db: schema version: %w", err) - } - if !version.Valid { - return 0, nil - } - return int(version.Int64), nil + return mcdsldb.SchemaVersion(d.DB) }