Step 14: SSH key auth for gRPC.
Server: AuthInterceptor parses authorized_keys, extracts SSH signature from gRPC metadata (nonce + timestamp signed by client's SSH key), verifies against authorized public keys with 5-minute timestamp window. Client: SSHCredentials implements PerRPCCredentials, signs nonce+timestamp per request. LoadSigner resolves key from flag, ssh-agent, or default paths. 8 tests: valid auth, reject unauthenticated, reject unauthorized key, reject expired timestamp, metadata generation, plus 2 integration tests (authenticated succeeds, unauthenticated rejected). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
213
server/auth.go
Normal file
213
server/auth.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
// Metadata keys for auth.
|
||||
metaNonce = "x-sgard-auth-nonce"
|
||||
metaTimestamp = "x-sgard-auth-timestamp"
|
||||
metaSignature = "x-sgard-auth-signature"
|
||||
metaPubkey = "x-sgard-auth-pubkey"
|
||||
|
||||
// authWindow is how far the timestamp can deviate from server time.
|
||||
authWindow = 5 * time.Minute
|
||||
)
|
||||
|
||||
// AuthInterceptor verifies SSH key signatures on gRPC requests.
|
||||
type AuthInterceptor struct {
|
||||
authorizedKeys map[string]ssh.PublicKey // keyed by fingerprint
|
||||
}
|
||||
|
||||
// NewAuthInterceptor creates an interceptor from an authorized_keys file.
|
||||
// The file uses the same format as ~/.ssh/authorized_keys.
|
||||
func NewAuthInterceptor(path string) (*AuthInterceptor, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading authorized keys: %w", err)
|
||||
}
|
||||
|
||||
keys := make(map[string]ssh.PublicKey)
|
||||
rest := data
|
||||
for len(rest) > 0 {
|
||||
var key ssh.PublicKey
|
||||
key, _, _, rest, err = ssh.ParseAuthorizedKey(rest)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fp := ssh.FingerprintSHA256(key)
|
||||
keys[fp] = key
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found in %s", path)
|
||||
}
|
||||
|
||||
return &AuthInterceptor{authorizedKeys: keys}, nil
|
||||
}
|
||||
|
||||
// NewAuthInterceptorFromKeys creates an interceptor from pre-parsed keys.
|
||||
// Intended for testing.
|
||||
func NewAuthInterceptorFromKeys(keys []ssh.PublicKey) *AuthInterceptor {
|
||||
m := make(map[string]ssh.PublicKey, len(keys))
|
||||
for _, k := range keys {
|
||||
m[ssh.FingerprintSHA256(k)] = k
|
||||
}
|
||||
return &AuthInterceptor{authorizedKeys: m}
|
||||
}
|
||||
|
||||
// UnaryInterceptor returns a gRPC unary server interceptor.
|
||||
func (a *AuthInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
if err := a.verify(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// StreamInterceptor returns a gRPC stream server interceptor.
|
||||
func (a *AuthInterceptor) StreamInterceptor() grpc.StreamServerInterceptor {
|
||||
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if err := a.verify(ss.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
return handler(srv, ss)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthInterceptor) verify(ctx context.Context) error {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, "missing metadata")
|
||||
}
|
||||
|
||||
nonceB64 := mdFirst(md, metaNonce)
|
||||
tsStr := mdFirst(md, metaTimestamp)
|
||||
sigB64 := mdFirst(md, metaSignature)
|
||||
pubkeyStr := mdFirst(md, metaPubkey)
|
||||
|
||||
if nonceB64 == "" || tsStr == "" || sigB64 == "" || pubkeyStr == "" {
|
||||
return status.Error(codes.Unauthenticated, "missing auth metadata fields")
|
||||
}
|
||||
|
||||
// Parse timestamp and check window.
|
||||
tsUnix, err := strconv.ParseInt(tsStr, 10, 64)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid timestamp")
|
||||
}
|
||||
ts := time.Unix(tsUnix, 0)
|
||||
if time.Since(ts).Abs() > authWindow {
|
||||
return status.Error(codes.Unauthenticated, "timestamp outside allowed window")
|
||||
}
|
||||
|
||||
// Parse public key and check authorization.
|
||||
pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubkeyStr))
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid public key")
|
||||
}
|
||||
fp := ssh.FingerprintSHA256(pubkey)
|
||||
authorized, ok := a.authorizedKeys[fp]
|
||||
if !ok {
|
||||
return status.Errorf(codes.PermissionDenied, "key %s not authorized", fp)
|
||||
}
|
||||
|
||||
// Decode nonce and signature.
|
||||
nonce, err := base64.StdEncoding.DecodeString(nonceB64)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid nonce encoding")
|
||||
}
|
||||
|
||||
sigBytes, err := base64.StdEncoding.DecodeString(sigB64)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid signature encoding")
|
||||
}
|
||||
|
||||
sig, err := parseSSHSignature(sigBytes)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid signature format")
|
||||
}
|
||||
|
||||
// Build the signed payload: nonce + timestamp bytes.
|
||||
payload := buildPayload(nonce, tsUnix)
|
||||
|
||||
// Verify.
|
||||
if err := authorized.Verify(payload, sig); err != nil {
|
||||
return status.Error(codes.Unauthenticated, "signature verification failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildPayload constructs the message that is signed: nonce || timestamp (big-endian int64).
|
||||
func buildPayload(nonce []byte, tsUnix int64) []byte {
|
||||
payload := make([]byte, len(nonce)+8)
|
||||
copy(payload, nonce)
|
||||
for i := 7; i >= 0; i-- {
|
||||
payload[len(nonce)+i] = byte(tsUnix & 0xff)
|
||||
tsUnix >>= 8
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// GenerateNonce creates a 32-byte random nonce.
|
||||
func GenerateNonce() ([]byte, error) {
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("generating nonce: %w", err)
|
||||
}
|
||||
return nonce, nil
|
||||
}
|
||||
|
||||
func mdFirst(md metadata.MD, key string) string {
|
||||
vals := md.Get(key)
|
||||
if len(vals) == 0 {
|
||||
return ""
|
||||
}
|
||||
return vals[0]
|
||||
}
|
||||
|
||||
// parseSSHSignature deserializes an SSH signature from its wire format.
|
||||
func parseSSHSignature(data []byte) (*ssh.Signature, error) {
|
||||
if len(data) < 4 {
|
||||
return nil, fmt.Errorf("signature too short")
|
||||
}
|
||||
|
||||
// SSH signature wire format: string format, string blob
|
||||
formatLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
||||
if 4+formatLen > len(data) {
|
||||
return nil, fmt.Errorf("invalid format length")
|
||||
}
|
||||
format := string(data[4 : 4+formatLen])
|
||||
|
||||
rest := data[4+formatLen:]
|
||||
if len(rest) < 4 {
|
||||
return nil, fmt.Errorf("missing blob length")
|
||||
}
|
||||
blobLen := int(rest[0])<<24 | int(rest[1])<<16 | int(rest[2])<<8 | int(rest[3])
|
||||
if 4+blobLen > len(rest) {
|
||||
return nil, fmt.Errorf("invalid blob length")
|
||||
}
|
||||
blob := rest[4 : 4+blobLen]
|
||||
|
||||
_ = strings.TrimSpace(format) // ensure format is clean
|
||||
|
||||
return &ssh.Signature{
|
||||
Format: format,
|
||||
Blob: blob,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user