package api import ( "database/sql" "encoding/json" "errors" "net/http" "strings" "time" "git.wntrmute.dev/kyle/mcias/data" "github.com/oklog/ulid/v2" ) type LoginRequest struct { Version string `json:"version"` Login data.Login `json:"login"` } type TokenResponse struct { Token string `json:"token"` Expires int64 `json:"expires"` } type TOTPVerifyRequest struct { Version string `json:"version"` Username string `json:"username"` TOTPCode string `json:"totp_code"` } type ErrorResponse struct { Error string `json:"error"` } type DatabaseCredentials struct { Host string `json:"host"` Port int `json:"port"` Name string `json:"name"` User string `json:"user"` Password string `json:"password"` } func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { var req LoginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.sendError(w, "Invalid request format", http.StatusBadRequest) return } if req.Version != "v1" || req.Login.User == "" || req.Login.Password == "" { s.sendError(w, "Invalid login request", http.StatusBadRequest) return } user, err := s.getUserByUsername(req.Login.User) if err != nil { if errors.Is(err, sql.ErrNoRows) { s.sendError(w, "Invalid username or password", http.StatusUnauthorized) } else { s.Logger.Printf("Database error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) } return } // Check password only first if !user.CheckPassword(&req.Login) { s.sendError(w, "Invalid username or password", http.StatusUnauthorized) return } // If TOTP is enabled and a code was provided, verify it if user.HasTOTP() { if req.Login.TOTPCode == "" { // TOTP is enabled but no code was provided // Return a special response indicating TOTP is required w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) if err := json.NewEncoder(w).Encode(ErrorResponse{ Error: "TOTP code required", }); err != nil { s.Logger.Printf("Error encoding response: %v", err) } return } // Validate the TOTP code valid, validErr := user.ValidateTOTPCode(req.Login.TOTPCode) if validErr != nil || !valid { s.sendError(w, "Invalid TOTP code", http.StatusUnauthorized) return } } token, expires, err := s.createToken(user.ID) if err != nil { s.Logger.Printf("Token creation error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(TokenResponse{ Token: token, Expires: expires, }); err != nil { s.Logger.Printf("Error encoding response: %v", err) } } func (s *Server) handleTokenLogin(w http.ResponseWriter, r *http.Request) { var req LoginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.sendError(w, "Invalid request format", http.StatusBadRequest) return } if req.Version != "v1" || req.Login.User == "" || req.Login.Token == "" { s.sendError(w, "Invalid login request", http.StatusBadRequest) return } // Verify the token is valid _, err := s.verifyToken(req.Login.User, req.Login.Token) if err != nil { s.sendError(w, "Invalid or expired token", http.StatusUnauthorized) return } // Renew the existing token instead of creating a new one expires, err := s.renewToken(req.Login.User, req.Login.Token) if err != nil { s.Logger.Printf("Token renewal error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(TokenResponse{ Token: req.Login.Token, Expires: expires, }); err != nil { s.Logger.Printf("Error encoding response: %v", err) } } func (s *Server) sendError(w http.ResponseWriter, message string, status int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(ErrorResponse{Error: message}); err != nil { s.Logger.Printf("Error encoding error response: %v", err) } } func (s *Server) getUserByUsername(username string) (*data.User, error) { query := `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?` row := s.DB.QueryRow(query, username) user := &data.User{} var totpSecret sql.NullString err := row.Scan(&user.ID, &user.Created, &user.User, &user.Password, &user.Salt, &totpSecret) if err != nil { return nil, err } // Set TOTP secret if it exists if totpSecret.Valid { user.TOTPSecret = totpSecret.String } rolesQuery := ` SELECT r.role FROM roles r JOIN user_roles ur ON r.id = ur.rid WHERE ur.uid = ? ` rows, err := s.DB.Query(rolesQuery, user.ID) if err != nil { return nil, err } defer rows.Close() var roles []string for rows.Next() { var role string if err := rows.Scan(&role); err != nil { return nil, err } roles = append(roles, role) } user.Roles = roles return user, nil } func (s *Server) createToken(userID string) (string, int64, error) { token := ulid.Make().String() expires := time.Now().Add(24 * time.Hour).Unix() query := `INSERT INTO tokens (id, uid, token, expires) VALUES (?, ?, ?, ?)` tokenID := ulid.Make().String() _, err := s.DB.Exec(query, tokenID, userID, token, expires) if err != nil { return "", 0, err } return token, expires, nil } func (s *Server) verifyToken(username, token string) (string, error) { query := ` SELECT t.uid, t.expires FROM tokens t JOIN users u ON t.uid = u.id WHERE u.user = ? AND t.token = ? ` var userID string var expires int64 err := s.DB.QueryRow(query, username, token).Scan(&userID, &expires) if err != nil { return "", err } if expires > 0 && expires < time.Now().Unix() { return "", errors.New("token expired") } return userID, nil } func (s *Server) renewToken(username, token string) (int64, error) { // First, verify the token exists and get the token ID query := ` SELECT t.id FROM tokens t JOIN users u ON t.uid = u.id WHERE u.user = ? AND t.token = ? ` var tokenID string err := s.DB.QueryRow(query, username, token).Scan(&tokenID) if err != nil { return 0, err } // Update the token's expiry time expires := time.Now().Add(24 * time.Hour).Unix() updateQuery := `UPDATE tokens SET expires = ? WHERE id = ?` _, err = s.DB.Exec(updateQuery, expires, tokenID) if err != nil { return 0, err } return expires, nil } func (s *Server) handleTOTPVerify(w http.ResponseWriter, r *http.Request) { var req TOTPVerifyRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { s.sendError(w, "Invalid request format", http.StatusBadRequest) return } if req.Version != "v1" || req.Username == "" || req.TOTPCode == "" { s.sendError(w, "Invalid TOTP verification request", http.StatusBadRequest) return } user, err := s.getUserByUsername(req.Username) if err != nil { if errors.Is(err, sql.ErrNoRows) { s.sendError(w, "User not found", http.StatusUnauthorized) } else { s.Logger.Printf("Database error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) } return } // Check if TOTP is enabled for the user if !user.HasTOTP() { s.sendError(w, "TOTP not enabled for user", http.StatusBadRequest) return } // Validate the TOTP code valid, validErr := user.ValidateTOTPCode(req.TOTPCode) if validErr != nil || !valid { s.sendError(w, "Invalid TOTP code", http.StatusUnauthorized) return } // TOTP code is valid, create a token token, expires, err := s.createToken(user.ID) if err != nil { s.Logger.Printf("Token creation error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(TokenResponse{ Token: token, Expires: expires, }); err != nil { s.Logger.Printf("Error encoding response: %v", err) } } func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Request) { // Extract authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { s.sendError(w, "Authorization header required", http.StatusUnauthorized) return } // Check if it's a Bearer token parts := strings.Split(authHeader, " ") if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" { s.sendError(w, "Invalid authorization format", http.StatusUnauthorized) return } token := parts[1] username := r.URL.Query().Get("username") if username == "" { s.sendError(w, "Username parameter required", http.StatusBadRequest) return } // Verify the token _, err := s.verifyToken(username, token) if err != nil { s.sendError(w, "Invalid or expired token", http.StatusUnauthorized) return } // Check if user has permission to read database credentials user, err := s.getUserByUsername(username) if err != nil { s.Logger.Printf("Database error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } hasPermission, err := user.HasPermission(s.Auth, "database_credentials", "read") if err != nil { s.Logger.Printf("Permission check error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } if !hasPermission { s.sendError(w, "Insufficient permissions: requires database_credentials:read permission", http.StatusForbidden) return } // Retrieve database credentials query := `SELECT id, host, port, name, user, password FROM database LIMIT 1` row := s.DB.QueryRow(query) var id string var creds DatabaseCredentials err = row.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password) if err != nil { if errors.Is(err, sql.ErrNoRows) { s.sendError(w, "No database credentials found", http.StatusNotFound) } else { s.Logger.Printf("Database error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) } return } // Return the credentials w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(creds); err != nil { s.Logger.Printf("Error encoding response: %v", err) } }