P1.2-P1.5: Complete Phase 1 core libraries
Four packages built in parallel: - P1.2 runtime: Container runtime abstraction with podman implementation. Interface (Pull/Run/Stop/Remove/Inspect/List), ContainerSpec/ContainerInfo types, CLI arg building, version extraction from image tags. 2 tests. - P1.3 servicedef: TOML service definition file parsing. Load/Write/LoadAll, validation (required fields, unique component names), proto conversion. 5 tests. - P1.4 config: CLI and agent config loading from TOML. Duration type for time fields, env var overrides (MCP_*/MCP_AGENT_*), required field validation, sensible defaults. 7 tests. - P1.5 auth: MCIAS integration. Token validator with 30s SHA-256 cache, gRPC unary interceptor (admin role enforcement, audit logging), Login/LoadToken/SaveToken for CLI. 9 tests. All packages pass build, vet, lint, and test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -16,3 +16,4 @@ srv/
|
|||||||
|
|
||||||
# OS
|
# OS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
.claude/
|
||||||
|
|||||||
@@ -8,10 +8,10 @@
|
|||||||
## Phase 1: Core Libraries
|
## Phase 1: Core Libraries
|
||||||
|
|
||||||
- [x] **P1.1** Registry package (`internal/registry/`)
|
- [x] **P1.1** Registry package (`internal/registry/`)
|
||||||
- [ ] **P1.2** Runtime package (`internal/runtime/`)
|
- [x] **P1.2** Runtime package (`internal/runtime/`)
|
||||||
- [ ] **P1.3** Service definition package (`internal/servicedef/`)
|
- [x] **P1.3** Service definition package (`internal/servicedef/`)
|
||||||
- [ ] **P1.4** Config package (`internal/config/`)
|
- [x] **P1.4** Config package (`internal/config/`)
|
||||||
- [ ] **P1.5** Auth package (`internal/auth/`)
|
- [x] **P1.5** Auth package (`internal/auth/`)
|
||||||
|
|
||||||
## Phase 2: Agent
|
## Phase 2: Agent
|
||||||
|
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -3,6 +3,7 @@ module git.wntrmute.dev/kyle/mcp
|
|||||||
go 1.25.7
|
go 1.25.7
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/pelletier/go-toml/v2 v2.3.0
|
||||||
github.com/spf13/cobra v1.10.2
|
github.com/spf13/cobra v1.10.2
|
||||||
google.golang.org/grpc v1.79.3
|
google.golang.org/grpc v1.79.3
|
||||||
google.golang.org/protobuf v1.36.11
|
google.golang.org/protobuf v1.36.11
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -23,6 +23,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
|
|||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
|
|||||||
332
internal/auth/auth.go
Normal file
332
internal/auth/auth.go
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// cacheTTL is the duration cached token validation results are valid.
|
||||||
|
const cacheTTL = 30 * time.Second
|
||||||
|
|
||||||
|
// TokenInfo holds the result of a token validation.
|
||||||
|
type TokenInfo struct {
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Roles []string `json:"roles"`
|
||||||
|
AccountType string `json:"account_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRole reports whether the token has the given role.
|
||||||
|
func (t *TokenInfo) HasRole(role string) bool {
|
||||||
|
for _, r := range t.Roles {
|
||||||
|
if r == role {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenValidator validates bearer tokens against MCIAS.
|
||||||
|
type TokenValidator interface {
|
||||||
|
ValidateToken(ctx context.Context, token string) (*TokenInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenCache stores validated token results with a TTL.
|
||||||
|
type tokenCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
entries map[string]cacheEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheEntry struct {
|
||||||
|
info *TokenInfo
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTokenCache() *tokenCache {
|
||||||
|
return &tokenCache{
|
||||||
|
entries: make(map[string]cacheEntry),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) get(hash string) (*TokenInfo, bool) {
|
||||||
|
c.mu.RLock()
|
||||||
|
entry, ok := c.entries[hash]
|
||||||
|
c.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if time.Now().After(entry.expiresAt) {
|
||||||
|
c.mu.Lock()
|
||||||
|
delete(c.entries, hash)
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return entry.info, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) put(hash string, info *TokenInfo) {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.entries[hash] = cacheEntry{
|
||||||
|
info: info,
|
||||||
|
expiresAt: time.Now().Add(cacheTTL),
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHTTPClient creates an HTTP client with TLS 1.3 minimum. If caCertPath
|
||||||
|
// is non-empty, the CA certificate is loaded and added to the root CA pool.
|
||||||
|
func newHTTPClient(caCertPath string) (*http.Client, error) {
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS13,
|
||||||
|
}
|
||||||
|
|
||||||
|
if caCertPath != "" {
|
||||||
|
caCert, err := os.ReadFile(caCertPath) //nolint:gosec // path from trusted config
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read CA cert %q: %w", caCertPath, err)
|
||||||
|
}
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
if !pool.AppendCertsFromPEM(caCert) {
|
||||||
|
return nil, fmt.Errorf("parse CA cert %q: no valid certificates found", caCertPath)
|
||||||
|
}
|
||||||
|
tlsConfig.RootCAs = pool
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: tlsConfig,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCIASValidator validates tokens by calling the MCIAS HTTP endpoint.
|
||||||
|
type MCIASValidator struct {
|
||||||
|
ServerURL string
|
||||||
|
CACertPath string
|
||||||
|
httpClient *http.Client
|
||||||
|
cache *tokenCache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMCIASValidator creates a validator that calls MCIAS at the given URL.
|
||||||
|
// If caCertPath is non-empty, the CA certificate is loaded and used for TLS.
|
||||||
|
func NewMCIASValidator(serverURL, caCertPath string) (*MCIASValidator, error) {
|
||||||
|
client, err := newHTTPClient(caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MCIASValidator{
|
||||||
|
ServerURL: strings.TrimRight(serverURL, "/"),
|
||||||
|
CACertPath: caCertPath,
|
||||||
|
httpClient: client,
|
||||||
|
cache: newTokenCache(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenHash(token string) string {
|
||||||
|
h := sha256.Sum256([]byte(token))
|
||||||
|
return hex.EncodeToString(h[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken validates a bearer token against MCIAS. Results are cached
|
||||||
|
// for 30 seconds keyed by the SHA-256 hash of the token.
|
||||||
|
func (v *MCIASValidator) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) {
|
||||||
|
hash := tokenHash(token)
|
||||||
|
|
||||||
|
if info, ok := v.cache.get(hash); ok {
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
url := v.ServerURL + "/v1/token/validate"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create validate request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
|
||||||
|
resp, err := v.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validate token: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read validate response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("validate token: MCIAS returned %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var info TokenInfo
|
||||||
|
if err := json.Unmarshal(body, &info); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse validate response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
v.cache.put(hash, &info)
|
||||||
|
return &info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// contextKey is an unexported type for context keys in this package.
|
||||||
|
type contextKey struct{}
|
||||||
|
|
||||||
|
// tokenInfoKey is the context key for TokenInfo.
|
||||||
|
var tokenInfoKey = contextKey{}
|
||||||
|
|
||||||
|
// ContextWithTokenInfo returns a new context carrying the given TokenInfo.
|
||||||
|
func ContextWithTokenInfo(ctx context.Context, info *TokenInfo) context.Context {
|
||||||
|
return context.WithValue(ctx, tokenInfoKey, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenInfoFromContext retrieves TokenInfo from the context, or nil if absent.
|
||||||
|
func TokenInfoFromContext(ctx context.Context) *TokenInfo {
|
||||||
|
info, _ := ctx.Value(tokenInfoKey).(*TokenInfo)
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthInterceptor returns a gRPC unary server interceptor that validates
|
||||||
|
// bearer tokens and requires the "admin" role.
|
||||||
|
func AuthInterceptor(validator TokenValidator) grpc.UnaryServerInterceptor {
|
||||||
|
return func(
|
||||||
|
ctx context.Context,
|
||||||
|
req any,
|
||||||
|
info *grpc.UnaryServerInfo,
|
||||||
|
handler grpc.UnaryHandler,
|
||||||
|
) (any, error) {
|
||||||
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "missing metadata")
|
||||||
|
}
|
||||||
|
|
||||||
|
authValues := md.Get("authorization")
|
||||||
|
if len(authValues) == 0 {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "missing authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
authHeader := authValues[0]
|
||||||
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "malformed authorization header")
|
||||||
|
}
|
||||||
|
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
|
||||||
|
tokenInfo, err := validator.ValidateToken(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("token validation failed", "method", info.FullMethod, "error", err)
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "token validation failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tokenInfo.Valid {
|
||||||
|
return nil, status.Error(codes.Unauthenticated, "invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tokenInfo.HasRole("admin") {
|
||||||
|
slog.Warn("permission denied", "method", info.FullMethod, "user", tokenInfo.Username)
|
||||||
|
return nil, status.Error(codes.PermissionDenied, "admin role required")
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("rpc", "method", info.FullMethod, "user", tokenInfo.Username, "account_type", tokenInfo.AccountType)
|
||||||
|
|
||||||
|
ctx = ContextWithTokenInfo(ctx, tokenInfo)
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login authenticates with MCIAS and returns a bearer token.
|
||||||
|
func Login(serverURL, caCertPath, username, password string) (string, error) {
|
||||||
|
client, err := newHTTPClient(caCertPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
serverURL = strings.TrimRight(serverURL, "/")
|
||||||
|
url := serverURL + "/v1/auth/login"
|
||||||
|
|
||||||
|
payload := struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}{
|
||||||
|
Username: username,
|
||||||
|
Password: password,
|
||||||
|
}
|
||||||
|
body, err := json.Marshal(payload) //nolint:gosec // intentional login credential payload
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshal login request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create login request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("login request: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read login response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("login failed: MCIAS returned %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return "", fmt.Errorf("parse login response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Token == "" {
|
||||||
|
return "", fmt.Errorf("login response missing token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.Token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadToken reads a token from the given file path and trims whitespace.
|
||||||
|
func LoadToken(path string) (string, error) {
|
||||||
|
data, err := os.ReadFile(path) //nolint:gosec // path from trusted caller
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("load token from %q: %w", path, err)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(string(data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveToken writes a token to the given file with 0600 permissions.
|
||||||
|
// Parent directories are created if they do not exist.
|
||||||
|
func SaveToken(path string, token string) error {
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||||
|
return fmt.Errorf("create token directory %q: %w", dir, err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, []byte(token+"\n"), 0600); err != nil {
|
||||||
|
return fmt.Errorf("save token to %q: %w", path, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
391
internal/auth/auth_test.go
Normal file
391
internal/auth/auth_test.go
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockMCIAS creates a test MCIAS server that responds to /v1/token/validate.
|
||||||
|
// The handler function receives the authorization header and returns (response, statusCode).
|
||||||
|
func mockMCIAS(t *testing.T, handler func(authHeader string) (any, int)) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
resp, code := handler(authHeader)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(code)
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
t.Errorf("encode response: %v", err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatorFromServer(t *testing.T, server *httptest.Server) *MCIASValidator {
|
||||||
|
t.Helper()
|
||||||
|
v, err := NewMCIASValidator(server.URL, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create validator: %v", err)
|
||||||
|
}
|
||||||
|
v.httpClient = server.Client()
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// callInterceptor invokes the interceptor with the given context and validator.
|
||||||
|
func callInterceptor(ctx context.Context, validator TokenValidator) (*TokenInfo, error) {
|
||||||
|
interceptor := AuthInterceptor(validator)
|
||||||
|
info := &grpc.UnaryServerInfo{FullMethod: "/mcp.v1.MCPService/TestMethod"}
|
||||||
|
|
||||||
|
var captured *TokenInfo
|
||||||
|
handler := func(ctx context.Context, req any) (any, error) {
|
||||||
|
captured = TokenInfoFromContext(ctx)
|
||||||
|
return "ok", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := interceptor(ctx, nil, info, handler)
|
||||||
|
return captured, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterceptorRejectsNoToken(t *testing.T) {
|
||||||
|
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||||
|
return map[string]any{"valid": false}, http.StatusOK
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
v := validatorFromServer(t, server)
|
||||||
|
|
||||||
|
// No metadata at all.
|
||||||
|
_, err := callInterceptor(context.Background(), v)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
|
||||||
|
t.Fatalf("expected Unauthenticated, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metadata present but no authorization key.
|
||||||
|
md := metadata.Pairs("other-key", "value")
|
||||||
|
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||||
|
_, err = callInterceptor(ctx, v)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error with empty auth, got nil")
|
||||||
|
}
|
||||||
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
|
||||||
|
t.Fatalf("expected Unauthenticated, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterceptorRejectsMalformedToken(t *testing.T) {
|
||||||
|
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||||
|
return map[string]any{"valid": false}, http.StatusOK
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
v := validatorFromServer(t, server)
|
||||||
|
|
||||||
|
md := metadata.Pairs("authorization", "NotBearer xxx")
|
||||||
|
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||||
|
|
||||||
|
_, err := callInterceptor(ctx, v)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
|
||||||
|
t.Fatalf("expected Unauthenticated, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterceptorRejectsInvalidToken(t *testing.T) {
|
||||||
|
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||||
|
return &TokenInfo{Valid: false}, http.StatusOK
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
v := validatorFromServer(t, server)
|
||||||
|
|
||||||
|
md := metadata.Pairs("authorization", "Bearer bad-token")
|
||||||
|
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||||
|
|
||||||
|
_, err := callInterceptor(ctx, v)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
|
||||||
|
t.Fatalf("expected Unauthenticated, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterceptorRejectsNonAdmin(t *testing.T) {
|
||||||
|
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||||
|
return &TokenInfo{
|
||||||
|
Valid: true,
|
||||||
|
Username: "regularuser",
|
||||||
|
Roles: []string{"user"},
|
||||||
|
AccountType: "human",
|
||||||
|
}, http.StatusOK
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
v := validatorFromServer(t, server)
|
||||||
|
|
||||||
|
md := metadata.Pairs("authorization", "Bearer user-token")
|
||||||
|
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||||
|
|
||||||
|
_, err := callInterceptor(ctx, v)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied {
|
||||||
|
t.Fatalf("expected PermissionDenied, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterceptorAcceptsAdmin(t *testing.T) {
|
||||||
|
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||||
|
return &TokenInfo{
|
||||||
|
Valid: true,
|
||||||
|
Username: "kyle",
|
||||||
|
Roles: []string{"admin", "user"},
|
||||||
|
AccountType: "human",
|
||||||
|
}, http.StatusOK
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
v := validatorFromServer(t, server)
|
||||||
|
|
||||||
|
md := metadata.Pairs("authorization", "Bearer admin-token")
|
||||||
|
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||||
|
|
||||||
|
captured, err := callInterceptor(ctx, v)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if captured == nil {
|
||||||
|
t.Fatal("expected token info in context, got nil")
|
||||||
|
}
|
||||||
|
if captured.Username != "kyle" {
|
||||||
|
t.Fatalf("username: got %q, want %q", captured.Username, "kyle")
|
||||||
|
}
|
||||||
|
if !captured.HasRole("admin") {
|
||||||
|
t.Fatal("expected admin role")
|
||||||
|
}
|
||||||
|
if captured.AccountType != "human" {
|
||||||
|
t.Fatalf("account_type: got %q, want %q", captured.AccountType, "human")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenCaching(t *testing.T) {
|
||||||
|
var requestCount atomic.Int64
|
||||||
|
|
||||||
|
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||||
|
requestCount.Add(1)
|
||||||
|
return &TokenInfo{
|
||||||
|
Valid: true,
|
||||||
|
Username: "kyle",
|
||||||
|
Roles: []string{"admin"},
|
||||||
|
AccountType: "human",
|
||||||
|
}, http.StatusOK
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
v := validatorFromServer(t, server)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// First call should hit the server.
|
||||||
|
info1, err := v.ValidateToken(ctx, "same-token")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first validate: %v", err)
|
||||||
|
}
|
||||||
|
if !info1.Valid {
|
||||||
|
t.Fatal("expected valid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second call with the same token should be cached.
|
||||||
|
info2, err := v.ValidateToken(ctx, "same-token")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second validate: %v", err)
|
||||||
|
}
|
||||||
|
if info2.Username != info1.Username {
|
||||||
|
t.Fatalf("cached result mismatch: got %q, want %q", info2.Username, info1.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := requestCount.Load(); count != 1 {
|
||||||
|
t.Fatalf("expected 1 MCIAS request, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenCacheSeparateEntries(t *testing.T) {
|
||||||
|
var requestCount atomic.Int64
|
||||||
|
|
||||||
|
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||||
|
requestCount.Add(1)
|
||||||
|
// Return different usernames based on the token.
|
||||||
|
token := authHeader[len("Bearer "):]
|
||||||
|
return &TokenInfo{
|
||||||
|
Valid: true,
|
||||||
|
Username: fmt.Sprintf("user-for-%s", token),
|
||||||
|
Roles: []string{"admin"},
|
||||||
|
AccountType: "human",
|
||||||
|
}, http.StatusOK
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
v := validatorFromServer(t, server)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
info1, err := v.ValidateToken(ctx, "token-a")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("validate token-a: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info2, err := v.ValidateToken(ctx, "token-b")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("validate token-b: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info1.Username == info2.Username {
|
||||||
|
t.Fatalf("different tokens should have different cache entries, both got %q", info1.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := requestCount.Load(); count != 2 {
|
||||||
|
t.Fatalf("expected 2 MCIAS requests for different tokens, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Repeat calls should be cached.
|
||||||
|
_, err = v.ValidateToken(ctx, "token-a")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cached validate token-a: %v", err)
|
||||||
|
}
|
||||||
|
_, err = v.ValidateToken(ctx, "token-b")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("cached validate token-b: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := requestCount.Load(); count != 2 {
|
||||||
|
t.Fatalf("expected still 2 MCIAS requests after cache hits, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadSaveToken(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "subdir", "token")
|
||||||
|
|
||||||
|
token := "test-bearer-token-12345"
|
||||||
|
|
||||||
|
if err := SaveToken(path, token); err != nil {
|
||||||
|
t.Fatalf("save: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check file permissions.
|
||||||
|
fi, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stat: %v", err)
|
||||||
|
}
|
||||||
|
if perm := fi.Mode().Perm(); perm != 0600 {
|
||||||
|
t.Fatalf("permissions: got %o, want 0600", perm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load and verify.
|
||||||
|
loaded, err := LoadToken(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load: %v", err)
|
||||||
|
}
|
||||||
|
if loaded != token {
|
||||||
|
t.Fatalf("loaded: got %q, want %q", loaded, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogin(t *testing.T) {
|
||||||
|
expectedToken := "mcias-session-token-xyz"
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/auth/login" {
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("unexpected method: %s", r.Method)
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var body struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||||
|
t.Errorf("decode request body: %v", err)
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if body.Username != "kyle" || body.Password != "secret" {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid credentials"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"token": expectedToken})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
token, err := Login(server.URL, "", "kyle", "secret")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("login: %v", err)
|
||||||
|
}
|
||||||
|
if token != expectedToken {
|
||||||
|
t.Fatalf("token: got %q, want %q", token, expectedToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginBadCredentials(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid credentials"})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
_, err := Login(server.URL, "", "kyle", "wrong")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for bad credentials")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextTokenInfo(t *testing.T) {
|
||||||
|
info := &TokenInfo{
|
||||||
|
Valid: true,
|
||||||
|
Username: "kyle",
|
||||||
|
Roles: []string{"admin"},
|
||||||
|
AccountType: "human",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := ContextWithTokenInfo(context.Background(), info)
|
||||||
|
got := TokenInfoFromContext(ctx)
|
||||||
|
if got == nil {
|
||||||
|
t.Fatal("expected token info from context, got nil")
|
||||||
|
}
|
||||||
|
if got.Username != "kyle" {
|
||||||
|
t.Fatalf("username: got %q, want %q", got.Username, "kyle")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty context should return nil.
|
||||||
|
got = TokenInfoFromContext(context.Background())
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("expected nil from empty context, got %+v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
186
internal/config/agent.go
Normal file
186
internal/config/agent.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
toml "github.com/pelletier/go-toml/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentConfig is the configuration for the mcp-agent daemon.
|
||||||
|
type AgentConfig struct {
|
||||||
|
Server ServerConfig `toml:"server"`
|
||||||
|
Database DatabaseConfig `toml:"database"`
|
||||||
|
MCIAS MCIASConfig `toml:"mcias"`
|
||||||
|
Agent AgentSettings `toml:"agent"`
|
||||||
|
Monitor MonitorConfig `toml:"monitor"`
|
||||||
|
Log LogConfig `toml:"log"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig holds gRPC server listen address and TLS paths.
|
||||||
|
type ServerConfig struct {
|
||||||
|
GRPCAddr string `toml:"grpc_addr"`
|
||||||
|
TLSCert string `toml:"tls_cert"`
|
||||||
|
TLSKey string `toml:"tls_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DatabaseConfig holds the SQLite database path.
|
||||||
|
type DatabaseConfig struct {
|
||||||
|
Path string `toml:"path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AgentSettings holds agent-specific settings.
|
||||||
|
type AgentSettings struct {
|
||||||
|
NodeName string `toml:"node_name"`
|
||||||
|
ContainerRuntime string `toml:"container_runtime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MonitorConfig holds monitoring loop parameters.
|
||||||
|
type MonitorConfig struct {
|
||||||
|
Interval Duration `toml:"interval"`
|
||||||
|
AlertCommand []string `toml:"alert_command"`
|
||||||
|
Cooldown Duration `toml:"cooldown"`
|
||||||
|
FlapThreshold int `toml:"flap_threshold"`
|
||||||
|
FlapWindow Duration `toml:"flap_window"`
|
||||||
|
Retention Duration `toml:"retention"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogConfig holds logging settings.
|
||||||
|
type LogConfig struct {
|
||||||
|
Level string `toml:"level"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duration wraps time.Duration to support TOML string unmarshaling.
|
||||||
|
// It accepts Go duration strings (e.g. "60s", "15m", "24h") plus a
|
||||||
|
// "d" suffix for days (e.g. "30d" becomes 30*24h).
|
||||||
|
type Duration struct {
|
||||||
|
time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalText implements encoding.TextUnmarshaler for TOML parsing.
|
||||||
|
func (d *Duration) UnmarshalText(text []byte) error {
|
||||||
|
s := string(text)
|
||||||
|
if s == "" {
|
||||||
|
d.Duration = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle "d" suffix for days.
|
||||||
|
if len(s) > 1 && s[len(s)-1] == 'd' {
|
||||||
|
var days float64
|
||||||
|
if _, err := fmt.Sscanf(s[:len(s)-1], "%f", &days); err != nil {
|
||||||
|
return fmt.Errorf("parse duration %q: %w", s, err)
|
||||||
|
}
|
||||||
|
d.Duration = time.Duration(days * float64(24*time.Hour))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dur, err := time.ParseDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse duration %q: %w", s, err)
|
||||||
|
}
|
||||||
|
d.Duration = dur
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalText implements encoding.TextMarshaler for TOML serialization.
|
||||||
|
func (d Duration) MarshalText() ([]byte, error) {
|
||||||
|
return []byte(d.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAgentConfig reads and validates an agent configuration file.
|
||||||
|
// Environment variables override file values for select fields.
|
||||||
|
func LoadAgentConfig(path string) (*AgentConfig, error) {
|
||||||
|
data, err := os.ReadFile(path) //nolint:gosec // config path from trusted CLI flag
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read config %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg AgentConfig
|
||||||
|
if err := toml.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse config %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
applyAgentDefaults(&cfg)
|
||||||
|
applyAgentEnvOverrides(&cfg)
|
||||||
|
|
||||||
|
if err := validateAgentConfig(&cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("validate config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyAgentDefaults(cfg *AgentConfig) {
|
||||||
|
if cfg.Monitor.Interval.Duration == 0 {
|
||||||
|
cfg.Monitor.Interval.Duration = 60 * time.Second
|
||||||
|
}
|
||||||
|
if cfg.Monitor.Cooldown.Duration == 0 {
|
||||||
|
cfg.Monitor.Cooldown.Duration = 15 * time.Minute
|
||||||
|
}
|
||||||
|
if cfg.Monitor.FlapThreshold == 0 {
|
||||||
|
cfg.Monitor.FlapThreshold = 3
|
||||||
|
}
|
||||||
|
if cfg.Monitor.FlapWindow.Duration == 0 {
|
||||||
|
cfg.Monitor.FlapWindow.Duration = 10 * time.Minute
|
||||||
|
}
|
||||||
|
if cfg.Monitor.Retention.Duration == 0 {
|
||||||
|
cfg.Monitor.Retention.Duration = 30 * 24 * time.Hour // 30 days
|
||||||
|
}
|
||||||
|
if cfg.Log.Level == "" {
|
||||||
|
cfg.Log.Level = "info"
|
||||||
|
}
|
||||||
|
if cfg.Agent.ContainerRuntime == "" {
|
||||||
|
cfg.Agent.ContainerRuntime = "podman"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyAgentEnvOverrides(cfg *AgentConfig) {
|
||||||
|
if v := os.Getenv("MCP_AGENT_SERVER_GRPC_ADDR"); v != "" {
|
||||||
|
cfg.Server.GRPCAddr = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("MCP_AGENT_SERVER_TLS_CERT"); v != "" {
|
||||||
|
cfg.Server.TLSCert = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("MCP_AGENT_SERVER_TLS_KEY"); v != "" {
|
||||||
|
cfg.Server.TLSKey = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("MCP_AGENT_DATABASE_PATH"); v != "" {
|
||||||
|
cfg.Database.Path = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("MCP_AGENT_NODE_NAME"); v != "" {
|
||||||
|
cfg.Agent.NodeName = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("MCP_AGENT_CONTAINER_RUNTIME"); v != "" {
|
||||||
|
cfg.Agent.ContainerRuntime = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("MCP_AGENT_LOG_LEVEL"); v != "" {
|
||||||
|
cfg.Log.Level = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateAgentConfig(cfg *AgentConfig) error {
|
||||||
|
if cfg.Server.GRPCAddr == "" {
|
||||||
|
return fmt.Errorf("server.grpc_addr is required")
|
||||||
|
}
|
||||||
|
if cfg.Server.TLSCert == "" {
|
||||||
|
return fmt.Errorf("server.tls_cert is required")
|
||||||
|
}
|
||||||
|
if cfg.Server.TLSKey == "" {
|
||||||
|
return fmt.Errorf("server.tls_key is required")
|
||||||
|
}
|
||||||
|
if cfg.Database.Path == "" {
|
||||||
|
return fmt.Errorf("database.path is required")
|
||||||
|
}
|
||||||
|
if cfg.MCIAS.ServerURL == "" {
|
||||||
|
return fmt.Errorf("mcias.server_url is required")
|
||||||
|
}
|
||||||
|
if cfg.MCIAS.ServiceName == "" {
|
||||||
|
return fmt.Errorf("mcias.service_name is required")
|
||||||
|
}
|
||||||
|
if cfg.Agent.NodeName == "" {
|
||||||
|
return fmt.Errorf("agent.node_name is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
97
internal/config/cli.go
Normal file
97
internal/config/cli.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
toml "github.com/pelletier/go-toml/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CLIConfig is the configuration for the mcp CLI binary.
|
||||||
|
type CLIConfig struct {
|
||||||
|
Services ServicesConfig `toml:"services"`
|
||||||
|
MCIAS MCIASConfig `toml:"mcias"`
|
||||||
|
Auth AuthConfig `toml:"auth"`
|
||||||
|
Nodes []NodeConfig `toml:"nodes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServicesConfig defines where service definition files live.
|
||||||
|
type ServicesConfig struct {
|
||||||
|
Dir string `toml:"dir"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCIASConfig holds MCIAS connection settings, shared by CLI and agent.
|
||||||
|
type MCIASConfig struct {
|
||||||
|
ServerURL string `toml:"server_url"`
|
||||||
|
CACert string `toml:"ca_cert"`
|
||||||
|
ServiceName string `toml:"service_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthConfig holds authentication settings for the CLI.
|
||||||
|
type AuthConfig struct {
|
||||||
|
TokenPath string `toml:"token_path"`
|
||||||
|
Username string `toml:"username"` // optional, for unattended operation
|
||||||
|
PasswordFile string `toml:"password_file"` // optional, for unattended operation
|
||||||
|
}
|
||||||
|
|
||||||
|
// NodeConfig defines a managed node that the CLI connects to.
|
||||||
|
type NodeConfig struct {
|
||||||
|
Name string `toml:"name"`
|
||||||
|
Address string `toml:"address"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadCLIConfig reads and validates a CLI configuration file.
|
||||||
|
// Environment variables override file values for select fields.
|
||||||
|
func LoadCLIConfig(path string) (*CLIConfig, error) {
|
||||||
|
data, err := os.ReadFile(path) //nolint:gosec // config path from trusted CLI flag
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read config %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg CLIConfig
|
||||||
|
if err := toml.Unmarshal(data, &cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse config %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
applyCLIEnvOverrides(&cfg)
|
||||||
|
|
||||||
|
if err := validateCLIConfig(&cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("validate config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyCLIEnvOverrides(cfg *CLIConfig) {
|
||||||
|
if v := os.Getenv("MCP_SERVICES_DIR"); v != "" {
|
||||||
|
cfg.Services.Dir = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("MCP_MCIAS_SERVER_URL"); v != "" {
|
||||||
|
cfg.MCIAS.ServerURL = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("MCP_MCIAS_CA_CERT"); v != "" {
|
||||||
|
cfg.MCIAS.CACert = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("MCP_MCIAS_SERVICE_NAME"); v != "" {
|
||||||
|
cfg.MCIAS.ServiceName = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("MCP_AUTH_TOKEN_PATH"); v != "" {
|
||||||
|
cfg.Auth.TokenPath = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateCLIConfig(cfg *CLIConfig) error {
|
||||||
|
if cfg.Services.Dir == "" {
|
||||||
|
return fmt.Errorf("services.dir is required")
|
||||||
|
}
|
||||||
|
if cfg.MCIAS.ServerURL == "" {
|
||||||
|
return fmt.Errorf("mcias.server_url is required")
|
||||||
|
}
|
||||||
|
if cfg.MCIAS.ServiceName == "" {
|
||||||
|
return fmt.Errorf("mcias.service_name is required")
|
||||||
|
}
|
||||||
|
if cfg.Auth.TokenPath == "" {
|
||||||
|
return fmt.Errorf("auth.token_path is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
476
internal/config/config_test.go
Normal file
476
internal/config/config_test.go
Normal file
@@ -0,0 +1,476 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testCLIConfig = `
|
||||||
|
[services]
|
||||||
|
dir = "/home/kyle/.config/mcp/services"
|
||||||
|
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
ca_cert = "/etc/mcp/ca.pem"
|
||||||
|
service_name = "mcp"
|
||||||
|
|
||||||
|
[auth]
|
||||||
|
token_path = "/home/kyle/.config/mcp/token"
|
||||||
|
username = "kyle"
|
||||||
|
password_file = "/home/kyle/.config/mcp/password"
|
||||||
|
|
||||||
|
[[nodes]]
|
||||||
|
name = "rift"
|
||||||
|
address = "100.95.252.120:9444"
|
||||||
|
|
||||||
|
[[nodes]]
|
||||||
|
name = "cascade"
|
||||||
|
address = "100.95.252.121:9444"
|
||||||
|
`
|
||||||
|
|
||||||
|
const testAgentConfig = `
|
||||||
|
[server]
|
||||||
|
grpc_addr = "100.95.252.120:9444"
|
||||||
|
tls_cert = "/srv/mcp/certs/cert.pem"
|
||||||
|
tls_key = "/srv/mcp/certs/key.pem"
|
||||||
|
|
||||||
|
[database]
|
||||||
|
path = "/srv/mcp/mcp.db"
|
||||||
|
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
ca_cert = "/etc/mcp/ca.pem"
|
||||||
|
service_name = "mcp-agent"
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
node_name = "rift"
|
||||||
|
container_runtime = "podman"
|
||||||
|
|
||||||
|
[monitor]
|
||||||
|
interval = "60s"
|
||||||
|
alert_command = ["notify-send", "MCP Alert"]
|
||||||
|
cooldown = "15m"
|
||||||
|
flap_threshold = 3
|
||||||
|
flap_window = "10m"
|
||||||
|
retention = "30d"
|
||||||
|
|
||||||
|
[log]
|
||||||
|
level = "debug"
|
||||||
|
`
|
||||||
|
|
||||||
|
func writeTempConfig(t *testing.T, content string) string {
|
||||||
|
t.Helper()
|
||||||
|
path := filepath.Join(t.TempDir(), "config.toml")
|
||||||
|
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
|
||||||
|
t.Fatalf("write temp config: %v", err)
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadCLIConfig(t *testing.T) {
|
||||||
|
path := writeTempConfig(t, testCLIConfig)
|
||||||
|
|
||||||
|
cfg, err := LoadCLIConfig(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Services.Dir != "/home/kyle/.config/mcp/services" {
|
||||||
|
t.Fatalf("services.dir: got %q", cfg.Services.Dir)
|
||||||
|
}
|
||||||
|
if cfg.MCIAS.ServerURL != "https://mcias.metacircular.net:8443" {
|
||||||
|
t.Fatalf("mcias.server_url: got %q", cfg.MCIAS.ServerURL)
|
||||||
|
}
|
||||||
|
if cfg.MCIAS.CACert != "/etc/mcp/ca.pem" {
|
||||||
|
t.Fatalf("mcias.ca_cert: got %q", cfg.MCIAS.CACert)
|
||||||
|
}
|
||||||
|
if cfg.MCIAS.ServiceName != "mcp" {
|
||||||
|
t.Fatalf("mcias.service_name: got %q", cfg.MCIAS.ServiceName)
|
||||||
|
}
|
||||||
|
if cfg.Auth.TokenPath != "/home/kyle/.config/mcp/token" {
|
||||||
|
t.Fatalf("auth.token_path: got %q", cfg.Auth.TokenPath)
|
||||||
|
}
|
||||||
|
if cfg.Auth.Username != "kyle" {
|
||||||
|
t.Fatalf("auth.username: got %q", cfg.Auth.Username)
|
||||||
|
}
|
||||||
|
if cfg.Auth.PasswordFile != "/home/kyle/.config/mcp/password" {
|
||||||
|
t.Fatalf("auth.password_file: got %q", cfg.Auth.PasswordFile)
|
||||||
|
}
|
||||||
|
if len(cfg.Nodes) != 2 {
|
||||||
|
t.Fatalf("nodes: got %d, want 2", len(cfg.Nodes))
|
||||||
|
}
|
||||||
|
if cfg.Nodes[0].Name != "rift" || cfg.Nodes[0].Address != "100.95.252.120:9444" {
|
||||||
|
t.Fatalf("nodes[0]: got %+v", cfg.Nodes[0])
|
||||||
|
}
|
||||||
|
if cfg.Nodes[1].Name != "cascade" || cfg.Nodes[1].Address != "100.95.252.121:9444" {
|
||||||
|
t.Fatalf("nodes[1]: got %+v", cfg.Nodes[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadAgentConfig(t *testing.T) {
|
||||||
|
path := writeTempConfig(t, testAgentConfig)
|
||||||
|
|
||||||
|
cfg, err := LoadAgentConfig(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Server.GRPCAddr != "100.95.252.120:9444" {
|
||||||
|
t.Fatalf("server.grpc_addr: got %q", cfg.Server.GRPCAddr)
|
||||||
|
}
|
||||||
|
if cfg.Server.TLSCert != "/srv/mcp/certs/cert.pem" {
|
||||||
|
t.Fatalf("server.tls_cert: got %q", cfg.Server.TLSCert)
|
||||||
|
}
|
||||||
|
if cfg.Server.TLSKey != "/srv/mcp/certs/key.pem" {
|
||||||
|
t.Fatalf("server.tls_key: got %q", cfg.Server.TLSKey)
|
||||||
|
}
|
||||||
|
if cfg.Database.Path != "/srv/mcp/mcp.db" {
|
||||||
|
t.Fatalf("database.path: got %q", cfg.Database.Path)
|
||||||
|
}
|
||||||
|
if cfg.MCIAS.ServerURL != "https://mcias.metacircular.net:8443" {
|
||||||
|
t.Fatalf("mcias.server_url: got %q", cfg.MCIAS.ServerURL)
|
||||||
|
}
|
||||||
|
if cfg.MCIAS.ServiceName != "mcp-agent" {
|
||||||
|
t.Fatalf("mcias.service_name: got %q", cfg.MCIAS.ServiceName)
|
||||||
|
}
|
||||||
|
if cfg.Agent.NodeName != "rift" {
|
||||||
|
t.Fatalf("agent.node_name: got %q", cfg.Agent.NodeName)
|
||||||
|
}
|
||||||
|
if cfg.Agent.ContainerRuntime != "podman" {
|
||||||
|
t.Fatalf("agent.container_runtime: got %q", cfg.Agent.ContainerRuntime)
|
||||||
|
}
|
||||||
|
if cfg.Monitor.Interval.Duration != 60*time.Second {
|
||||||
|
t.Fatalf("monitor.interval: got %v", cfg.Monitor.Interval.Duration)
|
||||||
|
}
|
||||||
|
if len(cfg.Monitor.AlertCommand) != 2 || cfg.Monitor.AlertCommand[0] != "notify-send" {
|
||||||
|
t.Fatalf("monitor.alert_command: got %v", cfg.Monitor.AlertCommand)
|
||||||
|
}
|
||||||
|
if cfg.Monitor.Cooldown.Duration != 15*time.Minute {
|
||||||
|
t.Fatalf("monitor.cooldown: got %v", cfg.Monitor.Cooldown.Duration)
|
||||||
|
}
|
||||||
|
if cfg.Monitor.FlapThreshold != 3 {
|
||||||
|
t.Fatalf("monitor.flap_threshold: got %d", cfg.Monitor.FlapThreshold)
|
||||||
|
}
|
||||||
|
if cfg.Monitor.FlapWindow.Duration != 10*time.Minute {
|
||||||
|
t.Fatalf("monitor.flap_window: got %v", cfg.Monitor.FlapWindow.Duration)
|
||||||
|
}
|
||||||
|
if cfg.Monitor.Retention.Duration != 30*24*time.Hour {
|
||||||
|
t.Fatalf("monitor.retention: got %v", cfg.Monitor.Retention.Duration)
|
||||||
|
}
|
||||||
|
if cfg.Log.Level != "debug" {
|
||||||
|
t.Fatalf("log.level: got %q", cfg.Log.Level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCLIConfigValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config string
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing services.dir",
|
||||||
|
config: `
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
service_name = "mcp"
|
||||||
|
[auth]
|
||||||
|
token_path = "/tmp/token"
|
||||||
|
`,
|
||||||
|
errMsg: "services.dir is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing mcias.server_url",
|
||||||
|
config: `
|
||||||
|
[services]
|
||||||
|
dir = "/tmp/services"
|
||||||
|
[mcias]
|
||||||
|
service_name = "mcp"
|
||||||
|
[auth]
|
||||||
|
token_path = "/tmp/token"
|
||||||
|
`,
|
||||||
|
errMsg: "mcias.server_url is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing mcias.service_name",
|
||||||
|
config: `
|
||||||
|
[services]
|
||||||
|
dir = "/tmp/services"
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
[auth]
|
||||||
|
token_path = "/tmp/token"
|
||||||
|
`,
|
||||||
|
errMsg: "mcias.service_name is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing auth.token_path",
|
||||||
|
config: `
|
||||||
|
[services]
|
||||||
|
dir = "/tmp/services"
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
service_name = "mcp"
|
||||||
|
`,
|
||||||
|
errMsg: "auth.token_path is required",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
path := writeTempConfig(t, tt.config)
|
||||||
|
_, err := LoadCLIConfig(path)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if got := err.Error(); !strings.Contains(got, tt.errMsg) {
|
||||||
|
t.Fatalf("error %q does not contain %q", got, tt.errMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentConfigValidation(t *testing.T) {
|
||||||
|
// Minimal valid agent config to start from.
|
||||||
|
base := `
|
||||||
|
[server]
|
||||||
|
grpc_addr = "0.0.0.0:9444"
|
||||||
|
tls_cert = "/srv/mcp/cert.pem"
|
||||||
|
tls_key = "/srv/mcp/key.pem"
|
||||||
|
[database]
|
||||||
|
path = "/srv/mcp/mcp.db"
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
service_name = "mcp-agent"
|
||||||
|
[agent]
|
||||||
|
node_name = "rift"
|
||||||
|
`
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config string
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing server.grpc_addr",
|
||||||
|
config: `
|
||||||
|
[server]
|
||||||
|
tls_cert = "/srv/mcp/cert.pem"
|
||||||
|
tls_key = "/srv/mcp/key.pem"
|
||||||
|
[database]
|
||||||
|
path = "/srv/mcp/mcp.db"
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
service_name = "mcp-agent"
|
||||||
|
[agent]
|
||||||
|
node_name = "rift"
|
||||||
|
`,
|
||||||
|
errMsg: "server.grpc_addr is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing server.tls_cert",
|
||||||
|
config: `
|
||||||
|
[server]
|
||||||
|
grpc_addr = "0.0.0.0:9444"
|
||||||
|
tls_key = "/srv/mcp/key.pem"
|
||||||
|
[database]
|
||||||
|
path = "/srv/mcp/mcp.db"
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
service_name = "mcp-agent"
|
||||||
|
[agent]
|
||||||
|
node_name = "rift"
|
||||||
|
`,
|
||||||
|
errMsg: "server.tls_cert is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing database.path",
|
||||||
|
config: `
|
||||||
|
[server]
|
||||||
|
grpc_addr = "0.0.0.0:9444"
|
||||||
|
tls_cert = "/srv/mcp/cert.pem"
|
||||||
|
tls_key = "/srv/mcp/key.pem"
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
service_name = "mcp-agent"
|
||||||
|
[agent]
|
||||||
|
node_name = "rift"
|
||||||
|
`,
|
||||||
|
errMsg: "database.path is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing agent.node_name",
|
||||||
|
config: `
|
||||||
|
[server]
|
||||||
|
grpc_addr = "0.0.0.0:9444"
|
||||||
|
tls_cert = "/srv/mcp/cert.pem"
|
||||||
|
tls_key = "/srv/mcp/key.pem"
|
||||||
|
[database]
|
||||||
|
path = "/srv/mcp/mcp.db"
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
service_name = "mcp-agent"
|
||||||
|
`,
|
||||||
|
errMsg: "agent.node_name is required",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
path := writeTempConfig(t, tt.config)
|
||||||
|
_, err := LoadAgentConfig(path)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if got := err.Error(); !strings.Contains(got, tt.errMsg) {
|
||||||
|
t.Fatalf("error %q does not contain %q", got, tt.errMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify base config actually loads fine.
|
||||||
|
t.Run("valid base config", func(t *testing.T) {
|
||||||
|
path := writeTempConfig(t, base)
|
||||||
|
if _, err := LoadAgentConfig(path); err != nil {
|
||||||
|
t.Fatalf("base config should be valid: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentConfigDefaults(t *testing.T) {
|
||||||
|
// Only required fields, no monitor/log/runtime settings.
|
||||||
|
minimal := `
|
||||||
|
[server]
|
||||||
|
grpc_addr = "0.0.0.0:9444"
|
||||||
|
tls_cert = "/srv/mcp/cert.pem"
|
||||||
|
tls_key = "/srv/mcp/key.pem"
|
||||||
|
[database]
|
||||||
|
path = "/srv/mcp/mcp.db"
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
service_name = "mcp-agent"
|
||||||
|
[agent]
|
||||||
|
node_name = "rift"
|
||||||
|
`
|
||||||
|
|
||||||
|
path := writeTempConfig(t, minimal)
|
||||||
|
cfg, err := LoadAgentConfig(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Monitor.Interval.Duration != 60*time.Second {
|
||||||
|
t.Fatalf("default interval: got %v, want 60s", cfg.Monitor.Interval.Duration)
|
||||||
|
}
|
||||||
|
if cfg.Monitor.Cooldown.Duration != 15*time.Minute {
|
||||||
|
t.Fatalf("default cooldown: got %v, want 15m", cfg.Monitor.Cooldown.Duration)
|
||||||
|
}
|
||||||
|
if cfg.Monitor.FlapThreshold != 3 {
|
||||||
|
t.Fatalf("default flap_threshold: got %d, want 3", cfg.Monitor.FlapThreshold)
|
||||||
|
}
|
||||||
|
if cfg.Monitor.FlapWindow.Duration != 10*time.Minute {
|
||||||
|
t.Fatalf("default flap_window: got %v, want 10m", cfg.Monitor.FlapWindow.Duration)
|
||||||
|
}
|
||||||
|
if cfg.Monitor.Retention.Duration != 30*24*time.Hour {
|
||||||
|
t.Fatalf("default retention: got %v, want 720h", cfg.Monitor.Retention.Duration)
|
||||||
|
}
|
||||||
|
if cfg.Log.Level != "info" {
|
||||||
|
t.Fatalf("default log level: got %q, want info", cfg.Log.Level)
|
||||||
|
}
|
||||||
|
if cfg.Agent.ContainerRuntime != "podman" {
|
||||||
|
t.Fatalf("default container_runtime: got %q, want podman", cfg.Agent.ContainerRuntime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvVarOverrides(t *testing.T) {
|
||||||
|
// Test agent env var override.
|
||||||
|
minimal := `
|
||||||
|
[server]
|
||||||
|
grpc_addr = "0.0.0.0:9444"
|
||||||
|
tls_cert = "/srv/mcp/cert.pem"
|
||||||
|
tls_key = "/srv/mcp/key.pem"
|
||||||
|
[database]
|
||||||
|
path = "/srv/mcp/mcp.db"
|
||||||
|
[mcias]
|
||||||
|
server_url = "https://mcias.metacircular.net:8443"
|
||||||
|
service_name = "mcp-agent"
|
||||||
|
[agent]
|
||||||
|
node_name = "rift"
|
||||||
|
[log]
|
||||||
|
level = "info"
|
||||||
|
`
|
||||||
|
t.Run("agent log level override", func(t *testing.T) {
|
||||||
|
t.Setenv("MCP_AGENT_LOG_LEVEL", "debug")
|
||||||
|
path := writeTempConfig(t, minimal)
|
||||||
|
cfg, err := LoadAgentConfig(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Log.Level != "debug" {
|
||||||
|
t.Fatalf("log level: got %q, want debug", cfg.Log.Level)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("agent node name override", func(t *testing.T) {
|
||||||
|
t.Setenv("MCP_AGENT_NODE_NAME", "override-node")
|
||||||
|
path := writeTempConfig(t, minimal)
|
||||||
|
cfg, err := LoadAgentConfig(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Agent.NodeName != "override-node" {
|
||||||
|
t.Fatalf("node_name: got %q, want override-node", cfg.Agent.NodeName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CLI services dir override", func(t *testing.T) {
|
||||||
|
t.Setenv("MCP_SERVICES_DIR", "/override/services")
|
||||||
|
path := writeTempConfig(t, testCLIConfig)
|
||||||
|
cfg, err := LoadCLIConfig(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Services.Dir != "/override/services" {
|
||||||
|
t.Fatalf("services.dir: got %q, want /override/services", cfg.Services.Dir)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDurationParsing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want time.Duration
|
||||||
|
fail bool
|
||||||
|
}{
|
||||||
|
{input: "60s", want: 60 * time.Second},
|
||||||
|
{input: "15m", want: 15 * time.Minute},
|
||||||
|
{input: "24h", want: 24 * time.Hour},
|
||||||
|
{input: "30d", want: 30 * 24 * time.Hour},
|
||||||
|
{input: "1d", want: 24 * time.Hour},
|
||||||
|
{input: "", want: 0},
|
||||||
|
{input: "bogus", fail: true},
|
||||||
|
{input: "notanumber-d", fail: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
var d Duration
|
||||||
|
err := d.UnmarshalText([]byte(tt.input))
|
||||||
|
if tt.fail {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for %q, got nil", tt.input)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error for %q: %v", tt.input, err)
|
||||||
|
}
|
||||||
|
if d.Duration != tt.want {
|
||||||
|
t.Fatalf("for %q: got %v, want %v", tt.input, d.Duration, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
208
internal/runtime/podman.go
Normal file
208
internal/runtime/podman.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package runtime
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Podman implements the Runtime interface using the podman CLI.
|
||||||
|
type Podman struct {
|
||||||
|
Command string // path to podman binary, default "podman"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Podman) command() string {
|
||||||
|
if p.Command != "" {
|
||||||
|
return p.Command
|
||||||
|
}
|
||||||
|
return "podman"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pull pulls a container image.
|
||||||
|
func (p *Podman) Pull(ctx context.Context, image string) error {
|
||||||
|
cmd := exec.CommandContext(ctx, p.command(), "pull", image) //nolint:gosec // args built programmatically
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("podman pull %q: %w: %s", image, err, out)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRunArgs constructs the argument list for podman run. Exported for testing.
|
||||||
|
func (p *Podman) BuildRunArgs(spec ContainerSpec) []string {
|
||||||
|
args := []string{"run", "-d", "--name", spec.Name}
|
||||||
|
|
||||||
|
if spec.Network != "" {
|
||||||
|
args = append(args, "--network", spec.Network)
|
||||||
|
}
|
||||||
|
if spec.User != "" {
|
||||||
|
args = append(args, "--user", spec.User)
|
||||||
|
}
|
||||||
|
if spec.Restart != "" {
|
||||||
|
args = append(args, "--restart", spec.Restart)
|
||||||
|
}
|
||||||
|
for _, port := range spec.Ports {
|
||||||
|
args = append(args, "-p", port)
|
||||||
|
}
|
||||||
|
for _, vol := range spec.Volumes {
|
||||||
|
args = append(args, "-v", vol)
|
||||||
|
}
|
||||||
|
|
||||||
|
args = append(args, spec.Image)
|
||||||
|
args = append(args, spec.Cmd...)
|
||||||
|
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run creates and starts a container from the given spec.
|
||||||
|
func (p *Podman) Run(ctx context.Context, spec ContainerSpec) error {
|
||||||
|
args := p.BuildRunArgs(spec)
|
||||||
|
cmd := exec.CommandContext(ctx, p.command(), args...) //nolint:gosec // args built programmatically
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("podman run %q: %w: %s", spec.Name, err, out)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops a running container.
|
||||||
|
func (p *Podman) Stop(ctx context.Context, name string) error {
|
||||||
|
cmd := exec.CommandContext(ctx, p.command(), "stop", name) //nolint:gosec // args built programmatically
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("podman stop %q: %w: %s", name, err, out)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes a container.
|
||||||
|
func (p *Podman) Remove(ctx context.Context, name string) error {
|
||||||
|
cmd := exec.CommandContext(ctx, p.command(), "rm", name) //nolint:gosec // args built programmatically
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("podman rm %q: %w: %s", name, err, out)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// podmanPortBinding is a single port binding from the inspect output.
|
||||||
|
type podmanPortBinding struct {
|
||||||
|
HostIP string `json:"HostIp"`
|
||||||
|
HostPort string `json:"HostPort"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// podmanInspectResult is the subset of podman inspect JSON we parse.
|
||||||
|
type podmanInspectResult struct {
|
||||||
|
Name string `json:"Name"`
|
||||||
|
Config struct {
|
||||||
|
Image string `json:"Image"`
|
||||||
|
Cmd []string `json:"Cmd"`
|
||||||
|
User string `json:"User"`
|
||||||
|
} `json:"Config"`
|
||||||
|
State struct {
|
||||||
|
Status string `json:"Status"`
|
||||||
|
} `json:"State"`
|
||||||
|
HostConfig struct {
|
||||||
|
RestartPolicy struct {
|
||||||
|
Name string `json:"Name"`
|
||||||
|
} `json:"RestartPolicy"`
|
||||||
|
NetworkMode string `json:"NetworkMode"`
|
||||||
|
} `json:"HostConfig"`
|
||||||
|
NetworkSettings struct {
|
||||||
|
Networks map[string]struct{} `json:"Networks"`
|
||||||
|
Ports map[string][]podmanPortBinding `json:"Ports"`
|
||||||
|
} `json:"NetworkSettings"`
|
||||||
|
Mounts []struct {
|
||||||
|
Source string `json:"Source"`
|
||||||
|
Destination string `json:"Destination"`
|
||||||
|
} `json:"Mounts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inspect retrieves information about a container.
|
||||||
|
func (p *Podman) Inspect(ctx context.Context, name string) (ContainerInfo, error) {
|
||||||
|
cmd := exec.CommandContext(ctx, p.command(), "inspect", name) //nolint:gosec // args built programmatically
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return ContainerInfo{}, fmt.Errorf("podman inspect %q: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []podmanInspectResult
|
||||||
|
if err := json.Unmarshal(out, &results); err != nil {
|
||||||
|
return ContainerInfo{}, fmt.Errorf("parse inspect output: %w", err)
|
||||||
|
}
|
||||||
|
if len(results) == 0 {
|
||||||
|
return ContainerInfo{}, fmt.Errorf("podman inspect %q: no results", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := results[0]
|
||||||
|
info := ContainerInfo{
|
||||||
|
Name: strings.TrimPrefix(r.Name, "/"),
|
||||||
|
Image: r.Config.Image,
|
||||||
|
State: r.State.Status,
|
||||||
|
User: r.Config.User,
|
||||||
|
Restart: r.HostConfig.RestartPolicy.Name,
|
||||||
|
Cmd: r.Config.Cmd,
|
||||||
|
Version: ExtractVersion(r.Config.Image),
|
||||||
|
}
|
||||||
|
|
||||||
|
info.Network = r.HostConfig.NetworkMode
|
||||||
|
if len(r.NetworkSettings.Networks) > 0 {
|
||||||
|
for netName := range r.NetworkSettings.Networks {
|
||||||
|
info.Network = netName
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for containerPort, bindings := range r.NetworkSettings.Ports {
|
||||||
|
for _, b := range bindings {
|
||||||
|
port := strings.SplitN(containerPort, "/", 2)[0] // strip "/tcp" suffix
|
||||||
|
mapping := b.HostPort + ":" + port
|
||||||
|
if b.HostIP != "" && b.HostIP != "0.0.0.0" {
|
||||||
|
mapping = b.HostIP + ":" + mapping
|
||||||
|
}
|
||||||
|
info.Ports = append(info.Ports, mapping)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range r.Mounts {
|
||||||
|
info.Volumes = append(info.Volumes, m.Source+":"+m.Destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// podmanPSEntry is a single entry from podman ps --format json.
|
||||||
|
type podmanPSEntry struct {
|
||||||
|
Names []string `json:"Names"`
|
||||||
|
Image string `json:"Image"`
|
||||||
|
State string `json:"State"`
|
||||||
|
Command string `json:"Command"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns information about all containers.
|
||||||
|
func (p *Podman) List(ctx context.Context) ([]ContainerInfo, error) {
|
||||||
|
cmd := exec.CommandContext(ctx, p.command(), "ps", "-a", "--format", "json") //nolint:gosec // args built programmatically
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("podman ps: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var entries []podmanPSEntry
|
||||||
|
if err := json.Unmarshal(out, &entries); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse ps output: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
infos := make([]ContainerInfo, 0, len(entries))
|
||||||
|
for _, e := range entries {
|
||||||
|
name := ""
|
||||||
|
if len(e.Names) > 0 {
|
||||||
|
name = e.Names[0]
|
||||||
|
}
|
||||||
|
infos = append(infos, ContainerInfo{
|
||||||
|
Name: name,
|
||||||
|
Image: e.Image,
|
||||||
|
State: e.State,
|
||||||
|
Version: ExtractVersion(e.Image),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return infos, nil
|
||||||
|
}
|
||||||
62
internal/runtime/runtime.go
Normal file
62
internal/runtime/runtime.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package runtime
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContainerSpec describes a container to create and run.
|
||||||
|
type ContainerSpec struct {
|
||||||
|
Name string // container name, format: <service>-<component>
|
||||||
|
Image string // full image reference
|
||||||
|
Network string // docker network name
|
||||||
|
User string // container user (e.g., "0:0")
|
||||||
|
Restart string // restart policy (e.g., "unless-stopped")
|
||||||
|
Ports []string // "host:container" port mappings
|
||||||
|
Volumes []string // "host:container" volume mounts
|
||||||
|
Cmd []string // command and arguments
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContainerInfo describes the observed state of a running or stopped container.
|
||||||
|
type ContainerInfo struct {
|
||||||
|
Name string
|
||||||
|
Image string
|
||||||
|
State string // "running", "stopped", "exited", etc.
|
||||||
|
Network string
|
||||||
|
User string
|
||||||
|
Restart string
|
||||||
|
Ports []string
|
||||||
|
Volumes []string
|
||||||
|
Cmd []string
|
||||||
|
Version string // extracted from image tag
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runtime is the container runtime abstraction.
|
||||||
|
type Runtime interface {
|
||||||
|
Pull(ctx context.Context, image string) error
|
||||||
|
Run(ctx context.Context, spec ContainerSpec) error
|
||||||
|
Stop(ctx context.Context, name string) error
|
||||||
|
Remove(ctx context.Context, name string) error
|
||||||
|
Inspect(ctx context.Context, name string) (ContainerInfo, error)
|
||||||
|
List(ctx context.Context) ([]ContainerInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractVersion parses the tag from an image reference.
|
||||||
|
// Examples:
|
||||||
|
//
|
||||||
|
// "registry/img:v1.2.0" -> "v1.2.0"
|
||||||
|
// "registry/img:latest" -> "latest"
|
||||||
|
// "registry/img" -> ""
|
||||||
|
// "registry:5000/img:v1" -> "v1"
|
||||||
|
func ExtractVersion(image string) string {
|
||||||
|
// Strip registry/path prefix so that a port like "registry:5000" isn't
|
||||||
|
// mistaken for a tag separator.
|
||||||
|
name := image
|
||||||
|
if i := strings.LastIndex(image, "/"); i >= 0 {
|
||||||
|
name = image[i+1:]
|
||||||
|
}
|
||||||
|
if i := strings.LastIndex(name, ":"); i >= 0 {
|
||||||
|
return name[i+1:]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
113
internal/runtime/runtime_test.go
Normal file
113
internal/runtime/runtime_test.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package runtime
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func requireEqualArgs(t *testing.T, got, want []string) {
|
||||||
|
t.Helper()
|
||||||
|
if !slices.Equal(got, want) {
|
||||||
|
t.Fatalf("args mismatch\ngot: %v\nwant: %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildRunArgs(t *testing.T) {
|
||||||
|
p := &Podman{}
|
||||||
|
|
||||||
|
t.Run("full spec", func(t *testing.T) {
|
||||||
|
spec := ContainerSpec{
|
||||||
|
Name: "metacrypt-api",
|
||||||
|
Image: "mcr.svc.mcp.metacircular.net:8443/metacrypt:v1.0.0",
|
||||||
|
Network: "docker_default",
|
||||||
|
User: "0:0",
|
||||||
|
Restart: "unless-stopped",
|
||||||
|
Ports: []string{"127.0.0.1:18443:8443", "127.0.0.1:19443:9443"},
|
||||||
|
Volumes: []string{"/srv/metacrypt:/srv/metacrypt", "/etc/ssl:/etc/ssl:ro"},
|
||||||
|
Cmd: []string{"server", "--config", "/srv/metacrypt/metacrypt.toml"},
|
||||||
|
}
|
||||||
|
requireEqualArgs(t, p.BuildRunArgs(spec), []string{
|
||||||
|
"run", "-d", "--name", "metacrypt-api",
|
||||||
|
"--network", "docker_default",
|
||||||
|
"--user", "0:0",
|
||||||
|
"--restart", "unless-stopped",
|
||||||
|
"-p", "127.0.0.1:18443:8443",
|
||||||
|
"-p", "127.0.0.1:19443:9443",
|
||||||
|
"-v", "/srv/metacrypt:/srv/metacrypt",
|
||||||
|
"-v", "/etc/ssl:/etc/ssl:ro",
|
||||||
|
"mcr.svc.mcp.metacircular.net:8443/metacrypt:v1.0.0",
|
||||||
|
"server", "--config", "/srv/metacrypt/metacrypt.toml",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("minimal spec", func(t *testing.T) {
|
||||||
|
spec := ContainerSpec{
|
||||||
|
Name: "test-app",
|
||||||
|
Image: "img:latest",
|
||||||
|
}
|
||||||
|
requireEqualArgs(t, p.BuildRunArgs(spec), []string{
|
||||||
|
"run", "-d", "--name", "test-app", "img:latest",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ports only", func(t *testing.T) {
|
||||||
|
spec := ContainerSpec{
|
||||||
|
Name: "test-app",
|
||||||
|
Image: "img:latest",
|
||||||
|
Ports: []string{"8080:80", "8443:443"},
|
||||||
|
}
|
||||||
|
requireEqualArgs(t, p.BuildRunArgs(spec), []string{
|
||||||
|
"run", "-d", "--name", "test-app",
|
||||||
|
"-p", "8080:80", "-p", "8443:443",
|
||||||
|
"img:latest",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("volumes only", func(t *testing.T) {
|
||||||
|
spec := ContainerSpec{
|
||||||
|
Name: "test-app",
|
||||||
|
Image: "img:latest",
|
||||||
|
Volumes: []string{"/data:/data", "/config:/config:ro"},
|
||||||
|
}
|
||||||
|
requireEqualArgs(t, p.BuildRunArgs(spec), []string{
|
||||||
|
"run", "-d", "--name", "test-app",
|
||||||
|
"-v", "/data:/data", "-v", "/config:/config:ro",
|
||||||
|
"img:latest",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cmd after image", func(t *testing.T) {
|
||||||
|
spec := ContainerSpec{
|
||||||
|
Name: "test-app",
|
||||||
|
Image: "img:latest",
|
||||||
|
Cmd: []string{"serve", "--port", "8080"},
|
||||||
|
}
|
||||||
|
requireEqualArgs(t, p.BuildRunArgs(spec), []string{
|
||||||
|
"run", "-d", "--name", "test-app",
|
||||||
|
"img:latest",
|
||||||
|
"serve", "--port", "8080",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractVersion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
image string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"registry.example.com:5000/img:v1.2.0", "v1.2.0"},
|
||||||
|
{"img:latest", "latest"},
|
||||||
|
{"img", ""},
|
||||||
|
{"registry.example.com/path/img:v1", "v1"},
|
||||||
|
{"registry.example.com:5000/path/img", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.image, func(t *testing.T) {
|
||||||
|
got := ExtractVersion(tt.image)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("ExtractVersion(%q) = %q, want %q", tt.image, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
185
internal/servicedef/servicedef.go
Normal file
185
internal/servicedef/servicedef.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
// Package servicedef handles parsing and writing TOML service definition files.
|
||||||
|
package servicedef
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
toml "github.com/pelletier/go-toml/v2"
|
||||||
|
|
||||||
|
mcpv1 "git.wntrmute.dev/kyle/mcp/gen/mcp/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServiceDef is the top-level TOML structure for a service definition file.
|
||||||
|
type ServiceDef struct {
|
||||||
|
Name string `toml:"name"`
|
||||||
|
Node string `toml:"node"`
|
||||||
|
Active *bool `toml:"active,omitempty"`
|
||||||
|
Components []ComponentDef `toml:"components"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ComponentDef describes a single container component within a service.
|
||||||
|
type ComponentDef struct {
|
||||||
|
Name string `toml:"name"`
|
||||||
|
Image string `toml:"image"`
|
||||||
|
Network string `toml:"network,omitempty"`
|
||||||
|
User string `toml:"user,omitempty"`
|
||||||
|
Restart string `toml:"restart,omitempty"`
|
||||||
|
Ports []string `toml:"ports,omitempty"`
|
||||||
|
Volumes []string `toml:"volumes,omitempty"`
|
||||||
|
Cmd []string `toml:"cmd,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load reads and parses a TOML service definition file. If the active field
|
||||||
|
// is omitted, it defaults to true.
|
||||||
|
func Load(path string) (*ServiceDef, error) {
|
||||||
|
data, err := os.ReadFile(path) //nolint:gosec // path from trusted config dir
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read service def %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var def ServiceDef
|
||||||
|
if err := toml.Unmarshal(data, &def); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse service def %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validate(&def); err != nil {
|
||||||
|
return nil, fmt.Errorf("validate service def %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if def.Active == nil {
|
||||||
|
t := true
|
||||||
|
def.Active = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
return &def, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write serializes a ServiceDef to TOML and writes it to the given path
|
||||||
|
// with 0644 permissions. Parent directories are created if needed.
|
||||||
|
func Write(path string, def *ServiceDef) error {
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil { //nolint:gosec // service defs are non-secret
|
||||||
|
return fmt.Errorf("create parent dirs for %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := toml.Marshal(def)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal service def: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(path, data, 0o644); err != nil { //nolint:gosec // service defs are non-secret
|
||||||
|
return fmt.Errorf("write service def %q: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAll reads all .toml files from dir, parses each one, and returns the
|
||||||
|
// list sorted by service name.
|
||||||
|
func LoadAll(dir string) ([]*ServiceDef, error) {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read service dir %q: %w", dir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var defs []*ServiceDef
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() || !strings.HasSuffix(e.Name(), ".toml") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
def, err := Load(filepath.Join(dir, e.Name()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defs = append(defs, def)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(defs, func(i, j int) bool {
|
||||||
|
return defs[i].Name < defs[j].Name
|
||||||
|
})
|
||||||
|
|
||||||
|
return defs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate checks that a ServiceDef has all required fields and that
|
||||||
|
// component names are unique.
|
||||||
|
func validate(def *ServiceDef) error {
|
||||||
|
if def.Name == "" {
|
||||||
|
return fmt.Errorf("service name is required")
|
||||||
|
}
|
||||||
|
if def.Node == "" {
|
||||||
|
return fmt.Errorf("service node is required")
|
||||||
|
}
|
||||||
|
if len(def.Components) == 0 {
|
||||||
|
return fmt.Errorf("service %q must have at least one component", def.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, c := range def.Components {
|
||||||
|
if c.Name == "" {
|
||||||
|
return fmt.Errorf("component name is required in service %q", def.Name)
|
||||||
|
}
|
||||||
|
if c.Image == "" {
|
||||||
|
return fmt.Errorf("component %q image is required in service %q", c.Name, def.Name)
|
||||||
|
}
|
||||||
|
if seen[c.Name] {
|
||||||
|
return fmt.Errorf("duplicate component name %q in service %q", c.Name, def.Name)
|
||||||
|
}
|
||||||
|
seen[c.Name] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToProto converts a ServiceDef to a proto ServiceSpec.
|
||||||
|
func ToProto(def *ServiceDef) *mcpv1.ServiceSpec {
|
||||||
|
spec := &mcpv1.ServiceSpec{
|
||||||
|
Name: def.Name,
|
||||||
|
Active: def.Active != nil && *def.Active,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range def.Components {
|
||||||
|
spec.Components = append(spec.Components, &mcpv1.ComponentSpec{
|
||||||
|
Name: c.Name,
|
||||||
|
Image: c.Image,
|
||||||
|
Network: c.Network,
|
||||||
|
User: c.User,
|
||||||
|
Restart: c.Restart,
|
||||||
|
Ports: c.Ports,
|
||||||
|
Volumes: c.Volumes,
|
||||||
|
Cmd: c.Cmd,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return spec
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromProto converts a proto ServiceSpec back to a ServiceDef. The node
|
||||||
|
// parameter is required because ServiceSpec does not include the node field
|
||||||
|
// (it is a CLI-side routing concern).
|
||||||
|
func FromProto(spec *mcpv1.ServiceSpec, node string) *ServiceDef {
|
||||||
|
active := spec.GetActive()
|
||||||
|
def := &ServiceDef{
|
||||||
|
Name: spec.GetName(),
|
||||||
|
Node: node,
|
||||||
|
Active: &active,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range spec.GetComponents() {
|
||||||
|
def.Components = append(def.Components, ComponentDef{
|
||||||
|
Name: c.GetName(),
|
||||||
|
Image: c.GetImage(),
|
||||||
|
Network: c.GetNetwork(),
|
||||||
|
User: c.GetUser(),
|
||||||
|
Restart: c.GetRestart(),
|
||||||
|
Ports: c.GetPorts(),
|
||||||
|
Volumes: c.GetVolumes(),
|
||||||
|
Cmd: c.GetCmd(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return def
|
||||||
|
}
|
||||||
289
internal/servicedef/servicedef_test.go
Normal file
289
internal/servicedef/servicedef_test.go
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
package servicedef
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func boolPtr(b bool) *bool { return &b }
|
||||||
|
|
||||||
|
func sampleDef() *ServiceDef {
|
||||||
|
return &ServiceDef{
|
||||||
|
Name: "metacrypt",
|
||||||
|
Node: "rift",
|
||||||
|
Active: boolPtr(true),
|
||||||
|
Components: []ComponentDef{
|
||||||
|
{
|
||||||
|
Name: "api",
|
||||||
|
Image: "mcr.svc.mcp.metacircular.net:8443/metacrypt:latest",
|
||||||
|
Network: "docker_default",
|
||||||
|
User: "0:0",
|
||||||
|
Restart: "unless-stopped",
|
||||||
|
Ports: []string{"127.0.0.1:18443:8443", "127.0.0.1:19443:9443"},
|
||||||
|
Volumes: []string{"/srv/metacrypt:/srv/metacrypt"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "web",
|
||||||
|
Image: "mcr.svc.mcp.metacircular.net:8443/metacrypt-web:latest",
|
||||||
|
Network: "docker_default",
|
||||||
|
User: "0:0",
|
||||||
|
Restart: "unless-stopped",
|
||||||
|
Ports: []string{"127.0.0.1:18080:8080"},
|
||||||
|
Volumes: []string{"/srv/metacrypt:/srv/metacrypt"},
|
||||||
|
Cmd: []string{"server", "--config", "/srv/metacrypt/metacrypt.toml"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compareComponents asserts that two component slices are equal field by field.
|
||||||
|
func compareComponents(t *testing.T, prefix string, got, want []ComponentDef) {
|
||||||
|
t.Helper()
|
||||||
|
if len(got) != len(want) {
|
||||||
|
t.Fatalf("%s components: got %d, want %d", prefix, len(got), len(want))
|
||||||
|
}
|
||||||
|
for i, wantC := range want {
|
||||||
|
gotC := got[i]
|
||||||
|
if gotC.Name != wantC.Name {
|
||||||
|
t.Fatalf("%s component[%d] name: got %q, want %q", prefix, i, gotC.Name, wantC.Name)
|
||||||
|
}
|
||||||
|
if gotC.Image != wantC.Image {
|
||||||
|
t.Fatalf("%s component[%d] image: got %q, want %q", prefix, i, gotC.Image, wantC.Image)
|
||||||
|
}
|
||||||
|
if gotC.Network != wantC.Network {
|
||||||
|
t.Fatalf("%s component[%d] network: got %q, want %q", prefix, i, gotC.Network, wantC.Network)
|
||||||
|
}
|
||||||
|
if gotC.User != wantC.User {
|
||||||
|
t.Fatalf("%s component[%d] user: got %q, want %q", prefix, i, gotC.User, wantC.User)
|
||||||
|
}
|
||||||
|
if gotC.Restart != wantC.Restart {
|
||||||
|
t.Fatalf("%s component[%d] restart: got %q, want %q", prefix, i, gotC.Restart, wantC.Restart)
|
||||||
|
}
|
||||||
|
compareStrSlice(t, prefix, i, "ports", gotC.Ports, wantC.Ports)
|
||||||
|
compareStrSlice(t, prefix, i, "volumes", gotC.Volumes, wantC.Volumes)
|
||||||
|
compareStrSlice(t, prefix, i, "cmd", gotC.Cmd, wantC.Cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareStrSlice(t *testing.T, prefix string, idx int, field string, got, want []string) {
|
||||||
|
t.Helper()
|
||||||
|
if len(got) != len(want) {
|
||||||
|
t.Fatalf("%s component[%d] %s: got %d, want %d", prefix, idx, field, len(got), len(want))
|
||||||
|
}
|
||||||
|
for j := range want {
|
||||||
|
if got[j] != want[j] {
|
||||||
|
t.Fatalf("%s component[%d] %s[%d]: got %q, want %q", prefix, idx, field, j, got[j], want[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadWrite(t *testing.T) {
|
||||||
|
def := sampleDef()
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "metacrypt.toml")
|
||||||
|
|
||||||
|
if err := Write(path, def); err != nil {
|
||||||
|
t.Fatalf("write: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got.Name != def.Name {
|
||||||
|
t.Fatalf("name: got %q, want %q", got.Name, def.Name)
|
||||||
|
}
|
||||||
|
if got.Node != def.Node {
|
||||||
|
t.Fatalf("node: got %q, want %q", got.Node, def.Node)
|
||||||
|
}
|
||||||
|
if *got.Active != *def.Active {
|
||||||
|
t.Fatalf("active: got %v, want %v", *got.Active, *def.Active)
|
||||||
|
}
|
||||||
|
compareComponents(t, "load-write", got.Components, def.Components)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
def *ServiceDef
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing name",
|
||||||
|
def: &ServiceDef{
|
||||||
|
Node: "rift",
|
||||||
|
Components: []ComponentDef{{Name: "api", Image: "img:v1"}},
|
||||||
|
},
|
||||||
|
wantErr: "service name is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing node",
|
||||||
|
def: &ServiceDef{
|
||||||
|
Name: "svc",
|
||||||
|
Components: []ComponentDef{{Name: "api", Image: "img:v1"}},
|
||||||
|
},
|
||||||
|
wantErr: "service node is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty components",
|
||||||
|
def: &ServiceDef{
|
||||||
|
Name: "svc",
|
||||||
|
Node: "rift",
|
||||||
|
},
|
||||||
|
wantErr: "must have at least one component",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "duplicate component names",
|
||||||
|
def: &ServiceDef{
|
||||||
|
Name: "svc",
|
||||||
|
Node: "rift",
|
||||||
|
Components: []ComponentDef{
|
||||||
|
{Name: "api", Image: "img:v1"},
|
||||||
|
{Name: "api", Image: "img:v2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "duplicate component name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component missing name",
|
||||||
|
def: &ServiceDef{
|
||||||
|
Name: "svc",
|
||||||
|
Node: "rift",
|
||||||
|
Components: []ComponentDef{{Image: "img:v1"}},
|
||||||
|
},
|
||||||
|
wantErr: "component name is required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component missing image",
|
||||||
|
def: &ServiceDef{
|
||||||
|
Name: "svc",
|
||||||
|
Node: "rift",
|
||||||
|
Components: []ComponentDef{{Name: "api"}},
|
||||||
|
},
|
||||||
|
wantErr: "image is required",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validate(tt.def)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected validation error")
|
||||||
|
}
|
||||||
|
if got := err.Error(); !strings.Contains(got, tt.wantErr) {
|
||||||
|
t.Fatalf("error %q does not contain %q", got, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadAll(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
|
// Write services in non-alphabetical order.
|
||||||
|
defs := []*ServiceDef{
|
||||||
|
{
|
||||||
|
Name: "mcr",
|
||||||
|
Node: "rift",
|
||||||
|
Active: boolPtr(true),
|
||||||
|
Components: []ComponentDef{{Name: "api", Image: "mcr:latest"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "metacrypt",
|
||||||
|
Node: "rift",
|
||||||
|
Active: boolPtr(true),
|
||||||
|
Components: []ComponentDef{{Name: "api", Image: "metacrypt:latest"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "mcias",
|
||||||
|
Node: "rift",
|
||||||
|
Active: boolPtr(false),
|
||||||
|
Components: []ComponentDef{{Name: "api", Image: "mcias:latest"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, d := range defs {
|
||||||
|
if err := Write(filepath.Join(dir, d.Name+".toml"), d); err != nil {
|
||||||
|
t.Fatalf("write %s: %v", d.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write a non-TOML file that should be ignored.
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("# ignore"), 0o600); err != nil {
|
||||||
|
t.Fatalf("write readme: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := LoadAll(dir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load all: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(got) != 3 {
|
||||||
|
t.Fatalf("count: got %d, want 3", len(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
wantOrder := []string{"mcias", "mcr", "metacrypt"}
|
||||||
|
for i, name := range wantOrder {
|
||||||
|
if got[i].Name != name {
|
||||||
|
t.Fatalf("order[%d]: got %q, want %q", i, got[i].Name, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestActiveDefault(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "svc.toml")
|
||||||
|
|
||||||
|
content := `name = "svc"
|
||||||
|
node = "rift"
|
||||||
|
|
||||||
|
[[components]]
|
||||||
|
name = "api"
|
||||||
|
image = "img:latest"
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
|
||||||
|
t.Fatalf("write: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
def, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if def.Active == nil {
|
||||||
|
t.Fatal("active should not be nil")
|
||||||
|
}
|
||||||
|
if !*def.Active {
|
||||||
|
t.Fatal("active should default to true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtoConversion(t *testing.T) {
|
||||||
|
def := sampleDef()
|
||||||
|
|
||||||
|
spec := ToProto(def)
|
||||||
|
if spec.Name != def.Name {
|
||||||
|
t.Fatalf("proto name: got %q, want %q", spec.Name, def.Name)
|
||||||
|
}
|
||||||
|
if !spec.Active {
|
||||||
|
t.Fatal("proto active should be true")
|
||||||
|
}
|
||||||
|
if len(spec.Components) != len(def.Components) {
|
||||||
|
t.Fatalf("proto components: got %d, want %d", len(spec.Components), len(def.Components))
|
||||||
|
}
|
||||||
|
|
||||||
|
got := FromProto(spec, def.Node)
|
||||||
|
if got.Name != def.Name {
|
||||||
|
t.Fatalf("round-trip name: got %q, want %q", got.Name, def.Name)
|
||||||
|
}
|
||||||
|
if got.Node != def.Node {
|
||||||
|
t.Fatalf("round-trip node: got %q, want %q", got.Node, def.Node)
|
||||||
|
}
|
||||||
|
if *got.Active != *def.Active {
|
||||||
|
t.Fatalf("round-trip active: got %v, want %v", *got.Active, *def.Active)
|
||||||
|
}
|
||||||
|
compareComponents(t, "round-trip", got.Components, def.Components)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user