From 15b88238105e56e42687b70dbaf010ea6e63c162 Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Thu, 26 Mar 2026 11:36:12 -0700 Subject: [PATCH] P1.2-P1.5: Complete Phase 1 core libraries Four packages built in parallel: - P1.2 runtime: Container runtime abstraction with podman implementation. Interface (Pull/Run/Stop/Remove/Inspect/List), ContainerSpec/ContainerInfo types, CLI arg building, version extraction from image tags. 2 tests. - P1.3 servicedef: TOML service definition file parsing. Load/Write/LoadAll, validation (required fields, unique component names), proto conversion. 5 tests. - P1.4 config: CLI and agent config loading from TOML. Duration type for time fields, env var overrides (MCP_*/MCP_AGENT_*), required field validation, sensible defaults. 7 tests. - P1.5 auth: MCIAS integration. Token validator with 30s SHA-256 cache, gRPC unary interceptor (admin role enforcement, audit logging), Login/LoadToken/SaveToken for CLI. 9 tests. All packages pass build, vet, lint, and test. Co-Authored-By: Claude Opus 4.6 (1M context) --- .gitignore | 1 + PROGRESS_V1.md | 8 +- go.mod | 1 + go.sum | 2 + internal/auth/auth.go | 332 +++++++++++++++++ internal/auth/auth_test.go | 391 ++++++++++++++++++++ internal/config/agent.go | 186 ++++++++++ internal/config/cli.go | 97 +++++ internal/config/config_test.go | 476 +++++++++++++++++++++++++ internal/runtime/podman.go | 208 +++++++++++ internal/runtime/runtime.go | 62 ++++ internal/runtime/runtime_test.go | 113 ++++++ internal/servicedef/servicedef.go | 185 ++++++++++ internal/servicedef/servicedef_test.go | 289 +++++++++++++++ 14 files changed, 2347 insertions(+), 4 deletions(-) create mode 100644 internal/auth/auth.go create mode 100644 internal/auth/auth_test.go create mode 100644 internal/config/agent.go create mode 100644 internal/config/cli.go create mode 100644 internal/config/config_test.go create mode 100644 internal/runtime/podman.go create mode 100644 internal/runtime/runtime.go create mode 100644 internal/runtime/runtime_test.go create mode 100644 internal/servicedef/servicedef.go create mode 100644 internal/servicedef/servicedef_test.go diff --git a/.gitignore b/.gitignore index 377373a..e480ccd 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ srv/ # OS .DS_Store +.claude/ diff --git a/PROGRESS_V1.md b/PROGRESS_V1.md index 26c67f2..e3afa83 100644 --- a/PROGRESS_V1.md +++ b/PROGRESS_V1.md @@ -8,10 +8,10 @@ ## Phase 1: Core Libraries - [x] **P1.1** Registry package (`internal/registry/`) -- [ ] **P1.2** Runtime package (`internal/runtime/`) -- [ ] **P1.3** Service definition package (`internal/servicedef/`) -- [ ] **P1.4** Config package (`internal/config/`) -- [ ] **P1.5** Auth package (`internal/auth/`) +- [x] **P1.2** Runtime package (`internal/runtime/`) +- [x] **P1.3** Service definition package (`internal/servicedef/`) +- [x] **P1.4** Config package (`internal/config/`) +- [x] **P1.5** Auth package (`internal/auth/`) ## Phase 2: Agent diff --git a/go.mod b/go.mod index 96c5d97..093d789 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module git.wntrmute.dev/kyle/mcp go 1.25.7 require ( + github.com/pelletier/go-toml/v2 v2.3.0 github.com/spf13/cobra v1.10.2 google.golang.org/grpc v1.79.3 google.golang.org/protobuf v1.36.11 diff --git a/go.sum b/go.sum index 5e35c4b..5a14716 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM= +github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..28fb664 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,332 @@ +package auth + +import ( + "bytes" + "context" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// cacheTTL is the duration cached token validation results are valid. +const cacheTTL = 30 * time.Second + +// TokenInfo holds the result of a token validation. +type TokenInfo struct { + Valid bool `json:"valid"` + Username string `json:"username"` + Roles []string `json:"roles"` + AccountType string `json:"account_type"` +} + +// HasRole reports whether the token has the given role. +func (t *TokenInfo) HasRole(role string) bool { + for _, r := range t.Roles { + if r == role { + return true + } + } + return false +} + +// TokenValidator validates bearer tokens against MCIAS. +type TokenValidator interface { + ValidateToken(ctx context.Context, token string) (*TokenInfo, error) +} + +// tokenCache stores validated token results with a TTL. +type tokenCache struct { + mu sync.RWMutex + entries map[string]cacheEntry +} + +type cacheEntry struct { + info *TokenInfo + expiresAt time.Time +} + +func newTokenCache() *tokenCache { + return &tokenCache{ + entries: make(map[string]cacheEntry), + } +} + +func (c *tokenCache) get(hash string) (*TokenInfo, bool) { + c.mu.RLock() + entry, ok := c.entries[hash] + c.mu.RUnlock() + if !ok { + return nil, false + } + if time.Now().After(entry.expiresAt) { + c.mu.Lock() + delete(c.entries, hash) + c.mu.Unlock() + return nil, false + } + return entry.info, true +} + +func (c *tokenCache) put(hash string, info *TokenInfo) { + c.mu.Lock() + c.entries[hash] = cacheEntry{ + info: info, + expiresAt: time.Now().Add(cacheTTL), + } + c.mu.Unlock() +} + +// newHTTPClient creates an HTTP client with TLS 1.3 minimum. If caCertPath +// is non-empty, the CA certificate is loaded and added to the root CA pool. +func newHTTPClient(caCertPath string) (*http.Client, error) { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS13, + } + + if caCertPath != "" { + caCert, err := os.ReadFile(caCertPath) //nolint:gosec // path from trusted config + if err != nil { + return nil, fmt.Errorf("read CA cert %q: %w", caCertPath, err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("parse CA cert %q: no valid certificates found", caCertPath) + } + tlsConfig.RootCAs = pool + } + + return &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + }, nil +} + +// MCIASValidator validates tokens by calling the MCIAS HTTP endpoint. +type MCIASValidator struct { + ServerURL string + CACertPath string + httpClient *http.Client + cache *tokenCache +} + +// NewMCIASValidator creates a validator that calls MCIAS at the given URL. +// If caCertPath is non-empty, the CA certificate is loaded and used for TLS. +func NewMCIASValidator(serverURL, caCertPath string) (*MCIASValidator, error) { + client, err := newHTTPClient(caCertPath) + if err != nil { + return nil, err + } + + return &MCIASValidator{ + ServerURL: strings.TrimRight(serverURL, "/"), + CACertPath: caCertPath, + httpClient: client, + cache: newTokenCache(), + }, nil +} + +func tokenHash(token string) string { + h := sha256.Sum256([]byte(token)) + return hex.EncodeToString(h[:]) +} + +// ValidateToken validates a bearer token against MCIAS. Results are cached +// for 30 seconds keyed by the SHA-256 hash of the token. +func (v *MCIASValidator) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) { + hash := tokenHash(token) + + if info, ok := v.cache.get(hash); ok { + return info, nil + } + + url := v.ServerURL + "/v1/token/validate" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + return nil, fmt.Errorf("create validate request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := v.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("validate token: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read validate response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("validate token: MCIAS returned %d", resp.StatusCode) + } + + var info TokenInfo + if err := json.Unmarshal(body, &info); err != nil { + return nil, fmt.Errorf("parse validate response: %w", err) + } + + v.cache.put(hash, &info) + return &info, nil +} + +// contextKey is an unexported type for context keys in this package. +type contextKey struct{} + +// tokenInfoKey is the context key for TokenInfo. +var tokenInfoKey = contextKey{} + +// ContextWithTokenInfo returns a new context carrying the given TokenInfo. +func ContextWithTokenInfo(ctx context.Context, info *TokenInfo) context.Context { + return context.WithValue(ctx, tokenInfoKey, info) +} + +// TokenInfoFromContext retrieves TokenInfo from the context, or nil if absent. +func TokenInfoFromContext(ctx context.Context) *TokenInfo { + info, _ := ctx.Value(tokenInfoKey).(*TokenInfo) + return info +} + +// AuthInterceptor returns a gRPC unary server interceptor that validates +// bearer tokens and requires the "admin" role. +func AuthInterceptor(validator TokenValidator) grpc.UnaryServerInterceptor { + return func( + ctx context.Context, + req any, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, + ) (any, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Error(codes.Unauthenticated, "missing metadata") + } + + authValues := md.Get("authorization") + if len(authValues) == 0 { + return nil, status.Error(codes.Unauthenticated, "missing authorization header") + } + + authHeader := authValues[0] + if !strings.HasPrefix(authHeader, "Bearer ") { + return nil, status.Error(codes.Unauthenticated, "malformed authorization header") + } + token := strings.TrimPrefix(authHeader, "Bearer ") + + tokenInfo, err := validator.ValidateToken(ctx, token) + if err != nil { + slog.Error("token validation failed", "method", info.FullMethod, "error", err) + return nil, status.Error(codes.Unauthenticated, "token validation failed") + } + + if !tokenInfo.Valid { + return nil, status.Error(codes.Unauthenticated, "invalid token") + } + + if !tokenInfo.HasRole("admin") { + slog.Warn("permission denied", "method", info.FullMethod, "user", tokenInfo.Username) + return nil, status.Error(codes.PermissionDenied, "admin role required") + } + + slog.Info("rpc", "method", info.FullMethod, "user", tokenInfo.Username, "account_type", tokenInfo.AccountType) + + ctx = ContextWithTokenInfo(ctx, tokenInfo) + return handler(ctx, req) + } +} + +// Login authenticates with MCIAS and returns a bearer token. +func Login(serverURL, caCertPath, username, password string) (string, error) { + client, err := newHTTPClient(caCertPath) + if err != nil { + return "", err + } + + serverURL = strings.TrimRight(serverURL, "/") + url := serverURL + "/v1/auth/login" + + payload := struct { + Username string `json:"username"` + Password string `json:"password"` + }{ + Username: username, + Password: password, + } + body, err := json.Marshal(payload) //nolint:gosec // intentional login credential payload + if err != nil { + return "", fmt.Errorf("marshal login request: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("create login request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("login request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read login response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("login failed: MCIAS returned %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Token string `json:"token"` + } + if err := json.Unmarshal(respBody, &result); err != nil { + return "", fmt.Errorf("parse login response: %w", err) + } + + if result.Token == "" { + return "", fmt.Errorf("login response missing token") + } + + return result.Token, nil +} + +// LoadToken reads a token from the given file path and trims whitespace. +func LoadToken(path string) (string, error) { + data, err := os.ReadFile(path) //nolint:gosec // path from trusted caller + if err != nil { + return "", fmt.Errorf("load token from %q: %w", path, err) + } + return strings.TrimSpace(string(data)), nil +} + +// SaveToken writes a token to the given file with 0600 permissions. +// Parent directories are created if they do not exist. +func SaveToken(path string, token string) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("create token directory %q: %w", dir, err) + } + if err := os.WriteFile(path, []byte(token+"\n"), 0600); err != nil { + return fmt.Errorf("save token to %q: %w", path, err) + } + return nil +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..803a614 --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,391 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync/atomic" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// mockMCIAS creates a test MCIAS server that responds to /v1/token/validate. +// The handler function receives the authorization header and returns (response, statusCode). +func mockMCIAS(t *testing.T, handler func(authHeader string) (any, int)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + resp, code := handler(authHeader) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Errorf("encode response: %v", err) + } + })) +} + +func validatorFromServer(t *testing.T, server *httptest.Server) *MCIASValidator { + t.Helper() + v, err := NewMCIASValidator(server.URL, "") + if err != nil { + t.Fatalf("create validator: %v", err) + } + v.httpClient = server.Client() + return v +} + +// callInterceptor invokes the interceptor with the given context and validator. +func callInterceptor(ctx context.Context, validator TokenValidator) (*TokenInfo, error) { + interceptor := AuthInterceptor(validator) + info := &grpc.UnaryServerInfo{FullMethod: "/mcp.v1.MCPService/TestMethod"} + + var captured *TokenInfo + handler := func(ctx context.Context, req any) (any, error) { + captured = TokenInfoFromContext(ctx) + return "ok", nil + } + + _, err := interceptor(ctx, nil, info, handler) + return captured, err +} + +func TestInterceptorRejectsNoToken(t *testing.T) { + server := mockMCIAS(t, func(authHeader string) (any, int) { + return map[string]any{"valid": false}, http.StatusOK + }) + defer server.Close() + + v := validatorFromServer(t, server) + + // No metadata at all. + _, err := callInterceptor(context.Background(), v) + if err == nil { + t.Fatal("expected error, got nil") + } + if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated { + t.Fatalf("expected Unauthenticated, got %v", err) + } + + // Metadata present but no authorization key. + md := metadata.Pairs("other-key", "value") + ctx := metadata.NewIncomingContext(context.Background(), md) + _, err = callInterceptor(ctx, v) + if err == nil { + t.Fatal("expected error with empty auth, got nil") + } + if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated { + t.Fatalf("expected Unauthenticated, got %v", err) + } +} + +func TestInterceptorRejectsMalformedToken(t *testing.T) { + server := mockMCIAS(t, func(authHeader string) (any, int) { + return map[string]any{"valid": false}, http.StatusOK + }) + defer server.Close() + + v := validatorFromServer(t, server) + + md := metadata.Pairs("authorization", "NotBearer xxx") + ctx := metadata.NewIncomingContext(context.Background(), md) + + _, err := callInterceptor(ctx, v) + if err == nil { + t.Fatal("expected error, got nil") + } + if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated { + t.Fatalf("expected Unauthenticated, got %v", err) + } +} + +func TestInterceptorRejectsInvalidToken(t *testing.T) { + server := mockMCIAS(t, func(authHeader string) (any, int) { + return &TokenInfo{Valid: false}, http.StatusOK + }) + defer server.Close() + + v := validatorFromServer(t, server) + + md := metadata.Pairs("authorization", "Bearer bad-token") + ctx := metadata.NewIncomingContext(context.Background(), md) + + _, err := callInterceptor(ctx, v) + if err == nil { + t.Fatal("expected error, got nil") + } + if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated { + t.Fatalf("expected Unauthenticated, got %v", err) + } +} + +func TestInterceptorRejectsNonAdmin(t *testing.T) { + server := mockMCIAS(t, func(authHeader string) (any, int) { + return &TokenInfo{ + Valid: true, + Username: "regularuser", + Roles: []string{"user"}, + AccountType: "human", + }, http.StatusOK + }) + defer server.Close() + + v := validatorFromServer(t, server) + + md := metadata.Pairs("authorization", "Bearer user-token") + ctx := metadata.NewIncomingContext(context.Background(), md) + + _, err := callInterceptor(ctx, v) + if err == nil { + t.Fatal("expected error, got nil") + } + if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied { + t.Fatalf("expected PermissionDenied, got %v", err) + } +} + +func TestInterceptorAcceptsAdmin(t *testing.T) { + server := mockMCIAS(t, func(authHeader string) (any, int) { + return &TokenInfo{ + Valid: true, + Username: "kyle", + Roles: []string{"admin", "user"}, + AccountType: "human", + }, http.StatusOK + }) + defer server.Close() + + v := validatorFromServer(t, server) + + md := metadata.Pairs("authorization", "Bearer admin-token") + ctx := metadata.NewIncomingContext(context.Background(), md) + + captured, err := callInterceptor(ctx, v) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if captured == nil { + t.Fatal("expected token info in context, got nil") + } + if captured.Username != "kyle" { + t.Fatalf("username: got %q, want %q", captured.Username, "kyle") + } + if !captured.HasRole("admin") { + t.Fatal("expected admin role") + } + if captured.AccountType != "human" { + t.Fatalf("account_type: got %q, want %q", captured.AccountType, "human") + } +} + +func TestTokenCaching(t *testing.T) { + var requestCount atomic.Int64 + + server := mockMCIAS(t, func(authHeader string) (any, int) { + requestCount.Add(1) + return &TokenInfo{ + Valid: true, + Username: "kyle", + Roles: []string{"admin"}, + AccountType: "human", + }, http.StatusOK + }) + defer server.Close() + + v := validatorFromServer(t, server) + + ctx := context.Background() + + // First call should hit the server. + info1, err := v.ValidateToken(ctx, "same-token") + if err != nil { + t.Fatalf("first validate: %v", err) + } + if !info1.Valid { + t.Fatal("expected valid token") + } + + // Second call with the same token should be cached. + info2, err := v.ValidateToken(ctx, "same-token") + if err != nil { + t.Fatalf("second validate: %v", err) + } + if info2.Username != info1.Username { + t.Fatalf("cached result mismatch: got %q, want %q", info2.Username, info1.Username) + } + + if count := requestCount.Load(); count != 1 { + t.Fatalf("expected 1 MCIAS request, got %d", count) + } +} + +func TestTokenCacheSeparateEntries(t *testing.T) { + var requestCount atomic.Int64 + + server := mockMCIAS(t, func(authHeader string) (any, int) { + requestCount.Add(1) + // Return different usernames based on the token. + token := authHeader[len("Bearer "):] + return &TokenInfo{ + Valid: true, + Username: fmt.Sprintf("user-for-%s", token), + Roles: []string{"admin"}, + AccountType: "human", + }, http.StatusOK + }) + defer server.Close() + + v := validatorFromServer(t, server) + + ctx := context.Background() + + info1, err := v.ValidateToken(ctx, "token-a") + if err != nil { + t.Fatalf("validate token-a: %v", err) + } + + info2, err := v.ValidateToken(ctx, "token-b") + if err != nil { + t.Fatalf("validate token-b: %v", err) + } + + if info1.Username == info2.Username { + t.Fatalf("different tokens should have different cache entries, both got %q", info1.Username) + } + + if count := requestCount.Load(); count != 2 { + t.Fatalf("expected 2 MCIAS requests for different tokens, got %d", count) + } + + // Repeat calls should be cached. + _, err = v.ValidateToken(ctx, "token-a") + if err != nil { + t.Fatalf("cached validate token-a: %v", err) + } + _, err = v.ValidateToken(ctx, "token-b") + if err != nil { + t.Fatalf("cached validate token-b: %v", err) + } + + if count := requestCount.Load(); count != 2 { + t.Fatalf("expected still 2 MCIAS requests after cache hits, got %d", count) + } +} + +func TestLoadSaveToken(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "subdir", "token") + + token := "test-bearer-token-12345" + + if err := SaveToken(path, token); err != nil { + t.Fatalf("save: %v", err) + } + + // Check file permissions. + fi, err := os.Stat(path) + if err != nil { + t.Fatalf("stat: %v", err) + } + if perm := fi.Mode().Perm(); perm != 0600 { + t.Fatalf("permissions: got %o, want 0600", perm) + } + + // Load and verify. + loaded, err := LoadToken(path) + if err != nil { + t.Fatalf("load: %v", err) + } + if loaded != token { + t.Fatalf("loaded: got %q, want %q", loaded, token) + } +} + +func TestLogin(t *testing.T) { + expectedToken := "mcias-session-token-xyz" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/auth/login" { + t.Errorf("unexpected path: %s", r.URL.Path) + http.NotFound(w, r) + return + } + if r.Method != http.MethodPost { + t.Errorf("unexpected method: %s", r.Method) + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + var body struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("decode request body: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + if body.Username != "kyle" || body.Password != "secret" { + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid credentials"}) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"token": expectedToken}) + })) + defer server.Close() + + token, err := Login(server.URL, "", "kyle", "secret") + if err != nil { + t.Fatalf("login: %v", err) + } + if token != expectedToken { + t.Fatalf("token: got %q, want %q", token, expectedToken) + } +} + +func TestLoginBadCredentials(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid credentials"}) + })) + defer server.Close() + + _, err := Login(server.URL, "", "kyle", "wrong") + if err == nil { + t.Fatal("expected error for bad credentials") + } +} + +func TestContextTokenInfo(t *testing.T) { + info := &TokenInfo{ + Valid: true, + Username: "kyle", + Roles: []string{"admin"}, + AccountType: "human", + } + + ctx := ContextWithTokenInfo(context.Background(), info) + got := TokenInfoFromContext(ctx) + if got == nil { + t.Fatal("expected token info from context, got nil") + } + if got.Username != "kyle" { + t.Fatalf("username: got %q, want %q", got.Username, "kyle") + } + + // Empty context should return nil. + got = TokenInfoFromContext(context.Background()) + if got != nil { + t.Fatalf("expected nil from empty context, got %+v", got) + } +} diff --git a/internal/config/agent.go b/internal/config/agent.go new file mode 100644 index 0000000..b706d81 --- /dev/null +++ b/internal/config/agent.go @@ -0,0 +1,186 @@ +package config + +import ( + "fmt" + "os" + "time" + + toml "github.com/pelletier/go-toml/v2" +) + +// AgentConfig is the configuration for the mcp-agent daemon. +type AgentConfig struct { + Server ServerConfig `toml:"server"` + Database DatabaseConfig `toml:"database"` + MCIAS MCIASConfig `toml:"mcias"` + Agent AgentSettings `toml:"agent"` + Monitor MonitorConfig `toml:"monitor"` + Log LogConfig `toml:"log"` +} + +// ServerConfig holds gRPC server listen address and TLS paths. +type ServerConfig struct { + GRPCAddr string `toml:"grpc_addr"` + TLSCert string `toml:"tls_cert"` + TLSKey string `toml:"tls_key"` +} + +// DatabaseConfig holds the SQLite database path. +type DatabaseConfig struct { + Path string `toml:"path"` +} + +// AgentSettings holds agent-specific settings. +type AgentSettings struct { + NodeName string `toml:"node_name"` + ContainerRuntime string `toml:"container_runtime"` +} + +// MonitorConfig holds monitoring loop parameters. +type MonitorConfig struct { + Interval Duration `toml:"interval"` + AlertCommand []string `toml:"alert_command"` + Cooldown Duration `toml:"cooldown"` + FlapThreshold int `toml:"flap_threshold"` + FlapWindow Duration `toml:"flap_window"` + Retention Duration `toml:"retention"` +} + +// LogConfig holds logging settings. +type LogConfig struct { + Level string `toml:"level"` +} + +// Duration wraps time.Duration to support TOML string unmarshaling. +// It accepts Go duration strings (e.g. "60s", "15m", "24h") plus a +// "d" suffix for days (e.g. "30d" becomes 30*24h). +type Duration struct { + time.Duration +} + +// UnmarshalText implements encoding.TextUnmarshaler for TOML parsing. +func (d *Duration) UnmarshalText(text []byte) error { + s := string(text) + if s == "" { + d.Duration = 0 + return nil + } + + // Handle "d" suffix for days. + if len(s) > 1 && s[len(s)-1] == 'd' { + var days float64 + if _, err := fmt.Sscanf(s[:len(s)-1], "%f", &days); err != nil { + return fmt.Errorf("parse duration %q: %w", s, err) + } + d.Duration = time.Duration(days * float64(24*time.Hour)) + return nil + } + + dur, err := time.ParseDuration(s) + if err != nil { + return fmt.Errorf("parse duration %q: %w", s, err) + } + d.Duration = dur + return nil +} + +// MarshalText implements encoding.TextMarshaler for TOML serialization. +func (d Duration) MarshalText() ([]byte, error) { + return []byte(d.String()), nil +} + +// LoadAgentConfig reads and validates an agent configuration file. +// Environment variables override file values for select fields. +func LoadAgentConfig(path string) (*AgentConfig, error) { + data, err := os.ReadFile(path) //nolint:gosec // config path from trusted CLI flag + if err != nil { + return nil, fmt.Errorf("read config %q: %w", path, err) + } + + var cfg AgentConfig + if err := toml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config %q: %w", path, err) + } + + applyAgentDefaults(&cfg) + applyAgentEnvOverrides(&cfg) + + if err := validateAgentConfig(&cfg); err != nil { + return nil, fmt.Errorf("validate config: %w", err) + } + + return &cfg, nil +} + +func applyAgentDefaults(cfg *AgentConfig) { + if cfg.Monitor.Interval.Duration == 0 { + cfg.Monitor.Interval.Duration = 60 * time.Second + } + if cfg.Monitor.Cooldown.Duration == 0 { + cfg.Monitor.Cooldown.Duration = 15 * time.Minute + } + if cfg.Monitor.FlapThreshold == 0 { + cfg.Monitor.FlapThreshold = 3 + } + if cfg.Monitor.FlapWindow.Duration == 0 { + cfg.Monitor.FlapWindow.Duration = 10 * time.Minute + } + if cfg.Monitor.Retention.Duration == 0 { + cfg.Monitor.Retention.Duration = 30 * 24 * time.Hour // 30 days + } + if cfg.Log.Level == "" { + cfg.Log.Level = "info" + } + if cfg.Agent.ContainerRuntime == "" { + cfg.Agent.ContainerRuntime = "podman" + } +} + +func applyAgentEnvOverrides(cfg *AgentConfig) { + if v := os.Getenv("MCP_AGENT_SERVER_GRPC_ADDR"); v != "" { + cfg.Server.GRPCAddr = v + } + if v := os.Getenv("MCP_AGENT_SERVER_TLS_CERT"); v != "" { + cfg.Server.TLSCert = v + } + if v := os.Getenv("MCP_AGENT_SERVER_TLS_KEY"); v != "" { + cfg.Server.TLSKey = v + } + if v := os.Getenv("MCP_AGENT_DATABASE_PATH"); v != "" { + cfg.Database.Path = v + } + if v := os.Getenv("MCP_AGENT_NODE_NAME"); v != "" { + cfg.Agent.NodeName = v + } + if v := os.Getenv("MCP_AGENT_CONTAINER_RUNTIME"); v != "" { + cfg.Agent.ContainerRuntime = v + } + if v := os.Getenv("MCP_AGENT_LOG_LEVEL"); v != "" { + cfg.Log.Level = v + } +} + +func validateAgentConfig(cfg *AgentConfig) error { + if cfg.Server.GRPCAddr == "" { + return fmt.Errorf("server.grpc_addr is required") + } + if cfg.Server.TLSCert == "" { + return fmt.Errorf("server.tls_cert is required") + } + if cfg.Server.TLSKey == "" { + return fmt.Errorf("server.tls_key is required") + } + if cfg.Database.Path == "" { + return fmt.Errorf("database.path is required") + } + if cfg.MCIAS.ServerURL == "" { + return fmt.Errorf("mcias.server_url is required") + } + if cfg.MCIAS.ServiceName == "" { + return fmt.Errorf("mcias.service_name is required") + } + if cfg.Agent.NodeName == "" { + return fmt.Errorf("agent.node_name is required") + } + return nil +} diff --git a/internal/config/cli.go b/internal/config/cli.go new file mode 100644 index 0000000..e07aeb4 --- /dev/null +++ b/internal/config/cli.go @@ -0,0 +1,97 @@ +package config + +import ( + "fmt" + "os" + + toml "github.com/pelletier/go-toml/v2" +) + +// CLIConfig is the configuration for the mcp CLI binary. +type CLIConfig struct { + Services ServicesConfig `toml:"services"` + MCIAS MCIASConfig `toml:"mcias"` + Auth AuthConfig `toml:"auth"` + Nodes []NodeConfig `toml:"nodes"` +} + +// ServicesConfig defines where service definition files live. +type ServicesConfig struct { + Dir string `toml:"dir"` +} + +// MCIASConfig holds MCIAS connection settings, shared by CLI and agent. +type MCIASConfig struct { + ServerURL string `toml:"server_url"` + CACert string `toml:"ca_cert"` + ServiceName string `toml:"service_name"` +} + +// AuthConfig holds authentication settings for the CLI. +type AuthConfig struct { + TokenPath string `toml:"token_path"` + Username string `toml:"username"` // optional, for unattended operation + PasswordFile string `toml:"password_file"` // optional, for unattended operation +} + +// NodeConfig defines a managed node that the CLI connects to. +type NodeConfig struct { + Name string `toml:"name"` + Address string `toml:"address"` +} + +// LoadCLIConfig reads and validates a CLI configuration file. +// Environment variables override file values for select fields. +func LoadCLIConfig(path string) (*CLIConfig, error) { + data, err := os.ReadFile(path) //nolint:gosec // config path from trusted CLI flag + if err != nil { + return nil, fmt.Errorf("read config %q: %w", path, err) + } + + var cfg CLIConfig + if err := toml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse config %q: %w", path, err) + } + + applyCLIEnvOverrides(&cfg) + + if err := validateCLIConfig(&cfg); err != nil { + return nil, fmt.Errorf("validate config: %w", err) + } + + return &cfg, nil +} + +func applyCLIEnvOverrides(cfg *CLIConfig) { + if v := os.Getenv("MCP_SERVICES_DIR"); v != "" { + cfg.Services.Dir = v + } + if v := os.Getenv("MCP_MCIAS_SERVER_URL"); v != "" { + cfg.MCIAS.ServerURL = v + } + if v := os.Getenv("MCP_MCIAS_CA_CERT"); v != "" { + cfg.MCIAS.CACert = v + } + if v := os.Getenv("MCP_MCIAS_SERVICE_NAME"); v != "" { + cfg.MCIAS.ServiceName = v + } + if v := os.Getenv("MCP_AUTH_TOKEN_PATH"); v != "" { + cfg.Auth.TokenPath = v + } +} + +func validateCLIConfig(cfg *CLIConfig) error { + if cfg.Services.Dir == "" { + return fmt.Errorf("services.dir is required") + } + if cfg.MCIAS.ServerURL == "" { + return fmt.Errorf("mcias.server_url is required") + } + if cfg.MCIAS.ServiceName == "" { + return fmt.Errorf("mcias.service_name is required") + } + if cfg.Auth.TokenPath == "" { + return fmt.Errorf("auth.token_path is required") + } + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..7fe2704 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,476 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +const testCLIConfig = ` +[services] +dir = "/home/kyle/.config/mcp/services" + +[mcias] +server_url = "https://mcias.metacircular.net:8443" +ca_cert = "/etc/mcp/ca.pem" +service_name = "mcp" + +[auth] +token_path = "/home/kyle/.config/mcp/token" +username = "kyle" +password_file = "/home/kyle/.config/mcp/password" + +[[nodes]] +name = "rift" +address = "100.95.252.120:9444" + +[[nodes]] +name = "cascade" +address = "100.95.252.121:9444" +` + +const testAgentConfig = ` +[server] +grpc_addr = "100.95.252.120:9444" +tls_cert = "/srv/mcp/certs/cert.pem" +tls_key = "/srv/mcp/certs/key.pem" + +[database] +path = "/srv/mcp/mcp.db" + +[mcias] +server_url = "https://mcias.metacircular.net:8443" +ca_cert = "/etc/mcp/ca.pem" +service_name = "mcp-agent" + +[agent] +node_name = "rift" +container_runtime = "podman" + +[monitor] +interval = "60s" +alert_command = ["notify-send", "MCP Alert"] +cooldown = "15m" +flap_threshold = 3 +flap_window = "10m" +retention = "30d" + +[log] +level = "debug" +` + +func writeTempConfig(t *testing.T, content string) string { + t.Helper() + path := filepath.Join(t.TempDir(), "config.toml") + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatalf("write temp config: %v", err) + } + return path +} + +func TestLoadCLIConfig(t *testing.T) { + path := writeTempConfig(t, testCLIConfig) + + cfg, err := LoadCLIConfig(path) + if err != nil { + t.Fatalf("load: %v", err) + } + + if cfg.Services.Dir != "/home/kyle/.config/mcp/services" { + t.Fatalf("services.dir: got %q", cfg.Services.Dir) + } + if cfg.MCIAS.ServerURL != "https://mcias.metacircular.net:8443" { + t.Fatalf("mcias.server_url: got %q", cfg.MCIAS.ServerURL) + } + if cfg.MCIAS.CACert != "/etc/mcp/ca.pem" { + t.Fatalf("mcias.ca_cert: got %q", cfg.MCIAS.CACert) + } + if cfg.MCIAS.ServiceName != "mcp" { + t.Fatalf("mcias.service_name: got %q", cfg.MCIAS.ServiceName) + } + if cfg.Auth.TokenPath != "/home/kyle/.config/mcp/token" { + t.Fatalf("auth.token_path: got %q", cfg.Auth.TokenPath) + } + if cfg.Auth.Username != "kyle" { + t.Fatalf("auth.username: got %q", cfg.Auth.Username) + } + if cfg.Auth.PasswordFile != "/home/kyle/.config/mcp/password" { + t.Fatalf("auth.password_file: got %q", cfg.Auth.PasswordFile) + } + if len(cfg.Nodes) != 2 { + t.Fatalf("nodes: got %d, want 2", len(cfg.Nodes)) + } + if cfg.Nodes[0].Name != "rift" || cfg.Nodes[0].Address != "100.95.252.120:9444" { + t.Fatalf("nodes[0]: got %+v", cfg.Nodes[0]) + } + if cfg.Nodes[1].Name != "cascade" || cfg.Nodes[1].Address != "100.95.252.121:9444" { + t.Fatalf("nodes[1]: got %+v", cfg.Nodes[1]) + } +} + +func TestLoadAgentConfig(t *testing.T) { + path := writeTempConfig(t, testAgentConfig) + + cfg, err := LoadAgentConfig(path) + if err != nil { + t.Fatalf("load: %v", err) + } + + if cfg.Server.GRPCAddr != "100.95.252.120:9444" { + t.Fatalf("server.grpc_addr: got %q", cfg.Server.GRPCAddr) + } + if cfg.Server.TLSCert != "/srv/mcp/certs/cert.pem" { + t.Fatalf("server.tls_cert: got %q", cfg.Server.TLSCert) + } + if cfg.Server.TLSKey != "/srv/mcp/certs/key.pem" { + t.Fatalf("server.tls_key: got %q", cfg.Server.TLSKey) + } + if cfg.Database.Path != "/srv/mcp/mcp.db" { + t.Fatalf("database.path: got %q", cfg.Database.Path) + } + if cfg.MCIAS.ServerURL != "https://mcias.metacircular.net:8443" { + t.Fatalf("mcias.server_url: got %q", cfg.MCIAS.ServerURL) + } + if cfg.MCIAS.ServiceName != "mcp-agent" { + t.Fatalf("mcias.service_name: got %q", cfg.MCIAS.ServiceName) + } + if cfg.Agent.NodeName != "rift" { + t.Fatalf("agent.node_name: got %q", cfg.Agent.NodeName) + } + if cfg.Agent.ContainerRuntime != "podman" { + t.Fatalf("agent.container_runtime: got %q", cfg.Agent.ContainerRuntime) + } + if cfg.Monitor.Interval.Duration != 60*time.Second { + t.Fatalf("monitor.interval: got %v", cfg.Monitor.Interval.Duration) + } + if len(cfg.Monitor.AlertCommand) != 2 || cfg.Monitor.AlertCommand[0] != "notify-send" { + t.Fatalf("monitor.alert_command: got %v", cfg.Monitor.AlertCommand) + } + if cfg.Monitor.Cooldown.Duration != 15*time.Minute { + t.Fatalf("monitor.cooldown: got %v", cfg.Monitor.Cooldown.Duration) + } + if cfg.Monitor.FlapThreshold != 3 { + t.Fatalf("monitor.flap_threshold: got %d", cfg.Monitor.FlapThreshold) + } + if cfg.Monitor.FlapWindow.Duration != 10*time.Minute { + t.Fatalf("monitor.flap_window: got %v", cfg.Monitor.FlapWindow.Duration) + } + if cfg.Monitor.Retention.Duration != 30*24*time.Hour { + t.Fatalf("monitor.retention: got %v", cfg.Monitor.Retention.Duration) + } + if cfg.Log.Level != "debug" { + t.Fatalf("log.level: got %q", cfg.Log.Level) + } +} + +func TestCLIConfigValidation(t *testing.T) { + tests := []struct { + name string + config string + errMsg string + }{ + { + name: "missing services.dir", + config: ` +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp" +[auth] +token_path = "/tmp/token" +`, + errMsg: "services.dir is required", + }, + { + name: "missing mcias.server_url", + config: ` +[services] +dir = "/tmp/services" +[mcias] +service_name = "mcp" +[auth] +token_path = "/tmp/token" +`, + errMsg: "mcias.server_url is required", + }, + { + name: "missing mcias.service_name", + config: ` +[services] +dir = "/tmp/services" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +[auth] +token_path = "/tmp/token" +`, + errMsg: "mcias.service_name is required", + }, + { + name: "missing auth.token_path", + config: ` +[services] +dir = "/tmp/services" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp" +`, + errMsg: "auth.token_path is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := writeTempConfig(t, tt.config) + _, err := LoadCLIConfig(path) + if err == nil { + t.Fatal("expected error, got nil") + } + if got := err.Error(); !strings.Contains(got, tt.errMsg) { + t.Fatalf("error %q does not contain %q", got, tt.errMsg) + } + }) + } +} + +func TestAgentConfigValidation(t *testing.T) { + // Minimal valid agent config to start from. + base := ` +[server] +grpc_addr = "0.0.0.0:9444" +tls_cert = "/srv/mcp/cert.pem" +tls_key = "/srv/mcp/key.pem" +[database] +path = "/srv/mcp/mcp.db" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp-agent" +[agent] +node_name = "rift" +` + + tests := []struct { + name string + config string + errMsg string + }{ + { + name: "missing server.grpc_addr", + config: ` +[server] +tls_cert = "/srv/mcp/cert.pem" +tls_key = "/srv/mcp/key.pem" +[database] +path = "/srv/mcp/mcp.db" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp-agent" +[agent] +node_name = "rift" +`, + errMsg: "server.grpc_addr is required", + }, + { + name: "missing server.tls_cert", + config: ` +[server] +grpc_addr = "0.0.0.0:9444" +tls_key = "/srv/mcp/key.pem" +[database] +path = "/srv/mcp/mcp.db" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp-agent" +[agent] +node_name = "rift" +`, + errMsg: "server.tls_cert is required", + }, + { + name: "missing database.path", + config: ` +[server] +grpc_addr = "0.0.0.0:9444" +tls_cert = "/srv/mcp/cert.pem" +tls_key = "/srv/mcp/key.pem" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp-agent" +[agent] +node_name = "rift" +`, + errMsg: "database.path is required", + }, + { + name: "missing agent.node_name", + config: ` +[server] +grpc_addr = "0.0.0.0:9444" +tls_cert = "/srv/mcp/cert.pem" +tls_key = "/srv/mcp/key.pem" +[database] +path = "/srv/mcp/mcp.db" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp-agent" +`, + errMsg: "agent.node_name is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := writeTempConfig(t, tt.config) + _, err := LoadAgentConfig(path) + if err == nil { + t.Fatal("expected error, got nil") + } + if got := err.Error(); !strings.Contains(got, tt.errMsg) { + t.Fatalf("error %q does not contain %q", got, tt.errMsg) + } + }) + } + + // Verify base config actually loads fine. + t.Run("valid base config", func(t *testing.T) { + path := writeTempConfig(t, base) + if _, err := LoadAgentConfig(path); err != nil { + t.Fatalf("base config should be valid: %v", err) + } + }) +} + +func TestAgentConfigDefaults(t *testing.T) { + // Only required fields, no monitor/log/runtime settings. + minimal := ` +[server] +grpc_addr = "0.0.0.0:9444" +tls_cert = "/srv/mcp/cert.pem" +tls_key = "/srv/mcp/key.pem" +[database] +path = "/srv/mcp/mcp.db" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp-agent" +[agent] +node_name = "rift" +` + + path := writeTempConfig(t, minimal) + cfg, err := LoadAgentConfig(path) + if err != nil { + t.Fatalf("load: %v", err) + } + + if cfg.Monitor.Interval.Duration != 60*time.Second { + t.Fatalf("default interval: got %v, want 60s", cfg.Monitor.Interval.Duration) + } + if cfg.Monitor.Cooldown.Duration != 15*time.Minute { + t.Fatalf("default cooldown: got %v, want 15m", cfg.Monitor.Cooldown.Duration) + } + if cfg.Monitor.FlapThreshold != 3 { + t.Fatalf("default flap_threshold: got %d, want 3", cfg.Monitor.FlapThreshold) + } + if cfg.Monitor.FlapWindow.Duration != 10*time.Minute { + t.Fatalf("default flap_window: got %v, want 10m", cfg.Monitor.FlapWindow.Duration) + } + if cfg.Monitor.Retention.Duration != 30*24*time.Hour { + t.Fatalf("default retention: got %v, want 720h", cfg.Monitor.Retention.Duration) + } + if cfg.Log.Level != "info" { + t.Fatalf("default log level: got %q, want info", cfg.Log.Level) + } + if cfg.Agent.ContainerRuntime != "podman" { + t.Fatalf("default container_runtime: got %q, want podman", cfg.Agent.ContainerRuntime) + } +} + +func TestEnvVarOverrides(t *testing.T) { + // Test agent env var override. + minimal := ` +[server] +grpc_addr = "0.0.0.0:9444" +tls_cert = "/srv/mcp/cert.pem" +tls_key = "/srv/mcp/key.pem" +[database] +path = "/srv/mcp/mcp.db" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp-agent" +[agent] +node_name = "rift" +[log] +level = "info" +` + t.Run("agent log level override", func(t *testing.T) { + t.Setenv("MCP_AGENT_LOG_LEVEL", "debug") + path := writeTempConfig(t, minimal) + cfg, err := LoadAgentConfig(path) + if err != nil { + t.Fatalf("load: %v", err) + } + if cfg.Log.Level != "debug" { + t.Fatalf("log level: got %q, want debug", cfg.Log.Level) + } + }) + + t.Run("agent node name override", func(t *testing.T) { + t.Setenv("MCP_AGENT_NODE_NAME", "override-node") + path := writeTempConfig(t, minimal) + cfg, err := LoadAgentConfig(path) + if err != nil { + t.Fatalf("load: %v", err) + } + if cfg.Agent.NodeName != "override-node" { + t.Fatalf("node_name: got %q, want override-node", cfg.Agent.NodeName) + } + }) + + t.Run("CLI services dir override", func(t *testing.T) { + t.Setenv("MCP_SERVICES_DIR", "/override/services") + path := writeTempConfig(t, testCLIConfig) + cfg, err := LoadCLIConfig(path) + if err != nil { + t.Fatalf("load: %v", err) + } + if cfg.Services.Dir != "/override/services" { + t.Fatalf("services.dir: got %q, want /override/services", cfg.Services.Dir) + } + }) +} + +func TestDurationParsing(t *testing.T) { + tests := []struct { + input string + want time.Duration + fail bool + }{ + {input: "60s", want: 60 * time.Second}, + {input: "15m", want: 15 * time.Minute}, + {input: "24h", want: 24 * time.Hour}, + {input: "30d", want: 30 * 24 * time.Hour}, + {input: "1d", want: 24 * time.Hour}, + {input: "", want: 0}, + {input: "bogus", fail: true}, + {input: "notanumber-d", fail: true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + var d Duration + err := d.UnmarshalText([]byte(tt.input)) + if tt.fail { + if err == nil { + t.Fatalf("expected error for %q, got nil", tt.input) + } + return + } + if err != nil { + t.Fatalf("unexpected error for %q: %v", tt.input, err) + } + if d.Duration != tt.want { + t.Fatalf("for %q: got %v, want %v", tt.input, d.Duration, tt.want) + } + }) + } +} diff --git a/internal/runtime/podman.go b/internal/runtime/podman.go new file mode 100644 index 0000000..defc6d7 --- /dev/null +++ b/internal/runtime/podman.go @@ -0,0 +1,208 @@ +package runtime + +import ( + "context" + "encoding/json" + "fmt" + "os/exec" + "strings" +) + +// Podman implements the Runtime interface using the podman CLI. +type Podman struct { + Command string // path to podman binary, default "podman" +} + +func (p *Podman) command() string { + if p.Command != "" { + return p.Command + } + return "podman" +} + +// Pull pulls a container image. +func (p *Podman) Pull(ctx context.Context, image string) error { + cmd := exec.CommandContext(ctx, p.command(), "pull", image) //nolint:gosec // args built programmatically + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("podman pull %q: %w: %s", image, err, out) + } + return nil +} + +// BuildRunArgs constructs the argument list for podman run. Exported for testing. +func (p *Podman) BuildRunArgs(spec ContainerSpec) []string { + args := []string{"run", "-d", "--name", spec.Name} + + if spec.Network != "" { + args = append(args, "--network", spec.Network) + } + if spec.User != "" { + args = append(args, "--user", spec.User) + } + if spec.Restart != "" { + args = append(args, "--restart", spec.Restart) + } + for _, port := range spec.Ports { + args = append(args, "-p", port) + } + for _, vol := range spec.Volumes { + args = append(args, "-v", vol) + } + + args = append(args, spec.Image) + args = append(args, spec.Cmd...) + + return args +} + +// Run creates and starts a container from the given spec. +func (p *Podman) Run(ctx context.Context, spec ContainerSpec) error { + args := p.BuildRunArgs(spec) + cmd := exec.CommandContext(ctx, p.command(), args...) //nolint:gosec // args built programmatically + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("podman run %q: %w: %s", spec.Name, err, out) + } + return nil +} + +// Stop stops a running container. +func (p *Podman) Stop(ctx context.Context, name string) error { + cmd := exec.CommandContext(ctx, p.command(), "stop", name) //nolint:gosec // args built programmatically + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("podman stop %q: %w: %s", name, err, out) + } + return nil +} + +// Remove removes a container. +func (p *Podman) Remove(ctx context.Context, name string) error { + cmd := exec.CommandContext(ctx, p.command(), "rm", name) //nolint:gosec // args built programmatically + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("podman rm %q: %w: %s", name, err, out) + } + return nil +} + +// podmanPortBinding is a single port binding from the inspect output. +type podmanPortBinding struct { + HostIP string `json:"HostIp"` + HostPort string `json:"HostPort"` +} + +// podmanInspectResult is the subset of podman inspect JSON we parse. +type podmanInspectResult struct { + Name string `json:"Name"` + Config struct { + Image string `json:"Image"` + Cmd []string `json:"Cmd"` + User string `json:"User"` + } `json:"Config"` + State struct { + Status string `json:"Status"` + } `json:"State"` + HostConfig struct { + RestartPolicy struct { + Name string `json:"Name"` + } `json:"RestartPolicy"` + NetworkMode string `json:"NetworkMode"` + } `json:"HostConfig"` + NetworkSettings struct { + Networks map[string]struct{} `json:"Networks"` + Ports map[string][]podmanPortBinding `json:"Ports"` + } `json:"NetworkSettings"` + Mounts []struct { + Source string `json:"Source"` + Destination string `json:"Destination"` + } `json:"Mounts"` +} + +// Inspect retrieves information about a container. +func (p *Podman) Inspect(ctx context.Context, name string) (ContainerInfo, error) { + cmd := exec.CommandContext(ctx, p.command(), "inspect", name) //nolint:gosec // args built programmatically + out, err := cmd.Output() + if err != nil { + return ContainerInfo{}, fmt.Errorf("podman inspect %q: %w", name, err) + } + + var results []podmanInspectResult + if err := json.Unmarshal(out, &results); err != nil { + return ContainerInfo{}, fmt.Errorf("parse inspect output: %w", err) + } + if len(results) == 0 { + return ContainerInfo{}, fmt.Errorf("podman inspect %q: no results", name) + } + + r := results[0] + info := ContainerInfo{ + Name: strings.TrimPrefix(r.Name, "/"), + Image: r.Config.Image, + State: r.State.Status, + User: r.Config.User, + Restart: r.HostConfig.RestartPolicy.Name, + Cmd: r.Config.Cmd, + Version: ExtractVersion(r.Config.Image), + } + + info.Network = r.HostConfig.NetworkMode + if len(r.NetworkSettings.Networks) > 0 { + for netName := range r.NetworkSettings.Networks { + info.Network = netName + break + } + } + + for containerPort, bindings := range r.NetworkSettings.Ports { + for _, b := range bindings { + port := strings.SplitN(containerPort, "/", 2)[0] // strip "/tcp" suffix + mapping := b.HostPort + ":" + port + if b.HostIP != "" && b.HostIP != "0.0.0.0" { + mapping = b.HostIP + ":" + mapping + } + info.Ports = append(info.Ports, mapping) + } + } + + for _, m := range r.Mounts { + info.Volumes = append(info.Volumes, m.Source+":"+m.Destination) + } + + return info, nil +} + +// podmanPSEntry is a single entry from podman ps --format json. +type podmanPSEntry struct { + Names []string `json:"Names"` + Image string `json:"Image"` + State string `json:"State"` + Command string `json:"Command"` +} + +// List returns information about all containers. +func (p *Podman) List(ctx context.Context) ([]ContainerInfo, error) { + cmd := exec.CommandContext(ctx, p.command(), "ps", "-a", "--format", "json") //nolint:gosec // args built programmatically + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("podman ps: %w", err) + } + + var entries []podmanPSEntry + if err := json.Unmarshal(out, &entries); err != nil { + return nil, fmt.Errorf("parse ps output: %w", err) + } + + infos := make([]ContainerInfo, 0, len(entries)) + for _, e := range entries { + name := "" + if len(e.Names) > 0 { + name = e.Names[0] + } + infos = append(infos, ContainerInfo{ + Name: name, + Image: e.Image, + State: e.State, + Version: ExtractVersion(e.Image), + }) + } + + return infos, nil +} diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go new file mode 100644 index 0000000..42b639a --- /dev/null +++ b/internal/runtime/runtime.go @@ -0,0 +1,62 @@ +package runtime + +import ( + "context" + "strings" +) + +// ContainerSpec describes a container to create and run. +type ContainerSpec struct { + Name string // container name, format: - + Image string // full image reference + Network string // docker network name + User string // container user (e.g., "0:0") + Restart string // restart policy (e.g., "unless-stopped") + Ports []string // "host:container" port mappings + Volumes []string // "host:container" volume mounts + Cmd []string // command and arguments +} + +// ContainerInfo describes the observed state of a running or stopped container. +type ContainerInfo struct { + Name string + Image string + State string // "running", "stopped", "exited", etc. + Network string + User string + Restart string + Ports []string + Volumes []string + Cmd []string + Version string // extracted from image tag +} + +// Runtime is the container runtime abstraction. +type Runtime interface { + Pull(ctx context.Context, image string) error + Run(ctx context.Context, spec ContainerSpec) error + Stop(ctx context.Context, name string) error + Remove(ctx context.Context, name string) error + Inspect(ctx context.Context, name string) (ContainerInfo, error) + List(ctx context.Context) ([]ContainerInfo, error) +} + +// ExtractVersion parses the tag from an image reference. +// Examples: +// +// "registry/img:v1.2.0" -> "v1.2.0" +// "registry/img:latest" -> "latest" +// "registry/img" -> "" +// "registry:5000/img:v1" -> "v1" +func ExtractVersion(image string) string { + // Strip registry/path prefix so that a port like "registry:5000" isn't + // mistaken for a tag separator. + name := image + if i := strings.LastIndex(image, "/"); i >= 0 { + name = image[i+1:] + } + if i := strings.LastIndex(name, ":"); i >= 0 { + return name[i+1:] + } + return "" +} diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go new file mode 100644 index 0000000..5ebbd0b --- /dev/null +++ b/internal/runtime/runtime_test.go @@ -0,0 +1,113 @@ +package runtime + +import ( + "slices" + "testing" +) + +func requireEqualArgs(t *testing.T, got, want []string) { + t.Helper() + if !slices.Equal(got, want) { + t.Fatalf("args mismatch\ngot: %v\nwant: %v", got, want) + } +} + +func TestBuildRunArgs(t *testing.T) { + p := &Podman{} + + t.Run("full spec", func(t *testing.T) { + spec := ContainerSpec{ + Name: "metacrypt-api", + Image: "mcr.svc.mcp.metacircular.net:8443/metacrypt:v1.0.0", + Network: "docker_default", + User: "0:0", + Restart: "unless-stopped", + Ports: []string{"127.0.0.1:18443:8443", "127.0.0.1:19443:9443"}, + Volumes: []string{"/srv/metacrypt:/srv/metacrypt", "/etc/ssl:/etc/ssl:ro"}, + Cmd: []string{"server", "--config", "/srv/metacrypt/metacrypt.toml"}, + } + requireEqualArgs(t, p.BuildRunArgs(spec), []string{ + "run", "-d", "--name", "metacrypt-api", + "--network", "docker_default", + "--user", "0:0", + "--restart", "unless-stopped", + "-p", "127.0.0.1:18443:8443", + "-p", "127.0.0.1:19443:9443", + "-v", "/srv/metacrypt:/srv/metacrypt", + "-v", "/etc/ssl:/etc/ssl:ro", + "mcr.svc.mcp.metacircular.net:8443/metacrypt:v1.0.0", + "server", "--config", "/srv/metacrypt/metacrypt.toml", + }) + }) + + t.Run("minimal spec", func(t *testing.T) { + spec := ContainerSpec{ + Name: "test-app", + Image: "img:latest", + } + requireEqualArgs(t, p.BuildRunArgs(spec), []string{ + "run", "-d", "--name", "test-app", "img:latest", + }) + }) + + t.Run("ports only", func(t *testing.T) { + spec := ContainerSpec{ + Name: "test-app", + Image: "img:latest", + Ports: []string{"8080:80", "8443:443"}, + } + requireEqualArgs(t, p.BuildRunArgs(spec), []string{ + "run", "-d", "--name", "test-app", + "-p", "8080:80", "-p", "8443:443", + "img:latest", + }) + }) + + t.Run("volumes only", func(t *testing.T) { + spec := ContainerSpec{ + Name: "test-app", + Image: "img:latest", + Volumes: []string{"/data:/data", "/config:/config:ro"}, + } + requireEqualArgs(t, p.BuildRunArgs(spec), []string{ + "run", "-d", "--name", "test-app", + "-v", "/data:/data", "-v", "/config:/config:ro", + "img:latest", + }) + }) + + t.Run("cmd after image", func(t *testing.T) { + spec := ContainerSpec{ + Name: "test-app", + Image: "img:latest", + Cmd: []string{"serve", "--port", "8080"}, + } + requireEqualArgs(t, p.BuildRunArgs(spec), []string{ + "run", "-d", "--name", "test-app", + "img:latest", + "serve", "--port", "8080", + }) + }) +} + +func TestExtractVersion(t *testing.T) { + tests := []struct { + image string + want string + }{ + {"registry.example.com:5000/img:v1.2.0", "v1.2.0"}, + {"img:latest", "latest"}, + {"img", ""}, + {"registry.example.com/path/img:v1", "v1"}, + {"registry.example.com:5000/path/img", ""}, + } + + for _, tt := range tests { + t.Run(tt.image, func(t *testing.T) { + got := ExtractVersion(tt.image) + if got != tt.want { + t.Fatalf("ExtractVersion(%q) = %q, want %q", tt.image, got, tt.want) + } + }) + } +} diff --git a/internal/servicedef/servicedef.go b/internal/servicedef/servicedef.go new file mode 100644 index 0000000..11b12d7 --- /dev/null +++ b/internal/servicedef/servicedef.go @@ -0,0 +1,185 @@ +// Package servicedef handles parsing and writing TOML service definition files. +package servicedef + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + toml "github.com/pelletier/go-toml/v2" + + mcpv1 "git.wntrmute.dev/kyle/mcp/gen/mcp/v1" +) + +// ServiceDef is the top-level TOML structure for a service definition file. +type ServiceDef struct { + Name string `toml:"name"` + Node string `toml:"node"` + Active *bool `toml:"active,omitempty"` + Components []ComponentDef `toml:"components"` +} + +// ComponentDef describes a single container component within a service. +type ComponentDef struct { + Name string `toml:"name"` + Image string `toml:"image"` + Network string `toml:"network,omitempty"` + User string `toml:"user,omitempty"` + Restart string `toml:"restart,omitempty"` + Ports []string `toml:"ports,omitempty"` + Volumes []string `toml:"volumes,omitempty"` + Cmd []string `toml:"cmd,omitempty"` +} + +// Load reads and parses a TOML service definition file. If the active field +// is omitted, it defaults to true. +func Load(path string) (*ServiceDef, error) { + data, err := os.ReadFile(path) //nolint:gosec // path from trusted config dir + if err != nil { + return nil, fmt.Errorf("read service def %q: %w", path, err) + } + + var def ServiceDef + if err := toml.Unmarshal(data, &def); err != nil { + return nil, fmt.Errorf("parse service def %q: %w", path, err) + } + + if err := validate(&def); err != nil { + return nil, fmt.Errorf("validate service def %q: %w", path, err) + } + + if def.Active == nil { + t := true + def.Active = &t + } + + return &def, nil +} + +// Write serializes a ServiceDef to TOML and writes it to the given path +// with 0644 permissions. Parent directories are created if needed. +func Write(path string, def *ServiceDef) error { + if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil { //nolint:gosec // service defs are non-secret + return fmt.Errorf("create parent dirs for %q: %w", path, err) + } + + data, err := toml.Marshal(def) + if err != nil { + return fmt.Errorf("marshal service def: %w", err) + } + + if err := os.WriteFile(path, data, 0o644); err != nil { //nolint:gosec // service defs are non-secret + return fmt.Errorf("write service def %q: %w", path, err) + } + + return nil +} + +// LoadAll reads all .toml files from dir, parses each one, and returns the +// list sorted by service name. +func LoadAll(dir string) ([]*ServiceDef, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("read service dir %q: %w", dir, err) + } + + var defs []*ServiceDef + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".toml") { + continue + } + def, err := Load(filepath.Join(dir, e.Name())) + if err != nil { + return nil, err + } + defs = append(defs, def) + } + + sort.Slice(defs, func(i, j int) bool { + return defs[i].Name < defs[j].Name + }) + + return defs, nil +} + +// validate checks that a ServiceDef has all required fields and that +// component names are unique. +func validate(def *ServiceDef) error { + if def.Name == "" { + return fmt.Errorf("service name is required") + } + if def.Node == "" { + return fmt.Errorf("service node is required") + } + if len(def.Components) == 0 { + return fmt.Errorf("service %q must have at least one component", def.Name) + } + + seen := make(map[string]bool) + for _, c := range def.Components { + if c.Name == "" { + return fmt.Errorf("component name is required in service %q", def.Name) + } + if c.Image == "" { + return fmt.Errorf("component %q image is required in service %q", c.Name, def.Name) + } + if seen[c.Name] { + return fmt.Errorf("duplicate component name %q in service %q", c.Name, def.Name) + } + seen[c.Name] = true + } + + return nil +} + +// ToProto converts a ServiceDef to a proto ServiceSpec. +func ToProto(def *ServiceDef) *mcpv1.ServiceSpec { + spec := &mcpv1.ServiceSpec{ + Name: def.Name, + Active: def.Active != nil && *def.Active, + } + + for _, c := range def.Components { + spec.Components = append(spec.Components, &mcpv1.ComponentSpec{ + Name: c.Name, + Image: c.Image, + Network: c.Network, + User: c.User, + Restart: c.Restart, + Ports: c.Ports, + Volumes: c.Volumes, + Cmd: c.Cmd, + }) + } + + return spec +} + +// FromProto converts a proto ServiceSpec back to a ServiceDef. The node +// parameter is required because ServiceSpec does not include the node field +// (it is a CLI-side routing concern). +func FromProto(spec *mcpv1.ServiceSpec, node string) *ServiceDef { + active := spec.GetActive() + def := &ServiceDef{ + Name: spec.GetName(), + Node: node, + Active: &active, + } + + for _, c := range spec.GetComponents() { + def.Components = append(def.Components, ComponentDef{ + Name: c.GetName(), + Image: c.GetImage(), + Network: c.GetNetwork(), + User: c.GetUser(), + Restart: c.GetRestart(), + Ports: c.GetPorts(), + Volumes: c.GetVolumes(), + Cmd: c.GetCmd(), + }) + } + + return def +} diff --git a/internal/servicedef/servicedef_test.go b/internal/servicedef/servicedef_test.go new file mode 100644 index 0000000..11f52f7 --- /dev/null +++ b/internal/servicedef/servicedef_test.go @@ -0,0 +1,289 @@ +package servicedef + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func boolPtr(b bool) *bool { return &b } + +func sampleDef() *ServiceDef { + return &ServiceDef{ + Name: "metacrypt", + Node: "rift", + Active: boolPtr(true), + Components: []ComponentDef{ + { + Name: "api", + Image: "mcr.svc.mcp.metacircular.net:8443/metacrypt:latest", + Network: "docker_default", + User: "0:0", + Restart: "unless-stopped", + Ports: []string{"127.0.0.1:18443:8443", "127.0.0.1:19443:9443"}, + Volumes: []string{"/srv/metacrypt:/srv/metacrypt"}, + }, + { + Name: "web", + Image: "mcr.svc.mcp.metacircular.net:8443/metacrypt-web:latest", + Network: "docker_default", + User: "0:0", + Restart: "unless-stopped", + Ports: []string{"127.0.0.1:18080:8080"}, + Volumes: []string{"/srv/metacrypt:/srv/metacrypt"}, + Cmd: []string{"server", "--config", "/srv/metacrypt/metacrypt.toml"}, + }, + }, + } +} + +// compareComponents asserts that two component slices are equal field by field. +func compareComponents(t *testing.T, prefix string, got, want []ComponentDef) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("%s components: got %d, want %d", prefix, len(got), len(want)) + } + for i, wantC := range want { + gotC := got[i] + if gotC.Name != wantC.Name { + t.Fatalf("%s component[%d] name: got %q, want %q", prefix, i, gotC.Name, wantC.Name) + } + if gotC.Image != wantC.Image { + t.Fatalf("%s component[%d] image: got %q, want %q", prefix, i, gotC.Image, wantC.Image) + } + if gotC.Network != wantC.Network { + t.Fatalf("%s component[%d] network: got %q, want %q", prefix, i, gotC.Network, wantC.Network) + } + if gotC.User != wantC.User { + t.Fatalf("%s component[%d] user: got %q, want %q", prefix, i, gotC.User, wantC.User) + } + if gotC.Restart != wantC.Restart { + t.Fatalf("%s component[%d] restart: got %q, want %q", prefix, i, gotC.Restart, wantC.Restart) + } + compareStrSlice(t, prefix, i, "ports", gotC.Ports, wantC.Ports) + compareStrSlice(t, prefix, i, "volumes", gotC.Volumes, wantC.Volumes) + compareStrSlice(t, prefix, i, "cmd", gotC.Cmd, wantC.Cmd) + } +} + +func compareStrSlice(t *testing.T, prefix string, idx int, field string, got, want []string) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("%s component[%d] %s: got %d, want %d", prefix, idx, field, len(got), len(want)) + } + for j := range want { + if got[j] != want[j] { + t.Fatalf("%s component[%d] %s[%d]: got %q, want %q", prefix, idx, field, j, got[j], want[j]) + } + } +} + +func TestLoadWrite(t *testing.T) { + def := sampleDef() + dir := t.TempDir() + path := filepath.Join(dir, "metacrypt.toml") + + if err := Write(path, def); err != nil { + t.Fatalf("write: %v", err) + } + + got, err := Load(path) + if err != nil { + t.Fatalf("load: %v", err) + } + + if got.Name != def.Name { + t.Fatalf("name: got %q, want %q", got.Name, def.Name) + } + if got.Node != def.Node { + t.Fatalf("node: got %q, want %q", got.Node, def.Node) + } + if *got.Active != *def.Active { + t.Fatalf("active: got %v, want %v", *got.Active, *def.Active) + } + compareComponents(t, "load-write", got.Components, def.Components) +} + +func TestValidation(t *testing.T) { + tests := []struct { + name string + def *ServiceDef + wantErr string + }{ + { + name: "missing name", + def: &ServiceDef{ + Node: "rift", + Components: []ComponentDef{{Name: "api", Image: "img:v1"}}, + }, + wantErr: "service name is required", + }, + { + name: "missing node", + def: &ServiceDef{ + Name: "svc", + Components: []ComponentDef{{Name: "api", Image: "img:v1"}}, + }, + wantErr: "service node is required", + }, + { + name: "empty components", + def: &ServiceDef{ + Name: "svc", + Node: "rift", + }, + wantErr: "must have at least one component", + }, + { + name: "duplicate component names", + def: &ServiceDef{ + Name: "svc", + Node: "rift", + Components: []ComponentDef{ + {Name: "api", Image: "img:v1"}, + {Name: "api", Image: "img:v2"}, + }, + }, + wantErr: "duplicate component name", + }, + { + name: "component missing name", + def: &ServiceDef{ + Name: "svc", + Node: "rift", + Components: []ComponentDef{{Image: "img:v1"}}, + }, + wantErr: "component name is required", + }, + { + name: "component missing image", + def: &ServiceDef{ + Name: "svc", + Node: "rift", + Components: []ComponentDef{{Name: "api"}}, + }, + wantErr: "image is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validate(tt.def) + if err == nil { + t.Fatal("expected validation error") + } + if got := err.Error(); !strings.Contains(got, tt.wantErr) { + t.Fatalf("error %q does not contain %q", got, tt.wantErr) + } + }) + } +} + +func TestLoadAll(t *testing.T) { + dir := t.TempDir() + + // Write services in non-alphabetical order. + defs := []*ServiceDef{ + { + Name: "mcr", + Node: "rift", + Active: boolPtr(true), + Components: []ComponentDef{{Name: "api", Image: "mcr:latest"}}, + }, + { + Name: "metacrypt", + Node: "rift", + Active: boolPtr(true), + Components: []ComponentDef{{Name: "api", Image: "metacrypt:latest"}}, + }, + { + Name: "mcias", + Node: "rift", + Active: boolPtr(false), + Components: []ComponentDef{{Name: "api", Image: "mcias:latest"}}, + }, + } + + for _, d := range defs { + if err := Write(filepath.Join(dir, d.Name+".toml"), d); err != nil { + t.Fatalf("write %s: %v", d.Name, err) + } + } + + // Write a non-TOML file that should be ignored. + if err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("# ignore"), 0o600); err != nil { + t.Fatalf("write readme: %v", err) + } + + got, err := LoadAll(dir) + if err != nil { + t.Fatalf("load all: %v", err) + } + + if len(got) != 3 { + t.Fatalf("count: got %d, want 3", len(got)) + } + + wantOrder := []string{"mcias", "mcr", "metacrypt"} + for i, name := range wantOrder { + if got[i].Name != name { + t.Fatalf("order[%d]: got %q, want %q", i, got[i].Name, name) + } + } +} + +func TestActiveDefault(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "svc.toml") + + content := `name = "svc" +node = "rift" + +[[components]] +name = "api" +image = "img:latest" +` + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatalf("write: %v", err) + } + + def, err := Load(path) + if err != nil { + t.Fatalf("load: %v", err) + } + + if def.Active == nil { + t.Fatal("active should not be nil") + } + if !*def.Active { + t.Fatal("active should default to true") + } +} + +func TestProtoConversion(t *testing.T) { + def := sampleDef() + + spec := ToProto(def) + if spec.Name != def.Name { + t.Fatalf("proto name: got %q, want %q", spec.Name, def.Name) + } + if !spec.Active { + t.Fatal("proto active should be true") + } + if len(spec.Components) != len(def.Components) { + t.Fatalf("proto components: got %d, want %d", len(spec.Components), len(def.Components)) + } + + got := FromProto(spec, def.Node) + if got.Name != def.Name { + t.Fatalf("round-trip name: got %q, want %q", got.Name, def.Name) + } + if got.Node != def.Node { + t.Fatalf("round-trip node: got %q, want %q", got.Node, def.Node) + } + if *got.Active != *def.Active { + t.Fatalf("round-trip active: got %v, want %v", *got.Active, *def.Active) + } + compareComponents(t, "round-trip", got.Components, def.Components) +}