Apply review fixes: validation, idempotency, SOA dedup, startup cleanup

- Migration v2: INSERT → INSERT OR IGNORE for idempotency
- Config: validate server.tls_cert and server.tls_key are non-empty
- gRPC: add input validation matching REST handlers
- gRPC: add logger to zone/record services, log timestamp parse errors
- REST+gRPC: extract SOA defaults into shared db.ApplySOADefaults()
- DNS: simplify SOA query condition (remove dead code from precedence bug)
- Startup: consolidate shutdown into shutdownAll(), clean up gRPC listener
  on error path, shut down sibling servers when one fails

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 21:17:15 -07:00
parent edcf99e8d1
commit f8f3a9868a
9 changed files with 155 additions and 102 deletions

View File

@@ -3,6 +3,7 @@ package grpcserver
import (
"context"
"errors"
"log/slog"
"time"
"google.golang.org/grpc/codes"
@@ -15,10 +16,15 @@ import (
type recordService struct {
pb.UnimplementedRecordServiceServer
db *db.DB
db *db.DB
logger *slog.Logger
}
func (s *recordService) ListRecords(_ context.Context, req *pb.ListRecordsRequest) (*pb.ListRecordsResponse, error) {
if req.Zone == "" {
return nil, status.Error(codes.InvalidArgument, "zone is required")
}
records, err := s.db.ListRecords(req.Zone, req.Name, req.Type)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "zone not found")
@@ -29,12 +35,16 @@ func (s *recordService) ListRecords(_ context.Context, req *pb.ListRecordsReques
resp := &pb.ListRecordsResponse{}
for _, r := range records {
resp.Records = append(resp.Records, recordToProto(r))
resp.Records = append(resp.Records, s.recordToProto(r))
}
return resp, nil
}
func (s *recordService) GetRecord(_ context.Context, req *pb.GetRecordRequest) (*pb.Record, error) {
if req.Id <= 0 {
return nil, status.Error(codes.InvalidArgument, "id must be positive")
}
record, err := s.db.GetRecord(req.Id)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "record not found")
@@ -42,10 +52,23 @@ func (s *recordService) GetRecord(_ context.Context, req *pb.GetRecordRequest) (
if err != nil {
return nil, status.Error(codes.Internal, "failed to get record")
}
return recordToProto(*record), nil
return s.recordToProto(*record), nil
}
func (s *recordService) CreateRecord(_ context.Context, req *pb.CreateRecordRequest) (*pb.Record, error) {
if req.Zone == "" {
return nil, status.Error(codes.InvalidArgument, "zone is required")
}
if req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "name is required")
}
if req.Type == "" {
return nil, status.Error(codes.InvalidArgument, "type is required")
}
if req.Value == "" {
return nil, status.Error(codes.InvalidArgument, "value is required")
}
record, err := s.db.CreateRecord(req.Zone, req.Name, req.Type, req.Value, int(req.Ttl))
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "zone not found")
@@ -56,10 +79,23 @@ func (s *recordService) CreateRecord(_ context.Context, req *pb.CreateRecordRequ
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
return recordToProto(*record), nil
return s.recordToProto(*record), nil
}
func (s *recordService) UpdateRecord(_ context.Context, req *pb.UpdateRecordRequest) (*pb.Record, error) {
if req.Id <= 0 {
return nil, status.Error(codes.InvalidArgument, "id must be positive")
}
if req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "name is required")
}
if req.Type == "" {
return nil, status.Error(codes.InvalidArgument, "type is required")
}
if req.Value == "" {
return nil, status.Error(codes.InvalidArgument, "value is required")
}
record, err := s.db.UpdateRecord(req.Id, req.Name, req.Type, req.Value, int(req.Ttl))
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "record not found")
@@ -70,10 +106,14 @@ func (s *recordService) UpdateRecord(_ context.Context, req *pb.UpdateRecordRequ
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
return recordToProto(*record), nil
return s.recordToProto(*record), nil
}
func (s *recordService) DeleteRecord(_ context.Context, req *pb.DeleteRecordRequest) (*pb.DeleteRecordResponse, error) {
if req.Id <= 0 {
return nil, status.Error(codes.InvalidArgument, "id must be positive")
}
err := s.db.DeleteRecord(req.Id)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "record not found")
@@ -84,7 +124,7 @@ func (s *recordService) DeleteRecord(_ context.Context, req *pb.DeleteRecordRequ
return &pb.DeleteRecordResponse{}, nil
}
func recordToProto(r db.Record) *pb.Record {
func (s *recordService) recordToProto(r db.Record) *pb.Record {
return &pb.Record{
Id: r.ID,
Zone: r.ZoneName,
@@ -92,14 +132,15 @@ func recordToProto(r db.Record) *pb.Record {
Type: r.Type,
Value: r.Value,
Ttl: int32(r.TTL),
CreatedAt: parseRecordTimestamp(r.CreatedAt),
UpdatedAt: parseRecordTimestamp(r.UpdatedAt),
CreatedAt: s.parseRecordTimestamp(r.CreatedAt),
UpdatedAt: s.parseRecordTimestamp(r.UpdatedAt),
}
}
func parseRecordTimestamp(s string) *timestamppb.Timestamp {
t, err := parseTime(s)
func (s *recordService) parseRecordTimestamp(v string) *timestamppb.Timestamp {
t, err := parseTime(v)
if err != nil {
s.logger.Warn("failed to parse record timestamp", "value", v, "error", err)
return nil
}
return timestamppb.New(t)

View File

@@ -33,8 +33,8 @@ func New(certFile, keyFile string, deps Deps, logger *slog.Logger) (*Server, err
pb.RegisterAdminServiceServer(srv.GRPCServer, &adminService{db: deps.DB})
pb.RegisterAuthServiceServer(srv.GRPCServer, &authService{auth: deps.Authenticator})
pb.RegisterZoneServiceServer(srv.GRPCServer, &zoneService{db: deps.DB})
pb.RegisterRecordServiceServer(srv.GRPCServer, &recordService{db: deps.DB})
pb.RegisterZoneServiceServer(srv.GRPCServer, &zoneService{db: deps.DB, logger: logger})
pb.RegisterRecordServiceServer(srv.GRPCServer, &recordService{db: deps.DB, logger: logger})
return s, nil
}

View File

@@ -3,6 +3,7 @@ package grpcserver
import (
"context"
"errors"
"log/slog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@@ -14,7 +15,8 @@ import (
type zoneService struct {
pb.UnimplementedZoneServiceServer
db *db.DB
db *db.DB
logger *slog.Logger
}
func (s *zoneService) ListZones(_ context.Context, _ *pb.ListZonesRequest) (*pb.ListZonesResponse, error) {
@@ -25,12 +27,16 @@ func (s *zoneService) ListZones(_ context.Context, _ *pb.ListZonesRequest) (*pb.
resp := &pb.ListZonesResponse{}
for _, z := range zones {
resp.Zones = append(resp.Zones, zoneToProto(z))
resp.Zones = append(resp.Zones, s.zoneToProto(z))
}
return resp, nil
}
func (s *zoneService) GetZone(_ context.Context, req *pb.GetZoneRequest) (*pb.Zone, error) {
if req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "name is required")
}
zone, err := s.db.GetZone(req.Name)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "zone not found")
@@ -38,27 +44,22 @@ func (s *zoneService) GetZone(_ context.Context, req *pb.GetZoneRequest) (*pb.Zo
if err != nil {
return nil, status.Error(codes.Internal, "failed to get zone")
}
return zoneToProto(*zone), nil
return s.zoneToProto(*zone), nil
}
func (s *zoneService) CreateZone(_ context.Context, req *pb.CreateZoneRequest) (*pb.Zone, error) {
refresh := int(req.Refresh)
if refresh == 0 {
refresh = 3600
if req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "name is required")
}
retry := int(req.Retry)
if retry == 0 {
retry = 600
if req.PrimaryNs == "" {
return nil, status.Error(codes.InvalidArgument, "primary_ns is required")
}
expire := int(req.Expire)
if expire == 0 {
expire = 86400
}
minTTL := int(req.MinimumTtl)
if minTTL == 0 {
minTTL = 300
if req.AdminEmail == "" {
return nil, status.Error(codes.InvalidArgument, "admin_email is required")
}
refresh, retry, expire, minTTL := db.ApplySOADefaults(int(req.Refresh), int(req.Retry), int(req.Expire), int(req.MinimumTtl))
zone, err := s.db.CreateZone(req.Name, req.PrimaryNs, req.AdminEmail, refresh, retry, expire, minTTL)
if errors.Is(err, db.ErrConflict) {
return nil, status.Error(codes.AlreadyExists, err.Error())
@@ -66,27 +67,22 @@ func (s *zoneService) CreateZone(_ context.Context, req *pb.CreateZoneRequest) (
if err != nil {
return nil, status.Error(codes.Internal, "failed to create zone")
}
return zoneToProto(*zone), nil
return s.zoneToProto(*zone), nil
}
func (s *zoneService) UpdateZone(_ context.Context, req *pb.UpdateZoneRequest) (*pb.Zone, error) {
refresh := int(req.Refresh)
if refresh == 0 {
refresh = 3600
if req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "name is required")
}
retry := int(req.Retry)
if retry == 0 {
retry = 600
if req.PrimaryNs == "" {
return nil, status.Error(codes.InvalidArgument, "primary_ns is required")
}
expire := int(req.Expire)
if expire == 0 {
expire = 86400
}
minTTL := int(req.MinimumTtl)
if minTTL == 0 {
minTTL = 300
if req.AdminEmail == "" {
return nil, status.Error(codes.InvalidArgument, "admin_email is required")
}
refresh, retry, expire, minTTL := db.ApplySOADefaults(int(req.Refresh), int(req.Retry), int(req.Expire), int(req.MinimumTtl))
zone, err := s.db.UpdateZone(req.Name, req.PrimaryNs, req.AdminEmail, refresh, retry, expire, minTTL)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "zone not found")
@@ -94,10 +90,14 @@ func (s *zoneService) UpdateZone(_ context.Context, req *pb.UpdateZoneRequest) (
if err != nil {
return nil, status.Error(codes.Internal, "failed to update zone")
}
return zoneToProto(*zone), nil
return s.zoneToProto(*zone), nil
}
func (s *zoneService) DeleteZone(_ context.Context, req *pb.DeleteZoneRequest) (*pb.DeleteZoneResponse, error) {
if req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "name is required")
}
err := s.db.DeleteZone(req.Name)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "zone not found")
@@ -108,7 +108,7 @@ func (s *zoneService) DeleteZone(_ context.Context, req *pb.DeleteZoneRequest) (
return &pb.DeleteZoneResponse{}, nil
}
func zoneToProto(z db.Zone) *pb.Zone {
func (s *zoneService) zoneToProto(z db.Zone) *pb.Zone {
return &pb.Zone{
Id: z.ID,
Name: z.Name,
@@ -119,15 +119,15 @@ func zoneToProto(z db.Zone) *pb.Zone {
Expire: int32(z.Expire),
MinimumTtl: int32(z.MinimumTTL),
Serial: z.Serial,
CreatedAt: parseTimestamp(z.CreatedAt),
UpdatedAt: parseTimestamp(z.UpdatedAt),
CreatedAt: s.parseTimestamp(z.CreatedAt),
UpdatedAt: s.parseTimestamp(z.UpdatedAt),
}
}
func parseTimestamp(s string) *timestamppb.Timestamp {
// SQLite stores as "2006-01-02T15:04:05Z".
t, err := parseTime(s)
func (s *zoneService) parseTimestamp(v string) *timestamppb.Timestamp {
t, err := parseTime(v)
if err != nil {
s.logger.Warn("failed to parse zone timestamp", "value", v, "error", err)
return nil
}
return timestamppb.New(t)