package auth import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "os" "path/filepath" "sync/atomic" "testing" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) // mockMCIAS creates a test MCIAS server that responds to /v1/token/validate. // The handler function receives the authorization header and returns (response, statusCode). func mockMCIAS(t *testing.T, handler func(authHeader string) (any, int)) *httptest.Server { t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") resp, code := handler(authHeader) w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) if err := json.NewEncoder(w).Encode(resp); err != nil { t.Errorf("encode response: %v", err) } })) } func validatorFromServer(t *testing.T, server *httptest.Server) *MCIASValidator { t.Helper() v, err := NewMCIASValidator(server.URL, "") if err != nil { t.Fatalf("create validator: %v", err) } v.httpClient = server.Client() return v } // callInterceptor invokes the interceptor with the given context and validator. func callInterceptor(ctx context.Context, validator TokenValidator) (*TokenInfo, error) { interceptor := AuthInterceptor(validator) info := &grpc.UnaryServerInfo{FullMethod: "/mcp.v1.MCPService/TestMethod"} var captured *TokenInfo handler := func(ctx context.Context, req any) (any, error) { captured = TokenInfoFromContext(ctx) return "ok", nil } _, err := interceptor(ctx, nil, info, handler) return captured, err } func TestInterceptorRejectsNoToken(t *testing.T) { server := mockMCIAS(t, func(authHeader string) (any, int) { return map[string]any{"valid": false}, http.StatusOK }) defer server.Close() v := validatorFromServer(t, server) // No metadata at all. _, err := callInterceptor(context.Background(), v) if err == nil { t.Fatal("expected error, got nil") } if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated { t.Fatalf("expected Unauthenticated, got %v", err) } // Metadata present but no authorization key. md := metadata.Pairs("other-key", "value") ctx := metadata.NewIncomingContext(context.Background(), md) _, err = callInterceptor(ctx, v) if err == nil { t.Fatal("expected error with empty auth, got nil") } if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated { t.Fatalf("expected Unauthenticated, got %v", err) } } func TestInterceptorRejectsMalformedToken(t *testing.T) { server := mockMCIAS(t, func(authHeader string) (any, int) { return map[string]any{"valid": false}, http.StatusOK }) defer server.Close() v := validatorFromServer(t, server) md := metadata.Pairs("authorization", "NotBearer xxx") ctx := metadata.NewIncomingContext(context.Background(), md) _, err := callInterceptor(ctx, v) if err == nil { t.Fatal("expected error, got nil") } if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated { t.Fatalf("expected Unauthenticated, got %v", err) } } func TestInterceptorRejectsInvalidToken(t *testing.T) { server := mockMCIAS(t, func(authHeader string) (any, int) { return &TokenInfo{Valid: false}, http.StatusOK }) defer server.Close() v := validatorFromServer(t, server) md := metadata.Pairs("authorization", "Bearer bad-token") ctx := metadata.NewIncomingContext(context.Background(), md) _, err := callInterceptor(ctx, v) if err == nil { t.Fatal("expected error, got nil") } if s, ok := status.FromError(err); !ok || s.Code() != codes.Unauthenticated { t.Fatalf("expected Unauthenticated, got %v", err) } } func TestInterceptorRejectsNonAdmin(t *testing.T) { server := mockMCIAS(t, func(authHeader string) (any, int) { return &TokenInfo{ Valid: true, Username: "regularuser", Roles: []string{"user"}, AccountType: "human", }, http.StatusOK }) defer server.Close() v := validatorFromServer(t, server) md := metadata.Pairs("authorization", "Bearer user-token") ctx := metadata.NewIncomingContext(context.Background(), md) _, err := callInterceptor(ctx, v) if err == nil { t.Fatal("expected error, got nil") } if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied { t.Fatalf("expected PermissionDenied, got %v", err) } } func TestInterceptorAcceptsAdmin(t *testing.T) { server := mockMCIAS(t, func(authHeader string) (any, int) { return &TokenInfo{ Valid: true, Username: "kyle", Roles: []string{"admin", "user"}, AccountType: "human", }, http.StatusOK }) defer server.Close() v := validatorFromServer(t, server) md := metadata.Pairs("authorization", "Bearer admin-token") ctx := metadata.NewIncomingContext(context.Background(), md) captured, err := callInterceptor(ctx, v) if err != nil { t.Fatalf("unexpected error: %v", err) } if captured == nil { t.Fatal("expected token info in context, got nil") } if captured.Username != "kyle" { t.Fatalf("username: got %q, want %q", captured.Username, "kyle") } if !captured.HasRole("admin") { t.Fatal("expected admin role") } if captured.AccountType != "human" { t.Fatalf("account_type: got %q, want %q", captured.AccountType, "human") } } func TestTokenCaching(t *testing.T) { var requestCount atomic.Int64 server := mockMCIAS(t, func(authHeader string) (any, int) { requestCount.Add(1) return &TokenInfo{ Valid: true, Username: "kyle", Roles: []string{"admin"}, AccountType: "human", }, http.StatusOK }) defer server.Close() v := validatorFromServer(t, server) ctx := context.Background() // First call should hit the server. info1, err := v.ValidateToken(ctx, "same-token") if err != nil { t.Fatalf("first validate: %v", err) } if !info1.Valid { t.Fatal("expected valid token") } // Second call with the same token should be cached. info2, err := v.ValidateToken(ctx, "same-token") if err != nil { t.Fatalf("second validate: %v", err) } if info2.Username != info1.Username { t.Fatalf("cached result mismatch: got %q, want %q", info2.Username, info1.Username) } if count := requestCount.Load(); count != 1 { t.Fatalf("expected 1 MCIAS request, got %d", count) } } func TestTokenCacheSeparateEntries(t *testing.T) { var requestCount atomic.Int64 server := mockMCIAS(t, func(authHeader string) (any, int) { requestCount.Add(1) // Return different usernames based on the token. token := authHeader[len("Bearer "):] return &TokenInfo{ Valid: true, Username: fmt.Sprintf("user-for-%s", token), Roles: []string{"admin"}, AccountType: "human", }, http.StatusOK }) defer server.Close() v := validatorFromServer(t, server) ctx := context.Background() info1, err := v.ValidateToken(ctx, "token-a") if err != nil { t.Fatalf("validate token-a: %v", err) } info2, err := v.ValidateToken(ctx, "token-b") if err != nil { t.Fatalf("validate token-b: %v", err) } if info1.Username == info2.Username { t.Fatalf("different tokens should have different cache entries, both got %q", info1.Username) } if count := requestCount.Load(); count != 2 { t.Fatalf("expected 2 MCIAS requests for different tokens, got %d", count) } // Repeat calls should be cached. _, err = v.ValidateToken(ctx, "token-a") if err != nil { t.Fatalf("cached validate token-a: %v", err) } _, err = v.ValidateToken(ctx, "token-b") if err != nil { t.Fatalf("cached validate token-b: %v", err) } if count := requestCount.Load(); count != 2 { t.Fatalf("expected still 2 MCIAS requests after cache hits, got %d", count) } } func TestLoadSaveToken(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "subdir", "token") token := "test-bearer-token-12345" if err := SaveToken(path, token); err != nil { t.Fatalf("save: %v", err) } // Check file permissions. fi, err := os.Stat(path) if err != nil { t.Fatalf("stat: %v", err) } if perm := fi.Mode().Perm(); perm != 0600 { t.Fatalf("permissions: got %o, want 0600", perm) } // Load and verify. loaded, err := LoadToken(path) if err != nil { t.Fatalf("load: %v", err) } if loaded != token { t.Fatalf("loaded: got %q, want %q", loaded, token) } } func TestLogin(t *testing.T) { expectedToken := "mcias-session-token-xyz" server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/auth/login" { t.Errorf("unexpected path: %s", r.URL.Path) http.NotFound(w, r) return } if r.Method != http.MethodPost { t.Errorf("unexpected method: %s", r.Method) w.WriteHeader(http.StatusMethodNotAllowed) return } var body struct { Username string `json:"username"` Password string `json:"password"` } if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Errorf("decode request body: %v", err) w.WriteHeader(http.StatusBadRequest) return } if body.Username != "kyle" || body.Password != "secret" { w.WriteHeader(http.StatusUnauthorized) _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid credentials"}) return } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]string{"token": expectedToken}) })) defer server.Close() token, err := Login(server.URL, "", "kyle", "secret") if err != nil { t.Fatalf("login: %v", err) } if token != expectedToken { t.Fatalf("token: got %q, want %q", token, expectedToken) } } func TestLoginBadCredentials(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) _ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid credentials"}) })) defer server.Close() _, err := Login(server.URL, "", "kyle", "wrong") if err == nil { t.Fatal("expected error for bad credentials") } } func TestContextTokenInfo(t *testing.T) { info := &TokenInfo{ Valid: true, Username: "kyle", Roles: []string{"admin"}, AccountType: "human", } ctx := ContextWithTokenInfo(context.Background(), info) got := TokenInfoFromContext(ctx) if got == nil { t.Fatal("expected token info from context, got nil") } if got.Username != "kyle" { t.Fatalf("username: got %q, want %q", got.Username, "kyle") } // Empty context should return nil. got = TokenInfoFromContext(context.Background()) if got != nil { t.Fatalf("expected nil from empty context, got %+v", got) } }