package grpcserver import ( "context" "encoding/json" "log/slog" "net/http" "net/http/httptest" "testing" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "git.wntrmute.dev/kyle/mcdsl/auth" ) func mockMCIAS(t *testing.T) *httptest.Server { t.Helper() mux := http.NewServeMux() mux.HandleFunc("POST /v1/token/validate", func(w http.ResponseWriter, r *http.Request) { var req struct { Token string `json:"token"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "bad", http.StatusBadRequest) return } switch req.Token { case "admin-token": w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]interface{}{ "valid": true, "username": "admin", "roles": []string{"admin", "user"}, }) case "user-token": w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]interface{}{ "valid": true, "username": "alice", "roles": []string{"user"}, }) default: w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]interface{}{"valid": false}) } }) return httptest.NewServer(mux) } func testAuth(t *testing.T, serverURL string) *auth.Authenticator { t.Helper() a, err := auth.New(auth.Config{ServerURL: serverURL}, slog.Default()) if err != nil { t.Fatalf("auth.New: %v", err) } return a } var testMethods = MethodMap{ Public: map[string]bool{"/test.Service/Health": true}, AuthRequired: map[string]bool{"/test.Service/List": true}, AdminRequired: map[string]bool{"/test.Service/Delete": true}, } // callInterceptor simulates calling a gRPC interceptor with the given // method and authorization metadata. func callInterceptor(interceptor grpc.UnaryServerInterceptor, method, authHeader string) (any, error) { ctx := context.Background() if authHeader != "" { md := metadata.Pairs("authorization", authHeader) ctx = metadata.NewIncomingContext(ctx, md) } info := &grpc.UnaryServerInfo{FullMethod: method} handler := func(ctx context.Context, _ any) (any, error) { // Return the TokenInfo from context to verify it was set. return auth.TokenInfoFromContext(ctx), nil } return interceptor(ctx, nil, info, handler) } func TestPublicMethodNoAuth(t *testing.T) { srv := mockMCIAS(t) defer srv.Close() a := testAuth(t, srv.URL) interceptor := authInterceptor(a, testMethods) resp, err := callInterceptor(interceptor, "/test.Service/Health", "") if err != nil { t.Fatalf("public method error: %v", err) } // Public methods don't set TokenInfo. info, _ := resp.(*auth.TokenInfo) if info != nil { t.Fatal("expected nil TokenInfo for public method") } } func TestAuthRequiredWithValidToken(t *testing.T) { srv := mockMCIAS(t) defer srv.Close() a := testAuth(t, srv.URL) interceptor := authInterceptor(a, testMethods) resp, err := callInterceptor(interceptor, "/test.Service/List", "Bearer user-token") if err != nil { t.Fatalf("auth method error: %v", err) } info, ok := resp.(*auth.TokenInfo) if !ok || info == nil { t.Fatal("expected TokenInfo in context") } if info.Username != "alice" { t.Fatalf("Username = %q, want %q", info.Username, "alice") } } func TestAuthRequiredWithoutToken(t *testing.T) { srv := mockMCIAS(t) defer srv.Close() a := testAuth(t, srv.URL) interceptor := authInterceptor(a, testMethods) _, err := callInterceptor(interceptor, "/test.Service/List", "") if err == nil { t.Fatal("expected error for missing token") } if status.Code(err) != codes.Unauthenticated { t.Fatalf("code = %v, want Unauthenticated", status.Code(err)) } } func TestAuthRequiredWithInvalidToken(t *testing.T) { srv := mockMCIAS(t) defer srv.Close() a := testAuth(t, srv.URL) interceptor := authInterceptor(a, testMethods) _, err := callInterceptor(interceptor, "/test.Service/List", "Bearer bad-token") if err == nil { t.Fatal("expected error for invalid token") } if status.Code(err) != codes.Unauthenticated { t.Fatalf("code = %v, want Unauthenticated", status.Code(err)) } } func TestAdminRequiredWithAdminToken(t *testing.T) { srv := mockMCIAS(t) defer srv.Close() a := testAuth(t, srv.URL) interceptor := authInterceptor(a, testMethods) resp, err := callInterceptor(interceptor, "/test.Service/Delete", "Bearer admin-token") if err != nil { t.Fatalf("admin method error: %v", err) } info, ok := resp.(*auth.TokenInfo) if !ok || info == nil { t.Fatal("expected TokenInfo in context") } if !info.IsAdmin { t.Fatal("expected IsAdmin=true") } } func TestAdminRequiredWithUserToken(t *testing.T) { srv := mockMCIAS(t) defer srv.Close() a := testAuth(t, srv.URL) interceptor := authInterceptor(a, testMethods) _, err := callInterceptor(interceptor, "/test.Service/Delete", "Bearer user-token") if err == nil { t.Fatal("expected error for non-admin on admin method") } if status.Code(err) != codes.PermissionDenied { t.Fatalf("code = %v, want PermissionDenied", status.Code(err)) } } func TestUnmappedMethodDenied(t *testing.T) { srv := mockMCIAS(t) defer srv.Close() a := testAuth(t, srv.URL) interceptor := authInterceptor(a, testMethods) _, err := callInterceptor(interceptor, "/test.Service/Unknown", "Bearer admin-token") if err == nil { t.Fatal("expected error for unmapped method") } if status.Code(err) != codes.PermissionDenied { t.Fatalf("code = %v, want PermissionDenied", status.Code(err)) } } func TestLoggingInterceptor(t *testing.T) { interceptor := loggingInterceptor(slog.Default()) info := &grpc.UnaryServerInfo{FullMethod: "/test.Service/Ping"} handler := func(_ context.Context, _ any) (any, error) { return "pong", nil } resp, err := interceptor(context.Background(), nil, info, handler) if err != nil { t.Fatalf("logging interceptor error: %v", err) } if resp != "pong" { t.Fatalf("resp = %v, want pong", resp) } } func TestNewWithoutTLS(t *testing.T) { srv := mockMCIAS(t) defer srv.Close() a := testAuth(t, srv.URL) s, err := New("", "", a, testMethods, slog.Default(), nil) if err != nil { t.Fatalf("New: %v", err) } if s.GRPCServer == nil { t.Fatal("GRPCServer is nil") } } func TestTokenInfoFromContext(t *testing.T) { info := &auth.TokenInfo{Username: "test", IsAdmin: true} ctx := auth.ContextWithTokenInfo(context.Background(), info) got := TokenInfoFromContext(ctx) if got == nil { t.Fatal("nil from context") } if got.Username != "test" { t.Fatalf("Username = %q, want %q", got.Username, "test") } }