Junie: add user permissions to databases.

This commit is contained in:
2025-06-07 11:38:04 -07:00
parent ccdbcce9c0
commit ab255d5d58
9 changed files with 620 additions and 72 deletions

View File

@@ -4,27 +4,20 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
"git.wntrmute.dev/kyle/mcias/client"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type DatabaseCredentials struct {
Host string `json:"host"`
Port int `json:"port"`
Name string `json:"name"`
User string `json:"user"`
Password string `json:"password"`
}
var (
dbUsername string
dbToken string
useStored bool
databaseID string
outputJSON bool
)
var databaseCmd = &cobra.Command{
@@ -37,25 +30,49 @@ var getCredentialsCmd = &cobra.Command{
Use: "credentials",
Short: "Get database credentials",
Long: `Retrieve database credentials from the MCIAS system.
This command requires authentication with a username and token.`,
This command requires authentication with a username and token.
If database-id is provided, it returns credentials for that specific database.
If database-id is not provided, it returns the first database the user has access to.`,
Run: func(cmd *cobra.Command, args []string) {
getCredentials()
},
}
var listCredentialsCmd = &cobra.Command{
Use: "list",
Short: "List all accessible database credentials",
Long: `List all database credentials the user has access to.
This command requires authentication with a username and token.`,
Run: func(cmd *cobra.Command, args []string) {
listCredentials()
},
}
func init() {
rootCmd.AddCommand(databaseCmd)
databaseCmd.AddCommand(getCredentialsCmd)
databaseCmd.AddCommand(listCredentialsCmd)
// Flags for getCredentialsCmd
getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
getCredentialsCmd.Flags().BoolVarP(&useStored, "use-stored", "s", false, "Use stored token from previous login")
getCredentialsCmd.Flags().StringVarP(&databaseID, "database-id", "d", "", "ID of the specific database to retrieve")
getCredentialsCmd.Flags().BoolVarP(&outputJSON, "json", "j", false, "Output in JSON format")
// Flags for listCredentialsCmd
listCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
listCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
listCredentialsCmd.Flags().BoolVarP(&useStored, "use-stored", "s", false, "Use stored token from previous login")
listCredentialsCmd.Flags().BoolVarP(&outputJSON, "json", "j", false, "Output in JSON format")
// Make username required only if not using stored token
getCredentialsCmd.MarkFlagsMutuallyExclusive("token", "use-stored")
listCredentialsCmd.MarkFlagsMutuallyExclusive("token", "use-stored")
}
func getCredentials() {
// createClient creates and configures a new MCIAS client
func createClient() *client.Client {
// If using stored token, load it from the token file
if useStored {
tokenInfo, err := loadToken()
@@ -83,50 +100,43 @@ func getCredentials() {
serverAddr = "http://localhost:8080"
}
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
// Create a new client with the appropriate options
c := client.NewClient(
client.WithBaseURL(serverAddr),
client.WithUsername(dbUsername),
client.WithToken(dbToken),
)
return c
}
func getCredentials() {
// Create a new client
c := createClient()
// Create a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
// Get database credentials
creds, err := c.GetDatabaseCredentials(ctx, databaseID)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
fmt.Fprintf(os.Stderr, "Error retrieving database credentials: %v\n", err)
os.Exit(1)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", dbToken))
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err)
os.Exit(1)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err)
os.Exit(1)
}
if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error)
} else {
fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status)
// Output in JSON format if requested
if outputJSON {
jsonData, err := json.MarshalIndent(creds, "", " ")
if err != nil {
fmt.Fprintf(os.Stderr, "Error formatting JSON: %v\n", err)
os.Exit(1)
}
os.Exit(1)
}
var creds DatabaseCredentials
if unmarshalErr := json.Unmarshal(body, &creds); unmarshalErr != nil {
fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr)
os.Exit(1)
fmt.Println(string(jsonData))
return
}
// Output in human-readable format
fmt.Println("Database Credentials:")
fmt.Printf("Host: %s\n", creds.Host)
fmt.Printf("Port: %d\n", creds.Port)
@@ -135,3 +145,41 @@ func getCredentials() {
fmt.Printf("Password: %s\n", creds.Password)
}
func listCredentials() {
// Create a new client
c := createClient()
// Create a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Get all database credentials
credsList, err := c.GetDatabaseCredentialsList(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "Error retrieving database credentials: %v\n", err)
os.Exit(1)
}
// Output in JSON format if requested
if outputJSON {
jsonData, err := json.MarshalIndent(credsList, "", " ")
if err != nil {
fmt.Fprintf(os.Stderr, "Error formatting JSON: %v\n", err)
os.Exit(1)
}
fmt.Println(string(jsonData))
return
}
// Output in human-readable format
fmt.Printf("Found %d database(s):\n\n", len(credsList))
for i, creds := range credsList {
fmt.Printf("Database #%d:\n", i+1)
fmt.Printf(" Host: %s\n", creds.Host)
fmt.Printf(" Port: %d\n", creds.Port)
fmt.Printf(" Name: %s\n", creds.Name)
fmt.Printf(" User: %s\n", creds.User)
fmt.Printf(" Password: %s\n", creds.Password)
fmt.Println()
}
}

View File

@@ -37,6 +37,36 @@ var databaseCmd = &cobra.Command{
Long: `Commands for managing database credentials in the MCIAS system.`,
}
var addUserCmd = &cobra.Command{
Use: "add-user [database-id] [user-id]",
Short: "Associate a user with a database",
Long: `Associate a user with a database, allowing them to read its credentials.`,
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
addUserToDatabase(args[0], args[1])
},
}
var removeUserCmd = &cobra.Command{
Use: "remove-user [database-id] [user-id]",
Short: "Remove a user's association with a database",
Long: `Remove a user's association with a database, preventing them from reading its credentials.`,
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
removeUserFromDatabase(args[0], args[1])
},
}
var listUsersCmd = &cobra.Command{
Use: "list-users [database-id]",
Short: "List users associated with a database",
Long: `List all users who have access to read the credentials of a specific database.`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
listDatabaseUsers(args[0])
},
}
var getCredentialsCmd = &cobra.Command{
Use: "credentials",
Short: "Get database credentials",
@@ -51,6 +81,9 @@ This command requires authentication with a username and token.`,
func init() {
rootCmd.AddCommand(databaseCmd)
databaseCmd.AddCommand(getCredentialsCmd)
databaseCmd.AddCommand(addUserCmd)
databaseCmd.AddCommand(removeUserCmd)
databaseCmd.AddCommand(listUsersCmd)
getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")

321
cmd/mcias/database_users.go Normal file
View File

@@ -0,0 +1,321 @@
package main
import (
"database/sql"
"fmt"
"log"
"os"
"github.com/Masterminds/squirrel"
"github.com/oklog/ulid/v2"
"github.com/spf13/cobra"
)
var (
dbUserCmd = &cobra.Command{
Use: "db-user",
Short: "Manage database user associations",
Long: `Commands for managing which users can access which database credentials.`,
}
addDbUserCmd = &cobra.Command{
Use: "add [database-id] [user-id]",
Short: "Associate a user with a database",
Long: `Associate a user with a database, allowing them to read its credentials.`,
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
addUserToDatabase(args[0], args[1])
},
}
removeDbUserCmd = &cobra.Command{
Use: "remove [database-id] [user-id]",
Short: "Remove a user's association with a database",
Long: `Remove a user's association with a database, preventing them from reading its credentials.`,
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
removeUserFromDatabase(args[0], args[1])
},
}
listDbUsersCmd = &cobra.Command{
Use: "list [database-id]",
Short: "List users associated with a database",
Long: `List all users who have access to read the credentials of a specific database.`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
listDatabaseUsers(args[0])
},
}
listUserDbsCmd = &cobra.Command{
Use: "list-user-dbs [user-id]",
Short: "List databases a user can access",
Long: `List all databases that a specific user has access to.`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
listUserDatabases(args[0])
},
}
)
// nolint:gochecknoinits // This is a standard pattern in Cobra applications
func init() {
rootCmd.AddCommand(dbUserCmd)
dbUserCmd.AddCommand(addDbUserCmd)
dbUserCmd.AddCommand(removeDbUserCmd)
dbUserCmd.AddCommand(listDbUsersCmd)
dbUserCmd.AddCommand(listUserDbsCmd)
}
// addUserToDatabase associates a user with a database, allowing them to read its credentials
func addUserToDatabase(databaseID, userID string) {
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
// Open the database
db, err := openDatabase(dbPath)
if err != nil {
logger.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Check if the database exists
exists, err := checkDatabaseExists(db, databaseID)
if err != nil {
logger.Fatalf("Failed to check if database exists: %v", err)
}
if !exists {
logger.Fatalf("Database with ID %s does not exist", databaseID)
}
// Check if the user exists
exists, err = checkUserExists(db, userID)
if err != nil {
logger.Fatalf("Failed to check if user exists: %v", err)
}
if !exists {
logger.Fatalf("User with ID %s does not exist", userID)
}
// Check if the association already exists
exists, err = checkAssociationExists(db, databaseID, userID)
if err != nil {
logger.Fatalf("Failed to check if association exists: %v", err)
}
if exists {
logger.Printf("User %s already has access to database %s", userID, databaseID)
return
}
// Create a new association
id := ulid.Make().String()
query, args, err := squirrel.Insert("database_users").
Columns("id", "db_id", "uid").
Values(id, databaseID, userID).
ToSql()
if err != nil {
logger.Fatalf("Failed to build query: %v", err)
}
_, err = db.Exec(query, args...)
if err != nil {
logger.Fatalf("Failed to add user to database: %v", err)
}
logger.Printf("User %s now has access to database %s", userID, databaseID)
}
// removeUserFromDatabase removes a user's association with a database
func removeUserFromDatabase(databaseID, userID string) {
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
// Open the database
db, err := openDatabase(dbPath)
if err != nil {
logger.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Check if the association exists
exists, err := checkAssociationExists(db, databaseID, userID)
if err != nil {
logger.Fatalf("Failed to check if association exists: %v", err)
}
if !exists {
logger.Printf("User %s does not have access to database %s", userID, databaseID)
return
}
// Remove the association
query, args, err := squirrel.Delete("database_users").
Where(squirrel.Eq{"db_id": databaseID, "uid": userID}).
ToSql()
if err != nil {
logger.Fatalf("Failed to build query: %v", err)
}
_, err = db.Exec(query, args...)
if err != nil {
logger.Fatalf("Failed to remove user from database: %v", err)
}
logger.Printf("User %s no longer has access to database %s", userID, databaseID)
}
// listDatabaseUsers lists all users who have access to a specific database
func listDatabaseUsers(databaseID string) {
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
// Open the database
db, err := openDatabase(dbPath)
if err != nil {
logger.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Check if the database exists
exists, err := checkDatabaseExists(db, databaseID)
if err != nil {
logger.Fatalf("Failed to check if database exists: %v", err)
}
if !exists {
logger.Fatalf("Database with ID %s does not exist", databaseID)
}
// Get the database name for display
var dbName string
err = db.QueryRow("SELECT name FROM database WHERE id = ?", databaseID).Scan(&dbName)
if err != nil {
logger.Fatalf("Failed to get database name: %v", err)
}
// Query all users who have access to this database
query, args, err := squirrel.Select("u.id", "u.user").
From("users u").
Join("database_users du ON u.id = du.uid").
Where(squirrel.Eq{"du.db_id": databaseID}).
OrderBy("u.user").
ToSql()
if err != nil {
logger.Fatalf("Failed to build query: %v", err)
}
rows, err := db.Query(query, args...)
if err != nil {
logger.Fatalf("Failed to query users: %v", err)
}
defer rows.Close()
fmt.Printf("Users with access to database '%s' (ID: %s):\n\n", dbName, databaseID)
count := 0
for rows.Next() {
var id, username string
if err := rows.Scan(&id, &username); err != nil {
logger.Fatalf("Failed to scan row: %v", err)
}
fmt.Printf(" %s (ID: %s)\n", username, id)
count++
}
if count == 0 {
fmt.Println(" No users have access to this database")
} else {
fmt.Printf("\nTotal: %d user(s)\n", count)
}
}
// listUserDatabases lists all databases a specific user has access to
func listUserDatabases(userID string) {
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
// Open the database
db, err := openDatabase(dbPath)
if err != nil {
logger.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Check if the user exists
exists, err := checkUserExists(db, userID)
if err != nil {
logger.Fatalf("Failed to check if user exists: %v", err)
}
if !exists {
logger.Fatalf("User with ID %s does not exist", userID)
}
// Get the username for display
var username string
err = db.QueryRow("SELECT user FROM users WHERE id = ?", userID).Scan(&username)
if err != nil {
logger.Fatalf("Failed to get username: %v", err)
}
// Query all databases this user has access to
query, args, err := squirrel.Select("d.id", "d.name", "d.host").
From("database d").
Join("database_users du ON d.id = du.db_id").
Where(squirrel.Eq{"du.uid": userID}).
OrderBy("d.name").
ToSql()
if err != nil {
logger.Fatalf("Failed to build query: %v", err)
}
rows, err := db.Query(query, args...)
if err != nil {
logger.Fatalf("Failed to query databases: %v", err)
}
defer rows.Close()
fmt.Printf("Databases accessible by user '%s' (ID: %s):\n\n", username, userID)
count := 0
for rows.Next() {
var id, name, host string
if err := rows.Scan(&id, &name, &host); err != nil {
logger.Fatalf("Failed to scan row: %v", err)
}
fmt.Printf(" %s (%s) (ID: %s)\n", name, host, id)
count++
}
if count == 0 {
fmt.Println(" This user does not have access to any databases")
} else {
fmt.Printf("\nTotal: %d database(s)\n", count)
}
}
// Helper functions
// checkDatabaseExists checks if a database with the given ID exists
func checkDatabaseExists(db *sql.DB, databaseID string) (bool, error) {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM database WHERE id = ?", databaseID).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// checkUserExists checks if a user with the given ID exists
func checkUserExists(db *sql.DB, userID string) (bool, error) {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE id = ?", userID).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// checkAssociationExists checks if a user is already associated with a database
func checkAssociationExists(db *sql.DB, databaseID, userID string) (bool, error) {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM database_users WHERE db_id = ? AND uid = ?", databaseID, userID).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}

View File

@@ -42,7 +42,7 @@ func setupRootCommand() {
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.mcias.yaml)")
rootCmd.PersistentFlags().StringVar(&dbPath, "db", "mcias.db", "Path to SQLite database file")
rootCmd.PersistentFlags().StringVar(&addr, "addr", ":8080", "Address to listen on")
rootCmd.PersistentFlags().StringVar(&addr, "addr", ":5000", "Address to listen on")
if err := viper.BindPFlag("db", rootCmd.PersistentFlags().Lookup("db")); err != nil {
fmt.Fprintf(os.Stderr, "Error binding db flag: %v\n", err)
}