Junie: add user permissions to databases.
This commit is contained in:
85
api/auth.go
85
api/auth.go
@@ -445,7 +445,7 @@ func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
_, err := s.verifyToken(username, token)
|
||||
userID, err := s.verifyToken(username, token)
|
||||
if err != nil {
|
||||
s.sendError(w, "Invalid or expired token", http.StatusUnauthorized)
|
||||
return
|
||||
@@ -471,37 +471,86 @@ func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Reques
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve database credentials
|
||||
// Use squirrel to build the query safely
|
||||
query, args, err := squirrel.Select("id", "host", "port", "name", "user", "password").
|
||||
From("database").
|
||||
Limit(1).
|
||||
ToSql()
|
||||
// Get database ID from query parameter if provided
|
||||
databaseID := r.URL.Query().Get("database_id")
|
||||
|
||||
// Build the query to retrieve databases the user has access to
|
||||
queryBuilder := squirrel.Select("d.id", "d.host", "d.port", "d.name", "d.user", "d.password").
|
||||
From("database d")
|
||||
|
||||
// If the user is an admin, they can see all databases
|
||||
isAdmin := false
|
||||
for _, role := range user.Roles {
|
||||
if role == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isAdmin {
|
||||
// Non-admin users can only see databases they're explicitly associated with
|
||||
queryBuilder = queryBuilder.Join("database_users du ON d.id = du.db_id").
|
||||
Where(squirrel.Eq{"du.uid": userID})
|
||||
}
|
||||
|
||||
// If a specific database ID was requested, filter by that
|
||||
if databaseID != "" {
|
||||
queryBuilder = queryBuilder.Where(squirrel.Eq{"d.id": databaseID})
|
||||
}
|
||||
|
||||
query, args, err := queryBuilder.ToSql()
|
||||
if err != nil {
|
||||
s.Logger.Printf("Query building error: %v", err)
|
||||
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
row := s.DB.QueryRow(query, args...)
|
||||
|
||||
var id string
|
||||
var creds DatabaseCredentials
|
||||
err = row.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password)
|
||||
rows, err := s.DB.Query(query, args...)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
s.sendError(w, "No database credentials found", http.StatusNotFound)
|
||||
} else {
|
||||
s.Logger.Printf("Database error: %v", err)
|
||||
s.Logger.Printf("Database error: %v", err)
|
||||
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var databases []DatabaseCredentials
|
||||
for rows.Next() {
|
||||
var id string
|
||||
var creds DatabaseCredentials
|
||||
err = rows.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password)
|
||||
if err != nil {
|
||||
s.Logger.Printf("Row scanning error: %v", err)
|
||||
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
databases = append(databases, creds)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
s.Logger.Printf("Rows iteration error: %v", err)
|
||||
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if len(databases) == 0 {
|
||||
s.sendError(w, "No database credentials found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// If a specific database was requested, return just that one
|
||||
if databaseID != "" && len(databases) == 1 {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(databases[0]); err != nil {
|
||||
s.Logger.Printf("Error encoding response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Return the credentials
|
||||
// Otherwise return all accessible databases
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(creds); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(databases); err != nil {
|
||||
s.Logger.Printf("Error encoding response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user