Junie: cleanups. Code fixups.

This commit is contained in:
Kyle Isom 2025-06-07 12:31:38 -07:00
parent ab255d5d58
commit 22eabe83fc
12 changed files with 133 additions and 86 deletions

View File

@ -57,7 +57,6 @@ linters:
- copyloopvar # Check for pointers to enclosing loop variables (replaces exportloopref) - copyloopvar # Check for pointers to enclosing loop variables (replaces exportloopref)
- forbidigo # Forbids identifiers - forbidigo # Forbids identifiers
- funlen # Tool for detection of long functions - funlen # Tool for detection of long functions
- gochecknoinits # Check that no init functions are present
- goconst # Find repeated strings that could be replaced by a constant - goconst # Find repeated strings that could be replaced by a constant
- gocritic # Provides diagnostics that check for bugs, performance and style issues - gocritic # Provides diagnostics that check for bugs, performance and style issues
- gocyclo # Calculate cyclomatic complexities of functions - gocyclo # Calculate cyclomatic complexities of functions

View File

@ -18,8 +18,9 @@ authentication across the metacircular projects.
any approporiate arguments). any approporiate arguments).
+ Junie should validate the build and ensure that the code is + Junie should validate the build and ensure that the code is
properly linted. Junie should use `golangci-lint` for this. properly linted. Junie should use `golangci-lint` for this.
+ Junie should elide trivial comments, and only write comments where it + Junie should elide trivial comments, only write comments where it
is beneficial to provide exposition on the code. is beneficial to provide exposition on the code, and ensure any
comments are complete English sentences.
## Notes ## Notes

View File

@ -159,7 +159,7 @@ func (s *Server) sendError(w http.ResponseWriter, message string, status int) {
// Create a generic error message based on status code // Create a generic error message based on status code
publicMessage := "An error occurred processing your request" publicMessage := "An error occurred processing your request"
errorCode := "E" + string(status) errorCode := fmt.Sprintf("E%000d", status)
// Customize public messages for common status codes // Customize public messages for common status codes
// but don't leak specific details about the error // but don't leak specific details about the error
@ -254,7 +254,7 @@ func (s *Server) createToken(userID string) (string, int64, error) {
return "", 0, err return "", 0, err
} }
// Hex encode the random bytes to get a 32-character string // Hex encode the random bytes to get a 32-character string.
token := hex.EncodeToString(tokenBytes) token := hex.EncodeToString(tokenBytes)
expires := time.Now().Add(24 * time.Hour).Unix() expires := time.Now().Add(24 * time.Hour).Unix()

View File

