Junie: TOTP flow update and db migrations.
This commit is contained in:
137
cmd/mcias-client/database.go
Normal file
137
cmd/mcias-client/database.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type DatabaseCredentials struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Name string `json:"name"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var (
|
||||
dbUsername string
|
||||
dbToken string
|
||||
useStored bool
|
||||
)
|
||||
|
||||
var databaseCmd = &cobra.Command{
|
||||
Use: "database",
|
||||
Short: "Manage database credentials",
|
||||
Long: `Commands for managing database credentials in the MCIAS system.`,
|
||||
}
|
||||
|
||||
var getCredentialsCmd = &cobra.Command{
|
||||
Use: "credentials",
|
||||
Short: "Get database credentials",
|
||||
Long: `Retrieve database credentials from the MCIAS system.
|
||||
This command requires authentication with a username and token.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
getCredentials()
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(databaseCmd)
|
||||
databaseCmd.AddCommand(getCredentialsCmd)
|
||||
|
||||
getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
|
||||
getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
|
||||
getCredentialsCmd.Flags().BoolVarP(&useStored, "use-stored", "s", false, "Use stored token from previous login")
|
||||
|
||||
// Make username required only if not using stored token
|
||||
getCredentialsCmd.MarkFlagsMutuallyExclusive("token", "use-stored")
|
||||
}
|
||||
|
||||
func getCredentials() {
|
||||
// If using stored token, load it from the token file
|
||||
if useStored {
|
||||
tokenInfo, err := loadToken()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading token: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
dbUsername = tokenInfo.Username
|
||||
dbToken = tokenInfo.Token
|
||||
}
|
||||
|
||||
// Validate required parameters
|
||||
if dbUsername == "" {
|
||||
fmt.Fprintf(os.Stderr, "Error: username is required\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if dbToken == "" {
|
||||
fmt.Fprintf(os.Stderr, "Error: token is required (either provide --token or use --use-stored)\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
serverAddr := viper.GetString("server")
|
||||
if serverAddr == "" {
|
||||
serverAddr = "http://localhost:8080"
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
|
||||
|
||||
// Create a context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", dbToken))
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp ErrorResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var creds DatabaseCredentials
|
||||
if unmarshalErr := json.Unmarshal(body, &creds); unmarshalErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Println("Database Credentials:")
|
||||
fmt.Printf("Host: %s\n", creds.Host)
|
||||
fmt.Printf("Port: %d\n", creds.Port)
|
||||
fmt.Printf("Name: %s\n", creds.Name)
|
||||
fmt.Printf("User: %s\n", creds.User)
|
||||
fmt.Printf("Password: %s\n", creds.Password)
|
||||
}
|
||||
|
||||
368
cmd/mcias-client/login.go
Normal file
368
cmd/mcias-client/login.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var (
|
||||
username string
|
||||
password string
|
||||
token string
|
||||
totpCode string
|
||||
)
|
||||
|
||||
type LoginRequest struct {
|
||||
Version string `json:"version"`
|
||||
Login LoginParams `json:"login"`
|
||||
}
|
||||
|
||||
type TOTPVerifyRequest struct {
|
||||
Version string `json:"version"`
|
||||
Username string `json:"username"`
|
||||
TOTPCode string `json:"totp_code"`
|
||||
}
|
||||
|
||||
type LoginParams struct {
|
||||
User string `json:"user"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TOTPCode string `json:"totp_code,omitempty"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Expires int64 `json:"expires"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type TokenInfo struct {
|
||||
Username string `json:"username"`
|
||||
Token string `json:"token"`
|
||||
Expires int64 `json:"expires"`
|
||||
}
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
Use: "login",
|
||||
Short: "Login to the MCIAS server",
|
||||
Long: `Login to the MCIAS server using either a username/password or a token.`,
|
||||
}
|
||||
|
||||
var passwordLoginCmd = &cobra.Command{
|
||||
Use: "password",
|
||||
Short: "Login with username and password",
|
||||
Long: `Login to the MCIAS server using a username and password.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
loginWithPassword()
|
||||
},
|
||||
}
|
||||
|
||||
var tokenLoginCmd = &cobra.Command{
|
||||
Use: "token",
|
||||
Short: "Login with a token",
|
||||
Long: `Login to the MCIAS server using a token.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
loginWithToken()
|
||||
},
|
||||
}
|
||||
|
||||
var totpVerifyCmd = &cobra.Command{
|
||||
Use: "totp",
|
||||
Short: "Verify TOTP code",
|
||||
Long: `Verify a TOTP code after password authentication.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
verifyTOTP()
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(loginCmd)
|
||||
loginCmd.AddCommand(passwordLoginCmd)
|
||||
loginCmd.AddCommand(tokenLoginCmd)
|
||||
loginCmd.AddCommand(totpVerifyCmd)
|
||||
|
||||
// TOTP verification flags
|
||||
totpVerifyCmd.Flags().StringVarP(&username, "username", "u", "", "Username for authentication")
|
||||
totpVerifyCmd.Flags().StringVarP(&totpCode, "code", "c", "", "TOTP code to verify")
|
||||
if err := totpVerifyCmd.MarkFlagRequired("username"); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking username flag as required: %v\n", err)
|
||||
}
|
||||
if err := totpVerifyCmd.MarkFlagRequired("code"); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking code flag as required: %v\n", err)
|
||||
}
|
||||
|
||||
// Password login flags
|
||||
passwordLoginCmd.Flags().StringVarP(&username, "username", "u", "", "Username for authentication")
|
||||
passwordLoginCmd.Flags().StringVarP(&password, "password", "p", "", "Password for authentication")
|
||||
passwordLoginCmd.Flags().StringVarP(&totpCode, "totp", "t", "", "TOTP code (if enabled)")
|
||||
if err := passwordLoginCmd.MarkFlagRequired("username"); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking username flag as required: %v\n", err)
|
||||
}
|
||||
if err := passwordLoginCmd.MarkFlagRequired("password"); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking password flag as required: %v\n", err)
|
||||
}
|
||||
|
||||
// Token login flags
|
||||
tokenLoginCmd.Flags().StringVarP(&username, "username", "u", "", "Username for authentication")
|
||||
tokenLoginCmd.Flags().StringVarP(&token, "token", "t", "", "Authentication token")
|
||||
if err := tokenLoginCmd.MarkFlagRequired("username"); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking username flag as required: %v\n", err)
|
||||
}
|
||||
if err := tokenLoginCmd.MarkFlagRequired("token"); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking token flag as required: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func loginWithPassword() {
|
||||
serverAddr := viper.GetString("server")
|
||||
if serverAddr == "" {
|
||||
serverAddr = "http://localhost:8080"
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/login/password", serverAddr)
|
||||
|
||||
loginReq := LoginRequest{
|
||||
Version: "v1",
|
||||
Login: LoginParams{
|
||||
User: username,
|
||||
Password: password,
|
||||
TOTPCode: totpCode,
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(loginReq)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create a context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp ErrorResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &tokenResp); unmarshalErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Save the token to the token file
|
||||
tokenInfo := TokenInfo{
|
||||
Username: username,
|
||||
Token: tokenResp.Token,
|
||||
Expires: tokenResp.Expires,
|
||||
}
|
||||
|
||||
if err := saveToken(tokenInfo); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving token: %v\n", err)
|
||||
// Continue anyway, as we can still display the token
|
||||
}
|
||||
|
||||
fmt.Println("Login successful!")
|
||||
fmt.Printf("Token: %s\n", tokenResp.Token)
|
||||
fmt.Printf("Expires: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
||||
}
|
||||
|
||||
func verifyTOTP() {
|
||||
serverAddr := viper.GetString("server")
|
||||
if serverAddr == "" {
|
||||
serverAddr = "http://localhost:8080"
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/login/totp", serverAddr)
|
||||
|
||||
totpReq := TOTPVerifyRequest{
|
||||
Version: "v1",
|
||||
Username: username,
|
||||
TOTPCode: totpCode,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(totpReq)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create a context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp ErrorResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &tokenResp); unmarshalErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Save the token to the token file
|
||||
tokenInfo := TokenInfo{
|
||||
Username: username,
|
||||
Token: tokenResp.Token,
|
||||
Expires: tokenResp.Expires,
|
||||
}
|
||||
|
||||
if err := saveToken(tokenInfo); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving token: %v\n", err)
|
||||
// Continue anyway, as we can still display the token
|
||||
}
|
||||
|
||||
fmt.Println("TOTP verification successful!")
|
||||
fmt.Printf("Token: %s\n", tokenResp.Token)
|
||||
fmt.Printf("Expires: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
||||
}
|
||||
|
||||
func loginWithToken() {
|
||||
serverAddr := viper.GetString("server")
|
||||
if serverAddr == "" {
|
||||
serverAddr = "http://localhost:8080"
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/login/token", serverAddr)
|
||||
|
||||
loginReq := LoginRequest{
|
||||
Version: "v1",
|
||||
Login: LoginParams{
|
||||
User: username,
|
||||
Token: token,
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(loginReq)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create a context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp ErrorResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if unmarshalErr := json.Unmarshal(body, &tokenResp); unmarshalErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Save the token to the token file
|
||||
tokenInfo := TokenInfo{
|
||||
Username: username,
|
||||
Token: tokenResp.Token,
|
||||
Expires: tokenResp.Expires,
|
||||
}
|
||||
|
||||
if err := saveToken(tokenInfo); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error saving token: %v\n", err)
|
||||
// Continue anyway, as we can still display the token
|
||||
}
|
||||
|
||||
fmt.Println("Token login successful!")
|
||||
fmt.Printf("Token: %s\n", tokenResp.Token)
|
||||
fmt.Printf("Expires: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
||||
}
|
||||
13
cmd/mcias-client/main.go
Normal file
13
cmd/mcias-client/main.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := Execute(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
78
cmd/mcias-client/root.go
Normal file
78
cmd/mcias-client/root.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var (
|
||||
cfgFile string
|
||||
serverAddr string
|
||||
tokenFile string
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "mcias-client",
|
||||
Short: "MCIAS Client - Command line client for the Metacircular Identity and Access System",
|
||||
Long: `MCIAS Client is a command line tool for interacting with the MCIAS server.
|
||||
It provides access to the MCIAS API endpoints for authentication and resource access.
|
||||
|
||||
It currently supports the following operations:
|
||||
1. User password authentication
|
||||
2. User token authentication
|
||||
3. Database credential retrieval`,
|
||||
}
|
||||
)
|
||||
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
// setupRootCommand initializes the root command and its flags
|
||||
func setupRootCommand() {
|
||||
cobra.OnInitialize(initConfig)
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.mcias-client.yaml)")
|
||||
rootCmd.PersistentFlags().StringVar(&serverAddr, "server", "http://localhost:8080", "MCIAS server address")
|
||||
rootCmd.PersistentFlags().StringVar(&tokenFile, "token-file", "", "File to store authentication token (default is $HOME/.mcias-token)")
|
||||
|
||||
if err := viper.BindPFlag("server", rootCmd.PersistentFlags().Lookup("server")); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error binding server flag: %v\n", err)
|
||||
}
|
||||
if err := viper.BindPFlag("token-file", rootCmd.PersistentFlags().Lookup("token-file")); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error binding token-file flag: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func initConfig() {
|
||||
if cfgFile != "" {
|
||||
viper.SetConfigFile(cfgFile)
|
||||
} else {
|
||||
home, err := os.UserHomeDir()
|
||||
cobra.CheckErr(err)
|
||||
|
||||
viper.AddConfigPath(home)
|
||||
viper.SetConfigType("yaml")
|
||||
viper.SetConfigName(".mcias-client")
|
||||
}
|
||||
|
||||
viper.AutomaticEnv()
|
||||
|
||||
if err := viper.ReadInConfig(); err == nil {
|
||||
fmt.Fprintln(os.Stderr, "Using config file:", viper.ConfigFileUsed())
|
||||
}
|
||||
|
||||
// Set default token file if not specified
|
||||
if viper.GetString("token-file") == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
viper.Set("token-file", fmt.Sprintf("%s/.mcias-token", home))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
setupRootCommand()
|
||||
}
|
||||
63
cmd/mcias-client/util.go
Normal file
63
cmd/mcias-client/util.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// loadToken loads the token from the token file
|
||||
func loadToken() (*TokenInfo, error) {
|
||||
tokenFilePath := viper.GetString("token-file")
|
||||
if tokenFilePath == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting home directory: %w", err)
|
||||
}
|
||||
tokenFilePath = fmt.Sprintf("%s/.mcias-token", home)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(tokenFilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading token file: %w", err)
|
||||
}
|
||||
|
||||
var tokenInfo TokenInfo
|
||||
if err := json.Unmarshal(data, &tokenInfo); err != nil {
|
||||
return nil, fmt.Errorf("error parsing token file: %w", err)
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if tokenInfo.Expires > 0 && tokenInfo.Expires < time.Now().Unix() {
|
||||
return nil, fmt.Errorf("token has expired, please login again")
|
||||
}
|
||||
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
// saveToken saves the token to the token file
|
||||
func saveToken(tokenInfo TokenInfo) error {
|
||||
tokenFilePath := viper.GetString("token-file")
|
||||
if tokenFilePath == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting home directory: %w", err)
|
||||
}
|
||||
tokenFilePath = fmt.Sprintf("%s/.mcias-token", home)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(tokenInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error encoding token: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(tokenFilePath, jsonData, 0600); err != nil {
|
||||
return fmt.Errorf("error saving token to file: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Token saved to %s\n", tokenFilePath)
|
||||
return nil
|
||||
}
|
||||
@@ -1,12 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
@@ -45,6 +47,7 @@ This command requires authentication with a username and token.`,
|
||||
},
|
||||
}
|
||||
|
||||
// nolint:gochecknoinits // This is a standard pattern in Cobra applications
|
||||
func init() {
|
||||
rootCmd.AddCommand(databaseCmd)
|
||||
databaseCmd.AddCommand(getCredentialsCmd)
|
||||
@@ -68,7 +71,12 @@ func getCredentials() {
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
|
||||
// Create a context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
@@ -89,7 +97,7 @@ func getCredentials() {
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp ErrorResponse
|
||||
if err := json.Unmarshal(body, &errResp); err == nil {
|
||||
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||
logger.Fatalf("Error: %s", errResp.Error)
|
||||
} else {
|
||||
logger.Fatalf("Error: %s", resp.Status)
|
||||
@@ -97,8 +105,8 @@ func getCredentials() {
|
||||
}
|
||||
|
||||
var creds DatabaseCredentials
|
||||
if err := json.Unmarshal(body, &creds); err != nil {
|
||||
logger.Fatalf("Failed to parse response: %v", err)
|
||||
if unmarshalErr := json.Unmarshal(body, &creds); unmarshalErr != nil {
|
||||
logger.Fatalf("Failed to parse response: %v", unmarshalErr)
|
||||
}
|
||||
|
||||
fmt.Println("Database Credentials:")
|
||||
@@ -107,4 +115,4 @@ func getCredentials() {
|
||||
fmt.Printf("Name: %s\n", creds.Name)
|
||||
fmt.Printf("User: %s\n", creds.User)
|
||||
fmt.Printf("Password: %s\n", creds.Password)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,4 +10,4 @@ func main() {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
184
cmd/mcias/migrate.go
Normal file
184
cmd/mcias/migrate.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var (
|
||||
migrationsDir string
|
||||
steps int
|
||||
)
|
||||
|
||||
var migrateCmd = &cobra.Command{
|
||||
Use: "migrate",
|
||||
Short: "Manage database migrations",
|
||||
Long: `Commands for managing database migrations in the MCIAS system.`,
|
||||
}
|
||||
|
||||
var migrateUpCmd = &cobra.Command{
|
||||
Use: "up [steps]",
|
||||
Short: "Apply migrations",
|
||||
Long: `Apply all or a specific number of migrations.
|
||||
If steps is not provided, all pending migrations will be applied.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runMigration("up", steps)
|
||||
},
|
||||
}
|
||||
|
||||
var migrateDownCmd = &cobra.Command{
|
||||
Use: "down [steps]",
|
||||
Short: "Revert migrations",
|
||||
Long: `Revert all or a specific number of migrations.
|
||||
If steps is not provided, all applied migrations will be reverted.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runMigration("down", steps)
|
||||
},
|
||||
}
|
||||
|
||||
var migrateVersionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Show current migration version",
|
||||
Long: `Display the current migration version of the database.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
showMigrationVersion()
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(migrateCmd)
|
||||
migrateCmd.AddCommand(migrateUpCmd)
|
||||
migrateCmd.AddCommand(migrateDownCmd)
|
||||
migrateCmd.AddCommand(migrateVersionCmd)
|
||||
|
||||
migrateCmd.PersistentFlags().StringVarP(&migrationsDir, "migrations", "m", "database/migrations", "Directory containing migration files")
|
||||
migrateCmd.PersistentFlags().IntVarP(&steps, "steps", "s", 0, "Number of migrations to apply or revert (0 means all)")
|
||||
}
|
||||
|
||||
func runMigration(direction string, steps int) {
|
||||
dbPath := viper.GetString("db")
|
||||
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
|
||||
|
||||
// Ensure migrations directory exists
|
||||
absPath, err := filepath.Abs(migrationsDir)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(absPath); os.IsNotExist(err) {
|
||||
logger.Fatalf("Migrations directory does not exist: %s", absPath)
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err := openDatabase(dbPath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create migration driver
|
||||
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to create migration driver: %v", err)
|
||||
}
|
||||
|
||||
// Create migrate instance
|
||||
m, err := migrate.NewWithDatabaseInstance(
|
||||
fmt.Sprintf("file://%s", absPath),
|
||||
"sqlite3", driver)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to create migration instance: %v", err)
|
||||
}
|
||||
|
||||
// Run migration
|
||||
if direction == "up" {
|
||||
if steps > 0 {
|
||||
err = m.Steps(steps)
|
||||
} else {
|
||||
err = m.Up()
|
||||
}
|
||||
if err != nil && err != migrate.ErrNoChange {
|
||||
logger.Fatalf("Failed to apply migrations: %v", err)
|
||||
}
|
||||
logger.Println("Migrations applied successfully")
|
||||
} else if direction == "down" {
|
||||
if steps > 0 {
|
||||
err = m.Steps(-steps)
|
||||
} else {
|
||||
err = m.Down()
|
||||
}
|
||||
if err != nil && err != migrate.ErrNoChange {
|
||||
logger.Fatalf("Failed to revert migrations: %v", err)
|
||||
}
|
||||
logger.Println("Migrations reverted successfully")
|
||||
}
|
||||
}
|
||||
|
||||
func showMigrationVersion() {
|
||||
dbPath := viper.GetString("db")
|
||||
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
|
||||
|
||||
// Open database connection
|
||||
db, err := openDatabase(dbPath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create migration driver
|
||||
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to create migration driver: %v", err)
|
||||
}
|
||||
|
||||
// Create migrate instance
|
||||
absPath, err := filepath.Abs(migrationsDir)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance(
|
||||
fmt.Sprintf("file://%s", absPath),
|
||||
"sqlite3", driver)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to create migration instance: %v", err)
|
||||
}
|
||||
|
||||
// Get current version
|
||||
version, dirty, err := m.Version()
|
||||
if err != nil {
|
||||
if err == migrate.ErrNilVersion {
|
||||
logger.Println("No migrations have been applied yet")
|
||||
return
|
||||
}
|
||||
logger.Fatalf("Failed to get migration version: %v", err)
|
||||
}
|
||||
|
||||
logger.Printf("Current migration version: %d (dirty: %t)", version, dirty)
|
||||
}
|
||||
|
||||
func openDatabase(dbPath string) (*sql.DB, error) {
|
||||
// Ensure database directory exists
|
||||
dbDir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -51,6 +52,7 @@ var revokePermissionCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
// nolint:gochecknoinits // This is a standard pattern in Cobra applications
|
||||
func init() {
|
||||
rootCmd.AddCommand(permissionCmd)
|
||||
permissionCmd.AddCommand(listPermissionsCmd)
|
||||
@@ -104,14 +106,14 @@ func listPermissions() {
|
||||
fmt.Println(strings.Repeat("-", 90))
|
||||
for rows.Next() {
|
||||
var id, resource, action, description string
|
||||
if err := rows.Scan(&id, &resource, &action, &description); err != nil {
|
||||
logger.Fatalf("Failed to scan permission row: %v", err)
|
||||
if scanErr := rows.Scan(&id, &resource, &action, &description); scanErr != nil {
|
||||
logger.Fatalf("Failed to scan permission row: %v", scanErr)
|
||||
}
|
||||
fmt.Printf("%-24s %-20s %-15s %-30s\n", id, resource, action, description)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
logger.Fatalf("Error iterating permission rows: %v", err)
|
||||
if rowErr := rows.Err(); rowErr != nil {
|
||||
logger.Fatalf("Error iterating permission rows: %v", rowErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,7 +131,7 @@ func grantPermission() {
|
||||
var roleID string
|
||||
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", permissionRole).Scan(&roleID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("Role %s not found", permissionRole)
|
||||
}
|
||||
logger.Fatalf("Failed to get role ID: %v", err)
|
||||
@@ -137,11 +139,11 @@ func grantPermission() {
|
||||
|
||||
// Get permission ID
|
||||
var permissionID string
|
||||
err = db.QueryRow("SELECT id FROM permissions WHERE resource = ? AND action = ?",
|
||||
err = db.QueryRow("SELECT id FROM permissions WHERE resource = ? AND action = ?",
|
||||
permissionResource, permissionAction).Scan(&permissionID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
logger.Fatalf("Permission with resource '%s' and action '%s' not found",
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("Permission with resource '%s' and action '%s' not found",
|
||||
permissionResource, permissionAction)
|
||||
}
|
||||
logger.Fatalf("Failed to get permission ID: %v", err)
|
||||
@@ -149,13 +151,13 @@ func grantPermission() {
|
||||
|
||||
// Check if role already has this permission
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM role_permissions WHERE rid = ? AND pid = ?",
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM role_permissions WHERE rid = ? AND pid = ?",
|
||||
roleID, permissionID).Scan(&count)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to check if role has permission: %v", err)
|
||||
}
|
||||
if count > 0 {
|
||||
logger.Fatalf("Role %s already has permission %s:%s",
|
||||
logger.Fatalf("Role %s already has permission %s:%s",
|
||||
permissionRole, permissionResource, permissionAction)
|
||||
}
|
||||
|
||||
@@ -163,13 +165,13 @@ func grantPermission() {
|
||||
id := ulid.Make().String()
|
||||
|
||||
// Grant permission to role
|
||||
_, err = db.Exec("INSERT INTO role_permissions (id, rid, pid) VALUES (?, ?, ?)",
|
||||
_, err = db.Exec("INSERT INTO role_permissions (id, rid, pid) VALUES (?, ?, ?)",
|
||||
id, roleID, permissionID)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to grant permission: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Permission %s:%s granted to role %s successfully\n",
|
||||
fmt.Printf("Permission %s:%s granted to role %s successfully\n",
|
||||
permissionResource, permissionAction, permissionRole)
|
||||
}
|
||||
|
||||
@@ -187,7 +189,7 @@ func revokePermission() {
|
||||
var roleID string
|
||||
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", permissionRole).Scan(&roleID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("Role %s not found", permissionRole)
|
||||
}
|
||||
logger.Fatalf("Failed to get role ID: %v", err)
|
||||
@@ -195,11 +197,11 @@ func revokePermission() {
|
||||
|
||||
// Get permission ID
|
||||
var permissionID string
|
||||
err = db.QueryRow("SELECT id FROM permissions WHERE resource = ? AND action = ?",
|
||||
err = db.QueryRow("SELECT id FROM permissions WHERE resource = ? AND action = ?",
|
||||
permissionResource, permissionAction).Scan(&permissionID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
logger.Fatalf("Permission with resource '%s' and action '%s' not found",
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
logger.Fatalf("Permission with resource '%s' and action '%s' not found",
|
||||
permissionResource, permissionAction)
|
||||
}
|
||||
logger.Fatalf("Failed to get permission ID: %v", err)
|
||||
@@ -207,13 +209,13 @@ func revokePermission() {
|
||||
|
||||
// Check if role has this permission
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM role_permissions WHERE rid = ? AND pid = ?",
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM role_permissions WHERE rid = ? AND pid = ?",
|
||||
roleID, permissionID).Scan(&count)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to check if role has permission: %v", err)
|
||||
}
|
||||
if count == 0 {
|
||||
logger.Fatalf("Role %s does not have permission %s:%s",
|
||||
logger.Fatalf("Role %s does not have permission %s:%s",
|
||||
permissionRole, permissionResource, permissionAction)
|
||||
}
|
||||
|
||||
@@ -223,6 +225,6 @@ func revokePermission() {
|
||||
logger.Fatalf("Failed to revoke permission: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Permission %s:%s revoked from role %s successfully\n",
|
||||
fmt.Printf("Permission %s:%s revoked from role %s successfully\n",
|
||||
permissionResource, permissionAction, permissionRole)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,4 +252,4 @@ func revokeRole() {
|
||||
}
|
||||
|
||||
fmt.Printf("Role %s revoked from user %s successfully\n", roleName, roleUser)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,10 +27,17 @@ It currently provides the following across metacircular services:
|
||||
)
|
||||
|
||||
func Execute() error {
|
||||
// Setup commands and flags
|
||||
setupRootCommand()
|
||||
setupTOTPCommands()
|
||||
// The migrate command is already set up in its init function
|
||||
|
||||
// Execute the root command
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func init() {
|
||||
// setupRootCommand initializes the root command and its flags
|
||||
func setupRootCommand() {
|
||||
cobra.OnInitialize(initConfig)
|
||||
|
||||
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.mcias.yaml)")
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestRootCommand(t *testing.T) {
|
||||
if rootCmd.Use != "mcias" {
|
||||
t.Errorf("Expected root command Use to be 'mcias', got '%s'", rootCmd.Use)
|
||||
}
|
||||
|
||||
if rootCmd.Short == "" {
|
||||
t.Error("Expected root command Short to be set")
|
||||
}
|
||||
|
||||
if rootCmd.Long == "" {
|
||||
t.Error("Expected root command Long to be set")
|
||||
}
|
||||
dbFlag := rootCmd.PersistentFlags().Lookup("db")
|
||||
if dbFlag == nil {
|
||||
t.Error("Expected 'db' flag to be defined")
|
||||
} else {
|
||||
if dbFlag.DefValue != "mcias.db" {
|
||||
t.Errorf("Expected 'db' flag default value to be 'mcias.db', got '%s'", dbFlag.DefValue)
|
||||
}
|
||||
}
|
||||
|
||||
addrFlag := rootCmd.PersistentFlags().Lookup("addr")
|
||||
if addrFlag == nil {
|
||||
t.Error("Expected 'addr' flag to be defined")
|
||||
} else {
|
||||
if addrFlag.DefValue != ":8080" {
|
||||
t.Errorf("Expected 'addr' flag default value to be ':8080', got '%s'", addrFlag.DefValue)
|
||||
}
|
||||
}
|
||||
hasServerCmd := false
|
||||
hasInitCmd := false
|
||||
hasUserCmd := false
|
||||
hasTokenCmd := false
|
||||
|
||||
for _, cmd := range rootCmd.Commands() {
|
||||
switch cmd.Use {
|
||||
case "server":
|
||||
hasServerCmd = true
|
||||
case "init":
|
||||
hasInitCmd = true
|
||||
case "user":
|
||||
hasUserCmd = true
|
||||
case "token":
|
||||
hasTokenCmd = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasServerCmd {
|
||||
t.Error("Expected 'server' command to be added to root command")
|
||||
}
|
||||
if !hasInitCmd {
|
||||
t.Error("Expected 'init' command to be added to root command")
|
||||
}
|
||||
if !hasUserCmd {
|
||||
t.Error("Expected 'user' command to be added to root command")
|
||||
}
|
||||
if !hasTokenCmd {
|
||||
t.Error("Expected 'token' command to be added to root command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecute(t *testing.T) {
|
||||
origCmd := rootCmd
|
||||
defer func() { rootCmd = origCmd }()
|
||||
|
||||
rootCmd = &cobra.Command{Use: "test"}
|
||||
if err := Execute(); err != nil {
|
||||
t.Errorf("Execute() returned an error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -12,9 +12,16 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
// userQuery is the SQL query to get user information from the database
|
||||
userQuery = `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?`
|
||||
)
|
||||
|
||||
var (
|
||||
totpUsername string
|
||||
totpCode string
|
||||
qrCodeOutput string
|
||||
issuer string
|
||||
)
|
||||
|
||||
var totpCmd = &cobra.Command{
|
||||
@@ -43,10 +50,22 @@ This command requires a username and a TOTP code.`,
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
var addTOTPCmd = &cobra.Command{
|
||||
Use: "add",
|
||||
Short: "Add a new TOTP token for a user",
|
||||
Long: `Add a new TOTP (Time-based One-Time Password) token for a user in the MCIAS system.
|
||||
This command requires a username. It will emit the secret, and optionally output a QR code image file.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
addTOTP()
|
||||
},
|
||||
}
|
||||
|
||||
// setupTOTPCommands initializes TOTP commands and flags
|
||||
func setupTOTPCommands() {
|
||||
rootCmd.AddCommand(totpCmd)
|
||||
totpCmd.AddCommand(enableTOTPCmd)
|
||||
totpCmd.AddCommand(validateTOTPCmd)
|
||||
totpCmd.AddCommand(addTOTPCmd)
|
||||
|
||||
enableTOTPCmd.Flags().StringVarP(&totpUsername, "username", "u", "", "Username to enable TOTP for")
|
||||
if err := enableTOTPCmd.MarkFlagRequired("username"); err != nil {
|
||||
@@ -61,6 +80,13 @@ func init() {
|
||||
if err := validateTOTPCmd.MarkFlagRequired("code"); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking code flag as required: %v\n", err)
|
||||
}
|
||||
|
||||
addTOTPCmd.Flags().StringVarP(&totpUsername, "username", "u", "", "Username to add TOTP token for")
|
||||
addTOTPCmd.Flags().StringVarP(&qrCodeOutput, "qr-output", "q", "", "Path to save QR code image (optional)")
|
||||
addTOTPCmd.Flags().StringVarP(&issuer, "issuer", "i", "MCIAS", "Issuer name for TOTP token (optional)")
|
||||
if err := addTOTPCmd.MarkFlagRequired("username"); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marking username flag as required: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func enableTOTP() {
|
||||
@@ -81,8 +107,7 @@ func enableTOTP() {
|
||||
var password, salt []byte
|
||||
var totpSecret sql.NullString
|
||||
|
||||
query := `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?`
|
||||
err = db.QueryRow(query, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
logger.Fatalf("User %s does not exist", totpUsername)
|
||||
@@ -141,8 +166,7 @@ func validateTOTP() {
|
||||
var password, salt []byte
|
||||
var totpSecret sql.NullString
|
||||
|
||||
query := `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?`
|
||||
err = db.QueryRow(query, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
logger.Fatalf("User %s does not exist", totpUsername)
|
||||
@@ -171,10 +195,80 @@ func validateTOTP() {
|
||||
logger.Fatalf("Failed to validate TOTP code: %v", err)
|
||||
}
|
||||
|
||||
// Close the database before potentially exiting
|
||||
db.Close()
|
||||
|
||||
if valid {
|
||||
fmt.Println("TOTP code is valid")
|
||||
} else {
|
||||
fmt.Println("TOTP code is invalid")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addTOTP() {
|
||||
dbPath := viper.GetString("db")
|
||||
|
||||
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Get the user from the database
|
||||
var userID string
|
||||
var created int64
|
||||
var username string
|
||||
var password, salt []byte
|
||||
var totpSecret sql.NullString
|
||||
|
||||
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
logger.Fatalf("User %s does not exist", totpUsername)
|
||||
}
|
||||
logger.Fatalf("Failed to get user: %v", err)
|
||||
}
|
||||
|
||||
// Check if TOTP is already enabled
|
||||
if totpSecret.Valid && totpSecret.String != "" {
|
||||
logger.Fatalf("TOTP is already enabled for user %s", totpUsername)
|
||||
}
|
||||
|
||||
// Create a user object
|
||||
user := &data.User{
|
||||
ID: userID,
|
||||
Created: created,
|
||||
User: username,
|
||||
Password: password,
|
||||
Salt: salt,
|
||||
}
|
||||
|
||||
// Generate a TOTP secret
|
||||
secret, err := user.GenerateTOTPSecret()
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to generate TOTP secret: %v", err)
|
||||
}
|
||||
|
||||
// Update the user in the database
|
||||
updateQuery := `UPDATE users SET totp_secret = ? WHERE id = ?`
|
||||
_, err = db.Exec(updateQuery, secret, userID)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to update user: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("TOTP token added for user %s\n", totpUsername)
|
||||
fmt.Printf("Secret: %s\n", secret)
|
||||
fmt.Println("Please save this secret in your authenticator app.")
|
||||
|
||||
// Generate QR code if output path is specified
|
||||
if qrCodeOutput != "" {
|
||||
err = data.GenerateTOTPQRCode(secret, username, issuer, qrCodeOutput)
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to generate QR code: %v", err)
|
||||
}
|
||||
fmt.Printf("QR code saved to %s\n", qrCodeOutput)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user