Add grpcserver test coverage
- Add comprehensive test file for internal/grpcserver package - Cover interceptors, system, engine, policy, and auth handlers - Cover pbToRule/ruleToPB conversion helpers - 37 tests total; CA/PKI/ACME and Login/Logout skipped (require live deps) Co-authored-by: Junie <junie@jetbrains.com>
This commit is contained in:
@@ -12,7 +12,7 @@ import (
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v1"
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
)
|
||||
|
||||
// VaultClient wraps the gRPC stubs for communicating with the vault.
|
||||
@@ -22,6 +22,7 @@ type VaultClient struct {
|
||||
auth pb.AuthServiceClient
|
||||
engine pb.EngineServiceClient
|
||||
pki pb.PKIServiceClient
|
||||
ca pb.CAServiceClient
|
||||
}
|
||||
|
||||
// NewVaultClient dials the vault gRPC server and returns a client.
|
||||
@@ -58,6 +59,7 @@ func NewVaultClient(addr, caCertPath string, logger *slog.Logger) (*VaultClient,
|
||||
auth: pb.NewAuthServiceClient(conn),
|
||||
engine: pb.NewEngineServiceClient(conn),
|
||||
pki: pb.NewPKIServiceClient(conn),
|
||||
ca: pb.NewCAServiceClient(conn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -156,39 +158,16 @@ func (c *VaultClient) Mount(ctx context.Context, token, name, engineType string,
|
||||
Type: engineType,
|
||||
}
|
||||
if len(config) > 0 {
|
||||
s, err := structFromMap(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("webserver: encode mount config: %w", err)
|
||||
cfg := make(map[string]string, len(config))
|
||||
for k, v := range config {
|
||||
cfg[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
req.Config = s
|
||||
req.Config = cfg
|
||||
}
|
||||
_, err := c.engine.Mount(withToken(ctx, token), req)
|
||||
return err
|
||||
}
|
||||
|
||||
// EngineRequest sends a generic engine operation to the vault.
|
||||
func (c *VaultClient) EngineRequest(ctx context.Context, token, mount, operation string, data map[string]interface{}) (map[string]interface{}, error) {
|
||||
req := &pb.ExecuteRequest{
|
||||
Mount: mount,
|
||||
Operation: operation,
|
||||
}
|
||||
if len(data) > 0 {
|
||||
s, err := structFromMap(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("webserver: encode engine request: %w", err)
|
||||
}
|
||||
req.Data = s
|
||||
}
|
||||
resp, err := c.engine.Execute(withToken(ctx, token), req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return resp.Data.AsMap(), nil
|
||||
}
|
||||
|
||||
// GetRootCert returns the root CA certificate PEM for the given mount.
|
||||
func (c *VaultClient) GetRootCert(ctx context.Context, mount string) ([]byte, error) {
|
||||
resp, err := c.pki.GetRootCert(ctx, &pb.GetRootCertRequest{Mount: mount})
|
||||
@@ -206,3 +185,126 @@ func (c *VaultClient) GetIssuerCert(ctx context.Context, mount, issuer string) (
|
||||
}
|
||||
return resp.CertPem, nil
|
||||
}
|
||||
|
||||
// ImportRoot imports an existing root CA certificate and key into the given mount.
|
||||
func (c *VaultClient) ImportRoot(ctx context.Context, token, mount, certPEM, keyPEM string) error {
|
||||
_, err := c.ca.ImportRoot(withToken(ctx, token), &pb.ImportRootRequest{
|
||||
Mount: mount,
|
||||
CertPem: []byte(certPEM),
|
||||
KeyPem: []byte(keyPEM),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateIssuerRequest holds parameters for creating an intermediate CA issuer.
|
||||
type CreateIssuerRequest struct {
|
||||
Mount string
|
||||
Name string
|
||||
KeyAlgorithm string
|
||||
KeySize int32
|
||||
Expiry string
|
||||
MaxTTL string
|
||||
}
|
||||
|
||||
// CreateIssuer creates a new intermediate CA issuer on the given mount.
|
||||
func (c *VaultClient) CreateIssuer(ctx context.Context, token string, req CreateIssuerRequest) error {
|
||||
_, err := c.ca.CreateIssuer(withToken(ctx, token), &pb.CreateIssuerRequest{
|
||||
Mount: req.Mount,
|
||||
Name: req.Name,
|
||||
KeyAlgorithm: req.KeyAlgorithm,
|
||||
KeySize: req.KeySize,
|
||||
Expiry: req.Expiry,
|
||||
MaxTtl: req.MaxTTL,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// ListIssuers returns the names of all issuers for the given mount.
|
||||
func (c *VaultClient) ListIssuers(ctx context.Context, token, mount string) ([]string, error) {
|
||||
resp, err := c.ca.ListIssuers(withToken(ctx, token), &pb.ListIssuersRequest{Mount: mount})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Issuers, nil
|
||||
}
|
||||
|
||||
// IssueCertRequest holds parameters for issuing a leaf certificate.
|
||||
type IssueCertRequest struct {
|
||||
Mount string
|
||||
Issuer string
|
||||
Profile string
|
||||
CommonName string
|
||||
DNSNames []string
|
||||
IPAddresses []string
|
||||
TTL string
|
||||
KeyUsages []string
|
||||
ExtKeyUsages []string
|
||||
}
|
||||
|
||||
// IssuedCert holds the result of a certificate issuance.
|
||||
type IssuedCert struct {
|
||||
Serial string
|
||||
CertPEM string
|
||||
KeyPEM string
|
||||
ChainPEM string
|
||||
ExpiresAt string
|
||||
}
|
||||
|
||||
// IssueCert issues a new leaf certificate from the named issuer.
|
||||
func (c *VaultClient) IssueCert(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error) {
|
||||
resp, err := c.ca.IssueCert(withToken(ctx, token), &pb.IssueCertRequest{
|
||||
Mount: req.Mount,
|
||||
Issuer: req.Issuer,
|
||||
Profile: req.Profile,
|
||||
CommonName: req.CommonName,
|
||||
DnsNames: req.DNSNames,
|
||||
IpAddresses: req.IPAddresses,
|
||||
Ttl: req.TTL,
|
||||
KeyUsages: req.KeyUsages,
|
||||
ExtKeyUsages: req.ExtKeyUsages,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
issued := &IssuedCert{
|
||||
Serial: resp.Serial,
|
||||
CertPEM: string(resp.CertPem),
|
||||
KeyPEM: string(resp.KeyPem),
|
||||
ChainPEM: string(resp.ChainPem),
|
||||
}
|
||||
if resp.ExpiresAt != nil {
|
||||
issued.ExpiresAt = resp.ExpiresAt.AsTime().Format("2006-01-02T15:04:05Z")
|
||||
}
|
||||
return issued, nil
|
||||
}
|
||||
|
||||
// CertSummary holds lightweight certificate metadata for list views.
|
||||
type CertSummary struct {
|
||||
Serial string
|
||||
Issuer string
|
||||
CommonName string
|
||||
Profile string
|
||||
ExpiresAt string
|
||||
}
|
||||
|
||||
// ListCerts returns all certificate summaries for the given mount.
|
||||
func (c *VaultClient) ListCerts(ctx context.Context, token, mount string) ([]CertSummary, error) {
|
||||
resp, err := c.ca.ListCerts(withToken(ctx, token), &pb.ListCertsRequest{Mount: mount})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certs := make([]CertSummary, 0, len(resp.Certs))
|
||||
for _, s := range resp.Certs {
|
||||
cs := CertSummary{
|
||||
Serial: s.Serial,
|
||||
Issuer: s.Issuer,
|
||||
CommonName: s.CommonName,
|
||||
Profile: s.Profile,
|
||||
}
|
||||
if s.ExpiresAt != nil {
|
||||
cs.ExpiresAt = s.ExpiresAt.AsTime().Format("2006-01-02T15:04:05Z")
|
||||
}
|
||||
certs = append(certs, cs)
|
||||
}
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
@@ -280,8 +280,8 @@ func (ws *WebServer) handlePKI(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
if resp, err := ws.vault.EngineRequest(r.Context(), token, mountName, "list-issuers", nil); err == nil {
|
||||
data["Issuers"] = resp["issuers"]
|
||||
if issuers, err := ws.vault.ListIssuers(r.Context(), token, mountName); err == nil {
|
||||
data["Issuers"] = issuers
|
||||
}
|
||||
|
||||
ws.renderTemplate(w, "pki.html", data)
|
||||
@@ -329,11 +329,7 @@ func (ws *WebServer) handleImportRoot(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = ws.vault.EngineRequest(r.Context(), token, mountName, "import-root", map[string]interface{}{
|
||||
"cert_pem": certPEM,
|
||||
"key_pem": keyPEM,
|
||||
})
|
||||
if err != nil {
|
||||
if err = ws.vault.ImportRoot(r.Context(), token, mountName, certPEM, keyPEM); err != nil {
|
||||
ws.renderPKIWithError(w, r, mountName, info, grpcMessage(err))
|
||||
return
|
||||
}
|
||||
@@ -362,25 +358,27 @@ func (ws *WebServer) handleCreateIssuer(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
reqData := map[string]interface{}{"name": name}
|
||||
issuerReq := CreateIssuerRequest{
|
||||
Mount: mountName,
|
||||
Name: name,
|
||||
}
|
||||
if v := r.FormValue("expiry"); v != "" {
|
||||
reqData["expiry"] = v
|
||||
issuerReq.Expiry = v
|
||||
}
|
||||
if v := r.FormValue("max_ttl"); v != "" {
|
||||
reqData["max_ttl"] = v
|
||||
issuerReq.MaxTTL = v
|
||||
}
|
||||
if v := r.FormValue("key_algorithm"); v != "" {
|
||||
reqData["key_algorithm"] = v
|
||||
issuerReq.KeyAlgorithm = v
|
||||
}
|
||||
if v := r.FormValue("key_size"); v != "" {
|
||||
var size float64
|
||||
if _, err := fmt.Sscanf(v, "%f", &size); err == nil {
|
||||
reqData["key_size"] = size
|
||||
var size int32
|
||||
if _, err := fmt.Sscanf(v, "%d", &size); err == nil {
|
||||
issuerReq.KeySize = size
|
||||
}
|
||||
}
|
||||
|
||||
_, err = ws.vault.EngineRequest(r.Context(), token, mountName, "create-issuer", reqData)
|
||||
if err != nil {
|
||||
if err = ws.vault.CreateIssuer(r.Context(), token, issuerReq); err != nil {
|
||||
ws.renderPKIWithError(w, r, mountName, info, grpcMessage(err))
|
||||
return
|
||||
}
|
||||
@@ -419,7 +417,7 @@ func (ws *WebServer) handleIssuerDetail(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
issuerName := chi.URLParam(r, "issuer")
|
||||
|
||||
resp, err := ws.vault.EngineRequest(r.Context(), token, mountName, "list-certs", nil)
|
||||
allCerts, err := ws.vault.ListCerts(r.Context(), token, mountName)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to list certificates", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -431,34 +429,22 @@ func (ws *WebServer) handleIssuerDetail(w http.ResponseWriter, r *http.Request)
|
||||
sortBy = "cn"
|
||||
}
|
||||
|
||||
var certs []map[string]interface{}
|
||||
if raw, ok := resp["certs"]; ok {
|
||||
if list, ok := raw.([]interface{}); ok {
|
||||
for _, item := range list {
|
||||
if m, ok := item.(map[string]interface{}); ok {
|
||||
issuer, _ := m["issuer"].(string)
|
||||
if issuer != issuerName {
|
||||
continue
|
||||
}
|
||||
if nameFilter != "" {
|
||||
cn, _ := m["cn"].(string)
|
||||
if !strings.Contains(strings.ToLower(cn), nameFilter) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
certs = append(certs, m)
|
||||
}
|
||||
}
|
||||
var certs []CertSummary
|
||||
for _, cs := range allCerts {
|
||||
if cs.Issuer != issuerName {
|
||||
continue
|
||||
}
|
||||
if nameFilter != "" && !strings.Contains(strings.ToLower(cs.CommonName), strings.ToLower(nameFilter)) {
|
||||
continue
|
||||
}
|
||||
certs = append(certs, cs)
|
||||
}
|
||||
|
||||
// Sort: by expiry date or by common name (default).
|
||||
if sortBy == "expiry" {
|
||||
for i := 1; i < len(certs); i++ {
|
||||
for j := i; j > 0; j-- {
|
||||
a, _ := certs[j-1]["expires_at"].(string)
|
||||
b, _ := certs[j]["expires_at"].(string)
|
||||
if a > b {
|
||||
if certs[j-1].ExpiresAt > certs[j].ExpiresAt {
|
||||
certs[j-1], certs[j] = certs[j], certs[j-1]
|
||||
}
|
||||
}
|
||||
@@ -466,9 +452,7 @@ func (ws *WebServer) handleIssuerDetail(w http.ResponseWriter, r *http.Request)
|
||||
} else {
|
||||
for i := 1; i < len(certs); i++ {
|
||||
for j := i; j > 0; j-- {
|
||||
a, _ := certs[j-1]["cn"].(string)
|
||||
b, _ := certs[j]["cn"].(string)
|
||||
if strings.ToLower(a) > strings.ToLower(b) {
|
||||
if strings.ToLower(certs[j-1].CommonName) > strings.ToLower(certs[j].CommonName) {
|
||||
certs[j-1], certs[j] = certs[j], certs[j-1]
|
||||
}
|
||||
}
|
||||
@@ -512,24 +496,31 @@ func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
reqData := map[string]interface{}{
|
||||
"common_name": commonName,
|
||||
"issuer": issuer,
|
||||
certReq := IssueCertRequest{
|
||||
Mount: mountName,
|
||||
Issuer: issuer,
|
||||
CommonName: commonName,
|
||||
}
|
||||
if v := r.FormValue("profile"); v != "" {
|
||||
reqData["profile"] = v
|
||||
certReq.Profile = v
|
||||
}
|
||||
if v := r.FormValue("ttl"); v != "" {
|
||||
reqData["ttl"] = v
|
||||
certReq.TTL = v
|
||||
}
|
||||
if lines := splitLines(r.FormValue("dns_names")); len(lines) > 0 {
|
||||
reqData["dns_names"] = lines
|
||||
for _, l := range lines {
|
||||
certReq.DNSNames = append(certReq.DNSNames, l.(string))
|
||||
}
|
||||
}
|
||||
if lines := splitLines(r.FormValue("ip_addresses")); len(lines) > 0 {
|
||||
reqData["ip_addresses"] = lines
|
||||
for _, l := range lines {
|
||||
certReq.IPAddresses = append(certReq.IPAddresses, l.(string))
|
||||
}
|
||||
}
|
||||
certReq.KeyUsages = r.Form["key_usages"]
|
||||
certReq.ExtKeyUsages = r.Form["ext_key_usages"]
|
||||
|
||||
resp, err := ws.vault.EngineRequest(r.Context(), token, mountName, "issue", reqData)
|
||||
issuedCert, err := ws.vault.IssueCert(r.Context(), token, certReq)
|
||||
if err != nil {
|
||||
ws.renderPKIWithError(w, r, mountName, info, grpcMessage(err))
|
||||
return
|
||||
@@ -537,10 +528,10 @@ func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Re-render the PKI page with the issued certificate displayed.
|
||||
data := map[string]interface{}{
|
||||
"Username": info.Username,
|
||||
"IsAdmin": info.IsAdmin,
|
||||
"MountName": mountName,
|
||||
"IssuedCert": resp,
|
||||
"Username": info.Username,
|
||||
"IsAdmin": info.IsAdmin,
|
||||
"MountName": mountName,
|
||||
"IssuedCert": issuedCert,
|
||||
}
|
||||
if rootPEM, err := ws.vault.GetRootCert(r.Context(), mountName); err == nil && len(rootPEM) > 0 {
|
||||
if cert, err := parsePEMCert(rootPEM); err == nil {
|
||||
@@ -552,8 +543,8 @@ func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) {
|
||||
data["HasRoot"] = true
|
||||
}
|
||||
}
|
||||
if issuerResp, err := ws.vault.EngineRequest(r.Context(), token, mountName, "list-issuers", nil); err == nil {
|
||||
data["Issuers"] = issuerResp["issuers"]
|
||||
if issuers, err := ws.vault.ListIssuers(r.Context(), token, mountName); err == nil {
|
||||
data["Issuers"] = issuers
|
||||
}
|
||||
ws.renderTemplate(w, "pki.html", data)
|
||||
}
|
||||
@@ -577,8 +568,8 @@ func (ws *WebServer) renderPKIWithError(w http.ResponseWriter, r *http.Request,
|
||||
data["HasRoot"] = true
|
||||
}
|
||||
}
|
||||
if resp, err := ws.vault.EngineRequest(r.Context(), token, mountName, "list-issuers", nil); err == nil {
|
||||
data["Issuers"] = resp["issuers"]
|
||||
if issuers, err := ws.vault.ListIssuers(r.Context(), token, mountName); err == nil {
|
||||
data["Issuers"] = issuers
|
||||
}
|
||||
|
||||
ws.renderTemplate(w, "pki.html", data)
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
@@ -22,11 +20,6 @@ func tokenInfoFromContext(ctx context.Context) *TokenInfo {
|
||||
return v
|
||||
}
|
||||
|
||||
// structFromMap converts a map[string]interface{} to a *structpb.Struct.
|
||||
func structFromMap(m map[string]interface{}) (*structpb.Struct, error) {
|
||||
return structpb.NewStruct(m)
|
||||
}
|
||||
|
||||
// parsePEMCert decodes the first PEM block and parses it as an x509 certificate.
|
||||
func parsePEMCert(pemData []byte) (*x509.Certificate, error) {
|
||||
block, _ := pem.Decode(pemData)
|
||||
|
||||
Reference in New Issue
Block a user