From ab255d5d585feea8326420ba0c8dac279ee94139 Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Sat, 7 Jun 2025 11:38:04 -0700 Subject: [PATCH] Junie: add user permissions to databases. --- api/auth.go | 85 ++++- client/database.go | 93 ++++- cmd/mcias-client/database.go | 136 +++++--- cmd/mcias/database.go | 33 ++ cmd/mcias/database_users.go | 321 ++++++++++++++++++ cmd/mcias/root.go | 2 +- .../000002_add_database_users.down.sql | 4 + .../000002_add_database_users.up.sql | 12 + mcias.service | 6 +- 9 files changed, 620 insertions(+), 72 deletions(-) create mode 100644 cmd/mcias/database_users.go create mode 100644 database/migrations/000002_add_database_users.down.sql create mode 100644 database/migrations/000002_add_database_users.up.sql diff --git a/api/auth.go b/api/auth.go index a95f55c..0eaa207 100644 --- a/api/auth.go +++ b/api/auth.go @@ -445,7 +445,7 @@ func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Reques } // Verify the token - _, err := s.verifyToken(username, token) + userID, err := s.verifyToken(username, token) if err != nil { s.sendError(w, "Invalid or expired token", http.StatusUnauthorized) return @@ -471,37 +471,86 @@ func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Reques return } - // Retrieve database credentials - // Use squirrel to build the query safely - query, args, err := squirrel.Select("id", "host", "port", "name", "user", "password"). - From("database"). - Limit(1). - ToSql() + // Get database ID from query parameter if provided + databaseID := r.URL.Query().Get("database_id") + + // Build the query to retrieve databases the user has access to + queryBuilder := squirrel.Select("d.id", "d.host", "d.port", "d.name", "d.user", "d.password"). + From("database d") + + // If the user is an admin, they can see all databases + isAdmin := false + for _, role := range user.Roles { + if role == "admin" { + isAdmin = true + break + } + } + + if !isAdmin { + // Non-admin users can only see databases they're explicitly associated with + queryBuilder = queryBuilder.Join("database_users du ON d.id = du.db_id"). + Where(squirrel.Eq{"du.uid": userID}) + } + + // If a specific database ID was requested, filter by that + if databaseID != "" { + queryBuilder = queryBuilder.Where(squirrel.Eq{"d.id": databaseID}) + } + + query, args, err := queryBuilder.ToSql() if err != nil { s.Logger.Printf("Query building error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) return } - row := s.DB.QueryRow(query, args...) - - var id string - var creds DatabaseCredentials - err = row.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password) + rows, err := s.DB.Query(query, args...) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - s.sendError(w, "No database credentials found", http.StatusNotFound) - } else { - s.Logger.Printf("Database error: %v", err) + s.Logger.Printf("Database error: %v", err) + s.sendError(w, "Internal server error", http.StatusInternalServerError) + return + } + defer rows.Close() + + var databases []DatabaseCredentials + for rows.Next() { + var id string + var creds DatabaseCredentials + err = rows.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password) + if err != nil { + s.Logger.Printf("Row scanning error: %v", err) s.sendError(w, "Internal server error", http.StatusInternalServerError) + return + } + databases = append(databases, creds) + } + + if err = rows.Err(); err != nil { + s.Logger.Printf("Rows iteration error: %v", err) + s.sendError(w, "Internal server error", http.StatusInternalServerError) + return + } + + if len(databases) == 0 { + s.sendError(w, "No database credentials found", http.StatusNotFound) + return + } + + // If a specific database was requested, return just that one + if databaseID != "" && len(databases) == 1 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(databases[0]); err != nil { + s.Logger.Printf("Error encoding response: %v", err) } return } - // Return the credentials + // Otherwise return all accessible databases w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(creds); err != nil { + if err := json.NewEncoder(w).Encode(databases); err != nil { s.Logger.Printf("Error encoding response: %v", err) } } diff --git a/client/database.go b/client/database.go index 8b4a8f3..5586be1 100644 --- a/client/database.go +++ b/client/database.go @@ -19,8 +19,83 @@ type DatabaseCredentials struct { } // GetDatabaseCredentials retrieves database credentials from the MCIAS server. +// If databaseID is provided, it returns credentials for that specific database. +// If databaseID is empty, it returns the first database the user has access to. // This method requires the client to be authenticated (have a valid token). -func (c *Client) GetDatabaseCredentials(ctx context.Context) (*DatabaseCredentials, error) { +func (c *Client) GetDatabaseCredentials(ctx context.Context, databaseID string) (*DatabaseCredentials, error) { + if !c.IsAuthenticated() { + return nil, fmt.Errorf("client is not authenticated, call LoginWithPassword or LoginWithToken first") + } + + if c.Username == "" { + return nil, fmt.Errorf("username is not set, call LoginWithPassword or LoginWithToken first") + } + + // Build the URL with query parameters + baseURL := fmt.Sprintf("%s/v1/database/credentials", c.BaseURL) + params := url.Values{} + params.Add("username", c.Username) + if databaseID != "" { + params.Add("database_id", databaseID) + } + requestURL := fmt.Sprintf("%s?%s", baseURL, params.Encode()) + + // Create the request + req, err := http.NewRequestWithContext(ctx, "GET", requestURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Add authorization header + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.Token)) + + // Send the request + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Read the response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + // Check for errors + if resp.StatusCode != http.StatusOK { + var errResp ErrorResponse + if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil { + return nil, fmt.Errorf("API error: %s (code: %s)", errResp.Error, errResp.ErrorCode) + } + return nil, fmt.Errorf("API error: %s", resp.Status) + } + + // Try to parse as a single database first (when a specific database_id is requested) + var creds DatabaseCredentials + if err := json.Unmarshal(body, &creds); err == nil { + // Successfully parsed as a single database + return &creds, nil + } + + // If that fails, try to parse as an array of databases + var credsList []DatabaseCredentials + if err := json.Unmarshal(body, &credsList); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // If we got an empty list, return an error + if len(credsList) == 0 { + return nil, fmt.Errorf("no database credentials found") + } + + // Return the first database in the list + return &credsList[0], nil +} + +// GetDatabaseCredentialsList retrieves all database credentials the user has access to. +// This method requires the client to be authenticated (have a valid token). +func (c *Client) GetDatabaseCredentialsList(ctx context.Context) ([]DatabaseCredentials, error) { if !c.IsAuthenticated() { return nil, fmt.Errorf("client is not authenticated, call LoginWithPassword or LoginWithToken first") } @@ -66,11 +141,17 @@ func (c *Client) GetDatabaseCredentials(ctx context.Context) (*DatabaseCredentia return nil, fmt.Errorf("API error: %s", resp.Status) } - // Parse the response - var creds DatabaseCredentials - if err := json.Unmarshal(body, &creds); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) + // Try to parse as an array of databases + var credsList []DatabaseCredentials + if err := json.Unmarshal(body, &credsList); err != nil { + // If that fails, try to parse as a single database + var creds DatabaseCredentials + if err := json.Unmarshal(body, &creds); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + // Return as a single-item list + return []DatabaseCredentials{creds}, nil } - return &creds, nil + return credsList, nil } diff --git a/cmd/mcias-client/database.go b/cmd/mcias-client/database.go index c5be9ed..27c3ccd 100644 --- a/cmd/mcias-client/database.go +++ b/cmd/mcias-client/database.go @@ -4,27 +4,20 @@ import ( "context" "encoding/json" "fmt" - "io" - "net/http" "os" "time" + "git.wntrmute.dev/kyle/mcias/client" "github.com/spf13/cobra" "github.com/spf13/viper" ) -type DatabaseCredentials struct { - Host string `json:"host"` - Port int `json:"port"` - Name string `json:"name"` - User string `json:"user"` - Password string `json:"password"` -} - var ( dbUsername string dbToken string useStored bool + databaseID string + outputJSON bool ) var databaseCmd = &cobra.Command{ @@ -37,25 +30,49 @@ var getCredentialsCmd = &cobra.Command{ Use: "credentials", Short: "Get database credentials", Long: `Retrieve database credentials from the MCIAS system. -This command requires authentication with a username and token.`, +This command requires authentication with a username and token. +If database-id is provided, it returns credentials for that specific database. +If database-id is not provided, it returns the first database the user has access to.`, Run: func(cmd *cobra.Command, args []string) { getCredentials() }, } +var listCredentialsCmd = &cobra.Command{ + Use: "list", + Short: "List all accessible database credentials", + Long: `List all database credentials the user has access to. +This command requires authentication with a username and token.`, + Run: func(cmd *cobra.Command, args []string) { + listCredentials() + }, +} + func init() { rootCmd.AddCommand(databaseCmd) databaseCmd.AddCommand(getCredentialsCmd) + databaseCmd.AddCommand(listCredentialsCmd) + // Flags for getCredentialsCmd getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication") getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token") getCredentialsCmd.Flags().BoolVarP(&useStored, "use-stored", "s", false, "Use stored token from previous login") + getCredentialsCmd.Flags().StringVarP(&databaseID, "database-id", "d", "", "ID of the specific database to retrieve") + getCredentialsCmd.Flags().BoolVarP(&outputJSON, "json", "j", false, "Output in JSON format") + + // Flags for listCredentialsCmd + listCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication") + listCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token") + listCredentialsCmd.Flags().BoolVarP(&useStored, "use-stored", "s", false, "Use stored token from previous login") + listCredentialsCmd.Flags().BoolVarP(&outputJSON, "json", "j", false, "Output in JSON format") // Make username required only if not using stored token getCredentialsCmd.MarkFlagsMutuallyExclusive("token", "use-stored") + listCredentialsCmd.MarkFlagsMutuallyExclusive("token", "use-stored") } -func getCredentials() { +// createClient creates and configures a new MCIAS client +func createClient() *client.Client { // If using stored token, load it from the token file if useStored { tokenInfo, err := loadToken() @@ -83,50 +100,43 @@ func getCredentials() { serverAddr = "http://localhost:8080" } - url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername) + // Create a new client with the appropriate options + c := client.NewClient( + client.WithBaseURL(serverAddr), + client.WithUsername(dbUsername), + client.WithToken(dbToken), + ) + + return c +} + +func getCredentials() { + // Create a new client + c := createClient() // Create a context with timeout ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + // Get database credentials + creds, err := c.GetDatabaseCredentials(ctx, databaseID) if err != nil { - fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err) + fmt.Fprintf(os.Stderr, "Error retrieving database credentials: %v\n", err) os.Exit(1) } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", dbToken)) - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err) - os.Exit(1) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err) - os.Exit(1) - } - - if resp.StatusCode != http.StatusOK { - var errResp ErrorResponse - if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil { - fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error) - } else { - fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status) + // Output in JSON format if requested + if outputJSON { + jsonData, err := json.MarshalIndent(creds, "", " ") + if err != nil { + fmt.Fprintf(os.Stderr, "Error formatting JSON: %v\n", err) + os.Exit(1) } - os.Exit(1) - } - - var creds DatabaseCredentials - if unmarshalErr := json.Unmarshal(body, &creds); unmarshalErr != nil { - fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr) - os.Exit(1) + fmt.Println(string(jsonData)) + return } + // Output in human-readable format fmt.Println("Database Credentials:") fmt.Printf("Host: %s\n", creds.Host) fmt.Printf("Port: %d\n", creds.Port) @@ -135,3 +145,41 @@ func getCredentials() { fmt.Printf("Password: %s\n", creds.Password) } +func listCredentials() { + // Create a new client + c := createClient() + + // Create a context with timeout + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Get all database credentials + credsList, err := c.GetDatabaseCredentialsList(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "Error retrieving database credentials: %v\n", err) + os.Exit(1) + } + + // Output in JSON format if requested + if outputJSON { + jsonData, err := json.MarshalIndent(credsList, "", " ") + if err != nil { + fmt.Fprintf(os.Stderr, "Error formatting JSON: %v\n", err) + os.Exit(1) + } + fmt.Println(string(jsonData)) + return + } + + // Output in human-readable format + fmt.Printf("Found %d database(s):\n\n", len(credsList)) + for i, creds := range credsList { + fmt.Printf("Database #%d:\n", i+1) + fmt.Printf(" Host: %s\n", creds.Host) + fmt.Printf(" Port: %d\n", creds.Port) + fmt.Printf(" Name: %s\n", creds.Name) + fmt.Printf(" User: %s\n", creds.User) + fmt.Printf(" Password: %s\n", creds.Password) + fmt.Println() + } +} diff --git a/cmd/mcias/database.go b/cmd/mcias/database.go index abc472f..97e94d7 100644 --- a/cmd/mcias/database.go +++ b/cmd/mcias/database.go @@ -37,6 +37,36 @@ var databaseCmd = &cobra.Command{ Long: `Commands for managing database credentials in the MCIAS system.`, } +var addUserCmd = &cobra.Command{ + Use: "add-user [database-id] [user-id]", + Short: "Associate a user with a database", + Long: `Associate a user with a database, allowing them to read its credentials.`, + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + addUserToDatabase(args[0], args[1]) + }, +} + +var removeUserCmd = &cobra.Command{ + Use: "remove-user [database-id] [user-id]", + Short: "Remove a user's association with a database", + Long: `Remove a user's association with a database, preventing them from reading its credentials.`, + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + removeUserFromDatabase(args[0], args[1]) + }, +} + +var listUsersCmd = &cobra.Command{ + Use: "list-users [database-id]", + Short: "List users associated with a database", + Long: `List all users who have access to read the credentials of a specific database.`, + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + listDatabaseUsers(args[0]) + }, +} + var getCredentialsCmd = &cobra.Command{ Use: "credentials", Short: "Get database credentials", @@ -51,6 +81,9 @@ This command requires authentication with a username and token.`, func init() { rootCmd.AddCommand(databaseCmd) databaseCmd.AddCommand(getCredentialsCmd) + databaseCmd.AddCommand(addUserCmd) + databaseCmd.AddCommand(removeUserCmd) + databaseCmd.AddCommand(listUsersCmd) getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication") getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token") diff --git a/cmd/mcias/database_users.go b/cmd/mcias/database_users.go new file mode 100644 index 0000000..71c5293 --- /dev/null +++ b/cmd/mcias/database_users.go @@ -0,0 +1,321 @@ +package main + +import ( + "database/sql" + "fmt" + "log" + "os" + + "github.com/Masterminds/squirrel" + "github.com/oklog/ulid/v2" + "github.com/spf13/cobra" +) + +var ( + dbUserCmd = &cobra.Command{ + Use: "db-user", + Short: "Manage database user associations", + Long: `Commands for managing which users can access which database credentials.`, + } + + addDbUserCmd = &cobra.Command{ + Use: "add [database-id] [user-id]", + Short: "Associate a user with a database", + Long: `Associate a user with a database, allowing them to read its credentials.`, + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + addUserToDatabase(args[0], args[1]) + }, + } + + removeDbUserCmd = &cobra.Command{ + Use: "remove [database-id] [user-id]", + Short: "Remove a user's association with a database", + Long: `Remove a user's association with a database, preventing them from reading its credentials.`, + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + removeUserFromDatabase(args[0], args[1]) + }, + } + + listDbUsersCmd = &cobra.Command{ + Use: "list [database-id]", + Short: "List users associated with a database", + Long: `List all users who have access to read the credentials of a specific database.`, + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + listDatabaseUsers(args[0]) + }, + } + + listUserDbsCmd = &cobra.Command{ + Use: "list-user-dbs [user-id]", + Short: "List databases a user can access", + Long: `List all databases that a specific user has access to.`, + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + listUserDatabases(args[0]) + }, + } +) + +// nolint:gochecknoinits // This is a standard pattern in Cobra applications +func init() { + rootCmd.AddCommand(dbUserCmd) + dbUserCmd.AddCommand(addDbUserCmd) + dbUserCmd.AddCommand(removeDbUserCmd) + dbUserCmd.AddCommand(listDbUsersCmd) + dbUserCmd.AddCommand(listUserDbsCmd) +} + +// addUserToDatabase associates a user with a database, allowing them to read its credentials +func addUserToDatabase(databaseID, userID string) { + logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags) + + // Open the database + db, err := openDatabase(dbPath) + if err != nil { + logger.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Check if the database exists + exists, err := checkDatabaseExists(db, databaseID) + if err != nil { + logger.Fatalf("Failed to check if database exists: %v", err) + } + if !exists { + logger.Fatalf("Database with ID %s does not exist", databaseID) + } + + // Check if the user exists + exists, err = checkUserExists(db, userID) + if err != nil { + logger.Fatalf("Failed to check if user exists: %v", err) + } + if !exists { + logger.Fatalf("User with ID %s does not exist", userID) + } + + // Check if the association already exists + exists, err = checkAssociationExists(db, databaseID, userID) + if err != nil { + logger.Fatalf("Failed to check if association exists: %v", err) + } + if exists { + logger.Printf("User %s already has access to database %s", userID, databaseID) + return + } + + // Create a new association + id := ulid.Make().String() + query, args, err := squirrel.Insert("database_users"). + Columns("id", "db_id", "uid"). + Values(id, databaseID, userID). + ToSql() + if err != nil { + logger.Fatalf("Failed to build query: %v", err) + } + + _, err = db.Exec(query, args...) + if err != nil { + logger.Fatalf("Failed to add user to database: %v", err) + } + + logger.Printf("User %s now has access to database %s", userID, databaseID) +} + +// removeUserFromDatabase removes a user's association with a database +func removeUserFromDatabase(databaseID, userID string) { + logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags) + + // Open the database + db, err := openDatabase(dbPath) + if err != nil { + logger.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Check if the association exists + exists, err := checkAssociationExists(db, databaseID, userID) + if err != nil { + logger.Fatalf("Failed to check if association exists: %v", err) + } + if !exists { + logger.Printf("User %s does not have access to database %s", userID, databaseID) + return + } + + // Remove the association + query, args, err := squirrel.Delete("database_users"). + Where(squirrel.Eq{"db_id": databaseID, "uid": userID}). + ToSql() + if err != nil { + logger.Fatalf("Failed to build query: %v", err) + } + + _, err = db.Exec(query, args...) + if err != nil { + logger.Fatalf("Failed to remove user from database: %v", err) + } + + logger.Printf("User %s no longer has access to database %s", userID, databaseID) +} + +// listDatabaseUsers lists all users who have access to a specific database +func listDatabaseUsers(databaseID string) { + logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags) + + // Open the database + db, err := openDatabase(dbPath) + if err != nil { + logger.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Check if the database exists + exists, err := checkDatabaseExists(db, databaseID) + if err != nil { + logger.Fatalf("Failed to check if database exists: %v", err) + } + if !exists { + logger.Fatalf("Database with ID %s does not exist", databaseID) + } + + // Get the database name for display + var dbName string + err = db.QueryRow("SELECT name FROM database WHERE id = ?", databaseID).Scan(&dbName) + if err != nil { + logger.Fatalf("Failed to get database name: %v", err) + } + + // Query all users who have access to this database + query, args, err := squirrel.Select("u.id", "u.user"). + From("users u"). + Join("database_users du ON u.id = du.uid"). + Where(squirrel.Eq{"du.db_id": databaseID}). + OrderBy("u.user"). + ToSql() + if err != nil { + logger.Fatalf("Failed to build query: %v", err) + } + + rows, err := db.Query(query, args...) + if err != nil { + logger.Fatalf("Failed to query users: %v", err) + } + defer rows.Close() + + fmt.Printf("Users with access to database '%s' (ID: %s):\n\n", dbName, databaseID) + + count := 0 + for rows.Next() { + var id, username string + if err := rows.Scan(&id, &username); err != nil { + logger.Fatalf("Failed to scan row: %v", err) + } + fmt.Printf(" %s (ID: %s)\n", username, id) + count++ + } + + if count == 0 { + fmt.Println(" No users have access to this database") + } else { + fmt.Printf("\nTotal: %d user(s)\n", count) + } +} + +// listUserDatabases lists all databases a specific user has access to +func listUserDatabases(userID string) { + logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags) + + // Open the database + db, err := openDatabase(dbPath) + if err != nil { + logger.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Check if the user exists + exists, err := checkUserExists(db, userID) + if err != nil { + logger.Fatalf("Failed to check if user exists: %v", err) + } + if !exists { + logger.Fatalf("User with ID %s does not exist", userID) + } + + // Get the username for display + var username string + err = db.QueryRow("SELECT user FROM users WHERE id = ?", userID).Scan(&username) + if err != nil { + logger.Fatalf("Failed to get username: %v", err) + } + + // Query all databases this user has access to + query, args, err := squirrel.Select("d.id", "d.name", "d.host"). + From("database d"). + Join("database_users du ON d.id = du.db_id"). + Where(squirrel.Eq{"du.uid": userID}). + OrderBy("d.name"). + ToSql() + if err != nil { + logger.Fatalf("Failed to build query: %v", err) + } + + rows, err := db.Query(query, args...) + if err != nil { + logger.Fatalf("Failed to query databases: %v", err) + } + defer rows.Close() + + fmt.Printf("Databases accessible by user '%s' (ID: %s):\n\n", username, userID) + + count := 0 + for rows.Next() { + var id, name, host string + if err := rows.Scan(&id, &name, &host); err != nil { + logger.Fatalf("Failed to scan row: %v", err) + } + fmt.Printf(" %s (%s) (ID: %s)\n", name, host, id) + count++ + } + + if count == 0 { + fmt.Println(" This user does not have access to any databases") + } else { + fmt.Printf("\nTotal: %d database(s)\n", count) + } +} + +// Helper functions + +// checkDatabaseExists checks if a database with the given ID exists +func checkDatabaseExists(db *sql.DB, databaseID string) (bool, error) { + var count int + err := db.QueryRow("SELECT COUNT(*) FROM database WHERE id = ?", databaseID).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} + +// checkUserExists checks if a user with the given ID exists +func checkUserExists(db *sql.DB, userID string) (bool, error) { + var count int + err := db.QueryRow("SELECT COUNT(*) FROM users WHERE id = ?", userID).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} + +// checkAssociationExists checks if a user is already associated with a database +func checkAssociationExists(db *sql.DB, databaseID, userID string) (bool, error) { + var count int + err := db.QueryRow("SELECT COUNT(*) FROM database_users WHERE db_id = ? AND uid = ?", databaseID, userID).Scan(&count) + if err != nil { + return false, err + } + return count > 0, nil +} diff --git a/cmd/mcias/root.go b/cmd/mcias/root.go index 1b86ca8..86ae64a 100644 --- a/cmd/mcias/root.go +++ b/cmd/mcias/root.go @@ -42,7 +42,7 @@ func setupRootCommand() { rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.mcias.yaml)") rootCmd.PersistentFlags().StringVar(&dbPath, "db", "mcias.db", "Path to SQLite database file") - rootCmd.PersistentFlags().StringVar(&addr, "addr", ":8080", "Address to listen on") + rootCmd.PersistentFlags().StringVar(&addr, "addr", ":5000", "Address to listen on") if err := viper.BindPFlag("db", rootCmd.PersistentFlags().Lookup("db")); err != nil { fmt.Fprintf(os.Stderr, "Error binding db flag: %v\n", err) } diff --git a/database/migrations/000002_add_database_users.down.sql b/database/migrations/000002_add_database_users.down.sql new file mode 100644 index 0000000..08f70c4 --- /dev/null +++ b/database/migrations/000002_add_database_users.down.sql @@ -0,0 +1,4 @@ +-- Drop the database_users table and its indexes +DROP INDEX IF EXISTS idx_database_users_db_id; +DROP INDEX IF EXISTS idx_database_users_uid; +DROP TABLE IF EXISTS database_users; \ No newline at end of file diff --git a/database/migrations/000002_add_database_users.up.sql b/database/migrations/000002_add_database_users.up.sql new file mode 100644 index 0000000..8f551f4 --- /dev/null +++ b/database/migrations/000002_add_database_users.up.sql @@ -0,0 +1,12 @@ +-- Add database_users table to associate users with databases +CREATE TABLE database_users ( + id TEXT PRIMARY KEY, + uid TEXT NOT NULL, + db_id TEXT NOT NULL, + FOREIGN KEY(uid) REFERENCES users(id), + FOREIGN KEY(db_id) REFERENCES database(id) +); + +-- Add index for faster lookups +CREATE INDEX idx_database_users_uid ON database_users(uid); +CREATE INDEX idx_database_users_db_id ON database_users(db_id); diff --git a/mcias.service b/mcias.service index 5ea6153..4c9e3eb 100644 --- a/mcias.service +++ b/mcias.service @@ -6,8 +6,8 @@ After=network.target Type=simple User=mcias Group=mcias -WorkingDirectory=/opt/mcias -ExecStart=/opt/mcias/mcias server --db /opt/mcias/mcias.db +WorkingDirectory=/srv/mcias +ExecStart=/usr/local/bin/mcias/mcias server --db /srv/mcias/mcias.db Restart=on-failure RestartSec=5 StandardOutput=journal @@ -19,7 +19,7 @@ PrivateTmp=true ProtectSystem=full ProtectHome=true NoNewPrivileges=true -ReadWritePaths=/opt/mcias +ReadWritePaths=/srv/mcias [Install] WantedBy=multi-user.target \ No newline at end of file