Junie: add user permissions to databases.

This commit is contained in:
2025-06-07 11:38:04 -07:00
parent ccdbcce9c0
commit ab255d5d58
9 changed files with 620 additions and 72 deletions

View File

@@ -445,7 +445,7 @@ func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Reques
}
// Verify the token
_, err := s.verifyToken(username, token)
userID, err := s.verifyToken(username, token)
if err != nil {
s.sendError(w, "Invalid or expired token", http.StatusUnauthorized)
return
@@ -471,37 +471,86 @@ func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Reques
return
}
// Retrieve database credentials
// Use squirrel to build the query safely
query, args, err := squirrel.Select("id", "host", "port", "name", "user", "password").
From("database").
Limit(1).
ToSql()
// Get database ID from query parameter if provided
databaseID := r.URL.Query().Get("database_id")
// Build the query to retrieve databases the user has access to
queryBuilder := squirrel.Select("d.id", "d.host", "d.port", "d.name", "d.user", "d.password").
From("database d")
// If the user is an admin, they can see all databases
isAdmin := false
for _, role := range user.Roles {
if role == "admin" {
isAdmin = true
break
}
}
if !isAdmin {
// Non-admin users can only see databases they're explicitly associated with
queryBuilder = queryBuilder.Join("database_users du ON d.id = du.db_id").
Where(squirrel.Eq{"du.uid": userID})
}
// If a specific database ID was requested, filter by that
if databaseID != "" {
queryBuilder = queryBuilder.Where(squirrel.Eq{"d.id": databaseID})
}
query, args, err := queryBuilder.ToSql()
if err != nil {
s.Logger.Printf("Query building error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
row := s.DB.QueryRow(query, args...)
var id string
var creds DatabaseCredentials
err = row.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password)
rows, err := s.DB.Query(query, args...)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
s.sendError(w, "No database credentials found", http.StatusNotFound)
} else {
s.Logger.Printf("Database error: %v", err)
s.Logger.Printf("Database error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
defer rows.Close()
var databases []DatabaseCredentials
for rows.Next() {
var id string
var creds DatabaseCredentials
err = rows.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password)
if err != nil {
s.Logger.Printf("Row scanning error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
databases = append(databases, creds)
}
if err = rows.Err(); err != nil {
s.Logger.Printf("Rows iteration error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
if len(databases) == 0 {
s.sendError(w, "No database credentials found", http.StatusNotFound)
return
}
// If a specific database was requested, return just that one
if databaseID != "" && len(databases) == 1 {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(databases[0]); err != nil {
s.Logger.Printf("Error encoding response: %v", err)
}
return
}
// Return the credentials
// Otherwise return all accessible databases
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(creds); err != nil {
if err := json.NewEncoder(w).Encode(databases); err != nil {
s.Logger.Printf("Error encoding response: %v", err)
}
}