diff --git a/.gitignore b/.gitignore index 96340aa..fd4e14d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ /sgard +.claude/ +result diff --git a/client/auth.go b/client/auth.go index 9a43ca5..8250dee 100644 --- a/client/auth.go +++ b/client/auth.go @@ -7,59 +7,148 @@ import ( "fmt" "net" "os" - "strconv" + "path/filepath" "strings" + "sync" "time" + "github.com/kisom/sgard/sgardpb" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/status" ) -// SSHCredentials implements grpc.PerRPCCredentials using an SSH signer. -type SSHCredentials struct { - signer ssh.Signer +// TokenCredentials implements grpc.PerRPCCredentials using a cached JWT token. +// It is safe for concurrent use. +type TokenCredentials struct { + mu sync.RWMutex + token string } -// NewSSHCredentials creates credentials from an SSH signer. -func NewSSHCredentials(signer ssh.Signer) *SSHCredentials { - return &SSHCredentials{signer: signer} +// NewTokenCredentials creates credentials with an initial token (may be empty). +func NewTokenCredentials(token string) *TokenCredentials { + return &TokenCredentials{token: token} } -// GetRequestMetadata signs a nonce+timestamp and returns auth metadata. -func (c *SSHCredentials) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { - nonce := make([]byte, 32) - if _, err := rand.Read(nonce); err != nil { - return nil, fmt.Errorf("generating nonce: %w", err) +// SetToken updates the cached token. +func (c *TokenCredentials) SetToken(token string) { + c.mu.Lock() + defer c.mu.Unlock() + c.token = token +} + +// GetRequestMetadata returns the token as gRPC metadata. +func (c *TokenCredentials) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { + c.mu.RLock() + defer c.mu.RUnlock() + if c.token == "" { + return nil, nil } - - tsUnix := time.Now().Unix() - payload := buildPayload(nonce, tsUnix) - - sig, err := c.signer.Sign(rand.Reader, payload) - if err != nil { - return nil, fmt.Errorf("signing payload: %w", err) - } - - pubkey := c.signer.PublicKey() - pubkeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubkey))) - - return map[string]string{ - "x-sgard-auth-nonce": base64.StdEncoding.EncodeToString(nonce), - "x-sgard-auth-timestamp": strconv.FormatInt(tsUnix, 10), - "x-sgard-auth-signature": base64.StdEncoding.EncodeToString(ssh.Marshal(sig)), - "x-sgard-auth-pubkey": pubkeyStr, - }, nil + return map[string]string{"x-sgard-auth-token": c.token}, nil } -// RequireTransportSecurity returns false — auth is via SSH signatures, -// not TLS. Transport security can be added separately. -func (c *SSHCredentials) RequireTransportSecurity() bool { +// RequireTransportSecurity returns false. +func (c *TokenCredentials) RequireTransportSecurity() bool { return false } -// Verify that SSHCredentials implements the interface. -var _ credentials.PerRPCCredentials = (*SSHCredentials)(nil) +var _ credentials.PerRPCCredentials = (*TokenCredentials)(nil) + +// TokenPath returns the XDG-compliant path for the token cache. +// Uses $XDG_STATE_HOME/sgard/token, falling back to ~/.local/state/sgard/token. +func TokenPath() (string, error) { + stateHome := os.Getenv("XDG_STATE_HOME") + if stateHome == "" { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("determining home directory: %w", err) + } + stateHome = filepath.Join(home, ".local", "state") + } + return filepath.Join(stateHome, "sgard", "token"), nil +} + +// LoadCachedToken reads the token from the XDG state path. +// Returns empty string if the file doesn't exist. +func LoadCachedToken() string { + path, err := TokenPath() + if err != nil { + return "" + } + data, err := os.ReadFile(path) + if err != nil { + return "" + } + return strings.TrimSpace(string(data)) +} + +// SaveToken writes the token to the XDG state path with 0600 permissions. +func SaveToken(token string) error { + path, err := TokenPath() + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return fmt.Errorf("creating token directory: %w", err) + } + return os.WriteFile(path, []byte(token+"\n"), 0o600) +} + +// Authenticate calls the server's Authenticate RPC with an SSH-signed challenge. +// If challenge is non-nil (reauth fast path), uses the server-provided nonce. +// Otherwise generates a fresh nonce. +func Authenticate(ctx context.Context, rpc sgardpb.GardenSyncClient, signer ssh.Signer, challenge *sgardpb.ReauthChallenge) (string, error) { + var nonce []byte + var tsUnix int64 + + if challenge != nil { + nonce = challenge.GetNonce() + tsUnix = challenge.GetTimestamp() + } else { + var err error + nonce = make([]byte, 32) + if _, err = rand.Read(nonce); err != nil { + return "", fmt.Errorf("generating nonce: %w", err) + } + tsUnix = time.Now().Unix() + } + + payload := buildPayload(nonce, tsUnix) + sig, err := signer.Sign(rand.Reader, payload) + if err != nil { + return "", fmt.Errorf("signing challenge: %w", err) + } + + pubkeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) + + resp, err := rpc.Authenticate(ctx, &sgardpb.AuthenticateRequest{ + Nonce: nonce, + Timestamp: tsUnix, + Signature: ssh.Marshal(sig), + PublicKey: pubkeyStr, + }) + if err != nil { + return "", fmt.Errorf("authenticate RPC: %w", err) + } + + return resp.GetToken(), nil +} + +// ExtractReauthChallenge extracts a ReauthChallenge from a gRPC error's +// details, if present. Returns nil if not found. +func ExtractReauthChallenge(err error) *sgardpb.ReauthChallenge { + st, ok := status.FromError(err) + if !ok { + return nil + } + for _, detail := range st.Details() { + if challenge, ok := detail.(*sgardpb.ReauthChallenge); ok { + return challenge + } + } + return nil +} // buildPayload constructs nonce || timestamp (big-endian int64). func buildPayload(nonce []byte, tsUnix int64) []byte { @@ -81,7 +170,6 @@ func LoadSigner(keyPath string) (ssh.Signer, error) { return loadSignerFromFile(keyPath) } - // Try ssh-agent. if sock := os.Getenv("SSH_AUTH_SOCK"); sock != "" { conn, err := net.Dial("unix", sock) if err == nil { @@ -94,7 +182,6 @@ func LoadSigner(keyPath string) (ssh.Signer, error) { } } - // Try default key paths. home, err := os.UserHomeDir() if err != nil { return nil, fmt.Errorf("no SSH key found: %w", err) @@ -122,3 +209,40 @@ func loadSignerFromFile(path string) (ssh.Signer, error) { } return signer, nil } + +// SSHCredentials is kept for backward compatibility in tests. +// It signs every request with SSH (the old approach). +type SSHCredentials struct { + signer ssh.Signer +} + +func NewSSHCredentials(signer ssh.Signer) *SSHCredentials { + return &SSHCredentials{signer: signer} +} + +func (c *SSHCredentials) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { + nonce := make([]byte, 32) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("generating nonce: %w", err) + } + tsUnix := time.Now().Unix() + payload := buildPayload(nonce, tsUnix) + sig, err := c.signer.Sign(rand.Reader, payload) + if err != nil { + return nil, fmt.Errorf("signing: %w", err) + } + pubkeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(c.signer.PublicKey()))) + + // Send as both token-style metadata (won't work) AND the old SSH fields + // for the Authenticate RPC. But this is only used in legacy tests. + return map[string]string{ + "x-sgard-auth-nonce": base64.StdEncoding.EncodeToString(nonce), + "x-sgard-auth-timestamp": fmt.Sprintf("%d", tsUnix), + "x-sgard-auth-signature": base64.StdEncoding.EncodeToString(ssh.Marshal(sig)), + "x-sgard-auth-pubkey": pubkeyStr, + }, nil +} + +func (c *SSHCredentials) RequireTransportSecurity() bool { return false } + +var _ credentials.PerRPCCredentials = (*SSHCredentials)(nil) diff --git a/client/client.go b/client/client.go index bae412b..0c15c9c 100644 --- a/client/client.go +++ b/client/client.go @@ -10,28 +10,103 @@ import ( "github.com/kisom/sgard/garden" "github.com/kisom/sgard/server" "github.com/kisom/sgard/sgardpb" + "golang.org/x/crypto/ssh" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) const chunkSize = 64 * 1024 // 64 KiB // Client wraps a gRPC connection to a GardenSync server. type Client struct { - rpc sgardpb.GardenSyncClient + rpc sgardpb.GardenSyncClient + creds *TokenCredentials // may be nil (no auth) + signer ssh.Signer // may be nil (no auth) } -// New creates a Client from an existing gRPC connection. +// New creates a Client from an existing gRPC connection (no auth). func New(conn grpc.ClientConnInterface) *Client { return &Client{rpc: sgardpb.NewGardenSyncClient(conn)} } +// NewWithAuth creates a Client with token-based auth and auto-renewal. +// Loads any cached token automatically. +func NewWithAuth(conn grpc.ClientConnInterface, creds *TokenCredentials, signer ssh.Signer) *Client { + return &Client{ + rpc: sgardpb.NewGardenSyncClient(conn), + creds: creds, + signer: signer, + } +} + +// EnsureAuth ensures the client has a valid token. If no token is cached, +// authenticates with the server using the SSH signer. +func (c *Client) EnsureAuth(ctx context.Context) error { + if c.creds == nil || c.signer == nil { + return nil + } + + // If we already have a token, assume it's valid until the server says otherwise. + md, _ := c.creds.GetRequestMetadata(ctx) + if md != nil && md["x-sgard-auth-token"] != "" { + return nil + } + + // No token — do full auth. + return c.authenticate(ctx, nil) +} + +// authenticate calls the Authenticate RPC and caches the resulting token. +func (c *Client) authenticate(ctx context.Context, challenge *sgardpb.ReauthChallenge) error { + token, err := Authenticate(ctx, c.rpc, c.signer, challenge) + if err != nil { + return err + } + c.creds.SetToken(token) + _ = SaveToken(token) + return nil +} + +// retryOnAuth retries a function once after re-authenticating if it fails +// with Unauthenticated. +func (c *Client) retryOnAuth(ctx context.Context, fn func() error) error { + err := fn() + if err == nil || c.signer == nil { + return err + } + + st, ok := status.FromError(err) + if !ok || st.Code() != codes.Unauthenticated { + return err + } + + // Extract reauth challenge if present (fast path). + challenge := ExtractReauthChallenge(err) + if authErr := c.authenticate(ctx, challenge); authErr != nil { + return fmt.Errorf("re-authentication failed: %w", authErr) + } + + // Retry the original call. + return fn() +} + // Push sends the local manifest and any missing blobs to the server. // Returns the number of blobs sent, or an error. If the server is newer, -// returns ErrServerNewer. +// returns ErrServerNewer. Automatically re-authenticates if the token expires. func (c *Client) Push(ctx context.Context, g *garden.Garden) (int, error) { + var result int + err := c.retryOnAuth(ctx, func() error { + n, err := c.doPush(ctx, g) + result = n + return err + }) + return result, err +} + +func (c *Client) doPush(ctx context.Context, g *garden.Garden) (int, error) { localManifest := g.GetManifest() - // Step 1: send manifest, get decision. resp, err := c.rpc.PushManifest(ctx, &sgardpb.PushManifestRequest{ Manifest: server.ManifestToProto(localManifest), }) @@ -110,9 +185,18 @@ func (c *Client) Push(ctx context.Context, g *garden.Garden) (int, error) { // Pull downloads the server's manifest and any missing blobs to the local garden. // Returns the number of blobs received, or an error. If the local manifest is -// newer or equal, returns 0 with no error. +// newer or equal, returns 0 with no error. Automatically re-authenticates if needed. func (c *Client) Pull(ctx context.Context, g *garden.Garden) (int, error) { - // Step 1: get server manifest. + var result int + err := c.retryOnAuth(ctx, func() error { + n, err := c.doPull(ctx, g) + result = n + return err + }) + return result, err +} + +func (c *Client) doPull(ctx context.Context, g *garden.Garden) (int, error) { pullResp, err := c.rpc.PullManifest(ctx, &sgardpb.PullManifestRequest{}) if err != nil { return 0, fmt.Errorf("pull manifest: %w", err) @@ -190,12 +274,18 @@ func (c *Client) Pull(ctx context.Context, g *garden.Garden) (int, error) { } // Prune requests the server to remove orphaned blobs. Returns the count removed. +// Automatically re-authenticates if needed. func (c *Client) Prune(ctx context.Context) (int, error) { - resp, err := c.rpc.Prune(ctx, &sgardpb.PruneRequest{}) - if err != nil { - return 0, fmt.Errorf("prune: %w", err) - } - return int(resp.BlobsRemoved), nil + var result int + err := c.retryOnAuth(ctx, func() error { + resp, err := c.rpc.Prune(ctx, &sgardpb.PruneRequest{}) + if err != nil { + return fmt.Errorf("prune: %w", err) + } + result = int(resp.BlobsRemoved) + return nil + }) + return result, err } func writeAndVerify(g *garden.Garden, expectedHash string, data []byte) error { diff --git a/client/client_test.go b/client/client_test.go index d945ca7..132ec72 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -210,8 +210,9 @@ func TestPrune(t *testing.T) { } } -func TestAuthIntegration(t *testing.T) { - // Generate an ed25519 key pair. +var testJWTKey = []byte("test-jwt-secret-key-32-bytes!!") + +func TestTokenAuthIntegration(t *testing.T) { _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatalf("generating key: %v", err) @@ -227,19 +228,18 @@ func TestAuthIntegration(t *testing.T) { t.Fatalf("init server garden: %v", err) } - // Set up server with auth interceptor. - auth := server.NewAuthInterceptorFromKeys([]ssh.PublicKey{signer.PublicKey()}) + auth := server.NewAuthInterceptorFromKeys([]ssh.PublicKey{signer.PublicKey()}, testJWTKey) lis := bufconn.Listen(bufSize) srv := grpc.NewServer( grpc.UnaryInterceptor(auth.UnaryInterceptor()), grpc.StreamInterceptor(auth.StreamInterceptor()), ) - sgardpb.RegisterGardenSyncServer(srv, server.New(serverGarden)) + sgardpb.RegisterGardenSyncServer(srv, server.NewWithAuth(serverGarden, auth)) t.Cleanup(func() { srv.Stop() }) go func() { _ = srv.Serve(lis) }() - // Client with SSH credentials. - creds := NewSSHCredentials(signer) + // Client with token auth + auto-renewal. + creds := NewTokenCredentials("") conn, err := grpc.NewClient("passthrough:///bufconn", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return lis.Dial() @@ -252,16 +252,22 @@ func TestAuthIntegration(t *testing.T) { } t.Cleanup(func() { _ = conn.Close() }) - c := New(conn) + c := NewWithAuth(conn, creds, signer) - // Authenticated request should succeed. - _, err = c.Pull(context.Background(), serverGarden) + // No token yet — EnsureAuth should authenticate via SSH. + ctx := context.Background() + if err := c.EnsureAuth(ctx); err != nil { + t.Fatalf("EnsureAuth: %v", err) + } + + // Now requests should work. + _, err = c.Pull(ctx, serverGarden) if err != nil { t.Fatalf("authenticated Pull should succeed: %v", err) } } -func TestAuthIntegrationRejectsUnauthenticated(t *testing.T) { +func TestAuthRejectsUnauthenticated(t *testing.T) { _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatalf("generating key: %v", err) @@ -277,17 +283,17 @@ func TestAuthIntegrationRejectsUnauthenticated(t *testing.T) { t.Fatalf("init server garden: %v", err) } - auth := server.NewAuthInterceptorFromKeys([]ssh.PublicKey{signer.PublicKey()}) + auth := server.NewAuthInterceptorFromKeys([]ssh.PublicKey{signer.PublicKey()}, testJWTKey) lis := bufconn.Listen(bufSize) srv := grpc.NewServer( grpc.UnaryInterceptor(auth.UnaryInterceptor()), grpc.StreamInterceptor(auth.StreamInterceptor()), ) - sgardpb.RegisterGardenSyncServer(srv, server.New(serverGarden)) + sgardpb.RegisterGardenSyncServer(srv, server.NewWithAuth(serverGarden, auth)) t.Cleanup(func() { srv.Stop() }) go func() { _ = srv.Serve(lis) }() - // Client WITHOUT credentials. + // Client WITHOUT credentials — no token, no signer. conn, err := grpc.NewClient("passthrough:///bufconn", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return lis.Dial() diff --git a/client/e2e_test.go b/client/e2e_test.go index a3811d0..5783976 100644 --- a/client/e2e_test.go +++ b/client/e2e_test.go @@ -39,19 +39,20 @@ func TestE2EPushPullCycle(t *testing.T) { t.Fatalf("init server: %v", err) } - auth := server.NewAuthInterceptorFromKeys([]ssh.PublicKey{signer.PublicKey()}) + jwtKey := []byte("e2e-test-jwt-secret-key-32bytes!") + auth := server.NewAuthInterceptorFromKeys([]ssh.PublicKey{signer.PublicKey()}, jwtKey) lis := bufconn.Listen(bufSize) srv := grpc.NewServer( grpc.UnaryInterceptor(auth.UnaryInterceptor()), grpc.StreamInterceptor(auth.StreamInterceptor()), ) - sgardpb.RegisterGardenSyncServer(srv, server.New(serverGarden)) + sgardpb.RegisterGardenSyncServer(srv, server.NewWithAuth(serverGarden, auth)) t.Cleanup(func() { srv.Stop() }) go func() { _ = srv.Serve(lis) }() dial := func(t *testing.T) *Client { t.Helper() - creds := NewSSHCredentials(signer) + creds := NewTokenCredentials("") conn, err := grpc.NewClient("passthrough:///bufconn", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return lis.Dial() @@ -63,7 +64,11 @@ func TestE2EPushPullCycle(t *testing.T) { t.Fatalf("dial: %v", err) } t.Cleanup(func() { _ = conn.Close() }) - return New(conn) + c := NewWithAuth(conn, creds, signer) + if err := c.EnsureAuth(context.Background()); err != nil { + t.Fatalf("EnsureAuth: %v", err) + } + return c } ctx := context.Background() diff --git a/cmd/sgard/main.go b/cmd/sgard/main.go index b7fe1bc..311d50d 100644 --- a/cmd/sgard/main.go +++ b/cmd/sgard/main.go @@ -1,12 +1,16 @@ package main import ( + "context" "fmt" "os" "path/filepath" "strings" + "github.com/kisom/sgard/client" "github.com/spf13/cobra" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" ) var ( @@ -47,6 +51,41 @@ func resolveRemote() (string, error) { return "", fmt.Errorf("no remote configured; use --remote, SGARD_REMOTE, or create %s/remote", repoFlag) } +// dialRemote creates a gRPC client with token-based auth and auto-renewal. +func dialRemote(ctx context.Context) (*client.Client, func(), error) { + addr, err := resolveRemote() + if err != nil { + return nil, nil, err + } + + signer, err := client.LoadSigner(sshKeyFlag) + if err != nil { + return nil, nil, err + } + + cachedToken := client.LoadCachedToken() + creds := client.NewTokenCredentials(cachedToken) + + conn, err := grpc.NewClient(addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithPerRPCCredentials(creds), + ) + if err != nil { + return nil, nil, fmt.Errorf("connecting to %s: %w", addr, err) + } + + c := client.NewWithAuth(conn, creds, signer) + + // Ensure we have a valid token before proceeding. + if err := c.EnsureAuth(ctx); err != nil { + _ = conn.Close() + return nil, nil, fmt.Errorf("authentication: %w", err) + } + + cleanup := func() { _ = conn.Close() } + return c, cleanup, nil +} + func main() { rootCmd.PersistentFlags().StringVar(&repoFlag, "repo", defaultRepo(), "path to sgard repository") rootCmd.PersistentFlags().StringVar(&remoteFlag, "remote", "", "gRPC server address (host:port)") diff --git a/cmd/sgard/prune.go b/cmd/sgard/prune.go index 420a986..3bc26a5 100644 --- a/cmd/sgard/prune.go +++ b/cmd/sgard/prune.go @@ -4,11 +4,8 @@ import ( "context" "fmt" - "github.com/kisom/sgard/client" "github.com/kisom/sgard/garden" "github.com/spf13/cobra" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) var pruneCmd = &cobra.Command{ @@ -19,7 +16,7 @@ var pruneCmd = &cobra.Command{ addr, _ := resolveRemote() if addr != "" { - return pruneRemote(addr) + return pruneRemote() } return pruneLocal() }, @@ -40,24 +37,16 @@ func pruneLocal() error { return nil } -func pruneRemote(addr string) error { - signer, err := client.LoadSigner(sshKeyFlag) +func pruneRemote() error { + ctx := context.Background() + + c, cleanup, err := dialRemote(ctx) if err != nil { return err } + defer cleanup() - creds := client.NewSSHCredentials(signer) - conn, err := grpc.NewClient(addr, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithPerRPCCredentials(creds), - ) - if err != nil { - return fmt.Errorf("connecting to %s: %w", addr, err) - } - defer func() { _ = conn.Close() }() - - c := client.New(conn) - removed, err := c.Prune(context.Background()) + removed, err := c.Prune(ctx) if err != nil { return err } diff --git a/cmd/sgard/pull.go b/cmd/sgard/pull.go index 5ffb9af..9e135d2 100644 --- a/cmd/sgard/pull.go +++ b/cmd/sgard/pull.go @@ -4,44 +4,28 @@ import ( "context" "fmt" - "github.com/kisom/sgard/client" "github.com/kisom/sgard/garden" "github.com/spf13/cobra" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) var pullCmd = &cobra.Command{ Use: "pull", Short: "Pull checkpoint from remote server", RunE: func(cmd *cobra.Command, args []string) error { - addr, err := resolveRemote() - if err != nil { - return err - } + ctx := context.Background() g, err := garden.Open(repoFlag) if err != nil { return err } - signer, err := client.LoadSigner(sshKeyFlag) + c, cleanup, err := dialRemote(ctx) if err != nil { return err } + defer cleanup() - creds := client.NewSSHCredentials(signer) - conn, err := grpc.NewClient(addr, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithPerRPCCredentials(creds), - ) - if err != nil { - return fmt.Errorf("connecting to %s: %w", addr, err) - } - defer func() { _ = conn.Close() }() - - c := client.New(conn) - pulled, err := c.Pull(context.Background(), g) + pulled, err := c.Pull(ctx, g) if err != nil { return err } diff --git a/cmd/sgard/push.go b/cmd/sgard/push.go index ea17ec5..ff586cc 100644 --- a/cmd/sgard/push.go +++ b/cmd/sgard/push.go @@ -8,41 +8,26 @@ import ( "github.com/kisom/sgard/client" "github.com/kisom/sgard/garden" "github.com/spf13/cobra" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) var pushCmd = &cobra.Command{ Use: "push", Short: "Push local checkpoint to remote server", RunE: func(cmd *cobra.Command, args []string) error { - addr, err := resolveRemote() - if err != nil { - return err - } + ctx := context.Background() g, err := garden.Open(repoFlag) if err != nil { return err } - signer, err := client.LoadSigner(sshKeyFlag) + c, cleanup, err := dialRemote(ctx) if err != nil { return err } + defer cleanup() - creds := client.NewSSHCredentials(signer) - conn, err := grpc.NewClient(addr, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithPerRPCCredentials(creds), - ) - if err != nil { - return fmt.Errorf("connecting to %s: %w", addr, err) - } - defer func() { _ = conn.Close() }() - - c := client.New(conn) - pushed, err := c.Push(context.Background(), g) + pushed, err := c.Push(ctx, g) if errors.Is(err, client.ErrServerNewer) { fmt.Println("Server is newer; run sgard pull instead.") return nil diff --git a/cmd/sgardd/main.go b/cmd/sgardd/main.go index f98dad0..e9ea261 100644 --- a/cmd/sgardd/main.go +++ b/cmd/sgardd/main.go @@ -29,9 +29,10 @@ var rootCmd = &cobra.Command{ } var opts []grpc.ServerOption + var srvInstance *server.Server if authKeysPath != "" { - auth, err := server.NewAuthInterceptor(authKeysPath) + auth, err := server.NewAuthInterceptor(authKeysPath, repoPath) if err != nil { return fmt.Errorf("loading authorized keys: %w", err) } @@ -39,13 +40,15 @@ var rootCmd = &cobra.Command{ grpc.UnaryInterceptor(auth.UnaryInterceptor()), grpc.StreamInterceptor(auth.StreamInterceptor()), ) + srvInstance = server.NewWithAuth(g, auth) fmt.Printf("Auth enabled: %s\n", authKeysPath) } else { + srvInstance = server.New(g) fmt.Println("WARNING: no --authorized-keys specified, running without authentication") } srv := grpc.NewServer(opts...) - sgardpb.RegisterGardenSyncServer(srv, server.New(g)) + sgardpb.RegisterGardenSyncServer(srv, srvInstance) lis, err := net.Listen("tcp", listenAddr) if err != nil { diff --git a/go.mod b/go.mod index 0083bcb..92a8687 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/kisom/sgard go 1.25.7 require ( + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jonboulle/clockwork v0.5.0 // indirect github.com/spf13/cobra v1.10.2 // indirect diff --git a/go.sum b/go.sum index 89ff90f..5ad7c2d 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I= diff --git a/proto/sgard/v1/sgard.proto b/proto/sgard/v1/sgard.proto index f25453b..40a6c5b 100644 --- a/proto/sgard/v1/sgard.proto +++ b/proto/sgard/v1/sgard.proto @@ -82,8 +82,32 @@ message PruneResponse { int32 blobs_removed = 1; } +// Auth messages. + +message AuthenticateRequest { + bytes nonce = 1; // 32-byte nonce (server-provided or client-generated) + int64 timestamp = 2; // Unix seconds + bytes signature = 3; // SSH signature over (nonce || timestamp) + string public_key = 4; // SSH public key in authorized_keys format +} + +message AuthenticateResponse { + string token = 1; // JWT valid for 30 days +} + +// ReauthChallenge is embedded in Unauthenticated error details when a +// token is expired but was previously valid. The client signs this +// challenge to obtain a new token without generating its own nonce. +message ReauthChallenge { + bytes nonce = 1; // server-generated 32-byte nonce + int64 timestamp = 2; // server's current Unix timestamp +} + // GardenSync is the sgard remote sync service. service GardenSync { + // Authenticate exchanges an SSH-signed challenge for a JWT token. + rpc Authenticate(AuthenticateRequest) returns (AuthenticateResponse); + // Push flow: send manifest, then stream missing blobs. rpc PushManifest(PushManifestRequest) returns (PushManifestResponse); rpc PushBlobs(stream PushBlobsRequest) returns (PushBlobsResponse); diff --git a/server/auth.go b/server/auth.go index c3db067..7858d46 100644 --- a/server/auth.go +++ b/server/auth.go @@ -3,13 +3,14 @@ package server import ( "context" "crypto/rand" - "encoding/base64" "fmt" "os" - "strconv" + "path/filepath" "strings" "time" + "github.com/golang-jwt/jwt/v5" + "github.com/kisom/sgard/sgardpb" "golang.org/x/crypto/ssh" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -18,25 +19,21 @@ import ( ) const ( - // Metadata keys for auth. - metaNonce = "x-sgard-auth-nonce" - metaTimestamp = "x-sgard-auth-timestamp" - metaSignature = "x-sgard-auth-signature" - metaPubkey = "x-sgard-auth-pubkey" - - // authWindow is how far the timestamp can deviate from server time. + metaToken = "x-sgard-auth-token" authWindow = 5 * time.Minute + tokenTTL = 30 * 24 * time.Hour // 30 days ) -// AuthInterceptor verifies SSH key signatures on gRPC requests. +// AuthInterceptor verifies JWT tokens or SSH key signatures on gRPC requests. type AuthInterceptor struct { authorizedKeys map[string]ssh.PublicKey // keyed by fingerprint + jwtKey []byte // HMAC-SHA256 signing key } -// NewAuthInterceptor creates an interceptor from an authorized_keys file. -// The file uses the same format as ~/.ssh/authorized_keys. -func NewAuthInterceptor(path string) (*AuthInterceptor, error) { - data, err := os.ReadFile(path) +// NewAuthInterceptor creates an interceptor from an authorized_keys file +// and a repository path (for the JWT secret key). +func NewAuthInterceptor(authorizedKeysPath, repoPath string) (*AuthInterceptor, error) { + data, err := os.ReadFile(authorizedKeysPath) if err != nil { return nil, fmt.Errorf("reading authorized keys: %w", err) } @@ -54,26 +51,35 @@ func NewAuthInterceptor(path string) (*AuthInterceptor, error) { } if len(keys) == 0 { - return nil, fmt.Errorf("no valid keys found in %s", path) + return nil, fmt.Errorf("no valid keys found in %s", authorizedKeysPath) } - return &AuthInterceptor{authorizedKeys: keys}, nil + jwtKey, err := loadOrGenerateJWTKey(repoPath) + if err != nil { + return nil, fmt.Errorf("loading JWT key: %w", err) + } + + return &AuthInterceptor{authorizedKeys: keys, jwtKey: jwtKey}, nil } -// NewAuthInterceptorFromKeys creates an interceptor from pre-parsed keys. -// Intended for testing. -func NewAuthInterceptorFromKeys(keys []ssh.PublicKey) *AuthInterceptor { +// NewAuthInterceptorFromKeys creates an interceptor from pre-parsed keys +// and a provided JWT key. Intended for testing. +func NewAuthInterceptorFromKeys(keys []ssh.PublicKey, jwtKey []byte) *AuthInterceptor { m := make(map[string]ssh.PublicKey, len(keys)) for _, k := range keys { m[ssh.FingerprintSHA256(k)] = k } - return &AuthInterceptor{authorizedKeys: m} + return &AuthInterceptor{authorizedKeys: m, jwtKey: jwtKey} } // UnaryInterceptor returns a gRPC unary server interceptor. func (a *AuthInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - if err := a.verify(ctx); err != nil { + // Authenticate RPC is exempt from auth — it's how you get a token. + if strings.HasSuffix(info.FullMethod, "/Authenticate") { + return handler(ctx, req) + } + if err := a.verifyToken(ctx); err != nil { return nil, err } return handler(ctx, req) @@ -83,76 +89,161 @@ func (a *AuthInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor { // StreamInterceptor returns a gRPC stream server interceptor. func (a *AuthInterceptor) StreamInterceptor() grpc.StreamServerInterceptor { return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - if err := a.verify(ss.Context()); err != nil { + if err := a.verifyToken(ss.Context()); err != nil { return err } return handler(srv, ss) } } -func (a *AuthInterceptor) verify(ctx context.Context) error { +// Authenticate verifies an SSH-signed challenge and issues a JWT. +func (a *AuthInterceptor) Authenticate(_ context.Context, req *sgardpb.AuthenticateRequest) (*sgardpb.AuthenticateResponse, error) { + pubkeyStr := req.GetPublicKey() + pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubkeyStr)) + if err != nil { + return nil, status.Error(codes.Unauthenticated, "invalid public key") + } + + fp := ssh.FingerprintSHA256(pubkey) + authorized, ok := a.authorizedKeys[fp] + if !ok { + return nil, status.Errorf(codes.PermissionDenied, "key %s not authorized", fp) + } + + // Verify timestamp window. + tsUnix := req.GetTimestamp() + ts := time.Unix(tsUnix, 0) + if time.Since(ts).Abs() > authWindow { + return nil, status.Error(codes.Unauthenticated, "timestamp outside allowed window") + } + + // Verify signature. + payload := buildPayload(req.GetNonce(), tsUnix) + sig, err := parseSSHSignature(req.GetSignature()) + if err != nil { + return nil, status.Error(codes.Unauthenticated, "invalid signature format") + } + if err := authorized.Verify(payload, sig); err != nil { + return nil, status.Error(codes.Unauthenticated, "signature verification failed") + } + + // Issue JWT. + token, err := a.issueToken(fp) + if err != nil { + return nil, status.Errorf(codes.Internal, "issuing token: %v", err) + } + + return &sgardpb.AuthenticateResponse{Token: token}, nil +} + +func (a *AuthInterceptor) verifyToken(ctx context.Context) error { md, ok := metadata.FromIncomingContext(ctx) if !ok { return status.Error(codes.Unauthenticated, "missing metadata") } - nonceB64 := mdFirst(md, metaNonce) - tsStr := mdFirst(md, metaTimestamp) - sigB64 := mdFirst(md, metaSignature) - pubkeyStr := mdFirst(md, metaPubkey) - - if nonceB64 == "" || tsStr == "" || sigB64 == "" || pubkeyStr == "" { - return status.Error(codes.Unauthenticated, "missing auth metadata fields") + tokenStr := mdFirst(md, metaToken) + if tokenStr == "" { + return status.Error(codes.Unauthenticated, "missing auth token") } - // Parse timestamp and check window. - tsUnix, err := strconv.ParseInt(tsStr, 10, 64) - if err != nil { - return status.Error(codes.Unauthenticated, "invalid timestamp") - } - ts := time.Unix(tsUnix, 0) - if time.Since(ts).Abs() > authWindow { - return status.Error(codes.Unauthenticated, "timestamp outside allowed window") + claims := &jwt.RegisteredClaims{} + token, err := jwt.ParseWithClaims(tokenStr, claims, func(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return a.jwtKey, nil + }) + + if err != nil || !token.Valid { + // Check if the token is expired but otherwise valid. + if a.isExpiredButValid(tokenStr, claims) { + return a.reauthError() + } + return status.Error(codes.Unauthenticated, "invalid token") } - // Parse public key and check authorization. - pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubkeyStr)) - if err != nil { - return status.Error(codes.Unauthenticated, "invalid public key") - } - fp := ssh.FingerprintSHA256(pubkey) - authorized, ok := a.authorizedKeys[fp] - if !ok { - return status.Errorf(codes.PermissionDenied, "key %s not authorized", fp) - } - - // Decode nonce and signature. - nonce, err := base64.StdEncoding.DecodeString(nonceB64) - if err != nil { - return status.Error(codes.Unauthenticated, "invalid nonce encoding") - } - - sigBytes, err := base64.StdEncoding.DecodeString(sigB64) - if err != nil { - return status.Error(codes.Unauthenticated, "invalid signature encoding") - } - - sig, err := parseSSHSignature(sigBytes) - if err != nil { - return status.Error(codes.Unauthenticated, "invalid signature format") - } - - // Build the signed payload: nonce + timestamp bytes. - payload := buildPayload(nonce, tsUnix) - - // Verify. - if err := authorized.Verify(payload, sig); err != nil { - return status.Error(codes.Unauthenticated, "signature verification failed") + // Verify the fingerprint is still authorized. + fp := claims.Subject + if _, ok := a.authorizedKeys[fp]; !ok { + return status.Errorf(codes.PermissionDenied, "key %s no longer authorized", fp) } return nil } +// isExpiredButValid checks if a token has a valid signature and the +// fingerprint is still in authorized_keys, but the token is expired. +func (a *AuthInterceptor) isExpiredButValid(tokenStr string, claims *jwt.RegisteredClaims) bool { + // Re-parse without time validation. + reClaims := &jwt.RegisteredClaims{} + _, err := jwt.ParseWithClaims(tokenStr, reClaims, func(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method") + } + return a.jwtKey, nil + }, jwt.WithoutClaimsValidation()) + if err != nil { + return false + } + + fp := reClaims.Subject + _, authorized := a.authorizedKeys[fp] + return authorized +} + +// reauthError returns an Unauthenticated error with a ReauthChallenge +// embedded in the error details. +func (a *AuthInterceptor) reauthError() error { + nonce := make([]byte, 32) + if _, err := rand.Read(nonce); err != nil { + return status.Error(codes.Internal, "generating reauth nonce") + } + + challenge := &sgardpb.ReauthChallenge{ + Nonce: nonce, + Timestamp: time.Now().Unix(), + } + + st, err := status.New(codes.Unauthenticated, "token expired"). + WithDetails(challenge) + if err != nil { + return status.Error(codes.Unauthenticated, "token expired") + } + return st.Err() +} + +func (a *AuthInterceptor) issueToken(fingerprint string) (string, error) { + now := time.Now() + claims := &jwt.RegisteredClaims{ + Subject: fingerprint, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(tokenTTL)), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(a.jwtKey) +} + +func loadOrGenerateJWTKey(repoPath string) ([]byte, error) { + keyPath := filepath.Join(repoPath, "jwt.key") + + data, err := os.ReadFile(keyPath) + if err == nil && len(data) >= 32 { + return data[:32], nil + } + + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("generating JWT key: %w", err) + } + + if err := os.WriteFile(keyPath, key, 0o600); err != nil { + return nil, fmt.Errorf("writing JWT key: %w", err) + } + + return key, nil +} + // buildPayload constructs the message that is signed: nonce || timestamp (big-endian int64). func buildPayload(nonce []byte, tsUnix int64) []byte { payload := make([]byte, len(nonce)+8) @@ -187,7 +278,6 @@ func parseSSHSignature(data []byte) (*ssh.Signature, error) { return nil, fmt.Errorf("signature too short") } - // SSH signature wire format: string format, string blob formatLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) if 4+formatLen > len(data) { return nil, fmt.Errorf("invalid format length") @@ -204,8 +294,6 @@ func parseSSHSignature(data []byte) (*ssh.Signature, error) { } blob := rest[4 : 4+blobLen] - _ = strings.TrimSpace(format) // ensure format is clean - return &ssh.Signature{ Format: format, Blob: blob, diff --git a/server/auth_test.go b/server/auth_test.go index bb08628..3e5c365 100644 --- a/server/auth_test.go +++ b/server/auth_test.go @@ -4,15 +4,18 @@ import ( "context" "crypto/ed25519" "crypto/rand" - "encoding/base64" - "strconv" + "strings" "testing" "time" + "github.com/golang-jwt/jwt/v5" + "github.com/kisom/sgard/sgardpb" "golang.org/x/crypto/ssh" "google.golang.org/grpc/metadata" ) +var testJWTKey = []byte("test-jwt-secret-key-32-bytes!!") + func generateTestKey(t *testing.T) (ssh.Signer, ssh.PublicKey) { t.Helper() _, priv, err := ed25519.GenerateKey(rand.Reader) @@ -26,94 +29,118 @@ func generateTestKey(t *testing.T) (ssh.Signer, ssh.PublicKey) { return signer, signer.PublicKey() } -func signedContext(t *testing.T, signer ssh.Signer) context.Context { - t.Helper() +func TestAuthenticateAndVerifyToken(t *testing.T) { + signer, pubkey := generateTestKey(t) + auth := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey}, testJWTKey) - nonce, err := GenerateNonce() - if err != nil { - t.Fatalf("generating nonce: %v", err) - } + // Generate a signed challenge. + nonce, _ := GenerateNonce() tsUnix := time.Now().Unix() payload := buildPayload(nonce, tsUnix) - sig, err := signer.Sign(rand.Reader, payload) if err != nil { t.Fatalf("signing: %v", err) } - pubkeyStr := string(ssh.MarshalAuthorizedKey(signer.PublicKey())) + pubkeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) - md := metadata.New(map[string]string{ - metaNonce: base64.StdEncoding.EncodeToString(nonce), - metaTimestamp: strconv.FormatInt(tsUnix, 10), - metaSignature: base64.StdEncoding.EncodeToString(ssh.Marshal(sig)), - metaPubkey: pubkeyStr, + // Call Authenticate. + resp, err := auth.Authenticate(context.Background(), &sgardpb.AuthenticateRequest{ + Nonce: nonce, + Timestamp: tsUnix, + Signature: ssh.Marshal(sig), + PublicKey: pubkeyStr, }) - return metadata.NewIncomingContext(context.Background(), md) -} + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + if resp.Token == "" { + t.Fatal("expected non-empty token") + } -func TestAuthVerifyValid(t *testing.T) { - signer, pubkey := generateTestKey(t) - interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey}) - - ctx := signedContext(t, signer) - if err := interceptor.verify(ctx); err != nil { - t.Fatalf("verify should succeed: %v", err) + // Use the token in metadata. + md := metadata.New(map[string]string{metaToken: resp.Token}) + ctx := metadata.NewIncomingContext(context.Background(), md) + if err := auth.verifyToken(ctx); err != nil { + t.Fatalf("verifyToken should accept valid token: %v", err) } } -func TestAuthRejectUnauthenticated(t *testing.T) { +func TestRejectMissingToken(t *testing.T) { _, pubkey := generateTestKey(t) - interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey}) + auth := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey}, testJWTKey) // No metadata at all. - ctx := context.Background() - if err := interceptor.verify(ctx); err == nil { - t.Fatal("verify should reject missing metadata") + if err := auth.verifyToken(context.Background()); err == nil { + t.Fatal("should reject missing metadata") + } + + // Empty metadata. + md := metadata.New(nil) + ctx := metadata.NewIncomingContext(context.Background(), md) + if err := auth.verifyToken(ctx); err == nil { + t.Fatal("should reject missing token") } } -func TestAuthRejectUnauthorizedKey(t *testing.T) { +func TestRejectUnauthorizedKey(t *testing.T) { signer1, _ := generateTestKey(t) _, pubkey2 := generateTestKey(t) - // Interceptor knows key2 but request is signed by key1. - interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey2}) + // Auth only knows pubkey2, but we authenticate with signer1. + auth := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey2}, testJWTKey) - ctx := signedContext(t, signer1) - if err := interceptor.verify(ctx); err == nil { - t.Fatal("verify should reject unauthorized key") - } -} - -func TestAuthRejectExpiredTimestamp(t *testing.T) { - signer, pubkey := generateTestKey(t) - interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey}) - - nonce, err := GenerateNonce() - if err != nil { - t.Fatalf("generating nonce: %v", err) - } - // Timestamp 10 minutes ago — outside the 5-minute window. - tsUnix := time.Now().Add(-10 * time.Minute).Unix() + nonce, _ := GenerateNonce() + tsUnix := time.Now().Unix() payload := buildPayload(nonce, tsUnix) + sig, _ := signer1.Sign(rand.Reader, payload) + pubkeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer1.PublicKey()))) - sig, err := signer.Sign(rand.Reader, payload) - if err != nil { - t.Fatalf("signing: %v", err) - } - - pubkeyStr := string(ssh.MarshalAuthorizedKey(signer.PublicKey())) - - md := metadata.New(map[string]string{ - metaNonce: base64.StdEncoding.EncodeToString(nonce), - metaTimestamp: strconv.FormatInt(tsUnix, 10), - metaSignature: base64.StdEncoding.EncodeToString(ssh.Marshal(sig)), - metaPubkey: pubkeyStr, + _, err := auth.Authenticate(context.Background(), &sgardpb.AuthenticateRequest{ + Nonce: nonce, + Timestamp: tsUnix, + Signature: ssh.Marshal(sig), + PublicKey: pubkeyStr, }) - ctx := metadata.NewIncomingContext(context.Background(), md) - - if err := interceptor.verify(ctx); err == nil { - t.Fatal("verify should reject expired timestamp") + if err == nil { + t.Fatal("should reject unauthorized key") } } + +func TestExpiredTokenReturnsChallenge(t *testing.T) { + signer, pubkey := generateTestKey(t) + auth := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey}, testJWTKey) + + // Issue a token, then manually create an expired one. + fp := ssh.FingerprintSHA256(signer.PublicKey()) + expiredToken, err := auth.issueExpiredToken(fp) + if err != nil { + t.Fatalf("issuing expired token: %v", err) + } + + md := metadata.New(map[string]string{metaToken: expiredToken}) + ctx := metadata.NewIncomingContext(context.Background(), md) + err = auth.verifyToken(ctx) + if err == nil { + t.Fatal("should reject expired token") + } + + // The error should contain a ReauthChallenge in its details. + // We can't easily extract it here without the client helper, + // but verify the error message indicates expiry. + if !strings.Contains(err.Error(), "expired") { + t.Errorf("error should mention expiry, got: %v", err) + } +} + +// issueExpiredToken is a test helper that creates an already-expired JWT. +func (a *AuthInterceptor) issueExpiredToken(fingerprint string) (string, error) { + past := time.Now().Add(-time.Hour) + claims := &jwt.RegisteredClaims{ + Subject: fingerprint, + IssuedAt: jwt.NewNumericDate(past.Add(-24 * time.Hour)), + ExpiresAt: jwt.NewNumericDate(past), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(a.jwtKey) +} diff --git a/server/server.go b/server/server.go index aedbefe..c0c12dd 100644 --- a/server/server.go +++ b/server/server.go @@ -24,6 +24,7 @@ type Server struct { garden *garden.Garden mu sync.RWMutex pendingManifest *manifest.Manifest + auth *AuthInterceptor // nil if auth is disabled } // New creates a new Server backed by the given Garden. @@ -31,6 +32,19 @@ func New(g *garden.Garden) *Server { return &Server{garden: g} } +// NewWithAuth creates a new Server with authentication enabled. +func NewWithAuth(g *garden.Garden, auth *AuthInterceptor) *Server { + return &Server{garden: g, auth: auth} +} + +// Authenticate handles the auth RPC by delegating to the AuthInterceptor. +func (s *Server) Authenticate(ctx context.Context, req *sgardpb.AuthenticateRequest) (*sgardpb.AuthenticateResponse, error) { + if s.auth == nil { + return nil, status.Error(codes.Unimplemented, "authentication not configured") + } + return s.auth.Authenticate(ctx, req) +} + // PushManifest compares the client manifest against the server manifest and // decides whether to accept, reject, or report up-to-date. func (s *Server) PushManifest(_ context.Context, req *sgardpb.PushManifestRequest) (*sgardpb.PushManifestResponse, error) { diff --git a/sgardpb/sgard.pb.go b/sgardpb/sgard.pb.go index 0b78166..4f5f32c 100644 --- a/sgardpb/sgard.pb.go +++ b/sgardpb/sgard.pb.go @@ -730,6 +730,173 @@ func (x *PruneResponse) GetBlobsRemoved() int32 { return 0 } +type AuthenticateRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Nonce []byte `protobuf:"bytes,1,opt,name=nonce,proto3" json:"nonce,omitempty"` // 32-byte nonce (server-provided or client-generated) + Timestamp int64 `protobuf:"varint,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` // Unix seconds + Signature []byte `protobuf:"bytes,3,opt,name=signature,proto3" json:"signature,omitempty"` // SSH signature over (nonce || timestamp) + PublicKey string `protobuf:"bytes,4,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"` // SSH public key in authorized_keys format + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthenticateRequest) Reset() { + *x = AuthenticateRequest{} + mi := &file_sgard_v1_sgard_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthenticateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthenticateRequest) ProtoMessage() {} + +func (x *AuthenticateRequest) ProtoReflect() protoreflect.Message { + mi := &file_sgard_v1_sgard_proto_msgTypes[13] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthenticateRequest.ProtoReflect.Descriptor instead. +func (*AuthenticateRequest) Descriptor() ([]byte, []int) { + return file_sgard_v1_sgard_proto_rawDescGZIP(), []int{13} +} + +func (x *AuthenticateRequest) GetNonce() []byte { + if x != nil { + return x.Nonce + } + return nil +} + +func (x *AuthenticateRequest) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + +func (x *AuthenticateRequest) GetSignature() []byte { + if x != nil { + return x.Signature + } + return nil +} + +func (x *AuthenticateRequest) GetPublicKey() string { + if x != nil { + return x.PublicKey + } + return "" +} + +type AuthenticateResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` // JWT valid for 30 days + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AuthenticateResponse) Reset() { + *x = AuthenticateResponse{} + mi := &file_sgard_v1_sgard_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AuthenticateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AuthenticateResponse) ProtoMessage() {} + +func (x *AuthenticateResponse) ProtoReflect() protoreflect.Message { + mi := &file_sgard_v1_sgard_proto_msgTypes[14] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AuthenticateResponse.ProtoReflect.Descriptor instead. +func (*AuthenticateResponse) Descriptor() ([]byte, []int) { + return file_sgard_v1_sgard_proto_rawDescGZIP(), []int{14} +} + +func (x *AuthenticateResponse) GetToken() string { + if x != nil { + return x.Token + } + return "" +} + +// ReauthChallenge is embedded in Unauthenticated error details when a +// token is expired but was previously valid. The client signs this +// challenge to obtain a new token without generating its own nonce. +type ReauthChallenge struct { + state protoimpl.MessageState `protogen:"open.v1"` + Nonce []byte `protobuf:"bytes,1,opt,name=nonce,proto3" json:"nonce,omitempty"` // server-generated 32-byte nonce + Timestamp int64 `protobuf:"varint,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` // server's current Unix timestamp + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReauthChallenge) Reset() { + *x = ReauthChallenge{} + mi := &file_sgard_v1_sgard_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReauthChallenge) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReauthChallenge) ProtoMessage() {} + +func (x *ReauthChallenge) ProtoReflect() protoreflect.Message { + mi := &file_sgard_v1_sgard_proto_msgTypes[15] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReauthChallenge.ProtoReflect.Descriptor instead. +func (*ReauthChallenge) Descriptor() ([]byte, []int) { + return file_sgard_v1_sgard_proto_rawDescGZIP(), []int{15} +} + +func (x *ReauthChallenge) GetNonce() []byte { + if x != nil { + return x.Nonce + } + return nil +} + +func (x *ReauthChallenge) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + var File_sgard_v1_sgard_proto protoreflect.FileDescriptor const file_sgard_v1_sgard_proto_rawDesc = "" + @@ -776,9 +943,21 @@ const file_sgard_v1_sgard_proto_rawDesc = "" + "\x05chunk\x18\x01 \x01(\v2\x13.sgard.v1.BlobChunkR\x05chunk\"\x0e\n" + "\fPruneRequest\"4\n" + "\rPruneResponse\x12#\n" + - "\rblobs_removed\x18\x01 \x01(\x05R\fblobsRemoved2\xf4\x02\n" + + "\rblobs_removed\x18\x01 \x01(\x05R\fblobsRemoved\"\x86\x01\n" + + "\x13AuthenticateRequest\x12\x14\n" + + "\x05nonce\x18\x01 \x01(\fR\x05nonce\x12\x1c\n" + + "\ttimestamp\x18\x02 \x01(\x03R\ttimestamp\x12\x1c\n" + + "\tsignature\x18\x03 \x01(\fR\tsignature\x12\x1d\n" + + "\n" + + "public_key\x18\x04 \x01(\tR\tpublicKey\",\n" + + "\x14AuthenticateResponse\x12\x14\n" + + "\x05token\x18\x01 \x01(\tR\x05token\"E\n" + + "\x0fReauthChallenge\x12\x14\n" + + "\x05nonce\x18\x01 \x01(\fR\x05nonce\x12\x1c\n" + + "\ttimestamp\x18\x02 \x01(\x03R\ttimestamp2\xc3\x03\n" + "\n" + "GardenSync\x12M\n" + + "\fAuthenticate\x12\x1d.sgard.v1.AuthenticateRequest\x1a\x1e.sgard.v1.AuthenticateResponse\x12M\n" + "\fPushManifest\x12\x1d.sgard.v1.PushManifestRequest\x1a\x1e.sgard.v1.PushManifestResponse\x12F\n" + "\tPushBlobs\x12\x1a.sgard.v1.PushBlobsRequest\x1a\x1b.sgard.v1.PushBlobsResponse(\x01\x12M\n" + "\fPullManifest\x12\x1d.sgard.v1.PullManifestRequest\x1a\x1e.sgard.v1.PullManifestResponse\x12F\n" + @@ -798,7 +977,7 @@ func file_sgard_v1_sgard_proto_rawDescGZIP() []byte { } var file_sgard_v1_sgard_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_sgard_v1_sgard_proto_msgTypes = make([]protoimpl.MessageInfo, 13) +var file_sgard_v1_sgard_proto_msgTypes = make([]protoimpl.MessageInfo, 16) var file_sgard_v1_sgard_proto_goTypes = []any{ (PushManifestResponse_Decision)(0), // 0: sgard.v1.PushManifestResponse.Decision (*ManifestEntry)(nil), // 1: sgard.v1.ManifestEntry @@ -814,31 +993,36 @@ var file_sgard_v1_sgard_proto_goTypes = []any{ (*PullBlobsResponse)(nil), // 11: sgard.v1.PullBlobsResponse (*PruneRequest)(nil), // 12: sgard.v1.PruneRequest (*PruneResponse)(nil), // 13: sgard.v1.PruneResponse - (*timestamppb.Timestamp)(nil), // 14: google.protobuf.Timestamp + (*AuthenticateRequest)(nil), // 14: sgard.v1.AuthenticateRequest + (*AuthenticateResponse)(nil), // 15: sgard.v1.AuthenticateResponse + (*ReauthChallenge)(nil), // 16: sgard.v1.ReauthChallenge + (*timestamppb.Timestamp)(nil), // 17: google.protobuf.Timestamp } var file_sgard_v1_sgard_proto_depIdxs = []int32{ - 14, // 0: sgard.v1.ManifestEntry.updated:type_name -> google.protobuf.Timestamp - 14, // 1: sgard.v1.Manifest.created:type_name -> google.protobuf.Timestamp - 14, // 2: sgard.v1.Manifest.updated:type_name -> google.protobuf.Timestamp + 17, // 0: sgard.v1.ManifestEntry.updated:type_name -> google.protobuf.Timestamp + 17, // 1: sgard.v1.Manifest.created:type_name -> google.protobuf.Timestamp + 17, // 2: sgard.v1.Manifest.updated:type_name -> google.protobuf.Timestamp 1, // 3: sgard.v1.Manifest.files:type_name -> sgard.v1.ManifestEntry 2, // 4: sgard.v1.PushManifestRequest.manifest:type_name -> sgard.v1.Manifest 0, // 5: sgard.v1.PushManifestResponse.decision:type_name -> sgard.v1.PushManifestResponse.Decision - 14, // 6: sgard.v1.PushManifestResponse.server_updated:type_name -> google.protobuf.Timestamp + 17, // 6: sgard.v1.PushManifestResponse.server_updated:type_name -> google.protobuf.Timestamp 3, // 7: sgard.v1.PushBlobsRequest.chunk:type_name -> sgard.v1.BlobChunk 2, // 8: sgard.v1.PullManifestResponse.manifest:type_name -> sgard.v1.Manifest 3, // 9: sgard.v1.PullBlobsResponse.chunk:type_name -> sgard.v1.BlobChunk - 4, // 10: sgard.v1.GardenSync.PushManifest:input_type -> sgard.v1.PushManifestRequest - 6, // 11: sgard.v1.GardenSync.PushBlobs:input_type -> sgard.v1.PushBlobsRequest - 8, // 12: sgard.v1.GardenSync.PullManifest:input_type -> sgard.v1.PullManifestRequest - 10, // 13: sgard.v1.GardenSync.PullBlobs:input_type -> sgard.v1.PullBlobsRequest - 12, // 14: sgard.v1.GardenSync.Prune:input_type -> sgard.v1.PruneRequest - 5, // 15: sgard.v1.GardenSync.PushManifest:output_type -> sgard.v1.PushManifestResponse - 7, // 16: sgard.v1.GardenSync.PushBlobs:output_type -> sgard.v1.PushBlobsResponse - 9, // 17: sgard.v1.GardenSync.PullManifest:output_type -> sgard.v1.PullManifestResponse - 11, // 18: sgard.v1.GardenSync.PullBlobs:output_type -> sgard.v1.PullBlobsResponse - 13, // 19: sgard.v1.GardenSync.Prune:output_type -> sgard.v1.PruneResponse - 15, // [15:20] is the sub-list for method output_type - 10, // [10:15] is the sub-list for method input_type + 14, // 10: sgard.v1.GardenSync.Authenticate:input_type -> sgard.v1.AuthenticateRequest + 4, // 11: sgard.v1.GardenSync.PushManifest:input_type -> sgard.v1.PushManifestRequest + 6, // 12: sgard.v1.GardenSync.PushBlobs:input_type -> sgard.v1.PushBlobsRequest + 8, // 13: sgard.v1.GardenSync.PullManifest:input_type -> sgard.v1.PullManifestRequest + 10, // 14: sgard.v1.GardenSync.PullBlobs:input_type -> sgard.v1.PullBlobsRequest + 12, // 15: sgard.v1.GardenSync.Prune:input_type -> sgard.v1.PruneRequest + 15, // 16: sgard.v1.GardenSync.Authenticate:output_type -> sgard.v1.AuthenticateResponse + 5, // 17: sgard.v1.GardenSync.PushManifest:output_type -> sgard.v1.PushManifestResponse + 7, // 18: sgard.v1.GardenSync.PushBlobs:output_type -> sgard.v1.PushBlobsResponse + 9, // 19: sgard.v1.GardenSync.PullManifest:output_type -> sgard.v1.PullManifestResponse + 11, // 20: sgard.v1.GardenSync.PullBlobs:output_type -> sgard.v1.PullBlobsResponse + 13, // 21: sgard.v1.GardenSync.Prune:output_type -> sgard.v1.PruneResponse + 16, // [16:22] is the sub-list for method output_type + 10, // [10:16] is the sub-list for method input_type 10, // [10:10] is the sub-list for extension type_name 10, // [10:10] is the sub-list for extension extendee 0, // [0:10] is the sub-list for field type_name @@ -855,7 +1039,7 @@ func file_sgard_v1_sgard_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_sgard_v1_sgard_proto_rawDesc), len(file_sgard_v1_sgard_proto_rawDesc)), NumEnums: 1, - NumMessages: 13, + NumMessages: 16, NumExtensions: 0, NumServices: 1, }, diff --git a/sgardpb/sgard_grpc.pb.go b/sgardpb/sgard_grpc.pb.go index 62d2477..3dd5976 100644 --- a/sgardpb/sgard_grpc.pb.go +++ b/sgardpb/sgard_grpc.pb.go @@ -19,6 +19,7 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( + GardenSync_Authenticate_FullMethodName = "/sgard.v1.GardenSync/Authenticate" GardenSync_PushManifest_FullMethodName = "/sgard.v1.GardenSync/PushManifest" GardenSync_PushBlobs_FullMethodName = "/sgard.v1.GardenSync/PushBlobs" GardenSync_PullManifest_FullMethodName = "/sgard.v1.GardenSync/PullManifest" @@ -32,6 +33,8 @@ const ( // // GardenSync is the sgard remote sync service. type GardenSyncClient interface { + // Authenticate exchanges an SSH-signed challenge for a JWT token. + Authenticate(ctx context.Context, in *AuthenticateRequest, opts ...grpc.CallOption) (*AuthenticateResponse, error) // Push flow: send manifest, then stream missing blobs. PushManifest(ctx context.Context, in *PushManifestRequest, opts ...grpc.CallOption) (*PushManifestResponse, error) PushBlobs(ctx context.Context, opts ...grpc.CallOption) (grpc.ClientStreamingClient[PushBlobsRequest, PushBlobsResponse], error) @@ -50,6 +53,16 @@ func NewGardenSyncClient(cc grpc.ClientConnInterface) GardenSyncClient { return &gardenSyncClient{cc} } +func (c *gardenSyncClient) Authenticate(ctx context.Context, in *AuthenticateRequest, opts ...grpc.CallOption) (*AuthenticateResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(AuthenticateResponse) + err := c.cc.Invoke(ctx, GardenSync_Authenticate_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *gardenSyncClient) PushManifest(ctx context.Context, in *PushManifestRequest, opts ...grpc.CallOption) (*PushManifestResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(PushManifestResponse) @@ -118,6 +131,8 @@ func (c *gardenSyncClient) Prune(ctx context.Context, in *PruneRequest, opts ... // // GardenSync is the sgard remote sync service. type GardenSyncServer interface { + // Authenticate exchanges an SSH-signed challenge for a JWT token. + Authenticate(context.Context, *AuthenticateRequest) (*AuthenticateResponse, error) // Push flow: send manifest, then stream missing blobs. PushManifest(context.Context, *PushManifestRequest) (*PushManifestResponse, error) PushBlobs(grpc.ClientStreamingServer[PushBlobsRequest, PushBlobsResponse]) error @@ -136,6 +151,9 @@ type GardenSyncServer interface { // pointer dereference when methods are called. type UnimplementedGardenSyncServer struct{} +func (UnimplementedGardenSyncServer) Authenticate(context.Context, *AuthenticateRequest) (*AuthenticateResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Authenticate not implemented") +} func (UnimplementedGardenSyncServer) PushManifest(context.Context, *PushManifestRequest) (*PushManifestResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method PushManifest not implemented") } @@ -172,6 +190,24 @@ func RegisterGardenSyncServer(s grpc.ServiceRegistrar, srv GardenSyncServer) { s.RegisterService(&GardenSync_ServiceDesc, srv) } +func _GardenSync_Authenticate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AuthenticateRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(GardenSyncServer).Authenticate(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: GardenSync_Authenticate_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(GardenSyncServer).Authenticate(ctx, req.(*AuthenticateRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _GardenSync_PushManifest_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(PushManifestRequest) if err := dec(in); err != nil { @@ -251,6 +287,10 @@ var GardenSync_ServiceDesc = grpc.ServiceDesc{ ServiceName: "sgard.v1.GardenSync", HandlerType: (*GardenSyncServer)(nil), Methods: []grpc.MethodDesc{ + { + MethodName: "Authenticate", + Handler: _GardenSync_Authenticate_Handler, + }, { MethodName: "PushManifest", Handler: _GardenSync_PushManifest_Handler,