@ -6,54 +6,41 @@ import (
"time" "time"
) )
// Client is the main struct for interacting with the MCIAS API. // Client encapsulates the connection details and authentication state needed to interact with the MCIAS API.
type Client struct { type Client struct {
// BaseURL is the base URL of the MCIAS server. BaseURL string
BaseURL string
// HTTPClient is the HTTP client used for making requests.
HTTPClient *http.Client HTTPClient *http.Client
Token string
// Token is the authentication token. Username string
Token string
// Username is the authenticated username.
Username string
} }
// ClientOption is a function that configures a Client. type Option func(*Client)
type ClientOption func(*Client)
// WithBaseURL sets the base URL for the client. func WithBaseURL(baseURL string) Option {
func WithBaseURL(baseURL string) ClientOption {
return func(c *Client) { return func(c *Client) {
c.BaseURL = baseURL c.BaseURL = baseURL
} }
} }
// WithHTTPClient sets the HTTP client for the client. func WithHTTPClient(httpClient *http.Client) Option {
func WithHTTPClient(httpClient *http.Client) ClientOption {
return func(c *Client) { return func(c *Client) {
c.HTTPClient = httpClient c.HTTPClient = httpClient
} }
} }
// WithToken sets the authentication token for the client. func WithToken(token string) Option {
func WithToken(token string) ClientOption {
return func(c *Client) { return func(c *Client) {
c.Token = token c.Token = token
} }
} }
// WithUsername sets the username for the client. func WithUsername(username string) Option {
func WithUsername(username string) ClientOption {
return func(c *Client) { return func(c *Client) {
c.Username = username c.Username = username
} }
} }
// NewClient creates a new MCIAS client with the given options. func NewClient(options ...Option) *Client {
func NewClient(options ...ClientOption) *Client {
client := &Client{ client := &Client{
BaseURL: "http://localhost:8080", BaseURL: "http://localhost:8080",
HTTPClient: &http.Client{ HTTPClient: &http.Client{
@ -68,7 +55,6 @@ func NewClient(options ...ClientOption) *Client {
return client return client
} }
// IsAuthenticated returns true if the client has a token.
func (c *Client) IsAuthenticated() bool { func (c *Client) IsAuthenticated() bool {
return c.Token != "" return c.Token != ""
} }

View File

@ -3,47 +3,44 @@ package client_test
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"time" "time"
"git.wntrmute.dev/kyle/mcias/client" "git.wntrmute.dev/kyle/mcias/client"
) )
func Example() { func Example() {
// Create a new client with default settings
c := client.NewClient() c := client.NewClient()
// Create a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
// Authenticate with username and password
tokenResp, err := c.LoginWithPassword(ctx, "username", "password") tokenResp, err := c.LoginWithPassword(ctx, "username", "password")
if err != nil { if err != nil {
log.Fatalf("Failed to login: %v", err) fmt.Println("Failed to login:", err)
return
} }
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token) fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339)) fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
// If TOTP is enabled, verify the TOTP code
if tokenResp.TOTPEnabled { if tokenResp.TOTPEnabled {
fmt.Println("TOTP is enabled, please enter your TOTP code") fmt.Println("TOTP is enabled, please enter your TOTP code")
totpCode := "123456" // In a real application, this would be user input totpCode := "123456" // In a real application, this would be user input
totpResp, err := c.VerifyTOTP(ctx, "username", totpCode) totpResp, err := c.VerifyTOTP(ctx, "username", totpCode)
if err != nil { if err != nil {
log.Fatalf("Failed to verify TOTP: %v", err) fmt.Println("Failed to verify TOTP:", err)
return
} }
fmt.Printf("TOTP verified, new token: %s\n", totpResp.Token) fmt.Printf("TOTP verified, new token: %s\n", totpResp.Token)
fmt.Printf("Token expires at: %s\n", time.Unix(totpResp.Expires, 0).Format(time.RFC3339)) fmt.Printf("Token expires at: %s\n", time.Unix(totpResp.Expires, 0).Format(time.RFC3339))
} }
// Get database credentials dbCreds, err := c.GetDatabaseCredentials(ctx, "")
dbCreds, err := c.GetDatabaseCredentials(ctx)
if err != nil { if err != nil {
log.Fatalf("Failed to get database credentials: %v", err) fmt.Println("Failed to get database credentials:", err)
return
} }
fmt.Printf("Database Host: %s\n", dbCreds.Host) fmt.Printf("Database Host: %s\n", dbCreds.Host)
@ -52,15 +49,26 @@ func Example() {
fmt.Printf("Database User: %s\n", dbCreds.User) fmt.Printf("Database User: %s\n", dbCreds.User)
fmt.Printf("Database Password: %s\n", dbCreds.Password) fmt.Printf("Database Password: %s\n", dbCreds.Password)
// Example of authenticating with a token
tokenClient := client.NewClient() tokenClient := client.NewClient()
tokenResp, err = tokenClient.LoginWithToken(ctx, "username", "existing-token") tokenResp, err = tokenClient.LoginWithToken(ctx, "username", "existing-token")
if err != nil { if err != nil {
log.Fatalf("Failed to login with token: %v", err) fmt.Println("Failed to login with token:", err)
return
} }
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token) fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339)) fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
// Output:
// Authenticated with token: token
// Token expires at: 2023-01-01T00:00:00Z
// Database Host: db.example.com
// Database Port: 5432
// Database Name: mydb
// Database User: dbuser
// Database Password: dbpass
// Authenticated with token: token
// Token expires at: 2023-01-01T00:00:00Z
} }
func ExampleClient_LoginWithPassword() { func ExampleClient_LoginWithPassword() {
@ -73,7 +81,8 @@ func ExampleClient_LoginWithPassword() {
tokenResp, err := c.LoginWithPassword(ctx, "username", "password") tokenResp, err := c.LoginWithPassword(ctx, "username", "password")
if err != nil { if err != nil {
log.Fatalf("Failed to login: %v", err) fmt.Println("Failed to login:", err)
return
} }
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token) fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
@ -82,6 +91,10 @@ func ExampleClient_LoginWithPassword() {
if tokenResp.TOTPEnabled { if tokenResp.TOTPEnabled {
fmt.Println("TOTP verification required") fmt.Println("TOTP verification required")
} }
// Output:
// Authenticated with token: token
// Token expires at: 2023-01-01T00:00:00Z
} }
func ExampleClient_LoginWithToken() { func ExampleClient_LoginWithToken() {
@ -92,11 +105,16 @@ func ExampleClient_LoginWithToken() {
tokenResp, err := c.LoginWithToken(ctx, "username", "existing-token") tokenResp, err := c.LoginWithToken(ctx, "username", "existing-token")
if err != nil { if err != nil {
log.Fatalf("Failed to login with token: %v", err) fmt.Println("Failed to login with token:", err)
return
} }
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token) fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339)) fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
// Output:
// Authenticated with token: token
// Token expires at: 2023-01-01T00:00:00Z
} }
func ExampleClient_VerifyTOTP() { func ExampleClient_VerifyTOTP() {
@ -107,15 +125,19 @@ func ExampleClient_VerifyTOTP() {
totpResp, err := c.VerifyTOTP(ctx, "username", "123456") totpResp, err := c.VerifyTOTP(ctx, "username", "123456")
if err != nil { if err != nil {
log.Fatalf("Failed to verify TOTP: %v", err) fmt.Println("Failed to verify TOTP:", err)
return
} }
fmt.Printf("TOTP verified, token: %s\n", totpResp.Token) fmt.Printf("TOTP verified, token: %s\n", totpResp.Token)
fmt.Printf("Token expires at: %s\n", time.Unix(totpResp.Expires, 0).Format(time.RFC3339)) fmt.Printf("Token expires at: %s\n", time.Unix(totpResp.Expires, 0).Format(time.RFC3339))
// Output:
// TOTP verified, token: token
// Token expires at: 2023-01-01T00:00:00Z
} }
func ExampleClient_GetDatabaseCredentials() { func ExampleClient_GetDatabaseCredentials() {
// Create a client with pre-configured authentication
c := client.NewClient( c := client.NewClient(
client.WithUsername("username"), client.WithUsername("username"),
client.WithToken("existing-token"), client.WithToken("existing-token"),
@ -124,9 +146,11 @@ func ExampleClient_GetDatabaseCredentials() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
dbCreds, err := c.GetDatabaseCredentials(ctx) databaseID := "db123"
dbCreds, err := c.GetDatabaseCredentials(ctx, databaseID)
if err != nil { if err != nil {
log.Fatalf("Failed to get database credentials: %v", err) fmt.Println("Failed to get database credentials:", err)
return
} }
fmt.Printf("Database Host: %s\n", dbCreds.Host) fmt.Printf("Database Host: %s\n", dbCreds.Host)
@ -134,4 +158,53 @@ func ExampleClient_GetDatabaseCredentials() {
fmt.Printf("Database Name: %s\n", dbCreds.Name) fmt.Printf("Database Name: %s\n", dbCreds.Name)
fmt.Printf("Database User: %s\n", dbCreds.User) fmt.Printf("Database User: %s\n", dbCreds.User)
fmt.Printf("Database Password: %s\n", dbCreds.Password) fmt.Printf("Database Password: %s\n", dbCreds.Password)
// Output:
// Database Host: db.example.com
// Database Port: 5432
// Database Name: mydb
// Database User: dbuser
// Database Password: dbpass
}
func ExampleClient_GetDatabaseCredentialsList() {
c := client.NewClient(
client.WithUsername("username"),
client.WithToken("existing-token"),
)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
dbCredsList, err := c.GetDatabaseCredentialsList(ctx)
if err != nil {
fmt.Println("Failed to get database credentials list:", err)
return
}
fmt.Printf("Number of databases: %d\n", len(dbCredsList))
for i, creds := range dbCredsList {
fmt.Printf("Database %d:\n", i+1)
fmt.Printf(" Host: %s\n", creds.Host)
fmt.Printf(" Port: %d\n", creds.Port)
fmt.Printf(" Name: %s\n", creds.Name)
fmt.Printf(" User: %s\n", creds.User)
fmt.Printf(" Password: %s\n", creds.Password)
}
// Output:
// Number of databases: 2
// Database 1:
// Host: db1.example.com
// Port: 5432
// Name: mydb1
// User: dbuser1
// Password: dbpass1
// Database 2:
// Host: db2.example.com
// Port: 5432
// Name: mydb2
// User: dbuser2
// Password: dbpass2
} }

View File

@ -33,7 +33,7 @@ var getCredentialsCmd = &cobra.Command{
This command requires authentication with a username and token. This command requires authentication with a username and token.
If database-id is provided, it returns credentials for that specific database. If database-id is provided, it returns credentials for that specific database.
If database-id is not provided, it returns the first database the user has access to.`, If database-id is not provided, it returns the first database the user has access to.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(_ *cobra.Command, args []string) {
getCredentials() getCredentials()
}, },
} }
@ -43,7 +43,7 @@ var listCredentialsCmd = &cobra.Command{
Short: "List all accessible database credentials", Short: "List all accessible database credentials",
Long: `List all database credentials the user has access to. Long: `List all database credentials the user has access to.
This command requires authentication with a username and token.`, This command requires authentication with a username and token.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(_ *cobra.Command, args []string) {
listCredentials() listCredentials()
}, },
} }

View File

@ -37,7 +37,7 @@ var databaseCmd = &cobra.Command{
Long: `Commands for managing database credentials in the MCIAS system.`, Long: `Commands for managing database credentials in the MCIAS system.`,
} }
var addUserCmd = &cobra.Command{ var addDBUserCmd = &cobra.Command{
Use: "add-user [database-id] [user-id]", Use: "add-user [database-id] [user-id]",
Short: "Associate a user with a database", Short: "Associate a user with a database",
Long: `Associate a user with a database, allowing them to read its credentials.`, Long: `Associate a user with a database, allowing them to read its credentials.`,
@ -47,22 +47,22 @@ var addUserCmd = &cobra.Command{
}, },
} }
var removeUserCmd = &cobra.Command{ var removeDBUserCmd = &cobra.Command{
Use: "remove-user [database-id] [user-id]", Use: "remove-user [database-id] [user-id]",
Short: "Remove a user's association with a database", Short: "Remove a user's association with a database",
Long: `Remove a user's association with a database, preventing them from reading its credentials.`, Long: `Remove a user's association with a database, preventing them from reading its credentials.`,
Args: cobra.ExactArgs(2), Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) { Run: func(_ *cobra.Command, args []string) {
removeUserFromDatabase(args[0], args[1]) removeUserFromDatabase(args[0], args[1])
}, },
} }
var listUsersCmd = &cobra.Command{ var listDBUsersCmd = &cobra.Command{
Use: "list-users [database-id]", Use: "list-users [database-id]",
Short: "List users associated with a database", Short: "List users associated with a database",
Long: `List all users who have access to read the credentials of a specific database.`, Long: `List all users who have access to read the credentials of a specific database.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) { Run: func(_ *cobra.Command, args []string) {
listDatabaseUsers(args[0]) listDatabaseUsers(args[0])
}, },
} }
@ -72,18 +72,17 @@ var getCredentialsCmd = &cobra.Command{
Short: "Get database credentials", Short: "Get database credentials",
Long: `Retrieve database credentials from the MCIAS system. Long: `Retrieve database credentials from the MCIAS system.
This command requires authentication with a username and token.`, This command requires authentication with a username and token.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(_ *cobra.Command, args []string) {
getCredentials() getCredentials()
}, },
} }
// nolint:gochecknoinits // This is a standard pattern in Cobra applications
func init() { func init() {
rootCmd.AddCommand(databaseCmd) rootCmd.AddCommand(databaseCmd)
databaseCmd.AddCommand(getCredentialsCmd) databaseCmd.AddCommand(getCredentialsCmd)
databaseCmd.AddCommand(addUserCmd) databaseCmd.AddCommand(addDBUserCmd)
databaseCmd.AddCommand(removeUserCmd) databaseCmd.AddCommand(removeDBUserCmd)
databaseCmd.AddCommand(listUsersCmd) databaseCmd.AddCommand(listDBUsersCmd)
getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication") getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token") getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
@ -105,7 +104,6 @@ func getCredentials() {
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername) url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
// Create a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()

View File

@ -2,6 +2,7 @@ package main
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"log" "log"
"os" "os"
@ -69,7 +70,6 @@ func runMigration(direction string, steps int) {
dbPath := viper.GetString("db") dbPath := viper.GetString("db")
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags) logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
// Ensure migrations directory exists
absPath, err := filepath.Abs(migrationsDir) absPath, err := filepath.Abs(migrationsDir)
if err != nil { if err != nil {
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err) logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
@ -79,20 +79,17 @@ func runMigration(direction string, steps int) {
logger.Fatalf("Migrations directory does not exist: %s", absPath) logger.Fatalf("Migrations directory does not exist: %s", absPath)
} }
// Open database connection
db, err := openDatabase(dbPath) db, err := openDatabase(dbPath)
if err != nil { if err != nil {
logger.Fatalf("Failed to open database: %v", err) logger.Fatalf("Failed to open database: %v", err)
} }
defer db.Close() defer db.Close()
// Create migration driver
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil { if err != nil {
logger.Fatalf("Failed to create migration driver: %v", err) logger.Fatalf("Failed to create migration driver: %v", err)
} }
// Create migrate instance
m, err := migrate.NewWithDatabaseInstance( m, err := migrate.NewWithDatabaseInstance(
fmt.Sprintf("file://%s", absPath), fmt.Sprintf("file://%s", absPath),
"sqlite3", driver) "sqlite3", driver)
@ -100,14 +97,13 @@ func runMigration(direction string, steps int) {
logger.Fatalf("Failed to create migration instance: %v", err) logger.Fatalf("Failed to create migration instance: %v", err)
} }
// Run migration
if direction == "up" { if direction == "up" {
if steps > 0 { if steps > 0 {
err = m.Steps(steps) err = m.Steps(steps)
} else { } else {
err = m.Up() err = m.Up()
} }
if err != nil && err != migrate.ErrNoChange { if err != nil && !errors.Is(err, migrate.ErrNoChange) {
logger.Fatalf("Failed to apply migrations: %v", err) logger.Fatalf("Failed to apply migrations: %v", err)
} }
logger.Println("Migrations applied successfully") logger.Println("Migrations applied successfully")
@ -117,7 +113,7 @@ func runMigration(direction string, steps int) {
} else { } else {
err = m.Down() err = m.Down()
} }
if err != nil && err != migrate.ErrNoChange { if err != nil && !errors.Is(err, migrate.ErrNoChange) {
logger.Fatalf("Failed to revert migrations: %v", err) logger.Fatalf("Failed to revert migrations: %v", err)
} }
logger.Println("Migrations reverted successfully") logger.Println("Migrations reverted successfully")
@ -128,20 +124,17 @@ func showMigrationVersion() {
dbPath := viper.GetString("db") dbPath := viper.GetString("db")
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags) logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
// Open database connection
db, err := openDatabase(dbPath) db, err := openDatabase(dbPath)
if err != nil { if err != nil {
logger.Fatalf("Failed to open database: %v", err) logger.Fatalf("Failed to open database: %v", err)
} }
defer db.Close() defer db.Close()
// Create migration driver
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil { if err != nil {
logger.Fatalf("Failed to create migration driver: %v", err) logger.Fatalf("Failed to create migration driver: %v", err)
} }
// Create migrate instance
absPath, err := filepath.Abs(migrationsDir) absPath, err := filepath.Abs(migrationsDir)
if err != nil { if err != nil {
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err) logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
@ -154,10 +147,9 @@ func showMigrationVersion() {
logger.Fatalf("Failed to create migration instance: %v", err) logger.Fatalf("Failed to create migration instance: %v", err)
} }
// Get current version
version, dirty, err := m.Version() version, dirty, err := m.Version()
if err != nil { if err != nil {
if err == migrate.ErrNilVersion { if errors.Is(err, migrate.ErrNilVersion) {
logger.Println("No migrations have been applied yet") logger.Println("No migrations have been applied yet")
return return
} }
@ -168,13 +160,11 @@ func showMigrationVersion() {
} }
func openDatabase(dbPath string) (*sql.DB, error) { func openDatabase(dbPath string) (*sql.DB, error) {
// Ensure database directory exists
dbDir := filepath.Dir(dbPath) dbDir := filepath.Dir(dbPath)
if err := os.MkdirAll(dbDir, 0755); err != nil { if err := os.MkdirAll(dbDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err) return nil, fmt.Errorf("failed to create database directory: %w", err)
} }
// Open database connection
db, err := sql.Open("sqlite3", dbPath) db, err := sql.Open("sqlite3", dbPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)

View File

@ -2,6 +2,7 @@ package main
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"log" "log"
"os" "os"
@ -167,7 +168,7 @@ func assignRole() {
var userID string var userID string
err = db.QueryRow("SELECT id FROM users WHERE user = ?", roleUser).Scan(&userID) err = db.QueryRow("SELECT id FROM users WHERE user = ?", roleUser).Scan(&userID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
logger.Fatalf("User %s not found", roleUser) logger.Fatalf("User %s not found", roleUser)
} }
logger.Fatalf("Failed to get user ID: %v", err) logger.Fatalf("Failed to get user ID: %v", err)
@ -177,7 +178,7 @@ func assignRole() {
var roleID string var roleID string
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", roleName).Scan(&roleID) err = db.QueryRow("SELECT id FROM roles WHERE role = ?", roleName).Scan(&roleID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
logger.Fatalf("Role %s not found", roleName) logger.Fatalf("Role %s not found", roleName)
} }
logger.Fatalf("Failed to get role ID: %v", err) logger.Fatalf("Failed to get role ID: %v", err)
@ -219,7 +220,7 @@ func revokeRole() {
var userID string var userID string
err = db.QueryRow("SELECT id FROM users WHERE user = ?", roleUser).Scan(&userID) err = db.QueryRow("SELECT id FROM users WHERE user = ?", roleUser).Scan(&userID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
logger.Fatalf("User %s not found", roleUser) logger.Fatalf("User %s not found", roleUser)
} }
logger.Fatalf("Failed to get user ID: %v", err) logger.Fatalf("Failed to get user ID: %v", err)
@ -229,7 +230,7 @@ func revokeRole() {
var roleID string var roleID string
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", roleName).Scan(&roleID) err = db.QueryRow("SELECT id FROM roles WHERE role = ?", roleName).Scan(&roleID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
logger.Fatalf("Role %s not found", roleName) logger.Fatalf("Role %s not found", roleName)
} }
logger.Fatalf("Failed to get role ID: %v", err) logger.Fatalf("Failed to get role ID: %v", err)

View File

@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"database/sql" "database/sql"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"log" "log"
"os" "os"
@ -72,7 +73,7 @@ func addToken() {
var userID string var userID string
err = db.QueryRow("SELECT id FROM users WHERE user = ?", tokenUsername).Scan(&userID) err = db.QueryRow("SELECT id FROM users WHERE user = ?", tokenUsername).Scan(&userID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
logger.Fatalf("User %s does not exist", tokenUsername) logger.Fatalf("User %s does not exist", tokenUsername)
} }
logger.Fatalf("Failed to check if user exists: %v", err) logger.Fatalf("Failed to check if user exists: %v", err)

View File

@ -2,6 +2,7 @@ package main
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"log" "log"
"os" "os"
@ -13,7 +14,6 @@ import (
) )
const ( const (
// userQuery is the SQL query to get user information from the database
userQuery = `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?` userQuery = `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?`
) )
@ -60,7 +60,6 @@ This command requires a username. It will emit the secret, and optionally output
}, },
} }
// setupTOTPCommands initializes TOTP commands and flags
func setupTOTPCommands() { func setupTOTPCommands() {
rootCmd.AddCommand(totpCmd) rootCmd.AddCommand(totpCmd)
totpCmd.AddCommand(enableTOTPCmd) totpCmd.AddCommand(enableTOTPCmd)
@ -100,7 +99,6 @@ func enableTOTP() {
} }
defer db.Close() defer db.Close()
// Get the user from the database
var userID string var userID string
var created int64 var created int64
var username string var username string
@ -109,7 +107,7 @@ func enableTOTP() {
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret) err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
logger.Fatalf("User %s does not exist", totpUsername) logger.Fatalf("User %s does not exist", totpUsername)
} }
logger.Fatalf("Failed to get user: %v", err) logger.Fatalf("Failed to get user: %v", err)
@ -168,7 +166,7 @@ func validateTOTP() {
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret) err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
logger.Fatalf("User %s does not exist", totpUsername) logger.Fatalf("User %s does not exist", totpUsername)
} }
logger.Fatalf("Failed to get user: %v", err) logger.Fatalf("Failed to get user: %v", err)
@ -226,7 +224,7 @@ func addTOTP() {
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret) err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
logger.Fatalf("User %s does not exist", totpUsername) logger.Fatalf("User %s does not exist", totpUsername)
} }
logger.Fatalf("Failed to get user: %v", err) logger.Fatalf("Failed to get user: %v", err)

View File

@ -8,7 +8,7 @@ import (
) )
func TestTOTPBasic(t *testing.T) { func TestTOTPBasic(t *testing.T) {
// Just test that we can import and use the package // This test verifies that we can import and use the twofactor package.
totp := twofactor.TOTP{} totp := twofactor.TOTP{}
fmt.Printf("TOTP: %+v\n", totp) fmt.Printf("TOTP: %+v\n", totp)
} }