Add grpcserver test coverage
- Add comprehensive test file for internal/grpcserver package - Cover interceptors, system, engine, policy, and auth handlers - Cover pbToRule/ruleToPB conversion helpers - 37 tests total; CA/PKI/ACME and Login/Logout skipped (require live deps) Co-authored-by: Junie <junie@jetbrains.com>
This commit is contained in:
432
internal/grpcserver/ca.go
Normal file
432
internal/grpcserver/ca.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine/ca"
|
||||
)
|
||||
|
||||
type caServer struct {
|
||||
pb.UnimplementedCAServiceServer
|
||||
s *GRPCServer
|
||||
}
|
||||
|
||||
// caHandleRequest is a helper that dispatches a CA engine request and maps
|
||||
// common errors to gRPC status codes.
|
||||
func (cs *caServer) caHandleRequest(ctx context.Context, mount, operation string, req *engine.Request) (*engine.Response, error) {
|
||||
resp, err := cs.s.engines.HandleRequest(ctx, mount, req)
|
||||
if err != nil {
|
||||
st := codes.Internal
|
||||
switch {
|
||||
case errors.Is(err, engine.ErrMountNotFound):
|
||||
st = codes.NotFound
|
||||
case errors.Is(err, ca.ErrIssuerNotFound):
|
||||
st = codes.NotFound
|
||||
case errors.Is(err, ca.ErrCertNotFound):
|
||||
st = codes.NotFound
|
||||
case errors.Is(err, ca.ErrIssuerExists):
|
||||
st = codes.AlreadyExists
|
||||
case errors.Is(err, ca.ErrUnauthorized):
|
||||
st = codes.Unauthenticated
|
||||
case errors.Is(err, ca.ErrForbidden):
|
||||
st = codes.PermissionDenied
|
||||
case strings.Contains(err.Error(), "not found"):
|
||||
st = codes.NotFound
|
||||
}
|
||||
cs.s.logger.Error("grpc: ca "+operation, "mount", mount, "error", err)
|
||||
return nil, status.Error(st, err.Error())
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) callerInfo(ctx context.Context) *engine.CallerInfo {
|
||||
ti := tokenInfoFromContext(ctx)
|
||||
if ti == nil {
|
||||
return nil
|
||||
}
|
||||
return &engine.CallerInfo{
|
||||
Username: ti.Username,
|
||||
Roles: ti.Roles,
|
||||
IsAdmin: ti.IsAdmin,
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *caServer) ImportRoot(ctx context.Context, req *pb.ImportRootRequest) (*pb.ImportRootResponse, error) {
|
||||
if req.Mount == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount is required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "import-root", &engine.Request{
|
||||
Operation: "import-root",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: map[string]interface{}{
|
||||
"cert_pem": string(req.CertPem),
|
||||
"key_pem": string(req.KeyPem),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cn, _ := resp.Data["cn"].(string)
|
||||
var expiresAt *timestamppb.Timestamp
|
||||
if s, ok := resp.Data["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
return &pb.ImportRootResponse{CommonName: cn, ExpiresAt: expiresAt}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) GetRoot(ctx context.Context, req *pb.GetRootRequest) (*pb.GetRootResponse, error) {
|
||||
if req.Mount == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount is required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-root", &engine.Request{
|
||||
Operation: "get-root",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certPEM, _ := resp.Data["cert_pem"].(string)
|
||||
return &pb.GetRootResponse{CertPem: []byte(certPEM)}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) CreateIssuer(ctx context.Context, req *pb.CreateIssuerRequest) (*pb.CreateIssuerResponse, error) {
|
||||
if req.Mount == "" || req.Name == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
|
||||
}
|
||||
data := map[string]interface{}{
|
||||
"name": req.Name,
|
||||
}
|
||||
if req.KeyAlgorithm != "" {
|
||||
data["key_algorithm"] = req.KeyAlgorithm
|
||||
}
|
||||
if req.KeySize != 0 {
|
||||
data["key_size"] = float64(req.KeySize)
|
||||
}
|
||||
if req.Expiry != "" {
|
||||
data["expiry"] = req.Expiry
|
||||
}
|
||||
if req.MaxTtl != "" {
|
||||
data["max_ttl"] = req.MaxTtl
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "create-issuer", &engine.Request{
|
||||
Operation: "create-issuer",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name, _ := resp.Data["name"].(string)
|
||||
certPEM, _ := resp.Data["cert_pem"].(string)
|
||||
return &pb.CreateIssuerResponse{Name: name, CertPem: []byte(certPEM)}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) DeleteIssuer(ctx context.Context, req *pb.DeleteIssuerRequest) (*pb.DeleteIssuerResponse, error) {
|
||||
if req.Mount == "" || req.Name == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
|
||||
}
|
||||
_, err := cs.caHandleRequest(ctx, req.Mount, "delete-issuer", &engine.Request{
|
||||
Operation: "delete-issuer",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: map[string]interface{}{"name": req.Name},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pb.DeleteIssuerResponse{}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) ListIssuers(ctx context.Context, req *pb.ListIssuersRequest) (*pb.ListIssuersResponse, error) {
|
||||
if req.Mount == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount is required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "list-issuers", &engine.Request{
|
||||
Operation: "list-issuers",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, _ := resp.Data["issuers"].([]interface{})
|
||||
issuers := make([]string, 0, len(raw))
|
||||
for _, v := range raw {
|
||||
if s, ok := v.(string); ok {
|
||||
issuers = append(issuers, s)
|
||||
}
|
||||
}
|
||||
return &pb.ListIssuersResponse{Issuers: issuers}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) GetIssuer(ctx context.Context, req *pb.GetIssuerRequest) (*pb.GetIssuerResponse, error) {
|
||||
if req.Mount == "" || req.Name == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-issuer", &engine.Request{
|
||||
Operation: "get-issuer",
|
||||
Path: req.Name,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certPEM, _ := resp.Data["cert_pem"].(string)
|
||||
return &pb.GetIssuerResponse{CertPem: []byte(certPEM)}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) GetChain(ctx context.Context, req *pb.CAServiceGetChainRequest) (*pb.CAServiceGetChainResponse, error) {
|
||||
if req.Mount == "" || req.Issuer == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and issuer are required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-chain", &engine.Request{
|
||||
Operation: "get-chain",
|
||||
Data: map[string]interface{}{"issuer": req.Issuer},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chainPEM, _ := resp.Data["chain_pem"].(string)
|
||||
return &pb.CAServiceGetChainResponse{ChainPem: []byte(chainPEM)}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) IssueCert(ctx context.Context, req *pb.IssueCertRequest) (*pb.IssueCertResponse, error) {
|
||||
if req.Mount == "" || req.Issuer == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and issuer are required")
|
||||
}
|
||||
data := map[string]interface{}{
|
||||
"issuer": req.Issuer,
|
||||
}
|
||||
if req.Profile != "" {
|
||||
data["profile"] = req.Profile
|
||||
}
|
||||
if req.CommonName != "" {
|
||||
data["common_name"] = req.CommonName
|
||||
}
|
||||
if len(req.DnsNames) > 0 {
|
||||
dns := make([]interface{}, len(req.DnsNames))
|
||||
for i, v := range req.DnsNames {
|
||||
dns[i] = v
|
||||
}
|
||||
data["dns_names"] = dns
|
||||
}
|
||||
if len(req.IpAddresses) > 0 {
|
||||
ips := make([]interface{}, len(req.IpAddresses))
|
||||
for i, v := range req.IpAddresses {
|
||||
ips[i] = v
|
||||
}
|
||||
data["ip_addresses"] = ips
|
||||
}
|
||||
if req.Ttl != "" {
|
||||
data["ttl"] = req.Ttl
|
||||
}
|
||||
if req.KeyAlgorithm != "" {
|
||||
data["key_algorithm"] = req.KeyAlgorithm
|
||||
}
|
||||
if req.KeySize != 0 {
|
||||
data["key_size"] = float64(req.KeySize)
|
||||
}
|
||||
if len(req.KeyUsages) > 0 {
|
||||
ku := make([]interface{}, len(req.KeyUsages))
|
||||
for i, v := range req.KeyUsages {
|
||||
ku[i] = v
|
||||
}
|
||||
data["key_usages"] = ku
|
||||
}
|
||||
if len(req.ExtKeyUsages) > 0 {
|
||||
eku := make([]interface{}, len(req.ExtKeyUsages))
|
||||
for i, v := range req.ExtKeyUsages {
|
||||
eku[i] = v
|
||||
}
|
||||
data["ext_key_usages"] = eku
|
||||
}
|
||||
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "issue", &engine.Request{
|
||||
Operation: "issue",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: data,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serial, _ := resp.Data["serial"].(string)
|
||||
cn, _ := resp.Data["cn"].(string)
|
||||
issuedBy, _ := resp.Data["issued_by"].(string)
|
||||
certPEM, _ := resp.Data["cert_pem"].(string)
|
||||
keyPEM, _ := resp.Data["key_pem"].(string)
|
||||
chainPEM, _ := resp.Data["chain_pem"].(string)
|
||||
sans := toStringSliceFromInterface(resp.Data["sans"])
|
||||
var expiresAt *timestamppb.Timestamp
|
||||
if s, ok := resp.Data["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
return &pb.IssueCertResponse{
|
||||
Serial: serial,
|
||||
CommonName: cn,
|
||||
Sans: sans,
|
||||
IssuedBy: issuedBy,
|
||||
ExpiresAt: expiresAt,
|
||||
CertPem: []byte(certPEM),
|
||||
KeyPem: []byte(keyPEM),
|
||||
ChainPem: []byte(chainPEM),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) GetCert(ctx context.Context, req *pb.GetCertRequest) (*pb.GetCertResponse, error) {
|
||||
if req.Mount == "" || req.Serial == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-cert", &engine.Request{
|
||||
Operation: "get-cert",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: map[string]interface{}{"serial": req.Serial},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rec := certRecordFromData(resp.Data)
|
||||
return &pb.GetCertResponse{Cert: rec}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) ListCerts(ctx context.Context, req *pb.ListCertsRequest) (*pb.ListCertsResponse, error) {
|
||||
if req.Mount == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount is required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "list-certs", &engine.Request{
|
||||
Operation: "list-certs",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, _ := resp.Data["certs"].([]interface{})
|
||||
summaries := make([]*pb.CertSummary, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
m, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
summaries = append(summaries, certSummaryFromData(m))
|
||||
}
|
||||
return &pb.ListCertsResponse{Certs: summaries}, nil
|
||||
}
|
||||
|
||||
func (cs *caServer) RenewCert(ctx context.Context, req *pb.RenewCertRequest) (*pb.RenewCertResponse, error) {
|
||||
if req.Mount == "" || req.Serial == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
|
||||
}
|
||||
resp, err := cs.caHandleRequest(ctx, req.Mount, "renew", &engine.Request{
|
||||
Operation: "renew",
|
||||
CallerInfo: cs.callerInfo(ctx),
|
||||
Data: map[string]interface{}{"serial": req.Serial},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serial, _ := resp.Data["serial"].(string)
|
||||
cn, _ := resp.Data["cn"].(string)
|
||||
issuedBy, _ := resp.Data["issued_by"].(string)
|
||||
certPEM, _ := resp.Data["cert_pem"].(string)
|
||||
keyPEM, _ := resp.Data["key_pem"].(string)
|
||||
chainPEM, _ := resp.Data["chain_pem"].(string)
|
||||
sans := toStringSliceFromInterface(resp.Data["sans"])
|
||||
var expiresAt *timestamppb.Timestamp
|
||||
if s, ok := resp.Data["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
return &pb.RenewCertResponse{
|
||||
Serial: serial,
|
||||
CommonName: cn,
|
||||
Sans: sans,
|
||||
IssuedBy: issuedBy,
|
||||
ExpiresAt: expiresAt,
|
||||
CertPem: []byte(certPEM),
|
||||
KeyPem: []byte(keyPEM),
|
||||
ChainPem: []byte(chainPEM),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func certRecordFromData(d map[string]interface{}) *pb.CertRecord {
|
||||
serial, _ := d["serial"].(string)
|
||||
issuer, _ := d["issuer"].(string)
|
||||
cn, _ := d["cn"].(string)
|
||||
profile, _ := d["profile"].(string)
|
||||
issuedBy, _ := d["issued_by"].(string)
|
||||
certPEM, _ := d["cert_pem"].(string)
|
||||
sans := toStringSliceFromInterface(d["sans"])
|
||||
var issuedAt, expiresAt *timestamppb.Timestamp
|
||||
if s, ok := d["issued_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
issuedAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
if s, ok := d["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
return &pb.CertRecord{
|
||||
Serial: serial,
|
||||
Issuer: issuer,
|
||||
CommonName: cn,
|
||||
Sans: sans,
|
||||
Profile: profile,
|
||||
IssuedBy: issuedBy,
|
||||
IssuedAt: issuedAt,
|
||||
ExpiresAt: expiresAt,
|
||||
CertPem: []byte(certPEM),
|
||||
}
|
||||
}
|
||||
|
||||
func certSummaryFromData(d map[string]interface{}) *pb.CertSummary {
|
||||
serial, _ := d["serial"].(string)
|
||||
issuer, _ := d["issuer"].(string)
|
||||
cn, _ := d["cn"].(string)
|
||||
profile, _ := d["profile"].(string)
|
||||
issuedBy, _ := d["issued_by"].(string)
|
||||
var issuedAt, expiresAt *timestamppb.Timestamp
|
||||
if s, ok := d["issued_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
issuedAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
if s, ok := d["expires_at"].(string); ok {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
expiresAt = timestamppb.New(t)
|
||||
}
|
||||
}
|
||||
return &pb.CertSummary{
|
||||
Serial: serial,
|
||||
Issuer: issuer,
|
||||
CommonName: cn,
|
||||
Profile: profile,
|
||||
IssuedBy: issuedBy,
|
||||
IssuedAt: issuedAt,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
}
|
||||
|
||||
func toStringSliceFromInterface(v interface{}) []string {
|
||||
raw, _ := v.([]interface{})
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
if s, ok := item.(string); ok {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user