Junie: add user permissions to databases.
This commit is contained in:
parent
ccdbcce9c0
commit
ab255d5d58
85
api/auth.go
85
api/auth.go
|
@ -445,7 +445,7 @@ func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Reques
|
|||
}
|
||||
|
||||
// Verify the token
|
||||
_, err := s.verifyToken(username, token)
|
||||
userID, err := s.verifyToken(username, token)
|
||||
if err != nil {
|
||||
s.sendError(w, "Invalid or expired token", http.StatusUnauthorized)
|
||||
return
|
||||
|
@ -471,37 +471,86 @@ func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Reques
|
|||
return
|
||||
}
|
||||
|
||||
// Retrieve database credentials
|
||||
// Use squirrel to build the query safely
|
||||
query, args, err := squirrel.Select("id", "host", "port", "name", "user", "password").
|
||||
From("database").
|
||||
Limit(1).
|
||||
ToSql()
|
||||
// Get database ID from query parameter if provided
|
||||
databaseID := r.URL.Query().Get("database_id")
|
||||
|
||||
// Build the query to retrieve databases the user has access to
|
||||
queryBuilder := squirrel.Select("d.id", "d.host", "d.port", "d.name", "d.user", "d.password").
|
||||
From("database d")
|
||||
|
||||
// If the user is an admin, they can see all databases
|
||||
isAdmin := false
|
||||
for _, role := range user.Roles {
|
||||
if role == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isAdmin {
|
||||
// Non-admin users can only see databases they're explicitly associated with
|
||||
queryBuilder = queryBuilder.Join("database_users du ON d.id = du.db_id").
|
||||
Where(squirrel.Eq{"du.uid": userID})
|
||||
}
|
||||
|
||||
// If a specific database ID was requested, filter by that
|
||||
if databaseID != "" {
|
||||
queryBuilder = queryBuilder.Where(squirrel.Eq{"d.id": databaseID})
|
||||
}
|
||||
|
||||
query, args, err := queryBuilder.ToSql()
|
||||
if err != nil {
|
||||
s.Logger.Printf("Query building error: %v", err)
|
||||
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
row := s.DB.QueryRow(query, args...)
|
||||
|
||||
var id string
|
||||
var creds DatabaseCredentials
|
||||
err = row.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password)
|
||||
rows, err := s.DB.Query(query, args...)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
s.sendError(w, "No database credentials found", http.StatusNotFound)
|
||||
} else {
|
||||
s.Logger.Printf("Database error: %v", err)
|
||||
s.Logger.Printf("Database error: %v", err)
|
||||
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var databases []DatabaseCredentials
|
||||
for rows.Next() {
|
||||
var id string
|
||||
var creds DatabaseCredentials
|
||||
err = rows.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password)
|
||||
if err != nil {
|
||||
s.Logger.Printf("Row scanning error: %v", err)
|
||||
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
databases = append(databases, creds)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
s.Logger.Printf("Rows iteration error: %v", err)
|
||||
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if len(databases) == 0 {
|
||||
s.sendError(w, "No database credentials found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// If a specific database was requested, return just that one
|
||||
if databaseID != "" && len(databases) == 1 {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(databases[0]); err != nil {
|
||||
s.Logger.Printf("Error encoding response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Return the credentials
|
||||
// Otherwise return all accessible databases
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(creds); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(databases); err != nil {
|
||||
s.Logger.Printf("Error encoding response: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,8 +19,83 @@ type DatabaseCredentials struct {
|
|||
}
|
||||
|
||||
// GetDatabaseCredentials retrieves database credentials from the MCIAS server.
|
||||
// If databaseID is provided, it returns credentials for that specific database.
|
||||
// If databaseID is empty, it returns the first database the user has access to.
|
||||
// This method requires the client to be authenticated (have a valid token).
|
||||
func (c *Client) GetDatabaseCredentials(ctx context.Context) (*DatabaseCredentials, error) {
|
||||
func (c *Client) GetDatabaseCredentials(ctx context.Context, databaseID string) (*DatabaseCredentials, error) {
|
||||
if !c.IsAuthenticated() {
|
||||
return nil, fmt.Errorf("client is not authenticated, call LoginWithPassword or LoginWithToken first")
|
||||
}
|
||||
|
||||
if c.Username == "" {
|
||||
return nil, fmt.Errorf("username is not set, call LoginWithPassword or LoginWithToken first")
|
||||
}
|
||||
|
||||
// Build the URL with query parameters
|
||||
baseURL := fmt.Sprintf("%s/v1/database/credentials", c.BaseURL)
|
||||
params := url.Values{}
|
||||
params.Add("username", c.Username)
|
||||
if databaseID != "" {
|
||||
params.Add("database_id", databaseID)
|
||||
}
|
||||
requestURL := fmt.Sprintf("%s?%s", baseURL, params.Encode())
|
||||
|
||||
// Create the request
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Add authorization header
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.Token))
|
||||
|
||||
// Send the request
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read the response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp ErrorResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||
return nil, fmt.Errorf("API error: %s (code: %s)", errResp.Error, errResp.ErrorCode)
|
||||
}
|
||||
return nil, fmt.Errorf("API error: %s", resp.Status)
|
||||
}
|
||||
|
||||
// Try to parse as a single database first (when a specific database_id is requested)
|
||||
var creds DatabaseCredentials
|
||||
if err := json.Unmarshal(body, &creds); err == nil {
|
||||
// Successfully parsed as a single database
|
||||
return &creds, nil
|
||||
}
|
||||
|
||||
// If that fails, try to parse as an array of databases
|
||||
var credsList []DatabaseCredentials
|
||||
if err := json.Unmarshal(body, &credsList); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
// If we got an empty list, return an error
|
||||
if len(credsList) == 0 {
|
||||
return nil, fmt.Errorf("no database credentials found")
|
||||
}
|
||||
|
||||
// Return the first database in the list
|
||||
return &credsList[0], nil
|
||||
}
|
||||
|
||||
// GetDatabaseCredentialsList retrieves all database credentials the user has access to.
|
||||
// This method requires the client to be authenticated (have a valid token).
|
||||
func (c *Client) GetDatabaseCredentialsList(ctx context.Context) ([]DatabaseCredentials, error) {
|
||||
if !c.IsAuthenticated() {
|
||||
return nil, fmt.Errorf("client is not authenticated, call LoginWithPassword or LoginWithToken first")
|
||||
}
|
||||
|
@ -66,11 +141,17 @@ func (c *Client) GetDatabaseCredentials(ctx context.Context) (*DatabaseCredentia
|
|||
return nil, fmt.Errorf("API error: %s", resp.Status)
|
||||
}
|
||||
|
||||
// Parse the response
|
||||
var creds DatabaseCredentials
|
||||
if err := json.Unmarshal(body, &creds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
// Try to parse as an array of databases
|
||||
var credsList []DatabaseCredentials
|
||||
if err := json.Unmarshal(body, &credsList); err != nil {
|
||||
// If that fails, try to parse as a single database
|
||||
var creds DatabaseCredentials
|
||||
if err := json.Unmarshal(body, &creds); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
// Return as a single-item list
|
||||
return []DatabaseCredentials{creds}, nil
|
||||
}
|
||||
|
||||
return &creds, nil
|
||||
return credsList, nil
|
||||
}
|
||||
|
|
|
@ -4,27 +4,20 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/client"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type DatabaseCredentials struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Name string `json:"name"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var (
|
||||
dbUsername string
|
||||
dbToken string
|
||||
useStored bool
|
||||
databaseID string
|
||||
outputJSON bool
|
||||
)
|
||||
|
||||
var databaseCmd = &cobra.Command{
|
||||
|
@ -37,25 +30,49 @@ var getCredentialsCmd = &cobra.Command{
|
|||
Use: "credentials",
|
||||
Short: "Get database credentials",
|
||||
Long: `Retrieve database credentials from the MCIAS system.
|
||||
This command requires authentication with a username and token.`,
|
||||
This command requires authentication with a username and token.
|
||||
If database-id is provided, it returns credentials for that specific database.
|
||||
If database-id is not provided, it returns the first database the user has access to.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
getCredentials()
|
||||
},
|
||||
}
|
||||
|
||||
var listCredentialsCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List all accessible database credentials",
|
||||
Long: `List all database credentials the user has access to.
|
||||
This command requires authentication with a username and token.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
listCredentials()
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(databaseCmd)
|
||||
databaseCmd.AddCommand(getCredentialsCmd)
|
||||
databaseCmd.AddCommand(listCredentialsCmd)
|
||||
|
||||
// Flags for getCredentialsCmd
|
||||
getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
|
||||
getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
|
||||
getCredentialsCmd.Flags().BoolVarP(&useStored, "use-stored", "s", false, "Use stored token from previous login")
|
||||
getCredentialsCmd.Flags().StringVarP(&databaseID, "database-id", "d", "", "ID of the specific database to retrieve")
|
||||
getCredentialsCmd.Flags().BoolVarP(&outputJSON, "json", "j", false, "Output in JSON format")
|
||||
|
||||
// Flags for listCredentialsCmd
|
||||
listCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
|
||||
listCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
|
||||
listCredentialsCmd.Flags().BoolVarP(&useStored, "use-stored", "s", false, "Use stored token from previous login")
|
||||
listCredentialsCmd.Flags().BoolVarP(&outputJSON, "json", "j", false, "Output in JSON format")
|
||||
|
||||
// Make username required only if not using stored token
|
||||
getCredentialsCmd.MarkFlagsMutuallyExclusive("token", "use-stored")
|
||||
listCredentialsCmd.MarkFlagsMutuallyExclusive("token", "use-stored")
|
||||
}
|
||||
|
||||
func getCredentials() {
|
||||
// createClient creates and configures a new MCIAS client
|
||||
func createClient() *client.Client {
|
||||
// If using stored token, load it from the token file
|
||||
if useStored {
|
||||
tokenInfo, err := loadToken()
|
||||
|
@ -83,50 +100,43 @@ func getCredentials() {
|
|||
serverAddr = "http://localhost:8080"
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
|
||||
// Create a new client with the appropriate options
|
||||
c := client.NewClient(
|
||||
client.WithBaseURL(serverAddr),
|
||||
client.WithUsername(dbUsername),
|
||||
client.WithToken(dbToken),
|
||||
)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func getCredentials() {
|
||||
// Create a new client
|
||||
c := createClient()
|
||||
|
||||
// Create a context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
// Get database credentials
|
||||
creds, err := c.GetDatabaseCredentials(ctx, databaseID)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Error retrieving database credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", dbToken))
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp ErrorResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status)
|
||||
// Output in JSON format if requested
|
||||
if outputJSON {
|
||||
jsonData, err := json.MarshalIndent(creds, "", " ")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error formatting JSON: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var creds DatabaseCredentials
|
||||
if unmarshalErr := json.Unmarshal(body, &creds); unmarshalErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr)
|
||||
os.Exit(1)
|
||||
fmt.Println(string(jsonData))
|
||||
return
|
||||
}
|
||||
|
||||
// Output in human-readable format
|
||||
fmt.Println("Database Credentials:")
|
||||
fmt.Printf("Host: %s\n", creds.Host)
|
||||
fmt.Printf("Port: %d\n", creds.Port)
|
||||
|
@ -135,3 +145,41 @@ func getCredentials() {
|
|||
fmt.Printf("Password: %s\n", creds.Password)
|
||||
}
|
||||
|
||||
func listCredentials() {
|
||||
// Create a new client
|
||||
c := createClient()
|
||||
|
||||
// Create a context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Get all database credentials
|
||||
credsList, err := c.GetDatabaseCredentialsList(ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error retrieving database credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Output in JSON format if requested
|
||||
if outputJSON {
|
||||
jsonData, err := json.MarshalIndent(credsList, "", " ")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error formatting JSON: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println(string(jsonData))
|
||||
return
|
||||
}
|
||||
|
||||
// Output in human-readable format
|
||||
fmt.Printf("Found %d database(s):\n\n", len(credsList))
|
||||
for i, creds := range credsList {
|
||||
fmt.Printf("Database #%d:\n", i+1)
|
||||
fmt.Printf(" Host: %s\n", creds.Host)
|
||||
fmt.Printf(" Port: %d\n", creds.Port)
|
||||
fmt.Printf(" Name: %s\n", creds.Name)
|
||||
fmt.Printf(" User: %s\n", creds.User)
|
||||
fmt.Printf(" Password: %s\n", creds.Password)
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,36 @@ var databaseCmd = &cobra.Command{
|
|||
Long: `Commands for managing database credentials in the MCIAS system.`,
|
||||
}
|
||||
|
||||
var addUserCmd = &cobra.Command{
|
||||
Use: "add-user [database-id] [user-id]",
|
||||
Short: "Associate a user with a database",
|
||||
Long: `Associate a user with a database, allowing them to read its credentials.`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
addUserToDatabase(args[0], args[1])
|
||||
},
|
||||
}
|
||||
|
||||
var removeUserCmd = &cobra.Command{
|
||||
Use: "remove-user [database-id] [user-id]",
|
||||
Short: "Remove a user's association with a database",
|
||||
Long: `Remove a user's association with a database, preventing them from reading its credentials.`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
removeUserFromDatabase(args[0], args[1])
|
||||
},
|
||||
}
|
||||
|
||||
var listUsersCmd = &cobra.Command{
|
||||
Use: "list-users [database-id]",
|
||||
Short: "List users associated with a database",
|
||||
Long: `List all users who have access to read the credentials of a specific database.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
listDatabaseUsers(args[0])
|
||||
},
|
||||
}
|
||||
|
||||
var getCredentialsCmd = &cobra.Command{
|
||||
Use: "credentials",
|
||||
Short: "Get database credentials",
|
||||
|
@ -51,6 +81,9 @@ This command requires authentication with a username and token.`,
|
|||
func init() {
|
||||
rootCmd.AddCommand(databaseCmd)
|
||||
databaseCmd.AddCommand(getCredentialsCmd)
|
||||
databaseCmd.AddCommand(addUserCmd)
|
||||
databaseCmd.AddCommand(removeUserCmd)
|
||||
databaseCmd.AddCommand(listUsersCmd)
|
||||
|
||||
getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
|
||||
getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
|
||||
|
|
|
@ -0,0 +1,321 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/oklog/ulid/v2"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
dbUserCmd = &cobra.Command{
|
||||
Use: "db-user",
|
||||
Short: "Manage database user associations",
|
||||
Long: `Commands for managing which users can access which database credentials.`,
|
||||
}
|
||||
|
||||
addDbUserCmd = &cobra.Command{
|
||||
Use: "add [database-id] [user-id]",
|
||||
Short: "Associate a user with a database",
|
||||
Long: `Associate a user with a database, allowing them to read its credentials.`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
addUserToDatabase(args[0], args[1])
|
||||
},
|
||||
}
|
||||
|
||||
removeDbUserCmd = &cobra.Command{
|
||||
Use: "remove [database-id] [user-id]",
|
||||
Short: "Remove a user's association with a database",
|
||||
Long: `Remove a user's association with a database, preventing them from reading its credentials.`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
removeUserFromDatabase(args[0], args[1])
|
||||
},
|
||||
}
|
||||
|
||||
listDbUsersCmd = &cobra.Command{
|
||||
Use: "list [database-id]",
|
||||
Short: "List users associated with a database",
|
||||
Long: `List all users who have access to read the credentials of a specific database.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
listDatabaseUsers(args[0])
|
||||
},
|
||||
}
|
||||
|
||||
listUserDbsCmd = &cobra.Command{
|
||||
Use: "list-user-dbs [user-id]",
|
||||
Short: "List databases a user can access",
|
||||
Long: `List all databases that a specific user has access to.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
listUserDatabases(args[0])
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// nolint:gochecknoinits // This is a standard pattern in Cobra applications
|
||||
func init() {
|
||||
rootCmd.AddCommand(dbUserCmd)
|
||||
dbUserCmd.AddCommand(addDbUserCmd)
|
||||
dbUserCmd.AddCommand(removeDbUserCmd)
|
||||
dbUserCmd.AddCommand(listDbUsersCmd)
|
||||
dbUserCmd.AddCommand(listUserDbsCmd)
|
||||
}
|
||||
|
||||
// addUserToDatabase associates a user with a database, allowing them to read its credentials
|
||||
func addUserToDatabase(databaseID, userID string) {
|
||||
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
|
||||
|
||||
// Open the database
|
||||
db, err := openDatabase(dbPath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Check if the database exists
|
||||
exists, err := checkDatabaseExists(db, databaseID)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to check if database exists: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
logger.Fatalf("Database with ID %s does not exist", databaseID)
|
||||
}
|
||||
|
||||
// Check if the user exists
|
||||
exists, err = checkUserExists(db, userID)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to check if user exists: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
logger.Fatalf("User with ID %s does not exist", userID)
|
||||
}
|
||||
|
||||
// Check if the association already exists
|
||||
exists, err = checkAssociationExists(db, databaseID, userID)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to check if association exists: %v", err)
|
||||
}
|
||||
if exists {
|
||||
logger.Printf("User %s already has access to database %s", userID, databaseID)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new association
|
||||
id := ulid.Make().String()
|
||||
query, args, err := squirrel.Insert("database_users").
|
||||
Columns("id", "db_id", "uid").
|
||||
Values(id, databaseID, userID).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to build query: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(query, args...)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to add user to database: %v", err)
|
||||
}
|
||||
|
||||
logger.Printf("User %s now has access to database %s", userID, databaseID)
|
||||
}
|
||||
|
||||
// removeUserFromDatabase removes a user's association with a database
|
||||
func removeUserFromDatabase(databaseID, userID string) {
|
||||
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
|
||||
|
||||
// Open the database
|
||||
db, err := openDatabase(dbPath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Check if the association exists
|
||||
exists, err := checkAssociationExists(db, databaseID, userID)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to check if association exists: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
logger.Printf("User %s does not have access to database %s", userID, databaseID)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove the association
|
||||
query, args, err := squirrel.Delete("database_users").
|
||||
Where(squirrel.Eq{"db_id": databaseID, "uid": userID}).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to build query: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(query, args...)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to remove user from database: %v", err)
|
||||
}
|
||||
|
||||
logger.Printf("User %s no longer has access to database %s", userID, databaseID)
|
||||
}
|
||||
|
||||
// listDatabaseUsers lists all users who have access to a specific database
|
||||
func listDatabaseUsers(databaseID string) {
|
||||
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
|
||||
|
||||
// Open the database
|
||||
db, err := openDatabase(dbPath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Check if the database exists
|
||||
exists, err := checkDatabaseExists(db, databaseID)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to check if database exists: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
logger.Fatalf("Database with ID %s does not exist", databaseID)
|
||||
}
|
||||
|
||||
// Get the database name for display
|
||||
var dbName string
|
||||
err = db.QueryRow("SELECT name FROM database WHERE id = ?", databaseID).Scan(&dbName)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to get database name: %v", err)
|
||||
}
|
||||
|
||||
// Query all users who have access to this database
|
||||
query, args, err := squirrel.Select("u.id", "u.user").
|
||||
From("users u").
|
||||
Join("database_users du ON u.id = du.uid").
|
||||
Where(squirrel.Eq{"du.db_id": databaseID}).
|
||||
OrderBy("u.user").
|
||||
ToSql()
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to build query: %v", err)
|
||||
}
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to query users: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
fmt.Printf("Users with access to database '%s' (ID: %s):\n\n", dbName, databaseID)
|
||||
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
var id, username string
|
||||
if err := rows.Scan(&id, &username); err != nil {
|
||||
logger.Fatalf("Failed to scan row: %v", err)
|
||||
}
|
||||
fmt.Printf(" %s (ID: %s)\n", username, id)
|
||||
count++
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
fmt.Println(" No users have access to this database")
|
||||
} else {
|
||||
fmt.Printf("\nTotal: %d user(s)\n", count)
|
||||
}
|
||||
}
|
||||
|
||||
// listUserDatabases lists all databases a specific user has access to
|
||||
func listUserDatabases(userID string) {
|
||||
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
|
||||
|
||||
// Open the database
|
||||
db, err := openDatabase(dbPath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Check if the user exists
|
||||
exists, err := checkUserExists(db, userID)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to check if user exists: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
logger.Fatalf("User with ID %s does not exist", userID)
|
||||
}
|
||||
|
||||
// Get the username for display
|
||||
var username string
|
||||
err = db.QueryRow("SELECT user FROM users WHERE id = ?", userID).Scan(&username)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to get username: %v", err)
|
||||
}
|
||||
|
||||
// Query all databases this user has access to
|
||||
query, args, err := squirrel.Select("d.id", "d.name", "d.host").
|
||||
From("database d").
|
||||
Join("database_users du ON d.id = du.db_id").
|
||||
Where(squirrel.Eq{"du.uid": userID}).
|
||||
OrderBy("d.name").
|
||||
ToSql()
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to build query: %v", err)
|
||||
}
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to query databases: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
fmt.Printf("Databases accessible by user '%s' (ID: %s):\n\n", username, userID)
|
||||
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
var id, name, host string
|
||||
if err := rows.Scan(&id, &name, &host); err != nil {
|
||||
logger.Fatalf("Failed to scan row: %v", err)
|
||||
}
|
||||
fmt.Printf(" %s (%s) (ID: %s)\n", name, host, id)
|
||||
count++
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
fmt.Println(" This user does not have access to any databases")
|
||||
} else {
|
||||
fmt.Printf("\nTotal: %d database(s)\n", count)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// checkDatabaseExists checks if a database with the given ID exists
|
||||
func checkDatabaseExists(db *sql.DB, databaseID string) (bool, error) {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM database WHERE id = ?", databaseID).Scan(&count)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// checkUserExists checks if a user with the given ID exists
|
||||
func checkUserExists(db *sql.DB, userID string) (bool, error) {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE id = ?", userID).Scan(&count)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// checkAssociationExists checks if a user is already associated with a database
|
||||
func checkAssociationExists(db *sql.DB, databaseID, userID string) (bool, error) {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM database_users WHERE db_id = ? AND uid = ?", databaseID, userID).Scan(&count)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
|
@ -42,7 +42,7 @@ func setupRootCommand() {
|
|||
|
||||
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.mcias.yaml)")
|
||||
rootCmd.PersistentFlags().StringVar(&dbPath, "db", "mcias.db", "Path to SQLite database file")
|
||||
rootCmd.PersistentFlags().StringVar(&addr, "addr", ":8080", "Address to listen on")
|
||||
rootCmd.PersistentFlags().StringVar(&addr, "addr", ":5000", "Address to listen on")
|
||||
if err := viper.BindPFlag("db", rootCmd.PersistentFlags().Lookup("db")); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error binding db flag: %v\n", err)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
-- Drop the database_users table and its indexes
|
||||
DROP INDEX IF EXISTS idx_database_users_db_id;
|
||||
DROP INDEX IF EXISTS idx_database_users_uid;
|
||||
DROP TABLE IF EXISTS database_users;
|
|
@ -0,0 +1,12 @@
|
|||
-- Add database_users table to associate users with databases
|
||||
CREATE TABLE database_users (
|
||||
id TEXT PRIMARY KEY,
|
||||
uid TEXT NOT NULL,
|
||||
db_id TEXT NOT NULL,
|
||||
FOREIGN KEY(uid) REFERENCES users(id),
|
||||
FOREIGN KEY(db_id) REFERENCES database(id)
|
||||
);
|
||||
|
||||
-- Add index for faster lookups
|
||||
CREATE INDEX idx_database_users_uid ON database_users(uid);
|
||||
CREATE INDEX idx_database_users_db_id ON database_users(db_id);
|
|
@ -6,8 +6,8 @@ After=network.target
|
|||
Type=simple
|
||||
User=mcias
|
||||
Group=mcias
|
||||
WorkingDirectory=/opt/mcias
|
||||
ExecStart=/opt/mcias/mcias server --db /opt/mcias/mcias.db
|
||||
WorkingDirectory=/srv/mcias
|
||||
ExecStart=/usr/local/bin/mcias/mcias server --db /srv/mcias/mcias.db
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
StandardOutput=journal
|
||||
|
@ -19,7 +19,7 @@ PrivateTmp=true
|
|||
ProtectSystem=full
|
||||
ProtectHome=true
|
||||
NoNewPrivileges=true
|
||||
ReadWritePaths=/opt/mcias
|
||||
ReadWritePaths=/srv/mcias
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
Loading…
Reference in New Issue