P1.2-P1.5: Complete Phase 1 core libraries
Four packages built in parallel: - P1.2 runtime: Container runtime abstraction with podman implementation. Interface (Pull/Run/Stop/Remove/Inspect/List), ContainerSpec/ContainerInfo types, CLI arg building, version extraction from image tags. 2 tests. - P1.3 servicedef: TOML service definition file parsing. Load/Write/LoadAll, validation (required fields, unique component names), proto conversion. 5 tests. - P1.4 config: CLI and agent config loading from TOML. Duration type for time fields, env var overrides (MCP_*/MCP_AGENT_*), required field validation, sensible defaults. 7 tests. - P1.5 auth: MCIAS integration. Token validator with 30s SHA-256 cache, gRPC unary interceptor (admin role enforcement, audit logging), Login/LoadToken/SaveToken for CLI. 9 tests. All packages pass build, vet, lint, and test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
332
internal/auth/auth.go
Normal file
332
internal/auth/auth.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// cacheTTL is the duration cached token validation results are valid.
|
||||
const cacheTTL = 30 * time.Second
|
||||
|
||||
// TokenInfo holds the result of a token validation.
|
||||
type TokenInfo struct {
|
||||
Valid bool `json:"valid"`
|
||||
Username string `json:"username"`
|
||||
Roles []string `json:"roles"`
|
||||
AccountType string `json:"account_type"`
|
||||
}
|
||||
|
||||
// HasRole reports whether the token has the given role.
|
||||
func (t *TokenInfo) HasRole(role string) bool {
|
||||
for _, r := range t.Roles {
|
||||
if r == role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TokenValidator validates bearer tokens against MCIAS.
|
||||
type TokenValidator interface {
|
||||
ValidateToken(ctx context.Context, token string) (*TokenInfo, error)
|
||||
}
|
||||
|
||||
// tokenCache stores validated token results with a TTL.
|
||||
type tokenCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]cacheEntry
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
info *TokenInfo
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newTokenCache() *tokenCache {
|
||||
return &tokenCache{
|
||||
entries: make(map[string]cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *tokenCache) get(hash string) (*TokenInfo, bool) {
|
||||
c.mu.RLock()
|
||||
entry, ok := c.entries[hash]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
c.mu.Lock()
|
||||
delete(c.entries, hash)
|
||||
c.mu.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
return entry.info, true
|
||||
}
|
||||
|
||||
func (c *tokenCache) put(hash string, info *TokenInfo) {
|
||||
c.mu.Lock()
|
||||
c.entries[hash] = cacheEntry{
|
||||
info: info,
|
||||
expiresAt: time.Now().Add(cacheTTL),
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// newHTTPClient creates an HTTP client with TLS 1.3 minimum. If caCertPath
|
||||
// is non-empty, the CA certificate is loaded and added to the root CA pool.
|
||||
func newHTTPClient(caCertPath string) (*http.Client, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
if caCertPath != "" {
|
||||
caCert, err := os.ReadFile(caCertPath) //nolint:gosec // path from trusted config
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read CA cert %q: %w", caCertPath, err)
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("parse CA cert %q: no valid certificates found", caCertPath)
|
||||
}
|
||||
tlsConfig.RootCAs = pool
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MCIASValidator validates tokens by calling the MCIAS HTTP endpoint.
|
||||
type MCIASValidator struct {
|
||||
ServerURL string
|
||||
CACertPath string
|
||||
httpClient *http.Client
|
||||
cache *tokenCache
|
||||
}
|
||||
|
||||
// NewMCIASValidator creates a validator that calls MCIAS at the given URL.
|
||||
// If caCertPath is non-empty, the CA certificate is loaded and used for TLS.
|
||||
func NewMCIASValidator(serverURL, caCertPath string) (*MCIASValidator, error) {
|
||||
client, err := newHTTPClient(caCertPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MCIASValidator{
|
||||
ServerURL: strings.TrimRight(serverURL, "/"),
|
||||
CACertPath: caCertPath,
|
||||
httpClient: client,
|
||||
cache: newTokenCache(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func tokenHash(token string) string {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// ValidateToken validates a bearer token against MCIAS. Results are cached
|
||||
// for 30 seconds keyed by the SHA-256 hash of the token.
|
||||
func (v *MCIASValidator) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) {
|
||||
hash := tokenHash(token)
|
||||
|
||||
if info, ok := v.cache.get(hash); ok {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
url := v.ServerURL + "/v1/token/validate"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create validate request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := v.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate token: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read validate response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("validate token: MCIAS returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var info TokenInfo
|
||||
if err := json.Unmarshal(body, &info); err != nil {
|
||||
return nil, fmt.Errorf("parse validate response: %w", err)
|
||||
}
|
||||
|
||||
v.cache.put(hash, &info)
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// contextKey is an unexported type for context keys in this package.
|
||||
type contextKey struct{}
|
||||
|
||||
// tokenInfoKey is the context key for TokenInfo.
|
||||
var tokenInfoKey = contextKey{}
|
||||
|
||||
// ContextWithTokenInfo returns a new context carrying the given TokenInfo.
|
||||
func ContextWithTokenInfo(ctx context.Context, info *TokenInfo) context.Context {
|
||||
return context.WithValue(ctx, tokenInfoKey, info)
|
||||
}
|
||||
|
||||
// TokenInfoFromContext retrieves TokenInfo from the context, or nil if absent.
|
||||
func TokenInfoFromContext(ctx context.Context) *TokenInfo {
|
||||
info, _ := ctx.Value(tokenInfoKey).(*TokenInfo)
|
||||
return info
|
||||
}
|
||||
|
||||
// AuthInterceptor returns a gRPC unary server interceptor that validates
|
||||
// bearer tokens and requires the "admin" role.
|
||||
func AuthInterceptor(validator TokenValidator) grpc.UnaryServerInterceptor {
|
||||
return func(
|
||||
ctx context.Context,
|
||||
req any,
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (any, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Unauthenticated, "missing metadata")
|
||||
}
|
||||
|
||||
authValues := md.Get("authorization")
|
||||
if len(authValues) == 0 {
|
||||
return nil, status.Error(codes.Unauthenticated, "missing authorization header")
|
||||
}
|
||||
|
||||
authHeader := authValues[0]
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return nil, status.Error(codes.Unauthenticated, "malformed authorization header")
|
||||
}
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
tokenInfo, err := validator.ValidateToken(ctx, token)
|
||||
if err != nil {
|
||||
slog.Error("token validation failed", "method", info.FullMethod, "error", err)
|
||||
return nil, status.Error(codes.Unauthenticated, "token validation failed")
|
||||
}
|
||||
|
||||
if !tokenInfo.Valid {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid token")
|
||||
}
|
||||
|
||||
if !tokenInfo.HasRole("admin") {
|
||||
slog.Warn("permission denied", "method", info.FullMethod, "user", tokenInfo.Username)
|
||||
return nil, status.Error(codes.PermissionDenied, "admin role required")
|
||||
}
|
||||
|
||||
slog.Info("rpc", "method", info.FullMethod, "user", tokenInfo.Username, "account_type", tokenInfo.AccountType)
|
||||
|
||||
ctx = ContextWithTokenInfo(ctx, tokenInfo)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// Login authenticates with MCIAS and returns a bearer token.
|
||||
func Login(serverURL, caCertPath, username, password string) (string, error) {
|
||||
client, err := newHTTPClient(caCertPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
serverURL = strings.TrimRight(serverURL, "/")
|
||||
url := serverURL + "/v1/auth/login"
|
||||
|
||||
payload := struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
body, err := json.Marshal(payload) //nolint:gosec // intentional login credential payload
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal login request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create login request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("login request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read login response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("login failed: MCIAS returned %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", fmt.Errorf("parse login response: %w", err)
|
||||
}
|
||||
|
||||
if result.Token == "" {
|
||||
return "", fmt.Errorf("login response missing token")
|
||||
}
|
||||
|
||||
return result.Token, nil
|
||||
}
|
||||
|
||||
// LoadToken reads a token from the given file path and trims whitespace.
|
||||
func LoadToken(path string) (string, error) {
|
||||
data, err := os.ReadFile(path) //nolint:gosec // path from trusted caller
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load token from %q: %w", path, err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
// SaveToken writes a token to the given file with 0600 permissions.
|
||||
// Parent directories are created if they do not exist.
|
||||
func SaveToken(path string, token string) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("create token directory %q: %w", dir, err)
|
||||
}
|
||||
if err := os.WriteFile(path, []byte(token+"\n"), 0600); err != nil {
|
||||
return fmt.Errorf("save token to %q: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
391
internal/auth/auth_test.go
Normal file
391
internal/auth/auth_test.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// mockMCIAS creates a test MCIAS server that responds to /v1/token/validate.
|
||||
// The handler function receives the authorization header and returns (response, statusCode).
|
||||
func mockMCIAS(t *testing.T, handler func(authHeader string) (any, int)) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
resp, code := handler(authHeader)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Errorf("encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func validatorFromServer(t *testing.T, server *httptest.Server) *MCIASValidator {
|
||||
t.Helper()
|
||||
v, err := NewMCIASValidator(server.URL, "")
|
||||
if err != nil {
|
||||
t.Fatalf("create validator: %v", err)
|
||||
}
|
||||
v.httpClient = server.Client()
|
||||
return v
|
||||
}
|
||||
|
||||
// callInterceptor invokes the interceptor with the given context and validator.
|
||||
func callInterceptor(ctx context.Context, validator TokenValidator) (*TokenInfo, error) {
|
||||
interceptor := AuthInterceptor(validator)
|
||||
info := &grpc.UnaryServerInfo{FullMethod: "/mcp.v1.MCPService/TestMethod"}
|
||||
|
||||
var captured *TokenInfo
|
||||
handler := func(ctx context.Context, req any) (any, error) {
|
||||
captured = TokenInfoFromContext(ctx)
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
_, err := interceptor(ctx, nil, info, handler)
|
||||
return captured, err
|
||||
}
|
||||
|
||||
func TestInterceptorRejectsNoToken(t *testing.T) {
|
||||
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||
return map[string]any{"valid": false}, http.StatusOK
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
v := validatorFromServer(t, server)
|
||||
|
||||
// No metadata at all.
|
||||
_, err := callInterceptor(context.Background(), v)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
|
||||
t.Fatalf("expected Unauthenticated, got %v", err)
|
||||
}
|
||||
|
||||
// Metadata present but no authorization key.
|
||||
md := metadata.Pairs("other-key", "value")
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = callInterceptor(ctx, v)
|
||||
if err == nil {
|
||||
t.Fatal("expected error with empty auth, got nil")
|
||||
}
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
|
||||
t.Fatalf("expected Unauthenticated, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptorRejectsMalformedToken(t *testing.T) {
|
||||
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||
return map[string]any{"valid": false}, http.StatusOK
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
v := validatorFromServer(t, server)
|
||||
|
||||
md := metadata.Pairs("authorization", "NotBearer xxx")
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
|
||||
_, err := callInterceptor(ctx, v)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
|
||||
t.Fatalf("expected Unauthenticated, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptorRejectsInvalidToken(t *testing.T) {
|
||||
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||
return &TokenInfo{Valid: false}, http.StatusOK
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
v := validatorFromServer(t, server)
|
||||
|
||||
md := metadata.Pairs("authorization", "Bearer bad-token")
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
|
||||
_, err := callInterceptor(ctx, v)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
|
||||
t.Fatalf("expected Unauthenticated, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptorRejectsNonAdmin(t *testing.T) {
|
||||
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||
return &TokenInfo{
|
||||
Valid: true,
|
||||
Username: "regularuser",
|
||||
Roles: []string{"user"},
|
||||
AccountType: "human",
|
||||
}, http.StatusOK
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
v := validatorFromServer(t, server)
|
||||
|
||||
md := metadata.Pairs("authorization", "Bearer user-token")
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
|
||||
_, err := callInterceptor(ctx, v)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied {
|
||||
t.Fatalf("expected PermissionDenied, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptorAcceptsAdmin(t *testing.T) {
|
||||
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||
return &TokenInfo{
|
||||
Valid: true,
|
||||
Username: "kyle",
|
||||
Roles: []string{"admin", "user"},
|
||||
AccountType: "human",
|
||||
}, http.StatusOK
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
v := validatorFromServer(t, server)
|
||||
|
||||
md := metadata.Pairs("authorization", "Bearer admin-token")
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
|
||||
captured, err := callInterceptor(ctx, v)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if captured == nil {
|
||||
t.Fatal("expected token info in context, got nil")
|
||||
}
|
||||
if captured.Username != "kyle" {
|
||||
t.Fatalf("username: got %q, want %q", captured.Username, "kyle")
|
||||
}
|
||||
if !captured.HasRole("admin") {
|
||||
t.Fatal("expected admin role")
|
||||
}
|
||||
if captured.AccountType != "human" {
|
||||
t.Fatalf("account_type: got %q, want %q", captured.AccountType, "human")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenCaching(t *testing.T) {
|
||||
var requestCount atomic.Int64
|
||||
|
||||
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||
requestCount.Add(1)
|
||||
return &TokenInfo{
|
||||
Valid: true,
|
||||
Username: "kyle",
|
||||
Roles: []string{"admin"},
|
||||
AccountType: "human",
|
||||
}, http.StatusOK
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
v := validatorFromServer(t, server)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First call should hit the server.
|
||||
info1, err := v.ValidateToken(ctx, "same-token")
|
||||
if err != nil {
|
||||
t.Fatalf("first validate: %v", err)
|
||||
}
|
||||
if !info1.Valid {
|
||||
t.Fatal("expected valid token")
|
||||
}
|
||||
|
||||
// Second call with the same token should be cached.
|
||||
info2, err := v.ValidateToken(ctx, "same-token")
|
||||
if err != nil {
|
||||
t.Fatalf("second validate: %v", err)
|
||||
}
|
||||
if info2.Username != info1.Username {
|
||||
t.Fatalf("cached result mismatch: got %q, want %q", info2.Username, info1.Username)
|
||||
}
|
||||
|
||||
if count := requestCount.Load(); count != 1 {
|
||||
t.Fatalf("expected 1 MCIAS request, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenCacheSeparateEntries(t *testing.T) {
|
||||
var requestCount atomic.Int64
|
||||
|
||||
server := mockMCIAS(t, func(authHeader string) (any, int) {
|
||||
requestCount.Add(1)
|
||||
// Return different usernames based on the token.
|
||||
token := authHeader[len("Bearer "):]
|
||||
return &TokenInfo{
|
||||
Valid: true,
|
||||
Username: fmt.Sprintf("user-for-%s", token),
|
||||
Roles: []string{"admin"},
|
||||
AccountType: "human",
|
||||
}, http.StatusOK
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
v := validatorFromServer(t, server)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
info1, err := v.ValidateToken(ctx, "token-a")
|
||||
if err != nil {
|
||||
t.Fatalf("validate token-a: %v", err)
|
||||
}
|
||||
|
||||
info2, err := v.ValidateToken(ctx, "token-b")
|
||||
if err != nil {
|
||||
t.Fatalf("validate token-b: %v", err)
|
||||
}
|
||||
|
||||
if info1.Username == info2.Username {
|
||||
t.Fatalf("different tokens should have different cache entries, both got %q", info1.Username)
|
||||
}
|
||||
|
||||
if count := requestCount.Load(); count != 2 {
|
||||
t.Fatalf("expected 2 MCIAS requests for different tokens, got %d", count)
|
||||
}
|
||||
|
||||
// Repeat calls should be cached.
|
||||
_, err = v.ValidateToken(ctx, "token-a")
|
||||
if err != nil {
|
||||
t.Fatalf("cached validate token-a: %v", err)
|
||||
}
|
||||
_, err = v.ValidateToken(ctx, "token-b")
|
||||
if err != nil {
|
||||
t.Fatalf("cached validate token-b: %v", err)
|
||||
}
|
||||
|
||||
if count := requestCount.Load(); count != 2 {
|
||||
t.Fatalf("expected still 2 MCIAS requests after cache hits, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSaveToken(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "subdir", "token")
|
||||
|
||||
token := "test-bearer-token-12345"
|
||||
|
||||
if err := SaveToken(path, token); err != nil {
|
||||
t.Fatalf("save: %v", err)
|
||||
}
|
||||
|
||||
// Check file permissions.
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("stat: %v", err)
|
||||
}
|
||||
if perm := fi.Mode().Perm(); perm != 0600 {
|
||||
t.Fatalf("permissions: got %o, want 0600", perm)
|
||||
}
|
||||
|
||||
// Load and verify.
|
||||
loaded, err := LoadToken(path)
|
||||
if err != nil {
|
||||
t.Fatalf("load: %v", err)
|
||||
}
|
||||
if loaded != token {
|
||||
t.Fatalf("loaded: got %q, want %q", loaded, token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
expectedToken := "mcias-session-token-xyz"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/auth/login" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("unexpected method: %s", r.Method)
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Errorf("decode request body: %v", err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if body.Username != "kyle" || body.Password != "secret" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"token": expectedToken})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
token, err := Login(server.URL, "", "kyle", "secret")
|
||||
if err != nil {
|
||||
t.Fatalf("login: %v", err)
|
||||
}
|
||||
if token != expectedToken {
|
||||
t.Fatalf("token: got %q, want %q", token, expectedToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginBadCredentials(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid credentials"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := Login(server.URL, "", "kyle", "wrong")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for bad credentials")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextTokenInfo(t *testing.T) {
|
||||
info := &TokenInfo{
|
||||
Valid: true,
|
||||
Username: "kyle",
|
||||
Roles: []string{"admin"},
|
||||
AccountType: "human",
|
||||
}
|
||||
|
||||
ctx := ContextWithTokenInfo(context.Background(), info)
|
||||
got := TokenInfoFromContext(ctx)
|
||||
if got == nil {
|
||||
t.Fatal("expected token info from context, got nil")
|
||||
}
|
||||
if got.Username != "kyle" {
|
||||
t.Fatalf("username: got %q, want %q", got.Username, "kyle")
|
||||
}
|
||||
|
||||
// Empty context should return nil.
|
||||
got = TokenInfoFromContext(context.Background())
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil from empty context, got %+v", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user