Junie: TOTP flow update and db migrations.
This commit is contained in:
47
data/auth.go
47
data/auth.go
@@ -7,7 +7,16 @@ import (
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
// Permission represents a system permission
|
||||
const (
|
||||
// Constants for error messages
|
||||
errScanPermission = "failed to scan permission: %w"
|
||||
errIteratePermissions = "error iterating permissions: %w"
|
||||
|
||||
// Constants for comparison
|
||||
zeroCount = 0
|
||||
)
|
||||
|
||||
// Permission represents a system permission.
|
||||
type Permission struct {
|
||||
ID string
|
||||
Resource string
|
||||
@@ -15,12 +24,12 @@ type Permission struct {
|
||||
Description string
|
||||
}
|
||||
|
||||
// AuthorizationService provides methods for checking user permissions
|
||||
// AuthorizationService provides methods for checking user permissions.
|
||||
type AuthorizationService struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewAuthorizationService creates a new authorization service
|
||||
// NewAuthorizationService creates a new authorization service.
|
||||
func NewAuthorizationService(db *sql.DB) *AuthorizationService {
|
||||
return &AuthorizationService{db: db}
|
||||
}
|
||||
@@ -40,10 +49,10 @@ func (a *AuthorizationService) UserHasPermission(userID, resource, action string
|
||||
return false, fmt.Errorf("failed to check user permission: %w", err)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
return count > zeroCount, nil
|
||||
}
|
||||
|
||||
// GetUserPermissions returns all permissions for a user based on their roles
|
||||
// GetUserPermissions returns all permissions for a user based on their roles.
|
||||
func (a *AuthorizationService) GetUserPermissions(userID string) ([]Permission, error) {
|
||||
query := `
|
||||
SELECT DISTINCT p.id, p.resource, p.action, p.description FROM permissions p
|
||||
@@ -61,20 +70,20 @@ func (a *AuthorizationService) GetUserPermissions(userID string) ([]Permission,
|
||||
var permissions []Permission
|
||||
for rows.Next() {
|
||||
var perm Permission
|
||||
if err := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan permission: %w", err)
|
||||
if scanErr := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); scanErr != nil {
|
||||
return nil, fmt.Errorf(errScanPermission, scanErr)
|
||||
}
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating permissions: %w", err)
|
||||
if rowErr := rows.Err(); rowErr != nil {
|
||||
return nil, fmt.Errorf(errIteratePermissions, rowErr)
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// GetRolePermissions returns all permissions for a specific role
|
||||
// GetRolePermissions returns all permissions for a specific role.
|
||||
func (a *AuthorizationService) GetRolePermissions(roleID string) ([]Permission, error) {
|
||||
query := `
|
||||
SELECT p.id, p.resource, p.action, p.description FROM permissions p
|
||||
@@ -91,14 +100,14 @@ func (a *AuthorizationService) GetRolePermissions(roleID string) ([]Permission,
|
||||
var permissions []Permission
|
||||
for rows.Next() {
|
||||
var perm Permission
|
||||
if err := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan permission: %w", err)
|
||||
if scanErr := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); scanErr != nil {
|
||||
return nil, fmt.Errorf(errScanPermission, scanErr)
|
||||
}
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating permissions: %w", err)
|
||||
if rowErr := rows.Err(); rowErr != nil {
|
||||
return nil, fmt.Errorf(errIteratePermissions, rowErr)
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
@@ -142,7 +151,7 @@ func (a *AuthorizationService) RevokePermissionFromRole(roleID, permissionID str
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllPermissions returns all permissions in the system
|
||||
// GetAllPermissions returns all permissions in the system.
|
||||
func (a *AuthorizationService) GetAllPermissions() ([]Permission, error) {
|
||||
query := `SELECT id, resource, action, description FROM permissions`
|
||||
|
||||
@@ -155,14 +164,14 @@ func (a *AuthorizationService) GetAllPermissions() ([]Permission, error) {
|
||||
var permissions []Permission
|
||||
for rows.Next() {
|
||||
var perm Permission
|
||||
if err := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan permission: %w", err)
|
||||
if scanErr := rows.Scan(&perm.ID, &perm.Resource, &perm.Action, &perm.Description); scanErr != nil {
|
||||
return nil, fmt.Errorf(errScanPermission, scanErr)
|
||||
}
|
||||
permissions = append(permissions, perm)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating permissions: %w", err)
|
||||
if rowErr := rows.Err(); rowErr != nil {
|
||||
return nil, fmt.Errorf(errIteratePermissions, rowErr)
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
|
||||
@@ -150,12 +150,12 @@ func TestGetUserPermissions(t *testing.T) {
|
||||
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Admin should have 4 permissions
|
||||
if len(permissions) != 4 {
|
||||
t.Errorf("Admin should have 4 permissions, got %d", len(permissions))
|
||||
}
|
||||
|
||||
|
||||
// Check for specific permissions
|
||||
hasDBRead := false
|
||||
hasDBWrite := false
|
||||
@@ -167,7 +167,7 @@ func TestGetUserPermissions(t *testing.T) {
|
||||
hasDBWrite = true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if !hasDBRead {
|
||||
t.Errorf("Admin should have database_credentials:read permission")
|
||||
}
|
||||
@@ -182,12 +182,12 @@ func TestGetUserPermissions(t *testing.T) {
|
||||
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// DB Operator should have 1 permission
|
||||
if len(permissions) != 1 {
|
||||
t.Errorf("DB Operator should have 1 permission, got %d", len(permissions))
|
||||
}
|
||||
|
||||
|
||||
// Check for specific permissions
|
||||
hasDBRead := false
|
||||
hasDBWrite := false
|
||||
@@ -199,7 +199,7 @@ func TestGetUserPermissions(t *testing.T) {
|
||||
hasDBWrite = true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if !hasDBRead {
|
||||
t.Errorf("DB Operator should have database_credentials:read permission")
|
||||
}
|
||||
@@ -214,10 +214,10 @@ func TestGetUserPermissions(t *testing.T) {
|
||||
t.Errorf("AuthorizationService.GetUserPermissions() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Regular user should have 0 permissions
|
||||
if len(permissions) != 0 {
|
||||
t.Errorf("Regular user should have 0 permissions, got %d", len(permissions))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
91
data/totp.go
91
data/totp.go
@@ -3,14 +3,39 @@ package data
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
|
||||
// #nosec G505 - SHA1 is used here because TOTP (RFC 6238) specifically uses HMAC-SHA1
|
||||
// as the default algorithm, and many authenticator apps still use it.
|
||||
// In the future, we should consider supporting stronger algorithms like SHA256 or SHA512.
|
||||
"crypto/sha1"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image/png"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"rsc.io/qr"
|
||||
)
|
||||
|
||||
// GenerateRandomBase32 generates a random base32 encoded string of the specified length
|
||||
const (
|
||||
// TOTPTimeStep is the time step in seconds for TOTP.
|
||||
TOTPTimeStep = 30
|
||||
// TOTPDigits is the number of digits in a TOTP code.
|
||||
TOTPDigits = 6
|
||||
// TOTPModulo is the modulo value for truncating the TOTP hash.
|
||||
TOTPModulo = 1000000
|
||||
// TOTPTimeWindow is the number of time steps to check before and after the current time.
|
||||
TOTPTimeWindow = 1
|
||||
|
||||
// Constants for TOTP calculation
|
||||
timeBytesLength = 8
|
||||
dynamicTruncationMask = 0x0F
|
||||
truncationModulusMask = 0x7FFFFFFF
|
||||
)
|
||||
|
||||
// GenerateRandomBase32 generates a random base32 encoded string of the specified length.
|
||||
func GenerateRandomBase32(length int) (string, error) {
|
||||
// Generate random bytes
|
||||
randomBytes := make([]byte, length)
|
||||
@@ -29,12 +54,11 @@ func GenerateRandomBase32(length int) (string, error) {
|
||||
|
||||
// ValidateTOTP validates a TOTP code against a secret
|
||||
func ValidateTOTP(secret, code string) bool {
|
||||
// Allow for a time skew of 30 seconds in either direction
|
||||
timeWindow := 1 // 1 before and 1 after current time
|
||||
currentTime := time.Now().Unix() / 30
|
||||
// Get current time step
|
||||
currentTime := time.Now().Unix() / TOTPTimeStep
|
||||
|
||||
// Try the time window
|
||||
for i := -timeWindow; i <= timeWindow; i++ {
|
||||
// Try the time window (allow for time skew)
|
||||
for i := -TOTPTimeWindow; i <= TOTPTimeWindow; i++ {
|
||||
if calculateTOTP(secret, currentTime+int64(i)) == code {
|
||||
return true
|
||||
}
|
||||
@@ -53,7 +77,7 @@ func calculateTOTP(secret string, timeCounter int64) string {
|
||||
}
|
||||
|
||||
// Convert time counter to bytes (big endian)
|
||||
timeBytes := make([]byte, 8)
|
||||
timeBytes := make([]byte, timeBytesLength)
|
||||
binary.BigEndian.PutUint64(timeBytes, uint64(timeCounter))
|
||||
|
||||
// Calculate HMAC-SHA1
|
||||
@@ -62,25 +86,40 @@ func calculateTOTP(secret string, timeCounter int64) string {
|
||||
hash := h.Sum(nil)
|
||||
|
||||
// Dynamic truncation
|
||||
offset := hash[len(hash)-1] & 0x0F
|
||||
truncatedHash := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7FFFFFFF
|
||||
otp := truncatedHash % 1000000
|
||||
offset := hash[len(hash)-1] & dynamicTruncationMask
|
||||
truncatedHash := binary.BigEndian.Uint32(hash[offset:offset+4]) & truncationModulusMask
|
||||
otp := truncatedHash % TOTPModulo
|
||||
|
||||
// Convert to 6-digit string with leading zeros if needed
|
||||
result := ""
|
||||
if otp < 10 {
|
||||
result = "00000" + string(otp+'0')
|
||||
} else if otp < 100 {
|
||||
result = "0000" + string((otp/10)+'0') + string((otp%10)+'0')
|
||||
} else if otp < 1000 {
|
||||
result = "000" + string((otp/100)+'0') + string(((otp/10)%10)+'0') + string((otp%10)+'0')
|
||||
} else if otp < 10000 {
|
||||
result = "00" + string((otp/1000)+'0') + string(((otp/100)%10)+'0') + string(((otp/10)%10)+'0') + string((otp%10)+'0')
|
||||
} else if otp < 100000 {
|
||||
result = "0" + string((otp/10000)+'0') + string(((otp/1000)%10)+'0') + string(((otp/100)%10)+'0') + string(((otp/10)%10)+'0') + string((otp%10)+'0')
|
||||
} else {
|
||||
result = string((otp/100000)+'0') + string(((otp/10000)%10)+'0') + string(((otp/1000)%10)+'0') + string(((otp/100)%10)+'0') + string(((otp/10)%10)+'0') + string((otp%10)+'0')
|
||||
// Format as a 6-digit string with leading zeros
|
||||
return fmt.Sprintf("%0*d", TOTPDigits, otp)
|
||||
}
|
||||
|
||||
// GenerateTOTPQRCode generates a QR code for a TOTP secret and saves it to a file
|
||||
func GenerateTOTPQRCode(secret, username, issuer, outputPath string) error {
|
||||
// Format the TOTP URI according to the KeyURI format
|
||||
// https://github.com/google/google-authenticator/wiki/Key-Uri-Format
|
||||
uri := fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&algorithm=SHA1&digits=%d&period=%d",
|
||||
issuer, username, secret, issuer, TOTPDigits, TOTPTimeStep)
|
||||
|
||||
// Generate QR code
|
||||
code, err := qr.Encode(uri, qr.M)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate QR code: %w", err)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
// Create output file
|
||||
file, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create output file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write QR code as PNG
|
||||
img := code.Image()
|
||||
err = png.Encode(file, img)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write QR code to file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
50
data/user.go
50
data/user.go
@@ -14,6 +14,16 @@ const (
|
||||
scryptN = 32768
|
||||
scryptR = 8
|
||||
scryptP = 2
|
||||
|
||||
// Constants for derived key length and comparison
|
||||
derivedKeyLength = 32
|
||||
validCompareResult = 1
|
||||
|
||||
// Empty string constant
|
||||
emptyString = ""
|
||||
|
||||
// TOTP secret length in bytes (160 bits)
|
||||
totpSecretLength = 20
|
||||
)
|
||||
|
||||
type User struct {
|
||||
@@ -48,16 +58,17 @@ func (u *User) GetPermissions(authService *AuthorizationService) ([]Permission,
|
||||
|
||||
type Login struct {
|
||||
User string `json:"user"`
|
||||
Password string `json:"password,omitzero"`
|
||||
Token string `json:"token,omitzero"`
|
||||
TOTPCode string `json:"totp_code,omitzero"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TOTPCode string `json:"totp_code,omitempty"`
|
||||
}
|
||||
|
||||
func derive(password string, salt []byte) ([]byte, error) {
|
||||
return scrypt.Key([]byte(password), salt, scryptN, scryptR, scryptP, 32)
|
||||
return scrypt.Key([]byte(password), salt, scryptN, scryptR, scryptP, derivedKeyLength)
|
||||
}
|
||||
|
||||
func (u *User) Check(login *Login) bool {
|
||||
// CheckPassword verifies only the username and password, without TOTP verification
|
||||
func (u *User) CheckPassword(login *Login) bool {
|
||||
if u.User != login.User {
|
||||
return false
|
||||
}
|
||||
@@ -67,18 +78,23 @@ func (u *User) Check(login *Login) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare(derived, u.Password) != 1 {
|
||||
return subtle.ConstantTimeCompare(derived, u.Password) == validCompareResult
|
||||
}
|
||||
|
||||
func (u *User) Check(login *Login) bool {
|
||||
// First check username and password
|
||||
if !u.CheckPassword(login) {
|
||||
return false
|
||||
}
|
||||
|
||||
// If TOTP is enabled for the user, validate the TOTP code
|
||||
if u.TOTPSecret != "" && login.TOTPCode != "" {
|
||||
// If TOTP is enabled for the user, validate the TOTP code
|
||||
if u.TOTPSecret != emptyString && login.TOTPCode != emptyString {
|
||||
// Use the ValidateTOTPCode method to validate the TOTP code
|
||||
valid, err := u.ValidateTOTPCode(login.TOTPCode)
|
||||
if err != nil || !valid {
|
||||
valid, validErr := u.ValidateTOTPCode(login.TOTPCode)
|
||||
if validErr != nil || !valid {
|
||||
return false
|
||||
}
|
||||
} else if u.TOTPSecret != "" && login.TOTPCode == "" {
|
||||
} else if u.TOTPSecret != emptyString && login.TOTPCode == emptyString {
|
||||
// TOTP is enabled but no code was provided
|
||||
return false
|
||||
}
|
||||
@@ -89,11 +105,11 @@ func (u *User) Check(login *Login) bool {
|
||||
func (u *User) Register(login *Login) error {
|
||||
var err error
|
||||
|
||||
if u.User != "" && u.User != login.User {
|
||||
if u.User != emptyString && u.User != login.User {
|
||||
return errors.New("invalid user")
|
||||
}
|
||||
|
||||
if u.ID == "" {
|
||||
if u.ID == emptyString {
|
||||
u.ID = ulid.Make().String()
|
||||
}
|
||||
|
||||
@@ -115,9 +131,9 @@ func (u *User) Register(login *Login) error {
|
||||
// GenerateTOTPSecret generates a new TOTP secret for the user
|
||||
func (u *User) GenerateTOTPSecret() (string, error) {
|
||||
// Generate a random secret
|
||||
secret, err := GenerateRandomBase32(20) // 20 bytes = 160 bits
|
||||
secret, err := GenerateRandomBase32(totpSecretLength)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate TOTP secret: %w", err)
|
||||
return emptyString, fmt.Errorf("failed to generate TOTP secret: %w", err)
|
||||
}
|
||||
|
||||
u.TOTPSecret = secret
|
||||
@@ -126,7 +142,7 @@ func (u *User) GenerateTOTPSecret() (string, error) {
|
||||
|
||||
// ValidateTOTPCode validates a TOTP code against the user's TOTP secret
|
||||
func (u *User) ValidateTOTPCode(code string) (bool, error) {
|
||||
if u.TOTPSecret == "" {
|
||||
if u.TOTPSecret == emptyString {
|
||||
return false, errors.New("TOTP not enabled for user")
|
||||
}
|
||||
|
||||
@@ -137,5 +153,5 @@ func (u *User) ValidateTOTPCode(code string) (bool, error) {
|
||||
|
||||
// HasTOTP returns true if TOTP is enabled for the user
|
||||
func (u *User) HasTOTP() bool {
|
||||
return u.TOTPSecret != ""
|
||||
return u.TOTPSecret != emptyString
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user