Junie: TOTP flow update and db migrations.
This commit is contained in:
parent
396214739e
commit
95d96732d2
|
@ -1,3 +1,4 @@
|
||||||
mcias.db
|
*.db
|
||||||
cmd/mcias/mcias
|
cmd/mcias/mcias
|
||||||
|
cmd/mcias-client/mcias-client
|
||||||
.idea
|
.idea
|
||||||
|
|
79
README.org
79
README.org
|
@ -41,7 +41,12 @@
|
||||||
|
|
||||||
* CLI Commands
|
* CLI Commands
|
||||||
|
|
||||||
MCIAS provides a command-line interface with the following commands:
|
MCIAS provides two command-line interfaces:
|
||||||
|
|
||||||
|
1. The server CLI (`mcias`) for managing the MCIAS server
|
||||||
|
2. The client CLI (`mcias-client`) for interacting with the MCIAS server
|
||||||
|
|
||||||
|
** Server CLI Commands
|
||||||
|
|
||||||
** Server Command
|
** Server Command
|
||||||
|
|
||||||
|
@ -81,6 +86,23 @@
|
||||||
go run main.go token list
|
go run main.go token list
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
|
** Migrate Commands
|
||||||
|
|
||||||
|
Apply database migrations:
|
||||||
|
#+begin_src bash
|
||||||
|
go run main.go migrate up [--migrations <dir>] [--steps <n>]
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
Revert database migrations:
|
||||||
|
#+begin_src bash
|
||||||
|
go run main.go migrate down [--migrations <dir>] [--steps <n>]
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
Show current migration version:
|
||||||
|
#+begin_src bash
|
||||||
|
go run main.go migrate version [--migrations <dir>]
|
||||||
|
#+end_src
|
||||||
|
|
||||||
* API Overview
|
* API Overview
|
||||||
|
|
||||||
** Authentication Endpoints
|
** Authentication Endpoints
|
||||||
|
@ -128,3 +150,58 @@
|
||||||
- Error handling correctness
|
- Error handling correctness
|
||||||
|
|
||||||
See the [[file:docs/installation.org][Installation and Usage Guide]] for more details.
|
See the [[file:docs/installation.org][Installation and Usage Guide]] for more details.
|
||||||
|
|
||||||
|
* Client Tool
|
||||||
|
|
||||||
|
MCIAS includes a separate command-line client tool (`mcias-client`) that can be used to interact with the MCIAS server. The client tool provides access to all the APIs defined in the server.
|
||||||
|
|
||||||
|
** Installation
|
||||||
|
|
||||||
|
To build and install the client tool:
|
||||||
|
|
||||||
|
#+begin_src bash
|
||||||
|
cd cmd/mcias-client
|
||||||
|
go build -o mcias-client
|
||||||
|
# Optional: Move to a directory in your PATH
|
||||||
|
sudo mv mcias-client /usr/local/bin/
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
** Client CLI Commands
|
||||||
|
|
||||||
|
*** Login Commands
|
||||||
|
|
||||||
|
Login with username and password:
|
||||||
|
#+begin_src bash
|
||||||
|
mcias-client login password --username <username> --password <password> [--totp <code>]
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
Login with a token:
|
||||||
|
#+begin_src bash
|
||||||
|
mcias-client login token --username <username> --token <token>
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
*** Database Commands
|
||||||
|
|
||||||
|
Get database credentials:
|
||||||
|
#+begin_src bash
|
||||||
|
mcias-client database credentials --username <username> --token <token>
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
Or use a stored token from a previous login:
|
||||||
|
#+begin_src bash
|
||||||
|
mcias-client database credentials --use-stored
|
||||||
|
#+end_src
|
||||||
|
|
||||||
|
** Configuration
|
||||||
|
|
||||||
|
The client tool can be configured using command-line flags or a configuration file:
|
||||||
|
|
||||||
|
- `--server`: MCIAS server address (default: http://localhost:8080)
|
||||||
|
- `--token-file`: File to store authentication token (default: $HOME/.mcias-token)
|
||||||
|
- `--config`: Config file (default: $HOME/.mcias-client.yaml)
|
||||||
|
|
||||||
|
Example configuration file ($HOME/.mcias-client.yaml):
|
||||||
|
#+begin_src yaml
|
||||||
|
server: "http://mcias.example.com:8080"
|
||||||
|
token-file: "/path/to/token/file"
|
||||||
|
#+end_src
|
||||||
|
|
92
api/auth.go
92
api/auth.go
|
@ -22,6 +22,12 @@ type TokenResponse struct {
|
||||||
Expires int64 `json:"expires"`
|
Expires int64 `json:"expires"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TOTPVerifyRequest struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
TOTPCode string `json:"totp_code"`
|
||||||
|
}
|
||||||
|
|
||||||
type ErrorResponse struct {
|
type ErrorResponse struct {
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
}
|
}
|
||||||
|
@ -57,17 +63,35 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check password and TOTP if enabled
|
// Check password only first
|
||||||
if !user.Check(&req.Login) {
|
if !user.CheckPassword(&req.Login) {
|
||||||
// If TOTP is enabled but no code was provided, return a special error
|
|
||||||
if user.HasTOTP() && req.Login.TOTPCode == "" {
|
|
||||||
s.sendError(w, "TOTP code required", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.sendError(w, "Invalid username or password", http.StatusUnauthorized)
|
s.sendError(w, "Invalid username or password", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If TOTP is enabled and a code was provided, verify it
|
||||||
|
if user.HasTOTP() {
|
||||||
|
if req.Login.TOTPCode == "" {
|
||||||
|
// TOTP is enabled but no code was provided
|
||||||
|
// Return a special response indicating TOTP is required
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
if err := json.NewEncoder(w).Encode(ErrorResponse{
|
||||||
|
Error: "TOTP code required",
|
||||||
|
}); err != nil {
|
||||||
|
s.Logger.Printf("Error encoding response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the TOTP code
|
||||||
|
valid, validErr := user.ValidateTOTPCode(req.Login.TOTPCode)
|
||||||
|
if validErr != nil || !valid {
|
||||||
|
s.sendError(w, "Invalid TOTP code", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
token, expires, err := s.createToken(user.ID)
|
token, expires, err := s.createToken(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Logger.Printf("Token creation error: %v", err)
|
s.Logger.Printf("Token creation error: %v", err)
|
||||||
|
@ -228,6 +252,60 @@ func (s *Server) renewToken(username, token string) (int64, error) {
|
||||||
return expires, nil
|
return expires, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleTOTPVerify(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req TOTPVerifyRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
s.sendError(w, "Invalid request format", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Version != "v1" || req.Username == "" || req.TOTPCode == "" {
|
||||||
|
s.sendError(w, "Invalid TOTP verification request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := s.getUserByUsername(req.Username)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
s.sendError(w, "User not found", http.StatusUnauthorized)
|
||||||
|
} else {
|
||||||
|
s.Logger.Printf("Database error: %v", err)
|
||||||
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if TOTP is enabled for the user
|
||||||
|
if !user.HasTOTP() {
|
||||||
|
s.sendError(w, "TOTP not enabled for user", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the TOTP code
|
||||||
|
valid, validErr := user.ValidateTOTPCode(req.TOTPCode)
|
||||||
|
if validErr != nil || !valid {
|
||||||
|
s.sendError(w, "Invalid TOTP code", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TOTP code is valid, create a token
|
||||||
|
token, expires, err := s.createToken(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
s.Logger.Printf("Token creation error: %v", err)
|
||||||
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if err := json.NewEncoder(w).Encode(TokenResponse{
|
||||||
|
Token: token,
|
||||||
|
Expires: expires,
|
||||||
|
}); err != nil {
|
||||||
|
s.Logger.Printf("Error encoding response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Request) {
|
||||||
// Extract authorization header
|
// Extract authorization header
|
||||||
authHeader := r.Header.Get("Authorization")
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
|
|
@ -32,6 +32,7 @@ func NewServer(db *sql.DB, logger *log.Logger) *Server {
|
||||||
func (s *Server) registerRoutes() {
|
func (s *Server) registerRoutes() {
|
||||||
s.Router.HandleFunc("POST /v1/login/password", s.handlePasswordLogin)
|
s.Router.HandleFunc("POST /v1/login/password", s.handlePasswordLogin)
|
||||||
s.Router.HandleFunc("POST /v1/login/token", s.handleTokenLogin)
|
s.Router.HandleFunc("POST /v1/login/token", s.handleTokenLogin)
|
||||||
|
s.Router.HandleFunc("POST /v1/login/totp", s.handleTOTPVerify)
|
||||||
s.Router.HandleFunc("GET /v1/database/credentials", s.handleDatabaseCredentials)
|
s.Router.HandleFunc("GET /v1/database/credentials", s.handleDatabaseCredentials)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,137 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DatabaseCredentials struct {
|
||||||
|
Host string `json:"host"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
User string `json:"user"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
dbUsername string
|
||||||
|
dbToken string
|
||||||
|
useStored bool
|
||||||
|
)
|
||||||
|
|
||||||
|
var databaseCmd = &cobra.Command{
|
||||||
|
Use: "database",
|
||||||
|
Short: "Manage database credentials",
|
||||||
|
Long: `Commands for managing database credentials in the MCIAS system.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var getCredentialsCmd = &cobra.Command{
|
||||||
|
Use: "credentials",
|
||||||
|
Short: "Get database credentials",
|
||||||
|
Long: `Retrieve database credentials from the MCIAS system.
|
||||||
|
This command requires authentication with a username and token.`,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
getCredentials()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(databaseCmd)
|
||||||
|
databaseCmd.AddCommand(getCredentialsCmd)
|
||||||
|
|
||||||
|
getCredentialsCmd.Flags().StringVarP(&dbUsername, "username", "u", "", "Username for authentication")
|
||||||
|
getCredentialsCmd.Flags().StringVarP(&dbToken, "token", "t", "", "Authentication token")
|
||||||
|
getCredentialsCmd.Flags().BoolVarP(&useStored, "use-stored", "s", false, "Use stored token from previous login")
|
||||||
|
|
||||||
|
// Make username required only if not using stored token
|
||||||
|
getCredentialsCmd.MarkFlagsMutuallyExclusive("token", "use-stored")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCredentials() {
|
||||||
|
// If using stored token, load it from the token file
|
||||||
|
if useStored {
|
||||||
|
tokenInfo, err := loadToken()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error loading token: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
dbUsername = tokenInfo.Username
|
||||||
|
dbToken = tokenInfo.Token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required parameters
|
||||||
|
if dbUsername == "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: username is required\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbToken == "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: token is required (either provide --token or use --use-stored)\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := viper.GetString("server")
|
||||||
|
if serverAddr == "" {
|
||||||
|
serverAddr = "http://localhost:8080"
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
|
||||||
|
|
||||||
|
// Create a context with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", dbToken))
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
var errResp ErrorResponse
|
||||||
|
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status)
|
||||||
|
}
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var creds DatabaseCredentials
|
||||||
|
if unmarshalErr := json.Unmarshal(body, &creds); unmarshalErr != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Database Credentials:")
|
||||||
|
fmt.Printf("Host: %s\n", creds.Host)
|
||||||
|
fmt.Printf("Port: %d\n", creds.Port)
|
||||||
|
fmt.Printf("Name: %s\n", creds.Name)
|
||||||
|
fmt.Printf("User: %s\n", creds.User)
|
||||||
|
fmt.Printf("Password: %s\n", creds.Password)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,368 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
token string
|
||||||
|
totpCode string
|
||||||
|
)
|
||||||
|
|
||||||
|
type LoginRequest struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
Login LoginParams `json:"login"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TOTPVerifyRequest struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
TOTPCode string `json:"totp_code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoginParams struct {
|
||||||
|
User string `json:"user"`
|
||||||
|
Password string `json:"password,omitempty"`
|
||||||
|
Token string `json:"token,omitempty"`
|
||||||
|
TOTPCode string `json:"totp_code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenResponse struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
Expires int64 `json:"expires"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenInfo struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
Expires int64 `json:"expires"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var loginCmd = &cobra.Command{
|
||||||
|
Use: "login",
|
||||||
|
Short: "Login to the MCIAS server",
|
||||||
|
Long: `Login to the MCIAS server using either a username/password or a token.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var passwordLoginCmd = &cobra.Command{
|
||||||
|
Use: "password",
|
||||||
|
Short: "Login with username and password",
|
||||||
|
Long: `Login to the MCIAS server using a username and password.`,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
loginWithPassword()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenLoginCmd = &cobra.Command{
|
||||||
|
Use: "token",
|
||||||
|
Short: "Login with a token",
|
||||||
|
Long: `Login to the MCIAS server using a token.`,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
loginWithToken()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var totpVerifyCmd = &cobra.Command{
|
||||||
|
Use: "totp",
|
||||||
|
Short: "Verify TOTP code",
|
||||||
|
Long: `Verify a TOTP code after password authentication.`,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
verifyTOTP()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(loginCmd)
|
||||||
|
loginCmd.AddCommand(passwordLoginCmd)
|
||||||
|
loginCmd.AddCommand(tokenLoginCmd)
|
||||||
|
loginCmd.AddCommand(totpVerifyCmd)
|
||||||
|
|
||||||
|
// TOTP verification flags
|
||||||
|
totpVerifyCmd.Flags().StringVarP(&username, "username", "u", "", "Username for authentication")
|
||||||
|
totpVerifyCmd.Flags().StringVarP(&totpCode, "code", "c", "", "TOTP code to verify")
|
||||||
|
if err := totpVerifyCmd.MarkFlagRequired("username"); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error marking username flag as required: %v\n", err)
|
||||||
|
}
|
||||||
|
if err := totpVerifyCmd.MarkFlagRequired("code"); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error marking code flag as required: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Password login flags
|
||||||
|
passwordLoginCmd.Flags().StringVarP(&username, "username", "u", "", "Username for authentication")
|
||||||
|
passwordLoginCmd.Flags().StringVarP(&password, "password", "p", "", "Password for authentication")
|
||||||
|
passwordLoginCmd.Flags().StringVarP(&totpCode, "totp", "t", "", "TOTP code (if enabled)")
|
||||||
|
if err := passwordLoginCmd.MarkFlagRequired("username"); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error marking username flag as required: %v\n", err)
|
||||||
|
}
|
||||||
|
if err := passwordLoginCmd.MarkFlagRequired("password"); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error marking password flag as required: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token login flags
|
||||||
|
tokenLoginCmd.Flags().StringVarP(&username, "username", "u", "", "Username for authentication")
|
||||||
|
tokenLoginCmd.Flags().StringVarP(&token, "token", "t", "", "Authentication token")
|
||||||
|
if err := tokenLoginCmd.MarkFlagRequired("username"); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error marking username flag as required: %v\n", err)
|
||||||
|
}
|
||||||
|
if err := tokenLoginCmd.MarkFlagRequired("token"); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error marking token flag as required: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loginWithPassword() {
|
||||||
|
serverAddr := viper.GetString("server")
|
||||||
|
if serverAddr == "" {
|
||||||
|
serverAddr = "http://localhost:8080"
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/v1/login/password", serverAddr)
|
||||||
|
|
||||||
|
loginReq := LoginRequest{
|
||||||
|
Version: "v1",
|
||||||
|
Login: LoginParams{
|
||||||
|
User: username,
|
||||||
|
Password: password,
|
||||||
|
TOTPCode: totpCode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(loginReq)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error creating request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a context with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
var errResp ErrorResponse
|
||||||
|
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status)
|
||||||
|
}
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp TokenResponse
|
||||||
|
if unmarshalErr := json.Unmarshal(body, &tokenResp); unmarshalErr != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the token to the token file
|
||||||
|
tokenInfo := TokenInfo{
|
||||||
|
Username: username,
|
||||||
|
Token: tokenResp.Token,
|
||||||
|
Expires: tokenResp.Expires,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := saveToken(tokenInfo); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error saving token: %v\n", err)
|
||||||
|
// Continue anyway, as we can still display the token
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Login successful!")
|
||||||
|
fmt.Printf("Token: %s\n", tokenResp.Token)
|
||||||
|
fmt.Printf("Expires: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyTOTP() {
|
||||||
|
serverAddr := viper.GetString("server")
|
||||||
|
if serverAddr == "" {
|
||||||
|
serverAddr = "http://localhost:8080"
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/v1/login/totp", serverAddr)
|
||||||
|
|
||||||
|
totpReq := TOTPVerifyRequest{
|
||||||
|
Version: "v1",
|
||||||
|
Username: username,
|
||||||
|
TOTPCode: totpCode,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(totpReq)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error creating request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a context with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
var errResp ErrorResponse
|
||||||
|
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status)
|
||||||
|
}
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp TokenResponse
|
||||||
|
if unmarshalErr := json.Unmarshal(body, &tokenResp); unmarshalErr != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the token to the token file
|
||||||
|
tokenInfo := TokenInfo{
|
||||||
|
Username: username,
|
||||||
|
Token: tokenResp.Token,
|
||||||
|
Expires: tokenResp.Expires,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := saveToken(tokenInfo); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error saving token: %v\n", err)
|
||||||
|
// Continue anyway, as we can still display the token
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("TOTP verification successful!")
|
||||||
|
fmt.Printf("Token: %s\n", tokenResp.Token)
|
||||||
|
fmt.Printf("Expires: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
|
||||||
|
func loginWithToken() {
|
||||||
|
serverAddr := viper.GetString("server")
|
||||||
|
if serverAddr == "" {
|
||||||
|
serverAddr = "http://localhost:8080"
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/v1/login/token", serverAddr)
|
||||||
|
|
||||||
|
loginReq := LoginRequest{
|
||||||
|
Version: "v1",
|
||||||
|
Login: LoginParams{
|
||||||
|
User: username,
|
||||||
|
Token: token,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(loginReq)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error creating request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a context with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to send request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to read response: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
var errResp ErrorResponse
|
||||||
|
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %s\n", errResp.Error)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %s\n", resp.Status)
|
||||||
|
}
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp TokenResponse
|
||||||
|
if unmarshalErr := json.Unmarshal(body, &tokenResp); unmarshalErr != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to parse response: %v\n", unmarshalErr)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the token to the token file
|
||||||
|
tokenInfo := TokenInfo{
|
||||||
|
Username: username,
|
||||||
|
Token: tokenResp.Token,
|
||||||
|
Expires: tokenResp.Expires,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := saveToken(tokenInfo); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error saving token: %v\n", err)
|
||||||
|
// Continue anyway, as we can still display the token
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Token login successful!")
|
||||||
|
fmt.Printf("Token: %s\n", tokenResp.Token)
|
||||||
|
fmt.Printf("Expires: %s\n", time.Unix(tokenResp.Expires, 0).Format(time.RFC3339))
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if err := Execute(); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cfgFile string
|
||||||
|
serverAddr string
|
||||||
|
tokenFile string
|
||||||
|
|
||||||
|
rootCmd = &cobra.Command{
|
||||||
|
Use: "mcias-client",
|
||||||
|
Short: "MCIAS Client - Command line client for the Metacircular Identity and Access System",
|
||||||
|
Long: `MCIAS Client is a command line tool for interacting with the MCIAS server.
|
||||||
|
It provides access to the MCIAS API endpoints for authentication and resource access.
|
||||||
|
|
||||||
|
It currently supports the following operations:
|
||||||
|
1. User password authentication
|
||||||
|
2. User token authentication
|
||||||
|
3. Database credential retrieval`,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func Execute() error {
|
||||||
|
return rootCmd.Execute()
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupRootCommand initializes the root command and its flags
|
||||||
|
func setupRootCommand() {
|
||||||
|
cobra.OnInitialize(initConfig)
|
||||||
|
|
||||||
|
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.mcias-client.yaml)")
|
||||||
|
rootCmd.PersistentFlags().StringVar(&serverAddr, "server", "http://localhost:8080", "MCIAS server address")
|
||||||
|
rootCmd.PersistentFlags().StringVar(&tokenFile, "token-file", "", "File to store authentication token (default is $HOME/.mcias-token)")
|
||||||
|
|
||||||
|
if err := viper.BindPFlag("server", rootCmd.PersistentFlags().Lookup("server")); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error binding server flag: %v\n", err)
|
||||||
|
}
|
||||||
|
if err := viper.BindPFlag("token-file", rootCmd.PersistentFlags().Lookup("token-file")); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error binding token-file flag: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initConfig() {
|
||||||
|
if cfgFile != "" {
|
||||||
|
viper.SetConfigFile(cfgFile)
|
||||||
|
} else {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
cobra.CheckErr(err)
|
||||||
|
|
||||||
|
viper.AddConfigPath(home)
|
||||||
|
viper.SetConfigType("yaml")
|
||||||
|
viper.SetConfigName(".mcias-client")
|
||||||
|
}
|
||||||
|
|
||||||
|
viper.AutomaticEnv()
|
||||||
|
|
||||||
|
if err := viper.ReadInConfig(); err == nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "Using config file:", viper.ConfigFileUsed())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default token file if not specified
|
||||||
|
if viper.GetString("token-file") == "" {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err == nil {
|
||||||
|
viper.Set("token-file", fmt.Sprintf("%s/.mcias-token", home))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
setupRootCommand()
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
// loadToken loads the token from the token file
|
||||||
|
func loadToken() (*TokenInfo, error) {
|
||||||
|
tokenFilePath := viper.GetString("token-file")
|
||||||
|
if tokenFilePath == "" {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting home directory: %w", err)
|
||||||
|
}
|
||||||
|
tokenFilePath = fmt.Sprintf("%s/.mcias-token", home)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(tokenFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error reading token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenInfo TokenInfo
|
||||||
|
if err := json.Unmarshal(data, &tokenInfo); err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if token is expired
|
||||||
|
if tokenInfo.Expires > 0 && tokenInfo.Expires < time.Now().Unix() {
|
||||||
|
return nil, fmt.Errorf("token has expired, please login again")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveToken saves the token to the token file
|
||||||
|
func saveToken(tokenInfo TokenInfo) error {
|
||||||
|
tokenFilePath := viper.GetString("token-file")
|
||||||
|
if tokenFilePath == "" {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting home directory: %w", err)
|
||||||
|
}
|
||||||
|
tokenFilePath = fmt.Sprintf("%s/.mcias-token", home)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(tokenInfo)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error encoding token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(tokenFilePath, jsonData, 0600); err != nil {
|
||||||
|
return fmt.Errorf("error saving token to file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Token saved to %s\n", tokenFilePath)
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,12 +1,14 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
@ -45,6 +47,7 @@ This command requires authentication with a username and token.`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:gochecknoinits // This is a standard pattern in Cobra applications
|
||||||
func init() {
|
func init() {
|
||||||
rootCmd.AddCommand(databaseCmd)
|
rootCmd.AddCommand(databaseCmd)
|
||||||
databaseCmd.AddCommand(getCredentialsCmd)
|
databaseCmd.AddCommand(getCredentialsCmd)
|
||||||
|
@ -68,7 +71,12 @@ func getCredentials() {
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
|
url := fmt.Sprintf("%s/v1/database/credentials?username=%s", serverAddr, dbUsername)
|
||||||
req, err := http.NewRequest("GET", url, nil)
|
|
||||||
|
// Create a context with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to create request: %v", err)
|
logger.Fatalf("Failed to create request: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -89,7 +97,7 @@ func getCredentials() {
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
var errResp ErrorResponse
|
var errResp ErrorResponse
|
||||||
if err := json.Unmarshal(body, &errResp); err == nil {
|
if unmarshalErr := json.Unmarshal(body, &errResp); unmarshalErr == nil {
|
||||||
logger.Fatalf("Error: %s", errResp.Error)
|
logger.Fatalf("Error: %s", errResp.Error)
|
||||||
} else {
|
} else {
|
||||||
logger.Fatalf("Error: %s", resp.Status)
|
logger.Fatalf("Error: %s", resp.Status)
|
||||||
|
@ -97,8 +105,8 @@ func getCredentials() {
|
||||||
}
|
}
|
||||||
|
|
||||||
var creds DatabaseCredentials
|
var creds DatabaseCredentials
|
||||||
if err := json.Unmarshal(body, &creds); err != nil {
|
if unmarshalErr := json.Unmarshal(body, &creds); unmarshalErr != nil {
|
||||||
logger.Fatalf("Failed to parse response: %v", err)
|
logger.Fatalf("Failed to parse response: %v", unmarshalErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("Database Credentials:")
|
fmt.Println("Database Credentials:")
|
||||||
|
@ -107,4 +115,4 @@ func getCredentials() {
|
||||||
fmt.Printf("Name: %s\n", creds.Name)
|
fmt.Printf("Name: %s\n", creds.Name)
|
||||||
fmt.Printf("User: %s\n", creds.User)
|
fmt.Printf("User: %s\n", creds.User)
|
||||||
fmt.Printf("Password: %s\n", creds.Password)
|
fmt.Printf("Password: %s\n", creds.Password)
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,4 +10,4 @@ func main() {
|
||||||
fmt.Fprintln(os.Stderr, err)
|
fmt.Fprintln(os.Stderr, err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,184 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/golang-migrate/migrate/v4"
|
||||||
|
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||||
|
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
migrationsDir string
|
||||||
|
steps int
|
||||||
|
)
|
||||||
|
|
||||||
|
var migrateCmd = &cobra.Command{
|
||||||
|
Use: "migrate",
|
||||||
|
Short: "Manage database migrations",
|
||||||
|
Long: `Commands for managing database migrations in the MCIAS system.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var migrateUpCmd = &cobra.Command{
|
||||||
|
Use: "up [steps]",
|
||||||
|
Short: "Apply migrations",
|
||||||
|
Long: `Apply all or a specific number of migrations.
|
||||||
|
If steps is not provided, all pending migrations will be applied.`,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
runMigration("up", steps)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var migrateDownCmd = &cobra.Command{
|
||||||
|
Use: "down [steps]",
|
||||||
|
Short: "Revert migrations",
|
||||||
|
Long: `Revert all or a specific number of migrations.
|
||||||
|
If steps is not provided, all applied migrations will be reverted.`,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
runMigration("down", steps)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var migrateVersionCmd = &cobra.Command{
|
||||||
|
Use: "version",
|
||||||
|
Short: "Show current migration version",
|
||||||
|
Long: `Display the current migration version of the database.`,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
showMigrationVersion()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(migrateCmd)
|
||||||
|
migrateCmd.AddCommand(migrateUpCmd)
|
||||||
|
migrateCmd.AddCommand(migrateDownCmd)
|
||||||
|
migrateCmd.AddCommand(migrateVersionCmd)
|
||||||
|
|
||||||
|
migrateCmd.PersistentFlags().StringVarP(&migrationsDir, "migrations", "m", "database/migrations", "Directory containing migration files")
|
||||||
|
migrateCmd.PersistentFlags().IntVarP(&steps, "steps", "s", 0, "Number of migrations to apply or revert (0 means all)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func runMigration(direction string, steps int) {
|
||||||
|
dbPath := viper.GetString("db")
|
||||||
|
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
|
||||||
|
|
||||||
|
// Ensure migrations directory exists
|
||||||
|
absPath, err := filepath.Abs(migrationsDir)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(absPath); os.IsNotExist(err) {
|
||||||
|
logger.Fatalf("Migrations directory does not exist: %s", absPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open database connection
|
||||||
|
db, err := openDatabase(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Create migration driver
|
||||||
|
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to create migration driver: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create migrate instance
|
||||||
|
m, err := migrate.NewWithDatabaseInstance(
|
||||||
|
fmt.Sprintf("file://%s", absPath),
|
||||||
|
"sqlite3", driver)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to create migration instance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run migration
|
||||||
|
if direction == "up" {
|
||||||
|
if steps > 0 {
|
||||||
|
err = m.Steps(steps)
|
||||||
|
} else {
|
||||||
|
err = m.Up()
|
||||||
|
}
|
||||||
|
if err != nil && err != migrate.ErrNoChange {
|
||||||
|
logger.Fatalf("Failed to apply migrations: %v", err)
|
||||||
|
}
|
||||||
|
logger.Println("Migrations applied successfully")
|
||||||
|
} else if direction == "down" {
|
||||||
|
if steps > 0 {
|
||||||
|
err = m.Steps(-steps)
|
||||||
|
} else {
|
||||||
|
err = m.Down()
|
||||||
|
}
|
||||||
|
if err != nil && err != migrate.ErrNoChange {
|
||||||
|
logger.Fatalf("Failed to revert migrations: %v", err)
|
||||||
|
}
|
||||||
|
logger.Println("Migrations reverted successfully")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func showMigrationVersion() {
|
||||||
|
dbPath := viper.GetString("db")
|
||||||
|
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
|
||||||
|
|
||||||
|
// Open database connection
|
||||||
|
db, err := openDatabase(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Create migration driver
|
||||||
|
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to create migration driver: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create migrate instance
|
||||||
|
absPath, err := filepath.Abs(migrationsDir)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := migrate.NewWithDatabaseInstance(
|
||||||
|
fmt.Sprintf("file://%s", absPath),
|
||||||
|
"sqlite3", driver)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to create migration instance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current version
|
||||||
|
version, dirty, err := m.Version()
|
||||||
|
if err != nil {
|
||||||
|
if err == migrate.ErrNilVersion {
|
||||||
|
logger.Println("No migrations have been applied yet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Fatalf("Failed to get migration version: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Printf("Current migration version: %d (dirty: %t)", version, dirty)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openDatabase(dbPath string) (*sql.DB, error) {
|
||||||
|
// Ensure database directory exists
|
||||||
|
dbDir := filepath.Dir(dbPath)
|
||||||
|
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open database connection
|
||||||
|
db, err := sql.Open("sqlite3", dbPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
@ -51,6 +52,7 @@ var revokePermissionCmd = &cobra.Command{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:gochecknoinits // This is a standard pattern in Cobra applications
|
||||||
func init() {
|
func init() {
|
||||||
rootCmd.AddCommand(permissionCmd)
|
rootCmd.AddCommand(permissionCmd)
|
||||||
permissionCmd.AddCommand(listPermissionsCmd)
|
permissionCmd.AddCommand(listPermissionsCmd)
|
||||||
|
@ -104,14 +106,14 @@ func listPermissions() {
|
||||||
fmt.Println(strings.Repeat("-", 90))
|
fmt.Println(strings.Repeat("-", 90))
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var id, resource, action, description string
|
var id, resource, action, description string
|
||||||
if err := rows.Scan(&id, &resource, &action, &description); err != nil {
|
if scanErr := rows.Scan(&id, &resource, &action, &description); scanErr != nil {
|
||||||
logger.Fatalf("Failed to scan permission row: %v", err)
|
logger.Fatalf("Failed to scan permission row: %v", scanErr)
|
||||||
}
|
}
|
||||||
fmt.Printf("%-24s %-20s %-15s %-30s\n", id, resource, action, description)
|
fmt.Printf("%-24s %-20s %-15s %-30s\n", id, resource, action, description)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
if rowErr := rows.Err(); rowErr != nil {
|
||||||
logger.Fatalf("Error iterating permission rows: %v", err)
|
logger.Fatalf("Error iterating permission rows: %v", rowErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,7 +131,7 @@ func grantPermission() {
|
||||||
var roleID string
|
var roleID string
|
||||||
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", permissionRole).Scan(&roleID)
|
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", permissionRole).Scan(&roleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("Role %s not found", permissionRole)
|
logger.Fatalf("Role %s not found", permissionRole)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get role ID: %v", err)
|
logger.Fatalf("Failed to get role ID: %v", err)
|
||||||
|
@ -137,11 +139,11 @@ func grantPermission() {
|
||||||
|
|
||||||
// Get permission ID
|
// Get permission ID
|
||||||
var permissionID string
|
var permissionID string
|
||||||
err = db.QueryRow("SELECT id FROM permissions WHERE resource = ? AND action = ?",
|
err = db.QueryRow("SELECT id FROM permissions WHERE resource = ? AND action = ?",
|
||||||
permissionResource, permissionAction).Scan(&permissionID)
|
permissionResource, permissionAction).Scan(&permissionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("Permission with resource '%s' and action '%s' not found",
|
logger.Fatalf("Permission with resource '%s' and action '%s' not found",
|
||||||
permissionResource, permissionAction)
|
permissionResource, permissionAction)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get permission ID: %v", err)
|
logger.Fatalf("Failed to get permission ID: %v", err)
|
||||||
|
@ -149,13 +151,13 @@ func grantPermission() {
|
||||||
|
|
||||||
// Check if role already has this permission
|
// Check if role already has this permission
|
||||||
var count int
|
var count int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM role_permissions WHERE rid = ? AND pid = ?",
|
err = db.QueryRow("SELECT COUNT(*) FROM role_permissions WHERE rid = ? AND pid = ?",
|
||||||
roleID, permissionID).Scan(&count)
|
roleID, permissionID).Scan(&count)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to check if role has permission: %v", err)
|
logger.Fatalf("Failed to check if role has permission: %v", err)
|
||||||
}
|
}
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
logger.Fatalf("Role %s already has permission %s:%s",
|
logger.Fatalf("Role %s already has permission %s:%s",
|
||||||
permissionRole, permissionResource, permissionAction)
|
permissionRole, permissionResource, permissionAction)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,13 +165,13 @@ func grantPermission() {
|
||||||
id := ulid.Make().String()
|
id := ulid.Make().String()
|
||||||
|
|
||||||
// Grant permission to role
|
// Grant permission to role
|
||||||
_, err = db.Exec("INSERT INTO role_permissions (id, rid, pid) VALUES (?, ?, ?)",
|
_, err = db.Exec("INSERT INTO role_permissions (id, rid, pid) VALUES (?, ?, ?)",
|
||||||
id, roleID, permissionID)
|
id, roleID, permissionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to grant permission: %v", err)
|
logger.Fatalf("Failed to grant permission: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Permission %s:%s granted to role %s successfully\n",
|
fmt.Printf("Permission %s:%s granted to role %s successfully\n",
|
||||||
permissionResource, permissionAction, permissionRole)
|
permissionResource, permissionAction, permissionRole)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,7 +189,7 @@ func revokePermission() {
|
||||||
var roleID string
|
var roleID string
|
||||||
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", permissionRole).Scan(&roleID)
|
err = db.QueryRow("SELECT id FROM roles WHERE role = ?", permissionRole).Scan(&roleID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("Role %s not found", permissionRole)
|
logger.Fatalf("Role %s not found", permissionRole)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get role ID: %v", err)
|
logger.Fatalf("Failed to get role ID: %v", err)
|
||||||
|
@ -195,11 +197,11 @@ func revokePermission() {
|
||||||
|
|
||||||
// Get permission ID
|
// Get permission ID
|
||||||
var permissionID string
|
var permissionID string
|
||||||
err = db.QueryRow("SELECT id FROM permissions WHERE resource = ? AND action = ?",
|
err = db.QueryRow("SELECT id FROM permissions WHERE resource = ? AND action = ?",
|
||||||
permissionResource, permissionAction).Scan(&permissionID)
|
permissionResource, permissionAction).Scan(&permissionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
logger.Fatalf("Permission with resource '%s' and action '%s' not found",
|
logger.Fatalf("Permission with resource '%s' and action '%s' not found",
|
||||||
permissionResource, permissionAction)
|
permissionResource, permissionAction)
|
||||||
}
|
}
|
||||||
logger.Fatalf("Failed to get permission ID: %v", err)
|
logger.Fatalf("Failed to get permission ID: %v", err)
|
||||||
|
@ -207,13 +209,13 @@ func revokePermission() {
|
||||||
|
|
||||||
// Check if role has this permission
|
// Check if role has this permission
|
||||||
var count int
|
var count int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM role_permissions WHERE rid = ? AND pid = ?",
|
err = db.QueryRow("SELECT COUNT(*) FROM role_permissions WHERE rid = ? AND pid = ?",
|
||||||
roleID, permissionID).Scan(&count)
|
roleID, permissionID).Scan(&count)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("Failed to check if role has permission: %v", err)
|
logger.Fatalf("Failed to check if role has permission: %v", err)
|
||||||
}
|
}
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
logger.Fatalf("Role %s does not have permission %s:%s",
|
logger.Fatalf("Role %s does not have permission %s:%s",
|
||||||
permissionRole, permissionResource, permissionAction)
|
permissionRole, permissionResource, permissionAction)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -223,6 +225,6 @@ func revokePermission() {
|
||||||
logger.Fatalf("Failed to revoke permission: %v", err)
|
logger.Fatalf("Failed to revoke permission: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Permission %s:%s revoked from role %s successfully\n",
|
fmt.Printf("Permission %s:%s revoked from role %s successfully\n",
|
||||||
permissionResource, permissionAction, permissionRole)
|
permissionResource, permissionAction, permissionRole)
|
||||||
}
|
}
|
||||||
|
|
|
@ -252,4 +252,4 @@ func revokeRole() {
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Role %s revoked from user %s successfully\n", roleName, roleUser)
|
fmt.Printf("Role %s revoked from user %s successfully\n", roleName, roleUser)
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,10 +27,17 @@ It currently provides the following across metacircular services:
|
||||||
)
|
)
|
||||||
|
|
||||||
func Execute() error {
|
func Execute() error {
|
||||||
|
// Setup commands and flags
|
||||||
|
setupRootCommand()
|
||||||
|
setupTOTPCommands()
|
||||||
|
// The migrate command is already set up in its init function
|
||||||
|
|
||||||
|
// Execute the root command
|
||||||
return rootCmd.Execute()
|
return rootCmd.Execute()
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
// setupRootCommand initializes the root command and its flags
|
||||||
|
func setupRootCommand() {
|
||||||
cobra.OnInitialize(initConfig)
|
cobra.OnInitialize(initConfig)
|
||||||
|
|
||||||
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.mcias.yaml)")
|
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.mcias.yaml)")
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRootCommand(t *testing.T) {
|
|
||||||
if rootCmd.Use != "mcias" {
|
|
||||||
t.Errorf("Expected root command Use to be 'mcias', got '%s'", rootCmd.Use)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rootCmd.Short == "" {
|
|
||||||
t.Error("Expected root command Short to be set")
|
|
||||||
}
|
|
||||||
|
|
||||||
if rootCmd.Long == "" {
|
|
||||||
t.Error("Expected root command Long to be set")
|
|
||||||
}
|
|
||||||
dbFlag := rootCmd.PersistentFlags().Lookup("db")
|
|
||||||
if dbFlag == nil {
|
|
||||||
t.Error("Expected 'db' flag to be defined")
|
|
||||||
} else {
|
|
||||||
if dbFlag.DefValue != "mcias.db" {
|
|
||||||
t.Errorf("Expected 'db' flag default value to be 'mcias.db', got '%s'", dbFlag.DefValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
addrFlag := rootCmd.PersistentFlags().Lookup("addr")
|
|
||||||
if addrFlag == nil {
|
|
||||||
t.Error("Expected 'addr' flag to be defined")
|
|
||||||
} else {
|
|
||||||
if addrFlag.DefValue != ":8080" {
|
|
||||||
t.Errorf("Expected 'addr' flag default value to be ':8080', got '%s'", addrFlag.DefValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hasServerCmd := false
|
|
||||||
hasInitCmd := false
|
|
||||||
hasUserCmd := false
|
|
||||||
hasTokenCmd := false
|
|
||||||
|
|
||||||
for _, cmd := range rootCmd.Commands() {
|
|
||||||
switch cmd.Use {
|
|
||||||
case "server":
|
|
||||||
hasServerCmd = true
|
|
||||||
case "init":
|
|
||||||
hasInitCmd = true
|
|
||||||
case "user":
|
|
||||||
hasUserCmd = true
|
|
||||||
case "token":
|
|
||||||
hasTokenCmd = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasServerCmd {
|
|
||||||
t.Error("Expected 'server' command to be added to root command")
|
|
||||||
}
|
|
||||||
if !hasInitCmd {
|
|
||||||
t.Error("Expected 'init' command to be added to root command")
|
|
||||||
}
|
|
||||||
if !hasUserCmd {
|
|
||||||
t.Error("Expected 'user' command to be added to root command")
|
|
||||||
}
|
|
||||||
if !hasTokenCmd {
|
|
||||||
t.Error("Expected 'token' command to be added to root command")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExecute(t *testing.T) {
|
|
||||||
origCmd := rootCmd
|
|
||||||
defer func() { rootCmd = origCmd }()
|
|
||||||
|
|
||||||
rootCmd = &cobra.Command{Use: "test"}
|
|
||||||
if err := Execute(); err != nil {
|
|
||||||
t.Errorf("Execute() returned an error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -12,9 +12,16 @@ import (
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// userQuery is the SQL query to get user information from the database
|
||||||
|
userQuery = `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?`
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
totpUsername string
|
totpUsername string
|
||||||
totpCode string
|
totpCode string
|
||||||
|
qrCodeOutput string
|
||||||
|
issuer string
|
||||||
)
|
)
|
||||||
|
|
||||||
var totpCmd = &cobra.Command{
|
var totpCmd = &cobra.Command{
|
||||||
|
@ -43,10 +50,22 @@ This command requires a username and a TOTP code.`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
var addTOTPCmd = &cobra.Command{
|
||||||
|
Use: "add",
|
||||||
|
Short: "Add a new TOTP token for a user",
|
||||||
|
Long: `Add a new TOTP (Time-based One-Time Password) token for a user in the MCIAS system.
|
||||||
|
This command requires a username. It will emit the secret, and optionally output a QR code image file.`,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
addTOTP()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupTOTPCommands initializes TOTP commands and flags
|
||||||
|
func setupTOTPCommands() {
|
||||||
rootCmd.AddCommand(totpCmd)
|
rootCmd.AddCommand(totpCmd)
|
||||||
totpCmd.AddCommand(enableTOTPCmd)
|
totpCmd.AddCommand(enableTOTPCmd)
|
||||||
totpCmd.AddCommand(validateTOTPCmd)
|
totpCmd.AddCommand(validateTOTPCmd)
|
||||||
|
totpCmd.AddCommand(addTOTPCmd)
|
||||||
|
|
||||||
enableTOTPCmd.Flags().StringVarP(&totpUsername, "username", "u", "", "Username to enable TOTP for")
|
enableTOTPCmd.Flags().StringVarP(&totpUsername, "username", "u", "", "Username to enable TOTP for")
|
||||||
if err := enableTOTPCmd.MarkFlagRequired("username"); err != nil {
|
if err := enableTOTPCmd.MarkFlagRequired("username"); err != nil {
|
||||||
|
@ -61,6 +80,13 @@ func init() {
|
||||||
if err := validateTOTPCmd.MarkFlagRequired("code"); err != nil {
|
if err := validateTOTPCmd.MarkFlagRequired("code"); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Error marking code flag as required: %v\n", err)
|
fmt.Fprintf(os.Stderr, "Error marking code flag as required: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addTOTPCmd.Flags().StringVarP(&totpUsername, "username", "u", "", "Username to add TOTP token for")
|
||||||
|
addTOTPCmd.Flags().StringVarP(&qrCodeOutput, "qr-output", "q", "", "Path to save QR code image (optional)")
|
||||||
|
addTOTPCmd.Flags().StringVarP(&issuer, "issuer", "i", "MCIAS", "Issuer name for TOTP token (optional)")
|
||||||
|
if err := addTOTPCmd.MarkFlagRequired("username"); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error marking username flag as required: %v\n", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func enableTOTP() {
|
func enableTOTP() {
|
||||||
|
@ -81,8 +107,7 @@ func enableTOTP() {
|
||||||
var password, salt []byte
|
var password, salt []byte
|
||||||
var totpSecret sql.NullString
|
var totpSecret sql.NullString
|
||||||
|
|
||||||
query := `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?`
|
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||||
err = db.QueryRow(query, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
logger.Fatalf("User %s does not exist", totpUsername)
|
logger.Fatalf("User %s does not exist", totpUsername)
|
||||||
|
@ -141,8 +166,7 @@ func validateTOTP() {
|
||||||
var password, salt []byte
|
var password, salt []byte
|
||||||
var totpSecret sql.NullString
|
var totpSecret sql.NullString
|
||||||
|
|
||||||
query := `SELECT id, created, user, password, salt, totp_secret FROM users WHERE user = ?`
|
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||||
err = db.QueryRow(query, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
logger.Fatalf("User %s does not exist", totpUsername)
|
logger.Fatalf("User %s does not exist", totpUsername)
|
||||||
|
@ -171,10 +195,80 @@ func validateTOTP() {
|
||||||
logger.Fatalf("Failed to validate TOTP code: %v", err)
|
logger.Fatalf("Failed to validate TOTP code: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close the database before potentially exiting
|
||||||
|
db.Close()
|
||||||
|
|
||||||
if valid {
|
if valid {
|
||||||
fmt.Println("TOTP code is valid")
|
fmt.Println("TOTP code is valid")
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("TOTP code is invalid")
|
fmt.Println("TOTP code is invalid")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addTOTP() {
|
||||||
|
dbPath := viper.GetString("db")
|
||||||
|
|
||||||
|
logger := log.New(os.Stdout, "MCIAS: ", log.LstdFlags)
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", dbPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Get the user from the database
|
||||||
|
var userID string
|
||||||
|
var created int64
|
||||||
|
var username string
|
||||||
|
var password, salt []byte
|
||||||
|
var totpSecret sql.NullString
|
||||||
|
|
||||||
|
err = db.QueryRow(userQuery, totpUsername).Scan(&userID, &created, &username, &password, &salt, &totpSecret)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
logger.Fatalf("User %s does not exist", totpUsername)
|
||||||
|
}
|
||||||
|
logger.Fatalf("Failed to get user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if TOTP is already enabled
|
||||||
|
if totpSecret.Valid && totpSecret.String != "" {
|
||||||
|
logger.Fatalf("TOTP is already enabled for user %s", totpUsername)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a user object
|
||||||
|
user := &data.User{
|
||||||
|
ID: userID,
|
||||||
|
Created: created,
|
||||||
|
User: username,
|
||||||
|
Password: password,
|
||||||
|
Salt: salt,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a TOTP secret
|
||||||
|
secret, err := user.GenerateTOTPSecret()
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to generate TOTP secret: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the user in the database
|
||||||
|
updateQuery := `UPDATE users SET totp_secret = ? WHERE id = ?`
|
||||||
|
_, err = db.Exec(updateQuery, secret, userID)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to update user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("TOTP token added for user %s\n", totpUsername)
|
||||||
|
fmt.Printf("Secret: %s\n", secret)
|
||||||
|
fmt.Println("Please save this secret in your authenticator app.")
|
||||||
|
|
||||||
|
// Generate QR code if output path is specified
|
||||||
|
if qrCodeOutput != "" {
|
||||||
|
err = data.GenerateTOTPQRCode(secret, username, issuer, qrCodeOutput)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatalf("Failed to generate QR code: %v", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("QR code saved to %s\n", qrCodeOutput)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
47
data/auth.go
47
data/auth.go
|
@ -7,7 +7,16 @@ import (
|
||||||
"github.com/oklog/ulid/v2"
|
"github.com/oklog/ulid/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Permission represents a system permission
|
const (
|
||||||
|
// Constants for error messages
|
||||||
|
errScanPermission = "failed to scan permission: %w"
|
||||||
|
errIteratePermissions = "error iterating permissions: %w"
|
||||||
|
|
||||||
|
// Constants for comparison
|
||||||
|
zeroCount = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
// Permission represents a system permission.
|
||||||
type Permission struct {
|
type Permission struct {
|
||||||
ID string
|
ID string
|
||||||
Resource string
|
Resource string
|
||||||
|
@ -15,12 +24,12 @@ type Permission struct {
|
||||||
Description string
|
Description string
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizationService provides methods for checking user permissions
|
// AuthorizationService provides methods for checking user permissions.
|
||||||
type AuthorizationService struct {
|
type AuthorizationService struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthorizationService creates a new authorization service
|
// NewAuthorizationService creates a new authorization service.
|
||||||
func NewAuthorizationService(db *sql.DB) *AuthorizationService {
|
func NewAuthorizationService(db *sql.DB) *AuthorizationService {
|
||||||
return &AuthorizationService{db: db}
|
return &AuthorizationService{db: db}
|
||||||
}
|
}
|
||||||
|
@ -40,10 +49,10 @@ func (a *AuthorizationService) UserHasPermission(userID, resource, action string
|
||||||
return false, fmt.Errorf("failed to check user permission: %w", err)
|
return false, fmt.Errorf("failed to check user permission: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return count > 0, nil
|
return count > zeroCount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserPermissions returns all permissions for a user based on their roles
|
// GetUserPermissions returns all permissions for a user based on their roles.
|
||||||
func (a *AuthorizationService) GetUserPermissions(userID string) ([]Permission, error) {
|
func (a *AuthorizationService) GetUserPermissions(userID string) ([]Permission, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT DISTINCT p.id, p.resource, p.action, p.description FROM permissions p
|
SELECT DISTINCT p.id, p.resource, p.action, p.description FROM permissions p
|
||||||
|
@ -61,20 +70,20 @@ func (a *AuthorizationService) GetUserPermissions(userID string) ([]Permission,
|
||||||
var permissions []Permission
|
var permissions []Permission
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var perm Permission
|
var perm Permission
|
||||||
if err := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); err != nil {
|
if scanErr := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); scanErr != nil {
|
||||||
return nil, fmt.Errorf("failed to scan permission: %w", err)
|
return nil, fmt.Errorf(errScanPermission, scanErr)
|
||||||
}
|
}
|
||||||
permissions = append(permissions, perm)
|
permissions = append(permissions, perm)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
if rowErr := rows.Err(); rowErr != nil {
|
||||||
return nil, fmt.Errorf("error iterating permissions: %w", err)
|
return nil, fmt.Errorf(errIteratePermissions, rowErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return permissions, nil
|
return permissions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRolePermissions returns all permissions for a specific role
|
// GetRolePermissions returns all permissions for a specific role.
|
||||||
func (a *AuthorizationService) GetRolePermissions(roleID string) ([]Permission, error) {
|
func (a *AuthorizationService) GetRolePermissions(roleID string) ([]Permission, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT p.id, p.resource, p.action, p.description FROM permissions p
|
SELECT p.id, p.resource, p.action, p.description FROM permissions p
|
||||||
|
@ -91,14 +100,14 @@ func (a *AuthorizationService) GetRolePermissions(roleID string) ([]Permission,
|
||||||
var permissions []Permission
|
var permissions []Permission
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var perm Permission
|
var perm Permission
|
||||||
if err := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); err != nil {
|
if scanErr := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); scanErr != nil {
|
||||||
return nil, fmt.Errorf("failed to scan permission: %w", err)
|
return nil, fmt.Errorf(errScanPermission, scanErr)
|
||||||
}
|
}
|
||||||
permissions = append(permissions, perm)
|
permissions = append(permissions, perm)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
if rowErr := rows.Err(); rowErr != nil {
|
||||||
return nil, fmt.Errorf("error iterating permissions: %w", err)
|
return nil, fmt.Errorf(errIteratePermissions, rowErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return permissions, nil
|
return permissions, nil
|
||||||
|
@ -142,7 +151,7 @@ func (a *AuthorizationService) RevokePermissionFromRole(roleID, permissionID str
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPermissions returns all permissions in the system
|
// GetAllPermissions returns all permissions in the system.
|
||||||
func (a *AuthorizationService) GetAllPermissions() ([]Permission, error) {
|
func (a *AuthorizationService) GetAllPermissions() ([]Permission, error) {
|
||||||
query := `SELECT id, resource, action, description FROM permissions`
|
query := `SELECT id, resource, action, description FROM permissions`
|
||||||
|
|
||||||
|
@ -155,14 +164,14 @@ func (a *AuthorizationService) GetAllPermissions() ([]Permission, error) {
|
||||||
var permissions []Permission
|
var permissions []Permission
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var perm Permission
|
var perm Permission
|
||||||
if err := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); err != nil {
|
if scanErr := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); scanErr != nil {
|
||||||
return nil, fmt.Errorf("failed to scan permission: %w", err)
|
return nil, fmt.Errorf(errScanPermission, scanErr)
|
||||||
}
|
}
|
||||||
permissions = append(permissions, perm)
|
permissions = append(permissions, perm)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Err(); err != nil {
|
if rowErr := rows.Err(); rowErr != nil {
|
||||||
return nil, fmt.Errorf("error iterating permissions: %w", err)
|
return nil, fmt.Errorf(errIteratePermissions, rowErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return permissions, nil
|
return permissions, nil
|
||||||
|
|
|
@ -150,12 +150,12 @@ func TestGetUserPermissions(t *testing.T) {
|
||||||
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Admin should have 4 permissions
|
// Admin should have 4 permissions
|
||||||
if len(permissions) != 4 {
|
if len(permissions) != 4 {
|
||||||
t.Errorf("Admin should have 4 permissions, got %d", len(permissions))
|
t.Errorf("Admin should have 4 permissions, got %d", len(permissions))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for specific permissions
|
// Check for specific permissions
|
||||||
hasDBRead := false
|
hasDBRead := false
|
||||||
hasDBWrite := false
|
hasDBWrite := false
|
||||||
|
@ -167,7 +167,7 @@ func TestGetUserPermissions(t *testing.T) {
|
||||||
hasDBWrite = true
|
hasDBWrite = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasDBRead {
|
if !hasDBRead {
|
||||||
t.Errorf("Admin should have database_credentials:read permission")
|
t.Errorf("Admin should have database_credentials:read permission")
|
||||||
}
|
}
|
||||||
|
@ -182,12 +182,12 @@ func TestGetUserPermissions(t *testing.T) {
|
||||||
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// DB Operator should have 1 permission
|
// DB Operator should have 1 permission
|
||||||
if len(permissions) != 1 {
|
if len(permissions) != 1 {
|
||||||
t.Errorf("DB Operator should have 1 permission, got %d", len(permissions))
|
t.Errorf("DB Operator should have 1 permission, got %d", len(permissions))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for specific permissions
|
// Check for specific permissions
|
||||||
hasDBRead := false
|
hasDBRead := false
|
||||||
hasDBWrite := false
|
hasDBWrite := false
|
||||||
|
@ -199,7 +199,7 @@ func TestGetUserPermissions(t *testing.T) {
|
||||||
hasDBWrite = true
|
hasDBWrite = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasDBRead {
|
if !hasDBRead {
|
||||||
t.Errorf("DB Operator should have database_credentials:read permission")
|
t.Errorf("DB Operator should have database_credentials:read permission")
|
||||||
}
|
}
|
||||||
|
@ -214,10 +214,10 @@ func TestGetUserPermissions(t *testing.T) {
|
||||||
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Regular user should have 0 permissions
|
// Regular user should have 0 permissions
|
||||||
if len(permissions) != 0 {
|
if len(permissions) != 0 {
|
||||||
t.Errorf("Regular user should have 0 permissions, got %d", len(permissions))
|
t.Errorf("Regular user should have 0 permissions, got %d", len(permissions))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
91
data/totp.go
91
data/totp.go
|
@ -3,14 +3,39 @@ package data
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
|
||||||
|
// #nosec G505 - SHA1 is used here because TOTP (RFC 6238) specifically uses HMAC-SHA1
|
||||||
|
// as the default algorithm, and many authenticator apps still use it.
|
||||||
|
// In the future, we should consider supporting stronger algorithms like SHA256 or SHA512.
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"image/png"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"rsc.io/qr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GenerateRandomBase32 generates a random base32 encoded string of the specified length
|
const (
|
||||||
|
// TOTPTimeStep is the time step in seconds for TOTP.
|
||||||
|
TOTPTimeStep = 30
|
||||||
|
// TOTPDigits is the number of digits in a TOTP code.
|
||||||
|
TOTPDigits = 6
|
||||||
|
// TOTPModulo is the modulo value for truncating the TOTP hash.
|
||||||
|
TOTPModulo = 1000000
|
||||||
|
// TOTPTimeWindow is the number of time steps to check before and after the current time.
|
||||||
|
TOTPTimeWindow = 1
|
||||||
|
|
||||||
|
// Constants for TOTP calculation
|
||||||
|
timeBytesLength = 8
|
||||||
|
dynamicTruncationMask = 0x0F
|
||||||
|
truncationModulusMask = 0x7FFFFFFF
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateRandomBase32 generates a random base32 encoded string of the specified length.
|
||||||
func GenerateRandomBase32(length int) (string, error) {
|
func GenerateRandomBase32(length int) (string, error) {
|
||||||
// Generate random bytes
|
// Generate random bytes
|
||||||
randomBytes := make([]byte, length)
|
randomBytes := make([]byte, length)
|
||||||
|
@ -29,12 +54,11 @@ func GenerateRandomBase32(length int) (string, error) {
|
||||||
|
|
||||||
// ValidateTOTP validates a TOTP code against a secret
|
// ValidateTOTP validates a TOTP code against a secret
|
||||||
func ValidateTOTP(secret, code string) bool {
|
func ValidateTOTP(secret, code string) bool {
|
||||||
// Allow for a time skew of 30 seconds in either direction
|
// Get current time step
|
||||||
timeWindow := 1 // 1 before and 1 after current time
|
currentTime := time.Now().Unix() / TOTPTimeStep
|
||||||
currentTime := time.Now().Unix() / 30
|
|
||||||
|
|
||||||
// Try the time window
|
// Try the time window (allow for time skew)
|
||||||
for i := -timeWindow; i <= timeWindow; i++ {
|
for i := -TOTPTimeWindow; i <= TOTPTimeWindow; i++ {
|
||||||
if calculateTOTP(secret, currentTime+int64(i)) == code {
|
if calculateTOTP(secret, currentTime+int64(i)) == code {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -53,7 +77,7 @@ func calculateTOTP(secret string, timeCounter int64) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert time counter to bytes (big endian)
|
// Convert time counter to bytes (big endian)
|
||||||
timeBytes := make([]byte, 8)
|
timeBytes := make([]byte, timeBytesLength)
|
||||||
binary.BigEndian.PutUint64(timeBytes, uint64(timeCounter))
|
binary.BigEndian.PutUint64(timeBytes, uint64(timeCounter))
|
||||||
|
|
||||||
// Calculate HMAC-SHA1
|
// Calculate HMAC-SHA1
|
||||||
|
@ -62,25 +86,40 @@ func calculateTOTP(secret string, timeCounter int64) string {
|
||||||
hash := h.Sum(nil)
|
hash := h.Sum(nil)
|
||||||
|
|
||||||
// Dynamic truncation
|
// Dynamic truncation
|
||||||
offset := hash[len(hash)-1] & 0x0F
|
offset := hash[len(hash)-1] & dynamicTruncationMask
|
||||||
truncatedHash := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7FFFFFFF
|
truncatedHash := binary.BigEndian.Uint32(hash[offset:offset+4]) & truncationModulusMask
|
||||||
otp := truncatedHash % 1000000
|
otp := truncatedHash % TOTPModulo
|
||||||
|
|
||||||
// Convert to 6-digit string with leading zeros if needed
|
// Format as a 6-digit string with leading zeros
|
||||||
result := ""
|
return fmt.Sprintf("%0*d", TOTPDigits, otp)
|
||||||
if otp < 10 {
|
}
|
||||||
result = "00000" + string(otp+'0')
|
|
||||||
} else if otp < 100 {
|
// GenerateTOTPQRCode generates a QR code for a TOTP secret and saves it to a file
|
||||||
result = "0000" + string((otp/10)+'0') + string((otp%10)+'0')
|
func GenerateTOTPQRCode(secret, username, issuer, outputPath string) error {
|
||||||
} else if otp < 1000 {
|
// Format the TOTP URI according to the KeyURI format
|
||||||
result = "000" + string((otp/100)+'0') + string(((otp/10)%10)+'0') + string((otp%10)+'0')
|
// https://github.com/google/google-authenticator/wiki/Key-Uri-Format
|
||||||
} else if otp < 10000 {
|
uri := fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&algorithm=SHA1&digits=%d&period=%d",
|
||||||
result = "00" + string((otp/1000)+'0') + string(((otp/100)%10)+'0') + string(((otp/10)%10)+'0') + string((otp%10)+'0')
|
issuer, username, secret, issuer, TOTPDigits, TOTPTimeStep)
|
||||||
} else if otp < 100000 {
|
|
||||||
result = "0" + string((otp/10000)+'0') + string(((otp/1000)%10)+'0') + string(((otp/100)%10)+'0') + string(((otp/10)%10)+'0') + string((otp%10)+'0')
|
// Generate QR code
|
||||||
} else {
|
code, err := qr.Encode(uri, qr.M)
|
||||||
result = string((otp/100000)+'0') + string(((otp/10000)%10)+'0') + string(((otp/1000)%10)+'0') + string(((otp/100)%10)+'0') + string(((otp/10)%10)+'0') + string((otp%10)+'0')
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate QR code: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
// Create output file
|
||||||
}
|
file, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create output file: %w", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Write QR code as PNG
|
||||||
|
img := code.Image()
|
||||||
|
err = png.Encode(file, img)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write QR code to file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
50
data/user.go
50
data/user.go
|
@ -14,6 +14,16 @@ const (
|
||||||
scryptN = 32768
|
scryptN = 32768
|
||||||
scryptR = 8
|
scryptR = 8
|
||||||
scryptP = 2
|
scryptP = 2
|
||||||
|
|
||||||
|
// Constants for derived key length and comparison
|
||||||
|
derivedKeyLength = 32
|
||||||
|
validCompareResult = 1
|
||||||
|
|
||||||
|
// Empty string constant
|
||||||
|
emptyString = ""
|
||||||
|
|
||||||
|
// TOTP secret length in bytes (160 bits)
|
||||||
|
totpSecretLength = 20
|
||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
|
@ -48,16 +58,17 @@ func (u *User) GetPermissions(authService *AuthorizationService) ([]Permission,
|
||||||
|
|
||||||
type Login struct {
|
type Login struct {
|
||||||
User string `json:"user"`
|
User string `json:"user"`
|
||||||
Password string `json:"password,omitzero"`
|
Password string `json:"password,omitempty"`
|
||||||
Token string `json:"token,omitzero"`
|
Token string `json:"token,omitempty"`
|
||||||
TOTPCode string `json:"totp_code,omitzero"`
|
TOTPCode string `json:"totp_code,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func derive(password string, salt []byte) ([]byte, error) {
|
func derive(password string, salt []byte) ([]byte, error) {
|
||||||
return scrypt.Key([]byte(password), salt, scryptN, scryptR, scryptP, 32)
|
return scrypt.Key([]byte(password), salt, scryptN, scryptR, scryptP, derivedKeyLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) Check(login *Login) bool {
|
// CheckPassword verifies only the username and password, without TOTP verification
|
||||||
|
func (u *User) CheckPassword(login *Login) bool {
|
||||||
if u.User != login.User {
|
if u.User != login.User {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -67,18 +78,23 @@ func (u *User) Check(login *Login) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if subtle.ConstantTimeCompare(derived, u.Password) != 1 {
|
return subtle.ConstantTimeCompare(derived, u.Password) == validCompareResult
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) Check(login *Login) bool {
|
||||||
|
// First check username and password
|
||||||
|
if !u.CheckPassword(login) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// If TOTP is enabled for the user, validate the TOTP code
|
// If TOTP is enabled for the user, validate the TOTP code
|
||||||
if u.TOTPSecret != "" && login.TOTPCode != "" {
|
if u.TOTPSecret != emptyString && login.TOTPCode != emptyString {
|
||||||
// Use the ValidateTOTPCode method to validate the TOTP code
|
// Use the ValidateTOTPCode method to validate the TOTP code
|
||||||
valid, err := u.ValidateTOTPCode(login.TOTPCode)
|
valid, validErr := u.ValidateTOTPCode(login.TOTPCode)
|
||||||
if err != nil || !valid {
|
if validErr != nil || !valid {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
} else if u.TOTPSecret != "" && login.TOTPCode == "" {
|
} else if u.TOTPSecret != emptyString && login.TOTPCode == emptyString {
|
||||||
// TOTP is enabled but no code was provided
|
// TOTP is enabled but no code was provided
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -89,11 +105,11 @@ func (u *User) Check(login *Login) bool {
|
||||||
func (u *User) Register(login *Login) error {
|
func (u *User) Register(login *Login) error {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if u.User != "" && u.User != login.User {
|
if u.User != emptyString && u.User != login.User {
|
||||||
return errors.New("invalid user")
|
return errors.New("invalid user")
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.ID == "" {
|
if u.ID == emptyString {
|
||||||
u.ID = ulid.Make().String()
|
u.ID = ulid.Make().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,9 +131,9 @@ func (u *User) Register(login *Login) error {
|
||||||
// GenerateTOTPSecret generates a new TOTP secret for the user
|
// GenerateTOTPSecret generates a new TOTP secret for the user
|
||||||
func (u *User) GenerateTOTPSecret() (string, error) {
|
func (u *User) GenerateTOTPSecret() (string, error) {
|
||||||
// Generate a random secret
|
// Generate a random secret
|
||||||
secret, err := GenerateRandomBase32(20) // 20 bytes = 160 bits
|
secret, err := GenerateRandomBase32(totpSecretLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to generate TOTP secret: %w", err)
|
return emptyString, fmt.Errorf("failed to generate TOTP secret: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
u.TOTPSecret = secret
|
u.TOTPSecret = secret
|
||||||
|
@ -126,7 +142,7 @@ func (u *User) GenerateTOTPSecret() (string, error) {
|
||||||
|
|
||||||
// ValidateTOTPCode validates a TOTP code against the user's TOTP secret
|
// ValidateTOTPCode validates a TOTP code against the user's TOTP secret
|
||||||
func (u *User) ValidateTOTPCode(code string) (bool, error) {
|
func (u *User) ValidateTOTPCode(code string) (bool, error) {
|
||||||
if u.TOTPSecret == "" {
|
if u.TOTPSecret == emptyString {
|
||||||
return false, errors.New("TOTP not enabled for user")
|
return false, errors.New("TOTP not enabled for user")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,5 +153,5 @@ func (u *User) ValidateTOTPCode(code string) (bool, error) {
|
||||||
|
|
||||||
// HasTOTP returns true if TOTP is enabled for the user
|
// HasTOTP returns true if TOTP is enabled for the user
|
||||||
func (u *User) HasTOTP() bool {
|
func (u *User) HasTOTP() bool {
|
||||||
return u.TOTPSecret != ""
|
return u.TOTPSecret != emptyString
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
-- Drop tables in reverse order of creation to avoid foreign key constraints
|
||||||
|
DROP TABLE IF EXISTS role_permissions;
|
||||||
|
DROP TABLE IF EXISTS permissions;
|
||||||
|
DROP TABLE IF EXISTS user_roles;
|
||||||
|
DROP TABLE IF EXISTS roles;
|
||||||
|
DROP TABLE IF EXISTS registrations;
|
||||||
|
DROP TABLE IF EXISTS database;
|
||||||
|
DROP TABLE IF EXISTS tokens;
|
||||||
|
DROP TABLE IF EXISTS users;
|
|
@ -0,0 +1,84 @@
|
||||||
|
CREATE TABLE users (
|
||||||
|
id text primary key,
|
||||||
|
created integer,
|
||||||
|
user text not null,
|
||||||
|
password blob not null,
|
||||||
|
salt blob not null,
|
||||||
|
totp_secret text
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE tokens (
|
||||||
|
id text primary key,
|
||||||
|
uid text not null,
|
||||||
|
token text not null,
|
||||||
|
expires integer default 0,
|
||||||
|
FOREIGN KEY(uid) REFERENCES user(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE database (
|
||||||
|
id text primary key,
|
||||||
|
host text not null,
|
||||||
|
port integer default 5432,
|
||||||
|
name text not null,
|
||||||
|
user text not null,
|
||||||
|
password text not null
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE registrations (
|
||||||
|
id text primary key,
|
||||||
|
code text not null
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE roles (
|
||||||
|
id text primary key,
|
||||||
|
role text not null
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE user_roles (
|
||||||
|
id text primary key,
|
||||||
|
uid text not null,
|
||||||
|
rid text not null,
|
||||||
|
FOREIGN KEY(uid) REFERENCES user(id),
|
||||||
|
FOREIGN KEY(rid) REFERENCES roles(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Add permissions table
|
||||||
|
CREATE TABLE permissions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
resource TEXT NOT NULL,
|
||||||
|
action TEXT NOT NULL,
|
||||||
|
description TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Link roles to permissions
|
||||||
|
CREATE TABLE role_permissions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
rid TEXT NOT NULL,
|
||||||
|
pid TEXT NOT NULL,
|
||||||
|
FOREIGN KEY(rid) REFERENCES roles(id),
|
||||||
|
FOREIGN KEY(pid) REFERENCES permissions(id)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Add default permissions
|
||||||
|
INSERT INTO permissions (id, resource, action, description) VALUES
|
||||||
|
('perm_db_read', 'database_credentials', 'read', 'Read database credentials'),
|
||||||
|
('perm_db_write', 'database_credentials', 'write', 'Modify database credentials'),
|
||||||
|
('perm_user_manage', 'users', 'manage', 'Manage user accounts'),
|
||||||
|
('perm_token_manage', 'tokens', 'manage', 'Manage authentication tokens');
|
||||||
|
|
||||||
|
-- Add default roles
|
||||||
|
INSERT INTO roles (id, role) VALUES
|
||||||
|
('role_admin', 'admin'),
|
||||||
|
('role_db_operator', 'db_operator'),
|
||||||
|
('role_user', 'user');
|
||||||
|
|
||||||
|
-- Grant permissions to admin role
|
||||||
|
INSERT INTO role_permissions (id, rid, pid) VALUES
|
||||||
|
('rp_admin_db_read', 'role_admin', 'perm_db_read'),
|
||||||
|
('rp_admin_db_write', 'role_admin', 'perm_db_write'),
|
||||||
|
('rp_admin_user_manage', 'role_admin', 'perm_user_manage'),
|
||||||
|
('rp_admin_token_manage', 'role_admin', 'perm_token_manage');
|
||||||
|
|
||||||
|
-- Grant database access to db_operator role
|
||||||
|
INSERT INTO role_permissions (id, rid, pid) VALUES
|
||||||
|
('rp_dbop_db_read', 'role_db_operator', 'perm_db_read');
|
|
@ -15,4 +15,4 @@ func DefaultSchema() (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return string(schemaBytes), nil
|
return string(schemaBytes), nil
|
||||||
}
|
}
|
||||||
|
|
8
go.mod
8
go.mod
|
@ -2,10 +2,14 @@ module git.wntrmute.dev/kyle/mcias
|
||||||
|
|
||||||
go 1.23.8
|
go 1.23.8
|
||||||
|
|
||||||
|
require github.com/gokyle/twofactor v1.0.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
||||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||||
github.com/gokyle/twofactor v1.0.1
|
github.com/golang-migrate/migrate/v4 v4.18.3 // indirect
|
||||||
|
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||||
|
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||||
github.com/oklog/ulid/v2 v2.1.0 // indirect
|
github.com/oklog/ulid/v2 v2.1.0 // indirect
|
||||||
|
@ -18,7 +22,7 @@ require (
|
||||||
github.com/spf13/pflag v1.0.6 // indirect
|
github.com/spf13/pflag v1.0.6 // indirect
|
||||||
github.com/spf13/viper v1.20.1 // indirect
|
github.com/spf13/viper v1.20.1 // indirect
|
||||||
github.com/subosito/gotenv v1.6.0 // indirect
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
go.uber.org/atomic v1.9.0 // indirect
|
go.uber.org/atomic v1.11.0 // indirect
|
||||||
go.uber.org/multierr v1.9.0 // indirect
|
go.uber.org/multierr v1.9.0 // indirect
|
||||||
golang.org/x/crypto v0.38.0 // indirect
|
golang.org/x/crypto v0.38.0 // indirect
|
||||||
golang.org/x/sys v0.33.0 // indirect
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
|
|
9
go.sum
9
go.sum
|
@ -7,6 +7,13 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx
|
||||||
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||||
github.com/gokyle/twofactor v1.0.1 h1:uRhvx0S4Hb82RPIDALnf7QxbmPL49LyyaCkJDpWx+Ek=
|
github.com/gokyle/twofactor v1.0.1 h1:uRhvx0S4Hb82RPIDALnf7QxbmPL49LyyaCkJDpWx+Ek=
|
||||||
github.com/gokyle/twofactor v1.0.1/go.mod h1:4gxzH1eaE/F3Pct/sCDNOylP0ClofUO5j4XZN9tKtLE=
|
github.com/gokyle/twofactor v1.0.1/go.mod h1:4gxzH1eaE/F3Pct/sCDNOylP0ClofUO5j4XZN9tKtLE=
|
||||||
|
github.com/golang-migrate/migrate/v4 v4.18.3 h1:EYGkoOsvgHHfm5U/naS1RP/6PL/Xv3S4B/swMiAmDLs=
|
||||||
|
github.com/golang-migrate/migrate/v4 v4.18.3/go.mod h1:99BKpIi6ruaaXRM1A77eqZ+FWPQ3cfRa+ZVy5bmWMaY=
|
||||||
|
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
|
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||||
|
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
|
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||||
|
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||||
|
@ -38,6 +45,8 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8
|
||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||||
|
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||||
|
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||||
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
|
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
|
||||||
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
|
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
|
||||||
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
|
||||||
|
|
Loading…
Reference in New Issue