Implement JWT token auth with transparent auto-renewal.

Replace per-call SSH signing with a two-layer auth system:

Server: AuthInterceptor verifies JWT tokens (HMAC-SHA256 signed with
repo-local jwt.key). Authenticate RPC accepts SSH-signed challenges
and issues 30-day JWTs. Expired-but-valid tokens return a
ReauthChallenge in error details (server-provided nonce for fast
re-auth). Authenticate RPC is exempt from token requirement.

Client: TokenCredentials replaces SSHCredentials as the primary
PerRPCCredentials. NewWithAuth creates clients with auto-renewal —
EnsureAuth obtains initial token, retryOnAuth catches Unauthenticated
errors and re-authenticates transparently. Token cached at
$XDG_STATE_HOME/sgard/token (0600).

CLI: dialRemote() helper handles token loading, connection setup,
and initial auth. Push/pull/prune commands simplified to use it.

Proto: Added Authenticate RPC, AuthenticateRequest/Response,
ReauthChallenge messages.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-24 00:52:16 -07:00
parent b7b1b27064
commit edef642025
18 changed files with 890 additions and 283 deletions

View File

@@ -3,13 +3,14 @@ package server
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"os"
"strconv"
"path/filepath"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/kisom/sgard/sgardpb"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@@ -18,25 +19,21 @@ import (
)
const (
// Metadata keys for auth.
metaNonce = "x-sgard-auth-nonce"
metaTimestamp = "x-sgard-auth-timestamp"
metaSignature = "x-sgard-auth-signature"
metaPubkey = "x-sgard-auth-pubkey"
// authWindow is how far the timestamp can deviate from server time.
metaToken = "x-sgard-auth-token"
authWindow = 5 * time.Minute
tokenTTL = 30 * 24 * time.Hour // 30 days
)
// AuthInterceptor verifies SSH key signatures on gRPC requests.
// AuthInterceptor verifies JWT tokens or SSH key signatures on gRPC requests.
type AuthInterceptor struct {
authorizedKeys map[string]ssh.PublicKey // keyed by fingerprint
jwtKey []byte // HMAC-SHA256 signing key
}
// NewAuthInterceptor creates an interceptor from an authorized_keys file.
// The file uses the same format as ~/.ssh/authorized_keys.
func NewAuthInterceptor(path string) (*AuthInterceptor, error) {
data, err := os.ReadFile(path)
// NewAuthInterceptor creates an interceptor from an authorized_keys file
// and a repository path (for the JWT secret key).
func NewAuthInterceptor(authorizedKeysPath, repoPath string) (*AuthInterceptor, error) {
data, err := os.ReadFile(authorizedKeysPath)
if err != nil {
return nil, fmt.Errorf("reading authorized keys: %w", err)
}
@@ -54,26 +51,35 @@ func NewAuthInterceptor(path string) (*AuthInterceptor, error) {
}
if len(keys) == 0 {
return nil, fmt.Errorf("no valid keys found in %s", path)
return nil, fmt.Errorf("no valid keys found in %s", authorizedKeysPath)
}
return &AuthInterceptor{authorizedKeys: keys}, nil
jwtKey, err := loadOrGenerateJWTKey(repoPath)
if err != nil {
return nil, fmt.Errorf("loading JWT key: %w", err)
}
return &AuthInterceptor{authorizedKeys: keys, jwtKey: jwtKey}, nil
}
// NewAuthInterceptorFromKeys creates an interceptor from pre-parsed keys.
// Intended for testing.
func NewAuthInterceptorFromKeys(keys []ssh.PublicKey) *AuthInterceptor {
// NewAuthInterceptorFromKeys creates an interceptor from pre-parsed keys
// and a provided JWT key. Intended for testing.
func NewAuthInterceptorFromKeys(keys []ssh.PublicKey, jwtKey []byte) *AuthInterceptor {
m := make(map[string]ssh.PublicKey, len(keys))
for _, k := range keys {
m[ssh.FingerprintSHA256(k)] = k
}
return &AuthInterceptor{authorizedKeys: m}
return &AuthInterceptor{authorizedKeys: m, jwtKey: jwtKey}
}
// UnaryInterceptor returns a gRPC unary server interceptor.
func (a *AuthInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if err := a.verify(ctx); err != nil {
// Authenticate RPC is exempt from auth — it's how you get a token.
if strings.HasSuffix(info.FullMethod, "/Authenticate") {
return handler(ctx, req)
}
if err := a.verifyToken(ctx); err != nil {
return nil, err
}
return handler(ctx, req)
@@ -83,76 +89,161 @@ func (a *AuthInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
// StreamInterceptor returns a gRPC stream server interceptor.
func (a *AuthInterceptor) StreamInterceptor() grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := a.verify(ss.Context()); err != nil {
if err := a.verifyToken(ss.Context()); err != nil {
return err
}
return handler(srv, ss)
}
}
func (a *AuthInterceptor) verify(ctx context.Context) error {
// Authenticate verifies an SSH-signed challenge and issues a JWT.
func (a *AuthInterceptor) Authenticate(_ context.Context, req *sgardpb.AuthenticateRequest) (*sgardpb.AuthenticateResponse, error) {
pubkeyStr := req.GetPublicKey()
pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubkeyStr))
if err != nil {
return nil, status.Error(codes.Unauthenticated, "invalid public key")
}
fp := ssh.FingerprintSHA256(pubkey)
authorized, ok := a.authorizedKeys[fp]
if !ok {
return nil, status.Errorf(codes.PermissionDenied, "key %s not authorized", fp)
}
// Verify timestamp window.
tsUnix := req.GetTimestamp()
ts := time.Unix(tsUnix, 0)
if time.Since(ts).Abs() > authWindow {
return nil, status.Error(codes.Unauthenticated, "timestamp outside allowed window")
}
// Verify signature.
payload := buildPayload(req.GetNonce(), tsUnix)
sig, err := parseSSHSignature(req.GetSignature())
if err != nil {
return nil, status.Error(codes.Unauthenticated, "invalid signature format")
}
if err := authorized.Verify(payload, sig); err != nil {
return nil, status.Error(codes.Unauthenticated, "signature verification failed")
}
// Issue JWT.
token, err := a.issueToken(fp)
if err != nil {
return nil, status.Errorf(codes.Internal, "issuing token: %v", err)
}
return &sgardpb.AuthenticateResponse{Token: token}, nil
}
func (a *AuthInterceptor) verifyToken(ctx context.Context) error {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Error(codes.Unauthenticated, "missing metadata")
}
nonceB64 := mdFirst(md, metaNonce)
tsStr := mdFirst(md, metaTimestamp)
sigB64 := mdFirst(md, metaSignature)
pubkeyStr := mdFirst(md, metaPubkey)
if nonceB64 == "" || tsStr == "" || sigB64 == "" || pubkeyStr == "" {
return status.Error(codes.Unauthenticated, "missing auth metadata fields")
tokenStr := mdFirst(md, metaToken)
if tokenStr == "" {
return status.Error(codes.Unauthenticated, "missing auth token")
}
// Parse timestamp and check window.
tsUnix, err := strconv.ParseInt(tsStr, 10, 64)
if err != nil {
return status.Error(codes.Unauthenticated, "invalid timestamp")
}
ts := time.Unix(tsUnix, 0)
if time.Since(ts).Abs() > authWindow {
return status.Error(codes.Unauthenticated, "timestamp outside allowed window")
claims := &jwt.RegisteredClaims{}
token, err := jwt.ParseWithClaims(tokenStr, claims, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return a.jwtKey, nil
})
if err != nil || !token.Valid {
// Check if the token is expired but otherwise valid.
if a.isExpiredButValid(tokenStr, claims) {
return a.reauthError()
}
return status.Error(codes.Unauthenticated, "invalid token")
}
// Parse public key and check authorization.
pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubkeyStr))
if err != nil {
return status.Error(codes.Unauthenticated, "invalid public key")
}
fp := ssh.FingerprintSHA256(pubkey)
authorized, ok := a.authorizedKeys[fp]
if !ok {
return status.Errorf(codes.PermissionDenied, "key %s not authorized", fp)
}
// Decode nonce and signature.
nonce, err := base64.StdEncoding.DecodeString(nonceB64)
if err != nil {
return status.Error(codes.Unauthenticated, "invalid nonce encoding")
}
sigBytes, err := base64.StdEncoding.DecodeString(sigB64)
if err != nil {
return status.Error(codes.Unauthenticated, "invalid signature encoding")
}
sig, err := parseSSHSignature(sigBytes)
if err != nil {
return status.Error(codes.Unauthenticated, "invalid signature format")
}
// Build the signed payload: nonce + timestamp bytes.
payload := buildPayload(nonce, tsUnix)
// Verify.
if err := authorized.Verify(payload, sig); err != nil {
return status.Error(codes.Unauthenticated, "signature verification failed")
// Verify the fingerprint is still authorized.
fp := claims.Subject
if _, ok := a.authorizedKeys[fp]; !ok {
return status.Errorf(codes.PermissionDenied, "key %s no longer authorized", fp)
}
return nil
}
// isExpiredButValid checks if a token has a valid signature and the
// fingerprint is still in authorized_keys, but the token is expired.
func (a *AuthInterceptor) isExpiredButValid(tokenStr string, claims *jwt.RegisteredClaims) bool {
// Re-parse without time validation.
reClaims := &jwt.RegisteredClaims{}
_, err := jwt.ParseWithClaims(tokenStr, reClaims, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method")
}
return a.jwtKey, nil
}, jwt.WithoutClaimsValidation())
if err != nil {
return false
}
fp := reClaims.Subject
_, authorized := a.authorizedKeys[fp]
return authorized
}
// reauthError returns an Unauthenticated error with a ReauthChallenge
// embedded in the error details.
func (a *AuthInterceptor) reauthError() error {
nonce := make([]byte, 32)
if _, err := rand.Read(nonce); err != nil {
return status.Error(codes.Internal, "generating reauth nonce")
}
challenge := &sgardpb.ReauthChallenge{
Nonce: nonce,
Timestamp: time.Now().Unix(),
}
st, err := status.New(codes.Unauthenticated, "token expired").
WithDetails(challenge)
if err != nil {
return status.Error(codes.Unauthenticated, "token expired")
}
return st.Err()
}
func (a *AuthInterceptor) issueToken(fingerprint string) (string, error) {
now := time.Now()
claims := &jwt.RegisteredClaims{
Subject: fingerprint,
IssuedAt: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(now.Add(tokenTTL)),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(a.jwtKey)
}
func loadOrGenerateJWTKey(repoPath string) ([]byte, error) {
keyPath := filepath.Join(repoPath, "jwt.key")
data, err := os.ReadFile(keyPath)
if err == nil && len(data) >= 32 {
return data[:32], nil
}
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("generating JWT key: %w", err)
}
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
return nil, fmt.Errorf("writing JWT key: %w", err)
}
return key, nil
}
// buildPayload constructs the message that is signed: nonce || timestamp (big-endian int64).
func buildPayload(nonce []byte, tsUnix int64) []byte {
payload := make([]byte, len(nonce)+8)
@@ -187,7 +278,6 @@ func parseSSHSignature(data []byte) (*ssh.Signature, error) {
return nil, fmt.Errorf("signature too short")
}
// SSH signature wire format: string format, string blob
formatLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if 4+formatLen > len(data) {
return nil, fmt.Errorf("invalid format length")
@@ -204,8 +294,6 @@ func parseSSHSignature(data []byte) (*ssh.Signature, error) {
}
blob := rest[4 : 4+blobLen]
_ = strings.TrimSpace(format) // ensure format is clean
return &ssh.Signature{
Format: format,
Blob: blob,