package api import ( "database/sql" "encoding/json" "log" "net" "net/http" "sync" "time" "git.wntrmute.dev/kyle/mcias/data" _ "github.com/mattn/go-sqlite3" "golang.org/x/time/rate" ) // client represents a client with rate limiting information type client struct { limiter *rate.Limiter lastSeen time.Time } // SecurityEvent represents a security-related event type SecurityEvent struct { Timestamp string `json:"timestamp"` EventType string `json:"event_type"` UserID string `json:"user_id,omitempty"` Username string `json:"username,omitempty"` IPAddress string `json:"ip_address"` UserAgent string `json:"user_agent"` RequestURI string `json:"request_uri"` Success bool `json:"success"` Details map[string]string `json:"details,omitempty"` } // RateLimiter manages rate limiting for clients type RateLimiter struct { clients map[string]*client mu sync.Mutex // Requests per second, burst size rate rate.Limit burst int } // NewRateLimiter creates a new rate limiter func NewRateLimiter(r rate.Limit, b int) *RateLimiter { return &RateLimiter{ clients: make(map[string]*client), rate: r, burst: b, } } // GetLimiter returns a rate limiter for a client func (rl *RateLimiter) GetLimiter(ip string) *rate.Limiter { rl.mu.Lock() defer rl.mu.Unlock() c, exists := rl.clients[ip] if !exists { c = &client{ limiter: rate.NewLimiter(rl.rate, rl.burst), lastSeen: time.Now(), } rl.clients[ip] = c } else { c.lastSeen = time.Now() } return c.limiter } // CleanupClients removes old clients func (rl *RateLimiter) CleanupClients() { rl.mu.Lock() defer rl.mu.Unlock() for ip, client := range rl.clients { if time.Since(client.lastSeen) > 1*time.Hour { delete(rl.clients, ip) } } } type Server struct { DB *sql.DB Router *http.ServeMux Logger *log.Logger Auth *data.AuthorizationService RateLimiter *RateLimiter } // getClientIP extracts the client IP address from the request func getClientIP(r *http.Request) string { // Check for X-Forwarded-For header first (for clients behind proxy) ip := r.Header.Get("X-Forwarded-For") if ip != "" { // X-Forwarded-For can contain multiple IPs, use the first one ips := net.ParseIP(ip) if ips != nil { return ips.String() } } // Fall back to RemoteAddr ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return ip } func NewServer(db *sql.DB, logger *log.Logger) *Server { // Create a rate limiter with 10 requests per second and burst of 30 rateLimiter := NewRateLimiter(10, 30) s := &Server{ DB: db, Router: http.NewServeMux(), Logger: logger, Auth: data.NewAuthorizationService(db), RateLimiter: rateLimiter, } // Start a goroutine to clean up old clients go func() { for { time.Sleep(1 * time.Hour) rateLimiter.CleanupClients() } }() s.registerRoutes() return s } func (s *Server) registerRoutes() { s.Router.HandleFunc("POST /v1/login/password", s.handlePasswordLogin) s.Router.HandleFunc("POST /v1/login/token", s.handleTokenLogin) s.Router.HandleFunc("POST /v1/login/totp", s.handleTOTPVerify) s.Router.HandleFunc("GET /v1/database/credentials", s.handleDatabaseCredentials) } // sendRateLimitExceeded sends a rate limit exceeded response func (s *Server) sendRateLimitExceeded(w http.ResponseWriter) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) response := map[string]string{ "error": "Rate limit exceeded. Please try again later.", } if err := json.NewEncoder(w).Encode(response); err != nil { s.Logger.Printf("Error encoding rate limit response: %v", err) } } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Get client IP clientIP := getClientIP(r) // Get rate limiter for this client limiter := s.RateLimiter.GetLimiter(clientIP) // Check if rate limit is exceeded if !limiter.Allow() { s.Logger.Printf("Rate limit exceeded for IP: %s, URI: %s", clientIP, r.RequestURI) s.sendRateLimitExceeded(w) return } // Apply stricter rate limiting for authentication endpoints if r.URL.Path == "/v1/login/password" || r.URL.Path == "/v1/login/token" || r.URL.Path == "/v1/login/totp" { // Use a separate limiter with lower rate for auth endpoints authLimiter := rate.NewLimiter(rate.Limit(1), 5) // 1 request per second, burst of 5 if !authLimiter.Allow() { s.Logger.Printf("Auth rate limit exceeded for IP: %s, URI: %s", clientIP, r.RequestURI) s.sendRateLimitExceeded(w) return } } // Proceed with the request s.Router.ServeHTTP(w, r) } // LogSecurityEvent logs a security-related event func (s *Server) LogSecurityEvent(r *http.Request, eventType string, userID, username string, success bool, details map[string]string) { event := SecurityEvent{ Timestamp: time.Now().UTC().Format(time.RFC3339), EventType: eventType, UserID: userID, Username: username, IPAddress: getClientIP(r), UserAgent: r.UserAgent(), RequestURI: r.RequestURI, Success: success, Details: details, } // Convert to JSON for structured logging eventJSON, err := json.Marshal(event) if err != nil { s.Logger.Printf("Error marshaling security event: %v", err) return } // Log the security event s.Logger.Printf("SECURITY_EVENT: %s", eventJSON) } func (s *Server) Start(addr string) error { s.Logger.Printf("Starting server on %s", addr) s.Logger.Printf("Note: This server is designed to run behind a reverse proxy that handles TLS") return http.ListenAndServe(addr, s) }