Implement JWT token auth with transparent auto-renewal.
Replace per-call SSH signing with a two-layer auth system: Server: AuthInterceptor verifies JWT tokens (HMAC-SHA256 signed with repo-local jwt.key). Authenticate RPC accepts SSH-signed challenges and issues 30-day JWTs. Expired-but-valid tokens return a ReauthChallenge in error details (server-provided nonce for fast re-auth). Authenticate RPC is exempt from token requirement. Client: TokenCredentials replaces SSHCredentials as the primary PerRPCCredentials. NewWithAuth creates clients with auto-renewal — EnsureAuth obtains initial token, retryOnAuth catches Unauthenticated errors and re-authenticates transparently. Token cached at $XDG_STATE_HOME/sgard/token (0600). CLI: dialRemote() helper handles token loading, connection setup, and initial auth. Push/pull/prune commands simplified to use it. Proto: Added Authenticate RPC, AuthenticateRequest/Response, ReauthChallenge messages. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
236
server/auth.go
236
server/auth.go
@@ -3,13 +3,14 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/kisom/sgard/sgardpb"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -18,25 +19,21 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// Metadata keys for auth.
|
||||
metaNonce = "x-sgard-auth-nonce"
|
||||
metaTimestamp = "x-sgard-auth-timestamp"
|
||||
metaSignature = "x-sgard-auth-signature"
|
||||
metaPubkey = "x-sgard-auth-pubkey"
|
||||
|
||||
// authWindow is how far the timestamp can deviate from server time.
|
||||
metaToken = "x-sgard-auth-token"
|
||||
authWindow = 5 * time.Minute
|
||||
tokenTTL = 30 * 24 * time.Hour // 30 days
|
||||
)
|
||||
|
||||
// AuthInterceptor verifies SSH key signatures on gRPC requests.
|
||||
// AuthInterceptor verifies JWT tokens or SSH key signatures on gRPC requests.
|
||||
type AuthInterceptor struct {
|
||||
authorizedKeys map[string]ssh.PublicKey // keyed by fingerprint
|
||||
jwtKey []byte // HMAC-SHA256 signing key
|
||||
}
|
||||
|
||||
// NewAuthInterceptor creates an interceptor from an authorized_keys file.
|
||||
// The file uses the same format as ~/.ssh/authorized_keys.
|
||||
func NewAuthInterceptor(path string) (*AuthInterceptor, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
// NewAuthInterceptor creates an interceptor from an authorized_keys file
|
||||
// and a repository path (for the JWT secret key).
|
||||
func NewAuthInterceptor(authorizedKeysPath, repoPath string) (*AuthInterceptor, error) {
|
||||
data, err := os.ReadFile(authorizedKeysPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading authorized keys: %w", err)
|
||||
}
|
||||
@@ -54,26 +51,35 @@ func NewAuthInterceptor(path string) (*AuthInterceptor, error) {
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil, fmt.Errorf("no valid keys found in %s", path)
|
||||
return nil, fmt.Errorf("no valid keys found in %s", authorizedKeysPath)
|
||||
}
|
||||
|
||||
return &AuthInterceptor{authorizedKeys: keys}, nil
|
||||
jwtKey, err := loadOrGenerateJWTKey(repoPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading JWT key: %w", err)
|
||||
}
|
||||
|
||||
return &AuthInterceptor{authorizedKeys: keys, jwtKey: jwtKey}, nil
|
||||
}
|
||||
|
||||
// NewAuthInterceptorFromKeys creates an interceptor from pre-parsed keys.
|
||||
// Intended for testing.
|
||||
func NewAuthInterceptorFromKeys(keys []ssh.PublicKey) *AuthInterceptor {
|
||||
// NewAuthInterceptorFromKeys creates an interceptor from pre-parsed keys
|
||||
// and a provided JWT key. Intended for testing.
|
||||
func NewAuthInterceptorFromKeys(keys []ssh.PublicKey, jwtKey []byte) *AuthInterceptor {
|
||||
m := make(map[string]ssh.PublicKey, len(keys))
|
||||
for _, k := range keys {
|
||||
m[ssh.FingerprintSHA256(k)] = k
|
||||
}
|
||||
return &AuthInterceptor{authorizedKeys: m}
|
||||
return &AuthInterceptor{authorizedKeys: m, jwtKey: jwtKey}
|
||||
}
|
||||
|
||||
// UnaryInterceptor returns a gRPC unary server interceptor.
|
||||
func (a *AuthInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
if err := a.verify(ctx); err != nil {
|
||||
// Authenticate RPC is exempt from auth — it's how you get a token.
|
||||
if strings.HasSuffix(info.FullMethod, "/Authenticate") {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
if err := a.verifyToken(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return handler(ctx, req)
|
||||
@@ -83,76 +89,161 @@ func (a *AuthInterceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
|
||||
// StreamInterceptor returns a gRPC stream server interceptor.
|
||||
func (a *AuthInterceptor) StreamInterceptor() grpc.StreamServerInterceptor {
|
||||
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
if err := a.verify(ss.Context()); err != nil {
|
||||
if err := a.verifyToken(ss.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
return handler(srv, ss)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthInterceptor) verify(ctx context.Context) error {
|
||||
// Authenticate verifies an SSH-signed challenge and issues a JWT.
|
||||
func (a *AuthInterceptor) Authenticate(_ context.Context, req *sgardpb.AuthenticateRequest) (*sgardpb.AuthenticateResponse, error) {
|
||||
pubkeyStr := req.GetPublicKey()
|
||||
pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubkeyStr))
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid public key")
|
||||
}
|
||||
|
||||
fp := ssh.FingerprintSHA256(pubkey)
|
||||
authorized, ok := a.authorizedKeys[fp]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "key %s not authorized", fp)
|
||||
}
|
||||
|
||||
// Verify timestamp window.
|
||||
tsUnix := req.GetTimestamp()
|
||||
ts := time.Unix(tsUnix, 0)
|
||||
if time.Since(ts).Abs() > authWindow {
|
||||
return nil, status.Error(codes.Unauthenticated, "timestamp outside allowed window")
|
||||
}
|
||||
|
||||
// Verify signature.
|
||||
payload := buildPayload(req.GetNonce(), tsUnix)
|
||||
sig, err := parseSSHSignature(req.GetSignature())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid signature format")
|
||||
}
|
||||
if err := authorized.Verify(payload, sig); err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "signature verification failed")
|
||||
}
|
||||
|
||||
// Issue JWT.
|
||||
token, err := a.issueToken(fp)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "issuing token: %v", err)
|
||||
}
|
||||
|
||||
return &sgardpb.AuthenticateResponse{Token: token}, nil
|
||||
}
|
||||
|
||||
func (a *AuthInterceptor) verifyToken(ctx context.Context) error {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, "missing metadata")
|
||||
}
|
||||
|
||||
nonceB64 := mdFirst(md, metaNonce)
|
||||
tsStr := mdFirst(md, metaTimestamp)
|
||||
sigB64 := mdFirst(md, metaSignature)
|
||||
pubkeyStr := mdFirst(md, metaPubkey)
|
||||
|
||||
if nonceB64 == "" || tsStr == "" || sigB64 == "" || pubkeyStr == "" {
|
||||
return status.Error(codes.Unauthenticated, "missing auth metadata fields")
|
||||
tokenStr := mdFirst(md, metaToken)
|
||||
if tokenStr == "" {
|
||||
return status.Error(codes.Unauthenticated, "missing auth token")
|
||||
}
|
||||
|
||||
// Parse timestamp and check window.
|
||||
tsUnix, err := strconv.ParseInt(tsStr, 10, 64)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid timestamp")
|
||||
}
|
||||
ts := time.Unix(tsUnix, 0)
|
||||
if time.Since(ts).Abs() > authWindow {
|
||||
return status.Error(codes.Unauthenticated, "timestamp outside allowed window")
|
||||
claims := &jwt.RegisteredClaims{}
|
||||
token, err := jwt.ParseWithClaims(tokenStr, claims, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return a.jwtKey, nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
// Check if the token is expired but otherwise valid.
|
||||
if a.isExpiredButValid(tokenStr, claims) {
|
||||
return a.reauthError()
|
||||
}
|
||||
return status.Error(codes.Unauthenticated, "invalid token")
|
||||
}
|
||||
|
||||
// Parse public key and check authorization.
|
||||
pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(pubkeyStr))
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid public key")
|
||||
}
|
||||
fp := ssh.FingerprintSHA256(pubkey)
|
||||
authorized, ok := a.authorizedKeys[fp]
|
||||
if !ok {
|
||||
return status.Errorf(codes.PermissionDenied, "key %s not authorized", fp)
|
||||
}
|
||||
|
||||
// Decode nonce and signature.
|
||||
nonce, err := base64.StdEncoding.DecodeString(nonceB64)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid nonce encoding")
|
||||
}
|
||||
|
||||
sigBytes, err := base64.StdEncoding.DecodeString(sigB64)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid signature encoding")
|
||||
}
|
||||
|
||||
sig, err := parseSSHSignature(sigBytes)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "invalid signature format")
|
||||
}
|
||||
|
||||
// Build the signed payload: nonce + timestamp bytes.
|
||||
payload := buildPayload(nonce, tsUnix)
|
||||
|
||||
// Verify.
|
||||
if err := authorized.Verify(payload, sig); err != nil {
|
||||
return status.Error(codes.Unauthenticated, "signature verification failed")
|
||||
// Verify the fingerprint is still authorized.
|
||||
fp := claims.Subject
|
||||
if _, ok := a.authorizedKeys[fp]; !ok {
|
||||
return status.Errorf(codes.PermissionDenied, "key %s no longer authorized", fp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isExpiredButValid checks if a token has a valid signature and the
|
||||
// fingerprint is still in authorized_keys, but the token is expired.
|
||||
func (a *AuthInterceptor) isExpiredButValid(tokenStr string, claims *jwt.RegisteredClaims) bool {
|
||||
// Re-parse without time validation.
|
||||
reClaims := &jwt.RegisteredClaims{}
|
||||
_, err := jwt.ParseWithClaims(tokenStr, reClaims, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method")
|
||||
}
|
||||
return a.jwtKey, nil
|
||||
}, jwt.WithoutClaimsValidation())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
fp := reClaims.Subject
|
||||
_, authorized := a.authorizedKeys[fp]
|
||||
return authorized
|
||||
}
|
||||
|
||||
// reauthError returns an Unauthenticated error with a ReauthChallenge
|
||||
// embedded in the error details.
|
||||
func (a *AuthInterceptor) reauthError() error {
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return status.Error(codes.Internal, "generating reauth nonce")
|
||||
}
|
||||
|
||||
challenge := &sgardpb.ReauthChallenge{
|
||||
Nonce: nonce,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
st, err := status.New(codes.Unauthenticated, "token expired").
|
||||
WithDetails(challenge)
|
||||
if err != nil {
|
||||
return status.Error(codes.Unauthenticated, "token expired")
|
||||
}
|
||||
return st.Err()
|
||||
}
|
||||
|
||||
func (a *AuthInterceptor) issueToken(fingerprint string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := &jwt.RegisteredClaims{
|
||||
Subject: fingerprint,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(tokenTTL)),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(a.jwtKey)
|
||||
}
|
||||
|
||||
func loadOrGenerateJWTKey(repoPath string) ([]byte, error) {
|
||||
keyPath := filepath.Join(repoPath, "jwt.key")
|
||||
|
||||
data, err := os.ReadFile(keyPath)
|
||||
if err == nil && len(data) >= 32 {
|
||||
return data[:32], nil
|
||||
}
|
||||
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return nil, fmt.Errorf("generating JWT key: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
|
||||
return nil, fmt.Errorf("writing JWT key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// buildPayload constructs the message that is signed: nonce || timestamp (big-endian int64).
|
||||
func buildPayload(nonce []byte, tsUnix int64) []byte {
|
||||
payload := make([]byte, len(nonce)+8)
|
||||
@@ -187,7 +278,6 @@ func parseSSHSignature(data []byte) (*ssh.Signature, error) {
|
||||
return nil, fmt.Errorf("signature too short")
|
||||
}
|
||||
|
||||
// SSH signature wire format: string format, string blob
|
||||
formatLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
||||
if 4+formatLen > len(data) {
|
||||
return nil, fmt.Errorf("invalid format length")
|
||||
@@ -204,8 +294,6 @@ func parseSSHSignature(data []byte) (*ssh.Signature, error) {
|
||||
}
|
||||
blob := rest[4 : 4+blobLen]
|
||||
|
||||
_ = strings.TrimSpace(format) // ensure format is clean
|
||||
|
||||
return &ssh.Signature{
|
||||
Format: format,
|
||||
Blob: blob,
|
||||
|
||||
@@ -4,15 +4,18 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/kisom/sgard/sgardpb"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
var testJWTKey = []byte("test-jwt-secret-key-32-bytes!!")
|
||||
|
||||
func generateTestKey(t *testing.T) (ssh.Signer, ssh.PublicKey) {
|
||||
t.Helper()
|
||||
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
@@ -26,94 +29,118 @@ func generateTestKey(t *testing.T) (ssh.Signer, ssh.PublicKey) {
|
||||
return signer, signer.PublicKey()
|
||||
}
|
||||
|
||||
func signedContext(t *testing.T, signer ssh.Signer) context.Context {
|
||||
t.Helper()
|
||||
func TestAuthenticateAndVerifyToken(t *testing.T) {
|
||||
signer, pubkey := generateTestKey(t)
|
||||
auth := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey}, testJWTKey)
|
||||
|
||||
nonce, err := GenerateNonce()
|
||||
if err != nil {
|
||||
t.Fatalf("generating nonce: %v", err)
|
||||
}
|
||||
// Generate a signed challenge.
|
||||
nonce, _ := GenerateNonce()
|
||||
tsUnix := time.Now().Unix()
|
||||
payload := buildPayload(nonce, tsUnix)
|
||||
|
||||
sig, err := signer.Sign(rand.Reader, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("signing: %v", err)
|
||||
}
|
||||
|
||||
pubkeyStr := string(ssh.MarshalAuthorizedKey(signer.PublicKey()))
|
||||
pubkeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey())))
|
||||
|
||||
md := metadata.New(map[string]string{
|
||||
metaNonce: base64.StdEncoding.EncodeToString(nonce),
|
||||
metaTimestamp: strconv.FormatInt(tsUnix, 10),
|
||||
metaSignature: base64.StdEncoding.EncodeToString(ssh.Marshal(sig)),
|
||||
metaPubkey: pubkeyStr,
|
||||
// Call Authenticate.
|
||||
resp, err := auth.Authenticate(context.Background(), &sgardpb.AuthenticateRequest{
|
||||
Nonce: nonce,
|
||||
Timestamp: tsUnix,
|
||||
Signature: ssh.Marshal(sig),
|
||||
PublicKey: pubkeyStr,
|
||||
})
|
||||
return metadata.NewIncomingContext(context.Background(), md)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Authenticate: %v", err)
|
||||
}
|
||||
if resp.Token == "" {
|
||||
t.Fatal("expected non-empty token")
|
||||
}
|
||||
|
||||
func TestAuthVerifyValid(t *testing.T) {
|
||||
signer, pubkey := generateTestKey(t)
|
||||
interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey})
|
||||
|
||||
ctx := signedContext(t, signer)
|
||||
if err := interceptor.verify(ctx); err != nil {
|
||||
t.Fatalf("verify should succeed: %v", err)
|
||||
// Use the token in metadata.
|
||||
md := metadata.New(map[string]string{metaToken: resp.Token})
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
if err := auth.verifyToken(ctx); err != nil {
|
||||
t.Fatalf("verifyToken should accept valid token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRejectUnauthenticated(t *testing.T) {
|
||||
func TestRejectMissingToken(t *testing.T) {
|
||||
_, pubkey := generateTestKey(t)
|
||||
interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey})
|
||||
auth := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey}, testJWTKey)
|
||||
|
||||
// No metadata at all.
|
||||
ctx := context.Background()
|
||||
if err := interceptor.verify(ctx); err == nil {
|
||||
t.Fatal("verify should reject missing metadata")
|
||||
if err := auth.verifyToken(context.Background()); err == nil {
|
||||
t.Fatal("should reject missing metadata")
|
||||
}
|
||||
|
||||
// Empty metadata.
|
||||
md := metadata.New(nil)
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
if err := auth.verifyToken(ctx); err == nil {
|
||||
t.Fatal("should reject missing token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRejectUnauthorizedKey(t *testing.T) {
|
||||
func TestRejectUnauthorizedKey(t *testing.T) {
|
||||
signer1, _ := generateTestKey(t)
|
||||
_, pubkey2 := generateTestKey(t)
|
||||
|
||||
// Interceptor knows key2 but request is signed by key1.
|
||||
interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey2})
|
||||
// Auth only knows pubkey2, but we authenticate with signer1.
|
||||
auth := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey2}, testJWTKey)
|
||||
|
||||
ctx := signedContext(t, signer1)
|
||||
if err := interceptor.verify(ctx); err == nil {
|
||||
t.Fatal("verify should reject unauthorized key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRejectExpiredTimestamp(t *testing.T) {
|
||||
signer, pubkey := generateTestKey(t)
|
||||
interceptor := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey})
|
||||
|
||||
nonce, err := GenerateNonce()
|
||||
if err != nil {
|
||||
t.Fatalf("generating nonce: %v", err)
|
||||
}
|
||||
// Timestamp 10 minutes ago — outside the 5-minute window.
|
||||
tsUnix := time.Now().Add(-10 * time.Minute).Unix()
|
||||
nonce, _ := GenerateNonce()
|
||||
tsUnix := time.Now().Unix()
|
||||
payload := buildPayload(nonce, tsUnix)
|
||||
sig, _ := signer1.Sign(rand.Reader, payload)
|
||||
pubkeyStr := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer1.PublicKey())))
|
||||
|
||||
sig, err := signer.Sign(rand.Reader, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("signing: %v", err)
|
||||
}
|
||||
|
||||
pubkeyStr := string(ssh.MarshalAuthorizedKey(signer.PublicKey()))
|
||||
|
||||
md := metadata.New(map[string]string{
|
||||
metaNonce: base64.StdEncoding.EncodeToString(nonce),
|
||||
metaTimestamp: strconv.FormatInt(tsUnix, 10),
|
||||
metaSignature: base64.StdEncoding.EncodeToString(ssh.Marshal(sig)),
|
||||
metaPubkey: pubkeyStr,
|
||||
_, err := auth.Authenticate(context.Background(), &sgardpb.AuthenticateRequest{
|
||||
Nonce: nonce,
|
||||
Timestamp: tsUnix,
|
||||
Signature: ssh.Marshal(sig),
|
||||
PublicKey: pubkeyStr,
|
||||
})
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
|
||||
if err := interceptor.verify(ctx); err == nil {
|
||||
t.Fatal("verify should reject expired timestamp")
|
||||
if err == nil {
|
||||
t.Fatal("should reject unauthorized key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiredTokenReturnsChallenge(t *testing.T) {
|
||||
signer, pubkey := generateTestKey(t)
|
||||
auth := NewAuthInterceptorFromKeys([]ssh.PublicKey{pubkey}, testJWTKey)
|
||||
|
||||
// Issue a token, then manually create an expired one.
|
||||
fp := ssh.FingerprintSHA256(signer.PublicKey())
|
||||
expiredToken, err := auth.issueExpiredToken(fp)
|
||||
if err != nil {
|
||||
t.Fatalf("issuing expired token: %v", err)
|
||||
}
|
||||
|
||||
md := metadata.New(map[string]string{metaToken: expiredToken})
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
err = auth.verifyToken(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("should reject expired token")
|
||||
}
|
||||
|
||||
// The error should contain a ReauthChallenge in its details.
|
||||
// We can't easily extract it here without the client helper,
|
||||
// but verify the error message indicates expiry.
|
||||
if !strings.Contains(err.Error(), "expired") {
|
||||
t.Errorf("error should mention expiry, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// issueExpiredToken is a test helper that creates an already-expired JWT.
|
||||
func (a *AuthInterceptor) issueExpiredToken(fingerprint string) (string, error) {
|
||||
past := time.Now().Add(-time.Hour)
|
||||
claims := &jwt.RegisteredClaims{
|
||||
Subject: fingerprint,
|
||||
IssuedAt: jwt.NewNumericDate(past.Add(-24 * time.Hour)),
|
||||
ExpiresAt: jwt.NewNumericDate(past),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(a.jwtKey)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ type Server struct {
|
||||
garden *garden.Garden
|
||||
mu sync.RWMutex
|
||||
pendingManifest *manifest.Manifest
|
||||
auth *AuthInterceptor // nil if auth is disabled
|
||||
}
|
||||
|
||||
// New creates a new Server backed by the given Garden.
|
||||
@@ -31,6 +32,19 @@ func New(g *garden.Garden) *Server {
|
||||
return &Server{garden: g}
|
||||
}
|
||||
|
||||
// NewWithAuth creates a new Server with authentication enabled.
|
||||
func NewWithAuth(g *garden.Garden, auth *AuthInterceptor) *Server {
|
||||
return &Server{garden: g, auth: auth}
|
||||
}
|
||||
|
||||
// Authenticate handles the auth RPC by delegating to the AuthInterceptor.
|
||||
func (s *Server) Authenticate(ctx context.Context, req *sgardpb.AuthenticateRequest) (*sgardpb.AuthenticateResponse, error) {
|
||||
if s.auth == nil {
|
||||
return nil, status.Error(codes.Unimplemented, "authentication not configured")
|
||||
}
|
||||
return s.auth.Authenticate(ctx, req)
|
||||
}
|
||||
|
||||
// PushManifest compares the client manifest against the server manifest and
|
||||
// decides whether to accept, reject, or report up-to-date.
|
||||
func (s *Server) PushManifest(_ context.Context, req *sgardpb.PushManifestRequest) (*sgardpb.PushManifestResponse, error) {
|
||||
|
||||
Reference in New Issue
Block a user