package main import ( "bytes" "context" "crypto/tls" "crypto/x509" "encoding/json" "fmt" "io" "net/http" "os" "google.golang.org/grpc" "google.golang.org/grpc/credentials" mcrv1 "git.wntrmute.dev/kyle/mcr/gen/mcr/v1" ) // apiClient wraps both REST and gRPC transports. When grpcAddr is set // the gRPC service clients are used; otherwise requests go via REST. type apiClient struct { serverURL string token string httpClient *http.Client // gRPC (nil when --grpc is not set). grpcConn *grpc.ClientConn registry mcrv1.RegistryServiceClient policy mcrv1.PolicyServiceClient audit mcrv1.AuditServiceClient admin mcrv1.AdminServiceClient } // newClient builds an apiClient from the resolved flags. func newClient(serverURL, grpcAddr, token, caCertFile string) (*apiClient, error) { tlsCfg := &tls.Config{ MinVersion: tls.VersionTLS13, } if caCertFile != "" { pem, err := os.ReadFile(caCertFile) //nolint:gosec // CA cert path is operator-supplied if err != nil { return nil, fmt.Errorf("reading CA cert: %w", err) } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(pem) { return nil, fmt.Errorf("ca-cert file contains no valid certificates") } tlsCfg.RootCAs = pool } c := &apiClient{ serverURL: serverURL, token: token, httpClient: &http.Client{ Transport: &http.Transport{ TLSClientConfig: tlsCfg, }, }, } if grpcAddr != "" { creds := credentials.NewTLS(tlsCfg) cc, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(creds), grpc.WithDefaultCallOptions(grpc.ForceCodecV2(mcrv1.JSONCodec{})), ) if err != nil { return nil, fmt.Errorf("grpc dial: %w", err) } c.grpcConn = cc c.registry = mcrv1.NewRegistryServiceClient(cc) c.policy = mcrv1.NewPolicyServiceClient(cc) c.audit = mcrv1.NewAuditServiceClient(cc) c.admin = mcrv1.NewAdminServiceClient(cc) } return c, nil } // close shuts down the gRPC connection if open. func (c *apiClient) close() { if c.grpcConn != nil { _ = c.grpcConn.Close() } } // useGRPC returns true when the client should use gRPC transport. func (c *apiClient) useGRPC() bool { return c.grpcConn != nil } // apiError is the JSON error envelope returned by the REST API. type apiError struct { Error string `json:"error"` } // restDo performs an HTTP request and returns the response body. If the // response status is >= 400 it reads the JSON error body and returns a // descriptive error. func (c *apiClient) restDo(method, path string, body any) ([]byte, error) { url := c.serverURL + path var bodyReader io.Reader if body != nil { b, err := json.Marshal(body) if err != nil { return nil, fmt.Errorf("marshal request: %w", err) } bodyReader = bytes.NewReader(b) } req, err := http.NewRequestWithContext(context.Background(), method, url, bodyReader) if err != nil { return nil, fmt.Errorf("create request: %w", err) } if body != nil { req.Header.Set("Content-Type", "application/json") } if c.token != "" { req.Header.Set("Authorization", "Bearer "+c.token) } resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("http %s %s: %w", method, path, err) } defer func() { _ = resp.Body.Close() }() data, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("read response: %w", err) } if resp.StatusCode >= 400 { var ae apiError if json.Unmarshal(data, &ae) == nil && ae.Error != "" { return nil, fmt.Errorf("server error (%d): %s", resp.StatusCode, ae.Error) } return nil, fmt.Errorf("server error (%d): %s", resp.StatusCode, string(data)) } return data, nil }