557 lines
15 KiB
Go
557 lines
15 KiB
Go
package api
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/kyle/mcias/data"
|
|
"github.com/Masterminds/squirrel"
|
|
"github.com/oklog/ulid/v2"
|
|
)
|
|
|
|
type LoginRequest struct {
|
|
Version string `json:"version"`
|
|
Login data.Login `json:"login"`
|
|
}
|
|
|
|
type TokenResponse struct {
|
|
Token string `json:"token"`
|
|
Expires int64 `json:"expires"`
|
|
}
|
|
|
|
type TOTPVerifyRequest struct {
|
|
Version string `json:"version"`
|
|
Username string `json:"username"`
|
|
TOTPCode string `json:"totp_code"`
|
|
}
|
|
|
|
type ErrorResponse struct {
|
|
Error string `json:"error"`
|
|
}
|
|
|
|
type DatabaseCredentials struct {
|
|
Host string `json:"host"`
|
|
Port int `json:"port"`
|
|
Name string `json:"name"`
|
|
User string `json:"user"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
|
|
var req LoginRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.sendError(w, "Invalid request format", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.Version != "v1" || req.Login.User == "" || req.Login.Password == "" {
|
|
s.sendError(w, "Invalid login request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
user, err := s.getUserByUsername(req.Login.User)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
s.sendError(w, "Invalid username or password", http.StatusUnauthorized)
|
|
} else {
|
|
s.Logger.Printf("Database error: %v", err)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Check password only
|
|
if !user.CheckPassword(&req.Login) {
|
|
s.sendError(w, "Invalid username or password", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Password is correct, create a token regardless of TOTP status
|
|
token, expires, err := s.createToken(user.ID)
|
|
if err != nil {
|
|
s.Logger.Printf("Token creation error: %v", err)
|
|
// Log the security event
|
|
details := map[string]string{
|
|
"reason": "Token creation error",
|
|
"error": err.Error(),
|
|
}
|
|
s.LogSecurityEvent(r, "login_attempt", user.ID, user.User, false, details)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// If user has TOTP enabled, include this information in the response
|
|
totpEnabled := user.HasTOTP()
|
|
|
|
// Log successful login
|
|
details := map[string]string{
|
|
"token_expires": time.Unix(expires, 0).UTC().Format(time.RFC3339),
|
|
"totp_enabled": fmt.Sprintf("%t", totpEnabled),
|
|
}
|
|
s.LogSecurityEvent(r, "login_success", user.ID, user.User, true, details)
|
|
|
|
// Include TOTP status in the response
|
|
response := struct {
|
|
TokenResponse
|
|
TOTPEnabled bool `json:"totp_enabled"`
|
|
}{
|
|
TokenResponse: TokenResponse{
|
|
Token: token,
|
|
Expires: expires,
|
|
},
|
|
TOTPEnabled: totpEnabled,
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
s.Logger.Printf("Error encoding response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleTokenLogin(w http.ResponseWriter, r *http.Request) {
|
|
var req LoginRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.sendError(w, "Invalid request format", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.Version != "v1" || req.Login.User == "" || req.Login.Token == "" {
|
|
s.sendError(w, "Invalid login request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Verify the token is valid
|
|
_, err := s.verifyToken(req.Login.User, req.Login.Token)
|
|
if err != nil {
|
|
s.sendError(w, "Invalid or expired token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Renew the existing token instead of creating a new one
|
|
expires, err := s.renewToken(req.Login.User, req.Login.Token)
|
|
if err != nil {
|
|
s.Logger.Printf("Token renewal error: %v", err)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(TokenResponse{
|
|
Token: req.Login.Token,
|
|
Expires: expires,
|
|
}); err != nil {
|
|
s.Logger.Printf("Error encoding response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (s *Server) sendError(w http.ResponseWriter, message string, status int) {
|
|
// Log the detailed error message for debugging
|
|
s.Logger.Printf("Error response: %s (Status: %d)", message, status)
|
|
|
|
// Create a generic error message based on status code
|
|
publicMessage := "An error occurred processing your request"
|
|
errorCode := fmt.Sprintf("E%000d", status)
|
|
|
|
// Customize public messages for common status codes
|
|
// but don't leak specific details about the error
|
|
switch status {
|
|
case http.StatusBadRequest:
|
|
publicMessage = "Invalid request format"
|
|
case http.StatusUnauthorized:
|
|
publicMessage = "Authentication required"
|
|
case http.StatusForbidden:
|
|
publicMessage = "Insufficient permissions"
|
|
case http.StatusNotFound:
|
|
publicMessage = "Resource not found"
|
|
case http.StatusTooManyRequests:
|
|
publicMessage = "Rate limit exceeded"
|
|
case http.StatusInternalServerError:
|
|
publicMessage = "Internal server error"
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
|
|
response := struct {
|
|
Error string `json:"error"`
|
|
ErrorCode string `json:"error_code"`
|
|
}{
|
|
Error: publicMessage,
|
|
ErrorCode: errorCode,
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
s.Logger.Printf("Error encoding error response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (s *Server) getUserByUsername(username string) (*data.User, error) {
|
|
// Use squirrel to build the query safely
|
|
query, args, err := squirrel.Select("id", "created", "user", "password", "salt", "totp_secret").
|
|
From("users").
|
|
Where(squirrel.Eq{"user": username}).
|
|
ToSql()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
row := s.DB.QueryRow(query, args...)
|
|
|
|
user := &data.User{}
|
|
var totpSecret sql.NullString
|
|
err = row.Scan(&user.ID, &user.Created, &user.User, &user.Password, &user.Salt, &totpSecret)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set TOTP secret if it exists
|
|
if totpSecret.Valid {
|
|
user.TOTPSecret = totpSecret.String
|
|
}
|
|
|
|
// Use squirrel to build the roles query safely
|
|
rolesQuery, rolesArgs, err := squirrel.Select("r.role").
|
|
From("roles r").
|
|
Join("user_roles ur ON r.id = ur.rid").
|
|
Where(squirrel.Eq{"ur.uid": user.ID}).
|
|
ToSql()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := s.DB.Query(rolesQuery, rolesArgs...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var roles []string
|
|
for rows.Next() {
|
|
var role string
|
|
if err := rows.Scan(&role); err != nil {
|
|
return nil, err
|
|
}
|
|
roles = append(roles, role)
|
|
}
|
|
user.Roles = roles
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (s *Server) createToken(userID string) (string, int64, error) {
|
|
// Generate 16 bytes of random data
|
|
tokenBytes := make([]byte, 16)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
// Hex encode the random bytes to get a 32-character string.
|
|
token := hex.EncodeToString(tokenBytes)
|
|
|
|
expires := time.Now().Add(24 * time.Hour).Unix()
|
|
tokenID := ulid.Make().String()
|
|
|
|
// Use squirrel to build the insert query safely
|
|
query, args, err := squirrel.Insert("tokens").
|
|
Columns("id", "uid", "token", "expires").
|
|
Values(tokenID, userID, token, expires).
|
|
ToSql()
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
_, err = s.DB.Exec(query, args...)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
return token, expires, nil
|
|
}
|
|
|
|
func (s *Server) verifyToken(username, token string) (string, error) {
|
|
// Use squirrel to build the query safely
|
|
query, args, err := squirrel.Select("t.uid", "t.expires").
|
|
From("tokens t").
|
|
Join("users u ON t.uid = u.id").
|
|
Where(squirrel.And{
|
|
squirrel.Eq{"u.user": username},
|
|
squirrel.Eq{"t.token": token},
|
|
}).
|
|
ToSql()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
var userID string
|
|
var expires int64
|
|
err = s.DB.QueryRow(query, args...).Scan(&userID, &expires)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if expires > 0 && expires < time.Now().Unix() {
|
|
return "", errors.New("token expired")
|
|
}
|
|
|
|
return userID, nil
|
|
}
|
|
|
|
func (s *Server) renewToken(username, token string) (int64, error) {
|
|
// First, verify the token exists and get the token ID
|
|
// Use squirrel to build the query safely
|
|
query, args, err := squirrel.Select("t.id").
|
|
From("tokens t").
|
|
Join("users u ON t.uid = u.id").
|
|
Where(squirrel.And{
|
|
squirrel.Eq{"u.user": username},
|
|
squirrel.Eq{"t.token": token},
|
|
}).
|
|
ToSql()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
var tokenID string
|
|
err = s.DB.QueryRow(query, args...).Scan(&tokenID)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Update the token's expiry time
|
|
expires := time.Now().Add(24 * time.Hour).Unix()
|
|
|
|
// Use squirrel to build the update query safely
|
|
updateQuery, updateArgs, err := squirrel.Update("tokens").
|
|
Set("expires", expires).
|
|
Where(squirrel.Eq{"id": tokenID}).
|
|
ToSql()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
_, err = s.DB.Exec(updateQuery, updateArgs...)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return expires, nil
|
|
}
|
|
|
|
func (s *Server) handleTOTPVerify(w http.ResponseWriter, r *http.Request) {
|
|
var req TOTPVerifyRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
s.sendError(w, "Invalid request format", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.Version != "v1" || req.Username == "" || req.TOTPCode == "" {
|
|
s.sendError(w, "Invalid TOTP verification request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
user, err := s.getUserByUsername(req.Username)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
s.sendError(w, "User not found", http.StatusUnauthorized)
|
|
} else {
|
|
s.Logger.Printf("Database error: %v", err)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Check if TOTP is enabled for the user
|
|
if !user.HasTOTP() {
|
|
// Log the security event
|
|
details := map[string]string{
|
|
"reason": "TOTP not enabled for user",
|
|
}
|
|
s.LogSecurityEvent(r, "totp_verification_attempt", user.ID, user.User, false, details)
|
|
s.sendError(w, "TOTP not enabled for user", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate the TOTP code
|
|
valid, validErr := user.ValidateTOTPCode(req.TOTPCode)
|
|
if validErr != nil || !valid {
|
|
// Log the security event
|
|
details := map[string]string{
|
|
"reason": "Invalid TOTP code",
|
|
}
|
|
s.LogSecurityEvent(r, "totp_verification_attempt", user.ID, user.User, false, details)
|
|
s.sendError(w, "Invalid TOTP code", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// TOTP code is valid, create a token
|
|
token, expires, err := s.createToken(user.ID)
|
|
if err != nil {
|
|
s.Logger.Printf("Token creation error: %v", err)
|
|
// Log the security event
|
|
details := map[string]string{
|
|
"reason": "Token creation error",
|
|
"error": err.Error(),
|
|
}
|
|
s.LogSecurityEvent(r, "totp_verification_attempt", user.ID, user.User, false, details)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Log successful TOTP verification
|
|
details := map[string]string{
|
|
"token_expires": time.Unix(expires, 0).UTC().Format(time.RFC3339),
|
|
}
|
|
s.LogSecurityEvent(r, "totp_verification_success", user.ID, user.User, true, details)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(TokenResponse{
|
|
Token: token,
|
|
Expires: expires,
|
|
}); err != nil {
|
|
s.Logger.Printf("Error encoding response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleDatabaseCredentials(w http.ResponseWriter, r *http.Request) {
|
|
// Extract authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
s.sendError(w, "Authorization header required", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Check if it's a Bearer token
|
|
parts := strings.Split(authHeader, " ")
|
|
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
|
s.sendError(w, "Invalid authorization format", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
token := parts[1]
|
|
username := r.URL.Query().Get("username")
|
|
if username == "" {
|
|
s.sendError(w, "Username parameter required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Verify the token
|
|
userID, err := s.verifyToken(username, token)
|
|
if err != nil {
|
|
s.sendError(w, "Invalid or expired token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Check if user has permission to read database credentials
|
|
user, err := s.getUserByUsername(username)
|
|
if err != nil {
|
|
s.Logger.Printf("Database error: %v", err)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
hasPermission, err := user.HasPermission(s.Auth, "database_credentials", "read")
|
|
if err != nil {
|
|
s.Logger.Printf("Permission check error: %v", err)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if !hasPermission {
|
|
s.sendError(w, "Insufficient permissions: requires database_credentials:read permission", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Get database ID from query parameter if provided
|
|
databaseID := r.URL.Query().Get("database_id")
|
|
|
|
// Build the query to retrieve databases the user has access to
|
|
queryBuilder := squirrel.Select("d.id", "d.host", "d.port", "d.name", "d.user", "d.password").
|
|
From("database d")
|
|
|
|
// If the user is an admin, they can see all databases
|
|
isAdmin := false
|
|
for _, role := range user.Roles {
|
|
if role == "admin" {
|
|
isAdmin = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !isAdmin {
|
|
// Non-admin users can only see databases they're explicitly associated with
|
|
queryBuilder = queryBuilder.Join("database_users du ON d.id = du.db_id").
|
|
Where(squirrel.Eq{"du.uid": userID})
|
|
}
|
|
|
|
// If a specific database ID was requested, filter by that
|
|
if databaseID != "" {
|
|
queryBuilder = queryBuilder.Where(squirrel.Eq{"d.id": databaseID})
|
|
}
|
|
|
|
query, args, err := queryBuilder.ToSql()
|
|
if err != nil {
|
|
s.Logger.Printf("Query building error: %v", err)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
rows, err := s.DB.Query(query, args...)
|
|
if err != nil {
|
|
s.Logger.Printf("Database error: %v", err)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var databases []DatabaseCredentials
|
|
for rows.Next() {
|
|
var id string
|
|
var creds DatabaseCredentials
|
|
err = rows.Scan(&id, &creds.Host, &creds.Port, &creds.Name, &creds.User, &creds.Password)
|
|
if err != nil {
|
|
s.Logger.Printf("Row scanning error: %v", err)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
databases = append(databases, creds)
|
|
}
|
|
|
|
if err = rows.Err(); err != nil {
|
|
s.Logger.Printf("Rows iteration error: %v", err)
|
|
s.sendError(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if len(databases) == 0 {
|
|
s.sendError(w, "No database credentials found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// If a specific database was requested, return just that one
|
|
if databaseID != "" && len(databases) == 1 {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(databases[0]); err != nil {
|
|
s.Logger.Printf("Error encoding response: %v", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Otherwise return all accessible databases
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(databases); err != nil {
|
|
s.Logger.Printf("Error encoding response: %v", err)
|
|
}
|
|
}
|