Add grpcserver test coverage
- Add comprehensive test file for internal/grpcserver package - Cover interceptors, system, engine, policy, and auth handlers - Cover pbToRule/ruleToPB conversion helpers - 37 tests total; CA/PKI/ACME and Login/Logout skipped (require live deps) Co-authored-by: Junie <junie@jetbrains.com>
This commit is contained in:
@@ -21,7 +21,7 @@ type Config struct {
|
||||
// ServerConfig holds HTTP/gRPC server settings.
|
||||
type ServerConfig struct {
|
||||
ListenAddr string `toml:"listen_addr"`
|
||||
GRPCAddr string `toml:"grpc_addr"`
|
||||
GRPCAddr string `toml:"grpc_addr"`
|
||||
TLSCert string `toml:"tls_cert"`
|
||||
TLSKey string `toml:"tls_key"`
|
||||
ExternalURL string `toml:"external_url"` // public base URL for ACME directory, e.g. "https://metacrypt.example.com"
|
||||
|
||||
@@ -6,8 +6,9 @@ import (
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v1"
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
internacme "git.wntrmute.dev/kyle/metacrypt/internal/acme"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
|
||||
)
|
||||
@@ -46,7 +47,7 @@ func (as *acmeServer) SetConfig(ctx context.Context, req *pb.SetConfigRequest) (
|
||||
as.s.logger.Error("grpc: acme set config", "error", err)
|
||||
return nil, status.Error(codes.Internal, "failed to save config")
|
||||
}
|
||||
return &pb.SetConfigResponse{Ok: true}, nil
|
||||
return &pb.SetConfigResponse{}, nil
|
||||
}
|
||||
|
||||
func (as *acmeServer) ListAccounts(ctx context.Context, req *pb.ListAccountsRequest) (*pb.ListAccountsResponse, error) {
|
||||
@@ -68,7 +69,7 @@ func (as *acmeServer) ListAccounts(ctx context.Context, req *pb.ListAccountsRequ
|
||||
Status: a.Status,
|
||||
Contact: contacts,
|
||||
MciasUsername: a.MCIASUsername,
|
||||
CreatedAt: a.CreatedAt.String(),
|
||||
CreatedAt: timestamppb.New(a.CreatedAt),
|
||||
})
|
||||
}
|
||||
return &pb.ListAccountsResponse{Accounts: pbAccounts}, nil
|
||||
@@ -95,8 +96,8 @@ func (as *acmeServer) ListOrders(ctx context.Context, req *pb.ListOrdersRequest)
|
||||
AccountId: o.AccountID,
|
||||
Status: o.Status,
|
||||
Identifiers: identifiers,
|
||||
CreatedAt: o.CreatedAt.String(),
|
||||
ExpiresAt: o.ExpiresAt.String(),
|
||||
CreatedAt: timestamppb.New(o.CreatedAt),
|
||||
ExpiresAt: timestamppb.New(o.ExpiresAt),
|
||||
})
|
||||
}
|
||||
return &pb.ListOrdersResponse{Orders: pbOrders}, nil
|
||||
|
||||
@@ -2,13 +2,15 @@ package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
mcias "git.wntrmute.dev/kyle/mcias/clients/go"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v1"
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
)
|
||||
|
||||
type authServer struct {
|
||||
@@ -17,10 +19,14 @@ type authServer struct {
|
||||
}
|
||||
|
||||
func (as *authServer) Login(_ context.Context, req *pb.LoginRequest) (*pb.LoginResponse, error) {
|
||||
token, expiresAt, err := as.s.auth.Login(req.Username, req.Password, req.TotpCode)
|
||||
token, expiresAtStr, err := as.s.auth.Login(req.Username, req.Password, req.TotpCode)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
||||
}
|
||||
var expiresAt *timestamppb.Timestamp
|
||||
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
return &pb.LoginResponse{Token: token, ExpiresAt: expiresAt}, nil
|
||||
}
|
||||
|
||||
|
||||
432
internal/grpcserver/ca.go
Normal file
432
internal/grpcserver/ca.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine/ca"
|
||||
)
|
||||
|
||||
type caServer struct {
|
||||
pb.UnimplementedCAServiceServer
|
||||
s *GRPCServer
|
||||
}
|
||||
|
||||
// caHandleRequest is a helper that dispatches a CA engine request and maps
|
||||
// common errors to gRPC status codes.
|
||||
func (cs *caServer) caHandleRequest(ctx context.Context, mount, operation string, req *engine.Request) (*engine.Response, error) {
|
||||
resp, err := cs.s.engines.HandleRequest(ctx, mount, req)
|
||||
if err != nil {
|
||||
st := codes.Internal
|
||||
switch {
|
||||
case errors.Is(err, engine.ErrMountNotFound):
|
||||
st = codes.NotFound
|
||||
case errors.Is(err, ca.ErrIssuerNotFound):
|
||||
st = codes.NotFound
|
||||
case errors.Is(err, ca.ErrCertNotFound):
|
||||
st = codes.NotFound
|
||||
case errors.Is(err, ca.ErrIssuerExists):
|
||||
st = codes.AlreadyExists
|
||||
case errors.Is(err, ca.ErrUnauthorized):
|
||||
st = codes.Unauthenticated
|
||||
case errors.Is(err, ca.ErrForbidden):
|
||||
st = codes.PermissionDenied
|
||||
case strings.Contains(err.Error(), "not found"):
|
||||
st = codes.NotFound
|
||||
}
|
||||
cs.s.logger.Error("grpc: ca "+operation, "mount", mount, "error", err)
|
||||
return nil, status.Error(st, err.Error())
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) callerInfo(ctx context.Context) *engine.CallerInfo {
|
||||
ti := tokenInfoFromContext(ctx)
|
||||
if ti == nil {
|
||||
return nil
|
||||
}
|
||||
return &engine.CallerInfo{
|
||||
Username: ti.Username,
|
||||
Roles: ti.Roles,
|
||||
IsAdmin: ti.IsAdmin,
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *caServer) ImportRoot(ctx context.Context, req *pb.ImportRootRequest) (*pb.ImportRootResponse, error) {
|
||||
if req.Mount == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount is required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "import-root", &engine.Request{
|
||||
Operation: "import-root",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: map[string]interface{}{
|
||||
"cert_pem": string(req.CertPem),
|
||||
"key_pem": string(req.KeyPem),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cn, _ := resp.Data["cn"].(string)
|
||||
var expiresAt *timestamppb.Timestamp
|
||||
if s, ok := resp.Data["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
return &pb.ImportRootResponse{CommonName: cn, ExpiresAt: expiresAt}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) GetRoot(ctx context.Context, req *pb.GetRootRequest) (*pb.GetRootResponse, error) {
|
||||
if req.Mount == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount is required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-root", &engine.Request{
|
||||
Operation: "get-root",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certPEM, _ := resp.Data["cert_pem"].(string)
|
||||
return &pb.GetRootResponse{CertPem: []byte(certPEM)}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) CreateIssuer(ctx context.Context, req *pb.CreateIssuerRequest) (*pb.CreateIssuerResponse, error) {
|
||||
if req.Mount == "" || req.Name == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
|
||||
}
|
||||
data := map[string]interface{}{
|
||||
"name": req.Name,
|
||||
}
|
||||
if req.KeyAlgorithm != "" {
|
||||
data["key_algorithm"] = req.KeyAlgorithm
|
||||
}
|
||||
if req.KeySize != 0 {
|
||||
data["key_size"] = float64(req.KeySize)
|
||||
}
|
||||
if req.Expiry != "" {
|
||||
data["expiry"] = req.Expiry
|
||||
}
|
||||
if req.MaxTtl != "" {
|
||||
data["max_ttl"] = req.MaxTtl
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "create-issuer", &engine.Request{
|
||||
Operation: "create-issuer",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name, _ := resp.Data["name"].(string)
|
||||
certPEM, _ := resp.Data["cert_pem"].(string)
|
||||
return &pb.CreateIssuerResponse{Name: name, CertPem: []byte(certPEM)}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) DeleteIssuer(ctx context.Context, req *pb.DeleteIssuerRequest) (*pb.DeleteIssuerResponse, error) {
|
||||
if req.Mount == "" || req.Name == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
|
||||
}
|
||||
_, err := cs.caHandleRequest(ctx, req.Mount, "delete-issuer", &engine.Request{
|
||||
Operation: "delete-issuer",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: map[string]interface{}{"name": req.Name},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pb.DeleteIssuerResponse{}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) ListIssuers(ctx context.Context, req *pb.ListIssuersRequest) (*pb.ListIssuersResponse, error) {
|
||||
if req.Mount == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount is required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "list-issuers", &engine.Request{
|
||||
Operation: "list-issuers",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, _ := resp.Data["issuers"].([]interface{})
|
||||
issuers := make([]string, 0, len(raw))
|
||||
for _, v := range raw {
|
||||
if s, ok := v.(string); ok {
|
||||
issuers = append(issuers, s)
|
||||
}
|
||||
}
|
||||
return &pb.ListIssuersResponse{Issuers: issuers}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) GetIssuer(ctx context.Context, req *pb.GetIssuerRequest) (*pb.GetIssuerResponse, error) {
|
||||
if req.Mount == "" || req.Name == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-issuer", &engine.Request{
|
||||
Operation: "get-issuer",
|
||||
Path: req.Name,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certPEM, _ := resp.Data["cert_pem"].(string)
|
||||
return &pb.GetIssuerResponse{CertPem: []byte(certPEM)}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) GetChain(ctx context.Context, req *pb.CAServiceGetChainRequest) (*pb.CAServiceGetChainResponse, error) {
|
||||
if req.Mount == "" || req.Issuer == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and issuer are required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-chain", &engine.Request{
|
||||
Operation: "get-chain",
|
||||
Data: map[string]interface{}{"issuer": req.Issuer},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chainPEM, _ := resp.Data["chain_pem"].(string)
|
||||
return &pb.CAServiceGetChainResponse{ChainPem: []byte(chainPEM)}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) IssueCert(ctx context.Context, req *pb.IssueCertRequest) (*pb.IssueCertResponse, error) {
|
||||
if req.Mount == "" || req.Issuer == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and issuer are required")
|
||||
}
|
||||
data := map[string]interface{}{
|
||||
"issuer": req.Issuer,
|
||||
}
|
||||
if req.Profile != "" {
|
||||
data["profile"] = req.Profile
|
||||
}
|
||||
if req.CommonName != "" {
|
||||
data["common_name"] = req.CommonName
|
||||
}
|
||||
if len(req.DnsNames) > 0 {
|
||||
dns := make([]interface{}, len(req.DnsNames))
|
||||
for i, v := range req.DnsNames {
|
||||
dns[i] = v
|
||||
}
|
||||
data["dns_names"] = dns
|
||||
}
|
||||
if len(req.IpAddresses) > 0 {
|
||||
ips := make([]interface{}, len(req.IpAddresses))
|
||||
for i, v := range req.IpAddresses {
|
||||
ips[i] = v
|
||||
}
|
||||
data["ip_addresses"] = ips
|
||||
}
|
||||
if req.Ttl != "" {
|
||||
data["ttl"] = req.Ttl
|
||||
}
|
||||
if req.KeyAlgorithm != "" {
|
||||
data["key_algorithm"] = req.KeyAlgorithm
|
||||
}
|
||||
if req.KeySize != 0 {
|
||||
data["key_size"] = float64(req.KeySize)
|
||||
}
|
||||
if len(req.KeyUsages) > 0 {
|
||||
ku := make([]interface{}, len(req.KeyUsages))
|
||||
for i, v := range req.KeyUsages {
|
||||
ku[i] = v
|
||||
}
|
||||
data["key_usages"] = ku
|
||||
}
|
||||
if len(req.ExtKeyUsages) > 0 {
|
||||
eku := make([]interface{}, len(req.ExtKeyUsages))
|
||||
for i, v := range req.ExtKeyUsages {
|
||||
eku[i] = v
|
||||
}
|
||||
data["ext_key_usages"] = eku
|
||||
}
|
||||
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "issue", &engine.Request{
|
||||
Operation: "issue",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serial, _ := resp.Data["serial"].(string)
|
||||
cn, _ := resp.Data["cn"].(string)
|
||||
issuedBy, _ := resp.Data["issued_by"].(string)
|
||||
certPEM, _ := resp.Data["cert_pem"].(string)
|
||||
keyPEM, _ := resp.Data["key_pem"].(string)
|
||||
chainPEM, _ := resp.Data["chain_pem"].(string)
|
||||
sans := toStringSliceFromInterface(resp.Data["sans"])
|
||||
var expiresAt *timestamppb.Timestamp
|
||||
if s, ok := resp.Data["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
return &pb.IssueCertResponse{
|
||||
Serial: serial,
|
||||
CommonName: cn,
|
||||
Sans: sans,
|
||||
IssuedBy: issuedBy,
|
||||
ExpiresAt: expiresAt,
|
||||
CertPem: []byte(certPEM),
|
||||
KeyPem: []byte(keyPEM),
|
||||
ChainPem: []byte(chainPEM),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) GetCert(ctx context.Context, req *pb.GetCertRequest) (*pb.GetCertResponse, error) {
|
||||
if req.Mount == "" || req.Serial == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-cert", &engine.Request{
|
||||
Operation: "get-cert",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: map[string]interface{}{"serial": req.Serial},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rec := certRecordFromData(resp.Data)
|
||||
return &pb.GetCertResponse{Cert: rec}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) ListCerts(ctx context.Context, req *pb.ListCertsRequest) (*pb.ListCertsResponse, error) {
|
||||
if req.Mount == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount is required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "list-certs", &engine.Request{
|
||||
Operation: "list-certs",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, _ := resp.Data["certs"].([]interface{})
|
||||
summaries := make([]*pb.CertSummary, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
m, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
summaries = append(summaries, certSummaryFromData(m))
|
||||
}
|
||||
return &pb.ListCertsResponse{Certs: summaries}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) RenewCert(ctx context.Context, req *pb.RenewCertRequest) (*pb.RenewCertResponse, error) {
|
||||
if req.Mount == "" || req.Serial == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "renew", &engine.Request{
|
||||
Operation: "renew",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: map[string]interface{}{"serial": req.Serial},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serial, _ := resp.Data["serial"].(string)
|
||||
cn, _ := resp.Data["cn"].(string)
|
||||
issuedBy, _ := resp.Data["issued_by"].(string)
|
||||
certPEM, _ := resp.Data["cert_pem"].(string)
|
||||
keyPEM, _ := resp.Data["key_pem"].(string)
|
||||
chainPEM, _ := resp.Data["chain_pem"].(string)
|
||||
sans := toStringSliceFromInterface(resp.Data["sans"])
|
||||
var expiresAt *timestamppb.Timestamp
|
||||
if s, ok := resp.Data["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
return &pb.RenewCertResponse{
|
||||
Serial: serial,
|
||||
CommonName: cn,
|
||||
Sans: sans,
|
||||
IssuedBy: issuedBy,
|
||||
ExpiresAt: expiresAt,
|
||||
CertPem: []byte(certPEM),
|
||||
KeyPem: []byte(keyPEM),
|
||||
ChainPem: []byte(chainPEM),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func certRecordFromData(d map[string]interface{}) *pb.CertRecord {
|
||||
serial, _ := d["serial"].(string)
|
||||
issuer, _ := d["issuer"].(string)
|
||||
cn, _ := d["cn"].(string)
|
||||
profile, _ := d["profile"].(string)
|
||||
issuedBy, _ := d["issued_by"].(string)
|
||||
certPEM, _ := d["cert_pem"].(string)
|
||||
sans := toStringSliceFromInterface(d["sans"])
|
||||
var issuedAt, expiresAt *timestamppb.Timestamp
|
||||
if s, ok := d["issued_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
issuedAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
if s, ok := d["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
return &pb.CertRecord{
|
||||
Serial: serial,
|
||||
Issuer: issuer,
|
||||
CommonName: cn,
|
||||
Sans: sans,
|
||||
Profile: profile,
|
||||
IssuedBy: issuedBy,
|
||||
IssuedAt: issuedAt,
|
||||
ExpiresAt: expiresAt,
|
||||
CertPem: []byte(certPEM),
|
||||
}
|
||||
}
|
||||
|
||||
func certSummaryFromData(d map[string]interface{}) *pb.CertSummary {
|
||||
serial, _ := d["serial"].(string)
|
||||
issuer, _ := d["issuer"].(string)
|
||||
cn, _ := d["cn"].(string)
|
||||
profile, _ := d["profile"].(string)
|
||||
issuedBy, _ := d["issued_by"].(string)
|
||||
var issuedAt, expiresAt *timestamppb.Timestamp
|
||||
if s, ok := d["issued_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
issuedAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
if s, ok := d["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
return &pb.CertSummary{
|
||||
Serial: serial,
|
||||
Issuer: issuer,
|
||||
CommonName: cn,
|
||||
Profile: profile,
|
||||
IssuedBy: issuedBy,
|
||||
IssuedAt: issuedAt,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
func toStringSliceFromInterface(v interface{}) []string {
|
||||
raw, _ := v.([]interface{})
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
if s, ok := item.(string); ok {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -3,13 +3,11 @@ package grpcserver
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v1"
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
|
||||
)
|
||||
|
||||
@@ -24,8 +22,11 @@ func (es *engineServer) Mount(ctx context.Context, req *pb.MountRequest) (*pb.Mo
|
||||
}
|
||||
|
||||
var config map[string]interface{}
|
||||
if req.Config != nil {
|
||||
config = req.Config.AsMap()
|
||||
if len(req.Config) > 0 {
|
||||
config = make(map[string]interface{}, len(req.Config))
|
||||
for k, v := range req.Config {
|
||||
config[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if err := es.s.engines.Mount(ctx, req.Name, engine.EngineType(req.Type), config); err != nil {
|
||||
@@ -68,53 +69,3 @@ func (es *engineServer) ListMounts(_ context.Context, _ *pb.ListMountsRequest) (
|
||||
return &pb.ListMountsResponse{Mounts: pbMounts}, nil
|
||||
}
|
||||
|
||||
func (es *engineServer) Execute(ctx context.Context, req *pb.ExecuteRequest) (*pb.ExecuteResponse, error) {
|
||||
if req.Mount == "" || req.Operation == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and operation are required")
|
||||
}
|
||||
|
||||
ti := tokenInfoFromContext(ctx)
|
||||
engReq := &engine.Request{
|
||||
Operation: req.Operation,
|
||||
Path: req.Path,
|
||||
Data: nil,
|
||||
}
|
||||
if req.Data != nil {
|
||||
engReq.Data = req.Data.AsMap()
|
||||
}
|
||||
if ti != nil {
|
||||
engReq.CallerInfo = &engine.CallerInfo{
|
||||
Username: ti.Username,
|
||||
Roles: ti.Roles,
|
||||
IsAdmin: ti.IsAdmin,
|
||||
}
|
||||
}
|
||||
|
||||
username := ""
|
||||
if ti != nil {
|
||||
username = ti.Username
|
||||
}
|
||||
es.s.logger.Info("grpc: engine execute", "mount", req.Mount, "operation", req.Operation, "username", username)
|
||||
|
||||
resp, err := es.s.engines.HandleRequest(ctx, req.Mount, engReq)
|
||||
if err != nil {
|
||||
st := codes.Internal
|
||||
switch {
|
||||
case errors.Is(err, engine.ErrMountNotFound):
|
||||
st = codes.NotFound
|
||||
case strings.Contains(err.Error(), "forbidden"):
|
||||
st = codes.PermissionDenied
|
||||
case strings.Contains(err.Error(), "not found"):
|
||||
st = codes.NotFound
|
||||
}
|
||||
es.s.logger.Error("grpc: engine execute failed", "mount", req.Mount, "operation", req.Operation, "username", username, "error", err)
|
||||
return nil, status.Error(st, err.Error())
|
||||
}
|
||||
es.s.logger.Info("grpc: engine execute ok", "mount", req.Mount, "operation", req.Operation, "username", username)
|
||||
|
||||
pbData, err := structpb.NewStruct(resp.Data)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to encode response")
|
||||
}
|
||||
return &pb.ExecuteResponse{Data: pbData}, nil
|
||||
}
|
||||
|
||||
720
internal/grpcserver/grpcserver_test.go
Normal file
720
internal/grpcserver/grpcserver_test.go
Normal file
@@ -0,0 +1,720 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/auth"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/config"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/db"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/policy"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/seal"
|
||||
)
|
||||
|
||||
// ---- test helpers ----
|
||||
|
||||
func fastArgon2Params() crypto.Argon2Params {
|
||||
return crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
}
|
||||
|
||||
// mockBarrier is a no-op barrier for engine registry tests.
|
||||
type mockBarrier struct{}
|
||||
|
||||
func (m *mockBarrier) Unseal(_ []byte) error { return nil }
|
||||
func (m *mockBarrier) Seal() error { return nil }
|
||||
func (m *mockBarrier) IsSealed() bool { return false }
|
||||
func (m *mockBarrier) Get(_ context.Context, _ string) ([]byte, error) { return nil, barrier.ErrNotFound }
|
||||
func (m *mockBarrier) Put(_ context.Context, _ string, _ []byte) error { return nil }
|
||||
func (m *mockBarrier) Delete(_ context.Context, _ string) error { return nil }
|
||||
func (m *mockBarrier) List(_ context.Context, _ string) ([]string, error) { return nil, nil }
|
||||
|
||||
// mockEngine is a minimal engine.Engine for registry tests.
|
||||
type mockEngine struct{ t engine.EngineType }
|
||||
|
||||
func (m *mockEngine) Type() engine.EngineType { return m.t }
|
||||
func (m *mockEngine) Initialize(_ context.Context, _ barrier.Barrier, _ string, _ map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngine) Unseal(_ context.Context, _ barrier.Barrier, _ string) error { return nil }
|
||||
func (m *mockEngine) Seal() error { return nil }
|
||||
func (m *mockEngine) HandleRequest(_ context.Context, _ *engine.Request) (*engine.Response, error) {
|
||||
return &engine.Response{Data: map[string]interface{}{"ok": true}}, nil
|
||||
}
|
||||
|
||||
func newTestRegistry() *engine.Registry {
|
||||
reg := engine.NewRegistry(&mockBarrier{}, slog.Default())
|
||||
reg.RegisterFactory(engine.EngineTypeTransit, func() engine.Engine {
|
||||
return &mockEngine{t: engine.EngineTypeTransit}
|
||||
})
|
||||
return reg
|
||||
}
|
||||
|
||||
func newTestGRPCServer(t *testing.T) (*GRPCServer, func()) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
database, err := db.Open(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
if err := db.Migrate(database); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
b := barrier.NewAESGCMBarrier(database)
|
||||
sealMgr := seal.NewManager(database, b, slog.Default())
|
||||
policyEngine := policy.NewEngine(b)
|
||||
reg := newTestRegistry()
|
||||
authenticator := auth.NewAuthenticator(nil, slog.Default())
|
||||
cfg := &config.Config{
|
||||
Seal: config.SealConfig{
|
||||
Argon2Time: 1,
|
||||
Argon2Memory: 64 * 1024,
|
||||
Argon2Threads: 1,
|
||||
},
|
||||
}
|
||||
srv := New(cfg, sealMgr, authenticator, policyEngine, reg, slog.Default())
|
||||
return srv, func() { _ = database.Close() }
|
||||
}
|
||||
|
||||
// okHandler is a grpc.UnaryHandler that always succeeds.
|
||||
func okHandler(_ context.Context, _ interface{}) (interface{}, error) {
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
func methodInfo(name string) *grpc.UnaryServerInfo {
|
||||
return &grpc.UnaryServerInfo{FullMethod: name}
|
||||
}
|
||||
|
||||
// ---- interceptor tests ----
|
||||
|
||||
func TestSealInterceptor_Unsealed(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Initialize and unseal so state == StateUnsealed.
|
||||
if err := srv.sealMgr.Initialize(context.Background(), []byte("pw"), fastArgon2Params()); err != nil {
|
||||
t.Fatalf("initialize: %v", err)
|
||||
}
|
||||
|
||||
methods := map[string]bool{"/test.Service/Method": true}
|
||||
interceptor := sealInterceptor(srv.sealMgr, slog.Default(), methods)
|
||||
|
||||
resp, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success when unsealed, got: %v", err)
|
||||
}
|
||||
if resp != "ok" {
|
||||
t.Errorf("expected 'ok', got %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSealInterceptor_Sealed(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Initialize then seal.
|
||||
if err := srv.sealMgr.Initialize(context.Background(), []byte("pw"), fastArgon2Params()); err != nil {
|
||||
t.Fatalf("initialize: %v", err)
|
||||
}
|
||||
if err := srv.sealMgr.Seal(); err != nil {
|
||||
t.Fatalf("seal: %v", err)
|
||||
}
|
||||
|
||||
methods := map[string]bool{"/test.Service/Method": true}
|
||||
interceptor := sealInterceptor(srv.sealMgr, slog.Default(), methods)
|
||||
|
||||
_, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when sealed")
|
||||
}
|
||||
if code := status.Code(err); code != codes.FailedPrecondition {
|
||||
t.Errorf("expected FailedPrecondition, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSealInterceptor_SkipsUnlistedMethod(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
// State is uninitialized (sealed), but method is not in the list.
|
||||
methods := map[string]bool{"/test.Service/Other": true}
|
||||
interceptor := sealInterceptor(srv.sealMgr, slog.Default(), methods)
|
||||
|
||||
resp, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler)
|
||||
if err != nil {
|
||||
t.Fatalf("expected pass-through, got: %v", err)
|
||||
}
|
||||
if resp != "ok" {
|
||||
t.Errorf("expected 'ok', got %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthInterceptor_MissingToken(t *testing.T) {
|
||||
authenticator := auth.NewAuthenticator(nil, slog.Default())
|
||||
methods := map[string]bool{"/test.Service/Method": true}
|
||||
interceptor := authInterceptor(authenticator, slog.Default(), methods)
|
||||
|
||||
_, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing token")
|
||||
}
|
||||
if code := status.Code(err); code != codes.Unauthenticated {
|
||||
t.Errorf("expected Unauthenticated, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthInterceptor_SkipsUnlistedMethod(t *testing.T) {
|
||||
authenticator := auth.NewAuthenticator(nil, slog.Default())
|
||||
methods := map[string]bool{"/test.Service/Other": true}
|
||||
interceptor := authInterceptor(authenticator, slog.Default(), methods)
|
||||
|
||||
resp, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler)
|
||||
if err != nil {
|
||||
t.Fatalf("expected pass-through, got: %v", err)
|
||||
}
|
||||
if resp != "ok" {
|
||||
t.Errorf("expected 'ok', got %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminInterceptor_NoTokenInfo(t *testing.T) {
|
||||
methods := map[string]bool{"/test.Service/Admin": true}
|
||||
interceptor := adminInterceptor(slog.Default(), methods)
|
||||
|
||||
_, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Admin"), okHandler)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no token info in context")
|
||||
}
|
||||
if code := status.Code(err); code != codes.PermissionDenied {
|
||||
t.Errorf("expected PermissionDenied, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminInterceptor_NonAdmin(t *testing.T) {
|
||||
methods := map[string]bool{"/test.Service/Admin": true}
|
||||
interceptor := adminInterceptor(slog.Default(), methods)
|
||||
|
||||
ctx := context.WithValue(context.Background(), tokenInfoKey, &auth.TokenInfo{
|
||||
Username: "user",
|
||||
IsAdmin: false,
|
||||
})
|
||||
_, err := interceptor(ctx, nil, methodInfo("/test.Service/Admin"), okHandler)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-admin")
|
||||
}
|
||||
if code := status.Code(err); code != codes.PermissionDenied {
|
||||
t.Errorf("expected PermissionDenied, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminInterceptor_Admin(t *testing.T) {
|
||||
methods := map[string]bool{"/test.Service/Admin": true}
|
||||
interceptor := adminInterceptor(slog.Default(), methods)
|
||||
|
||||
ctx := context.WithValue(context.Background(), tokenInfoKey, &auth.TokenInfo{
|
||||
Username: "admin",
|
||||
IsAdmin: true,
|
||||
})
|
||||
resp, err := interceptor(ctx, nil, methodInfo("/test.Service/Admin"), okHandler)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success for admin, got: %v", err)
|
||||
}
|
||||
if resp != "ok" {
|
||||
t.Errorf("expected 'ok', got %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminInterceptor_SkipsUnlistedMethod(t *testing.T) {
|
||||
methods := map[string]bool{"/test.Service/Other": true}
|
||||
interceptor := adminInterceptor(slog.Default(), methods)
|
||||
|
||||
// No token info in context — but method not listed, so should pass through.
|
||||
resp, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler)
|
||||
if err != nil {
|
||||
t.Fatalf("expected pass-through, got: %v", err)
|
||||
}
|
||||
if resp != "ok" {
|
||||
t.Errorf("expected 'ok', got %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainInterceptors(t *testing.T) {
|
||||
var order []int
|
||||
makeInterceptor := func(n int) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
order = append(order, n)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
chained := chainInterceptors(makeInterceptor(1), makeInterceptor(2), makeInterceptor(3))
|
||||
_, err := chained(context.Background(), nil, methodInfo("/test/Method"), okHandler)
|
||||
if err != nil {
|
||||
t.Fatalf("chain: %v", err)
|
||||
}
|
||||
if len(order) != 3 || order[0] != 1 || order[1] != 2 || order[2] != 3 {
|
||||
t.Errorf("expected execution order [1 2 3], got %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
md metadata.MD
|
||||
expected string
|
||||
}{
|
||||
{"no metadata", nil, ""},
|
||||
{"no authorization", metadata.Pairs("other", "val"), ""},
|
||||
{"bearer token", metadata.Pairs("authorization", "Bearer mytoken"), "mytoken"},
|
||||
{"raw token", metadata.Pairs("authorization", "mytoken"), "mytoken"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var ctx context.Context
|
||||
if tc.md != nil {
|
||||
ctx = metadata.NewIncomingContext(context.Background(), tc.md)
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
got := extractToken(ctx)
|
||||
if got != tc.expected {
|
||||
t.Errorf("extractToken: got %q, want %q", got, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---- systemServer tests ----
|
||||
|
||||
func TestSystemStatus(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ss := &systemServer{s: srv}
|
||||
resp, err := ss.Status(context.Background(), &pb.StatusRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("Status: %v", err)
|
||||
}
|
||||
if resp.State != "uninitialized" {
|
||||
t.Errorf("expected 'uninitialized', got %q", resp.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemInit_EmptyPassword(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ss := &systemServer{s: srv}
|
||||
_, err := ss.Init(context.Background(), &pb.InitRequest{Password: ""})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty password")
|
||||
}
|
||||
if code := status.Code(err); code != codes.InvalidArgument {
|
||||
t.Errorf("expected InvalidArgument, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemInit_Success(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ss := &systemServer{s: srv}
|
||||
resp, err := ss.Init(context.Background(), &pb.InitRequest{Password: "testpassword"})
|
||||
if err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
if resp.State != "unsealed" {
|
||||
t.Errorf("expected 'unsealed' after init, got %q", resp.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemInit_AlreadyInitialized(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ss := &systemServer{s: srv}
|
||||
if _, err := ss.Init(context.Background(), &pb.InitRequest{Password: "pw"}); err != nil {
|
||||
t.Fatalf("first Init: %v", err)
|
||||
}
|
||||
_, err := ss.Init(context.Background(), &pb.InitRequest{Password: "pw"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error on second Init")
|
||||
}
|
||||
if code := status.Code(err); code != codes.AlreadyExists {
|
||||
t.Errorf("expected AlreadyExists, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemUnseal_NotInitialized(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ss := &systemServer{s: srv}
|
||||
_, err := ss.Unseal(context.Background(), &pb.UnsealRequest{Password: "pw"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when not initialized")
|
||||
}
|
||||
if code := status.Code(err); code != codes.FailedPrecondition {
|
||||
t.Errorf("expected FailedPrecondition, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemUnseal_InvalidPassword(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ss := &systemServer{s: srv}
|
||||
if _, err := ss.Init(context.Background(), &pb.InitRequest{Password: "correct"}); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
if err := srv.sealMgr.Seal(); err != nil {
|
||||
t.Fatalf("Seal: %v", err)
|
||||
}
|
||||
|
||||
_, err := ss.Unseal(context.Background(), &pb.UnsealRequest{Password: "wrong"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for wrong password")
|
||||
}
|
||||
if code := status.Code(err); code != codes.Unauthenticated {
|
||||
t.Errorf("expected Unauthenticated, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemUnseal_Success(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ss := &systemServer{s: srv}
|
||||
if _, err := ss.Init(context.Background(), &pb.InitRequest{Password: "pw"}); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
if err := srv.sealMgr.Seal(); err != nil {
|
||||
t.Fatalf("Seal: %v", err)
|
||||
}
|
||||
|
||||
resp, err := ss.Unseal(context.Background(), &pb.UnsealRequest{Password: "pw"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unseal: %v", err)
|
||||
}
|
||||
if resp.State != "unsealed" {
|
||||
t.Errorf("expected 'unsealed', got %q", resp.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemSeal_Success(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ss := &systemServer{s: srv}
|
||||
if _, err := ss.Init(context.Background(), &pb.InitRequest{Password: "pw"}); err != nil {
|
||||
t.Fatalf("Init: %v", err)
|
||||
}
|
||||
|
||||
resp, err := ss.Seal(context.Background(), &pb.SealRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("Seal: %v", err)
|
||||
}
|
||||
if resp.State != "sealed" {
|
||||
t.Errorf("expected 'sealed', got %q", resp.State)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- engineServer tests ----
|
||||
|
||||
func TestEngineMount_MissingFields(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
es := &engineServer{s: srv}
|
||||
|
||||
_, err := es.Mount(context.Background(), &pb.MountRequest{Name: "", Type: "transit"})
|
||||
if code := status.Code(err); code != codes.InvalidArgument {
|
||||
t.Errorf("empty name: expected InvalidArgument, got %v", code)
|
||||
}
|
||||
|
||||
_, err = es.Mount(context.Background(), &pb.MountRequest{Name: "default", Type: ""})
|
||||
if code := status.Code(err); code != codes.InvalidArgument {
|
||||
t.Errorf("empty type: expected InvalidArgument, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineMount_UnknownType(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
es := &engineServer{s: srv}
|
||||
_, err := es.Mount(context.Background(), &pb.MountRequest{Name: "test", Type: "unknown"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown engine type")
|
||||
}
|
||||
if code := status.Code(err); code != codes.InvalidArgument {
|
||||
t.Errorf("expected InvalidArgument, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineMount_Success(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
es := &engineServer{s: srv}
|
||||
_, err := es.Mount(context.Background(), &pb.MountRequest{Name: "default", Type: "transit"})
|
||||
if err != nil {
|
||||
t.Fatalf("Mount: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineMount_Duplicate(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
es := &engineServer{s: srv}
|
||||
if _, err := es.Mount(context.Background(), &pb.MountRequest{Name: "default", Type: "transit"}); err != nil {
|
||||
t.Fatalf("first Mount: %v", err)
|
||||
}
|
||||
_, err := es.Mount(context.Background(), &pb.MountRequest{Name: "default", Type: "transit"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate mount")
|
||||
}
|
||||
if code := status.Code(err); code != codes.AlreadyExists {
|
||||
t.Errorf("expected AlreadyExists, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineUnmount_MissingName(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
es := &engineServer{s: srv}
|
||||
_, err := es.Unmount(context.Background(), &pb.UnmountRequest{Name: ""})
|
||||
if code := status.Code(err); code != codes.InvalidArgument {
|
||||
t.Errorf("expected InvalidArgument, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineUnmount_NotFound(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
es := &engineServer{s: srv}
|
||||
_, err := es.Unmount(context.Background(), &pb.UnmountRequest{Name: "nonexistent"})
|
||||
if code := status.Code(err); code != codes.NotFound {
|
||||
t.Errorf("expected NotFound, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineUnmount_Success(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
es := &engineServer{s: srv}
|
||||
if _, err := es.Mount(context.Background(), &pb.MountRequest{Name: "default", Type: "transit"}); err != nil {
|
||||
t.Fatalf("Mount: %v", err)
|
||||
}
|
||||
if _, err := es.Unmount(context.Background(), &pb.UnmountRequest{Name: "default"}); err != nil {
|
||||
t.Fatalf("Unmount: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngineListMounts(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
es := &engineServer{s: srv}
|
||||
|
||||
resp, err := es.ListMounts(context.Background(), &pb.ListMountsRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListMounts: %v", err)
|
||||
}
|
||||
if len(resp.Mounts) != 0 {
|
||||
t.Errorf("expected 0 mounts, got %d", len(resp.Mounts))
|
||||
}
|
||||
|
||||
if _, err := es.Mount(context.Background(), &pb.MountRequest{Name: "eng1", Type: "transit"}); err != nil {
|
||||
t.Fatalf("Mount: %v", err)
|
||||
}
|
||||
|
||||
resp, err = es.ListMounts(context.Background(), &pb.ListMountsRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListMounts after mount: %v", err)
|
||||
}
|
||||
if len(resp.Mounts) != 1 {
|
||||
t.Errorf("expected 1 mount, got %d", len(resp.Mounts))
|
||||
}
|
||||
if resp.Mounts[0].Name != "eng1" {
|
||||
t.Errorf("mount name: got %q, want %q", resp.Mounts[0].Name, "eng1")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- policyServer tests ----
|
||||
|
||||
func TestPolicyCreate_MissingID(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ps := &policyServer{s: srv}
|
||||
_, err := ps.CreatePolicy(context.Background(), &pb.CreatePolicyRequest{
|
||||
Rule: &pb.PolicyRule{Id: ""},
|
||||
})
|
||||
if code := status.Code(err); code != codes.InvalidArgument {
|
||||
t.Errorf("expected InvalidArgument, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyCreate_NilRule(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ps := &policyServer{s: srv}
|
||||
_, err := ps.CreatePolicy(context.Background(), &pb.CreatePolicyRequest{Rule: nil})
|
||||
if code := status.Code(err); code != codes.InvalidArgument {
|
||||
t.Errorf("expected InvalidArgument, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyRoundtrip(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Policy engine needs an unsealed barrier; unseal it via the seal manager.
|
||||
if err := srv.sealMgr.Initialize(context.Background(), []byte("pw"), fastArgon2Params()); err != nil {
|
||||
t.Fatalf("initialize: %v", err)
|
||||
}
|
||||
|
||||
ps := &policyServer{s: srv}
|
||||
rule := &pb.PolicyRule{
|
||||
Id: "rule-1",
|
||||
Priority: 10,
|
||||
Effect: "allow",
|
||||
Usernames: []string{"alice"},
|
||||
Resources: []string{"/ca/*"},
|
||||
Actions: []string{"read"},
|
||||
}
|
||||
|
||||
// Create.
|
||||
createResp, err := ps.CreatePolicy(context.Background(), &pb.CreatePolicyRequest{Rule: rule})
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePolicy: %v", err)
|
||||
}
|
||||
if createResp.Rule.Id != "rule-1" {
|
||||
t.Errorf("created rule id: got %q, want %q", createResp.Rule.Id, "rule-1")
|
||||
}
|
||||
|
||||
// Get.
|
||||
getResp, err := ps.GetPolicy(context.Background(), &pb.GetPolicyRequest{Id: "rule-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetPolicy: %v", err)
|
||||
}
|
||||
if getResp.Rule.Id != "rule-1" {
|
||||
t.Errorf("get rule id: got %q, want %q", getResp.Rule.Id, "rule-1")
|
||||
}
|
||||
|
||||
// List.
|
||||
listResp, err := ps.ListPolicies(context.Background(), &pb.ListPoliciesRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListPolicies: %v", err)
|
||||
}
|
||||
if len(listResp.Rules) != 1 {
|
||||
t.Errorf("expected 1 rule, got %d", len(listResp.Rules))
|
||||
}
|
||||
|
||||
// Delete.
|
||||
if _, err := ps.DeletePolicy(context.Background(), &pb.DeletePolicyRequest{Id: "rule-1"}); err != nil {
|
||||
t.Fatalf("DeletePolicy: %v", err)
|
||||
}
|
||||
|
||||
// Get after delete should fail.
|
||||
_, err = ps.GetPolicy(context.Background(), &pb.GetPolicyRequest{Id: "rule-1"})
|
||||
if code := status.Code(err); code != codes.NotFound {
|
||||
t.Errorf("expected NotFound after delete, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyGet_MissingID(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ps := &policyServer{s: srv}
|
||||
_, err := ps.GetPolicy(context.Background(), &pb.GetPolicyRequest{Id: ""})
|
||||
if code := status.Code(err); code != codes.InvalidArgument {
|
||||
t.Errorf("expected InvalidArgument, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyDelete_MissingID(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
ps := &policyServer{s: srv}
|
||||
_, err := ps.DeletePolicy(context.Background(), &pb.DeletePolicyRequest{Id: ""})
|
||||
if code := status.Code(err); code != codes.InvalidArgument {
|
||||
t.Errorf("expected InvalidArgument, got %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- authServer tests ----
|
||||
|
||||
func TestAuthTokenInfo_FromContext(t *testing.T) {
|
||||
srv, cleanup := newTestGRPCServer(t)
|
||||
defer cleanup()
|
||||
|
||||
as := &authServer{s: srv}
|
||||
ti := &auth.TokenInfo{Username: "alice", Roles: []string{"user"}, IsAdmin: false}
|
||||
ctx := context.WithValue(context.Background(), tokenInfoKey, ti)
|
||||
|
||||
resp, err := as.TokenInfo(ctx, &pb.TokenInfoRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("TokenInfo: %v", err)
|
||||
}
|
||||
if resp.Username != "alice" {
|
||||
t.Errorf("username: got %q, want %q", resp.Username, "alice")
|
||||
}
|
||||
if resp.IsAdmin {
|
||||
t.Error("expected IsAdmin=false")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ---- pbToRule / ruleToPB conversion tests ----
|
||||
|
||||
func TestPbToRuleRoundtrip(t *testing.T) {
|
||||
original := &pb.PolicyRule{
|
||||
Id: "test-rule",
|
||||
Priority: 5,
|
||||
Effect: "deny",
|
||||
Usernames: []string{"bob"},
|
||||
Roles: []string{"operator"},
|
||||
Resources: []string{"/pki/*"},
|
||||
Actions: []string{"write", "delete"},
|
||||
}
|
||||
|
||||
rule := pbToRule(original)
|
||||
if rule.ID != original.Id {
|
||||
t.Errorf("ID: got %q, want %q", rule.ID, original.Id)
|
||||
}
|
||||
if rule.Priority != int(original.Priority) {
|
||||
t.Errorf("Priority: got %d, want %d", rule.Priority, original.Priority)
|
||||
}
|
||||
if string(rule.Effect) != original.Effect {
|
||||
t.Errorf("Effect: got %q, want %q", rule.Effect, original.Effect)
|
||||
}
|
||||
|
||||
back := ruleToPB(rule)
|
||||
if back.Id != original.Id {
|
||||
t.Errorf("roundtrip Id: got %q, want %q", back.Id, original.Id)
|
||||
}
|
||||
if back.Priority != original.Priority {
|
||||
t.Errorf("roundtrip Priority: got %d, want %d", back.Priority, original.Priority)
|
||||
}
|
||||
if back.Effect != original.Effect {
|
||||
t.Errorf("roundtrip Effect: got %q, want %q", back.Effect, original.Effect)
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v1"
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine/ca"
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v1"
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/policy"
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v1"
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
internacme "git.wntrmute.dev/kyle/metacrypt/internal/acme"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/auth"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/config"
|
||||
@@ -79,6 +79,7 @@ func (s *GRPCServer) Start() error {
|
||||
pb.RegisterAuthServiceServer(s.srv, &authServer{s: s})
|
||||
pb.RegisterEngineServiceServer(s.srv, &engineServer{s: s})
|
||||
pb.RegisterPKIServiceServer(s.srv, &pkiServer{s: s})
|
||||
pb.RegisterCAServiceServer(s.srv, &caServer{s: s})
|
||||
pb.RegisterPolicyServiceServer(s.srv, &policyServer{s: s})
|
||||
pb.RegisterACMEServiceServer(s.srv, &acmeServer{s: s})
|
||||
|
||||
@@ -105,57 +106,77 @@ func (s *GRPCServer) Shutdown() {
|
||||
// to be unsealed.
|
||||
func sealRequiredMethods() map[string]bool {
|
||||
return map[string]bool{
|
||||
"/metacrypt.v1.AuthService/Login": true,
|
||||
"/metacrypt.v1.AuthService/Logout": true,
|
||||
"/metacrypt.v1.AuthService/TokenInfo": true,
|
||||
"/metacrypt.v1.EngineService/Mount": true,
|
||||
"/metacrypt.v1.EngineService/Unmount": true,
|
||||
"/metacrypt.v1.EngineService/ListMounts": true,
|
||||
"/metacrypt.v1.EngineService/Execute": true,
|
||||
"/metacrypt.v1.PKIService/GetRootCert": true,
|
||||
"/metacrypt.v1.PKIService/GetChain": true,
|
||||
"/metacrypt.v1.PKIService/GetIssuerCert": true,
|
||||
"/metacrypt.v1.PolicyService/CreatePolicy": true,
|
||||
"/metacrypt.v1.PolicyService/ListPolicies": true,
|
||||
"/metacrypt.v1.PolicyService/GetPolicy": true,
|
||||
"/metacrypt.v1.PolicyService/DeletePolicy": true,
|
||||
"/metacrypt.v1.ACMEService/CreateEAB": true,
|
||||
"/metacrypt.v1.ACMEService/SetConfig": true,
|
||||
"/metacrypt.v1.ACMEService/ListAccounts": true,
|
||||
"/metacrypt.v1.ACMEService/ListOrders": true,
|
||||
"/metacrypt.v2.AuthService/Login": true,
|
||||
"/metacrypt.v2.AuthService/Logout": true,
|
||||
"/metacrypt.v2.AuthService/TokenInfo": true,
|
||||
"/metacrypt.v2.EngineService/Mount": true,
|
||||
"/metacrypt.v2.EngineService/Unmount": true,
|
||||
"/metacrypt.v2.EngineService/ListMounts": true,
|
||||
"/metacrypt.v2.PKIService/GetRootCert": true,
|
||||
"/metacrypt.v2.PKIService/GetChain": true,
|
||||
"/metacrypt.v2.PKIService/GetIssuerCert": true,
|
||||
"/metacrypt.v2.CAService/ImportRoot": true,
|
||||
"/metacrypt.v2.CAService/GetRoot": true,
|
||||
"/metacrypt.v2.CAService/CreateIssuer": true,
|
||||
"/metacrypt.v2.CAService/DeleteIssuer": true,
|
||||
"/metacrypt.v2.CAService/ListIssuers": true,
|
||||
"/metacrypt.v2.CAService/GetIssuer": true,
|
||||
"/metacrypt.v2.CAService/GetChain": true,
|
||||
"/metacrypt.v2.CAService/IssueCert": true,
|
||||
"/metacrypt.v2.CAService/GetCert": true,
|
||||
"/metacrypt.v2.CAService/ListCerts": true,
|
||||
"/metacrypt.v2.CAService/RenewCert": true,
|
||||
"/metacrypt.v2.PolicyService/CreatePolicy": true,
|
||||
"/metacrypt.v2.PolicyService/ListPolicies": true,
|
||||
"/metacrypt.v2.PolicyService/GetPolicy": true,
|
||||
"/metacrypt.v2.PolicyService/DeletePolicy": true,
|
||||
"/metacrypt.v2.ACMEService/CreateEAB": true,
|
||||
"/metacrypt.v2.ACMEService/SetConfig": true,
|
||||
"/metacrypt.v2.ACMEService/ListAccounts": true,
|
||||
"/metacrypt.v2.ACMEService/ListOrders": true,
|
||||
}
|
||||
}
|
||||
|
||||
// authRequiredMethods returns the set of RPC full names that require a valid token.
|
||||
func authRequiredMethods() map[string]bool {
|
||||
return map[string]bool{
|
||||
"/metacrypt.v1.AuthService/Logout": true,
|
||||
"/metacrypt.v1.AuthService/TokenInfo": true,
|
||||
"/metacrypt.v1.EngineService/Mount": true,
|
||||
"/metacrypt.v1.EngineService/Unmount": true,
|
||||
"/metacrypt.v1.EngineService/ListMounts": true,
|
||||
"/metacrypt.v1.EngineService/Execute": true,
|
||||
"/metacrypt.v1.PolicyService/CreatePolicy": true,
|
||||
"/metacrypt.v1.PolicyService/ListPolicies": true,
|
||||
"/metacrypt.v1.PolicyService/GetPolicy": true,
|
||||
"/metacrypt.v1.PolicyService/DeletePolicy": true,
|
||||
"/metacrypt.v1.ACMEService/CreateEAB": true,
|
||||
"/metacrypt.v1.ACMEService/SetConfig": true,
|
||||
"/metacrypt.v1.ACMEService/ListAccounts": true,
|
||||
"/metacrypt.v1.ACMEService/ListOrders": true,
|
||||
"/metacrypt.v2.AuthService/Logout": true,
|
||||
"/metacrypt.v2.AuthService/TokenInfo": true,
|
||||
"/metacrypt.v2.EngineService/Mount": true,
|
||||
"/metacrypt.v2.EngineService/Unmount": true,
|
||||
"/metacrypt.v2.EngineService/ListMounts": true,
|
||||
"/metacrypt.v2.CAService/ImportRoot": true,
|
||||
"/metacrypt.v2.CAService/CreateIssuer": true,
|
||||
"/metacrypt.v2.CAService/DeleteIssuer": true,
|
||||
"/metacrypt.v2.CAService/ListIssuers": true,
|
||||
"/metacrypt.v2.CAService/IssueCert": true,
|
||||
"/metacrypt.v2.CAService/GetCert": true,
|
||||
"/metacrypt.v2.CAService/ListCerts": true,
|
||||
"/metacrypt.v2.CAService/RenewCert": true,
|
||||
"/metacrypt.v2.PolicyService/CreatePolicy": true,
|
||||
"/metacrypt.v2.PolicyService/ListPolicies": true,
|
||||
"/metacrypt.v2.PolicyService/GetPolicy": true,
|
||||
"/metacrypt.v2.PolicyService/DeletePolicy": true,
|
||||
"/metacrypt.v2.ACMEService/CreateEAB": true,
|
||||
"/metacrypt.v2.ACMEService/SetConfig": true,
|
||||
"/metacrypt.v2.ACMEService/ListAccounts": true,
|
||||
"/metacrypt.v2.ACMEService/ListOrders": true,
|
||||
}
|
||||
}
|
||||
|
||||
// adminRequiredMethods returns the set of RPC full names that require admin.
|
||||
func adminRequiredMethods() map[string]bool {
|
||||
return map[string]bool{
|
||||
"/metacrypt.v1.SystemService/Seal": true,
|
||||
"/metacrypt.v1.EngineService/Mount": true,
|
||||
"/metacrypt.v1.EngineService/Unmount": true,
|
||||
"/metacrypt.v1.PolicyService/CreatePolicy": true,
|
||||
"/metacrypt.v1.PolicyService/DeletePolicy": true,
|
||||
"/metacrypt.v1.ACMEService/SetConfig": true,
|
||||
"/metacrypt.v1.ACMEService/ListAccounts": true,
|
||||
"/metacrypt.v1.ACMEService/ListOrders": true,
|
||||
"/metacrypt.v2.SystemService/Seal": true,
|
||||
"/metacrypt.v2.EngineService/Mount": true,
|
||||
"/metacrypt.v2.EngineService/Unmount": true,
|
||||
"/metacrypt.v2.CAService/ImportRoot": true,
|
||||
"/metacrypt.v2.CAService/CreateIssuer": true,
|
||||
"/metacrypt.v2.CAService/DeleteIssuer": true,
|
||||
"/metacrypt.v2.PolicyService/CreatePolicy": true,
|
||||
"/metacrypt.v2.PolicyService/DeletePolicy": true,
|
||||
"/metacrypt.v2.ACMEService/SetConfig": true,
|
||||
"/metacrypt.v2.ACMEService/ListAccounts": true,
|
||||
"/metacrypt.v2.ACMEService/ListOrders": true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v1"
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/seal"
|
||||
)
|
||||
|
||||
@@ -96,15 +96,19 @@ func (s *Server) handleUnseal(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := s.seal.Unseal([]byte(req.Password)); err != nil {
|
||||
if errors.Is(err, seal.ErrNotInitialized) {
|
||||
s.logger.Warn("unseal attempt on uninitialized service", "remote_addr", r.RemoteAddr)
|
||||
http.Error(w, `{"error":"not initialized"}`, http.StatusPreconditionFailed)
|
||||
} else if errors.Is(err, seal.ErrInvalidPassword) {
|
||||
s.logger.Warn("unseal attempt with invalid password", "remote_addr", r.RemoteAddr)
|
||||
http.Error(w, `{"error":"invalid password"}`, http.StatusUnauthorized)
|
||||
} else if errors.Is(err, seal.ErrRateLimited) {
|
||||
s.logger.Warn("unseal attempt rate limited", "remote_addr", r.RemoteAddr)
|
||||
http.Error(w, `{"error":"too many attempts, try again later"}`, http.StatusTooManyRequests)
|
||||
} else if errors.Is(err, seal.ErrNotSealed) {
|
||||
s.logger.Warn("unseal attempt on already-unsealed service", "remote_addr", r.RemoteAddr)
|
||||
http.Error(w, `{"error":"already unsealed"}`, http.StatusConflict)
|
||||
} else {
|
||||
s.logger.Error("unseal failed", "error", err)
|
||||
s.logger.Error("unseal failed", "remote_addr", r.RemoteAddr, "error", err)
|
||||
http.Error(w, `{"error":"unseal failed"}`, http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
@@ -116,6 +120,7 @@ func (s *Server) handleUnseal(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("service unsealed", "remote_addr", r.RemoteAddr)
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"state": s.seal.State().String(),
|
||||
})
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v1"
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
)
|
||||
|
||||
// VaultClient wraps the gRPC stubs for communicating with the vault.
|
||||
@@ -22,6 +22,7 @@ type VaultClient struct {
|
||||
auth pb.AuthServiceClient
|
||||
engine pb.EngineServiceClient
|
||||
pki pb.PKIServiceClient
|
||||
ca pb.CAServiceClient
|
||||
}
|
||||
|
||||
// NewVaultClient dials the vault gRPC server and returns a client.
|
||||
@@ -58,6 +59,7 @@ func NewVaultClient(addr, caCertPath string, logger *slog.Logger) (*VaultClient,
|
||||
auth: pb.NewAuthServiceClient(conn),
|
||||
engine: pb.NewEngineServiceClient(conn),
|
||||
pki: pb.NewPKIServiceClient(conn),
|
||||
ca: pb.NewCAServiceClient(conn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -156,39 +158,16 @@ func (c *VaultClient) Mount(ctx context.Context, token, name, engineType string,
|
||||
Type: engineType,
|
||||
}
|
||||
if len(config) > 0 {
|
||||
s, err := structFromMap(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("webserver: encode mount config: %w", err)
|
||||
cfg := make(map[string]string, len(config))
|
||||
for k, v := range config {
|
||||
cfg[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
req.Config = s
|
||||
req.Config = cfg
|
||||
}
|
||||
_, err := c.engine.Mount(withToken(ctx, token), req)
|
||||
return err
|
||||
}
|
||||
|
||||
// EngineRequest sends a generic engine operation to the vault.
|
||||
func (c *VaultClient) EngineRequest(ctx context.Context, token, mount, operation string, data map[string]interface{}) (map[string]interface{}, error) {
|
||||
req := &pb.ExecuteRequest{
|
||||
Mount: mount,
|
||||
Operation: operation,
|
||||
}
|
||||
if len(data) > 0 {
|
||||
s, err := structFromMap(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("webserver: encode engine request: %w", err)
|
||||
}
|
||||
req.Data = s
|
||||
}
|
||||
resp, err := c.engine.Execute(withToken(ctx, token), req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return resp.Data.AsMap(), nil
|
||||
}
|
||||
|
||||
// GetRootCert returns the root CA certificate PEM for the given mount.
|
||||
func (c *VaultClient) GetRootCert(ctx context.Context, mount string) ([]byte, error) {
|
||||
resp, err := c.pki.GetRootCert(ctx, &pb.GetRootCertRequest{Mount: mount})
|
||||
@@ -206,3 +185,126 @@ func (c *VaultClient) GetIssuerCert(ctx context.Context, mount, issuer string) (
|
||||
}
|
||||
return resp.CertPem, nil
|
||||
}
|
||||
|
||||
// ImportRoot imports an existing root CA certificate and key into the given mount.
|
||||
func (c *VaultClient) ImportRoot(ctx context.Context, token, mount, certPEM, keyPEM string) error {
|
||||
_, err := c.ca.ImportRoot(withToken(ctx, token), &pb.ImportRootRequest{
|
||||
Mount: mount,
|
||||
CertPem: []byte(certPEM),
|
||||
KeyPem: []byte(keyPEM),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateIssuerRequest holds parameters for creating an intermediate CA issuer.
|
||||
type CreateIssuerRequest struct {
|
||||
Mount string
|
||||
Name string
|
||||
KeyAlgorithm string
|
||||
KeySize int32
|
||||
Expiry string
|
||||
MaxTTL string
|
||||
}
|
||||
|
||||
// CreateIssuer creates a new intermediate CA issuer on the given mount.
|
||||
func (c *VaultClient) CreateIssuer(ctx context.Context, token string, req CreateIssuerRequest) error {
|
||||
_, err := c.ca.CreateIssuer(withToken(ctx, token), &pb.CreateIssuerRequest{
|
||||
Mount: req.Mount,
|
||||
Name: req.Name,
|
||||
KeyAlgorithm: req.KeyAlgorithm,
|
||||
KeySize: req.KeySize,
|
||||
Expiry: req.Expiry,
|
||||
MaxTtl: req.MaxTTL,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// ListIssuers returns the names of all issuers for the given mount.
|
||||
func (c *VaultClient) ListIssuers(ctx context.Context, token, mount string) ([]string, error) {
|
||||
resp, err := c.ca.ListIssuers(withToken(ctx, token), &pb.ListIssuersRequest{Mount: mount})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Issuers, nil
|
||||
}
|
||||
|
||||
// IssueCertRequest holds parameters for issuing a leaf certificate.
|
||||
type IssueCertRequest struct {
|
||||
Mount string
|
||||
Issuer string
|
||||
Profile string
|
||||
CommonName string
|
||||
DNSNames []string
|
||||
IPAddresses []string
|
||||
TTL string
|
||||
KeyUsages []string
|
||||
ExtKeyUsages []string
|
||||
}
|
||||
|
||||
// IssuedCert holds the result of a certificate issuance.
|
||||
type IssuedCert struct {
|
||||
Serial string
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
ChainPEM string
|
||||
ExpiresAt string
|
||||
}
|
||||
|
||||
// IssueCert issues a new leaf certificate from the named issuer.
|
||||
func (c *VaultClient) IssueCert(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error) {
|
||||
resp, err := c.ca.IssueCert(withToken(ctx, token), &pb.IssueCertRequest{
|
||||
Mount: req.Mount,
|
||||
Issuer: req.Issuer,
|
||||
Profile: req.Profile,
|
||||
CommonName: req.CommonName,
|
||||
DnsNames: req.DNSNames,
|
||||
IpAddresses: req.IPAddresses,
|
||||
Ttl: req.TTL,
|
||||
KeyUsages: req.KeyUsages,
|
||||
ExtKeyUsages: req.ExtKeyUsages,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
issued := &IssuedCert{
|
||||
Serial: resp.Serial,
|
||||
CertPEM: string(resp.CertPem),
|
||||
KeyPEM: string(resp.KeyPem),
|
||||
ChainPEM: string(resp.ChainPem),
|
||||
}
|
||||
if resp.ExpiresAt != nil {
|
||||
issued.ExpiresAt = resp.ExpiresAt.AsTime().Format("2006-01-02T15:04:05Z")
|
||||
}
|
||||
return issued, nil
|
||||
}
|
||||
|
||||
// CertSummary holds lightweight certificate metadata for list views.
|
||||
type CertSummary struct {
|
||||
Serial string
|
||||
Issuer string
|
||||
CommonName string
|
||||
Profile string
|
||||
ExpiresAt string
|
||||
}
|
||||
|
||||
// ListCerts returns all certificate summaries for the given mount.
|
||||
func (c *VaultClient) ListCerts(ctx context.Context, token, mount string) ([]CertSummary, error) {
|
||||
resp, err := c.ca.ListCerts(withToken(ctx, token), &pb.ListCertsRequest{Mount: mount})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certs := make([]CertSummary, 0, len(resp.Certs))
|
||||
for _, s := range resp.Certs {
|
||||
cs := CertSummary{
|
||||
Serial: s.Serial,
|
||||
Issuer: s.Issuer,
|
||||
CommonName: s.CommonName,
|
||||
Profile: s.Profile,
|
||||
}
|
||||
if s.ExpiresAt != nil {
|
||||
cs.ExpiresAt = s.ExpiresAt.AsTime().Format("2006-01-02T15:04:05Z")
|
||||
}
|
||||
certs = append(certs, cs)
|
||||
}
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
@@ -280,8 +280,8 @@ func (ws *WebServer) handlePKI(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
if resp, err := ws.vault.EngineRequest(r.Context(), token, mountName, "list-issuers", nil); err == nil {
|
||||
data["Issuers"] = resp["issuers"]
|
||||
if issuers, err := ws.vault.ListIssuers(r.Context(), token, mountName); err == nil {
|
||||
data["Issuers"] = issuers
|
||||
}
|
||||
|
||||
ws.renderTemplate(w, "pki.html", data)
|
||||
@@ -329,11 +329,7 @@ func (ws *WebServer) handleImportRoot(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = ws.vault.EngineRequest(r.Context(), token, mountName, "import-root", map[string]interface{}{
|
||||
"cert_pem": certPEM,
|
||||
"key_pem": keyPEM,
|
||||
})
|
||||
if err != nil {
|
||||
if err = ws.vault.ImportRoot(r.Context(), token, mountName, certPEM, keyPEM); err != nil {
|
||||
ws.renderPKIWithError(w, r, mountName, info, grpcMessage(err))
|
||||
return
|
||||
}
|
||||
@@ -362,25 +358,27 @@ func (ws *WebServer) handleCreateIssuer(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
reqData := map[string]interface{}{"name": name}
|
||||
issuerReq := CreateIssuerRequest{
|
||||
Mount: mountName,
|
||||
Name: name,
|
||||
}
|
||||
if v := r.FormValue("expiry"); v != "" {
|
||||
reqData["expiry"] = v
|
||||
issuerReq.Expiry = v
|
||||
}
|
||||
if v := r.FormValue("max_ttl"); v != "" {
|
||||
reqData["max_ttl"] = v
|
||||
issuerReq.MaxTTL = v
|
||||
}
|
||||
if v := r.FormValue("key_algorithm"); v != "" {
|
||||
reqData["key_algorithm"] = v
|
||||
issuerReq.KeyAlgorithm = v
|
||||
}
|
||||
if v := r.FormValue("key_size"); v != "" {
|
||||
var size float64
|
||||
if _, err := fmt.Sscanf(v, "%f", &size); err == nil {
|
||||
reqData["key_size"] = size
|
||||
var size int32
|
||||
if _, err := fmt.Sscanf(v, "%d", &size); err == nil {
|
||||
issuerReq.KeySize = size
|
||||
}
|
||||
}
|
||||
|
||||
_, err = ws.vault.EngineRequest(r.Context(), token, mountName, "create-issuer", reqData)
|
||||
if err != nil {
|
||||
if err = ws.vault.CreateIssuer(r.Context(), token, issuerReq); err != nil {
|
||||
ws.renderPKIWithError(w, r, mountName, info, grpcMessage(err))
|
||||
return
|
||||
}
|
||||
@@ -419,7 +417,7 @@ func (ws *WebServer) handleIssuerDetail(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
issuerName := chi.URLParam(r, "issuer")
|
||||
|
||||
resp, err := ws.vault.EngineRequest(r.Context(), token, mountName, "list-certs", nil)
|
||||
allCerts, err := ws.vault.ListCerts(r.Context(), token, mountName)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to list certificates", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -431,34 +429,22 @@ func (ws *WebServer) handleIssuerDetail(w http.ResponseWriter, r *http.Request)
|
||||
sortBy = "cn"
|
||||
}
|
||||
|
||||
var certs []map[string]interface{}
|
||||
if raw, ok := resp["certs"]; ok {
|
||||
if list, ok := raw.([]interface{}); ok {
|
||||
for _, item := range list {
|
||||
if m, ok := item.(map[string]interface{}); ok {
|
||||
issuer, _ := m["issuer"].(string)
|
||||
if issuer != issuerName {
|
||||
continue
|
||||
}
|
||||
if nameFilter != "" {
|
||||
cn, _ := m["cn"].(string)
|
||||
if !strings.Contains(strings.ToLower(cn), nameFilter) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
certs = append(certs, m)
|
||||
}
|
||||
}
|
||||
var certs []CertSummary
|
||||
for _, cs := range allCerts {
|
||||
if cs.Issuer != issuerName {
|
||||
continue
|
||||
}
|
||||
if nameFilter != "" && !strings.Contains(strings.ToLower(cs.CommonName), strings.ToLower(nameFilter)) {
|
||||
continue
|
||||
}
|
||||
certs = append(certs, cs)
|
||||
}
|
||||
|
||||
// Sort: by expiry date or by common name (default).
|
||||
if sortBy == "expiry" {
|
||||
for i := 1; i < len(certs); i++ {
|
||||
for j := i; j > 0; j-- {
|
||||
a, _ := certs[j-1]["expires_at"].(string)
|
||||
b, _ := certs[j]["expires_at"].(string)
|
||||
if a > b {
|
||||
if certs[j-1].ExpiresAt > certs[j].ExpiresAt {
|
||||
certs[j-1], certs[j] = certs[j], certs[j-1]
|
||||
}
|
||||
}
|
||||
@@ -466,9 +452,7 @@ func (ws *WebServer) handleIssuerDetail(w http.ResponseWriter, r *http.Request)
|
||||
} else {
|
||||
for i := 1; i < len(certs); i++ {
|
||||
for j := i; j > 0; j-- {
|
||||
a, _ := certs[j-1]["cn"].(string)
|
||||
b, _ := certs[j]["cn"].(string)
|
||||
if strings.ToLower(a) > strings.ToLower(b) {
|
||||
if strings.ToLower(certs[j-1].CommonName) > strings.ToLower(certs[j].CommonName) {
|
||||
certs[j-1], certs[j] = certs[j], certs[j-1]
|
||||
}
|
||||
}
|
||||
@@ -512,24 +496,31 @@ func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
reqData := map[string]interface{}{
|
||||
"common_name": commonName,
|
||||
"issuer": issuer,
|
||||
certReq := IssueCertRequest{
|
||||
Mount: mountName,
|
||||
Issuer: issuer,
|
||||
CommonName: commonName,
|
||||
}
|
||||
if v := r.FormValue("profile"); v != "" {
|
||||
reqData["profile"] = v
|
||||
certReq.Profile = v
|
||||
}
|
||||
if v := r.FormValue("ttl"); v != "" {
|
||||
reqData["ttl"] = v
|
||||
certReq.TTL = v
|
||||
}
|
||||
if lines := splitLines(r.FormValue("dns_names")); len(lines) > 0 {
|
||||
reqData["dns_names"] = lines
|
||||
for _, l := range lines {
|
||||
certReq.DNSNames = append(certReq.DNSNames, l.(string))
|
||||
}
|
||||
}
|
||||
if lines := splitLines(r.FormValue("ip_addresses")); len(lines) > 0 {
|
||||
reqData["ip_addresses"] = lines
|
||||
for _, l := range lines {
|
||||
certReq.IPAddresses = append(certReq.IPAddresses, l.(string))
|
||||
}
|
||||
}
|
||||
certReq.KeyUsages = r.Form["key_usages"]
|
||||
certReq.ExtKeyUsages = r.Form["ext_key_usages"]
|
||||
|
||||
resp, err := ws.vault.EngineRequest(r.Context(), token, mountName, "issue", reqData)
|
||||
issuedCert, err := ws.vault.IssueCert(r.Context(), token, certReq)
|
||||
if err != nil {
|
||||
ws.renderPKIWithError(w, r, mountName, info, grpcMessage(err))
|
||||
return
|
||||
@@ -537,10 +528,10 @@ func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Re-render the PKI page with the issued certificate displayed.
|
||||
data := map[string]interface{}{
|
||||
"Username": info.Username,
|
||||
"IsAdmin": info.IsAdmin,
|
||||
"MountName": mountName,
|
||||
"IssuedCert": resp,
|
||||
"Username": info.Username,
|
||||
"IsAdmin": info.IsAdmin,
|
||||
"MountName": mountName,
|
||||
"IssuedCert": issuedCert,
|
||||
}
|
||||
if rootPEM, err := ws.vault.GetRootCert(r.Context(), mountName); err == nil && len(rootPEM) > 0 {
|
||||
if cert, err := parsePEMCert(rootPEM); err == nil {
|
||||
@@ -552,8 +543,8 @@ func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) {
|
||||
data["HasRoot"] = true
|
||||
}
|
||||
}
|
||||
if issuerResp, err := ws.vault.EngineRequest(r.Context(), token, mountName, "list-issuers", nil); err == nil {
|
||||
data["Issuers"] = issuerResp["issuers"]
|
||||
if issuers, err := ws.vault.ListIssuers(r.Context(), token, mountName); err == nil {
|
||||
data["Issuers"] = issuers
|
||||
}
|
||||
ws.renderTemplate(w, "pki.html", data)
|
||||
}
|
||||
@@ -577,8 +568,8 @@ func (ws *WebServer) renderPKIWithError(w http.ResponseWriter, r *http.Request,
|
||||
data["HasRoot"] = true
|
||||
}
|
||||
}
|
||||
if resp, err := ws.vault.EngineRequest(r.Context(), token, mountName, "list-issuers", nil); err == nil {
|
||||
data["Issuers"] = resp["issuers"]
|
||||
if issuers, err := ws.vault.ListIssuers(r.Context(), token, mountName); err == nil {
|
||||
data["Issuers"] = issuers
|
||||
}
|
||||
|
||||
ws.renderTemplate(w, "pki.html", data)
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
@@ -22,11 +20,6 @@ func tokenInfoFromContext(ctx context.Context) *TokenInfo {
|
||||
return v
|
||||
}
|
||||
|
||||
// structFromMap converts a map[string]interface{} to a *structpb.Struct.
|
||||
func structFromMap(m map[string]interface{}) (*structpb.Struct, error) {
|
||||
return structpb.NewStruct(m)
|
||||
}
|
||||
|
||||
// parsePEMCert decodes the first PEM block and parses it as an x509 certificate.
|
||||
func parsePEMCert(pemData []byte) (*x509.Certificate, error) {
|
||||
block, _ := pem.Decode(pemData)
|
||||
|
||||
Reference in New Issue
Block a user