package auth import ( "bytes" "context" "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/hex" "encoding/json" "fmt" "io" "log/slog" "net/http" "os" "path/filepath" "strings" "sync" "time" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) // cacheTTL is the duration cached token validation results are valid. const cacheTTL = 30 * time.Second // TokenInfo holds the result of a token validation. type TokenInfo struct { Valid bool `json:"valid"` Username string `json:"username"` Roles []string `json:"roles"` AccountType string `json:"account_type"` } // HasRole reports whether the token has the given role. func (t *TokenInfo) HasRole(role string) bool { for _, r := range t.Roles { if r == role { return true } } return false } // TokenValidator validates bearer tokens against MCIAS. type TokenValidator interface { ValidateToken(ctx context.Context, token string) (*TokenInfo, error) } // tokenCache stores validated token results with a TTL. type tokenCache struct { mu sync.RWMutex entries map[string]cacheEntry } type cacheEntry struct { info *TokenInfo expiresAt time.Time } func newTokenCache() *tokenCache { return &tokenCache{ entries: make(map[string]cacheEntry), } } func (c *tokenCache) get(hash string) (*TokenInfo, bool) { c.mu.RLock() entry, ok := c.entries[hash] c.mu.RUnlock() if !ok { return nil, false } if time.Now().After(entry.expiresAt) { c.mu.Lock() delete(c.entries, hash) c.mu.Unlock() return nil, false } return entry.info, true } func (c *tokenCache) put(hash string, info *TokenInfo) { c.mu.Lock() c.entries[hash] = cacheEntry{ info: info, expiresAt: time.Now().Add(cacheTTL), } c.mu.Unlock() } // newHTTPClient creates an HTTP client with TLS 1.3 minimum. If caCertPath // is non-empty, the CA certificate is loaded and added to the root CA pool. func newHTTPClient(caCertPath string) (*http.Client, error) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS13, } if caCertPath != "" { caCert, err := os.ReadFile(caCertPath) //nolint:gosec // path from trusted config if err != nil { return nil, fmt.Errorf("read CA cert %q: %w", caCertPath, err) } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(caCert) { return nil, fmt.Errorf("parse CA cert %q: no valid certificates found", caCertPath) } tlsConfig.RootCAs = pool } return &http.Client{ Timeout: 10 * time.Second, Transport: &http.Transport{ TLSClientConfig: tlsConfig, }, }, nil } // MCIASValidator validates tokens by calling the MCIAS HTTP endpoint. type MCIASValidator struct { ServerURL string CACertPath string httpClient *http.Client cache *tokenCache } // NewMCIASValidator creates a validator that calls MCIAS at the given URL. // If caCertPath is non-empty, the CA certificate is loaded and used for TLS. func NewMCIASValidator(serverURL, caCertPath string) (*MCIASValidator, error) { client, err := newHTTPClient(caCertPath) if err != nil { return nil, err } return &MCIASValidator{ ServerURL: strings.TrimRight(serverURL, "/"), CACertPath: caCertPath, httpClient: client, cache: newTokenCache(), }, nil } func tokenHash(token string) string { h := sha256.Sum256([]byte(token)) return hex.EncodeToString(h[:]) } // ValidateToken validates a bearer token against MCIAS. Results are cached // for 30 seconds keyed by the SHA-256 hash of the token. func (v *MCIASValidator) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) { hash := tokenHash(token) if info, ok := v.cache.get(hash); ok { return info, nil } url := v.ServerURL + "/v1/token/validate" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) if err != nil { return nil, fmt.Errorf("create validate request: %w", err) } req.Header.Set("Authorization", "Bearer "+token) resp, err := v.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("validate token: %w", err) } defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("read validate response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("validate token: MCIAS returned %d", resp.StatusCode) } var info TokenInfo if err := json.Unmarshal(body, &info); err != nil { return nil, fmt.Errorf("parse validate response: %w", err) } v.cache.put(hash, &info) return &info, nil } // contextKey is an unexported type for context keys in this package. type contextKey struct{} // tokenInfoKey is the context key for TokenInfo. var tokenInfoKey = contextKey{} // ContextWithTokenInfo returns a new context carrying the given TokenInfo. func ContextWithTokenInfo(ctx context.Context, info *TokenInfo) context.Context { return context.WithValue(ctx, tokenInfoKey, info) } // TokenInfoFromContext retrieves TokenInfo from the context, or nil if absent. func TokenInfoFromContext(ctx context.Context) *TokenInfo { info, _ := ctx.Value(tokenInfoKey).(*TokenInfo) return info } // AuthInterceptor returns a gRPC unary server interceptor that validates // bearer tokens and requires the "admin" role. func AuthInterceptor(validator TokenValidator) grpc.UnaryServerInterceptor { return func( ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (any, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, status.Error(codes.Unauthenticated, "missing metadata") } authValues := md.Get("authorization") if len(authValues) == 0 { return nil, status.Error(codes.Unauthenticated, "missing authorization header") } authHeader := authValues[0] if !strings.HasPrefix(authHeader, "Bearer ") { return nil, status.Error(codes.Unauthenticated, "malformed authorization header") } token := strings.TrimPrefix(authHeader, "Bearer ") tokenInfo, err := validator.ValidateToken(ctx, token) if err != nil { slog.Error("token validation failed", "method", info.FullMethod, "error", err) return nil, status.Error(codes.Unauthenticated, "token validation failed") } if !tokenInfo.Valid { return nil, status.Error(codes.Unauthenticated, "invalid token") } if !tokenInfo.HasRole("admin") { slog.Warn("permission denied", "method", info.FullMethod, "user", tokenInfo.Username) return nil, status.Error(codes.PermissionDenied, "admin role required") } slog.Info("rpc", "method", info.FullMethod, "user", tokenInfo.Username, "account_type", tokenInfo.AccountType) ctx = ContextWithTokenInfo(ctx, tokenInfo) return handler(ctx, req) } } // Login authenticates with MCIAS and returns a bearer token. func Login(serverURL, caCertPath, username, password string) (string, error) { client, err := newHTTPClient(caCertPath) if err != nil { return "", err } serverURL = strings.TrimRight(serverURL, "/") url := serverURL + "/v1/auth/login" payload := struct { Username string `json:"username"` Password string `json:"password"` }{ Username: username, Password: password, } body, err := json.Marshal(payload) //nolint:gosec // intentional login credential payload if err != nil { return "", fmt.Errorf("marshal login request: %w", err) } req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) if err != nil { return "", fmt.Errorf("create login request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("login request: %w", err) } defer func() { _ = resp.Body.Close() }() respBody, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("read login response: %w", err) } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("login failed: MCIAS returned %d: %s", resp.StatusCode, string(respBody)) } var result struct { Token string `json:"token"` } if err := json.Unmarshal(respBody, &result); err != nil { return "", fmt.Errorf("parse login response: %w", err) } if result.Token == "" { return "", fmt.Errorf("login response missing token") } return result.Token, nil } // LoadToken reads a token from the given file path and trims whitespace. func LoadToken(path string) (string, error) { data, err := os.ReadFile(path) //nolint:gosec // path from trusted caller if err != nil { return "", fmt.Errorf("load token from %q: %w", path, err) } return strings.TrimSpace(string(data)), nil } // SaveToken writes a token to the given file with 0600 permissions. // Parent directories are created if they do not exist. func SaveToken(path string, token string) error { dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0700); err != nil { return fmt.Errorf("create token directory %q: %w", dir, err) } if err := os.WriteFile(path, []byte(token+"\n"), 0600); err != nil { return fmt.Errorf("save token to %q: %w", path, err) } return nil }