Add SQLite persistence and write-through gRPC mutations
Database (internal/db) stores listeners, routes, and firewall rules with WAL mode, foreign keys, and idempotent migrations. First run seeds from TOML config; subsequent runs load from DB as source of truth. gRPC admin API now writes to the database before updating in-memory state (write-through cache pattern). Adds snapshot command for VACUUM INTO backups. Refactors firewall.New to accept raw rule slices instead of config struct for flexibility. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc-proxy/v1"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
||||
)
|
||||
|
||||
@@ -24,11 +26,12 @@ import (
|
||||
type AdminServer struct {
|
||||
pb.UnimplementedProxyAdminServer
|
||||
srv *server.Server
|
||||
store *db.Store
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a gRPC server with TLS and optional mTLS.
|
||||
func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
|
||||
func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
|
||||
@@ -39,7 +42,6 @@ func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
// mTLS: require and verify client certificates.
|
||||
if cfg.ClientCA != "" {
|
||||
caCert, err := os.ReadFile(cfg.ClientCA)
|
||||
if err != nil {
|
||||
@@ -58,6 +60,7 @@ func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server
|
||||
|
||||
admin := &AdminServer{
|
||||
srv: srv,
|
||||
store: store,
|
||||
logger: logger,
|
||||
}
|
||||
pb.RegisterProxyAdminServer(grpcServer, admin)
|
||||
@@ -90,7 +93,7 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// AddRoute adds a route to a listener's route table.
|
||||
// AddRoute writes to the database first, then updates in-memory state.
|
||||
func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) {
|
||||
if req.Route == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "route is required")
|
||||
@@ -104,15 +107,23 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := ls.AddRoute(req.Route.Hostname, req.Route.Backend); err != nil {
|
||||
hostname := strings.ToLower(req.Route.Hostname)
|
||||
|
||||
// Write-through: DB first, then memory.
|
||||
if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend); err != nil {
|
||||
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
||||
}
|
||||
|
||||
a.logger.Info("route added", "listener", ls.Addr, "hostname", req.Route.Hostname, "backend", req.Route.Backend)
|
||||
if err := ls.AddRoute(hostname, req.Route.Backend); err != nil {
|
||||
// DB succeeded but memory failed (should not happen since DB enforces uniqueness).
|
||||
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
|
||||
}
|
||||
|
||||
a.logger.Info("route added", "listener", ls.Addr, "hostname", hostname, "backend", req.Route.Backend)
|
||||
return &pb.AddRouteResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveRoute removes a route from a listener's route table.
|
||||
// RemoveRoute writes to the database first, then updates in-memory state.
|
||||
func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) {
|
||||
if req.Hostname == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "hostname is required")
|
||||
@@ -123,11 +134,17 @@ func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := ls.RemoveRoute(req.Hostname); err != nil {
|
||||
hostname := strings.ToLower(req.Hostname)
|
||||
|
||||
if err := a.store.DeleteRoute(ls.ID, hostname); err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "%v", err)
|
||||
}
|
||||
|
||||
a.logger.Info("route removed", "listener", ls.Addr, "hostname", req.Hostname)
|
||||
if err := ls.RemoveRoute(hostname); err != nil {
|
||||
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
|
||||
}
|
||||
|
||||
a.logger.Info("route removed", "listener", ls.Addr, "hostname", hostname)
|
||||
return &pb.RemoveRouteResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -158,61 +175,74 @@ func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRules
|
||||
return &pb.GetFirewallRulesResponse{Rules: rules}, nil
|
||||
}
|
||||
|
||||
// AddFirewallRule adds a firewall rule.
|
||||
// AddFirewallRule writes to the database first, then updates in-memory state.
|
||||
func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) {
|
||||
if req.Rule == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "rule is required")
|
||||
}
|
||||
|
||||
fw := a.srv.Firewall()
|
||||
switch req.Rule.Type {
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
|
||||
if err := fw.AddIP(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
|
||||
if err := fw.AddCIDR(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
|
||||
if req.Rule.Value == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "country code is required")
|
||||
}
|
||||
fw.AddCountry(req.Rule.Value)
|
||||
default:
|
||||
return nil, status.Error(codes.InvalidArgument, "unknown rule type")
|
||||
ruleType, err := protoRuleTypeToString(req.Rule.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.logger.Info("firewall rule added", "type", req.Rule.Type, "value", req.Rule.Value)
|
||||
if req.Rule.Value == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "value is required")
|
||||
}
|
||||
|
||||
// Write-through: DB first, then memory.
|
||||
if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
||||
}
|
||||
|
||||
fw := a.srv.Firewall()
|
||||
switch ruleType {
|
||||
case "ip":
|
||||
if err := fw.AddIP(req.Rule.Value); err != nil {
|
||||
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
|
||||
}
|
||||
case "cidr":
|
||||
if err := fw.AddCIDR(req.Rule.Value); err != nil {
|
||||
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
|
||||
}
|
||||
case "country":
|
||||
fw.AddCountry(req.Rule.Value)
|
||||
}
|
||||
|
||||
a.logger.Info("firewall rule added", "type", ruleType, "value", req.Rule.Value)
|
||||
return &pb.AddFirewallRuleResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveFirewallRule removes a firewall rule.
|
||||
// RemoveFirewallRule writes to the database first, then updates in-memory state.
|
||||
func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) {
|
||||
if req.Rule == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "rule is required")
|
||||
}
|
||||
|
||||
fw := a.srv.Firewall()
|
||||
switch req.Rule.Type {
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
|
||||
if err := fw.RemoveIP(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
|
||||
if err := fw.RemoveCIDR(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
|
||||
if req.Rule.Value == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "country code is required")
|
||||
}
|
||||
fw.RemoveCountry(req.Rule.Value)
|
||||
default:
|
||||
return nil, status.Error(codes.InvalidArgument, "unknown rule type")
|
||||
ruleType, err := protoRuleTypeToString(req.Rule.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.logger.Info("firewall rule removed", "type", req.Rule.Type, "value", req.Rule.Value)
|
||||
if err := a.store.DeleteFirewallRule(ruleType, req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "%v", err)
|
||||
}
|
||||
|
||||
fw := a.srv.Firewall()
|
||||
switch ruleType {
|
||||
case "ip":
|
||||
if err := fw.RemoveIP(req.Rule.Value); err != nil {
|
||||
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
|
||||
}
|
||||
case "cidr":
|
||||
if err := fw.RemoveCIDR(req.Rule.Value); err != nil {
|
||||
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
|
||||
}
|
||||
case "country":
|
||||
fw.RemoveCountry(req.Rule.Value)
|
||||
}
|
||||
|
||||
a.logger.Info("firewall rule removed", "type", ruleType, "value", req.Rule.Value)
|
||||
return &pb.RemoveFirewallRuleResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -244,3 +274,16 @@ func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) {
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "listener %q not found", addr)
|
||||
}
|
||||
|
||||
func protoRuleTypeToString(t pb.FirewallRuleType) (string, error) {
|
||||
switch t {
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
|
||||
return "ip", nil
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
|
||||
return "cidr", nil
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
|
||||
return "country", nil
|
||||
default:
|
||||
return "", status.Error(codes.InvalidArgument, "unknown rule type")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user