mcias/api/auth.go

557 lines
15 KiB
Go

package api
import (
"crypto/rand"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"git.wntrmute.dev/kyle/mcias/data"
"github.com/Masterminds/squirrel"
"github.com/oklog/ulid/v2"
)
type LoginRequest struct {
Version string `json:"version"`
Login data.Login `json:"login"`
}
type TokenResponse struct {
Token string `json:"token"`
Expires int64 `json:"expires"`
}
type TOTPVerifyRequest struct {
Version string `json:"version"`
Username string `json:"username"`
TOTPCode string `json:"totp_code"`
}
type ErrorResponse struct {
Error string `json:"error"`
}
type DatabaseCredentials struct {
Host string `json:"host"`
Port int `json:"port"`
Name string `json:"name"`
User string `json:"user"`
Password string `json:"password"`
}
func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.sendError(w, "Invalid request format", http.StatusBadRequest)
return
}
if req.Version != "v1" || req.Login.User == "" || req.Login.Password == "" {
s.sendError(w, "Invalid login request", http.StatusBadRequest)
return
}
user, err := s.getUserByUsername(req.Login.User)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
s.sendError(w, "Invalid username or password", http.StatusUnauthorized)
} else {
s.Logger.Printf("Database error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
}
return
}
// Check password only
if !user.CheckPassword(&req.Login) {
s.sendError(w, "Invalid username or password", http.StatusUnauthorized)
return
}
// Password is correct, create a token regardless of TOTP status
token, expires, err := s.createToken(user.ID)
if err != nil {
s.Logger.Printf("Token creation error: %v", err)
// Log the security event
details := map[string]string{
"reason": "Token creation error",
"error": err.Error(),
}
s.LogSecurityEvent(r, "login_attempt", user.ID, user.User, false, details)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
// If user has TOTP enabled, include this information in the response
totpEnabled := user.HasTOTP()
// Log successful login
details := map[string]string{
"token_expires": time.Unix(expires, 0).UTC().Format(time.RFC3339),
"totp_enabled": fmt.Sprintf("%t", totpEnabled),
}
s.LogSecurityEvent(r, "login_success", user.ID, user.User, true, details)
// Include TOTP status in the response
response := struct {
TokenResponse
TOTPEnabled bool `json:"totp_enabled"`
}{
TokenResponse: TokenResponse{
Token: token,
Expires: expires,
},
TOTPEnabled: totpEnabled,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(response); err != nil {
s.Logger.Printf("Error encoding response: %v", err)
}
}
func (s *Server) handleTokenLogin(w http.ResponseWriter, r *http.Request) {
var req LoginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.sendError(w, "Invalid request format", http.StatusBadRequest)
return
}
if req.Version != "v1" || req.Login.User == "" || req.Login.Token == "" {
s.sendError(w, "Invalid login request", http.StatusBadRequest)
return
}
// Verify the token is valid
_, err := s.verifyToken(req.Login.User, req.Login.Token)
if err != nil {
s.sendError(w, "Invalid or expired token", http.StatusUnauthorized)
return
}
// Renew the existing token instead of creating a new one
expires, err := s.renewToken(req.Login.User, req.Login.Token)
if err != nil {
s.Logger.Printf("Token renewal error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(TokenResponse{
Token: req.Login.Token,
Expires: expires,
}); err != nil {
s.Logger.Printf("Error encoding response: %v", err)
}
}
func (s *Server) sendError(w http.ResponseWriter, message string, status int) {
// Log the detailed error message for debugging
s.Logger.Printf("Error response: %s (Status: %d)", message, status)
// Create a generic error message based on status code
publicMessage := "An error occurred processing your request"
errorCode := fmt.Sprintf("E%000d", status)
// Customize public messages for common status codes
// but don't leak specific details about the error
switch status {
case http.StatusBadRequest:
publicMessage = "Invalid request format"
case http.StatusUnauthorized:
publicMessage = "Authentication required"
case http.StatusForbidden:
publicMessage = "Insufficient permissions"
case http.StatusNotFound:
publicMessage = "Resource not found"
case http.StatusTooManyRequests:
publicMessage = "Rate limit exceeded"
case http.StatusInternalServerError:
publicMessage = "Internal server error"
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
response := struct {
Error string `json:"error"`
ErrorCode string `json:"error_code"`
}{
Error: publicMessage,
ErrorCode: errorCode,
}
if err := json.NewEncoder(w).Encode(response); err != nil {
s.Logger.Printf("Error encoding error response: %v", err)
}
}
func (s *Server) getUserByUsername(username string) (*data.User, error) {
// Use squirrel to build the query safely
query, args, err := squirrel.Select("id", "created", "user", "password", "salt", "totp_secret").
From("users").
Where(squirrel.Eq{"user": username}).
ToSql()
if err != nil {
return nil, err
}
row := s.DB.QueryRow(query, args...)
user := &data.User{}
var totpSecret sql.NullString
err = row.Scan(&user.ID, &user.Created, &user.User, &user.Password, &user.Salt, &totpSecret)
if err != nil {
return nil, err
}
// Set TOTP secret if it exists
if totpSecret.Valid {
user.TOTPSecret = totpSecret.String
}
// Use squirrel to build the roles query safely
rolesQuery, rolesArgs, err := squirrel.Select("r.role").
From("roles r").
Join("user_roles ur ON r.id = ur.rid").
Where(squirrel.Eq{"ur.uid": user.ID}).
ToSql()
if err != nil {
return nil, err
}
rows, err := s.DB.Query(rolesQuery, rolesArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
var roles []string
for rows.Next() {
var role string
if err := rows.Scan(&role); err != nil {
return nil, err
}
roles = append(roles, role)
}
user.Roles = roles
return user, nil
}
func (s *Server) createToken(userID string) (string, int64, error) {
// Generate 16 bytes of random data
tokenBytes := make([]byte, 16)
if _, err := rand.Read(tokenBytes); err != nil {
return "", 0, err
}
// Hex encode the random bytes to get a 32-character string.
token := hex.EncodeToString(tokenBytes)
expires := time.Now().Add(24 * time.Hour).Unix()
tokenID := ulid.Make().String()
// Use squirrel to build the insert query safely
query, args, err := squirrel.Insert("tokens").
Columns("id", "uid", "token", "expires").
Values(tokenID, userID, token, expires).
ToSql()
if err != nil {
return "", 0, err
}
_, err = s.DB.Exec(query, args...)
if err != nil {
return "", 0, err
}
return token, expires, nil
}
func (s *Server) verifyToken(username, token string) (string, error) {
// Use squirrel to build the query safely
query, args, err := squirrel.Select("t.uid", "t.expires").
From("tokens t").
Join("users u ON t.uid = u.id").
Where(squirrel.And{
squirrel.Eq{"u.user": username},
squirrel.Eq{"t.token": token},
}).
ToSql()
if err != nil {
return "", err
}
var userID string
var expires int64
err = s.DB.QueryRow(query, args...).Scan(&userID, &expires)
if err != nil {
return "", err
}
if expires > 0 && expires < time.Now().Unix() {
return "", errors.New("token expired")
}
return userID, nil
}
func (s *Server) renewToken(username, token string) (int64, error) {
// First, verify the token exists and get the token ID
// Use squirrel to build the query safely
query, args, err := squirrel.Select("t.id").
From("tokens t").
Join("users u ON t.uid = u.id").
Where(squirrel.And{
squirrel.Eq{"u.user": username},
squirrel.Eq{"t.token": token},
}).
ToSql()
if err != nil {
return 0, err
}
var tokenID string
err = s.DB.QueryRow(query, args...).Scan(&tokenID)
if err != nil {
return 0, err
}
// Update the token's expiry time
expires := time.Now().Add(24 * time.Hour).Unix()
// Use squirrel to build the update query safely
updateQuery, updateArgs, err := squirrel.Update("tokens").
Set("expires", expires).
Where(squirrel.Eq{"id": tokenID}).
ToSql()
if err != nil {
return 0, err
}
_, err = s.DB.Exec(updateQuery, updateArgs...)
if err != nil {
return 0, err
}
return expires, nil
}
func (s *Server) handleTOTPVerify(w http.ResponseWriter, r *http.Request) {
var req TOTPVerifyRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.sendError(w, "Invalid request format", http.StatusBadRequest)
return
}
if req.Version != "v1" || req.Username == "" || req.TOTPCode == "" {
s.sendError(w, "Invalid TOTP verification request", http.StatusBadRequest)
return
}
user, err := s.getUserByUsername(req.Username)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
s.sendError(w, "User not found", http.StatusUnauthorized)
} else {
s.Logger.Printf("Database error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
}
return
}
// Check if TOTP is enabled for the user
if !user.HasTOTP() {
// Log the security event
details := map[string]string{
"reason": "TOTP not enabled for user",
}
s.LogSecurityEvent(r, "totp_verification_attempt", user.ID, user.User, false, details)
s.sendError(w, "TOTP not enabled for user", http.StatusBadRequest)
return
}
// Validate the TOTP code
valid, validErr := user.ValidateTOTPCode(req.TOTPCode)
if validErr != nil || !valid {
// Log the security event
details := map[string]string{
"reason": "Invalid TOTP code",
}
s.LogSecurityEvent(r, "totp_verification_attempt", user.ID, user.User, false, details)
s.sendError(w, "Invalid TOTP code", http.StatusUnauthorized)
return
}
// TOTP code is valid, create a token
token, expires, err := s.createToken(user.ID)
if err != nil {
s.Logger.Printf("Token creation error: %v", err)
// Log the security event
details := map[string]string{
"reason": "Token creation error",
"error": err.Error(),
}
s.LogSecurityEvent(r, "totp_verification_attempt", user.ID, user.User, false, details)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
// Log successful TOTP verification
details := map[string]string{
"token_expires": time.Unix(expires, 0).UTC().Format(time.RFC3339),
}
s.LogSecurityEvent(r, "totp_verification_success", user.ID, user.User, true, details)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(TokenResponse{
Token: token,
Expires: expires,
}); err != nil {
s.Logger.Printf("Error encoding response: %v", err)
}
}
func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Request) {
// Extract authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.sendError(w, "Authorization header required", http.StatusUnauthorized)
return
}
// Check if it's a Bearer token
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
s.sendError(w, "Invalid authorization format", http.StatusUnauthorized)
return
}
token := parts[1]
username := r.URL.Query().Get("username")
if username == "" {
s.sendError(w, "Username parameter required", http.StatusBadRequest)
return
}
// Verify the token
userID, err := s.verifyToken(username, token)
if err != nil {
s.sendError(w, "Invalid or expired token", http.StatusUnauthorized)
return
}
// Check if user has permission to read database credentials
user, err := s.getUserByUsername(username)
if err != nil {
s.Logger.Printf("Database error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
hasPermission, err := user.HasPermission(s.Auth, "database_credentials", "read")
if err != nil {
s.Logger.Printf("Permission check error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
if !hasPermission {
s.sendError(w, "Insufficient permissions: requires database_credentials:read permission", http.StatusForbidden)
return
}
// Get database ID from query parameter if provided
databaseID := r.URL.Query().Get("database_id")
// Build the query to retrieve databases the user has access to
queryBuilder := squirrel.Select("d.id", "d.host", "d.port", "d.name", "d.user", "d.password").
From("database d")
// If the user is an admin, they can see all databases
isAdmin := false
for _, role := range user.Roles {
if role == "admin" {
isAdmin = true
break
}
}
if !isAdmin {
// Non-admin users can only see databases they're explicitly associated with
queryBuilder = queryBuilder.Join("database_users du ON d.id = du.db_id").
Where(squirrel.Eq{"du.uid": userID})
}
// If a specific database ID was requested, filter by that
if databaseID != "" {
queryBuilder = queryBuilder.Where(squirrel.Eq{"d.id": databaseID})
}
query, args, err := queryBuilder.ToSql()
if err != nil {
s.Logger.Printf("Query building error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
rows, err := s.DB.Query(query, args...)
if err != nil {
s.Logger.Printf("Database error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
defer rows.Close()
var databases []DatabaseCredentials
for rows.Next() {
var id string
var creds DatabaseCredentials
err = rows.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password)
if err != nil {
s.Logger.Printf("Row scanning error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
databases = append(databases, creds)
}
if err = rows.Err(); err != nil {
s.Logger.Printf("Rows iteration error: %v", err)
s.sendError(w, "Internal server error", http.StatusInternalServerError)
return
}
if len(databases) == 0 {
s.sendError(w, "No database credentials found", http.StatusNotFound)
return
}
// If a specific database was requested, return just that one
if databaseID != "" && len(databases) == 1 {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(databases[0]); err != nil {
s.Logger.Printf("Error encoding response: %v", err)
}
return
}
// Otherwise return all accessible databases
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(databases); err != nil {
s.Logger.Printf("Error encoding response: %v", err)
}
}