Apply review fixes: validation, idempotency, SOA dedup, startup cleanup
- Migration v2: INSERT → INSERT OR IGNORE for idempotency - Config: validate server.tls_cert and server.tls_key are non-empty - gRPC: add input validation matching REST handlers - gRPC: add logger to zone/record services, log timestamp parse errors - REST+gRPC: extract SOA defaults into shared db.ApplySOADefaults() - DNS: simplify SOA query condition (remove dead code from precedence bug) - Startup: consolidate shutdown into shutdownAll(), clean up gRPC listener on error path, shut down sibling servers when one fails Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -138,6 +138,26 @@ func runServer(configPath string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shutdownAll tears down all servers. Safe to call even if some
|
||||||
|
// servers were never started. grpcSrv.Serve takes ownership of
|
||||||
|
// grpcLis, so we only close grpcLis if we never reached Serve.
|
||||||
|
grpcServeStarted := false
|
||||||
|
shutdownAll := func() {
|
||||||
|
dnsServer.Shutdown()
|
||||||
|
if grpcSrv != nil {
|
||||||
|
grpcSrv.GracefulStop()
|
||||||
|
} else if grpcLis != nil && !grpcServeStarted {
|
||||||
|
_ = grpcLis.Close()
|
||||||
|
}
|
||||||
|
shutdownTimeout := 30 * time.Second
|
||||||
|
if cfg.Server.ShutdownTimeout.Duration > 0 {
|
||||||
|
shutdownTimeout = cfg.Server.ShutdownTimeout.Duration
|
||||||
|
}
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||||
|
defer cancel()
|
||||||
|
_ = httpServer.Shutdown(shutdownCtx)
|
||||||
|
}
|
||||||
|
|
||||||
// Graceful shutdown on SIGINT/SIGTERM.
|
// Graceful shutdown on SIGINT/SIGTERM.
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
defer stop()
|
defer stop()
|
||||||
@@ -151,6 +171,7 @@ func runServer(configPath string) error {
|
|||||||
|
|
||||||
// Start gRPC server.
|
// Start gRPC server.
|
||||||
if grpcSrv != nil {
|
if grpcSrv != nil {
|
||||||
|
grpcServeStarted = true
|
||||||
go func() {
|
go func() {
|
||||||
logger.Info("gRPC server listening", "addr", grpcLis.Addr())
|
logger.Info("gRPC server listening", "addr", grpcLis.Addr())
|
||||||
errCh <- grpcSrv.Serve(grpcLis)
|
errCh <- grpcSrv.Serve(grpcLis)
|
||||||
@@ -169,22 +190,11 @@ func runServer(configPath string) error {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-errCh:
|
case err := <-errCh:
|
||||||
|
shutdownAll()
|
||||||
return fmt.Errorf("server error: %w", err)
|
return fmt.Errorf("server error: %w", err)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
logger.Info("shutting down")
|
logger.Info("shutting down")
|
||||||
dnsServer.Shutdown()
|
shutdownAll()
|
||||||
if grpcSrv != nil {
|
|
||||||
grpcSrv.GracefulStop()
|
|
||||||
}
|
|
||||||
shutdownTimeout := 30 * time.Second
|
|
||||||
if cfg.Server.ShutdownTimeout.Duration > 0 {
|
|
||||||
shutdownTimeout = cfg.Server.ShutdownTimeout.Duration
|
|
||||||
}
|
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
|
||||||
defer cancel()
|
|
||||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
|
||||||
return fmt.Errorf("shutdown: %w", err)
|
|
||||||
}
|
|
||||||
logger.Info("mcns stopped")
|
logger.Info("mcns stopped")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,5 +45,11 @@ func (c *Config) Validate() error {
|
|||||||
if c.MCIAS.ServerURL == "" {
|
if c.MCIAS.ServerURL == "" {
|
||||||
return fmt.Errorf("mcias.server_url is required")
|
return fmt.Errorf("mcias.server_url is required")
|
||||||
}
|
}
|
||||||
|
if c.Server.TLSCert == "" {
|
||||||
|
return fmt.Errorf("server.tls_cert is required")
|
||||||
|
}
|
||||||
|
if c.Server.TLSKey == "" {
|
||||||
|
return fmt.Errorf("server.tls_key is required")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,28 +43,28 @@ CREATE INDEX IF NOT EXISTS idx_records_zone_name ON records(zone_id, name);`,
|
|||||||
Name: "seed zones and records from CoreDNS zone files",
|
Name: "seed zones and records from CoreDNS zone files",
|
||||||
SQL: `
|
SQL: `
|
||||||
-- Zone: svc.mcp.metacircular.net (service addresses)
|
-- Zone: svc.mcp.metacircular.net (service addresses)
|
||||||
INSERT INTO zones (id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial)
|
INSERT OR IGNORE INTO zones (id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial)
|
||||||
VALUES (1, 'svc.mcp.metacircular.net', 'ns.mcp.metacircular.net.', 'admin.metacircular.net.', 3600, 600, 86400, 300, 2026032601);
|
VALUES (1, 'svc.mcp.metacircular.net', 'ns.mcp.metacircular.net.', 'admin.metacircular.net.', 3600, 600, 86400, 300, 2026032601);
|
||||||
|
|
||||||
-- Zone: mcp.metacircular.net (node addresses)
|
-- Zone: mcp.metacircular.net (node addresses)
|
||||||
INSERT INTO zones (id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial)
|
INSERT OR IGNORE INTO zones (id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial)
|
||||||
VALUES (2, 'mcp.metacircular.net', 'ns.mcp.metacircular.net.', 'admin.metacircular.net.', 3600, 600, 86400, 300, 2026032501);
|
VALUES (2, 'mcp.metacircular.net', 'ns.mcp.metacircular.net.', 'admin.metacircular.net.', 3600, 600, 86400, 300, 2026032501);
|
||||||
|
|
||||||
-- svc.mcp.metacircular.net records
|
-- svc.mcp.metacircular.net records
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (1, 'metacrypt', 'A', '192.168.88.181', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (1, 'metacrypt', 'A', '192.168.88.181', 300);
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (1, 'metacrypt', 'A', '100.95.252.120', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (1, 'metacrypt', 'A', '100.95.252.120', 300);
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (1, 'mcr', 'A', '192.168.88.181', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (1, 'mcr', 'A', '192.168.88.181', 300);
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (1, 'mcr', 'A', '100.95.252.120', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (1, 'mcr', 'A', '100.95.252.120', 300);
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (1, 'sgard', 'A', '192.168.88.181', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (1, 'sgard', 'A', '192.168.88.181', 300);
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (1, 'sgard', 'A', '100.95.252.120', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (1, 'sgard', 'A', '100.95.252.120', 300);
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (1, 'mcp-agent', 'A', '192.168.88.181', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (1, 'mcp-agent', 'A', '192.168.88.181', 300);
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (1, 'mcp-agent', 'A', '100.95.252.120', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (1, 'mcp-agent', 'A', '100.95.252.120', 300);
|
||||||
|
|
||||||
-- mcp.metacircular.net records
|
-- mcp.metacircular.net records
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (2, 'rift', 'A', '192.168.88.181', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (2, 'rift', 'A', '192.168.88.181', 300);
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (2, 'rift', 'A', '100.95.252.120', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (2, 'rift', 'A', '100.95.252.120', 300);
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (2, 'ns', 'A', '192.168.88.181', 300);
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (2, 'ns', 'A', '192.168.88.181', 300);
|
||||||
INSERT INTO records (zone_id, name, type, value, ttl) VALUES (2, 'ns', 'A', '100.95.252.120', 300);`,
|
INSERT OR IGNORE INTO records (zone_id, name, type, value, ttl) VALUES (2, 'ns', 'A', '100.95.252.120', 300);`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -169,6 +169,24 @@ func (d *DB) ZoneNames() ([]string, error) {
|
|||||||
return names, rows.Err()
|
return names, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ApplySOADefaults fills in zero-valued SOA parameters with sensible defaults:
|
||||||
|
// refresh=3600, retry=600, expire=86400, minTTL=300.
|
||||||
|
func ApplySOADefaults(refresh, retry, expire, minTTL int) (int, int, int, int) {
|
||||||
|
if refresh == 0 {
|
||||||
|
refresh = 3600
|
||||||
|
}
|
||||||
|
if retry == 0 {
|
||||||
|
retry = 600
|
||||||
|
}
|
||||||
|
if expire == 0 {
|
||||||
|
expire = 86400
|
||||||
|
}
|
||||||
|
if minTTL == 0 {
|
||||||
|
minTTL = 300
|
||||||
|
}
|
||||||
|
return refresh, retry, expire, minTTL
|
||||||
|
}
|
||||||
|
|
||||||
// nextSerial computes the next SOA serial in YYYYMMDDNN format.
|
// nextSerial computes the next SOA serial in YYYYMMDDNN format.
|
||||||
func nextSerial(current int64) int64 {
|
func nextSerial(current int64) int64 {
|
||||||
today := time.Now().UTC()
|
today := time.Now().UTC()
|
||||||
|
|||||||
@@ -116,8 +116,8 @@ func (s *Server) handleAuthoritativeQuery(w dns.ResponseWriter, r *dns.Msg, zone
|
|||||||
relName = strings.TrimSuffix(qname, "."+zoneFQDN)
|
relName = strings.TrimSuffix(qname, "."+zoneFQDN)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle SOA queries.
|
// SOA queries always return the zone apex SOA regardless of query name.
|
||||||
if qtype == dns.TypeSOA || relName == "@" && qtype == dns.TypeSOA {
|
if qtype == dns.TypeSOA {
|
||||||
soa := s.buildSOA(zone)
|
soa := s.buildSOA(zone)
|
||||||
s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{soa}, nil)
|
s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{soa}, nil)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package grpcserver
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@@ -16,9 +17,14 @@ import (
|
|||||||
type recordService struct {
|
type recordService struct {
|
||||||
pb.UnimplementedRecordServiceServer
|
pb.UnimplementedRecordServiceServer
|
||||||
db *db.DB
|
db *db.DB
|
||||||
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *recordService) ListRecords(_ context.Context, req *pb.ListRecordsRequest) (*pb.ListRecordsResponse, error) {
|
func (s *recordService) ListRecords(_ context.Context, req *pb.ListRecordsRequest) (*pb.ListRecordsResponse, error) {
|
||||||
|
if req.Zone == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "zone is required")
|
||||||
|
}
|
||||||
|
|
||||||
records, err := s.db.ListRecords(req.Zone, req.Name, req.Type)
|
records, err := s.db.ListRecords(req.Zone, req.Name, req.Type)
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, status.Error(codes.NotFound, "zone not found")
|
return nil, status.Error(codes.NotFound, "zone not found")
|
||||||
@@ -29,12 +35,16 @@ func (s *recordService) ListRecords(_ context.Context, req *pb.ListRecordsReques
|
|||||||
|
|
||||||
resp := &pb.ListRecordsResponse{}
|
resp := &pb.ListRecordsResponse{}
|
||||||
for _, r := range records {
|
for _, r := range records {
|
||||||
resp.Records = append(resp.Records, recordToProto(r))
|
resp.Records = append(resp.Records, s.recordToProto(r))
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *recordService) GetRecord(_ context.Context, req *pb.GetRecordRequest) (*pb.Record, error) {
|
func (s *recordService) GetRecord(_ context.Context, req *pb.GetRecordRequest) (*pb.Record, error) {
|
||||||
|
if req.Id <= 0 {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "id must be positive")
|
||||||
|
}
|
||||||
|
|
||||||
record, err := s.db.GetRecord(req.Id)
|
record, err := s.db.GetRecord(req.Id)
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, status.Error(codes.NotFound, "record not found")
|
return nil, status.Error(codes.NotFound, "record not found")
|
||||||
@@ -42,10 +52,23 @@ func (s *recordService) GetRecord(_ context.Context, req *pb.GetRecordRequest) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Internal, "failed to get record")
|
return nil, status.Error(codes.Internal, "failed to get record")
|
||||||
}
|
}
|
||||||
return recordToProto(*record), nil
|
return s.recordToProto(*record), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *recordService) CreateRecord(_ context.Context, req *pb.CreateRecordRequest) (*pb.Record, error) {
|
func (s *recordService) CreateRecord(_ context.Context, req *pb.CreateRecordRequest) (*pb.Record, error) {
|
||||||
|
if req.Zone == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "zone is required")
|
||||||
|
}
|
||||||
|
if req.Name == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "name is required")
|
||||||
|
}
|
||||||
|
if req.Type == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "type is required")
|
||||||
|
}
|
||||||
|
if req.Value == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "value is required")
|
||||||
|
}
|
||||||
|
|
||||||
record, err := s.db.CreateRecord(req.Zone, req.Name, req.Type, req.Value, int(req.Ttl))
|
record, err := s.db.CreateRecord(req.Zone, req.Name, req.Type, req.Value, int(req.Ttl))
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, status.Error(codes.NotFound, "zone not found")
|
return nil, status.Error(codes.NotFound, "zone not found")
|
||||||
@@ -56,10 +79,23 @@ func (s *recordService) CreateRecord(_ context.Context, req *pb.CreateRecordRequ
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||||
}
|
}
|
||||||
return recordToProto(*record), nil
|
return s.recordToProto(*record), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *recordService) UpdateRecord(_ context.Context, req *pb.UpdateRecordRequest) (*pb.Record, error) {
|
func (s *recordService) UpdateRecord(_ context.Context, req *pb.UpdateRecordRequest) (*pb.Record, error) {
|
||||||
|
if req.Id <= 0 {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "id must be positive")
|
||||||
|
}
|
||||||
|
if req.Name == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "name is required")
|
||||||
|
}
|
||||||
|
if req.Type == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "type is required")
|
||||||
|
}
|
||||||
|
if req.Value == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "value is required")
|
||||||
|
}
|
||||||
|
|
||||||
record, err := s.db.UpdateRecord(req.Id, req.Name, req.Type, req.Value, int(req.Ttl))
|
record, err := s.db.UpdateRecord(req.Id, req.Name, req.Type, req.Value, int(req.Ttl))
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, status.Error(codes.NotFound, "record not found")
|
return nil, status.Error(codes.NotFound, "record not found")
|
||||||
@@ -70,10 +106,14 @@ func (s *recordService) UpdateRecord(_ context.Context, req *pb.UpdateRecordRequ
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||||
}
|
}
|
||||||
return recordToProto(*record), nil
|
return s.recordToProto(*record), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *recordService) DeleteRecord(_ context.Context, req *pb.DeleteRecordRequest) (*pb.DeleteRecordResponse, error) {
|
func (s *recordService) DeleteRecord(_ context.Context, req *pb.DeleteRecordRequest) (*pb.DeleteRecordResponse, error) {
|
||||||
|
if req.Id <= 0 {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "id must be positive")
|
||||||
|
}
|
||||||
|
|
||||||
err := s.db.DeleteRecord(req.Id)
|
err := s.db.DeleteRecord(req.Id)
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, status.Error(codes.NotFound, "record not found")
|
return nil, status.Error(codes.NotFound, "record not found")
|
||||||
@@ -84,7 +124,7 @@ func (s *recordService) DeleteRecord(_ context.Context, req *pb.DeleteRecordRequ
|
|||||||
return &pb.DeleteRecordResponse{}, nil
|
return &pb.DeleteRecordResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func recordToProto(r db.Record) *pb.Record {
|
func (s *recordService) recordToProto(r db.Record) *pb.Record {
|
||||||
return &pb.Record{
|
return &pb.Record{
|
||||||
Id: r.ID,
|
Id: r.ID,
|
||||||
Zone: r.ZoneName,
|
Zone: r.ZoneName,
|
||||||
@@ -92,14 +132,15 @@ func recordToProto(r db.Record) *pb.Record {
|
|||||||
Type: r.Type,
|
Type: r.Type,
|
||||||
Value: r.Value,
|
Value: r.Value,
|
||||||
Ttl: int32(r.TTL),
|
Ttl: int32(r.TTL),
|
||||||
CreatedAt: parseRecordTimestamp(r.CreatedAt),
|
CreatedAt: s.parseRecordTimestamp(r.CreatedAt),
|
||||||
UpdatedAt: parseRecordTimestamp(r.UpdatedAt),
|
UpdatedAt: s.parseRecordTimestamp(r.UpdatedAt),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRecordTimestamp(s string) *timestamppb.Timestamp {
|
func (s *recordService) parseRecordTimestamp(v string) *timestamppb.Timestamp {
|
||||||
t, err := parseTime(s)
|
t, err := parseTime(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Warn("failed to parse record timestamp", "value", v, "error", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return timestamppb.New(t)
|
return timestamppb.New(t)
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ func New(certFile, keyFile string, deps Deps, logger *slog.Logger) (*Server, err
|
|||||||
|
|
||||||
pb.RegisterAdminServiceServer(srv.GRPCServer, &adminService{db: deps.DB})
|
pb.RegisterAdminServiceServer(srv.GRPCServer, &adminService{db: deps.DB})
|
||||||
pb.RegisterAuthServiceServer(srv.GRPCServer, &authService{auth: deps.Authenticator})
|
pb.RegisterAuthServiceServer(srv.GRPCServer, &authService{auth: deps.Authenticator})
|
||||||
pb.RegisterZoneServiceServer(srv.GRPCServer, &zoneService{db: deps.DB})
|
pb.RegisterZoneServiceServer(srv.GRPCServer, &zoneService{db: deps.DB, logger: logger})
|
||||||
pb.RegisterRecordServiceServer(srv.GRPCServer, &recordService{db: deps.DB})
|
pb.RegisterRecordServiceServer(srv.GRPCServer, &recordService{db: deps.DB, logger: logger})
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package grpcserver
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
@@ -15,6 +16,7 @@ import (
|
|||||||
type zoneService struct {
|
type zoneService struct {
|
||||||
pb.UnimplementedZoneServiceServer
|
pb.UnimplementedZoneServiceServer
|
||||||
db *db.DB
|
db *db.DB
|
||||||
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *zoneService) ListZones(_ context.Context, _ *pb.ListZonesRequest) (*pb.ListZonesResponse, error) {
|
func (s *zoneService) ListZones(_ context.Context, _ *pb.ListZonesRequest) (*pb.ListZonesResponse, error) {
|
||||||
@@ -25,12 +27,16 @@ func (s *zoneService) ListZones(_ context.Context, _ *pb.ListZonesRequest) (*pb.
|
|||||||
|
|
||||||
resp := &pb.ListZonesResponse{}
|
resp := &pb.ListZonesResponse{}
|
||||||
for _, z := range zones {
|
for _, z := range zones {
|
||||||
resp.Zones = append(resp.Zones, zoneToProto(z))
|
resp.Zones = append(resp.Zones, s.zoneToProto(z))
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *zoneService) GetZone(_ context.Context, req *pb.GetZoneRequest) (*pb.Zone, error) {
|
func (s *zoneService) GetZone(_ context.Context, req *pb.GetZoneRequest) (*pb.Zone, error) {
|
||||||
|
if req.Name == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "name is required")
|
||||||
|
}
|
||||||
|
|
||||||
zone, err := s.db.GetZone(req.Name)
|
zone, err := s.db.GetZone(req.Name)
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, status.Error(codes.NotFound, "zone not found")
|
return nil, status.Error(codes.NotFound, "zone not found")
|
||||||
@@ -38,27 +44,22 @@ func (s *zoneService) GetZone(_ context.Context, req *pb.GetZoneRequest) (*pb.Zo
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Internal, "failed to get zone")
|
return nil, status.Error(codes.Internal, "failed to get zone")
|
||||||
}
|
}
|
||||||
return zoneToProto(*zone), nil
|
return s.zoneToProto(*zone), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *zoneService) CreateZone(_ context.Context, req *pb.CreateZoneRequest) (*pb.Zone, error) {
|
func (s *zoneService) CreateZone(_ context.Context, req *pb.CreateZoneRequest) (*pb.Zone, error) {
|
||||||
refresh := int(req.Refresh)
|
if req.Name == "" {
|
||||||
if refresh == 0 {
|
return nil, status.Error(codes.InvalidArgument, "name is required")
|
||||||
refresh = 3600
|
|
||||||
}
|
}
|
||||||
retry := int(req.Retry)
|
if req.PrimaryNs == "" {
|
||||||
if retry == 0 {
|
return nil, status.Error(codes.InvalidArgument, "primary_ns is required")
|
||||||
retry = 600
|
|
||||||
}
|
}
|
||||||
expire := int(req.Expire)
|
if req.AdminEmail == "" {
|
||||||
if expire == 0 {
|
return nil, status.Error(codes.InvalidArgument, "admin_email is required")
|
||||||
expire = 86400
|
|
||||||
}
|
|
||||||
minTTL := int(req.MinimumTtl)
|
|
||||||
if minTTL == 0 {
|
|
||||||
minTTL = 300
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
refresh, retry, expire, minTTL := db.ApplySOADefaults(int(req.Refresh), int(req.Retry), int(req.Expire), int(req.MinimumTtl))
|
||||||
|
|
||||||
zone, err := s.db.CreateZone(req.Name, req.PrimaryNs, req.AdminEmail, refresh, retry, expire, minTTL)
|
zone, err := s.db.CreateZone(req.Name, req.PrimaryNs, req.AdminEmail, refresh, retry, expire, minTTL)
|
||||||
if errors.Is(err, db.ErrConflict) {
|
if errors.Is(err, db.ErrConflict) {
|
||||||
return nil, status.Error(codes.AlreadyExists, err.Error())
|
return nil, status.Error(codes.AlreadyExists, err.Error())
|
||||||
@@ -66,27 +67,22 @@ func (s *zoneService) CreateZone(_ context.Context, req *pb.CreateZoneRequest) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Internal, "failed to create zone")
|
return nil, status.Error(codes.Internal, "failed to create zone")
|
||||||
}
|
}
|
||||||
return zoneToProto(*zone), nil
|
return s.zoneToProto(*zone), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *zoneService) UpdateZone(_ context.Context, req *pb.UpdateZoneRequest) (*pb.Zone, error) {
|
func (s *zoneService) UpdateZone(_ context.Context, req *pb.UpdateZoneRequest) (*pb.Zone, error) {
|
||||||
refresh := int(req.Refresh)
|
if req.Name == "" {
|
||||||
if refresh == 0 {
|
return nil, status.Error(codes.InvalidArgument, "name is required")
|
||||||
refresh = 3600
|
|
||||||
}
|
}
|
||||||
retry := int(req.Retry)
|
if req.PrimaryNs == "" {
|
||||||
if retry == 0 {
|
return nil, status.Error(codes.InvalidArgument, "primary_ns is required")
|
||||||
retry = 600
|
|
||||||
}
|
}
|
||||||
expire := int(req.Expire)
|
if req.AdminEmail == "" {
|
||||||
if expire == 0 {
|
return nil, status.Error(codes.InvalidArgument, "admin_email is required")
|
||||||
expire = 86400
|
|
||||||
}
|
|
||||||
minTTL := int(req.MinimumTtl)
|
|
||||||
if minTTL == 0 {
|
|
||||||
minTTL = 300
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
refresh, retry, expire, minTTL := db.ApplySOADefaults(int(req.Refresh), int(req.Retry), int(req.Expire), int(req.MinimumTtl))
|
||||||
|
|
||||||
zone, err := s.db.UpdateZone(req.Name, req.PrimaryNs, req.AdminEmail, refresh, retry, expire, minTTL)
|
zone, err := s.db.UpdateZone(req.Name, req.PrimaryNs, req.AdminEmail, refresh, retry, expire, minTTL)
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, status.Error(codes.NotFound, "zone not found")
|
return nil, status.Error(codes.NotFound, "zone not found")
|
||||||
@@ -94,10 +90,14 @@ func (s *zoneService) UpdateZone(_ context.Context, req *pb.UpdateZoneRequest) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.Internal, "failed to update zone")
|
return nil, status.Error(codes.Internal, "failed to update zone")
|
||||||
}
|
}
|
||||||
return zoneToProto(*zone), nil
|
return s.zoneToProto(*zone), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *zoneService) DeleteZone(_ context.Context, req *pb.DeleteZoneRequest) (*pb.DeleteZoneResponse, error) {
|
func (s *zoneService) DeleteZone(_ context.Context, req *pb.DeleteZoneRequest) (*pb.DeleteZoneResponse, error) {
|
||||||
|
if req.Name == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "name is required")
|
||||||
|
}
|
||||||
|
|
||||||
err := s.db.DeleteZone(req.Name)
|
err := s.db.DeleteZone(req.Name)
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, status.Error(codes.NotFound, "zone not found")
|
return nil, status.Error(codes.NotFound, "zone not found")
|
||||||
@@ -108,7 +108,7 @@ func (s *zoneService) DeleteZone(_ context.Context, req *pb.DeleteZoneRequest) (
|
|||||||
return &pb.DeleteZoneResponse{}, nil
|
return &pb.DeleteZoneResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func zoneToProto(z db.Zone) *pb.Zone {
|
func (s *zoneService) zoneToProto(z db.Zone) *pb.Zone {
|
||||||
return &pb.Zone{
|
return &pb.Zone{
|
||||||
Id: z.ID,
|
Id: z.ID,
|
||||||
Name: z.Name,
|
Name: z.Name,
|
||||||
@@ -119,15 +119,15 @@ func zoneToProto(z db.Zone) *pb.Zone {
|
|||||||
Expire: int32(z.Expire),
|
Expire: int32(z.Expire),
|
||||||
MinimumTtl: int32(z.MinimumTTL),
|
MinimumTtl: int32(z.MinimumTTL),
|
||||||
Serial: z.Serial,
|
Serial: z.Serial,
|
||||||
CreatedAt: parseTimestamp(z.CreatedAt),
|
CreatedAt: s.parseTimestamp(z.CreatedAt),
|
||||||
UpdatedAt: parseTimestamp(z.UpdatedAt),
|
UpdatedAt: s.parseTimestamp(z.UpdatedAt),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseTimestamp(s string) *timestamppb.Timestamp {
|
func (s *zoneService) parseTimestamp(v string) *timestamppb.Timestamp {
|
||||||
// SQLite stores as "2006-01-02T15:04:05Z".
|
t, err := parseTime(v)
|
||||||
t, err := parseTime(s)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.logger.Warn("failed to parse zone timestamp", "value", v, "error", err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return timestamppb.New(t)
|
return timestamppb.New(t)
|
||||||
|
|||||||
@@ -72,18 +72,7 @@ func createZoneHandler(database *db.DB) http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply defaults for SOA params.
|
// Apply defaults for SOA params.
|
||||||
if req.Refresh == 0 {
|
req.Refresh, req.Retry, req.Expire, req.MinimumTTL = db.ApplySOADefaults(req.Refresh, req.Retry, req.Expire, req.MinimumTTL)
|
||||||
req.Refresh = 3600
|
|
||||||
}
|
|
||||||
if req.Retry == 0 {
|
|
||||||
req.Retry = 600
|
|
||||||
}
|
|
||||||
if req.Expire == 0 {
|
|
||||||
req.Expire = 86400
|
|
||||||
}
|
|
||||||
if req.MinimumTTL == 0 {
|
|
||||||
req.MinimumTTL = 300
|
|
||||||
}
|
|
||||||
|
|
||||||
zone, err := database.CreateZone(req.Name, req.PrimaryNS, req.AdminEmail, req.Refresh, req.Retry, req.Expire, req.MinimumTTL)
|
zone, err := database.CreateZone(req.Name, req.PrimaryNS, req.AdminEmail, req.Refresh, req.Retry, req.Expire, req.MinimumTTL)
|
||||||
if errors.Is(err, db.ErrConflict) {
|
if errors.Is(err, db.ErrConflict) {
|
||||||
@@ -117,18 +106,7 @@ func updateZoneHandler(database *db.DB) http.HandlerFunc {
|
|||||||
writeError(w, http.StatusBadRequest, "admin_email is required")
|
writeError(w, http.StatusBadRequest, "admin_email is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if req.Refresh == 0 {
|
req.Refresh, req.Retry, req.Expire, req.MinimumTTL = db.ApplySOADefaults(req.Refresh, req.Retry, req.Expire, req.MinimumTTL)
|
||||||
req.Refresh = 3600
|
|
||||||
}
|
|
||||||
if req.Retry == 0 {
|
|
||||||
req.Retry = 600
|
|
||||||
}
|
|
||||||
if req.Expire == 0 {
|
|
||||||
req.Expire = 86400
|
|
||||||
}
|
|
||||||
if req.MinimumTTL == 0 {
|
|
||||||
req.MinimumTTL = 300
|
|
||||||
}
|
|
||||||
|
|
||||||
zone, err := database.UpdateZone(name, req.PrimaryNS, req.AdminEmail, req.Refresh, req.Retry, req.Expire, req.MinimumTTL)
|
zone, err := database.UpdateZone(name, req.PrimaryNS, req.AdminEmail, req.Refresh, req.Retry, req.Expire, req.MinimumTTL)
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
|
|||||||
Reference in New Issue
Block a user