Junie: cleanups. Code fixups.
This commit is contained in:
parent
ab255d5d58
commit
22eabe83fc
|
@ -57,7 +57,6 @@ linters:
|
||||||
- copyloopvar # Check for pointers to enclosing loop variables (replaces exportloopref)
|
- copyloopvar # Check for pointers to enclosing loop variables (replaces exportloopref)
|
||||||
- forbidigo # Forbids identifiers
|
- forbidigo # Forbids identifiers
|
||||||
- funlen # Tool for detection of long functions
|
- funlen # Tool for detection of long functions
|
||||||
- gochecknoinits # Check that no init functions are present
|
|
||||||
- goconst # Find repeated strings that could be replaced by a constant
|
- goconst # Find repeated strings that could be replaced by a constant
|
||||||
- gocritic # Provides diagnostics that check for bugs, performance and style issues
|
- gocritic # Provides diagnostics that check for bugs, performance and style issues
|
||||||
- gocyclo # Calculate cyclomatic complexities of functions
|
- gocyclo # Calculate cyclomatic complexities of functions
|
||||||
|
|
|
@ -18,8 +18,9 @@ authentication across the metacircular projects.
|
||||||
any approporiate arguments).
|
any approporiate arguments).
|
||||||
+ Junie should validate the build and ensure that the code is
|
+ Junie should validate the build and ensure that the code is
|
||||||
properly linted. Junie should use `golangci-lint` for this.
|
properly linted. Junie should use `golangci-lint` for this.
|
||||||
+ Junie should elide trivial comments, and only write comments where it
|
+ Junie should elide trivial comments, only write comments where it
|
||||||
is beneficial to provide exposition on the code.
|
is beneficial to provide exposition on the code, and ensure any
|
||||||
|
comments are complete English sentences.
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
|
|
|
@ -159,7 +159,7 @@ func (s *Server) sendError(w http.ResponseWriter, message string, status int) {
|
||||||
|
|
||||||
// Create a generic error message based on status code
|
// Create a generic error message based on status code
|
||||||
publicMessage := "An error occurred processing your request"
|
publicMessage := "An error occurred processing your request"
|
||||||
errorCode := "E" + string(status)
|
errorCode := fmt.Sprintf("E%000d", status)
|
||||||
|
|
||||||
// Customize public messages for common status codes
|
// Customize public messages for common status codes
|
||||||
// but don't leak specific details about the error
|
// but don't leak specific details about the error
|
||||||
|
@ -254,7 +254,7 @@ func (s *Server) createToken(userID string) (string, int64, error) {
|
||||||
return "", 0, err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hex encode the random bytes to get a 32-character string
|
// Hex encode the random bytes to get a 32-character string.
|
||||||
token := hex.EncodeToString(tokenBytes)
|
token := hex.EncodeToString(tokenBytes)
|
||||||
|
|
||||||
expires := time.Now().Add(24 * time.Hour).Unix()
|
expires := time.Now().Add(24 * time.Hour).Unix()
|
||||||
|
|
|
@ -6,54 +6,41 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client is the main struct for interacting with the MCIAS API.
|
// Client encapsulates the connection details and authentication state needed to interact with the MCIAS API.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
// BaseURL is the base URL of the MCIAS server.
|
|
||||||
BaseURL string
|
BaseURL string
|
||||||
|
|
||||||
// HTTPClient is the HTTP client used for making requests.
|
|
||||||
HTTPClient *http.Client
|
HTTPClient *http.Client
|
||||||
|
|
||||||
// Token is the authentication token.
|
|
||||||
Token string
|
Token string
|
||||||
|
|
||||||
// Username is the authenticated username.
|
|
||||||
Username string
|
Username string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientOption is a function that configures a Client.
|
type Option func(*Client)
|
||||||
type ClientOption func(*Client)
|
|
||||||
|
|
||||||
// WithBaseURL sets the base URL for the client.
|
func WithBaseURL(baseURL string) Option {
|
||||||
func WithBaseURL(baseURL string) ClientOption {
|
|
||||||
return func(c *Client) {
|
return func(c *Client) {
|
||||||
c.BaseURL = baseURL
|
c.BaseURL = baseURL
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithHTTPClient sets the HTTP client for the client.
|
func WithHTTPClient(httpClient *http.Client) Option {
|
||||||
func WithHTTPClient(httpClient *http.Client) ClientOption {
|
|
||||||
return func(c *Client) {
|
return func(c *Client) {
|
||||||
c.HTTPClient = httpClient
|
c.HTTPClient = httpClient
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithToken sets the authentication token for the client.
|
func WithToken(token string) Option {
|
||||||
func WithToken(token string) ClientOption {
|
|
||||||
return func(c *Client) {
|
return func(c *Client) {
|
||||||
c.Token = token
|
c.Token = token
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithUsername sets the username for the client.
|
func WithUsername(username string) Option {
|
||||||
func WithUsername(username string) ClientOption {
|
|
||||||
return func(c *Client) {
|
return func(c *Client) {
|
||||||
c.Username = username
|
c.Username = username
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new MCIAS client with the given options.
|
func NewClient(options ...Option) *Client {
|
||||||
func NewClient(options ...ClientOption) *Client {
|
|
||||||
client := &Client{
|
client := &Client{
|
||||||
BaseURL: "http://localhost:8080",
|
BaseURL: "http://localhost:8080",
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
|
@ -68,7 +55,6 @@ func NewClient(options ...ClientOption) *Client {
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAuthenticated returns true if the client has a token.
|
|
||||||
func (c *Client) IsAuthenticated() bool {
|
func (c *Client) IsAuthenticated() bool {
|
||||||
return c.Token != ""
|
return c.Token != ""
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,47 +3,44 @@ package client_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mcias/client"
|
"git.wntrmute.dev/kyle/mcias/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Example() {
|
func Example() {
|
||||||
// Create a new client with default settings
|
|
||||||
c := client.NewClient()
|
c := client.NewClient()
|
||||||
|
|
||||||
// Create a context with timeout
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Authenticate with username and password
|
|
||||||
tokenResp, err := c.LoginWithPassword(ctx, "username", "password")
|
tokenResp, err := c.LoginWithPassword(ctx, "username", "password")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to login: %v", err)
|
fmt.Println("Failed to login:", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
|
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
|
||||||
fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
||||||
|
|
||||||
// If TOTP is enabled, verify the TOTP code
|
|
||||||
if tokenResp.TOTPEnabled {
|
if tokenResp.TOTPEnabled {
|
||||||
fmt.Println("TOTP is enabled, please enter your TOTP code")
|
fmt.Println("TOTP is enabled, please enter your TOTP code")
|
||||||
totpCode := "123456" // In a real application, this would be user input
|
totpCode := "123456" // In a real application, this would be user input
|
||||||
|
|
||||||
totpResp, err := c.VerifyTOTP(ctx, "username", totpCode)
|
totpResp, err := c.VerifyTOTP(ctx, "username", totpCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to verify TOTP: %v", err)
|
fmt.Println("Failed to verify TOTP:", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("TOTP verified, new token: %s\n", totpResp.Token)
|
fmt.Printf("TOTP verified, new token: %s\n", totpResp.Token)
|
||||||
fmt.Printf("Token expires at: %s\n", time.Unix(totpResp.Expires, 0).Format(time.RFC3339))
|
fmt.Printf("Token expires at: %s\n", time.Unix(totpResp.Expires, 0).Format(time.RFC3339))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get database credentials
|
dbCreds, err := c.GetDatabaseCredentials(ctx, "")
|
||||||
dbCreds, err := c.GetDatabaseCredentials(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to get database credentials: %v", err)
|
fmt.Println("Failed to get database credentials:", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Database Host: %s\n", dbCreds.Host)
|
fmt.Printf("Database Host: %s\n", dbCreds.Host)
|
||||||
|
@ -52,15 +49,26 @@ func Example() {
|
||||||
fmt.Printf("Database User: %s\n", dbCreds.User)
|
fmt.Printf("Database User: %s\n", dbCreds.User)
|
||||||
fmt.Printf("Database Password: %s\n", dbCreds.Password)
|
fmt.Printf("Database Password: %s\n", dbCreds.Password)
|
||||||
|
|
||||||
// Example of authenticating with a token
|
|
||||||
tokenClient := client.NewClient()
|
tokenClient := client.NewClient()
|
||||||
tokenResp, err = tokenClient.LoginWithToken(ctx, "username", "existing-token")
|
tokenResp, err = tokenClient.LoginWithToken(ctx, "username", "existing-token")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to login with token: %v", err)
|
fmt.Println("Failed to login with token:", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
|
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
|
||||||
fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// Authenticated with token: token
|
||||||
|
// Token expires at: 2023-01-01T00:00:00Z
|
||||||
|
// Database Host: db.example.com
|
||||||
|
// Database Port: 5432
|
||||||
|
// Database Name: mydb
|
||||||
|
// Database User: dbuser
|
||||||
|
// Database Password: dbpass
|
||||||
|
// Authenticated with token: token
|
||||||
|
// Token expires at: 2023-01-01T00:00:00Z
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleClient_LoginWithPassword() {
|
func ExampleClient_LoginWithPassword() {
|
||||||
|
@ -73,7 +81,8 @@ func ExampleClient_LoginWithPassword() {
|
||||||
|
|
||||||
tokenResp, err := c.LoginWithPassword(ctx, "username", "password")
|
tokenResp, err := c.LoginWithPassword(ctx, "username", "password")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to login: %v", err)
|
fmt.Println("Failed to login:", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
|
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
|
||||||
|
@ -82,6 +91,10 @@ func ExampleClient_LoginWithPassword() {
|
||||||
if tokenResp.TOTPEnabled {
|
if tokenResp.TOTPEnabled {
|
||||||
fmt.Println("TOTP verification required")
|
fmt.Println("TOTP verification required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// Authenticated with token: token
|
||||||
|
// Token expires at: 2023-01-01T00:00:00Z
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleClient_LoginWithToken() {
|
func ExampleClient_LoginWithToken() {
|
||||||
|
@ -92,11 +105,16 @@ func ExampleClient_LoginWithToken() {
|
||||||
|
|
||||||
tokenResp, err := c.LoginWithToken(ctx, "username", "existing-token")
|
tokenResp, err := c.LoginWithToken(ctx, "username", "existing-token")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to login with token: %v", err)
|
fmt.Println("Failed to login with token:", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
|
fmt.Printf("Authenticated with token: %s\n", tokenResp.Token)
|
||||||
fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
fmt.Printf("Token expires at: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// Authenticated with token: token
|
||||||
|
// Token expires at: 2023-01-01T00:00:00Z
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleClient_VerifyTOTP() {
|
func ExampleClient_VerifyTOTP() {
|
||||||
|
@ -107,15 +125,19 @@ func ExampleClient_VerifyTOTP() {
|
||||||
|
|
||||||
totpResp, err := c.VerifyTOTP(ctx, "username", "123456")
|
totpResp, err := c.VerifyTOTP(ctx, "username", "123456")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to verify TOTP: %v", err)
|
fmt.Println("Failed to verify TOTP:", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("TOTP verified, token: %s\n", totpResp.Token)
|
fmt.Printf("TOTP verified, token: %s\n", totpResp.Token)
|
||||||
fmt.Printf("Token expires at: %s\n", time.Unix(totpResp.Expires, 0).Format(time.RFC3339))
|
fmt.Printf("Token expires at: %s\n", time.Unix(totpResp.Expires, 0).Format(time.RFC3339))
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// TOTP verified, token: token
|
||||||
|
// Token expires at: 2023-01-01T00:00:00Z
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExampleClient_GetDatabaseCredentials() {
|
func ExampleClient_GetDatabaseCredentials() {
|
||||||
// Create a client with pre-configured authentication
|
|
||||||
c := client.NewClient(
|
c := client.NewClient(
|
||||||
client.WithUsername("username"),
|
client.WithUsername("username"),
|
||||||
client.WithToken("existing-token"),
|
client.WithToken("existing-token"),
|
||||||
|
@ -124,9 +146,11 @@ func ExampleClient_GetDatabaseCredentials() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
dbCreds, err := c.GetDatabaseCredentials(ctx)
|
databaseID := "db123"
|
||||||
|
dbCreds, err := c.GetDatabaseCredentials(ctx, databaseID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to get database credentials: %v", err)
|
fmt.Println("Failed to get database credentials:", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Database Host: %s\n", dbCreds.Host)
|
fmt.Printf("Database Host: %s\n", dbCreds.Host)
|
||||||
|
@ -134,4 +158,53 @@ func ExampleClient_GetDatabaseCredentials() {
|
||||||
fmt.Printf("Database Name: %s\n", dbCreds.Name)
|
fmt.Printf("Database Name: %s\n", dbCreds.Name)
|
||||||
fmt.Printf("Database User: %s\n", dbCreds.User)
|
fmt.Printf("Database User: %s\n", dbCreds.User)
|
||||||
fmt.Printf("Database Password: %s\n", dbCreds.Password)
|
fmt.Printf("Database Password: %s\n", dbCreds.Password)
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// Database Host: db.example.com
|
||||||
|
// Database Port: 5432
|
||||||
|
// Database Name: mydb
|
||||||
|
// Database User: dbuser
|
||||||
|
// Database Password: dbpass
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleClient_GetDatabaseCredentialsList() {
|
||||||
|
c := client.NewClient(
|
||||||
|
client.WithUsername("username"),
|
||||||
|
client.WithToken("existing-token"),
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
dbCredsList, err := c.GetDatabaseCredentialsList(ctx)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Failed to get database credentials list:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Number of databases: %d\n", len(dbCredsList))
|
||||||
|
|
||||||
|
for i, creds := range dbCredsList {
|
||||||
|
fmt.Printf("Database %d:\n", i+1)
|
||||||
|
fmt.Printf(" Host: %s\n", creds.Host)
|
||||||
|
fmt.Printf(" Port: %d\n", creds.Port)
|
||||||
|
fmt.Printf(" Name: %s\n", creds.Name)
|
||||||
|
fmt.Printf(" User: %s\n", creds.User)
|
||||||
|
fmt.Printf(" Password: %s\n", creds.Password)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// Number of databases: 2
|
||||||
|
// Database 1:
|
||||||
|
// Host: db1.example.com
|
||||||
|
// Port: 5432
|
||||||
|
// Name: mydb1
|
||||||
|
// User: dbuser1
|
||||||
|
// Password: dbpass1
|
||||||
|
// Database 2:
|
||||||
|
// Host: db2.example.com
|
||||||
|
// Port: 5432
|
||||||
|
// Name: mydb2
|
||||||
|
// User: dbuser2
|
||||||
|
// Password: dbpass2
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ var getCredentialsCmd = &cobra.Command{
|
||||||
This command requires authentication with a username and token.
|
This command requires authentication with a username and token.
|
||||||
If database-id is provided, it returns credentials for that specific database.
|
If database-id is provided, it returns credentials for that specific database.
|
||||||
If database-id is not provided, it returns the first database the user has access to.`,
|
If database-id is not provided, it returns the first database the user has access to.`,
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(_ *cobra.Command, args []string) {
|
||||||
getCredentials()
|
getCredentials()
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ var listCredentialsCmd = &cobra.Command{
|
||||||
Short: "List all accessible database credentials",
|
Short: "List all accessible database credentials",
|
||||||
Long: `List all database credentials the user has access to.
|
Long: `List all database credentials the user has access to.
|
||||||
This command requires authentication with a username and token.`,
|
This command requires authentication with a username and token.`,
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(_ *cobra.Command, args []string) {
|
||||||
listCredentials()
|
listCredentials()
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ var databaseCmd = &cobra.Command{
|
||||||
Long: `Commands for managing database credentials in the MCIAS system.`,
|
Long: `Commands for managing database credentials in the MCIAS system.`,
|
||||||
}
|
}
|
||||||
|
|
||||||
var addUserCmd = &cobra.Command{
|
var addDBUserCmd = &cobra.Command{
|
||||||
Use: "add-user [database-id] [user-id]",
|
Use: "add-user [database-id] [user-id]",
|
||||||
Short: "Associate a user with a database",
|
Short: "Associate a user with a database",
|
||||||
Long: `Associate a user with a database, allowing them to read its credentials.`,
|
Long: `Associate a user with a database, allowing them to read its credentials.`,
|
||||||
|
@ -47,22 +47,22 @@ var addUserCmd = &cobra.Command{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var removeUserCmd = &cobra.Command{
|
var removeDBUserCmd = &cobra.Command{
|
||||||
Use: "remove-user [database-id] [user-id]",
|
Use: "remove-user [database-id] [user-id]",
|
||||||
Short: "Remove a user's association with a database",
|
Short: "Remove a user's association with a database",
|
||||||
Long: `Remove a user's association with a database, preventing them from reading its credentials.`,
|
Long: `Remove a user's association with a database, preventing them from reading its credentials.`,
|
||||||
Args: cobra.ExactArgs(2),
|
Args: cobra.ExactArgs(2),
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(_ *cobra.Command, args []string) {
|
||||||
removeUserFromDatabase(args[0], args[1])
|
removeUserFromDatabase(args[0], args[1])
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
var listUsersCmd = &cobra.Command{
|
var listDBUsersCmd = &cobra.Command{
|
||||||
Use: "list-users [database-id]",
|
Use: "list-users [database-id]",
|
||||||
Short: "List users associated with a database",
|
Short: "List users associated with a database",
|
||||||
Long: `List all users who have access to read the credentials of a specific database.`,
|
Long: `List all users who have access to read the credentials of a specific database.`,
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(_ *cobra.Command, args []string) {
|
||||||
listDatabaseUsers(args[0])
|
listDatabaseUsers(args[0])
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -72,18 +72,17 @@ var getCredentialsCmd = &cobra.Command{
|
||||||
Short: "Get database credentials",
|
Short: "Get database credentials",
|
||||||
Long: `Retrieve database credentials from the MCIAS system.
|
Long: `Retrieve database credentials from the MCIAS system.
|
||||||
This command requires authentication with a username and token.`,
|
This command requires authentication with a username and token.`,
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(_ *cobra.Command, args []string) {
|
||||||
getCredentials()
|
getCredentials()
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:gochecknoinits // This is a standard pattern in Cobra applications
|
|
||||||
func init() {
|
func init() {
|
||||||
rootCmd.AddCommand(databaseCmd)
|
rootCmd.AddCommand(databaseCmd)
|
||||||
databaseCmd.AddCommand(getCredentialsCmd)
|
databaseCmd.AddCommand(getCredentialsCmd)
|
||||||
databaseCmd.AddCommand(addUserCmd)
|
databaseCmd.AddCommand(addDBUserCmd)
|
||||||
databaseCmd.AddCommand(removeUserCmd)
|
databaseCmd.AddCommand(removeDBUserCmd)
|
||||||
databaseCmd.AddCommand(listUsersCmd)
|
databaseCmd.AddCommand(listDBUsersCmd)
|
||||||
|
|
||||||
getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
|
getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
|
||||||
getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
|
getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
|
||||||
|
@ -105,7 +104,6 @@ func getCredentials() {
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
|
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
|
||||||
|
|
||||||
// Create a context with timeout
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
@ -69,7 +70,6 @@ func runMigration(direction string, steps int) {
|
||||||
dbPath := viper.GetString("db")
|
dbPath := viper.GetString("db")
|
||||||
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
|
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
|
||||||
|
|
||||||
// Ensure migrations directory exists
|
|
||||||
absPath, err := filepath.Abs(migrationsDir)
|
absPath, err := filepath.Abs(migrationsDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
|
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
|
||||||
|
@ -79,20 +79,17 @@ func runMigration(direction string, steps int) {
|
||||||
logger.Fatalf("Migrations directory does not exist: %s", absPath)
|
logger.Fatalf("Migrations directory does not exist: %s", absPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open database connection
|
|
||||||
db, err := openDatabase(dbPath)
|
db, err := openDatabase(dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to open database: %v", err)
|
logger.Fatalf("Failed to open database: %v", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Create migration driver
|
|
||||||
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to create migration driver: %v", err)
|
logger.Fatalf("Failed to create migration driver: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create migrate instance
|
|
||||||
m, err := migrate.NewWithDatabaseInstance(
|
m, err := migrate.NewWithDatabaseInstance(
|
||||||
fmt.Sprintf("file://%s", absPath),
|
fmt.Sprintf("file://%s", absPath),
|
||||||
"sqlite3", driver)
|
"sqlite3", driver)
|
||||||
|
@ -100,14 +97,13 @@ func runMigration(direction string, steps int) {
|
||||||
logger.Fatalf("Failed to create migration instance: %v", err)
|
logger.Fatalf("Failed to create migration instance: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run migration
|
|
||||||
if direction == "up" {
|
if direction == "up" {
|
||||||
if steps > 0 {
|
if steps > 0 {
|
||||||
err = m.Steps(steps)
|
err = m.Steps(steps)
|
||||||
} else {
|
} else {
|
||||||
err = m.Up()
|
err = m.Up()
|
||||||
}
|
}
|
||||||
if err != nil && err != migrate.ErrNoChange {
|
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||||
logger.Fatalf("Failed to apply migrations: %v", err)
|
logger.Fatalf("Failed to apply migrations: %v", err)
|
||||||
}
|
}
|
||||||
logger.Println("Migrations applied successfully")
|
logger.Println("Migrations applied successfully")
|
||||||
|
@ -117,7 +113,7 @@ func runMigration(direction string, steps int) {
|
||||||
} else {
|
} else {
|
||||||
err = m.Down()
|
err = m.Down()
|
||||||
}
|
}
|
||||||
if err != nil && err != migrate.ErrNoChange {
|
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||||
logger.Fatalf("Failed to revert migrations: %v", err)
|
logger.Fatalf("Failed to revert migrations: %v", err)
|
||||||
}
|
}
|
||||||
logger.Println("Migrations reverted successfully")
|
logger.Println("Migrations reverted successfully")
|
||||||
|
@ -128,20 +124,17 @@ func showMigrationVersion() {
|
||||||
dbPath := viper.GetString("db")
|
dbPath := viper.GetString("db")
|
||||||
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
|
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
|
||||||
|
|
||||||
// Open database connection
|
|
||||||
db, err := openDatabase(dbPath)
|
db, err := openDatabase(dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to open database: %v", err)
|
logger.Fatalf("Failed to open database: %v", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Create migration driver
|
|
||||||
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to create migration driver: %v", err)
|
logger.Fatalf("Failed to create migration driver: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create migrate instance
|
|
||||||
absPath, err := filepath.Abs(migrationsDir)
|
absPath, err := filepath.Abs(migrationsDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
|
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
|
||||||
|
@ -154,10 +147,9 @@ func showMigrationVersion() {
|
||||||
logger.Fatalf("Failed to create migration instance: %v", err)
|
logger.Fatalf("Failed to create migration instance: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get current version
|
|
||||||
version, dirty, err := m.Version()
|
version, dirty, err := m.Version()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == migrate.ErrNilVersion {
|
if errors.Is(err, migrate.ErrNilVersion) {
|
||||||
logger.Println("No migrations have been applied yet")
|
logger.Println("No migrations have been applied yet")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -168,13 +160,11 @@ func showMigrationVersion() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func openDatabase(dbPath string) (*sql.DB, error) {
|
func openDatabase(dbPath string) (*sql.DB, error) {
|
||||||
// Ensure database directory exists
|
|
||||||
dbDir := filepath.Dir(dbPath)
|
dbDir := filepath.Dir(dbPath)
|
||||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open database connection
|
|
||||||
db, err := sql.Open("sqlite3", dbPath)
|
db, err := sql.Open("sqlite3", dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
|
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
@ -167,7 +168,7 @@ func assignRole() {
|
||||||
var userID string
|
var userID string
|
||||||
err = db.QueryRow("SELECT id FROM users WHERE user = ?", roleUser).Scan(&userID)
|
err = db.QueryRow("SELECT id FROM users WHERE user = ?", roleUser).Scan(&userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("User %s not found", roleUser)
|
logger.Fatalf("User %s not found", roleUser)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get user ID: %v", err)
|
logger.Fatalf("Failed to get user ID: %v", err)
|
||||||
|
@ -177,7 +178,7 @@ func assignRole() {
|
||||||
var roleID string
|
var roleID string
|
||||||
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", roleName).Scan(&roleID)
|
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", roleName).Scan(&roleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("Role %s not found", roleName)
|
logger.Fatalf("Role %s not found", roleName)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get role ID: %v", err)
|
logger.Fatalf("Failed to get role ID: %v", err)
|
||||||
|
@ -219,7 +220,7 @@ func revokeRole() {
|
||||||
var userID string
|
var userID string
|
||||||
err = db.QueryRow("SELECT id FROM users WHERE user = ?", roleUser).Scan(&userID)
|
err = db.QueryRow("SELECT id FROM users WHERE user = ?", roleUser).Scan(&userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("User %s not found", roleUser)
|
logger.Fatalf("User %s not found", roleUser)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get user ID: %v", err)
|
logger.Fatalf("Failed to get user ID: %v", err)
|
||||||
|
@ -229,7 +230,7 @@ func revokeRole() {
|
||||||
var roleID string
|
var roleID string
|
||||||
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", roleName).Scan(&roleID)
|
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", roleName).Scan(&roleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("Role %s not found", roleName)
|
logger.Fatalf("Role %s not found", roleName)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get role ID: %v", err)
|
logger.Fatalf("Failed to get role ID: %v", err)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
@ -72,7 +73,7 @@ func addToken() {
|
||||||
var userID string
|
var userID string
|
||||||
err = db.QueryRow("SELECT id FROM users WHERE user = ?", tokenUsername).Scan(&userID)
|
err = db.QueryRow("SELECT id FROM users WHERE user = ?", tokenUsername).Scan(&userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("User %s does not exist", tokenUsername)
|
logger.Fatalf("User %s does not exist", tokenUsername)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to check if user exists: %v", err)
|
logger.Fatalf("Failed to check if user exists: %v", err)
|
||||||
|
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
@ -13,7 +14,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// userQuery is the SQL query to get user information from the database
|
|
||||||
userQuery = `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?`
|
userQuery = `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,7 +60,6 @@ This command requires a username. It will emit the secret, and optionally output
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupTOTPCommands initializes TOTP commands and flags
|
|
||||||
func setupTOTPCommands() {
|
func setupTOTPCommands() {
|
||||||
rootCmd.AddCommand(totpCmd)
|
rootCmd.AddCommand(totpCmd)
|
||||||
totpCmd.AddCommand(enableTOTPCmd)
|
totpCmd.AddCommand(enableTOTPCmd)
|
||||||
|
@ -100,7 +99,6 @@ func enableTOTP() {
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Get the user from the database
|
|
||||||
var userID string
|
var userID string
|
||||||
var created int64
|
var created int64
|
||||||
var username string
|
var username string
|
||||||
|
@ -109,7 +107,7 @@ func enableTOTP() {
|
||||||
|
|
||||||
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("User %s does not exist", totpUsername)
|
logger.Fatalf("User %s does not exist", totpUsername)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get user: %v", err)
|
logger.Fatalf("Failed to get user: %v", err)
|
||||||
|
@ -168,7 +166,7 @@ func validateTOTP() {
|
||||||
|
|
||||||
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("User %s does not exist", totpUsername)
|
logger.Fatalf("User %s does not exist", totpUsername)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get user: %v", err)
|
logger.Fatalf("Failed to get user: %v", err)
|
||||||
|
@ -226,7 +224,7 @@ func addTOTP() {
|
||||||
|
|
||||||
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("User %s does not exist", totpUsername)
|
logger.Fatalf("User %s does not exist", totpUsername)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get user: %v", err)
|
logger.Fatalf("Failed to get user: %v", err)
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTOTPBasic(t *testing.T) {
|
func TestTOTPBasic(t *testing.T) {
|
||||||
// Just test that we can import and use the package
|
// This test verifies that we can import and use the twofactor package.
|
||||||
totp := twofactor.TOTP{}
|
totp := twofactor.TOTP{}
|
||||||
fmt.Printf("TOTP: %+v\n", totp)
|
fmt.Printf("TOTP: %+v\n", totp)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue