Step 14: SSH key auth for gRPC.

Server: AuthInterceptor parses authorized_keys, extracts SSH signature
from gRPC metadata (nonce + timestamp signed by client's SSH key),
verifies against authorized public keys with 5-minute timestamp window.

Client: SSHCredentials implements PerRPCCredentials, signs nonce+timestamp
per request. LoadSigner resolves key from flag, ssh-agent, or default paths.

8 tests: valid auth, reject unauthenticated, reject unauthorized key,
reject expired timestamp, metadata generation, plus 2 integration tests
(authenticated succeeds, unauthenticated rejected).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-23 23:58:09 -07:00
parent 525c3f0b4f
commit 4b841cdd82
7 changed files with 621 additions and 6 deletions

213
server/auth.go Normal file
View File

@@ -0,0 +1,213 @@
package server
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"os"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
const (
// Metadata keys for auth.
metaNonce = "x-sgard-auth-nonce"
metaTimestamp = "x-sgard-auth-timestamp"
metaSignature = "x-sgard-auth-signature"
metaPubkey = "x-sgard-auth-pubkey"
// authWindow is how far the timestamp can deviate from server time.
authWindow = 5 * time.Minute
)
// AuthInterceptor verifies SSH key signatures on gRPC requests.
type AuthInterceptor struct {
authorizedKeys map[string]ssh.PublicKey // keyed by fingerprint
}
// NewAuthInterceptor creates an interceptor from an authorized_keys file.
// The file uses the same format as ~/.ssh/authorized_keys.
func NewAuthInterceptor(path string) (*AuthInterceptor, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading authorized keys: %w", err)
}
keys := make(map[string]ssh.PublicKey)
rest := data
for len(rest) > 0 {
var key ssh.PublicKey
key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
if err != nil {
break
}
fp := ssh.FingerprintSHA256(key)
keys[fp] = key
}
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in %s", path)
}
return &AuthInterceptor{authorizedKeys: keys}, nil
}
// NewAuthInterceptorFromKeys creates an interceptor from pre-parsed keys.
// Intended for testing.
func NewAuthInterceptorFromKeys(keys []ssh.PublicKey) *AuthInterceptor {
m := make(map[string]ssh.PublicKey, len(keys))
for _, k := range keys {
m[ssh.FingerprintSHA256(k)] = k
}
return &AuthInterceptor{authorizedKeys: m}
}
// UnaryInterceptor returns a gRPC unary server interceptor.
func (a *AuthInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if err := a.verify(ctx); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
// StreamInterceptor returns a gRPC stream server interceptor.
func (a *AuthInterceptor) StreamInterceptor() grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := a.verify(ss.Context()); err != nil {
return err
}
return handler(srv, ss)
}
}
func (a *AuthInterceptor) verify(ctx context.Context) error {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Error(codes.Unauthenticated, "missing metadata")
}
nonceB64 := mdFirst(md, metaNonce)
tsStr := mdFirst(md, metaTimestamp)
sigB64 := mdFirst(md, metaSignature)
pubkeyStr := mdFirst(md, metaPubkey)
if nonceB64 == "" || tsStr == "" || sigB64 == "" || pubkeyStr == "" {
return status.Error(codes.Unauthenticated, "missing auth metadata fields")
}
// Parse timestamp and check window.
tsUnix, err := strconv.ParseInt(tsStr, 10, 64)
if err != nil {
return status.Error(codes.Unauthenticated, "invalid timestamp")
}
ts := time.Unix(tsUnix, 0)
if time.Since(ts).Abs() > authWindow {
return status.Error(codes.Unauthenticated, "timestamp outside allowed window")
}
// Parse public key and check authorization.
pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubkeyStr))
if err != nil {
return status.Error(codes.Unauthenticated, "invalid public key")
}
fp := ssh.FingerprintSHA256(pubkey)
authorized, ok := a.authorizedKeys[fp]
if !ok {
return status.Errorf(codes.PermissionDenied, "key %s not authorized", fp)
}
// Decode nonce and signature.
nonce, err := base64.StdEncoding.DecodeString(nonceB64)
if err != nil {
return status.Error(codes.Unauthenticated, "invalid nonce encoding")
}
sigBytes, err := base64.StdEncoding.DecodeString(sigB64)
if err != nil {
return status.Error(codes.Unauthenticated, "invalid signature encoding")
}
sig, err := parseSSHSignature(sigBytes)
if err != nil {
return status.Error(codes.Unauthenticated, "invalid signature format")
}
// Build the signed payload: nonce + timestamp bytes.
payload := buildPayload(nonce, tsUnix)
// Verify.
if err := authorized.Verify(payload, sig); err != nil {
return status.Error(codes.Unauthenticated, "signature verification failed")
}
return nil
}
// buildPayload constructs the message that is signed: nonce || timestamp (big-endian int64).
func buildPayload(nonce []byte, tsUnix int64) []byte {
payload := make([]byte, len(nonce)+8)
copy(payload, nonce)
for i := 7; i >= 0; i-- {
payload[len(nonce)+i] = byte(tsUnix & 0xff)
tsUnix >>= 8
}
return payload
}
// GenerateNonce creates a 32-byte random nonce.
func GenerateNonce() ([]byte, error) {
nonce := make([]byte, 32)
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("generating nonce: %w", err)
}
return nonce, nil
}
func mdFirst(md metadata.MD, key string) string {
vals := md.Get(key)
if len(vals) == 0 {
return ""
}
return vals[0]
}
// parseSSHSignature deserializes an SSH signature from its wire format.
func parseSSHSignature(data []byte) (*ssh.Signature, error) {
if len(data) < 4 {
return nil, fmt.Errorf("signature too short")
}
// SSH signature wire format: string format, string blob
formatLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if 4+formatLen > len(data) {
return nil, fmt.Errorf("invalid format length")
}
format := string(data[4 : 4+formatLen])
rest := data[4+formatLen:]
if len(rest) < 4 {
return nil, fmt.Errorf("missing blob length")
}
blobLen := int(rest[0])<<24 | int(rest[1])<<16 | int(rest[2])<<8 | int(rest[3])
if 4+blobLen > len(rest) {
return nil, fmt.Errorf("invalid blob length")
}
blob := rest[4 : 4+blobLen]
_ = strings.TrimSpace(format) // ensure format is clean
return &ssh.Signature{
Format: format,
Blob: blob,
}, nil
}

119
server/auth_test.go Normal file
View File

@@ -0,0 +1,119 @@
package server
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"strconv"
"testing"
"time"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc/metadata"
)
func generateTestKey(t *testing.T) (ssh.Signer, ssh.PublicKey) {
t.Helper()
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generating key: %v", err)
}
signer, err := ssh.NewSignerFromKey(priv)
if err != nil {
t.Fatalf("creating signer: %v", err)
}
return signer, signer.PublicKey()
}
func signedContext(t *testing.T, signer ssh.Signer) context.Context {
t.Helper()
nonce, err := GenerateNonce()
if err != nil {
t.Fatalf("generating nonce: %v", err)
}
tsUnix := time.Now().Unix()
payload := buildPayload(nonce, tsUnix)
sig, err := signer.Sign(rand.Reader, payload)
if err != nil {
t.Fatalf("signing: %v", err)
}
pubkeyStr := string(ssh.MarshalAuthorizedKey(signer.PublicKey()))
md := metadata.New(map[string]string{
metaNonce: base64.StdEncoding.EncodeToString(nonce),
metaTimestamp: strconv.FormatInt(tsUnix, 10),
metaSignature: base64.StdEncoding.EncodeToString(ssh.Marshal(sig)),
metaPubkey: pubkeyStr,
})
return metadata.NewIncomingContext(context.Background(), md)
}
func TestAuthVerifyValid(t *testing.T) {
signer, pubkey := generateTestKey(t)
interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey})
ctx := signedContext(t, signer)
if err := interceptor.verify(ctx); err != nil {
t.Fatalf("verify should succeed: %v", err)
}
}
func TestAuthRejectUnauthenticated(t *testing.T) {
_, pubkey := generateTestKey(t)
interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey})
// No metadata at all.
ctx := context.Background()
if err := interceptor.verify(ctx); err == nil {
t.Fatal("verify should reject missing metadata")
}
}
func TestAuthRejectUnauthorizedKey(t *testing.T) {
signer1, _ := generateTestKey(t)
_, pubkey2 := generateTestKey(t)
// Interceptor knows key2 but request is signed by key1.
interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey2})
ctx := signedContext(t, signer1)
if err := interceptor.verify(ctx); err == nil {
t.Fatal("verify should reject unauthorized key")
}
}
func TestAuthRejectExpiredTimestamp(t *testing.T) {
signer, pubkey := generateTestKey(t)
interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey})
nonce, err := GenerateNonce()
if err != nil {
t.Fatalf("generating nonce: %v", err)
}
// Timestamp 10 minutes ago — outside the 5-minute window.
tsUnix := time.Now().Add(-10 * time.Minute).Unix()
payload := buildPayload(nonce, tsUnix)
sig, err := signer.Sign(rand.Reader, payload)
if err != nil {
t.Fatalf("signing: %v", err)
}
pubkeyStr := string(ssh.MarshalAuthorizedKey(signer.PublicKey()))
md := metadata.New(map[string]string{
metaNonce: base64.StdEncoding.EncodeToString(nonce),
metaTimestamp: strconv.FormatInt(tsUnix, 10),
metaSignature: base64.StdEncoding.EncodeToString(ssh.Marshal(sig)),
metaPubkey: pubkeyStr,
})
ctx := metadata.NewIncomingContext(context.Background(), md)
if err := interceptor.verify(ctx); err == nil {
t.Fatal("verify should reject expired timestamp")
}
}