Starting over.
This commit is contained in:
183
data/auth.go
183
data/auth.go
@@ -1,183 +0,0 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// Constants for error messages
|
||||
errScanPermission = "failed to scan permission: %w"
|
||||
errIteratePermissions = "error iterating permissions: %w"
|
||||
|
||||
// Constants for comparison
|
||||
zeroCount = 0
|
||||
)
|
||||
|
||||
// Permission represents a system permission.
|
||||
type Permission struct {
|
||||
ID string
|
||||
Resource string
|
||||
Action string
|
||||
Description string
|
||||
}
|
||||
|
||||
// AuthorizationService provides methods for checking user permissions.
|
||||
type AuthorizationService struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewAuthorizationService creates a new authorization service.
|
||||
func NewAuthorizationService(db *sql.DB) *AuthorizationService {
|
||||
return &AuthorizationService{db: db}
|
||||
}
|
||||
|
||||
// UserHasPermission checks if a user has a specific permission for a resource and action
|
||||
func (a *AuthorizationService) UserHasPermission(userID, resource, action string) (bool, error) {
|
||||
query := `
|
||||
SELECT COUNT(*) FROM permissions p
|
||||
JOIN role_permissions rp ON p.id = rp.pid
|
||||
JOIN user_roles ur ON rp.rid = ur.rid
|
||||
WHERE ur.uid = ? AND p.resource = ? AND p.action = ?
|
||||
`
|
||||
|
||||
var count int
|
||||
err := a.db.QueryRow(query, userID, resource, action).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check user permission: %w", err)
|
||||
}
|
||||
|
||||
return count > zeroCount, nil
|
||||
}
|
||||
|
||||
// GetUserPermissions returns all permissions for a user based on their roles.
|
||||
func (a *AuthorizationService) GetUserPermissions(userID string) ([]Permission, error) {
|
||||
query := `
|
||||
SELECT DISTINCT p.id, p.resource, p.action, p.description FROM permissions p
|
||||
JOIN role_permissions rp ON p.id = rp.pid
|
||||
JOIN user_roles ur ON rp.rid = ur.rid
|
||||
WHERE ur.uid = ?
|
||||
`
|
||||
|
||||
rows, err := a.db.Query(query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user permissions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var permissions []Permission
|
||||
for rows.Next() {
|
||||
var perm Permission
|
||||
if scanErr := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); scanErr != nil {
|
||||
return nil, fmt.Errorf(errScanPermission, scanErr)
|
||||
}
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
|
||||
if rowErr := rows.Err(); rowErr != nil {
|
||||
return nil, fmt.Errorf(errIteratePermissions, rowErr)
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// GetRolePermissions returns all permissions for a specific role.
|
||||
func (a *AuthorizationService) GetRolePermissions(roleID string) ([]Permission, error) {
|
||||
query := `
|
||||
SELECT p.id, p.resource, p.action, p.description FROM permissions p
|
||||
JOIN role_permissions rp ON p.id = rp.pid
|
||||
WHERE rp.rid = ?
|
||||
`
|
||||
|
||||
rows, err := a.db.Query(query, roleID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get role permissions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var permissions []Permission
|
||||
for rows.Next() {
|
||||
var perm Permission
|
||||
if scanErr := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); scanErr != nil {
|
||||
return nil, fmt.Errorf(errScanPermission, scanErr)
|
||||
}
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
|
||||
if rowErr := rows.Err(); rowErr != nil {
|
||||
return nil, fmt.Errorf(errIteratePermissions, rowErr)
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// GrantPermissionToRole grants a permission to a role
|
||||
func (a *AuthorizationService) GrantPermissionToRole(roleID, permissionID string) error {
|
||||
// Check if the role-permission relationship already exists
|
||||
checkQuery := `SELECT COUNT(*) FROM role_permissions WHERE rid = ? AND pid = ?`
|
||||
var count int
|
||||
err := a.db.QueryRow(checkQuery, roleID, permissionID).Scan(&count)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check role permission: %w", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return nil // Permission already granted
|
||||
}
|
||||
|
||||
// Generate a new ID for the role-permission relationship
|
||||
id := GenerateID()
|
||||
|
||||
// Insert the new role-permission relationship
|
||||
insertQuery := `INSERT INTO role_permissions (id, rid, pid) VALUES (?, ?, ?)`
|
||||
_, err = a.db.Exec(insertQuery, id, roleID, permissionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to grant permission to role: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokePermissionFromRole revokes a permission from a role
|
||||
func (a *AuthorizationService) RevokePermissionFromRole(roleID, permissionID string) error {
|
||||
query := `DELETE FROM role_permissions WHERE rid = ? AND pid = ?`
|
||||
_, err := a.db.Exec(query, roleID, permissionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke permission from role: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllPermissions returns all permissions in the system.
|
||||
func (a *AuthorizationService) GetAllPermissions() ([]Permission, error) {
|
||||
query := `SELECT id, resource, action, description FROM permissions`
|
||||
|
||||
rows, err := a.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get permissions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var permissions []Permission
|
||||
for rows.Next() {
|
||||
var perm Permission
|
||||
if scanErr := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); scanErr != nil {
|
||||
return nil, fmt.Errorf(errScanPermission, scanErr)
|
||||
}
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
|
||||
if rowErr := rows.Err(); rowErr != nil {
|
||||
return nil, fmt.Errorf(errIteratePermissions, rowErr)
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// GenerateID generates a unique ID for database records
|
||||
func GenerateID() string {
|
||||
return ulid.Make().String()
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) (*sql.DB, func()) {
|
||||
// Create a temporary database for testing
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open in-memory database: %v", err)
|
||||
}
|
||||
|
||||
// Read the schema file
|
||||
schemaBytes, err := os.ReadFile("../database/schema.sql")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read schema file: %v", err)
|
||||
}
|
||||
schema := string(schemaBytes)
|
||||
|
||||
// Execute the schema
|
||||
_, err = db.Exec(schema)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute schema: %v", err)
|
||||
}
|
||||
|
||||
// Create test data
|
||||
setupTestData(t, db)
|
||||
|
||||
// Return the database and a cleanup function
|
||||
return db, func() {
|
||||
db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestData(t *testing.T, db *sql.DB) {
|
||||
// Create test users
|
||||
_, err := db.Exec(`INSERT INTO users (id, created, user, password, salt) VALUES
|
||||
('user1', 1622505600, 'testadmin', 'dummy', 'dummy'),
|
||||
('user2', 1622505600, 'testoperator', 'dummy', 'dummy'),
|
||||
('user3', 1622505600, 'testuser', 'dummy', 'dummy')`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert test users: %v", err)
|
||||
}
|
||||
|
||||
// Create test roles (these should already exist from schema.sql)
|
||||
// But we'll check and insert if needed
|
||||
var count int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM roles WHERE role = 'admin'").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check roles: %v", err)
|
||||
}
|
||||
if count == 0 {
|
||||
_, err = db.Exec(`INSERT INTO roles (id, role) VALUES
|
||||
('role_admin', 'admin'),
|
||||
('role_db_operator', 'db_operator'),
|
||||
('role_user', 'user')`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert test roles: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign roles to users
|
||||
_, err = db.Exec(`INSERT INTO user_roles (id, uid, rid) VALUES
|
||||
('ur1', 'user1', 'role_admin'),
|
||||
('ur2', 'user2', 'role_db_operator'),
|
||||
('ur3', 'user3', 'role_user')`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign roles to users: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHasPermission(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
authService := NewAuthorizationService(db)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
resource string
|
||||
action string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Admin has database read permission",
|
||||
userID: "user1",
|
||||
resource: "database_credentials",
|
||||
action: "read",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Admin has database write permission",
|
||||
userID: "user1",
|
||||
resource: "database_credentials",
|
||||
action: "write",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "DB Operator has database read permission",
|
||||
userID: "user2",
|
||||
resource: "database_credentials",
|
||||
action: "read",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "DB Operator does not have database write permission",
|
||||
userID: "user2",
|
||||
resource: "database_credentials",
|
||||
action: "write",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Regular user does not have database read permission",
|
||||
userID: "user3",
|
||||
resource: "database_credentials",
|
||||
action: "read",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := authService.UserHasPermission(tt.userID, tt.resource, tt.action)
|
||||
if err != nil {
|
||||
t.Errorf("AuthorizationService.UserHasPermission() error = %v", err)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("AuthorizationService.UserHasPermission() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserPermissions(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
authService := NewAuthorizationService(db)
|
||||
|
||||
t.Run("Admin has all permissions", func(t *testing.T) {
|
||||
permissions, err := authService.GetUserPermissions("user1")
|
||||
if err != nil {
|
||||
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Admin should have 4 permissions
|
||||
if len(permissions) != 4 {
|
||||
t.Errorf("Admin should have 4 permissions, got %d", len(permissions))
|
||||
}
|
||||
|
||||
// Check for specific permissions
|
||||
hasDBRead := false
|
||||
hasDBWrite := false
|
||||
for _, p := range permissions {
|
||||
if p.Resource == "database_credentials" && p.Action == "read" {
|
||||
hasDBRead = true
|
||||
}
|
||||
if p.Resource == "database_credentials" && p.Action == "write" {
|
||||
hasDBWrite = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasDBRead {
|
||||
t.Errorf("Admin should have database_credentials:read permission")
|
||||
}
|
||||
if !hasDBWrite {
|
||||
t.Errorf("Admin should have database_credentials:write permission")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DB Operator has limited permissions", func(t *testing.T) {
|
||||
permissions, err := authService.GetUserPermissions("user2")
|
||||
if err != nil {
|
||||
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// DB Operator should have 1 permission
|
||||
if len(permissions) != 1 {
|
||||
t.Errorf("DB Operator should have 1 permission, got %d", len(permissions))
|
||||
}
|
||||
|
||||
// Check for specific permissions
|
||||
hasDBRead := false
|
||||
hasDBWrite := false
|
||||
for _, p := range permissions {
|
||||
if p.Resource == "database_credentials" && p.Action == "read" {
|
||||
hasDBRead = true
|
||||
}
|
||||
if p.Resource == "database_credentials" && p.Action == "write" {
|
||||
hasDBWrite = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasDBRead {
|
||||
t.Errorf("DB Operator should have database_credentials:read permission")
|
||||
}
|
||||
if hasDBWrite {
|
||||
t.Errorf("DB Operator should not have database_credentials:write permission")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Regular user has no permissions", func(t *testing.T) {
|
||||
permissions, err := authService.GetUserPermissions("user3")
|
||||
if err != nil {
|
||||
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Regular user should have 0 permissions
|
||||
if len(permissions) != 0 {
|
||||
t.Errorf("Regular user should have 0 permissions, got %d", len(permissions))
|
||||
}
|
||||
})
|
||||
}
|
||||
15
data/rand.go
15
data/rand.go
@@ -1,15 +0,0 @@
|
||||
package data
|
||||
|
||||
import "crypto/rand"
|
||||
|
||||
const saltLength = 32
|
||||
|
||||
func Salt() ([]byte, error) {
|
||||
salt := make([]byte, saltLength)
|
||||
_, err := rand.Read(salt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return salt, nil
|
||||
}
|
||||
125
data/totp.go
125
data/totp.go
@@ -1,125 +0,0 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
|
||||
// #nosec G505 - SHA1 is used here because TOTP (RFC 6238) specifically uses HMAC-SHA1
|
||||
// as the default algorithm, and many authenticator apps still use it.
|
||||
// In the future, we should consider supporting stronger algorithms like SHA256 or SHA512.
|
||||
"crypto/sha1"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image/png"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"rsc.io/qr"
|
||||
)
|
||||
|
||||
const (
|
||||
// TOTPTimeStep is the time step in seconds for TOTP.
|
||||
TOTPTimeStep = 30
|
||||
// TOTPDigits is the number of digits in a TOTP code.
|
||||
TOTPDigits = 6
|
||||
// TOTPModulo is the modulo value for truncating the TOTP hash.
|
||||
TOTPModulo = 1000000
|
||||
// TOTPTimeWindow is the number of time steps to check before and after the current time.
|
||||
TOTPTimeWindow = 1
|
||||
|
||||
// Constants for TOTP calculation
|
||||
timeBytesLength = 8
|
||||
dynamicTruncationMask = 0x0F
|
||||
truncationModulusMask = 0x7FFFFFFF
|
||||
)
|
||||
|
||||
// GenerateRandomBase32 generates a random base32 encoded string of the specified length.
|
||||
func GenerateRandomBase32(length int) (string, error) {
|
||||
// Generate random bytes
|
||||
randomBytes := make([]byte, length)
|
||||
_, err := rand.Read(randomBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Encode to base32
|
||||
encoder := base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||
encoded := encoder.EncodeToString(randomBytes)
|
||||
|
||||
// Convert to uppercase and remove any padding
|
||||
return strings.ToUpper(encoded), nil
|
||||
}
|
||||
|
||||
// ValidateTOTP validates a TOTP code against a secret
|
||||
func ValidateTOTP(secret, code string) bool {
|
||||
// Get current time step
|
||||
currentTime := time.Now().Unix() / TOTPTimeStep
|
||||
|
||||
// Try the time window (allow for time skew)
|
||||
for i := -TOTPTimeWindow; i <= TOTPTimeWindow; i++ {
|
||||
if calculateTOTP(secret, currentTime+int64(i)) == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// calculateTOTP calculates the TOTP code for a given secret and time
|
||||
func calculateTOTP(secret string, timeCounter int64) string {
|
||||
// Decode the secret from base32
|
||||
encoder := base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||
secretBytes, err := encoder.DecodeString(strings.ToUpper(secret))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Convert time counter to bytes (big endian)
|
||||
timeBytes := make([]byte, timeBytesLength)
|
||||
binary.BigEndian.PutUint64(timeBytes, uint64(timeCounter))
|
||||
|
||||
// Calculate HMAC-SHA1
|
||||
h := hmac.New(sha1.New, secretBytes)
|
||||
h.Write(timeBytes)
|
||||
hash := h.Sum(nil)
|
||||
|
||||
// Dynamic truncation
|
||||
offset := hash[len(hash)-1] & dynamicTruncationMask
|
||||
truncatedHash := binary.BigEndian.Uint32(hash[offset:offset+4]) & truncationModulusMask
|
||||
otp := truncatedHash % TOTPModulo
|
||||
|
||||
// Format as a 6-digit string with leading zeros
|
||||
return fmt.Sprintf("%0*d", TOTPDigits, otp)
|
||||
}
|
||||
|
||||
// GenerateTOTPQRCode generates a QR code for a TOTP secret and saves it to a file
|
||||
func GenerateTOTPQRCode(secret, username, issuer, outputPath string) error {
|
||||
// Format the TOTP URI according to the KeyURI format
|
||||
// https://github.com/google/google-authenticator/wiki/Key-Uri-Format
|
||||
uri := fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&algorithm=SHA1&digits=%d&period=%d",
|
||||
issuer, username, secret, issuer, TOTPDigits, TOTPTimeStep)
|
||||
|
||||
// Generate QR code
|
||||
code, err := qr.Encode(uri, qr.M)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate QR code: %w", err)
|
||||
}
|
||||
|
||||
// Create output file
|
||||
file, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write QR code as PNG
|
||||
img := code.Image()
|
||||
err = png.Encode(file, img)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write QR code to file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/gokyle/twofactor"
|
||||
)
|
||||
|
||||
func TestTOTPBasic(t *testing.T) {
|
||||
// This test verifies that we can import and use the twofactor package.
|
||||
totp := twofactor.TOTP{}
|
||||
fmt.Printf("TOTP: %+v\n", totp)
|
||||
}
|
||||
143
data/user.go
143
data/user.go
@@ -1,143 +0,0 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
scryptN = 32768
|
||||
scryptR = 8
|
||||
scryptP = 2
|
||||
|
||||
// Constants for derived key length and comparison
|
||||
derivedKeyLength = 32
|
||||
validCompareResult = 1
|
||||
|
||||
// Empty string constant
|
||||
emptyString = ""
|
||||
|
||||
// TOTP secret length in bytes (160 bits)
|
||||
totpSecretLength = 20
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
Created int64
|
||||
User string
|
||||
Password []byte
|
||||
Salt []byte
|
||||
TOTPSecret string
|
||||
Roles []string
|
||||
}
|
||||
|
||||
// HasRole checks if the user has a specific role
|
||||
func (u *User) HasRole(role string) bool {
|
||||
for _, r := range u.Roles {
|
||||
if r == role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasPermission checks if the user has a specific permission using the authorization service
|
||||
func (u *User) HasPermission(authService *AuthorizationService, resource, action string) (bool, error) {
|
||||
return authService.UserHasPermission(u.ID, resource, action)
|
||||
}
|
||||
|
||||
// GetPermissions returns all permissions for the user using the authorization service
|
||||
func (u *User) GetPermissions(authService *AuthorizationService) ([]Permission, error) {
|
||||
return authService.GetUserPermissions(u.ID)
|
||||
}
|
||||
|
||||
type Login struct {
|
||||
User string `json:"user"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TOTPCode string `json:"totp_code,omitempty"`
|
||||
}
|
||||
|
||||
func derive(password string, salt []byte) ([]byte, error) {
|
||||
return scrypt.Key([]byte(password), salt, scryptN, scryptR, scryptP, derivedKeyLength)
|
||||
}
|
||||
|
||||
// CheckPassword verifies only the username and password, without TOTP verification
|
||||
func (u *User) CheckPassword(login *Login) bool {
|
||||
if u.User != login.User {
|
||||
return false
|
||||
}
|
||||
|
||||
derived, err := derive(login.Password, u.Salt)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return subtle.ConstantTimeCompare(derived, u.Password) == validCompareResult
|
||||
}
|
||||
|
||||
// Check is a legacy method that now only checks the password
|
||||
// It's kept for backward compatibility but is equivalent to CheckPassword
|
||||
func (u *User) Check(login *Login) bool {
|
||||
// Only check username and password, TOTP verification is now a separate flow
|
||||
return u.CheckPassword(login)
|
||||
}
|
||||
|
||||
func (u *User) Register(login *Login) error {
|
||||
var err error
|
||||
|
||||
if u.User != emptyString && u.User != login.User {
|
||||
return errors.New("invalid user")
|
||||
}
|
||||
|
||||
if u.ID == emptyString {
|
||||
u.ID = ulid.Make().String()
|
||||
}
|
||||
|
||||
u.User = login.User
|
||||
u.Salt, err = Salt()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register user: %w", err)
|
||||
}
|
||||
|
||||
u.Password, err = derive(login.Password, u.Salt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("key derivation failed: %w", err)
|
||||
}
|
||||
|
||||
u.Created = time.Now().Unix()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateTOTPSecret generates a new TOTP secret for the user
|
||||
func (u *User) GenerateTOTPSecret() (string, error) {
|
||||
// Generate a random secret
|
||||
secret, err := GenerateRandomBase32(totpSecretLength)
|
||||
if err != nil {
|
||||
return emptyString, fmt.Errorf("failed to generate TOTP secret: %w", err)
|
||||
}
|
||||
|
||||
u.TOTPSecret = secret
|
||||
return u.TOTPSecret, nil
|
||||
}
|
||||
|
||||
// ValidateTOTPCode validates a TOTP code against the user's TOTP secret
|
||||
func (u *User) ValidateTOTPCode(code string) (bool, error) {
|
||||
if u.TOTPSecret == emptyString {
|
||||
return false, errors.New("TOTP not enabled for user")
|
||||
}
|
||||
|
||||
// Use the twofactor package to validate the code
|
||||
valid := ValidateTOTP(u.TOTPSecret, code)
|
||||
return valid, nil
|
||||
}
|
||||
|
||||
// HasTOTP returns true if TOTP is enabled for the user
|
||||
func (u *User) HasTOTP() bool {
|
||||
return u.TOTPSecret != emptyString
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUserRegisterAndCheck(t *testing.T) {
|
||||
user := &User{}
|
||||
login := &Login{
|
||||
User: "testuser",
|
||||
Password: "testpassword",
|
||||
}
|
||||
|
||||
err := user.Register(login)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register user: %v", err)
|
||||
}
|
||||
|
||||
if user.ID == "" {
|
||||
t.Error("Expected user ID to be set, got empty string")
|
||||
}
|
||||
if user.User != "testuser" {
|
||||
t.Errorf("Expected username 'testuser', got '%s'", user.User)
|
||||
}
|
||||
if len(user.Password) == 0 {
|
||||
t.Error("Expected password to be hashed and set, got empty slice")
|
||||
}
|
||||
if len(user.Salt) == 0 {
|
||||
t.Error("Expected salt to be set, got empty slice")
|
||||
}
|
||||
|
||||
correctLogin := &Login{
|
||||
User: "testuser",
|
||||
Password: "testpassword",
|
||||
}
|
||||
if !user.Check(correctLogin) {
|
||||
t.Error("Check failed for correct password")
|
||||
}
|
||||
|
||||
incorrectLogin := &Login{
|
||||
User: "testuser",
|
||||
Password: "wrongpassword",
|
||||
}
|
||||
if user.Check(incorrectLogin) {
|
||||
t.Error("Check passed for incorrect password")
|
||||
}
|
||||
|
||||
wrongUserLogin := &Login{
|
||||
User: "wronguser",
|
||||
Password: "testpassword",
|
||||
}
|
||||
if user.Check(wrongUserLogin) {
|
||||
t.Error("Check passed for incorrect username")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSalt(t *testing.T) {
|
||||
salt, err := Salt()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate salt: %v", err)
|
||||
}
|
||||
|
||||
if len(salt) != saltLength {
|
||||
t.Errorf("Expected salt length %d, got %d", saltLength, len(salt))
|
||||
}
|
||||
|
||||
salt2, err := Salt()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate second salt: %v", err)
|
||||
}
|
||||
|
||||
// Check that salts are different (extremely unlikely to be the same)
|
||||
different := false
|
||||
for i := 0; i < len(salt); i++ {
|
||||
if salt[i] != salt2[i] {
|
||||
different = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !different {
|
||||
t.Error("Expected different salts, got identical values")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user