Files
mcp/internal/auth/auth_test.go
Kyle Isom 86d516acf6 Drop admin requirement from agent interceptor, reject guests
The agent now accepts any authenticated user or system account, except
those with the guest role. Admin is reserved for MCIAS account management
and policy changes, not routine deploy/stop/start operations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-28 16:07:17 -07:00

414 lines
11 KiB
Go

package auth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// mockMCIAS creates a test MCIAS server that responds to /v1/token/validate.
// The handler function receives the authorization header and returns (response, statusCode).
func mockMCIAS(t *testing.T, handler func(authHeader string) (any, int)) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
resp, code := handler(authHeader)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Errorf("encode response: %v", err)
}
}))
}
func validatorFromServer(t *testing.T, server *httptest.Server) *MCIASValidator {
t.Helper()
v, err := NewMCIASValidator(server.URL, "")
if err != nil {
t.Fatalf("create validator: %v", err)
}
v.httpClient = server.Client()
return v
}
// callInterceptor invokes the interceptor with the given context and validator.
func callInterceptor(ctx context.Context, validator TokenValidator) (*TokenInfo, error) {
interceptor := AuthInterceptor(validator)
info := &grpc.UnaryServerInfo{FullMethod: "/mcp.v1.MCPService/TestMethod"}
var captured *TokenInfo
handler := func(ctx context.Context, req any) (any, error) {
captured = TokenInfoFromContext(ctx)
return "ok", nil
}
_, err := interceptor(ctx, nil, info, handler)
return captured, err
}
func TestInterceptorRejectsNoToken(t *testing.T) {
server := mockMCIAS(t, func(authHeader string) (any, int) {
return map[string]any{"valid": false}, http.StatusOK
})
defer server.Close()
v := validatorFromServer(t, server)
// No metadata at all.
_, err := callInterceptor(context.Background(), v)
if err == nil {
t.Fatal("expected error, got nil")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
t.Fatalf("expected Unauthenticated, got %v", err)
}
// Metadata present but no authorization key.
md := metadata.Pairs("other-key", "value")
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err = callInterceptor(ctx, v)
if err == nil {
t.Fatal("expected error with empty auth, got nil")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
t.Fatalf("expected Unauthenticated, got %v", err)
}
}
func TestInterceptorRejectsMalformedToken(t *testing.T) {
server := mockMCIAS(t, func(authHeader string) (any, int) {
return map[string]any{"valid": false}, http.StatusOK
})
defer server.Close()
v := validatorFromServer(t, server)
md := metadata.Pairs("authorization", "NotBearer xxx")
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err := callInterceptor(ctx, v)
if err == nil {
t.Fatal("expected error, got nil")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
t.Fatalf("expected Unauthenticated, got %v", err)
}
}
func TestInterceptorRejectsInvalidToken(t *testing.T) {
server := mockMCIAS(t, func(authHeader string) (any, int) {
return &TokenInfo{Valid: false}, http.StatusOK
})
defer server.Close()
v := validatorFromServer(t, server)
md := metadata.Pairs("authorization", "Bearer bad-token")
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err := callInterceptor(ctx, v)
if err == nil {
t.Fatal("expected error, got nil")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated {
t.Fatalf("expected Unauthenticated, got %v", err)
}
}
func TestInterceptorAcceptsRegularUser(t *testing.T) {
server := mockMCIAS(t, func(authHeader string) (any, int) {
return &TokenInfo{
Valid: true,
Username: "regularuser",
Roles: []string{"user"},
AccountType: "human",
}, http.StatusOK
})
defer server.Close()
v := validatorFromServer(t, server)
md := metadata.Pairs("authorization", "Bearer user-token")
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err := callInterceptor(ctx, v)
if err != nil {
t.Fatalf("expected regular user to be accepted, got %v", err)
}
}
func TestInterceptorRejectsGuest(t *testing.T) {
server := mockMCIAS(t, func(authHeader string) (any, int) {
return &TokenInfo{
Valid: true,
Username: "visitor",
Roles: []string{"guest"},
AccountType: "human",
}, http.StatusOK
})
defer server.Close()
v := validatorFromServer(t, server)
md := metadata.Pairs("authorization", "Bearer guest-token")
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err := callInterceptor(ctx, v)
if err == nil {
t.Fatal("expected error, got nil")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied {
t.Fatalf("expected PermissionDenied, got %v", err)
}
}
func TestInterceptorAcceptsAdmin(t *testing.T) {
server := mockMCIAS(t, func(authHeader string) (any, int) {
return &TokenInfo{
Valid: true,
Username: "kyle",
Roles: []string{"admin", "user"},
AccountType: "human",
}, http.StatusOK
})
defer server.Close()
v := validatorFromServer(t, server)
md := metadata.Pairs("authorization", "Bearer admin-token")
ctx := metadata.NewIncomingContext(context.Background(), md)
captured, err := callInterceptor(ctx, v)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if captured == nil {
t.Fatal("expected token info in context, got nil")
}
if captured.Username != "kyle" {
t.Fatalf("username: got %q, want %q", captured.Username, "kyle")
}
if !captured.HasRole("admin") {
t.Fatal("expected admin role")
}
if captured.AccountType != "human" {
t.Fatalf("account_type: got %q, want %q", captured.AccountType, "human")
}
}
func TestTokenCaching(t *testing.T) {
var requestCount atomic.Int64
server := mockMCIAS(t, func(authHeader string) (any, int) {
requestCount.Add(1)
return &TokenInfo{
Valid: true,
Username: "kyle",
Roles: []string{"admin"},
AccountType: "human",
}, http.StatusOK
})
defer server.Close()
v := validatorFromServer(t, server)
ctx := context.Background()
// First call should hit the server.
info1, err := v.ValidateToken(ctx, "same-token")
if err != nil {
t.Fatalf("first validate: %v", err)
}
if !info1.Valid {
t.Fatal("expected valid token")
}
// Second call with the same token should be cached.
info2, err := v.ValidateToken(ctx, "same-token")
if err != nil {
t.Fatalf("second validate: %v", err)
}
if info2.Username != info1.Username {
t.Fatalf("cached result mismatch: got %q, want %q", info2.Username, info1.Username)
}
if count := requestCount.Load(); count != 1 {
t.Fatalf("expected 1 MCIAS request, got %d", count)
}
}
func TestTokenCacheSeparateEntries(t *testing.T) {
var requestCount atomic.Int64
server := mockMCIAS(t, func(authHeader string) (any, int) {
requestCount.Add(1)
// Return different usernames based on the token.
token := authHeader[len("Bearer "):]
return &TokenInfo{
Valid: true,
Username: fmt.Sprintf("user-for-%s", token),
Roles: []string{"admin"},
AccountType: "human",
}, http.StatusOK
})
defer server.Close()
v := validatorFromServer(t, server)
ctx := context.Background()
info1, err := v.ValidateToken(ctx, "token-a")
if err != nil {
t.Fatalf("validate token-a: %v", err)
}
info2, err := v.ValidateToken(ctx, "token-b")
if err != nil {
t.Fatalf("validate token-b: %v", err)
}
if info1.Username == info2.Username {
t.Fatalf("different tokens should have different cache entries, both got %q", info1.Username)
}
if count := requestCount.Load(); count != 2 {
t.Fatalf("expected 2 MCIAS requests for different tokens, got %d", count)
}
// Repeat calls should be cached.
_, err = v.ValidateToken(ctx, "token-a")
if err != nil {
t.Fatalf("cached validate token-a: %v", err)
}
_, err = v.ValidateToken(ctx, "token-b")
if err != nil {
t.Fatalf("cached validate token-b: %v", err)
}
if count := requestCount.Load(); count != 2 {
t.Fatalf("expected still 2 MCIAS requests after cache hits, got %d", count)
}
}
func TestLoadSaveToken(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "subdir", "token")
token := "test-bearer-token-12345"
if err := SaveToken(path, token); err != nil {
t.Fatalf("save: %v", err)
}
// Check file permissions.
fi, err := os.Stat(path)
if err != nil {
t.Fatalf("stat: %v", err)
}
if perm := fi.Mode().Perm(); perm != 0600 {
t.Fatalf("permissions: got %o, want 0600", perm)
}
// Load and verify.
loaded, err := LoadToken(path)
if err != nil {
t.Fatalf("load: %v", err)
}
if loaded != token {
t.Fatalf("loaded: got %q, want %q", loaded, token)
}
}
func TestLogin(t *testing.T) {
expectedToken := "mcias-session-token-xyz"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/auth/login" {
t.Errorf("unexpected path: %s", r.URL.Path)
http.NotFound(w, r)
return
}
if r.Method != http.MethodPost {
t.Errorf("unexpected method: %s", r.Method)
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
var body struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Errorf("decode request body: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
if body.Username != "kyle" || body.Password != "secret" {
w.WriteHeader(http.StatusUnauthorized)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid credentials"})
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{"token": expectedToken})
}))
defer server.Close()
token, err := Login(server.URL, "", "kyle", "secret")
if err != nil {
t.Fatalf("login: %v", err)
}
if token != expectedToken {
t.Fatalf("token: got %q, want %q", token, expectedToken)
}
}
func TestLoginBadCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid credentials"})
}))
defer server.Close()
_, err := Login(server.URL, "", "kyle", "wrong")
if err == nil {
t.Fatal("expected error for bad credentials")
}
}
func TestContextTokenInfo(t *testing.T) {
info := &TokenInfo{
Valid: true,
Username: "kyle",
Roles: []string{"admin"},
AccountType: "human",
}
ctx := ContextWithTokenInfo(context.Background(), info)
got := TokenInfoFromContext(ctx)
if got == nil {
t.Fatal("expected token info from context, got nil")
}
if got.Username != "kyle" {
t.Fatalf("username: got %q, want %q", got.Username, "kyle")
}
// Empty context should return nil.
got = TokenInfoFromContext(context.Background())
if got != nil {
t.Fatalf("expected nil from empty context, got %+v", got)
}
}