Step 14: SSH key auth for gRPC.
Server: AuthInterceptor parses authorized_keys, extracts SSH signature from gRPC metadata (nonce + timestamp signed by client's SSH key), verifies against authorized public keys with 5-minute timestamp window. Client: SSHCredentials implements PerRPCCredentials, signs nonce+timestamp per request. LoadSigner resolves key from flag, ssh-agent, or default paths. 8 tests: valid auth, reject unauthenticated, reject unauthorized key, reject expired timestamp, metadata generation, plus 2 integration tests (authenticated succeeds, unauthenticated rejected). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
124
client/auth.go
Normal file
124
client/auth.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
// SSHCredentials implements grpc.PerRPCCredentials using an SSH signer.
|
||||
type SSHCredentials struct {
|
||||
signer ssh.Signer
|
||||
}
|
||||
|
||||
// NewSSHCredentials creates credentials from an SSH signer.
|
||||
func NewSSHCredentials(signer ssh.Signer) *SSHCredentials {
|
||||
return &SSHCredentials{signer: signer}
|
||||
}
|
||||
|
||||
// GetRequestMetadata signs a nonce+timestamp and returns auth metadata.
|
||||
func (c *SSHCredentials) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) {
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("generating nonce: %w", err)
|
||||
}
|
||||
|
||||
tsUnix := time.Now().Unix()
|
||||
payload := buildPayload(nonce, tsUnix)
|
||||
|
||||
sig, err := c.signer.Sign(rand.Reader, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("signing payload: %w", err)
|
||||
}
|
||||
|
||||
pubkey := c.signer.PublicKey()
|
||||
pubkeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubkey)))
|
||||
|
||||
return map[string]string{
|
||||
"x-sgard-auth-nonce": base64.StdEncoding.EncodeToString(nonce),
|
||||
"x-sgard-auth-timestamp": strconv.FormatInt(tsUnix, 10),
|
||||
"x-sgard-auth-signature": base64.StdEncoding.EncodeToString(ssh.Marshal(sig)),
|
||||
"x-sgard-auth-pubkey": pubkeyStr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequireTransportSecurity returns false — auth is via SSH signatures,
|
||||
// not TLS. Transport security can be added separately.
|
||||
func (c *SSHCredentials) RequireTransportSecurity() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify that SSHCredentials implements the interface.
|
||||
var _ credentials.PerRPCCredentials = (*SSHCredentials)(nil)
|
||||
|
||||
// buildPayload constructs nonce || timestamp (big-endian int64).
|
||||
func buildPayload(nonce []byte, tsUnix int64) []byte {
|
||||
payload := make([]byte, len(nonce)+8)
|
||||
copy(payload, nonce)
|
||||
for i := 7; i >= 0; i-- {
|
||||
payload[len(nonce)+i] = byte(tsUnix & 0xff)
|
||||
tsUnix >>= 8
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// LoadSigner loads an SSH signer. Resolution order:
|
||||
// 1. keyPath (if non-empty)
|
||||
// 2. SSH agent (if SSH_AUTH_SOCK is set)
|
||||
// 3. Default key paths: ~/.ssh/id_ed25519, ~/.ssh/id_rsa
|
||||
func LoadSigner(keyPath string) (ssh.Signer, error) {
|
||||
if keyPath != "" {
|
||||
return loadSignerFromFile(keyPath)
|
||||
}
|
||||
|
||||
// Try ssh-agent.
|
||||
if sock := os.Getenv("SSH_AUTH_SOCK"); sock != "" {
|
||||
conn, err := net.Dial("unix", sock)
|
||||
if err == nil {
|
||||
ag := agent.NewClient(conn)
|
||||
signers, err := ag.Signers()
|
||||
if err == nil && len(signers) > 0 {
|
||||
return signers[0], nil
|
||||
}
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Try default key paths.
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no SSH key found: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range []string{"id_ed25519", "id_rsa"} {
|
||||
path := home + "/.ssh/" + name
|
||||
signer, err := loadSignerFromFile(path)
|
||||
if err == nil {
|
||||
return signer, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no SSH key found (tried --ssh-key, agent, ~/.ssh/id_ed25519, ~/.ssh/id_rsa)")
|
||||
}
|
||||
|
||||
func loadSignerFromFile(path string) (ssh.Signer, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading key %s: %w", path, err)
|
||||
}
|
||||
signer, err := ssh.ParsePrivateKey(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing key %s: %w", path, err)
|
||||
}
|
||||
return signer, nil
|
||||
}
|
||||
57
client/auth_test.go
Normal file
57
client/auth_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestSSHCredentialsMetadata(t *testing.T) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generating key: %v", err)
|
||||
}
|
||||
signer, err := ssh.NewSignerFromKey(priv)
|
||||
if err != nil {
|
||||
t.Fatalf("creating signer: %v", err)
|
||||
}
|
||||
|
||||
creds := NewSSHCredentials(signer)
|
||||
|
||||
md, err := creds.GetRequestMetadata(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetRequestMetadata: %v", err)
|
||||
}
|
||||
|
||||
// Verify all required fields are present and non-empty.
|
||||
for _, key := range []string{
|
||||
"x-sgard-auth-nonce",
|
||||
"x-sgard-auth-timestamp",
|
||||
"x-sgard-auth-signature",
|
||||
"x-sgard-auth-pubkey",
|
||||
} {
|
||||
val, ok := md[key]
|
||||
if !ok || val == "" {
|
||||
t.Errorf("missing or empty metadata key %s", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCredentialsNoTransportSecurity(t *testing.T) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generating key: %v", err)
|
||||
}
|
||||
signer, err := ssh.NewSignerFromKey(priv)
|
||||
if err != nil {
|
||||
t.Fatalf("creating signer: %v", err)
|
||||
}
|
||||
|
||||
creds := NewSSHCredentials(signer)
|
||||
if creds.RequireTransportSecurity() {
|
||||
t.Error("RequireTransportSecurity should be false")
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
@@ -12,6 +14,7 @@ import (
|
||||
"github.com/kisom/sgard/garden"
|
||||
"github.com/kisom/sgard/server"
|
||||
"github.com/kisom/sgard/sgardpb"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
@@ -206,3 +209,100 @@ func TestPrune(t *testing.T) {
|
||||
t.Errorf("removed %d blobs, want 1", removed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthIntegration(t *testing.T) {
|
||||
// Generate an ed25519 key pair.
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generating key: %v", err)
|
||||
}
|
||||
signer, err := ssh.NewSignerFromKey(priv)
|
||||
if err != nil {
|
||||
t.Fatalf("creating signer: %v", err)
|
||||
}
|
||||
|
||||
serverDir := t.TempDir()
|
||||
serverGarden, err := garden.Init(serverDir)
|
||||
if err != nil {
|
||||
t.Fatalf("init server garden: %v", err)
|
||||
}
|
||||
|
||||
// Set up server with auth interceptor.
|
||||
auth := server.NewAuthInterceptorFromKeys([]ssh.PublicKey{signer.PublicKey()})
|
||||
lis := bufconn.Listen(bufSize)
|
||||
srv := grpc.NewServer(
|
||||
grpc.UnaryInterceptor(auth.UnaryInterceptor()),
|
||||
grpc.StreamInterceptor(auth.StreamInterceptor()),
|
||||
)
|
||||
sgardpb.RegisterGardenSyncServer(srv, server.New(serverGarden))
|
||||
t.Cleanup(func() { srv.Stop() })
|
||||
go func() { _ = srv.Serve(lis) }()
|
||||
|
||||
// Client with SSH credentials.
|
||||
creds := NewSSHCredentials(signer)
|
||||
conn, err := grpc.NewClient("passthrough:///bufconn",
|
||||
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
|
||||
return lis.Dial()
|
||||
}),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithPerRPCCredentials(creds),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
c := New(conn)
|
||||
|
||||
// Authenticated request should succeed.
|
||||
_, err = c.Pull(context.Background(), serverGarden)
|
||||
if err != nil {
|
||||
t.Fatalf("authenticated Pull should succeed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthIntegrationRejectsUnauthenticated(t *testing.T) {
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generating key: %v", err)
|
||||
}
|
||||
signer, err := ssh.NewSignerFromKey(priv)
|
||||
if err != nil {
|
||||
t.Fatalf("creating signer: %v", err)
|
||||
}
|
||||
|
||||
serverDir := t.TempDir()
|
||||
serverGarden, err := garden.Init(serverDir)
|
||||
if err != nil {
|
||||
t.Fatalf("init server garden: %v", err)
|
||||
}
|
||||
|
||||
auth := server.NewAuthInterceptorFromKeys([]ssh.PublicKey{signer.PublicKey()})
|
||||
lis := bufconn.Listen(bufSize)
|
||||
srv := grpc.NewServer(
|
||||
grpc.UnaryInterceptor(auth.UnaryInterceptor()),
|
||||
grpc.StreamInterceptor(auth.StreamInterceptor()),
|
||||
)
|
||||
sgardpb.RegisterGardenSyncServer(srv, server.New(serverGarden))
|
||||
t.Cleanup(func() { srv.Stop() })
|
||||
go func() { _ = srv.Serve(lis) }()
|
||||
|
||||
// Client WITHOUT credentials.
|
||||
conn, err := grpc.NewClient("passthrough:///bufconn",
|
||||
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
|
||||
return lis.Dial()
|
||||
}),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
c := New(conn)
|
||||
|
||||
_, err = c.Pull(context.Background(), serverGarden)
|
||||
if err == nil {
|
||||
t.Fatal("unauthenticated Pull should fail")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user