2 Commits

Author SHA1 Message Date
82b7d295ef Add gRPC handler tests for zones, records, admin, and interceptors
Full integration tests exercising gRPC services through real server with
mock MCIAS auth. Covers all CRUD operations for zones and records,
health check bypass, auth/admin interceptor enforcement, CNAME
exclusivity conflicts, and method map completeness verification.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 21:06:44 -07:00
4ec0c3a916 Add REST API handler tests for zones, records, and middleware
Cover all REST handlers with httptest-based tests using real SQLite:
zones (list, get, create, update, delete), records (list, get, create,
update, delete with validation/conflict cases), requireAdmin middleware
(admin, non-admin, missing context), and utility functions (writeJSON,
writeError, extractBearerToken, tokenInfoFromContext).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 21:05:54 -07:00
17 changed files with 1833 additions and 318 deletions

3
.gitignore vendored
View File

@@ -3,3 +3,6 @@ srv/
*.db *.db
*.db-wal *.db-wal
*.db-shm *.db-shm
.idea/
.vscode/
.DS_Store

View File

@@ -1,42 +0,0 @@
# MCNS
Metacircular Networking Service -- an authoritative DNS server for the
Metacircular platform. MCNS serves DNS zones backed by SQLite, forwards
non-authoritative queries to upstream resolvers, and exposes a gRPC and
REST management API authenticated through MCIAS. Records are updated
dynamically via the API and visible to DNS immediately on commit.
## Quick Start
Build the binary:
```bash
make mcns
```
Copy and edit the example configuration:
```bash
cp deploy/examples/mcns.toml /srv/mcns/mcns.toml
# Edit TLS paths, database path, MCIAS URL, upstream resolvers
```
Run the server:
```bash
./mcns server --config /srv/mcns/mcns.toml
```
The server starts three listeners:
| Port | Protocol | Purpose |
|------|----------|---------|
| 53 | UDP + TCP | DNS (no auth) |
| 8443 | TCP | REST management API (TLS, MCIAS auth) |
| 9443 | TCP | gRPC management API (TLS, MCIAS auth) |
## Documentation
- [ARCHITECTURE.md](ARCHITECTURE.md) -- full technical specification, database schema, API surface, and security model.
- [RUNBOOK.md](RUNBOOK.md) -- operational procedures and incident response for operators.
- [CLAUDE.md](CLAUDE.md) -- context for AI-assisted development.

View File

@@ -1,242 +0,0 @@
# MCNS Runbook
## Service Overview
MCNS is an authoritative DNS server for the Metacircular platform. It
listens on port 53 (UDP+TCP) for DNS queries, port 8443 for the REST
management API, and port 9443 for the gRPC management API. Zone and
record data is stored in SQLite. All management operations require MCIAS
authentication; DNS queries are unauthenticated.
## Health Checks
### CLI
```bash
mcns status --addr https://localhost:8443
```
With a custom CA certificate:
```bash
mcns status --addr https://localhost:8443 --ca-cert /srv/mcns/certs/ca.pem
```
Expected output: `ok`
### REST
```bash
curl -k https://localhost:8443/v1/health
```
Expected: HTTP 200.
### gRPC
Use the `AdminService.Health` RPC on port 9443. This method is public
(no auth required).
### DNS
```bash
dig @localhost svc.mcp.metacircular.net SOA +short
```
A valid SOA response confirms the DNS listener and database are working.
## Common Operations
### Start the Service
1. Verify config exists: `ls /srv/mcns/mcns.toml`
2. Start the container:
```bash
docker compose -f deploy/docker/docker-compose-rift.yml up -d
```
3. Verify health:
```bash
mcns status --addr https://localhost:8443
```
### Stop the Service
1. Stop the container:
```bash
docker compose -f deploy/docker/docker-compose-rift.yml stop mcns
```
2. MCNS handles SIGTERM gracefully and drains in-flight requests (30s timeout).
### Restart the Service
1. Restart the container:
```bash
docker compose -f deploy/docker/docker-compose-rift.yml restart mcns
```
2. Verify health:
```bash
mcns status --addr https://localhost:8443
```
### Backup (Snapshot)
1. Run the snapshot command:
```bash
mcns snapshot --config /srv/mcns/mcns.toml
```
2. The snapshot is saved to `/srv/mcns/backups/mcns-YYYYMMDD-HHMMSS.db`.
3. Verify the snapshot file exists and has a reasonable size:
```bash
ls -lh /srv/mcns/backups/
```
### Restore from Snapshot
1. Stop the service (see above).
2. Back up the current database:
```bash
cp /srv/mcns/mcns.db /srv/mcns/mcns.db.pre-restore
```
3. Copy the snapshot into place:
```bash
cp /srv/mcns/backups/mcns-YYYYMMDD-HHMMSS.db /srv/mcns/mcns.db
```
4. Start the service (see above).
5. Verify the service is healthy:
```bash
mcns status --addr https://localhost:8443
```
6. Verify zones are accessible by querying DNS:
```bash
dig @localhost svc.mcp.metacircular.net SOA +short
```
### Log Inspection
Container logs:
```bash
docker compose -f deploy/docker/docker-compose-rift.yml logs --tail 100 mcns
```
Follow logs in real time:
```bash
docker compose -f deploy/docker/docker-compose-rift.yml logs -f mcns
```
MCNS logs to stderr as structured text (slog). Log level is configured
via `[log] level` in `mcns.toml` (debug, info, warn, error).
## Incident Procedures
### Database Corruption
Symptoms: server fails to start with SQLite errors, or queries return
unexpected errors.
1. Stop the service.
2. Check for WAL/SHM files alongside the database:
```bash
ls -la /srv/mcns/mcns.db*
```
3. Attempt an integrity check:
```bash
sqlite3 /srv/mcns/mcns.db "PRAGMA integrity_check;"
```
4. If integrity check fails, restore from the most recent snapshot:
```bash
cp /srv/mcns/mcns.db /srv/mcns/mcns.db.corrupt
cp /srv/mcns/backups/mcns-YYYYMMDD-HHMMSS.db /srv/mcns/mcns.db
```
5. Start the service and verify health.
6. Re-create any records added after the snapshot was taken.
### Certificate Expiry
Symptoms: health check fails with TLS errors, API clients get
certificate errors.
1. Check certificate expiry:
```bash
openssl x509 -in /srv/mcns/certs/cert.pem -noout -enddate
```
2. Replace the certificate and key files at the paths in `mcns.toml`.
3. Restart the service to load the new certificate.
4. Verify health:
```bash
mcns status --addr https://localhost:8443
```
### MCIAS Outage
Symptoms: management API returns 502 or authentication errors. DNS
continues to work normally (DNS has no auth dependency).
1. Confirm MCIAS is unreachable:
```bash
curl -k https://svc.metacircular.net:8443/v1/health
```
2. DNS resolution is unaffected -- no immediate action needed for DNS.
3. Management operations (zone/record create/update/delete) will fail
until MCIAS recovers.
4. Escalate to MCIAS (see Escalation below).
### DNS Not Resolving
Symptoms: `dig @<server> <name>` returns SERVFAIL or times out.
1. Verify the service is running:
```bash
docker compose -f deploy/docker/docker-compose-rift.yml ps mcns
```
2. Check that port 53 is listening:
```bash
ss -ulnp | grep ':53'
ss -tlnp | grep ':53'
```
3. Test an authoritative query:
```bash
dig @localhost svc.mcp.metacircular.net SOA
```
4. Test a forwarded query:
```bash
dig @localhost example.com A
```
5. If authoritative queries fail but forwarding works, the database may
be corrupt (see Database Corruption above).
6. If forwarding fails, check upstream connectivity:
```bash
dig @1.1.1.1 example.com A
```
7. Check logs for errors:
```bash
docker compose -f deploy/docker/docker-compose-rift.yml logs --tail 50 mcns
```
### Port 53 Already in Use
Symptoms: MCNS fails to start with "address already in use" on port 53.
1. Identify what is using the port:
```bash
ss -ulnp | grep ':53'
ss -tlnp | grep ':53'
```
2. Common culprit: `systemd-resolved` listening on `127.0.0.53:53`.
- If on a system with systemd-resolved, either disable it or bind
MCNS to a specific IP instead of `0.0.0.0:53`.
3. If another DNS server is running, stop it or change the MCNS
`[dns] listen_addr` in `mcns.toml` to a different address.
4. Restart MCNS and verify DNS is responding.
## Escalation
Escalate when:
- Database corruption cannot be resolved by restoring a snapshot.
- MCIAS is down and management operations are urgently needed.
- DNS resolution failures persist after following the procedures above.
- Any issue not covered by this runbook.
Escalation path: Kyle (platform owner).

