package main import ( "database/sql" "fmt" "log" "os" "github.com/Masterminds/squirrel" "github.com/oklog/ulid/v2" "github.com/spf13/cobra" ) var ( dbUserCmd = &cobra.Command{ Use: "db-user", Short: "Manage database user associations", Long: `Commands for managing which users can access which database credentials.`, } addDbUserCmd = &cobra.Command{ Use: "add [database-id] [user-id]", Short: "Associate a user with a database", Long: `Associate a user with a database, allowing them to read its credentials.`, Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { addUserToDatabase(args[0], args[1]) }, } removeDbUserCmd = &cobra.Command{ Use: "remove [database-id] [user-id]", Short: "Remove a user's association with a database", Long: `Remove a user's association with a database, preventing them from reading its credentials.`, Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { removeUserFromDatabase(args[0], args[1]) }, } listDbUsersCmd = &cobra.Command{ Use: "list [database-id]", Short: "List users associated with a database", Long: `List all users who have access to read the credentials of a specific database.`, Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { listDatabaseUsers(args[0]) }, } listUserDbsCmd = &cobra.Command{ Use: "list-user-dbs [user-id]", Short: "List databases a user can access", Long: `List all databases that a specific user has access to.`, Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { listUserDatabases(args[0]) }, } ) // nolint:gochecknoinits // This is a standard pattern in Cobra applications func init() { rootCmd.AddCommand(dbUserCmd) dbUserCmd.AddCommand(addDbUserCmd) dbUserCmd.AddCommand(removeDbUserCmd) dbUserCmd.AddCommand(listDbUsersCmd) dbUserCmd.AddCommand(listUserDbsCmd) } // addUserToDatabase associates a user with a database, allowing them to read its credentials func addUserToDatabase(databaseID, userID string) { logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags) // Open the database db, err := openDatabase(dbPath) if err != nil { logger.Fatalf("Failed to open database: %v", err) } defer db.Close() // Check if the database exists exists, err := checkDatabaseExists(db, databaseID) if err != nil { logger.Fatalf("Failed to check if database exists: %v", err) } if !exists { logger.Fatalf("Database with ID %s does not exist", databaseID) } // Check if the user exists exists, err = checkUserExists(db, userID) if err != nil { logger.Fatalf("Failed to check if user exists: %v", err) } if !exists { logger.Fatalf("User with ID %s does not exist", userID) } // Check if the association already exists exists, err = checkAssociationExists(db, databaseID, userID) if err != nil { logger.Fatalf("Failed to check if association exists: %v", err) } if exists { logger.Printf("User %s already has access to database %s", userID, databaseID) return } // Create a new association id := ulid.Make().String() query, args, err := squirrel.Insert("database_users"). Columns("id", "db_id", "uid"). Values(id, databaseID, userID). ToSql() if err != nil { logger.Fatalf("Failed to build query: %v", err) } _, err = db.Exec(query, args...) if err != nil { logger.Fatalf("Failed to add user to database: %v", err) } logger.Printf("User %s now has access to database %s", userID, databaseID) } // removeUserFromDatabase removes a user's association with a database func removeUserFromDatabase(databaseID, userID string) { logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags) // Open the database db, err := openDatabase(dbPath) if err != nil { logger.Fatalf("Failed to open database: %v", err) } defer db.Close() // Check if the association exists exists, err := checkAssociationExists(db, databaseID, userID) if err != nil { logger.Fatalf("Failed to check if association exists: %v", err) } if !exists { logger.Printf("User %s does not have access to database %s", userID, databaseID) return } // Remove the association query, args, err := squirrel.Delete("database_users"). Where(squirrel.Eq{"db_id": databaseID, "uid": userID}). ToSql() if err != nil { logger.Fatalf("Failed to build query: %v", err) } _, err = db.Exec(query, args...) if err != nil { logger.Fatalf("Failed to remove user from database: %v", err) } logger.Printf("User %s no longer has access to database %s", userID, databaseID) } // listDatabaseUsers lists all users who have access to a specific database func listDatabaseUsers(databaseID string) { logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags) // Open the database db, err := openDatabase(dbPath) if err != nil { logger.Fatalf("Failed to open database: %v", err) } defer db.Close() // Check if the database exists exists, err := checkDatabaseExists(db, databaseID) if err != nil { logger.Fatalf("Failed to check if database exists: %v", err) } if !exists { logger.Fatalf("Database with ID %s does not exist", databaseID) } // Get the database name for display var dbName string err = db.QueryRow("SELECT name FROM database WHERE id = ?", databaseID).Scan(&dbName) if err != nil { logger.Fatalf("Failed to get database name: %v", err) } // Query all users who have access to this database query, args, err := squirrel.Select("u.id", "u.user"). From("users u"). Join("database_users du ON u.id = du.uid"). Where(squirrel.Eq{"du.db_id": databaseID}). OrderBy("u.user"). ToSql() if err != nil { logger.Fatalf("Failed to build query: %v", err) } rows, err := db.Query(query, args...) if err != nil { logger.Fatalf("Failed to query users: %v", err) } defer rows.Close() fmt.Printf("Users with access to database '%s' (ID: %s):\n\n", dbName, databaseID) count := 0 for rows.Next() { var id, username string if err := rows.Scan(&id, &username); err != nil { logger.Fatalf("Failed to scan row: %v", err) } fmt.Printf(" %s (ID: %s)\n", username, id) count++ } if count == 0 { fmt.Println(" No users have access to this database") } else { fmt.Printf("\nTotal: %d user(s)\n", count) } } // listUserDatabases lists all databases a specific user has access to func listUserDatabases(userID string) { logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags) // Open the database db, err := openDatabase(dbPath) if err != nil { logger.Fatalf("Failed to open database: %v", err) } defer db.Close() // Check if the user exists exists, err := checkUserExists(db, userID) if err != nil { logger.Fatalf("Failed to check if user exists: %v", err) } if !exists { logger.Fatalf("User with ID %s does not exist", userID) } // Get the username for display var username string err = db.QueryRow("SELECT user FROM users WHERE id = ?", userID).Scan(&username) if err != nil { logger.Fatalf("Failed to get username: %v", err) } // Query all databases this user has access to query, args, err := squirrel.Select("d.id", "d.name", "d.host"). From("database d"). Join("database_users du ON d.id = du.db_id"). Where(squirrel.Eq{"du.uid": userID}). OrderBy("d.name"). ToSql() if err != nil { logger.Fatalf("Failed to build query: %v", err) } rows, err := db.Query(query, args...) if err != nil { logger.Fatalf("Failed to query databases: %v", err) } defer rows.Close() fmt.Printf("Databases accessible by user '%s' (ID: %s):\n\n", username, userID) count := 0 for rows.Next() { var id, name, host string if err := rows.Scan(&id, &name, &host); err != nil { logger.Fatalf("Failed to scan row: %v", err) } fmt.Printf(" %s (%s) (ID: %s)\n", name, host, id) count++ } if count == 0 { fmt.Println(" This user does not have access to any databases") } else { fmt.Printf("\nTotal: %d database(s)\n", count) } } // Helper functions // checkDatabaseExists checks if a database with the given ID exists func checkDatabaseExists(db *sql.DB, databaseID string) (bool, error) { var count int err := db.QueryRow("SELECT COUNT(*) FROM database WHERE id = ?", databaseID).Scan(&count) if err != nil { return false, err } return count > 0, nil } // checkUserExists checks if a user with the given ID exists func checkUserExists(db *sql.DB, userID string) (bool, error) { var count int err := db.QueryRow("SELECT COUNT(*) FROM users WHERE id = ?", userID).Scan(&count) if err != nil { return false, err } return count > 0, nil } // checkAssociationExists checks if a user is already associated with a database func checkAssociationExists(db *sql.DB, databaseID, userID string) (bool, error) { var count int err := db.QueryRow("SELECT COUNT(*) FROM database_users WHERE db_id = ? AND uid = ?", databaseID, userID).Scan(&count) if err != nil { return false, err } return count > 0, nil }