322 lines
8.7 KiB
Go
322 lines
8.7 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
|
|
"github.com/Masterminds/squirrel"
|
|
"github.com/oklog/ulid/v2"
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
var (
|
|
dbUserCmd = &cobra.Command{
|
|
Use: "db-user",
|
|
Short: "Manage database user associations",
|
|
Long: `Commands for managing which users can access which database credentials.`,
|
|
}
|
|
|
|
addDbUserCmd = &cobra.Command{
|
|
Use: "add [database-id] [user-id]",
|
|
Short: "Associate a user with a database",
|
|
Long: `Associate a user with a database, allowing them to read its credentials.`,
|
|
Args: cobra.ExactArgs(2),
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
addUserToDatabase(args[0], args[1])
|
|
},
|
|
}
|
|
|
|
removeDbUserCmd = &cobra.Command{
|
|
Use: "remove [database-id] [user-id]",
|
|
Short: "Remove a user's association with a database",
|
|
Long: `Remove a user's association with a database, preventing them from reading its credentials.`,
|
|
Args: cobra.ExactArgs(2),
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
removeUserFromDatabase(args[0], args[1])
|
|
},
|
|
}
|
|
|
|
listDbUsersCmd = &cobra.Command{
|
|
Use: "list [database-id]",
|
|
Short: "List users associated with a database",
|
|
Long: `List all users who have access to read the credentials of a specific database.`,
|
|
Args: cobra.ExactArgs(1),
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
listDatabaseUsers(args[0])
|
|
},
|
|
}
|
|
|
|
listUserDbsCmd = &cobra.Command{
|
|
Use: "list-user-dbs [user-id]",
|
|
Short: "List databases a user can access",
|
|
Long: `List all databases that a specific user has access to.`,
|
|
Args: cobra.ExactArgs(1),
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
listUserDatabases(args[0])
|
|
},
|
|
}
|
|
)
|
|
|
|
// nolint:gochecknoinits // This is a standard pattern in Cobra applications
|
|
func init() {
|
|
rootCmd.AddCommand(dbUserCmd)
|
|
dbUserCmd.AddCommand(addDbUserCmd)
|
|
dbUserCmd.AddCommand(removeDbUserCmd)
|
|
dbUserCmd.AddCommand(listDbUsersCmd)
|
|
dbUserCmd.AddCommand(listUserDbsCmd)
|
|
}
|
|
|
|
// addUserToDatabase associates a user with a database, allowing them to read its credentials
|
|
func addUserToDatabase(databaseID, userID string) {
|
|
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
|
|
|
|
// Open the database
|
|
db, err := openDatabase(dbPath)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Check if the database exists
|
|
exists, err := checkDatabaseExists(db, databaseID)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to check if database exists: %v", err)
|
|
}
|
|
if !exists {
|
|
logger.Fatalf("Database with ID %s does not exist", databaseID)
|
|
}
|
|
|
|
// Check if the user exists
|
|
exists, err = checkUserExists(db, userID)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to check if user exists: %v", err)
|
|
}
|
|
if !exists {
|
|
logger.Fatalf("User with ID %s does not exist", userID)
|
|
}
|
|
|
|
// Check if the association already exists
|
|
exists, err = checkAssociationExists(db, databaseID, userID)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to check if association exists: %v", err)
|
|
}
|
|
if exists {
|
|
logger.Printf("User %s already has access to database %s", userID, databaseID)
|
|
return
|
|
}
|
|
|
|
// Create a new association
|
|
id := ulid.Make().String()
|
|
query, args, err := squirrel.Insert("database_users").
|
|
Columns("id", "db_id", "uid").
|
|
Values(id, databaseID, userID).
|
|
ToSql()
|
|
if err != nil {
|
|
logger.Fatalf("Failed to build query: %v", err)
|
|
}
|
|
|
|
_, err = db.Exec(query, args...)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to add user to database: %v", err)
|
|
}
|
|
|
|
logger.Printf("User %s now has access to database %s", userID, databaseID)
|
|
}
|
|
|
|
// removeUserFromDatabase removes a user's association with a database
|
|
func removeUserFromDatabase(databaseID, userID string) {
|
|
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
|
|
|
|
// Open the database
|
|
db, err := openDatabase(dbPath)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Check if the association exists
|
|
exists, err := checkAssociationExists(db, databaseID, userID)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to check if association exists: %v", err)
|
|
}
|
|
if !exists {
|
|
logger.Printf("User %s does not have access to database %s", userID, databaseID)
|
|
return
|
|
}
|
|
|
|
// Remove the association
|
|
query, args, err := squirrel.Delete("database_users").
|
|
Where(squirrel.Eq{"db_id": databaseID, "uid": userID}).
|
|
ToSql()
|
|
if err != nil {
|
|
logger.Fatalf("Failed to build query: %v", err)
|
|
}
|
|
|
|
_, err = db.Exec(query, args...)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to remove user from database: %v", err)
|
|
}
|
|
|
|
logger.Printf("User %s no longer has access to database %s", userID, databaseID)
|
|
}
|
|
|
|
// listDatabaseUsers lists all users who have access to a specific database
|
|
func listDatabaseUsers(databaseID string) {
|
|
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
|
|
|
|
// Open the database
|
|
db, err := openDatabase(dbPath)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Check if the database exists
|
|
exists, err := checkDatabaseExists(db, databaseID)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to check if database exists: %v", err)
|
|
}
|
|
if !exists {
|
|
logger.Fatalf("Database with ID %s does not exist", databaseID)
|
|
}
|
|
|
|
// Get the database name for display
|
|
var dbName string
|
|
err = db.QueryRow("SELECT name FROM database WHERE id = ?", databaseID).Scan(&dbName)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to get database name: %v", err)
|
|
}
|
|
|
|
// Query all users who have access to this database
|
|
query, args, err := squirrel.Select("u.id", "u.user").
|
|
From("users u").
|
|
Join("database_users du ON u.id = du.uid").
|
|
Where(squirrel.Eq{"du.db_id": databaseID}).
|
|
OrderBy("u.user").
|
|
ToSql()
|
|
if err != nil {
|
|
logger.Fatalf("Failed to build query: %v", err)
|
|
}
|
|
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to query users: %v", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
fmt.Printf("Users with access to database '%s' (ID: %s):\n\n", dbName, databaseID)
|
|
|
|
count := 0
|
|
for rows.Next() {
|
|
var id, username string
|
|
if err := rows.Scan(&id, &username); err != nil {
|
|
logger.Fatalf("Failed to scan row: %v", err)
|
|
}
|
|
fmt.Printf(" %s (ID: %s)\n", username, id)
|
|
count++
|
|
}
|
|
|
|
if count == 0 {
|
|
fmt.Println(" No users have access to this database")
|
|
} else {
|
|
fmt.Printf("\nTotal: %d user(s)\n", count)
|
|
}
|
|
}
|
|
|
|
// listUserDatabases lists all databases a specific user has access to
|
|
func listUserDatabases(userID string) {
|
|
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
|
|
|
|
// Open the database
|
|
db, err := openDatabase(dbPath)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// Check if the user exists
|
|
exists, err := checkUserExists(db, userID)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to check if user exists: %v", err)
|
|
}
|
|
if !exists {
|
|
logger.Fatalf("User with ID %s does not exist", userID)
|
|
}
|
|
|
|
// Get the username for display
|
|
var username string
|
|
err = db.QueryRow("SELECT user FROM users WHERE id = ?", userID).Scan(&username)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to get username: %v", err)
|
|
}
|
|
|
|
// Query all databases this user has access to
|
|
query, args, err := squirrel.Select("d.id", "d.name", "d.host").
|
|
From("database d").
|
|
Join("database_users du ON d.id = du.db_id").
|
|
Where(squirrel.Eq{"du.uid": userID}).
|
|
OrderBy("d.name").
|
|
ToSql()
|
|
if err != nil {
|
|
logger.Fatalf("Failed to build query: %v", err)
|
|
}
|
|
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
logger.Fatalf("Failed to query databases: %v", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
fmt.Printf("Databases accessible by user '%s' (ID: %s):\n\n", username, userID)
|
|
|
|
count := 0
|
|
for rows.Next() {
|
|
var id, name, host string
|
|
if err := rows.Scan(&id, &name, &host); err != nil {
|
|
logger.Fatalf("Failed to scan row: %v", err)
|
|
}
|
|
fmt.Printf(" %s (%s) (ID: %s)\n", name, host, id)
|
|
count++
|
|
}
|
|
|
|
if count == 0 {
|
|
fmt.Println(" This user does not have access to any databases")
|
|
} else {
|
|
fmt.Printf("\nTotal: %d database(s)\n", count)
|
|
}
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
// checkDatabaseExists checks if a database with the given ID exists
|
|
func checkDatabaseExists(db *sql.DB, databaseID string) (bool, error) {
|
|
var count int
|
|
err := db.QueryRow("SELECT COUNT(*) FROM database WHERE id = ?", databaseID).Scan(&count)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
// checkUserExists checks if a user with the given ID exists
|
|
func checkUserExists(db *sql.DB, userID string) (bool, error) {
|
|
var count int
|
|
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE id = ?", userID).Scan(&count)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
// checkAssociationExists checks if a user is already associated with a database
|
|
func checkAssociationExists(db *sql.DB, databaseID, userID string) (bool, error) {
|
|
var count int
|
|
err := db.QueryRow("SELECT COUNT(*) FROM database_users WHERE db_id = ? AND uid = ?", databaseID, userID).Scan(&count)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|