View File

@@ -4,7 +4,7 @@
// protoc v6.32.1 // protoc v6.32.1
// source: proto/mcns/v1/admin.proto // source: proto/mcns/v1/admin.proto
package v1 package mcnsv1
import ( import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
@@ -110,7 +110,7 @@ const file_proto_mcns_v1_admin_proto_rawDesc = "" +
"\x0eHealthResponse\x12\x16\n" + "\x0eHealthResponse\x12\x16\n" +
"\x06status\x18\x01 \x01(\tR\x06status2I\n" + "\x06status\x18\x01 \x01(\tR\x06status2I\n" +
"\fAdminService\x129\n" + "\fAdminService\x129\n" +
"\x06Health\x12\x16.mcns.v1.HealthRequest\x1a\x17.mcns.v1.HealthResponseB(Z&git.wntrmute.dev/kyle/mcns/gen/mcns/v1b\x06proto3" "\x06Health\x12\x16.mcns.v1.HealthRequest\x1a\x17.mcns.v1.HealthResponseB/Z-git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1b\x06proto3"
var ( var (
file_proto_mcns_v1_admin_proto_rawDescOnce sync.Once file_proto_mcns_v1_admin_proto_rawDescOnce sync.Once

View File

@@ -4,7 +4,7 @@
// - protoc v6.32.1 // - protoc v6.32.1
// source: proto/mcns/v1/admin.proto // source: proto/mcns/v1/admin.proto
package v1 package mcnsv1
import ( import (
context "context" context "context"
@@ -25,6 +25,8 @@ const (
// AdminServiceClient is the client API for AdminService service. // AdminServiceClient is the client API for AdminService service.
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
//
// AdminService exposes server health and administrative operations.
type AdminServiceClient interface { type AdminServiceClient interface {
Health(ctx context.Context, in *HealthRequest, opts ...grpc.CallOption) (*HealthResponse, error) Health(ctx context.Context, in *HealthRequest, opts ...grpc.CallOption) (*HealthResponse, error)
} }
@@ -50,6 +52,8 @@ func (c *adminServiceClient) Health(ctx context.Context, in *HealthRequest, opts
// AdminServiceServer is the server API for AdminService service. // AdminServiceServer is the server API for AdminService service.
// All implementations must embed UnimplementedAdminServiceServer // All implementations must embed UnimplementedAdminServiceServer
// for forward compatibility. // for forward compatibility.
//
// AdminService exposes server health and administrative operations.
type AdminServiceServer interface { type AdminServiceServer interface {
Health(context.Context, *HealthRequest) (*HealthResponse, error) Health(context.Context, *HealthRequest) (*HealthResponse, error)
mustEmbedUnimplementedAdminServiceServer() mustEmbedUnimplementedAdminServiceServer()

View File

@@ -4,7 +4,7 @@
// protoc v6.32.1 // protoc v6.32.1
// source: proto/mcns/v1/auth.proto // source: proto/mcns/v1/auth.proto
package v1 package mcnsv1
import ( import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
@@ -22,10 +22,11 @@ const (
) )
type LoginRequest struct { type LoginRequest struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"`
TotpCode string `protobuf:"bytes,3,opt,name=totp_code,json=totpCode,proto3" json:"totp_code,omitempty"` // TOTP code for two-factor authentication, if enabled on the account.
TotpCode string `protobuf:"bytes,3,opt,name=totp_code,json=totpCode,proto3" json:"totp_code,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -221,7 +222,7 @@ const file_proto_mcns_v1_auth_proto_rawDesc = "" +
"\x0eLogoutResponse2\x80\x01\n" + "\x0eLogoutResponse2\x80\x01\n" +
"\vAuthService\x126\n" + "\vAuthService\x126\n" +
"\x05Login\x12\x15.mcns.v1.LoginRequest\x1a\x16.mcns.v1.LoginResponse\x129\n" + "\x05Login\x12\x15.mcns.v1.LoginRequest\x1a\x16.mcns.v1.LoginResponse\x129\n" +
"\x06Logout\x12\x16.mcns.v1.LogoutRequest\x1a\x17.mcns.v1.LogoutResponseB(Z&git.wntrmute.dev/kyle/mcns/gen/mcns/v1b\x06proto3" "\x06Logout\x12\x16.mcns.v1.LogoutRequest\x1a\x17.mcns.v1.LogoutResponseB/Z-git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1b\x06proto3"
var ( var (
file_proto_mcns_v1_auth_proto_rawDescOnce sync.Once file_proto_mcns_v1_auth_proto_rawDescOnce sync.Once

View File

@@ -4,7 +4,7 @@
// - protoc v6.32.1 // - protoc v6.32.1
// source: proto/mcns/v1/auth.proto // source: proto/mcns/v1/auth.proto
package v1 package mcnsv1
import ( import (
context "context" context "context"
@@ -26,6 +26,8 @@ const (
// AuthServiceClient is the client API for AuthService service. // AuthServiceClient is the client API for AuthService service.
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
//
// AuthService handles authentication by delegating to MCIAS.
type AuthServiceClient interface { type AuthServiceClient interface {
Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error)
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
@@ -62,6 +64,8 @@ func (c *authServiceClient) Logout(ctx context.Context, in *LogoutRequest, opts
// AuthServiceServer is the server API for AuthService service. // AuthServiceServer is the server API for AuthService service.
// All implementations must embed UnimplementedAuthServiceServer // All implementations must embed UnimplementedAuthServiceServer
// for forward compatibility. // for forward compatibility.
//
// AuthService handles authentication by delegating to MCIAS.
type AuthServiceServer interface { type AuthServiceServer interface {
Login(context.Context, *LoginRequest) (*LoginResponse, error) Login(context.Context, *LoginRequest) (*LoginResponse, error)
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)

View File

@@ -4,7 +4,7 @@
// protoc v6.32.1 // protoc v6.32.1
// source: proto/mcns/v1/record.proto // source: proto/mcns/v1/record.proto
package v1 package mcnsv1
import ( import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
@@ -23,10 +23,12 @@ const (
) )
type Record struct { type Record struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Id int64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` Id int64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"`
Zone string `protobuf:"bytes,2,opt,name=zone,proto3" json:"zone,omitempty"` // Zone name this record belongs to (e.g. "example.com.").
Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"` Zone string `protobuf:"bytes,2,opt,name=zone,proto3" json:"zone,omitempty"`
Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"`
// DNS record type (A, AAAA, CNAME, MX, TXT, etc.).
Type string `protobuf:"bytes,4,opt,name=type,proto3" json:"type,omitempty"` Type string `protobuf:"bytes,4,opt,name=type,proto3" json:"type,omitempty"`
Value string `protobuf:"bytes,5,opt,name=value,proto3" json:"value,omitempty"` Value string `protobuf:"bytes,5,opt,name=value,proto3" json:"value,omitempty"`
Ttl int32 `protobuf:"varint,6,opt,name=ttl,proto3" json:"ttl,omitempty"` Ttl int32 `protobuf:"varint,6,opt,name=ttl,proto3" json:"ttl,omitempty"`
@@ -123,10 +125,12 @@ func (x *Record) GetUpdatedAt() *timestamppb.Timestamp {
} }
type ListRecordsRequest struct { type ListRecordsRequest struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Zone string `protobuf:"bytes,1,opt,name=zone,proto3" json:"zone,omitempty"` Zone string `protobuf:"bytes,1,opt,name=zone,proto3" json:"zone,omitempty"`
Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` // Optional filter by record name.
Type string `protobuf:"bytes,3,opt,name=type,proto3" json:"type,omitempty"` Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
// Optional filter by record type (A, AAAA, CNAME, etc.).
Type string `protobuf:"bytes,3,opt,name=type,proto3" json:"type,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -227,12 +231,13 @@ func (x *ListRecordsResponse) GetRecords() []*Record {
} }
type CreateRecordRequest struct { type CreateRecordRequest struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Zone string `protobuf:"bytes,1,opt,name=zone,proto3" json:"zone,omitempty"` // Zone name the record will be created in; must reference an existing zone.
Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` Zone string `protobuf:"bytes,1,opt,name=zone,proto3" json:"zone,omitempty"`
Type string `protobuf:"bytes,3,opt,name=type,proto3" json:"type,omitempty"` Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
Value string `protobuf:"bytes,4,opt,name=value,proto3" json:"value,omitempty"` Type string `protobuf:"bytes,3,opt,name=type,proto3" json:"type,omitempty"`
Ttl int32 `protobuf:"varint,5,opt,name=ttl,proto3" json:"ttl,omitempty"` Value string `protobuf:"bytes,4,opt,name=value,proto3" json:"value,omitempty"`
Ttl int32 `protobuf:"varint,5,opt,name=ttl,proto3" json:"ttl,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -546,7 +551,7 @@ const file_proto_mcns_v1_record_proto_rawDesc = "" +
"\fCreateRecord\x12\x1c.mcns.v1.CreateRecordRequest\x1a\x0f.mcns.v1.Record\x127\n" + "\fCreateRecord\x12\x1c.mcns.v1.CreateRecordRequest\x1a\x0f.mcns.v1.Record\x127\n" +
"\tGetRecord\x12\x19.mcns.v1.GetRecordRequest\x1a\x0f.mcns.v1.Record\x12=\n" + "\tGetRecord\x12\x19.mcns.v1.GetRecordRequest\x1a\x0f.mcns.v1.Record\x12=\n" +
"\fUpdateRecord\x12\x1c.mcns.v1.UpdateRecordRequest\x1a\x0f.mcns.v1.Record\x12K\n" + "\fUpdateRecord\x12\x1c.mcns.v1.UpdateRecordRequest\x1a\x0f.mcns.v1.Record\x12K\n" +
"\fDeleteRecord\x12\x1c.mcns.v1.DeleteRecordRequest\x1a\x1d.mcns.v1.DeleteRecordResponseB(Z&git.wntrmute.dev/kyle/mcns/gen/mcns/v1b\x06proto3" "\fDeleteRecord\x12\x1c.mcns.v1.DeleteRecordRequest\x1a\x1d.mcns.v1.DeleteRecordResponseB/Z-git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1b\x06proto3"
var ( var (
file_proto_mcns_v1_record_proto_rawDescOnce sync.Once file_proto_mcns_v1_record_proto_rawDescOnce sync.Once

View File

@@ -4,7 +4,7 @@
// - protoc v6.32.1 // - protoc v6.32.1
// source: proto/mcns/v1/record.proto // source: proto/mcns/v1/record.proto
package v1 package mcnsv1
import ( import (
context "context" context "context"
@@ -29,6 +29,8 @@ const (
// RecordServiceClient is the client API for RecordService service. // RecordServiceClient is the client API for RecordService service.
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
//
// RecordService manages DNS records within zones.
type RecordServiceClient interface { type RecordServiceClient interface {
ListRecords(ctx context.Context, in *ListRecordsRequest, opts ...grpc.CallOption) (*ListRecordsResponse, error) ListRecords(ctx context.Context, in *ListRecordsRequest, opts ...grpc.CallOption) (*ListRecordsResponse, error)
CreateRecord(ctx context.Context, in *CreateRecordRequest, opts ...grpc.CallOption) (*Record, error) CreateRecord(ctx context.Context, in *CreateRecordRequest, opts ...grpc.CallOption) (*Record, error)
@@ -98,6 +100,8 @@ func (c *recordServiceClient) DeleteRecord(ctx context.Context, in *DeleteRecord
// RecordServiceServer is the server API for RecordService service. // RecordServiceServer is the server API for RecordService service.
// All implementations must embed UnimplementedRecordServiceServer // All implementations must embed UnimplementedRecordServiceServer
// for forward compatibility. // for forward compatibility.
//
// RecordService manages DNS records within zones.
type RecordServiceServer interface { type RecordServiceServer interface {
ListRecords(context.Context, *ListRecordsRequest) (*ListRecordsResponse, error) ListRecords(context.Context, *ListRecordsRequest) (*ListRecordsResponse, error)
CreateRecord(context.Context, *CreateRecordRequest) (*Record, error) CreateRecord(context.Context, *CreateRecordRequest) (*Record, error)

View File

@@ -4,7 +4,7 @@
// protoc v6.32.1 // protoc v6.32.1
// source: proto/mcns/v1/zone.proto // source: proto/mcns/v1/zone.proto
package v1 package mcnsv1
import ( import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
@@ -595,7 +595,7 @@ const file_proto_mcns_v1_zone_proto_rawDesc = "" +
"\n" + "\n" +
"UpdateZone\x12\x1a.mcns.v1.UpdateZoneRequest\x1a\r.mcns.v1.Zone\x12E\n" + "UpdateZone\x12\x1a.mcns.v1.UpdateZoneRequest\x1a\r.mcns.v1.Zone\x12E\n" +
"\n" + "\n" +
"DeleteZone\x12\x1a.mcns.v1.DeleteZoneRequest\x1a\x1b.mcns.v1.DeleteZoneResponseB(Z&git.wntrmute.dev/kyle/mcns/gen/mcns/v1b\x06proto3" "DeleteZone\x12\x1a.mcns.v1.DeleteZoneRequest\x1a\x1b.mcns.v1.DeleteZoneResponseB/Z-git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1b\x06proto3"
var ( var (
file_proto_mcns_v1_zone_proto_rawDescOnce sync.Once file_proto_mcns_v1_zone_proto_rawDescOnce sync.Once

View File

@@ -4,7 +4,7 @@
// - protoc v6.32.1 // - protoc v6.32.1
// source: proto/mcns/v1/zone.proto // source: proto/mcns/v1/zone.proto
package v1 package mcnsv1
import ( import (
context "context" context "context"
@@ -29,6 +29,8 @@ const (
// ZoneServiceClient is the client API for ZoneService service. // ZoneServiceClient is the client API for ZoneService service.
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
//
// ZoneService manages DNS zones and their SOA parameters.
type ZoneServiceClient interface { type ZoneServiceClient interface {
ListZones(ctx context.Context, in *ListZonesRequest, opts ...grpc.CallOption) (*ListZonesResponse, error) ListZones(ctx context.Context, in *ListZonesRequest, opts ...grpc.CallOption) (*ListZonesResponse, error)
CreateZone(ctx context.Context, in *CreateZoneRequest, opts ...grpc.CallOption) (*Zone, error) CreateZone(ctx context.Context, in *CreateZoneRequest, opts ...grpc.CallOption) (*Zone, error)
@@ -98,6 +100,8 @@ func (c *zoneServiceClient) DeleteZone(ctx context.Context, in *DeleteZoneReques
// ZoneServiceServer is the server API for ZoneService service. // ZoneServiceServer is the server API for ZoneService service.
// All implementations must embed UnimplementedZoneServiceServer // All implementations must embed UnimplementedZoneServiceServer
// for forward compatibility. // for forward compatibility.
//
// ZoneService manages DNS zones and their SOA parameters.
type ZoneServiceServer interface { type ZoneServiceServer interface {
ListZones(context.Context, *ListZonesRequest) (*ListZonesResponse, error) ListZones(context.Context, *ListZonesRequest) (*ListZonesResponse, error)
CreateZone(context.Context, *CreateZoneRequest) (*Zone, error) CreateZone(context.Context, *CreateZoneRequest) (*Zone, error)

View File

@@ -0,0 +1,815 @@
package grpcserver
import (
"context"
"encoding/json"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
// mockMCIAS starts a fake MCIAS HTTP server for token validation.
// Recognized tokens:
// - "admin-token" -> valid, username=admin-uuid, roles=[admin]
// - "user-token" -> valid, username=user-uuid, roles=[user]
// - anything else -> invalid
func mockMCIAS(t *testing.T) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
mux.HandleFunc("POST /v1/token/validate", func(w http.ResponseWriter, r *http.Request) {
var req struct {
Token string `json:"token"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
switch req.Token {
case "admin-token":
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"valid": true,
"username": "admin-uuid",
"account_type": "human",
"roles": []string{"admin"},
})
case "user-token":
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"valid": true,
"username": "user-uuid",
"account_type": "human",
"roles": []string{"user"},
})
default:
_ = json.NewEncoder(w).Encode(map[string]interface{}{"valid": false})
}
})
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return srv
}
// testAuthenticator creates an mcdsl/auth.Authenticator that talks to the given mock MCIAS.
func testAuthenticator(t *testing.T, serverURL string) *mcdslauth.Authenticator {
t.Helper()
a, err := mcdslauth.New(mcdslauth.Config{ServerURL: serverURL}, slog.Default())
if err != nil {
t.Fatalf("auth.New: %v", err)
}
return a
}
// openTestDB creates a temporary test database with migrations applied.
func openTestDB(t *testing.T) *db.DB {
t.Helper()
path := filepath.Join(t.TempDir(), "test.db")
d, err := db.Open(path)
if err != nil {
t.Fatalf("Open: %v", err)
}
t.Cleanup(func() { _ = d.Close() })
if err := d.Migrate(); err != nil {
t.Fatalf("Migrate: %v", err)
}
return d
}
// startTestServer creates a gRPC server with auth interceptors and returns
// a connected client. Passing empty cert/key strings skips TLS.
func startTestServer(t *testing.T, deps Deps) *grpc.ClientConn {
t.Helper()
srv, err := New("", "", deps, slog.Default())
if err != nil {
t.Fatalf("New: %v", err)
}
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen: %v", err)
}
go func() {
_ = srv.Serve(lis)
}()
t.Cleanup(func() { srv.GracefulStop() })
//nolint:gosec // insecure credentials for testing only
cc, err := grpc.NewClient(
lis.Addr().String(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
t.Fatalf("Dial: %v", err)
}
t.Cleanup(func() { _ = cc.Close() })
return cc
}
// withAuth adds a bearer token to the outgoing context metadata.
func withAuth(ctx context.Context, token string) context.Context {
return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token)
}
// seedZone creates a test zone and returns it.
func seedZone(t *testing.T, database *db.DB, name string) *db.Zone {
t.Helper()
zone, err := database.CreateZone(name, "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("seed zone %q: %v", name, err)
}
return zone
}
// seedRecord creates a test A record and returns it.
func seedRecord(t *testing.T, database *db.DB, zoneName, name, value string) *db.Record {
t.Helper()
rec, err := database.CreateRecord(zoneName, name, "A", value, 300)
if err != nil {
t.Fatalf("seed record %s.%s: %v", name, zoneName, err)
}
return rec
}
// ---------------------------------------------------------------------------
// Admin tests
// ---------------------------------------------------------------------------
func TestHealthBypassesAuth(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
client := pb.NewAdminServiceClient(cc)
// No auth token -- should still succeed because Health is public.
resp, err := client.Health(context.Background(), &pb.HealthRequest{})
if err != nil {
t.Fatalf("Health should not require auth: %v", err)
}
if resp.Status != "ok" {
t.Fatalf("Health status: got %q, want %q", resp.Status, "ok")
}
}
// ---------------------------------------------------------------------------
// Zone tests
// ---------------------------------------------------------------------------
func TestListZones(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewZoneServiceClient(cc)
resp, err := client.ListZones(ctx, &pb.ListZonesRequest{})
if err != nil {
t.Fatalf("ListZones: %v", err)
}
// Seed migration creates 2 zones.
if len(resp.Zones) != 2 {
t.Fatalf("got %d zones, want 2 (seed zones)", len(resp.Zones))
}
}
func TestGetZoneFound(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewZoneServiceClient(cc)
zone, err := client.GetZone(ctx, &pb.GetZoneRequest{Name: "example.com"})
if err != nil {
t.Fatalf("GetZone: %v", err)
}
if zone.Name != "example.com" {
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
}
}
func TestGetZoneNotFound(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewZoneServiceClient(cc)
_, err := client.GetZone(ctx, &pb.GetZoneRequest{Name: "nonexistent.com"})
if err == nil {
t.Fatal("expected error for nonexistent zone")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status, got %v", err)
}
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
func TestCreateZoneSuccess(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewZoneServiceClient(cc)
zone, err := client.CreateZone(ctx, &pb.CreateZoneRequest{
Name: "newzone.com",
PrimaryNs: "ns1.newzone.com.",
AdminEmail: "admin.newzone.com.",
Refresh: 3600,
Retry: 600,
Expire: 86400,
MinimumTtl: 300,
})
if err != nil {
t.Fatalf("CreateZone: %v", err)
}
if zone.Name != "newzone.com" {
t.Fatalf("got name %q, want %q", zone.Name, "newzone.com")
}
if zone.Serial == 0 {
t.Fatal("serial should not be zero")
}
}
func TestCreateZoneDuplicate(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewZoneServiceClient(cc)
_, err := client.CreateZone(ctx, &pb.CreateZoneRequest{
Name: "example.com",
PrimaryNs: "ns1.example.com.",
AdminEmail: "admin.example.com.",
})
if err == nil {
t.Fatal("expected error for duplicate zone")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status, got %v", err)
}
if st.Code() != codes.AlreadyExists {
t.Fatalf("code: got %v, want AlreadyExists", st.Code())
}
}
func TestUpdateZone(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
original := seedZone(t, database, "example.com")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewZoneServiceClient(cc)
updated, err := client.UpdateZone(ctx, &pb.UpdateZoneRequest{
Name: "example.com",
PrimaryNs: "ns2.example.com.",
AdminEmail: "newadmin.example.com.",
Refresh: 7200,
Retry: 1200,
Expire: 172800,
MinimumTtl: 600,
})
if err != nil {
t.Fatalf("UpdateZone: %v", err)
}
if updated.PrimaryNs != "ns2.example.com." {
t.Fatalf("got primary_ns %q, want %q", updated.PrimaryNs, "ns2.example.com.")
}
if updated.Serial <= original.Serial {
t.Fatalf("serial should have incremented: %d <= %d", updated.Serial, original.Serial)
}
}
func TestDeleteZone(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewZoneServiceClient(cc)
_, err := client.DeleteZone(ctx, &pb.DeleteZoneRequest{Name: "example.com"})
if err != nil {
t.Fatalf("DeleteZone: %v", err)
}
// Verify it is gone.
_, err = client.GetZone(withAuth(context.Background(), "user-token"), &pb.GetZoneRequest{Name: "example.com"})
if err == nil {
t.Fatal("expected NotFound after delete")
}
st, _ := status.FromError(err)
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
func TestDeleteZoneNotFound(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewZoneServiceClient(cc)
_, err := client.DeleteZone(ctx, &pb.DeleteZoneRequest{Name: "nonexistent.com"})
if err == nil {
t.Fatal("expected error for nonexistent zone")
}
st, _ := status.FromError(err)
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
// ---------------------------------------------------------------------------
// Record tests
// ---------------------------------------------------------------------------
func TestListRecords(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
seedRecord(t, database, "example.com", "www", "10.0.0.1")
seedRecord(t, database, "example.com", "mail", "10.0.0.2")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewRecordServiceClient(cc)
resp, err := client.ListRecords(ctx, &pb.ListRecordsRequest{Zone: "example.com"})
if err != nil {
t.Fatalf("ListRecords: %v", err)
}
if len(resp.Records) != 2 {
t.Fatalf("got %d records, want 2", len(resp.Records))
}
}
func TestGetRecordFound(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
rec := seedRecord(t, database, "example.com", "www", "10.0.0.1")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewRecordServiceClient(cc)
got, err := client.GetRecord(ctx, &pb.GetRecordRequest{Id: rec.ID})
if err != nil {
t.Fatalf("GetRecord: %v", err)
}
if got.Name != "www" {
t.Fatalf("got name %q, want %q", got.Name, "www")
}
if got.Value != "10.0.0.1" {
t.Fatalf("got value %q, want %q", got.Value, "10.0.0.1")
}
}
func TestGetRecordNotFound(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewRecordServiceClient(cc)
_, err := client.GetRecord(ctx, &pb.GetRecordRequest{Id: 999999})
if err == nil {
t.Fatal("expected error for nonexistent record")
}
st, _ := status.FromError(err)
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
func TestCreateRecordSuccess(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewRecordServiceClient(cc)
rec, err := client.CreateRecord(ctx, &pb.CreateRecordRequest{
Zone: "example.com",
Name: "www",
Type: "A",
Value: "10.0.0.1",
Ttl: 300,
})
if err != nil {
t.Fatalf("CreateRecord: %v", err)
}
if rec.Name != "www" {
t.Fatalf("got name %q, want %q", rec.Name, "www")
}
if rec.Type != "A" {
t.Fatalf("got type %q, want %q", rec.Type, "A")
}
}
func TestCreateRecordInvalidValue(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewRecordServiceClient(cc)
_, err := client.CreateRecord(ctx, &pb.CreateRecordRequest{
Zone: "example.com",
Name: "www",
Type: "A",
Value: "not-an-ip",
Ttl: 300,
})
if err == nil {
t.Fatal("expected error for invalid A record value")
}
st, _ := status.FromError(err)
if st.Code() != codes.InvalidArgument {
t.Fatalf("code: got %v, want InvalidArgument", st.Code())
}
}
func TestCreateRecordCNAMEConflict(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
seedRecord(t, database, "example.com", "www", "10.0.0.1")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewRecordServiceClient(cc)
// Try to create a CNAME for "www" which already has an A record.
_, err := client.CreateRecord(ctx, &pb.CreateRecordRequest{
Zone: "example.com",
Name: "www",
Type: "CNAME",
Value: "other.example.com.",
Ttl: 300,
})
if err == nil {
t.Fatal("expected error for CNAME conflict with existing A record")
}
st, _ := status.FromError(err)
if st.Code() != codes.AlreadyExists {
t.Fatalf("code: got %v, want AlreadyExists", st.Code())
}
}
func TestCreateRecordAConflictWithCNAME(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
// Create a CNAME first.
_, err := database.CreateRecord("example.com", "alias", "CNAME", "target.example.com.", 300)
if err != nil {
t.Fatalf("seed CNAME: %v", err)
}
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewRecordServiceClient(cc)
// Try to create an A record for "alias" which already has a CNAME.
_, err = client.CreateRecord(ctx, &pb.CreateRecordRequest{
Zone: "example.com",
Name: "alias",
Type: "A",
Value: "10.0.0.1",
Ttl: 300,
})
if err == nil {
t.Fatal("expected error for A record conflict with existing CNAME")
}
st, _ := status.FromError(err)
if st.Code() != codes.AlreadyExists {
t.Fatalf("code: got %v, want AlreadyExists", st.Code())
}
}
func TestUpdateRecord(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
rec := seedRecord(t, database, "example.com", "www", "10.0.0.1")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewRecordServiceClient(cc)
updated, err := client.UpdateRecord(ctx, &pb.UpdateRecordRequest{
Id: rec.ID,
Name: "www",
Type: "A",
Value: "10.0.0.2",
Ttl: 600,
})
if err != nil {
t.Fatalf("UpdateRecord: %v", err)
}
if updated.Value != "10.0.0.2" {
t.Fatalf("got value %q, want %q", updated.Value, "10.0.0.2")
}
if updated.Ttl != 600 {
t.Fatalf("got ttl %d, want 600", updated.Ttl)
}
}
func TestUpdateRecordNotFound(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewRecordServiceClient(cc)
_, err := client.UpdateRecord(ctx, &pb.UpdateRecordRequest{
Id: 999999,
Name: "www",
Type: "A",
Value: "10.0.0.1",
Ttl: 300,
})
if err == nil {
t.Fatal("expected error for nonexistent record")
}
st, _ := status.FromError(err)
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
func TestDeleteRecord(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
seedZone(t, database, "example.com")
rec := seedRecord(t, database, "example.com", "www", "10.0.0.1")
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewRecordServiceClient(cc)
_, err := client.DeleteRecord(ctx, &pb.DeleteRecordRequest{Id: rec.ID})
if err != nil {
t.Fatalf("DeleteRecord: %v", err)
}
// Verify it is gone.
_, err = client.GetRecord(withAuth(context.Background(), "user-token"), &pb.GetRecordRequest{Id: rec.ID})
if err == nil {
t.Fatal("expected NotFound after delete")
}
st, _ := status.FromError(err)
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
func TestDeleteRecordNotFound(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewRecordServiceClient(cc)
_, err := client.DeleteRecord(ctx, &pb.DeleteRecordRequest{Id: 999999})
if err == nil {
t.Fatal("expected error for nonexistent record")
}
st, _ := status.FromError(err)
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
// ---------------------------------------------------------------------------
// Auth interceptor tests
// ---------------------------------------------------------------------------
func TestAuthRequiredNoToken(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
client := pb.NewZoneServiceClient(cc)
// No auth token on an auth-required method.
_, err := client.ListZones(context.Background(), &pb.ListZonesRequest{})
if err == nil {
t.Fatal("expected error for unauthenticated request")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.Unauthenticated {
t.Fatalf("code: got %v, want Unauthenticated", st.Code())
}
}
func TestAuthRequiredInvalidToken(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "bad-token")
client := pb.NewZoneServiceClient(cc)
_, err := client.ListZones(ctx, &pb.ListZonesRequest{})
if err == nil {
t.Fatal("expected error for invalid token")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.Unauthenticated {
t.Fatalf("code: got %v, want Unauthenticated", st.Code())
}
}
func TestAdminRequiredDeniedForUser(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewZoneServiceClient(cc)
// CreateZone requires admin.
_, err := client.CreateZone(ctx, &pb.CreateZoneRequest{
Name: "forbidden.com",
PrimaryNs: "ns1.forbidden.com.",
AdminEmail: "admin.forbidden.com.",
})
if err == nil {
t.Fatal("expected error for non-admin user")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.PermissionDenied {
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
}
}
func TestAdminRequiredAllowedForAdmin(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{DB: database, Authenticator: auth})
ctx := withAuth(context.Background(), "admin-token")
client := pb.NewZoneServiceClient(cc)
// Admin should be able to create zones.
zone, err := client.CreateZone(ctx, &pb.CreateZoneRequest{
Name: "admin-created.com",
PrimaryNs: "ns1.admin-created.com.",
AdminEmail: "admin.admin-created.com.",
})
if err != nil {
t.Fatalf("CreateZone as admin: %v", err)
}
if zone.Name != "admin-created.com" {
t.Fatalf("got name %q, want %q", zone.Name, "admin-created.com")
}
}
// ---------------------------------------------------------------------------
// Interceptor map completeness test
// ---------------------------------------------------------------------------
func TestMethodMapCompleteness(t *testing.T) {
mm := methodMap()
expectedPublic := []string{
"/mcns.v1.AdminService/Health",
"/mcns.v1.AuthService/Login",
}
for _, method := range expectedPublic {
if !mm.Public[method] {
t.Errorf("method %s should be public but is not in Public", method)
}
}
if len(mm.Public) != len(expectedPublic) {
t.Errorf("Public has %d entries, expected %d", len(mm.Public), len(expectedPublic))
}
expectedAuth := []string{
"/mcns.v1.AuthService/Logout",
"/mcns.v1.ZoneService/ListZones",
"/mcns.v1.ZoneService/GetZone",
"/mcns.v1.RecordService/ListRecords",
"/mcns.v1.RecordService/GetRecord",
}
for _, method := range expectedAuth {
if !mm.AuthRequired[method] {
t.Errorf("method %s should require auth but is not in AuthRequired", method)
}
}
if len(mm.AuthRequired) != len(expectedAuth) {
t.Errorf("AuthRequired has %d entries, expected %d", len(mm.AuthRequired), len(expectedAuth))
}
expectedAdmin := []string{
"/mcns.v1.ZoneService/CreateZone",
"/mcns.v1.ZoneService/UpdateZone",
"/mcns.v1.ZoneService/DeleteZone",
"/mcns.v1.RecordService/CreateRecord",
"/mcns.v1.RecordService/UpdateRecord",
"/mcns.v1.RecordService/DeleteRecord",
}
for _, method := range expectedAdmin {
if !mm.AdminRequired[method] {
t.Errorf("method %s should require admin but is not in AdminRequired", method)
}
}
if len(mm.AdminRequired) != len(expectedAdmin) {
t.Errorf("AdminRequired has %d entries, expected %d", len(mm.AdminRequired), len(expectedAdmin))
}
// Verify no method appears in multiple maps (each RPC in exactly one map).
all := make(map[string]int)
for k := range mm.Public {
all[k]++
}
for k := range mm.AuthRequired {
all[k]++
}
for k := range mm.AdminRequired {
all[k]++
}
for method, count := range all {
if count != 1 {
t.Errorf("method %s appears in %d maps, expected exactly 1", method, count)
}
}
}

View File

@@ -0,0 +1,949 @@
package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"github.com/go-chi/chi/v5"
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
// openTestDB creates a temporary SQLite database with all migrations applied.
func openTestDB(t *testing.T) *db.DB {
t.Helper()
dir := t.TempDir()
database, err := db.Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := database.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
}
t.Cleanup(func() { _ = database.Close() })
return database
}
// createTestZone inserts a zone for use by record tests.
func createTestZone(t *testing.T, database *db.DB) *db.Zone {
t.Helper()
zone, err := database.CreateZone("test.example.com", "ns.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
return zone
}
// newChiRequest builds a request with chi URL params injected into the context.
func newChiRequest(method, target string, body string, params map[string]string) *http.Request {
var r *http.Request
if body != "" {
r = httptest.NewRequest(method, target, strings.NewReader(body))
} else {
r = httptest.NewRequest(method, target, nil)
}
r.Header.Set("Content-Type", "application/json")
if len(params) > 0 {
rctx := chi.NewRouteContext()
for k, v := range params {
rctx.URLParams.Add(k, v)
}
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
}
return r
}
// decodeJSON decodes the response body into v.
func decodeJSON(t *testing.T, rec *httptest.ResponseRecorder, v any) {
t.Helper()
if err := json.NewDecoder(rec.Body).Decode(v); err != nil {
t.Fatalf("decode json: %v", err)
}
}
// ---- Zone handler tests ----
func TestListZonesHandler_SeedOnly(t *testing.T) {
database := openTestDB(t)
handler := listZonesHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodGet, "/v1/zones", "", nil)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var resp map[string][]db.Zone
decodeJSON(t, rec, &resp)
zones := resp["zones"]
if len(zones) != 2 {
t.Fatalf("got %d zones, want 2 (seed zones)", len(zones))
}
}
func TestListZonesHandler_Populated(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
handler := listZonesHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodGet, "/v1/zones", "", nil)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var resp map[string][]db.Zone
decodeJSON(t, rec, &resp)
zones := resp["zones"]
// 2 seed + 1 created = 3.
if len(zones) != 3 {
t.Fatalf("got %d zones, want 3", len(zones))
}
}
func TestGetZoneHandler_Found(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
handler := getZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com", "", map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var zone db.Zone
decodeJSON(t, rec, &zone)
if zone.Name != "test.example.com" {
t.Fatalf("zone name = %q, want %q", zone.Name, "test.example.com")
}
}
func TestGetZoneHandler_NotFound(t *testing.T) {
database := openTestDB(t)
handler := getZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodGet, "/v1/zones/nonexistent.com", "", map[string]string{"zone": "nonexistent.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestCreateZoneHandler_Success(t *testing.T) {
database := openTestDB(t)
body := `{"name":"new.example.com","primary_ns":"ns1.example.com.","admin_email":"admin.example.com."}`
handler := createZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPost, "/v1/zones", body, nil)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String())
}
var zone db.Zone
decodeJSON(t, rec, &zone)
if zone.Name != "new.example.com" {
t.Fatalf("zone name = %q, want %q", zone.Name, "new.example.com")
}
if zone.PrimaryNS != "ns1.example.com." {
t.Fatalf("primary_ns = %q, want %q", zone.PrimaryNS, "ns1.example.com.")
}
// SOA defaults should be applied.
if zone.Refresh != 3600 {
t.Fatalf("refresh = %d, want 3600", zone.Refresh)
}
}
func TestCreateZoneHandler_MissingFields(t *testing.T) {
tests := []struct {
name string
body string
}{
{"missing name", `{"primary_ns":"ns1.example.com.","admin_email":"admin.example.com."}`},
{"missing primary_ns", `{"name":"new.example.com","admin_email":"admin.example.com."}`},
{"missing admin_email", `{"name":"new.example.com","primary_ns":"ns1.example.com."}`},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
database := openTestDB(t)
handler := createZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPost, "/v1/zones", tt.body, nil)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
})
}
}
func TestCreateZoneHandler_Duplicate(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
body := `{"name":"test.example.com","primary_ns":"ns1.example.com.","admin_email":"admin.example.com."}`
handler := createZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPost, "/v1/zones", body, nil)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusConflict {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusConflict)
}
}
func TestCreateZoneHandler_InvalidJSON(t *testing.T) {
database := openTestDB(t)
handler := createZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPost, "/v1/zones", "not json", nil)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestUpdateZoneHandler_Success(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
body := `{"primary_ns":"ns2.example.com.","admin_email":"newadmin.example.com.","refresh":7200}`
handler := updateZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com", body, map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
}
var zone db.Zone
decodeJSON(t, rec, &zone)
if zone.PrimaryNS != "ns2.example.com." {
t.Fatalf("primary_ns = %q, want %q", zone.PrimaryNS, "ns2.example.com.")
}
if zone.Refresh != 7200 {
t.Fatalf("refresh = %d, want 7200", zone.Refresh)
}
}
func TestUpdateZoneHandler_NotFound(t *testing.T) {
database := openTestDB(t)
body := `{"primary_ns":"ns2.example.com.","admin_email":"admin.example.com."}`
handler := updateZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPut, "/v1/zones/nonexistent.com", body, map[string]string{"zone": "nonexistent.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestUpdateZoneHandler_MissingFields(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
body := `{"admin_email":"admin.example.com."}`
handler := updateZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com", body, map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestDeleteZoneHandler_Success(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
handler := deleteZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com", "", map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
}
// Verify the zone is gone.
_, err := database.GetZone("test.example.com")
if err != db.ErrNotFound {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
}
func TestDeleteZoneHandler_NotFound(t *testing.T) {
database := openTestDB(t)
handler := deleteZoneHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodDelete, "/v1/zones/nonexistent.com", "", map[string]string{"zone": "nonexistent.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
// ---- Record handler tests ----
func TestListRecordsHandler_WithZone(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
_, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
_, err = database.CreateRecord("test.example.com", "mail", "A", "10.0.0.2", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
handler := listRecordsHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records", "", map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var resp map[string][]db.Record
decodeJSON(t, rec, &resp)
records := resp["records"]
if len(records) != 2 {
t.Fatalf("got %d records, want 2", len(records))
}
}
func TestListRecordsHandler_ZoneNotFound(t *testing.T) {
database := openTestDB(t)
handler := listRecordsHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodGet, "/v1/zones/nonexistent.com/records", "", map[string]string{"zone": "nonexistent.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestListRecordsHandler_EmptyZone(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
handler := listRecordsHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records", "", map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var resp map[string][]db.Record
decodeJSON(t, rec, &resp)
records := resp["records"]
if len(records) != 0 {
t.Fatalf("got %d records, want 0", len(records))
}
}
func TestListRecordsHandler_WithFilters(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
_, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
_, err = database.CreateRecord("test.example.com", "www", "A", "10.0.0.2", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
_, err = database.CreateRecord("test.example.com", "mail", "A", "10.0.0.3", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
handler := listRecordsHandler(database)
// Filter by name.
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records?name=www", "", map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var resp map[string][]db.Record
decodeJSON(t, rec, &resp)
if len(resp["records"]) != 2 {
t.Fatalf("got %d records for name=www, want 2", len(resp["records"]))
}
}
func TestGetRecordHandler_Found(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
handler := getRecordHandler(database)
rec := httptest.NewRecorder()
idStr := fmt.Sprintf("%d", created.ID)
req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records/"+idStr, "", map[string]string{
"zone": "test.example.com",
"id": idStr,
})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var record db.Record
decodeJSON(t, rec, &record)
if record.Name != "www" {
t.Fatalf("record name = %q, want %q", record.Name, "www")
}
if record.Value != "10.0.0.1" {
t.Fatalf("record value = %q, want %q", record.Value, "10.0.0.1")
}
}
func TestGetRecordHandler_NotFound(t *testing.T) {
database := openTestDB(t)
handler := getRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records/99999", "", map[string]string{
"zone": "test.example.com",
"id": "99999",
})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestGetRecordHandler_InvalidID(t *testing.T) {
database := openTestDB(t)
handler := getRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records/abc", "", map[string]string{
"zone": "test.example.com",
"id": "abc",
})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestCreateRecordHandler_Success(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
body := `{"name":"www","type":"A","value":"10.0.0.1","ttl":600}`
handler := createRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", body, map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String())
}
var record db.Record
decodeJSON(t, rec, &record)
if record.Name != "www" {
t.Fatalf("record name = %q, want %q", record.Name, "www")
}
if record.Type != "A" {
t.Fatalf("record type = %q, want %q", record.Type, "A")
}
if record.TTL != 600 {
t.Fatalf("ttl = %d, want 600", record.TTL)
}
}
func TestCreateRecordHandler_MissingFields(t *testing.T) {
tests := []struct {
name string
body string
}{
{"missing name", `{"type":"A","value":"10.0.0.1"}`},
{"missing type", `{"name":"www","value":"10.0.0.1"}`},
{"missing value", `{"name":"www","type":"A"}`},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
handler := createRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", tt.body, map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
})
}
}
func TestCreateRecordHandler_InvalidIP(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
body := `{"name":"www","type":"A","value":"not-an-ip"}`
handler := createRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", body, map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestCreateRecordHandler_CNAMEConflict(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
// Create an A record first.
_, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300)
if err != nil {
t.Fatalf("create A record: %v", err)
}
// Try to create a CNAME for the same name via handler.
body := `{"name":"www","type":"CNAME","value":"other.example.com."}`
handler := createRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", body, map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusConflict {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusConflict)
}
}
func TestCreateRecordHandler_ZoneNotFound(t *testing.T) {
database := openTestDB(t)
body := `{"name":"www","type":"A","value":"10.0.0.1"}`
handler := createRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPost, "/v1/zones/nonexistent.com/records", body, map[string]string{"zone": "nonexistent.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestCreateRecordHandler_InvalidJSON(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
handler := createRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", "not json", map[string]string{"zone": "test.example.com"})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestUpdateRecordHandler_Success(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
idStr := fmt.Sprintf("%d", created.ID)
body := `{"name":"www","type":"A","value":"10.0.0.2","ttl":600}`
handler := updateRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/"+idStr, body, map[string]string{
"zone": "test.example.com",
"id": idStr,
})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String())
}
var record db.Record
decodeJSON(t, rec, &record)
if record.Value != "10.0.0.2" {
t.Fatalf("value = %q, want %q", record.Value, "10.0.0.2")
}
if record.TTL != 600 {
t.Fatalf("ttl = %d, want 600", record.TTL)
}
}
func TestUpdateRecordHandler_NotFound(t *testing.T) {
database := openTestDB(t)
body := `{"name":"www","type":"A","value":"10.0.0.1"}`
handler := updateRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/99999", body, map[string]string{
"zone": "test.example.com",
"id": "99999",
})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestUpdateRecordHandler_InvalidID(t *testing.T) {
database := openTestDB(t)
body := `{"name":"www","type":"A","value":"10.0.0.1"}`
handler := updateRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/abc", body, map[string]string{
"zone": "test.example.com",
"id": "abc",
})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestUpdateRecordHandler_MissingFields(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
idStr := fmt.Sprintf("%d", created.ID)
// Missing name.
body := `{"type":"A","value":"10.0.0.1"}`
handler := updateRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/"+idStr, body, map[string]string{
"zone": "test.example.com",
"id": idStr,
})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
func TestDeleteRecordHandler_Success(t *testing.T) {
database := openTestDB(t)
createTestZone(t, database)
created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
idStr := fmt.Sprintf("%d", created.ID)
handler := deleteRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com/records/"+idStr, "", map[string]string{
"zone": "test.example.com",
"id": idStr,
})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
}
// Verify record is gone.
_, err = database.GetRecord(created.ID)
if err != db.ErrNotFound {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
}
func TestDeleteRecordHandler_NotFound(t *testing.T) {
database := openTestDB(t)
handler := deleteRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com/records/99999", "", map[string]string{
"zone": "test.example.com",
"id": "99999",
})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestDeleteRecordHandler_InvalidID(t *testing.T) {
database := openTestDB(t)
handler := deleteRecordHandler(database)
rec := httptest.NewRecorder()
req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com/records/abc", "", map[string]string{
"zone": "test.example.com",
"id": "abc",
})
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
}
// ---- Middleware tests ----
func TestRequireAdmin_WithAdminContext(t *testing.T) {
called := false
inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
handler := requireAdmin(inner)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
// Inject admin TokenInfo into context.
info := &mcdslauth.TokenInfo{
Username: "admin-user",
IsAdmin: true,
Roles: []string{"admin"},
}
ctx := context.WithValue(req.Context(), tokenInfoKey, info)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
if !called {
t.Fatal("inner handler was not called")
}
}
func TestRequireAdmin_WithNonAdminContext(t *testing.T) {
called := false
inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
handler := requireAdmin(inner)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
// Inject non-admin TokenInfo into context.
info := &mcdslauth.TokenInfo{
Username: "regular-user",
IsAdmin: false,
Roles: []string{"viewer"},
}
ctx := context.WithValue(req.Context(), tokenInfoKey, info)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
}
if called {
t.Fatal("inner handler should not have been called")
}
}
func TestRequireAdmin_NoTokenInfo(t *testing.T) {
called := false
inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
handler := requireAdmin(inner)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
}
if called {
t.Fatal("inner handler should not have been called")
}
}
func TestExtractBearerToken(t *testing.T) {
tests := []struct {
name string
header string
want string
}{
{"valid bearer", "Bearer abc123", "abc123"},
{"empty header", "", ""},
{"no prefix", "abc123", ""},
{"basic auth", "Basic abc123", ""},
{"bearer with spaces", "Bearer token-with-space ", "token-with-space"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
if tt.header != "" {
r.Header.Set("Authorization", tt.header)
}
got := extractBearerToken(r)
if got != tt.want {
t.Fatalf("extractBearerToken(%q) = %q, want %q", tt.header, got, tt.want)
}
})
}
}
func TestTokenInfoFromContext(t *testing.T) {
// No token info in context.
ctx := context.Background()
if info := tokenInfoFromContext(ctx); info != nil {
t.Fatal("expected nil, got token info")
}
// With token info.
expected := &mcdslauth.TokenInfo{Username: "testuser", IsAdmin: true}
ctx = context.WithValue(ctx, tokenInfoKey, expected)
got := tokenInfoFromContext(ctx)
if got == nil {
t.Fatal("expected token info, got nil")
}
if got.Username != expected.Username {
t.Fatalf("username = %q, want %q", got.Username, expected.Username)
}
if !got.IsAdmin {
t.Fatal("expected IsAdmin to be true")
}
}
// ---- writeJSON / writeError tests ----
func TestWriteJSON(t *testing.T) {
rec := httptest.NewRecorder()
writeJSON(rec, http.StatusOK, map[string]string{"key": "value"})
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
if ct := rec.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("content-type = %q, want %q", ct, "application/json")
}
var resp map[string]string
decodeJSON(t, rec, &resp)
if resp["key"] != "value" {
t.Fatalf("got key=%q, want %q", resp["key"], "value")
}
}
func TestWriteError(t *testing.T) {
rec := httptest.NewRecorder()
writeError(rec, http.StatusBadRequest, "bad input")
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
var resp map[string]string
decodeJSON(t, rec, &resp)
if resp["error"] != "bad input" {
t.Fatalf("got error=%q, want %q", resp["error"], "bad input")
}
}

View File

@@ -2,8 +2,9 @@ syntax = "proto3";
package mcns.v1; package mcns.v1;
option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"; option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1";
// AdminService exposes server health and administrative operations.
service AdminService { service AdminService {
rpc Health(HealthRequest) returns (HealthResponse); rpc Health(HealthRequest) returns (HealthResponse);
} }

View File

@@ -2,8 +2,9 @@ syntax = "proto3";
package mcns.v1; package mcns.v1;
option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"; option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1";
// AuthService handles authentication by delegating to MCIAS.
service AuthService { service AuthService {
rpc Login(LoginRequest) returns (LoginResponse); rpc Login(LoginRequest) returns (LoginResponse);
rpc Logout(LogoutRequest) returns (LogoutResponse); rpc Logout(LogoutRequest) returns (LogoutResponse);
@@ -12,6 +13,7 @@ service AuthService {
message LoginRequest { message LoginRequest {
string username = 1; string username = 1;
string password = 2; string password = 2;
// TOTP code for two-factor authentication, if enabled on the account.
string totp_code = 3; string totp_code = 3;
} }

View File

@@ -2,10 +2,11 @@ syntax = "proto3";
package mcns.v1; package mcns.v1;
option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"; option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1";
import "google/protobuf/timestamp.proto"; import "google/protobuf/timestamp.proto";
// RecordService manages DNS records within zones.
service RecordService { service RecordService {
rpc ListRecords(ListRecordsRequest) returns (ListRecordsResponse); rpc ListRecords(ListRecordsRequest) returns (ListRecordsResponse);
rpc CreateRecord(CreateRecordRequest) returns (Record); rpc CreateRecord(CreateRecordRequest) returns (Record);
@@ -16,8 +17,10 @@ service RecordService {
message Record { message Record {
int64 id = 1; int64 id = 1;
// Zone name this record belongs to (e.g. "example.com.").
string zone = 2; string zone = 2;
string name = 3; string name = 3;
// DNS record type (A, AAAA, CNAME, MX, TXT, etc.).
string type = 4; string type = 4;
string value = 5; string value = 5;
int32 ttl = 6; int32 ttl = 6;
@@ -27,7 +30,9 @@ message Record {
message ListRecordsRequest { message ListRecordsRequest {
string zone = 1; string zone = 1;
// Optional filter by record name.
string name = 2; string name = 2;
// Optional filter by record type (A, AAAA, CNAME, etc.).
string type = 3; string type = 3;
} }
@@ -36,6 +41,7 @@ message ListRecordsResponse {
} }
message CreateRecordRequest { message CreateRecordRequest {
// Zone name the record will be created in; must reference an existing zone.
string zone = 1; string zone = 1;
string name = 2; string name = 2;
string type = 3; string type = 3;

View File

@@ -2,10 +2,11 @@ syntax = "proto3";
package mcns.v1; package mcns.v1;
option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"; option go_package = "git.wntrmute.dev/kyle/mcns/gen/mcns/v1;mcnsv1";
import "google/protobuf/timestamp.proto"; import "google/protobuf/timestamp.proto";
// ZoneService manages DNS zones and their SOA parameters.
service ZoneService { service ZoneService {
rpc ListZones(ListZonesRequest) returns (ListZonesResponse); rpc ListZones(ListZonesRequest) returns (ListZonesResponse);
rpc CreateZone(CreateZoneRequest) returns (Zone); rpc CreateZone(CreateZoneRequest) returns (Zone);