From 4ec0c3a9168d8cfe22d3defdccb4a5a55d4d2f9d Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Thu, 26 Mar 2026 21:05:54 -0700 Subject: [PATCH 1/2] Add REST API handler tests for zones, records, and middleware Cover all REST handlers with httptest-based tests using real SQLite: zones (list, get, create, update, delete), records (list, get, create, update, delete with validation/conflict cases), requireAdmin middleware (admin, non-admin, missing context), and utility functions (writeJSON, writeError, extractBearerToken, tokenInfoFromContext). Co-Authored-By: Claude Opus 4.6 (1M context) --- .gitignore | 3 + gen/mcns/v1/admin.pb.go | 4 +- gen/mcns/v1/admin_grpc.pb.go | 6 +- gen/mcns/v1/auth.pb.go | 13 +- gen/mcns/v1/auth_grpc.pb.go | 6 +- gen/mcns/v1/record.pb.go | 37 +- gen/mcns/v1/record_grpc.pb.go | 6 +- gen/mcns/v1/zone.pb.go | 4 +- gen/mcns/v1/zone_grpc.pb.go | 6 +- internal/server/handlers_test.go | 949 +++++++++++++++++++++++++++++++ proto/mcns/v1/admin.proto | 3 +- proto/mcns/v1/auth.proto | 4 +- proto/mcns/v1/record.proto | 8 +- proto/mcns/v1/zone.proto | 3 +- 14 files changed, 1018 insertions(+), 34 deletions(-) create mode 100644 internal/server/handlers_test.go diff --git a/.gitignore b/.gitignore index cf6450f..21be1a0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,6 @@ srv/ *.db *.db-wal *.db-shm +.idea/ +.vscode/ +.DS_Store diff --git a/gen/mcns/v1/admin.pb.go b/gen/mcns/v1/admin.pb.go index 35d7d14..f69565f 100644 --- a/gen/mcns/v1/admin.pb.go +++ b/gen/mcns/v1/admin.pb.go @@ -4,7 +4,7 @@ // protoc v6.32.1 // source: proto/mcns/v1/admin.proto -package v1 +package mcnsv1 import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" @@ -110,7 +110,7 @@ const file_proto_mcns_v1_admin_proto_rawDesc = "" + "\x0eHealthResponse\x12\x16\n" + "\x06status\x18\x01 \x01(\tR\x06status2I\n" + "\fAdminService\x129\n" + - "\x06Health\x12\x16.mcns.v1.HealthRequest\x1a\x17.mcns.v1.HealthResponseB(Z&git.wntrmute.dev/kyle/mcns/gen/mcns/v1b\x06proto3" + "\x06Health\x12\x16.mcns.v1.HealthRequest\x1a\x17.mcns.v1.HealthResponseB/Z-git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1b\x06proto3" var ( file_proto_mcns_v1_admin_proto_rawDescOnce sync.Once diff --git a/gen/mcns/v1/admin_grpc.pb.go b/gen/mcns/v1/admin_grpc.pb.go index 0734a50..3eb4c74 100644 --- a/gen/mcns/v1/admin_grpc.pb.go +++ b/gen/mcns/v1/admin_grpc.pb.go @@ -4,7 +4,7 @@ // - protoc v6.32.1 // source: proto/mcns/v1/admin.proto -package v1 +package mcnsv1 import ( context "context" @@ -25,6 +25,8 @@ const ( // AdminServiceClient is the client API for AdminService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// AdminService exposes server health and administrative operations. type AdminServiceClient interface { Health(ctx context.Context, in *HealthRequest, opts ...grpc.CallOption) (*HealthResponse, error) } @@ -50,6 +52,8 @@ func (c *adminServiceClient) Health(ctx context.Context, in *HealthRequest, opts // AdminServiceServer is the server API for AdminService service. // All implementations must embed UnimplementedAdminServiceServer // for forward compatibility. +// +// AdminService exposes server health and administrative operations. type AdminServiceServer interface { Health(context.Context, *HealthRequest) (*HealthResponse, error) mustEmbedUnimplementedAdminServiceServer() diff --git a/gen/mcns/v1/auth.pb.go b/gen/mcns/v1/auth.pb.go index 95cf0d5..cb81910 100644 --- a/gen/mcns/v1/auth.pb.go +++ b/gen/mcns/v1/auth.pb.go @@ -4,7 +4,7 @@ // protoc v6.32.1 // source: proto/mcns/v1/auth.proto -package v1 +package mcnsv1 import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" @@ -22,10 +22,11 @@ const ( ) type LoginRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` - Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` - TotpCode string `protobuf:"bytes,3,opt,name=totp_code,json=totpCode,proto3" json:"totp_code,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` + // TOTP code for two-factor authentication, if enabled on the account. + TotpCode string `protobuf:"bytes,3,opt,name=totp_code,json=totpCode,proto3" json:"totp_code,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -221,7 +222,7 @@ const file_proto_mcns_v1_auth_proto_rawDesc = "" + "\x0eLogoutResponse2\x80\x01\n" + "\vAuthService\x126\n" + "\x05Login\x12\x15.mcns.v1.LoginRequest\x1a\x16.mcns.v1.LoginResponse\x129\n" + - "\x06Logout\x12\x16.mcns.v1.LogoutRequest\x1a\x17.mcns.v1.LogoutResponseB(Z&git.wntrmute.dev/kyle/mcns/gen/mcns/v1b\x06proto3" + "\x06Logout\x12\x16.mcns.v1.LogoutRequest\x1a\x17.mcns.v1.LogoutResponseB/Z-git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1b\x06proto3" var ( file_proto_mcns_v1_auth_proto_rawDescOnce sync.Once diff --git a/gen/mcns/v1/auth_grpc.pb.go b/gen/mcns/v1/auth_grpc.pb.go index ba5c146..364cbdc 100644 --- a/gen/mcns/v1/auth_grpc.pb.go +++ b/gen/mcns/v1/auth_grpc.pb.go @@ -4,7 +4,7 @@ // - protoc v6.32.1 // source: proto/mcns/v1/auth.proto -package v1 +package mcnsv1 import ( context "context" @@ -26,6 +26,8 @@ const ( // AuthServiceClient is the client API for AuthService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// AuthService handles authentication by delegating to MCIAS. type AuthServiceClient interface { Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) @@ -62,6 +64,8 @@ func (c *authServiceClient) Logout(ctx context.Context, in *LogoutRequest, opts // AuthServiceServer is the server API for AuthService service. // All implementations must embed UnimplementedAuthServiceServer // for forward compatibility. +// +// AuthService handles authentication by delegating to MCIAS. type AuthServiceServer interface { Login(context.Context, *LoginRequest) (*LoginResponse, error) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) diff --git a/gen/mcns/v1/record.pb.go b/gen/mcns/v1/record.pb.go index 4b4c4c2..de3b179 100644 --- a/gen/mcns/v1/record.pb.go +++ b/gen/mcns/v1/record.pb.go @@ -4,7 +4,7 @@ // protoc v6.32.1 // source: proto/mcns/v1/record.proto -package v1 +package mcnsv1 import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" @@ -23,10 +23,12 @@ const ( ) type Record struct { - state protoimpl.MessageState `protogen:"open.v1"` - Id int64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` - Zone string `protobuf:"bytes,2,opt,name=zone,proto3" json:"zone,omitempty"` - Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + Id int64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + // Zone name this record belongs to (e.g. "example.com."). + Zone string `protobuf:"bytes,2,opt,name=zone,proto3" json:"zone,omitempty"` + Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"` + // DNS record type (A, AAAA, CNAME, MX, TXT, etc.). Type string `protobuf:"bytes,4,opt,name=type,proto3" json:"type,omitempty"` Value string `protobuf:"bytes,5,opt,name=value,proto3" json:"value,omitempty"` Ttl int32 `protobuf:"varint,6,opt,name=ttl,proto3" json:"ttl,omitempty"` @@ -123,10 +125,12 @@ func (x *Record) GetUpdatedAt() *timestamppb.Timestamp { } type ListRecordsRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Zone string `protobuf:"bytes,1,opt,name=zone,proto3" json:"zone,omitempty"` - Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` - Type string `protobuf:"bytes,3,opt,name=type,proto3" json:"type,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + Zone string `protobuf:"bytes,1,opt,name=zone,proto3" json:"zone,omitempty"` + // Optional filter by record name. + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + // Optional filter by record type (A, AAAA, CNAME, etc.). + Type string `protobuf:"bytes,3,opt,name=type,proto3" json:"type,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -227,12 +231,13 @@ func (x *ListRecordsResponse) GetRecords() []*Record { } type CreateRecordRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Zone string `protobuf:"bytes,1,opt,name=zone,proto3" json:"zone,omitempty"` - Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` - Type string `protobuf:"bytes,3,opt,name=type,proto3" json:"type,omitempty"` - Value string `protobuf:"bytes,4,opt,name=value,proto3" json:"value,omitempty"` - Ttl int32 `protobuf:"varint,5,opt,name=ttl,proto3" json:"ttl,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + // Zone name the record will be created in; must reference an existing zone. + Zone string `protobuf:"bytes,1,opt,name=zone,proto3" json:"zone,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + Type string `protobuf:"bytes,3,opt,name=type,proto3" json:"type,omitempty"` + Value string `protobuf:"bytes,4,opt,name=value,proto3" json:"value,omitempty"` + Ttl int32 `protobuf:"varint,5,opt,name=ttl,proto3" json:"ttl,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -546,7 +551,7 @@ const file_proto_mcns_v1_record_proto_rawDesc = "" + "\fCreateRecord\x12\x1c.mcns.v1.CreateRecordRequest\x1a\x0f.mcns.v1.Record\x127\n" + "\tGetRecord\x12\x19.mcns.v1.GetRecordRequest\x1a\x0f.mcns.v1.Record\x12=\n" + "\fUpdateRecord\x12\x1c.mcns.v1.UpdateRecordRequest\x1a\x0f.mcns.v1.Record\x12K\n" + - "\fDeleteRecord\x12\x1c.mcns.v1.DeleteRecordRequest\x1a\x1d.mcns.v1.DeleteRecordResponseB(Z&git.wntrmute.dev/kyle/mcns/gen/mcns/v1b\x06proto3" + "\fDeleteRecord\x12\x1c.mcns.v1.DeleteRecordRequest\x1a\x1d.mcns.v1.DeleteRecordResponseB/Z-git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1b\x06proto3" var ( file_proto_mcns_v1_record_proto_rawDescOnce sync.Once diff --git a/gen/mcns/v1/record_grpc.pb.go b/gen/mcns/v1/record_grpc.pb.go index 72e5c0b..7e5b2c3 100644 --- a/gen/mcns/v1/record_grpc.pb.go +++ b/gen/mcns/v1/record_grpc.pb.go @@ -4,7 +4,7 @@ // - protoc v6.32.1 // source: proto/mcns/v1/record.proto -package v1 +package mcnsv1 import ( context "context" @@ -29,6 +29,8 @@ const ( // RecordServiceClient is the client API for RecordService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// RecordService manages DNS records within zones. type RecordServiceClient interface { ListRecords(ctx context.Context, in *ListRecordsRequest, opts ...grpc.CallOption) (*ListRecordsResponse, error) CreateRecord(ctx context.Context, in *CreateRecordRequest, opts ...grpc.CallOption) (*Record, error) @@ -98,6 +100,8 @@ func (c *recordServiceClient) DeleteRecord(ctx context.Context, in *DeleteRecord // RecordServiceServer is the server API for RecordService service. // All implementations must embed UnimplementedRecordServiceServer // for forward compatibility. +// +// RecordService manages DNS records within zones. type RecordServiceServer interface { ListRecords(context.Context, *ListRecordsRequest) (*ListRecordsResponse, error) CreateRecord(context.Context, *CreateRecordRequest) (*Record, error) diff --git a/gen/mcns/v1/zone.pb.go b/gen/mcns/v1/zone.pb.go index 6e155c2..f402555 100644 --- a/gen/mcns/v1/zone.pb.go +++ b/gen/mcns/v1/zone.pb.go @@ -4,7 +4,7 @@ // protoc v6.32.1 // source: proto/mcns/v1/zone.proto -package v1 +package mcnsv1 import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" @@ -595,7 +595,7 @@ const file_proto_mcns_v1_zone_proto_rawDesc = "" + "\n" + "UpdateZone\x12\x1a.mcns.v1.UpdateZoneRequest\x1a\r.mcns.v1.Zone\x12E\n" + "\n" + - "DeleteZone\x12\x1a.mcns.v1.DeleteZoneRequest\x1a\x1b.mcns.v1.DeleteZoneResponseB(Z&git.wntrmute.dev/kyle/mcns/gen/mcns/v1b\x06proto3" + "DeleteZone\x12\x1a.mcns.v1.DeleteZoneRequest\x1a\x1b.mcns.v1.DeleteZoneResponseB/Z-git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1b\x06proto3" var ( file_proto_mcns_v1_zone_proto_rawDescOnce sync.Once diff --git a/gen/mcns/v1/zone_grpc.pb.go b/gen/mcns/v1/zone_grpc.pb.go index cec965e..4d35702 100644 --- a/gen/mcns/v1/zone_grpc.pb.go +++ b/gen/mcns/v1/zone_grpc.pb.go @@ -4,7 +4,7 @@ // - protoc v6.32.1 // source: proto/mcns/v1/zone.proto -package v1 +package mcnsv1 import ( context "context" @@ -29,6 +29,8 @@ const ( // ZoneServiceClient is the client API for ZoneService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// ZoneService manages DNS zones and their SOA parameters. type ZoneServiceClient interface { ListZones(ctx context.Context, in *ListZonesRequest, opts ...grpc.CallOption) (*ListZonesResponse, error) CreateZone(ctx context.Context, in *CreateZoneRequest, opts ...grpc.CallOption) (*Zone, error) @@ -98,6 +100,8 @@ func (c *zoneServiceClient) DeleteZone(ctx context.Context, in *DeleteZoneReques // ZoneServiceServer is the server API for ZoneService service. // All implementations must embed UnimplementedZoneServiceServer // for forward compatibility. +// +// ZoneService manages DNS zones and their SOA parameters. type ZoneServiceServer interface { ListZones(context.Context, *ListZonesRequest) (*ListZonesResponse, error) CreateZone(context.Context, *CreateZoneRequest) (*Zone, error) diff --git a/internal/server/handlers_test.go b/internal/server/handlers_test.go new file mode 100644 index 0000000..52c7006 --- /dev/null +++ b/internal/server/handlers_test.go @@ -0,0 +1,949 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + + mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth" + "git.wntrmute.dev/kyle/mcns/internal/db" +) + +// openTestDB creates a temporary SQLite database with all migrations applied. +func openTestDB(t *testing.T) *db.DB { + t.Helper() + dir := t.TempDir() + database, err := db.Open(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("open db: %v", err) + } + if err := database.Migrate(); err != nil { + t.Fatalf("migrate: %v", err) + } + t.Cleanup(func() { _ = database.Close() }) + return database +} + +// createTestZone inserts a zone for use by record tests. +func createTestZone(t *testing.T, database *db.DB) *db.Zone { + t.Helper() + zone, err := database.CreateZone("test.example.com", "ns.example.com.", "admin.example.com.", 3600, 600, 86400, 300) + if err != nil { + t.Fatalf("create zone: %v", err) + } + return zone +} + +// newChiRequest builds a request with chi URL params injected into the context. +func newChiRequest(method, target string, body string, params map[string]string) *http.Request { + var r *http.Request + if body != "" { + r = httptest.NewRequest(method, target, strings.NewReader(body)) + } else { + r = httptest.NewRequest(method, target, nil) + } + r.Header.Set("Content-Type", "application/json") + + if len(params) > 0 { + rctx := chi.NewRouteContext() + for k, v := range params { + rctx.URLParams.Add(k, v) + } + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) + } + return r +} + +// decodeJSON decodes the response body into v. +func decodeJSON(t *testing.T, rec *httptest.ResponseRecorder, v any) { + t.Helper() + if err := json.NewDecoder(rec.Body).Decode(v); err != nil { + t.Fatalf("decode json: %v", err) + } +} + +// ---- Zone handler tests ---- + +func TestListZonesHandler_SeedOnly(t *testing.T) { + database := openTestDB(t) + + handler := listZonesHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodGet, "/v1/zones", "", nil) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string][]db.Zone + decodeJSON(t, rec, &resp) + zones := resp["zones"] + if len(zones) != 2 { + t.Fatalf("got %d zones, want 2 (seed zones)", len(zones)) + } +} + +func TestListZonesHandler_Populated(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + handler := listZonesHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodGet, "/v1/zones", "", nil) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string][]db.Zone + decodeJSON(t, rec, &resp) + zones := resp["zones"] + // 2 seed + 1 created = 3. + if len(zones) != 3 { + t.Fatalf("got %d zones, want 3", len(zones)) + } +} + +func TestGetZoneHandler_Found(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + handler := getZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com", "", map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var zone db.Zone + decodeJSON(t, rec, &zone) + if zone.Name != "test.example.com" { + t.Fatalf("zone name = %q, want %q", zone.Name, "test.example.com") + } +} + +func TestGetZoneHandler_NotFound(t *testing.T) { + database := openTestDB(t) + + handler := getZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodGet, "/v1/zones/nonexistent.com", "", map[string]string{"zone": "nonexistent.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +func TestCreateZoneHandler_Success(t *testing.T) { + database := openTestDB(t) + + body := `{"name":"new.example.com","primary_ns":"ns1.example.com.","admin_email":"admin.example.com."}` + handler := createZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPost, "/v1/zones", body, nil) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String()) + } + + var zone db.Zone + decodeJSON(t, rec, &zone) + if zone.Name != "new.example.com" { + t.Fatalf("zone name = %q, want %q", zone.Name, "new.example.com") + } + if zone.PrimaryNS != "ns1.example.com." { + t.Fatalf("primary_ns = %q, want %q", zone.PrimaryNS, "ns1.example.com.") + } + // SOA defaults should be applied. + if zone.Refresh != 3600 { + t.Fatalf("refresh = %d, want 3600", zone.Refresh) + } +} + +func TestCreateZoneHandler_MissingFields(t *testing.T) { + tests := []struct { + name string + body string + }{ + {"missing name", `{"primary_ns":"ns1.example.com.","admin_email":"admin.example.com."}`}, + {"missing primary_ns", `{"name":"new.example.com","admin_email":"admin.example.com."}`}, + {"missing admin_email", `{"name":"new.example.com","primary_ns":"ns1.example.com."}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + database := openTestDB(t) + handler := createZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPost, "/v1/zones", tt.body, nil) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + } +} + +func TestCreateZoneHandler_Duplicate(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + body := `{"name":"test.example.com","primary_ns":"ns1.example.com.","admin_email":"admin.example.com."}` + handler := createZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPost, "/v1/zones", body, nil) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusConflict { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusConflict) + } +} + +func TestCreateZoneHandler_InvalidJSON(t *testing.T) { + database := openTestDB(t) + + handler := createZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPost, "/v1/zones", "not json", nil) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestUpdateZoneHandler_Success(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + body := `{"primary_ns":"ns2.example.com.","admin_email":"newadmin.example.com.","refresh":7200}` + handler := updateZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com", body, map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var zone db.Zone + decodeJSON(t, rec, &zone) + if zone.PrimaryNS != "ns2.example.com." { + t.Fatalf("primary_ns = %q, want %q", zone.PrimaryNS, "ns2.example.com.") + } + if zone.Refresh != 7200 { + t.Fatalf("refresh = %d, want 7200", zone.Refresh) + } +} + +func TestUpdateZoneHandler_NotFound(t *testing.T) { + database := openTestDB(t) + + body := `{"primary_ns":"ns2.example.com.","admin_email":"admin.example.com."}` + handler := updateZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPut, "/v1/zones/nonexistent.com", body, map[string]string{"zone": "nonexistent.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +func TestUpdateZoneHandler_MissingFields(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + body := `{"admin_email":"admin.example.com."}` + handler := updateZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com", body, map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestDeleteZoneHandler_Success(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + handler := deleteZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com", "", map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) + } + + // Verify the zone is gone. + _, err := database.GetZone("test.example.com") + if err != db.ErrNotFound { + t.Fatalf("expected ErrNotFound after delete, got %v", err) + } +} + +func TestDeleteZoneHandler_NotFound(t *testing.T) { + database := openTestDB(t) + + handler := deleteZoneHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodDelete, "/v1/zones/nonexistent.com", "", map[string]string{"zone": "nonexistent.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +// ---- Record handler tests ---- + +func TestListRecordsHandler_WithZone(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + _, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) + if err != nil { + t.Fatalf("create record: %v", err) + } + _, err = database.CreateRecord("test.example.com", "mail", "A", "10.0.0.2", 300) + if err != nil { + t.Fatalf("create record: %v", err) + } + + handler := listRecordsHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records", "", map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string][]db.Record + decodeJSON(t, rec, &resp) + records := resp["records"] + if len(records) != 2 { + t.Fatalf("got %d records, want 2", len(records)) + } +} + +func TestListRecordsHandler_ZoneNotFound(t *testing.T) { + database := openTestDB(t) + + handler := listRecordsHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodGet, "/v1/zones/nonexistent.com/records", "", map[string]string{"zone": "nonexistent.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +func TestListRecordsHandler_EmptyZone(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + handler := listRecordsHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records", "", map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string][]db.Record + decodeJSON(t, rec, &resp) + records := resp["records"] + if len(records) != 0 { + t.Fatalf("got %d records, want 0", len(records)) + } +} + +func TestListRecordsHandler_WithFilters(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + _, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) + if err != nil { + t.Fatalf("create record: %v", err) + } + _, err = database.CreateRecord("test.example.com", "www", "A", "10.0.0.2", 300) + if err != nil { + t.Fatalf("create record: %v", err) + } + _, err = database.CreateRecord("test.example.com", "mail", "A", "10.0.0.3", 300) + if err != nil { + t.Fatalf("create record: %v", err) + } + + handler := listRecordsHandler(database) + + // Filter by name. + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records?name=www", "", map[string]string{"zone": "test.example.com"}) + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var resp map[string][]db.Record + decodeJSON(t, rec, &resp) + if len(resp["records"]) != 2 { + t.Fatalf("got %d records for name=www, want 2", len(resp["records"])) + } +} + +func TestGetRecordHandler_Found(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) + if err != nil { + t.Fatalf("create record: %v", err) + } + + handler := getRecordHandler(database) + rec := httptest.NewRecorder() + idStr := fmt.Sprintf("%d", created.ID) + req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records/"+idStr, "", map[string]string{ + "zone": "test.example.com", + "id": idStr, + }) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var record db.Record + decodeJSON(t, rec, &record) + if record.Name != "www" { + t.Fatalf("record name = %q, want %q", record.Name, "www") + } + if record.Value != "10.0.0.1" { + t.Fatalf("record value = %q, want %q", record.Value, "10.0.0.1") + } +} + +func TestGetRecordHandler_NotFound(t *testing.T) { + database := openTestDB(t) + + handler := getRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records/99999", "", map[string]string{ + "zone": "test.example.com", + "id": "99999", + }) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +func TestGetRecordHandler_InvalidID(t *testing.T) { + database := openTestDB(t) + + handler := getRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records/abc", "", map[string]string{ + "zone": "test.example.com", + "id": "abc", + }) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestCreateRecordHandler_Success(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + body := `{"name":"www","type":"A","value":"10.0.0.1","ttl":600}` + handler := createRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", body, map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String()) + } + + var record db.Record + decodeJSON(t, rec, &record) + if record.Name != "www" { + t.Fatalf("record name = %q, want %q", record.Name, "www") + } + if record.Type != "A" { + t.Fatalf("record type = %q, want %q", record.Type, "A") + } + if record.TTL != 600 { + t.Fatalf("ttl = %d, want 600", record.TTL) + } +} + +func TestCreateRecordHandler_MissingFields(t *testing.T) { + tests := []struct { + name string + body string + }{ + {"missing name", `{"type":"A","value":"10.0.0.1"}`}, + {"missing type", `{"name":"www","value":"10.0.0.1"}`}, + {"missing value", `{"name":"www","type":"A"}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + handler := createRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", tt.body, map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + } +} + +func TestCreateRecordHandler_InvalidIP(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + body := `{"name":"www","type":"A","value":"not-an-ip"}` + handler := createRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", body, map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestCreateRecordHandler_CNAMEConflict(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + // Create an A record first. + _, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) + if err != nil { + t.Fatalf("create A record: %v", err) + } + + // Try to create a CNAME for the same name via handler. + body := `{"name":"www","type":"CNAME","value":"other.example.com."}` + handler := createRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", body, map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusConflict { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusConflict) + } +} + +func TestCreateRecordHandler_ZoneNotFound(t *testing.T) { + database := openTestDB(t) + + body := `{"name":"www","type":"A","value":"10.0.0.1"}` + handler := createRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPost, "/v1/zones/nonexistent.com/records", body, map[string]string{"zone": "nonexistent.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +func TestCreateRecordHandler_InvalidJSON(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + handler := createRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", "not json", map[string]string{"zone": "test.example.com"}) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestUpdateRecordHandler_Success(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) + if err != nil { + t.Fatalf("create record: %v", err) + } + + idStr := fmt.Sprintf("%d", created.ID) + body := `{"name":"www","type":"A","value":"10.0.0.2","ttl":600}` + handler := updateRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/"+idStr, body, map[string]string{ + "zone": "test.example.com", + "id": idStr, + }) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var record db.Record + decodeJSON(t, rec, &record) + if record.Value != "10.0.0.2" { + t.Fatalf("value = %q, want %q", record.Value, "10.0.0.2") + } + if record.TTL != 600 { + t.Fatalf("ttl = %d, want 600", record.TTL) + } +} + +func TestUpdateRecordHandler_NotFound(t *testing.T) { + database := openTestDB(t) + + body := `{"name":"www","type":"A","value":"10.0.0.1"}` + handler := updateRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/99999", body, map[string]string{ + "zone": "test.example.com", + "id": "99999", + }) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +func TestUpdateRecordHandler_InvalidID(t *testing.T) { + database := openTestDB(t) + + body := `{"name":"www","type":"A","value":"10.0.0.1"}` + handler := updateRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/abc", body, map[string]string{ + "zone": "test.example.com", + "id": "abc", + }) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestUpdateRecordHandler_MissingFields(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) + if err != nil { + t.Fatalf("create record: %v", err) + } + + idStr := fmt.Sprintf("%d", created.ID) + + // Missing name. + body := `{"type":"A","value":"10.0.0.1"}` + handler := updateRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/"+idStr, body, map[string]string{ + "zone": "test.example.com", + "id": idStr, + }) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestDeleteRecordHandler_Success(t *testing.T) { + database := openTestDB(t) + createTestZone(t, database) + + created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) + if err != nil { + t.Fatalf("create record: %v", err) + } + + idStr := fmt.Sprintf("%d", created.ID) + handler := deleteRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com/records/"+idStr, "", map[string]string{ + "zone": "test.example.com", + "id": idStr, + }) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) + } + + // Verify record is gone. + _, err = database.GetRecord(created.ID) + if err != db.ErrNotFound { + t.Fatalf("expected ErrNotFound after delete, got %v", err) + } +} + +func TestDeleteRecordHandler_NotFound(t *testing.T) { + database := openTestDB(t) + + handler := deleteRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com/records/99999", "", map[string]string{ + "zone": "test.example.com", + "id": "99999", + }) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) + } +} + +func TestDeleteRecordHandler_InvalidID(t *testing.T) { + database := openTestDB(t) + + handler := deleteRecordHandler(database) + rec := httptest.NewRecorder() + req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com/records/abc", "", map[string]string{ + "zone": "test.example.com", + "id": "abc", + }) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +// ---- Middleware tests ---- + +func TestRequireAdmin_WithAdminContext(t *testing.T) { + called := false + inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + + handler := requireAdmin(inner) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + // Inject admin TokenInfo into context. + info := &mcdslauth.TokenInfo{ + Username: "admin-user", + IsAdmin: true, + Roles: []string{"admin"}, + } + ctx := context.WithValue(req.Context(), tokenInfoKey, info) + req = req.WithContext(ctx) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + if !called { + t.Fatal("inner handler was not called") + } +} + +func TestRequireAdmin_WithNonAdminContext(t *testing.T) { + called := false + inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + + handler := requireAdmin(inner) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + // Inject non-admin TokenInfo into context. + info := &mcdslauth.TokenInfo{ + Username: "regular-user", + IsAdmin: false, + Roles: []string{"viewer"}, + } + ctx := context.WithValue(req.Context(), tokenInfoKey, info) + req = req.WithContext(ctx) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) + } + if called { + t.Fatal("inner handler should not have been called") + } +} + +func TestRequireAdmin_NoTokenInfo(t *testing.T) { + called := false + inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + + handler := requireAdmin(inner) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) + } + if called { + t.Fatal("inner handler should not have been called") + } +} + +func TestExtractBearerToken(t *testing.T) { + tests := []struct { + name string + header string + want string + }{ + {"valid bearer", "Bearer abc123", "abc123"}, + {"empty header", "", ""}, + {"no prefix", "abc123", ""}, + {"basic auth", "Basic abc123", ""}, + {"bearer with spaces", "Bearer token-with-space ", "token-with-space"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + if tt.header != "" { + r.Header.Set("Authorization", tt.header) + } + got := extractBearerToken(r) + if got != tt.want { + t.Fatalf("extractBearerToken(%q) = %q, want %q", tt.header, got, tt.want) + } + }) + } +} + +func TestTokenInfoFromContext(t *testing.T) { + // No token info in context. + ctx := context.Background() + if info := tokenInfoFromContext(ctx); info != nil { + t.Fatal("expected nil, got token info") + } + + // With token info. + expected := &mcdslauth.TokenInfo{Username: "testuser", IsAdmin: true} + ctx = context.WithValue(ctx, tokenInfoKey, expected) + got := tokenInfoFromContext(ctx) + if got == nil { + t.Fatal("expected token info, got nil") + } + if got.Username != expected.Username { + t.Fatalf("username = %q, want %q", got.Username, expected.Username) + } + if !got.IsAdmin { + t.Fatal("expected IsAdmin to be true") + } +} + +// ---- writeJSON / writeError tests ---- + +func TestWriteJSON(t *testing.T) { + rec := httptest.NewRecorder() + writeJSON(rec, http.StatusOK, map[string]string{"key": "value"}) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + if ct := rec.Header().Get("Content-Type"); ct != "application/json" { + t.Fatalf("content-type = %q, want %q", ct, "application/json") + } + + var resp map[string]string + decodeJSON(t, rec, &resp) + if resp["key"] != "value" { + t.Fatalf("got key=%q, want %q", resp["key"], "value") + } +} + +func TestWriteError(t *testing.T) { + rec := httptest.NewRecorder() + writeError(rec, http.StatusBadRequest, "bad input") + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + + var resp map[string]string + decodeJSON(t, rec, &resp) + if resp["error"] != "bad input" { + t.Fatalf("got error=%q, want %q", resp["error"], "bad input") + } +} diff --git a/proto/mcns/v1/admin.proto b/proto/mcns/v1/admin.proto index 9dcf262..f1a4e19 100644 --- a/proto/mcns/v1/admin.proto +++ b/proto/mcns/v1/admin.proto @@ -2,8 +2,9 @@ syntax = "proto3"; package mcns.v1; -option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"; +option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1"; +// AdminService exposes server health and administrative operations. service AdminService { rpc Health(HealthRequest) returns (HealthResponse); } diff --git a/proto/mcns/v1/auth.proto b/proto/mcns/v1/auth.proto index 25fe759..46cdcfc 100644 --- a/proto/mcns/v1/auth.proto +++ b/proto/mcns/v1/auth.proto @@ -2,8 +2,9 @@ syntax = "proto3"; package mcns.v1; -option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"; +option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1"; +// AuthService handles authentication by delegating to MCIAS. service AuthService { rpc Login(LoginRequest) returns (LoginResponse); rpc Logout(LogoutRequest) returns (LogoutResponse); @@ -12,6 +13,7 @@ service AuthService { message LoginRequest { string username = 1; string password = 2; + // TOTP code for two-factor authentication, if enabled on the account. string totp_code = 3; } diff --git a/proto/mcns/v1/record.proto b/proto/mcns/v1/record.proto index 6f9b2ef..2197f46 100644 --- a/proto/mcns/v1/record.proto +++ b/proto/mcns/v1/record.proto @@ -2,10 +2,11 @@ syntax = "proto3"; package mcns.v1; -option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"; +option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1"; import "google/protobuf/timestamp.proto"; +// RecordService manages DNS records within zones. service RecordService { rpc ListRecords(ListRecordsRequest) returns (ListRecordsResponse); rpc CreateRecord(CreateRecordRequest) returns (Record); @@ -16,8 +17,10 @@ service RecordService { message Record { int64 id = 1; + // Zone name this record belongs to (e.g. "example.com."). string zone = 2; string name = 3; + // DNS record type (A, AAAA, CNAME, MX, TXT, etc.). string type = 4; string value = 5; int32 ttl = 6; @@ -27,7 +30,9 @@ message Record { message ListRecordsRequest { string zone = 1; + // Optional filter by record name. string name = 2; + // Optional filter by record type (A, AAAA, CNAME, etc.). string type = 3; } @@ -36,6 +41,7 @@ message ListRecordsResponse { } message CreateRecordRequest { + // Zone name the record will be created in; must reference an existing zone. string zone = 1; string name = 2; string type = 3; diff --git a/proto/mcns/v1/zone.proto b/proto/mcns/v1/zone.proto index eb68aaf..65b63d4 100644 --- a/proto/mcns/v1/zone.proto +++ b/proto/mcns/v1/zone.proto @@ -2,10 +2,11 @@ syntax = "proto3"; package mcns.v1; -option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"; +option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1"; import "google/protobuf/timestamp.proto"; +// ZoneService manages DNS zones and their SOA parameters. service ZoneService { rpc ListZones(ListZonesRequest) returns (ListZonesResponse); rpc CreateZone(CreateZoneRequest) returns (Zone); From 82b7d295effc2c2d1c14e1303b04b86ca70ccef2 Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Thu, 26 Mar 2026 21:06:44 -0700 Subject: [PATCH 2/2] Add gRPC handler tests for zones, records, admin, and interceptors Full integration tests exercising gRPC services through real server with mock MCIAS auth. Covers all CRUD operations for zones and records, health check bypass, auth/admin interceptor enforcement, CNAME exclusivity conflicts, and method map completeness verification. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/grpcserver/handlers_test.go | 815 +++++++++++++++++++++++++++ 1 file changed, 815 insertions(+) create mode 100644 internal/grpcserver/handlers_test.go diff --git a/internal/grpcserver/handlers_test.go b/internal/grpcserver/handlers_test.go new file mode 100644 index 0000000..9104006 --- /dev/null +++ b/internal/grpcserver/handlers_test.go @@ -0,0 +1,815 @@ +package grpcserver + +import ( + "context" + "encoding/json" + "log/slog" + "net" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth" + + pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1" + "git.wntrmute.dev/kyle/mcns/internal/db" +) + +// mockMCIAS starts a fake MCIAS HTTP server for token validation. +// Recognized tokens: +// - "admin-token" -> valid, username=admin-uuid, roles=[admin] +// - "user-token" -> valid, username=user-uuid, roles=[user] +// - anything else -> invalid +func mockMCIAS(t *testing.T) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + mux.HandleFunc("POST /v1/token/validate", func(w http.ResponseWriter, r *http.Request) { + var req struct { + Token string `json:"token"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + switch req.Token { + case "admin-token": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "valid": true, + "username": "admin-uuid", + "account_type": "human", + "roles": []string{"admin"}, + }) + case "user-token": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "valid": true, + "username": "user-uuid", + "account_type": "human", + "roles": []string{"user"}, + }) + default: + _ = json.NewEncoder(w).Encode(map[string]interface{}{"valid": false}) + } + }) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + return srv +} + +// testAuthenticator creates an mcdsl/auth.Authenticator that talks to the given mock MCIAS. +func testAuthenticator(t *testing.T, serverURL string) *mcdslauth.Authenticator { + t.Helper() + a, err := mcdslauth.New(mcdslauth.Config{ServerURL: serverURL}, slog.Default()) + if err != nil { + t.Fatalf("auth.New: %v", err) + } + return a +} + +// openTestDB creates a temporary test database with migrations applied. +func openTestDB(t *testing.T) *db.DB { + t.Helper() + path := filepath.Join(t.TempDir(), "test.db") + d, err := db.Open(path) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { _ = d.Close() }) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + return d +} + +// startTestServer creates a gRPC server with auth interceptors and returns +// a connected client. Passing empty cert/key strings skips TLS. +func startTestServer(t *testing.T, deps Deps) *grpc.ClientConn { + t.Helper() + + srv, err := New("", "", deps, slog.Default()) + if err != nil { + t.Fatalf("New: %v", err) + } + + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen: %v", err) + } + + go func() { + _ = srv.Serve(lis) + }() + t.Cleanup(func() { srv.GracefulStop() }) + + //nolint:gosec // insecure credentials for testing only + cc, err := grpc.NewClient( + lis.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Dial: %v", err) + } + t.Cleanup(func() { _ = cc.Close() }) + + return cc +} + +// withAuth adds a bearer token to the outgoing context metadata. +func withAuth(ctx context.Context, token string) context.Context { + return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token) +} + +// seedZone creates a test zone and returns it. +func seedZone(t *testing.T, database *db.DB, name string) *db.Zone { + t.Helper() + zone, err := database.CreateZone(name, "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300) + if err != nil { + t.Fatalf("seed zone %q: %v", name, err) + } + return zone +} + +// seedRecord creates a test A record and returns it. +func seedRecord(t *testing.T, database *db.DB, zoneName, name, value string) *db.Record { + t.Helper() + rec, err := database.CreateRecord(zoneName, name, "A", value, 300) + if err != nil { + t.Fatalf("seed record %s.%s: %v", name, zoneName, err) + } + return rec +} + +// --------------------------------------------------------------------------- +// Admin tests +// --------------------------------------------------------------------------- + +func TestHealthBypassesAuth(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + client := pb.NewAdminServiceClient(cc) + + // No auth token -- should still succeed because Health is public. + resp, err := client.Health(context.Background(), &pb.HealthRequest{}) + if err != nil { + t.Fatalf("Health should not require auth: %v", err) + } + if resp.Status != "ok" { + t.Fatalf("Health status: got %q, want %q", resp.Status, "ok") + } +} + +// --------------------------------------------------------------------------- +// Zone tests +// --------------------------------------------------------------------------- + +func TestListZones(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "user-token") + client := pb.NewZoneServiceClient(cc) + + resp, err := client.ListZones(ctx, &pb.ListZonesRequest{}) + if err != nil { + t.Fatalf("ListZones: %v", err) + } + // Seed migration creates 2 zones. + if len(resp.Zones) != 2 { + t.Fatalf("got %d zones, want 2 (seed zones)", len(resp.Zones)) + } +} + +func TestGetZoneFound(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "user-token") + client := pb.NewZoneServiceClient(cc) + + zone, err := client.GetZone(ctx, &pb.GetZoneRequest{Name: "example.com"}) + if err != nil { + t.Fatalf("GetZone: %v", err) + } + if zone.Name != "example.com" { + t.Fatalf("got name %q, want %q", zone.Name, "example.com") + } +} + +func TestGetZoneNotFound(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "user-token") + client := pb.NewZoneServiceClient(cc) + + _, err := client.GetZone(ctx, &pb.GetZoneRequest{Name: "nonexistent.com"}) + if err == nil { + t.Fatal("expected error for nonexistent zone") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status, got %v", err) + } + if st.Code() != codes.NotFound { + t.Fatalf("code: got %v, want NotFound", st.Code()) + } +} + +func TestCreateZoneSuccess(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewZoneServiceClient(cc) + + zone, err := client.CreateZone(ctx, &pb.CreateZoneRequest{ + Name: "newzone.com", + PrimaryNs: "ns1.newzone.com.", + AdminEmail: "admin.newzone.com.", + Refresh: 3600, + Retry: 600, + Expire: 86400, + MinimumTtl: 300, + }) + if err != nil { + t.Fatalf("CreateZone: %v", err) + } + if zone.Name != "newzone.com" { + t.Fatalf("got name %q, want %q", zone.Name, "newzone.com") + } + if zone.Serial == 0 { + t.Fatal("serial should not be zero") + } +} + +func TestCreateZoneDuplicate(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewZoneServiceClient(cc) + + _, err := client.CreateZone(ctx, &pb.CreateZoneRequest{ + Name: "example.com", + PrimaryNs: "ns1.example.com.", + AdminEmail: "admin.example.com.", + }) + if err == nil { + t.Fatal("expected error for duplicate zone") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status, got %v", err) + } + if st.Code() != codes.AlreadyExists { + t.Fatalf("code: got %v, want AlreadyExists", st.Code()) + } +} + +func TestUpdateZone(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + original := seedZone(t, database, "example.com") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewZoneServiceClient(cc) + + updated, err := client.UpdateZone(ctx, &pb.UpdateZoneRequest{ + Name: "example.com", + PrimaryNs: "ns2.example.com.", + AdminEmail: "newadmin.example.com.", + Refresh: 7200, + Retry: 1200, + Expire: 172800, + MinimumTtl: 600, + }) + if err != nil { + t.Fatalf("UpdateZone: %v", err) + } + if updated.PrimaryNs != "ns2.example.com." { + t.Fatalf("got primary_ns %q, want %q", updated.PrimaryNs, "ns2.example.com.") + } + if updated.Serial <= original.Serial { + t.Fatalf("serial should have incremented: %d <= %d", updated.Serial, original.Serial) + } +} + +func TestDeleteZone(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewZoneServiceClient(cc) + + _, err := client.DeleteZone(ctx, &pb.DeleteZoneRequest{Name: "example.com"}) + if err != nil { + t.Fatalf("DeleteZone: %v", err) + } + + // Verify it is gone. + _, err = client.GetZone(withAuth(context.Background(), "user-token"), &pb.GetZoneRequest{Name: "example.com"}) + if err == nil { + t.Fatal("expected NotFound after delete") + } + st, _ := status.FromError(err) + if st.Code() != codes.NotFound { + t.Fatalf("code: got %v, want NotFound", st.Code()) + } +} + +func TestDeleteZoneNotFound(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewZoneServiceClient(cc) + + _, err := client.DeleteZone(ctx, &pb.DeleteZoneRequest{Name: "nonexistent.com"}) + if err == nil { + t.Fatal("expected error for nonexistent zone") + } + st, _ := status.FromError(err) + if st.Code() != codes.NotFound { + t.Fatalf("code: got %v, want NotFound", st.Code()) + } +} + +// --------------------------------------------------------------------------- +// Record tests +// --------------------------------------------------------------------------- + +func TestListRecords(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + seedRecord(t, database, "example.com", "www", "10.0.0.1") + seedRecord(t, database, "example.com", "mail", "10.0.0.2") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "user-token") + client := pb.NewRecordServiceClient(cc) + + resp, err := client.ListRecords(ctx, &pb.ListRecordsRequest{Zone: "example.com"}) + if err != nil { + t.Fatalf("ListRecords: %v", err) + } + if len(resp.Records) != 2 { + t.Fatalf("got %d records, want 2", len(resp.Records)) + } +} + +func TestGetRecordFound(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + rec := seedRecord(t, database, "example.com", "www", "10.0.0.1") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "user-token") + client := pb.NewRecordServiceClient(cc) + + got, err := client.GetRecord(ctx, &pb.GetRecordRequest{Id: rec.ID}) + if err != nil { + t.Fatalf("GetRecord: %v", err) + } + if got.Name != "www" { + t.Fatalf("got name %q, want %q", got.Name, "www") + } + if got.Value != "10.0.0.1" { + t.Fatalf("got value %q, want %q", got.Value, "10.0.0.1") + } +} + +func TestGetRecordNotFound(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "user-token") + client := pb.NewRecordServiceClient(cc) + + _, err := client.GetRecord(ctx, &pb.GetRecordRequest{Id: 999999}) + if err == nil { + t.Fatal("expected error for nonexistent record") + } + st, _ := status.FromError(err) + if st.Code() != codes.NotFound { + t.Fatalf("code: got %v, want NotFound", st.Code()) + } +} + +func TestCreateRecordSuccess(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewRecordServiceClient(cc) + + rec, err := client.CreateRecord(ctx, &pb.CreateRecordRequest{ + Zone: "example.com", + Name: "www", + Type: "A", + Value: "10.0.0.1", + Ttl: 300, + }) + if err != nil { + t.Fatalf("CreateRecord: %v", err) + } + if rec.Name != "www" { + t.Fatalf("got name %q, want %q", rec.Name, "www") + } + if rec.Type != "A" { + t.Fatalf("got type %q, want %q", rec.Type, "A") + } +} + +func TestCreateRecordInvalidValue(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewRecordServiceClient(cc) + + _, err := client.CreateRecord(ctx, &pb.CreateRecordRequest{ + Zone: "example.com", + Name: "www", + Type: "A", + Value: "not-an-ip", + Ttl: 300, + }) + if err == nil { + t.Fatal("expected error for invalid A record value") + } + st, _ := status.FromError(err) + if st.Code() != codes.InvalidArgument { + t.Fatalf("code: got %v, want InvalidArgument", st.Code()) + } +} + +func TestCreateRecordCNAMEConflict(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + seedRecord(t, database, "example.com", "www", "10.0.0.1") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewRecordServiceClient(cc) + + // Try to create a CNAME for "www" which already has an A record. + _, err := client.CreateRecord(ctx, &pb.CreateRecordRequest{ + Zone: "example.com", + Name: "www", + Type: "CNAME", + Value: "other.example.com.", + Ttl: 300, + }) + if err == nil { + t.Fatal("expected error for CNAME conflict with existing A record") + } + st, _ := status.FromError(err) + if st.Code() != codes.AlreadyExists { + t.Fatalf("code: got %v, want AlreadyExists", st.Code()) + } +} + +func TestCreateRecordAConflictWithCNAME(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + + // Create a CNAME first. + _, err := database.CreateRecord("example.com", "alias", "CNAME", "target.example.com.", 300) + if err != nil { + t.Fatalf("seed CNAME: %v", err) + } + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewRecordServiceClient(cc) + + // Try to create an A record for "alias" which already has a CNAME. + _, err = client.CreateRecord(ctx, &pb.CreateRecordRequest{ + Zone: "example.com", + Name: "alias", + Type: "A", + Value: "10.0.0.1", + Ttl: 300, + }) + if err == nil { + t.Fatal("expected error for A record conflict with existing CNAME") + } + st, _ := status.FromError(err) + if st.Code() != codes.AlreadyExists { + t.Fatalf("code: got %v, want AlreadyExists", st.Code()) + } +} + +func TestUpdateRecord(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + rec := seedRecord(t, database, "example.com", "www", "10.0.0.1") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewRecordServiceClient(cc) + + updated, err := client.UpdateRecord(ctx, &pb.UpdateRecordRequest{ + Id: rec.ID, + Name: "www", + Type: "A", + Value: "10.0.0.2", + Ttl: 600, + }) + if err != nil { + t.Fatalf("UpdateRecord: %v", err) + } + if updated.Value != "10.0.0.2" { + t.Fatalf("got value %q, want %q", updated.Value, "10.0.0.2") + } + if updated.Ttl != 600 { + t.Fatalf("got ttl %d, want 600", updated.Ttl) + } +} + +func TestUpdateRecordNotFound(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewRecordServiceClient(cc) + + _, err := client.UpdateRecord(ctx, &pb.UpdateRecordRequest{ + Id: 999999, + Name: "www", + Type: "A", + Value: "10.0.0.1", + Ttl: 300, + }) + if err == nil { + t.Fatal("expected error for nonexistent record") + } + st, _ := status.FromError(err) + if st.Code() != codes.NotFound { + t.Fatalf("code: got %v, want NotFound", st.Code()) + } +} + +func TestDeleteRecord(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + seedZone(t, database, "example.com") + rec := seedRecord(t, database, "example.com", "www", "10.0.0.1") + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewRecordServiceClient(cc) + + _, err := client.DeleteRecord(ctx, &pb.DeleteRecordRequest{Id: rec.ID}) + if err != nil { + t.Fatalf("DeleteRecord: %v", err) + } + + // Verify it is gone. + _, err = client.GetRecord(withAuth(context.Background(), "user-token"), &pb.GetRecordRequest{Id: rec.ID}) + if err == nil { + t.Fatal("expected NotFound after delete") + } + st, _ := status.FromError(err) + if st.Code() != codes.NotFound { + t.Fatalf("code: got %v, want NotFound", st.Code()) + } +} + +func TestDeleteRecordNotFound(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewRecordServiceClient(cc) + + _, err := client.DeleteRecord(ctx, &pb.DeleteRecordRequest{Id: 999999}) + if err == nil { + t.Fatal("expected error for nonexistent record") + } + st, _ := status.FromError(err) + if st.Code() != codes.NotFound { + t.Fatalf("code: got %v, want NotFound", st.Code()) + } +} + +// --------------------------------------------------------------------------- +// Auth interceptor tests +// --------------------------------------------------------------------------- + +func TestAuthRequiredNoToken(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + client := pb.NewZoneServiceClient(cc) + + // No auth token on an auth-required method. + _, err := client.ListZones(context.Background(), &pb.ListZonesRequest{}) + if err == nil { + t.Fatal("expected error for unauthenticated request") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.Unauthenticated { + t.Fatalf("code: got %v, want Unauthenticated", st.Code()) + } +} + +func TestAuthRequiredInvalidToken(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "bad-token") + client := pb.NewZoneServiceClient(cc) + + _, err := client.ListZones(ctx, &pb.ListZonesRequest{}) + if err == nil { + t.Fatal("expected error for invalid token") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.Unauthenticated { + t.Fatalf("code: got %v, want Unauthenticated", st.Code()) + } +} + +func TestAdminRequiredDeniedForUser(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "user-token") + client := pb.NewZoneServiceClient(cc) + + // CreateZone requires admin. + _, err := client.CreateZone(ctx, &pb.CreateZoneRequest{ + Name: "forbidden.com", + PrimaryNs: "ns1.forbidden.com.", + AdminEmail: "admin.forbidden.com.", + }) + if err == nil { + t.Fatal("expected error for non-admin user") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("expected gRPC status error, got %v", err) + } + if st.Code() != codes.PermissionDenied { + t.Fatalf("code: got %v, want PermissionDenied", st.Code()) + } +} + +func TestAdminRequiredAllowedForAdmin(t *testing.T) { + mcias := mockMCIAS(t) + auth := testAuthenticator(t, mcias.URL) + database := openTestDB(t) + + cc := startTestServer(t, Deps{DB: database, Authenticator: auth}) + ctx := withAuth(context.Background(), "admin-token") + client := pb.NewZoneServiceClient(cc) + + // Admin should be able to create zones. + zone, err := client.CreateZone(ctx, &pb.CreateZoneRequest{ + Name: "admin-created.com", + PrimaryNs: "ns1.admin-created.com.", + AdminEmail: "admin.admin-created.com.", + }) + if err != nil { + t.Fatalf("CreateZone as admin: %v", err) + } + if zone.Name != "admin-created.com" { + t.Fatalf("got name %q, want %q", zone.Name, "admin-created.com") + } +} + +// --------------------------------------------------------------------------- +// Interceptor map completeness test +// --------------------------------------------------------------------------- + +func TestMethodMapCompleteness(t *testing.T) { + mm := methodMap() + + expectedPublic := []string{ + "/mcns.v1.AdminService/Health", + "/mcns.v1.AuthService/Login", + } + for _, method := range expectedPublic { + if !mm.Public[method] { + t.Errorf("method %s should be public but is not in Public", method) + } + } + if len(mm.Public) != len(expectedPublic) { + t.Errorf("Public has %d entries, expected %d", len(mm.Public), len(expectedPublic)) + } + + expectedAuth := []string{ + "/mcns.v1.AuthService/Logout", + "/mcns.v1.ZoneService/ListZones", + "/mcns.v1.ZoneService/GetZone", + "/mcns.v1.RecordService/ListRecords", + "/mcns.v1.RecordService/GetRecord", + } + for _, method := range expectedAuth { + if !mm.AuthRequired[method] { + t.Errorf("method %s should require auth but is not in AuthRequired", method) + } + } + if len(mm.AuthRequired) != len(expectedAuth) { + t.Errorf("AuthRequired has %d entries, expected %d", len(mm.AuthRequired), len(expectedAuth)) + } + + expectedAdmin := []string{ + "/mcns.v1.ZoneService/CreateZone", + "/mcns.v1.ZoneService/UpdateZone", + "/mcns.v1.ZoneService/DeleteZone", + "/mcns.v1.RecordService/CreateRecord", + "/mcns.v1.RecordService/UpdateRecord", + "/mcns.v1.RecordService/DeleteRecord", + } + for _, method := range expectedAdmin { + if !mm.AdminRequired[method] { + t.Errorf("method %s should require admin but is not in AdminRequired", method) + } + } + if len(mm.AdminRequired) != len(expectedAdmin) { + t.Errorf("AdminRequired has %d entries, expected %d", len(mm.AdminRequired), len(expectedAdmin)) + } + + // Verify no method appears in multiple maps (each RPC in exactly one map). + all := make(map[string]int) + for k := range mm.Public { + all[k]++ + } + for k := range mm.AuthRequired { + all[k]++ + } + for k := range mm.AdminRequired { + all[k]++ + } + for method, count := range all { + if count != 1 { + t.Errorf("method %s appears in %d maps, expected exactly 1", method, count) + } + } +}