Junie: cleanups. Code fixups.
This commit is contained in:
@@ -33,7 +33,7 @@ var getCredentialsCmd = &cobra.Command{
|
||||
This command requires authentication with a username and token.
|
||||
If database-id is provided, it returns credentials for that specific database.
|
||||
If database-id is not provided, it returns the first database the user has access to.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run: func(_ *cobra.Command, args []string) {
|
||||
getCredentials()
|
||||
},
|
||||
}
|
||||
@@ -43,7 +43,7 @@ var listCredentialsCmd = &cobra.Command{
|
||||
Short: "List all accessible database credentials",
|
||||
Long: `List all database credentials the user has access to.
|
||||
This command requires authentication with a username and token.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run: func(_ *cobra.Command, args []string) {
|
||||
listCredentials()
|
||||
},
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ var databaseCmd = &cobra.Command{
|
||||
Long: `Commands for managing database credentials in the MCIAS system.`,
|
||||
}
|
||||
|
||||
var addUserCmd = &cobra.Command{
|
||||
var addDBUserCmd = &cobra.Command{
|
||||
Use: "add-user [database-id] [user-id]",
|
||||
Short: "Associate a user with a database",
|
||||
Long: `Associate a user with a database, allowing them to read its credentials.`,
|
||||
@@ -47,22 +47,22 @@ var addUserCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
var removeUserCmd = &cobra.Command{
|
||||
var removeDBUserCmd = &cobra.Command{
|
||||
Use: "remove-user [database-id] [user-id]",
|
||||
Short: "Remove a user's association with a database",
|
||||
Long: `Remove a user's association with a database, preventing them from reading its credentials.`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run: func(_ *cobra.Command, args []string) {
|
||||
removeUserFromDatabase(args[0], args[1])
|
||||
},
|
||||
}
|
||||
|
||||
var listUsersCmd = &cobra.Command{
|
||||
var listDBUsersCmd = &cobra.Command{
|
||||
Use: "list-users [database-id]",
|
||||
Short: "List users associated with a database",
|
||||
Long: `List all users who have access to read the credentials of a specific database.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run: func(_ *cobra.Command, args []string) {
|
||||
listDatabaseUsers(args[0])
|
||||
},
|
||||
}
|
||||
@@ -72,18 +72,17 @@ var getCredentialsCmd = &cobra.Command{
|
||||
Short: "Get database credentials",
|
||||
Long: `Retrieve database credentials from the MCIAS system.
|
||||
This command requires authentication with a username and token.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
Run: func(_ *cobra.Command, args []string) {
|
||||
getCredentials()
|
||||
},
|
||||
}
|
||||
|
||||
// nolint:gochecknoinits // This is a standard pattern in Cobra applications
|
||||
func init() {
|
||||
rootCmd.AddCommand(databaseCmd)
|
||||
databaseCmd.AddCommand(getCredentialsCmd)
|
||||
databaseCmd.AddCommand(addUserCmd)
|
||||
databaseCmd.AddCommand(removeUserCmd)
|
||||
databaseCmd.AddCommand(listUsersCmd)
|
||||
databaseCmd.AddCommand(addDBUserCmd)
|
||||
databaseCmd.AddCommand(removeDBUserCmd)
|
||||
databaseCmd.AddCommand(listDBUsersCmd)
|
||||
|
||||
getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
|
||||
getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
|
||||
@@ -105,7 +104,6 @@ func getCredentials() {
|
||||
|
||||
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
|
||||
|
||||
// Create a context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -69,7 +70,6 @@ func runMigration(direction string, steps int) {
|
||||
dbPath := viper.GetString("db")
|
||||
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
|
||||
|
||||
// Ensure migrations directory exists
|
||||
absPath, err := filepath.Abs(migrationsDir)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
|
||||
@@ -79,20 +79,17 @@ func runMigration(direction string, steps int) {
|
||||
logger.Fatalf("Migrations directory does not exist: %s", absPath)
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err := openDatabase(dbPath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create migration driver
|
||||
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to create migration driver: %v", err)
|
||||
}
|
||||
|
||||
// Create migrate instance
|
||||
m, err := migrate.NewWithDatabaseInstance(
|
||||
fmt.Sprintf("file://%s", absPath),
|
||||
"sqlite3", driver)
|
||||
@@ -100,14 +97,13 @@ func runMigration(direction string, steps int) {
|
||||
logger.Fatalf("Failed to create migration instance: %v", err)
|
||||
}
|
||||
|
||||
// Run migration
|
||||
if direction == "up" {
|
||||
if steps > 0 {
|
||||
err = m.Steps(steps)
|
||||
} else {
|
||||
err = m.Up()
|
||||
}
|
||||
if err != nil && err != migrate.ErrNoChange {
|
||||
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
logger.Fatalf("Failed to apply migrations: %v", err)
|
||||
}
|
||||
logger.Println("Migrations applied successfully")
|
||||
@@ -117,7 +113,7 @@ func runMigration(direction string, steps int) {
|
||||
} else {
|
||||
err = m.Down()
|
||||
}
|
||||
if err != nil && err != migrate.ErrNoChange {
|
||||
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
logger.Fatalf("Failed to revert migrations: %v", err)
|
||||
}
|
||||
logger.Println("Migrations reverted successfully")
|
||||
@@ -128,20 +124,17 @@ func showMigrationVersion() {
|
||||
dbPath := viper.GetString("db")
|
||||
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
|
||||
|
||||
// Open database connection
|
||||
db, err := openDatabase(dbPath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create migration driver
|
||||
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to create migration driver: %v", err)
|
||||
}
|
||||
|
||||
// Create migrate instance
|
||||
absPath, err := filepath.Abs(migrationsDir)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
|
||||
@@ -154,10 +147,9 @@ func showMigrationVersion() {
|
||||
logger.Fatalf("Failed to create migration instance: %v", err)
|
||||
}
|
||||
|
||||
// Get current version
|
||||
version, dirty, err := m.Version()
|
||||
if err != nil {
|
||||
if err == migrate.ErrNilVersion {
|
||||
if errors.Is(err, migrate.ErrNilVersion) {
|
||||
logger.Println("No migrations have been applied yet")
|
||||
return
|
||||
}
|
||||
@@ -168,13 +160,11 @@ func showMigrationVersion() {
|
||||
}
|
||||
|
||||
func openDatabase(dbPath string) (*sql.DB, error) {
|
||||
// Ensure database directory exists
|
||||
dbDir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -167,7 +168,7 @@ func assignRole() {
|
||||
var userID string
|
||||
err = db.QueryRow("SELECT id FROM users WHERE user = ?", roleUser).Scan(&userID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("User %s not found", roleUser)
|
||||
}
|
||||
logger.Fatalf("Failed to get user ID: %v", err)
|
||||
@@ -177,7 +178,7 @@ func assignRole() {
|
||||
var roleID string
|
||||
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", roleName).Scan(&roleID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("Role %s not found", roleName)
|
||||
}
|
||||
logger.Fatalf("Failed to get role ID: %v", err)
|
||||
@@ -219,7 +220,7 @@ func revokeRole() {
|
||||
var userID string
|
||||
err = db.QueryRow("SELECT id FROM users WHERE user = ?", roleUser).Scan(&userID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("User %s not found", roleUser)
|
||||
}
|
||||
logger.Fatalf("Failed to get user ID: %v", err)
|
||||
@@ -229,7 +230,7 @@ func revokeRole() {
|
||||
var roleID string
|
||||
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", roleName).Scan(&roleID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("Role %s not found", roleName)
|
||||
}
|
||||
logger.Fatalf("Failed to get role ID: %v", err)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -72,7 +73,7 @@ func addToken() {
|
||||
var userID string
|
||||
err = db.QueryRow("SELECT id FROM users WHERE user = ?", tokenUsername).Scan(&userID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("User %s does not exist", tokenUsername)
|
||||
}
|
||||
logger.Fatalf("Failed to check if user exists: %v", err)
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -13,7 +14,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// userQuery is the SQL query to get user information from the database
|
||||
userQuery = `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?`
|
||||
)
|
||||
|
||||
@@ -60,7 +60,6 @@ This command requires a username. It will emit the secret, and optionally output
|
||||
},
|
||||
}
|
||||
|
||||
// setupTOTPCommands initializes TOTP commands and flags
|
||||
func setupTOTPCommands() {
|
||||
rootCmd.AddCommand(totpCmd)
|
||||
totpCmd.AddCommand(enableTOTPCmd)
|
||||
@@ -100,7 +99,6 @@ func enableTOTP() {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Get the user from the database
|
||||
var userID string
|
||||
var created int64
|
||||
var username string
|
||||
@@ -109,7 +107,7 @@ func enableTOTP() {
|
||||
|
||||
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("User %s does not exist", totpUsername)
|
||||
}
|
||||
logger.Fatalf("Failed to get user: %v", err)
|
||||
@@ -168,7 +166,7 @@ func validateTOTP() {
|
||||
|
||||
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("User %s does not exist", totpUsername)
|
||||
}
|
||||
logger.Fatalf("Failed to get user: %v", err)
|
||||
@@ -226,7 +224,7 @@ func addTOTP() {
|
||||
|
||||
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("User %s does not exist", totpUsername)
|
||||
}
|
||||
logger.Fatalf("Failed to get user: %v", err)
|
||||
|
||||
Reference in New Issue
Block a user