package api import ( "crypto/rand" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "net/http" "strings" "time" "git.wntrmute.dev/kyle/mcias/data" "github.com/Masterminds/squirrel" "github.com/oklog/ulid/v2" ) type LoginRequest struct { Version string `json:"version"` Login data.Login `json:"login"` } type TokenResponse struct { Token string `json:"token"` Expires int64 `json:"expires"` } type TOTPVerifyRequest struct { Version string `json:"version"` Username string `json:"username"` TOTPCode string `json:"totp_code"` } type ErrorResponse struct { Error string `json:"error"` } type DatabaseCredentials struct { Host string `json:"host"` Port int `json:"port"` Name string `json:"name"` User string `json:"user"` Password string `json:"password"` } func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { var req LoginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.sendError(w, "Invalid request format", http.StatusBadRequest) return } if req.Version != "v1" || req.Login.User == "" || req.Login.Password == "" { s.sendError(w, "Invalid login request", http.StatusBadRequest) return } user, err := s.getUserByUsername(req.Login.User) if err != nil { if errors.Is(err, sql.ErrNoRows) { s.sendError(w, "Invalid username or password", http.StatusUnauthorized) } else { s.Logger.Printf("Database error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) } return } // Check password only if !user.CheckPassword(&req.Login) { s.sendError(w, "Invalid username or password", http.StatusUnauthorized) return } // Password is correct, create a token regardless of TOTP status token, expires, err := s.createToken(user.ID) if err != nil { s.Logger.Printf("Token creation error: %v", err) // Log the security event details := map[string]string{ "reason": "Token creation error", "error": err.Error(), } s.LogSecurityEvent(r, "login_attempt", user.ID, user.User, false, details) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } // If user has TOTP enabled, include this information in the response totpEnabled := user.HasTOTP() // Log successful login details := map[string]string{ "token_expires": time.Unix(expires, 0).UTC().Format(time.RFC3339), "totp_enabled": fmt.Sprintf("%t", totpEnabled), } s.LogSecurityEvent(r, "login_success", user.ID, user.User, true, details) // Include TOTP status in the response response := struct { TokenResponse TOTPEnabled bool `json:"totp_enabled"` }{ TokenResponse: TokenResponse{ Token: token, Expires: expires, }, TOTPEnabled: totpEnabled, } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(response); err != nil { s.Logger.Printf("Error encoding response: %v", err) } } func (s *Server) handleTokenLogin(w http.ResponseWriter, r *http.Request) { var req LoginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.sendError(w, "Invalid request format", http.StatusBadRequest) return } if req.Version != "v1" || req.Login.User == "" || req.Login.Token == "" { s.sendError(w, "Invalid login request", http.StatusBadRequest) return } // Verify the token is valid _, err := s.verifyToken(req.Login.User, req.Login.Token) if err != nil { s.sendError(w, "Invalid or expired token", http.StatusUnauthorized) return } // Renew the existing token instead of creating a new one expires, err := s.renewToken(req.Login.User, req.Login.Token) if err != nil { s.Logger.Printf("Token renewal error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(TokenResponse{ Token: req.Login.Token, Expires: expires, }); err != nil { s.Logger.Printf("Error encoding response: %v", err) } } func (s *Server) sendError(w http.ResponseWriter, message string, status int) { // Log the detailed error message for debugging s.Logger.Printf("Error response: %s (Status: %d)", message, status) // Create a generic error message based on status code publicMessage := "An error occurred processing your request" errorCode := fmt.Sprintf("E%000d", status) // Customize public messages for common status codes // but don't leak specific details about the error switch status { case http.StatusBadRequest: publicMessage = "Invalid request format" case http.StatusUnauthorized: publicMessage = "Authentication required" case http.StatusForbidden: publicMessage = "Insufficient permissions" case http.StatusNotFound: publicMessage = "Resource not found" case http.StatusTooManyRequests: publicMessage = "Rate limit exceeded" case http.StatusInternalServerError: publicMessage = "Internal server error" } w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) response := struct { Error string `json:"error"` ErrorCode string `json:"error_code"` }{ Error: publicMessage, ErrorCode: errorCode, } if err := json.NewEncoder(w).Encode(response); err != nil { s.Logger.Printf("Error encoding error response: %v", err) } } func (s *Server) getUserByUsername(username string) (*data.User, error) { // Use squirrel to build the query safely query, args, err := squirrel.Select("id", "created", "user", "password", "salt", "totp_secret"). From("users"). Where(squirrel.Eq{"user": username}). ToSql() if err != nil { return nil, err } row := s.DB.QueryRow(query, args...) user := &data.User{} var totpSecret sql.NullString err = row.Scan(&user.ID, &user.Created, &user.User, &user.Password, &user.Salt, &totpSecret) if err != nil { return nil, err } // Set TOTP secret if it exists if totpSecret.Valid { user.TOTPSecret = totpSecret.String } // Use squirrel to build the roles query safely rolesQuery, rolesArgs, err := squirrel.Select("r.role"). From("roles r"). Join("user_roles ur ON r.id = ur.rid"). Where(squirrel.Eq{"ur.uid": user.ID}). ToSql() if err != nil { return nil, err } rows, err := s.DB.Query(rolesQuery, rolesArgs...) if err != nil { return nil, err } defer rows.Close() var roles []string for rows.Next() { var role string if err := rows.Scan(&role); err != nil { return nil, err } roles = append(roles, role) } user.Roles = roles return user, nil } func (s *Server) createToken(userID string) (string, int64, error) { // Generate 16 bytes of random data tokenBytes := make([]byte, 16) if _, err := rand.Read(tokenBytes); err != nil { return "", 0, err } // Hex encode the random bytes to get a 32-character string. token := hex.EncodeToString(tokenBytes) expires := time.Now().Add(24 * time.Hour).Unix() tokenID := ulid.Make().String() // Use squirrel to build the insert query safely query, args, err := squirrel.Insert("tokens"). Columns("id", "uid", "token", "expires"). Values(tokenID, userID, token, expires). ToSql() if err != nil { return "", 0, err } _, err = s.DB.Exec(query, args...) if err != nil { return "", 0, err } return token, expires, nil } func (s *Server) verifyToken(username, token string) (string, error) { // Use squirrel to build the query safely query, args, err := squirrel.Select("t.uid", "t.expires"). From("tokens t"). Join("users u ON t.uid = u.id"). Where(squirrel.And{ squirrel.Eq{"u.user": username}, squirrel.Eq{"t.token": token}, }). ToSql() if err != nil { return "", err } var userID string var expires int64 err = s.DB.QueryRow(query, args...).Scan(&userID, &expires) if err != nil { return "", err } if expires > 0 && expires < time.Now().Unix() { return "", errors.New("token expired") } return userID, nil } func (s *Server) renewToken(username, token string) (int64, error) { // First, verify the token exists and get the token ID // Use squirrel to build the query safely query, args, err := squirrel.Select("t.id"). From("tokens t"). Join("users u ON t.uid = u.id"). Where(squirrel.And{ squirrel.Eq{"u.user": username}, squirrel.Eq{"t.token": token}, }). ToSql() if err != nil { return 0, err } var tokenID string err = s.DB.QueryRow(query, args...).Scan(&tokenID) if err != nil { return 0, err } // Update the token's expiry time expires := time.Now().Add(24 * time.Hour).Unix() // Use squirrel to build the update query safely updateQuery, updateArgs, err := squirrel.Update("tokens"). Set("expires", expires). Where(squirrel.Eq{"id": tokenID}). ToSql() if err != nil { return 0, err } _, err = s.DB.Exec(updateQuery, updateArgs...) if err != nil { return 0, err } return expires, nil } func (s *Server) handleTOTPVerify(w http.ResponseWriter, r *http.Request) { var req TOTPVerifyRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.sendError(w, "Invalid request format", http.StatusBadRequest) return } if req.Version != "v1" || req.Username == "" || req.TOTPCode == "" { s.sendError(w, "Invalid TOTP verification request", http.StatusBadRequest) return } user, err := s.getUserByUsername(req.Username) if err != nil { if errors.Is(err, sql.ErrNoRows) { s.sendError(w, "User not found", http.StatusUnauthorized) } else { s.Logger.Printf("Database error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) } return } // Check if TOTP is enabled for the user if !user.HasTOTP() { // Log the security event details := map[string]string{ "reason": "TOTP not enabled for user", } s.LogSecurityEvent(r, "totp_verification_attempt", user.ID, user.User, false, details) s.sendError(w, "TOTP not enabled for user", http.StatusBadRequest) return } // Validate the TOTP code valid, validErr := user.ValidateTOTPCode(req.TOTPCode) if validErr != nil || !valid { // Log the security event details := map[string]string{ "reason": "Invalid TOTP code", } s.LogSecurityEvent(r, "totp_verification_attempt", user.ID, user.User, false, details) s.sendError(w, "Invalid TOTP code", http.StatusUnauthorized) return } // TOTP code is valid, create a token token, expires, err := s.createToken(user.ID) if err != nil { s.Logger.Printf("Token creation error: %v", err) // Log the security event details := map[string]string{ "reason": "Token creation error", "error": err.Error(), } s.LogSecurityEvent(r, "totp_verification_attempt", user.ID, user.User, false, details) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } // Log successful TOTP verification details := map[string]string{ "token_expires": time.Unix(expires, 0).UTC().Format(time.RFC3339), } s.LogSecurityEvent(r, "totp_verification_success", user.ID, user.User, true, details) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(TokenResponse{ Token: token, Expires: expires, }); err != nil { s.Logger.Printf("Error encoding response: %v", err) } } func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Request) { // Extract authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { s.sendError(w, "Authorization header required", http.StatusUnauthorized) return } // Check if it's a Bearer token parts := strings.Split(authHeader, " ") if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { s.sendError(w, "Invalid authorization format", http.StatusUnauthorized) return } token := parts[1] username := r.URL.Query().Get("username") if username == "" { s.sendError(w, "Username parameter required", http.StatusBadRequest) return } // Verify the token userID, err := s.verifyToken(username, token) if err != nil { s.sendError(w, "Invalid or expired token", http.StatusUnauthorized) return } // Check if user has permission to read database credentials user, err := s.getUserByUsername(username) if err != nil { s.Logger.Printf("Database error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } hasPermission, err := user.HasPermission(s.Auth, "database_credentials", "read") if err != nil { s.Logger.Printf("Permission check error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } if !hasPermission { s.sendError(w, "Insufficient permissions: requires database_credentials:read permission", http.StatusForbidden) return } // Get database ID from query parameter if provided databaseID := r.URL.Query().Get("database_id") // Build the query to retrieve databases the user has access to queryBuilder := squirrel.Select("d.id", "d.host", "d.port", "d.name", "d.user", "d.password"). From("database d") // If the user is an admin, they can see all databases isAdmin := false for _, role := range user.Roles { if role == "admin" { isAdmin = true break } } if !isAdmin { // Non-admin users can only see databases they're explicitly associated with queryBuilder = queryBuilder.Join("database_users du ON d.id = du.db_id"). Where(squirrel.Eq{"du.uid": userID}) } // If a specific database ID was requested, filter by that if databaseID != "" { queryBuilder = queryBuilder.Where(squirrel.Eq{"d.id": databaseID}) } query, args, err := queryBuilder.ToSql() if err != nil { s.Logger.Printf("Query building error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } rows, err := s.DB.Query(query, args...) if err != nil { s.Logger.Printf("Database error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } defer rows.Close() var databases []DatabaseCredentials for rows.Next() { var id string var creds DatabaseCredentials err = rows.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password) if err != nil { s.Logger.Printf("Row scanning error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } databases = append(databases, creds) } if err = rows.Err(); err != nil { s.Logger.Printf("Rows iteration error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } if len(databases) == 0 { s.sendError(w, "No database credentials found", http.StatusNotFound) return } // If a specific database was requested, return just that one if databaseID != "" && len(databases) == 1 { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(databases[0]); err != nil { s.Logger.Printf("Error encoding response: %v", err) } return } // Otherwise return all accessible databases w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(databases); err != nil { s.Logger.Printf("Error encoding response: %v", err) } }