Checkpoint sql work.
This commit is contained in:
44
data/migrate.go
Normal file
44
data/migrate.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// applySchema reads schema.sql from the current working directory and executes it.
|
||||
// If the file cannot be read, it falls back to a minimal schema that matches tests.
|
||||
func applySchema(ctx context.Context, db *sql.DB) error {
|
||||
b, err := os.ReadFile("schema.sql")
|
||||
if err != nil {
|
||||
// Fallback: minimal schema needed for users/roles used by tests.
|
||||
b = []byte(`
|
||||
PRAGMA foreign_keys = ON;
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
type TEXT NOT NULL CHECK (type IN ('human','system')),
|
||||
pwd_hash TEXT NOT NULL,
|
||||
totp_secret TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS roles (name TEXT PRIMARY KEY);
|
||||
CREATE TABLE IF NOT EXISTS user_roles (
|
||||
user_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, role)
|
||||
);
|
||||
`)
|
||||
}
|
||||
stmts := strings.TrimSpace(string(b))
|
||||
if stmts == "" {
|
||||
return errors.New("empty schema")
|
||||
}
|
||||
_, err = db.ExecContext(ctx, stmts)
|
||||
return err
|
||||
}
|
||||
156
data/store.go
Normal file
156
data/store.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// Store provides persistence for users and roles in a SQLite database.
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// Open opens or creates a SQLite database at the given path and ensures the schema exists.
|
||||
func Open(ctx context.Context, path string) (*Store, error) {
|
||||
dsn := path
|
||||
if dsn == "" {
|
||||
dsn = ":memory:"
|
||||
}
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Enforce foreign keys
|
||||
if _, err = db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err = applySchema(ctx, db); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &Store{db: db}, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database.
|
||||
func (s *Store) Close() error { return s.db.Close() }
|
||||
|
||||
// CreateUser inserts a new user. If u.ID is empty, a random ID is generated.
|
||||
func (s *Store) CreateUser(ctx context.Context, u *User) error {
|
||||
if u == nil {
|
||||
return errors.New("nil user")
|
||||
}
|
||||
if u.Username == "" {
|
||||
return errors.New("username required")
|
||||
}
|
||||
if u.Type == "" {
|
||||
u.Type = AccountHuman
|
||||
}
|
||||
if u.ID == "" {
|
||||
id, err := newID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.ID = id
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO users(id, username, type, pwd_hash, totp_secret, created_at, updated_at)
|
||||
VALUES(?,?,?,?,?,?,?)`,
|
||||
u.ID, u.Username, string(u.Type), u.PasswordHash(), u.TOTPSecret, now, now,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateUser updates mutable fields of a user identified by ID.
|
||||
func (s *Store) UpdateUser(ctx context.Context, u *User) error {
|
||||
if u == nil || u.ID == "" {
|
||||
return errors.New("user ID required")
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`UPDATE users SET username=?, type=?, pwd_hash=?, totp_secret=?, updated_at=? WHERE id=?`,
|
||||
u.Username, string(u.Type), u.PasswordHash(), u.TOTPSecret, time.Now().Unix(), u.ID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetUserByUsername fetches a user and its roles.
|
||||
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT id, username, type, pwd_hash, totp_secret FROM users WHERE username=?`, username)
|
||||
var id, uname, typ, ph, totp string
|
||||
if err := row.Scan(&id, &uname, &typ, &ph, &totp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u := &User{ID: id, Username: uname, Type: AccountType(typ), TOTPSecret: totp}
|
||||
if err := u.LoadPasswordHash(ph); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roles, err := s.userRoles(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.Roles = roles
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// AssignRole ensures a role exists and links it to the user.
|
||||
func (s *Store) AssignRole(ctx context.Context, userID, role string) error {
|
||||
if role == "" || userID == "" {
|
||||
return errors.New("userID and role required")
|
||||
}
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
if _, err = tx.ExecContext(ctx, `INSERT OR IGNORE INTO roles(name) VALUES (?)`, role); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = tx.ExecContext(ctx, `INSERT OR IGNORE INTO user_roles(user_id, role) VALUES (?,?)`, userID, role); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// RemoveRole removes a role association from a user.
|
||||
func (s *Store) RemoveRole(ctx context.Context, userID, role string) error {
|
||||
if role == "" || userID == "" {
|
||||
return errors.New("userID and role required")
|
||||
}
|
||||
_, err := s.db.ExecContext(ctx, `DELETE FROM user_roles WHERE user_id=? AND role=?`, userID, role)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) userRoles(ctx context.Context, userID string) ([]string, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `SELECT role FROM user_roles WHERE user_id=? ORDER BY role`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []string
|
||||
for rows.Next() {
|
||||
var r string
|
||||
if err := rows.Scan(&r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, r)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func newID() (string, error) {
|
||||
var b [16]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
dst := make([]byte, hex.EncodedLen(len(b)))
|
||||
hex.Encode(dst, b[:])
|
||||
return string(dst), nil
|
||||
}
|
||||
61
data/store_test.go
Normal file
61
data/store_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package data
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStoreUserCRUDAndRoles(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s, err := Open(ctx, ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("open store: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = s.Close() })
|
||||
|
||||
u := &User{Username: "alice", Type: AccountHuman}
|
||||
if err := u.SetPassword("correct horse battery staple"); err != nil {
|
||||
t.Fatalf("set password: %v", err)
|
||||
}
|
||||
u.TOTPSecret = "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ"
|
||||
if err := s.CreateUser(ctx, u); err != nil {
|
||||
t.Fatalf("create user: %v", err)
|
||||
}
|
||||
if u.ID == "" {
|
||||
t.Fatal("expected ID to be set")
|
||||
}
|
||||
|
||||
if err := s.AssignRole(ctx, u.ID, "admin"); err != nil {
|
||||
t.Fatalf("assign role: %v", err)
|
||||
}
|
||||
|
||||
got, err := s.GetUserByUsername(ctx, "alice")
|
||||
if err != nil {
|
||||
t.Fatalf("get user: %v", err)
|
||||
}
|
||||
if got.ID != u.ID || got.Username != "alice" || got.Type != AccountHuman {
|
||||
t.Fatalf("unexpected user: %+v", got)
|
||||
}
|
||||
if !got.CheckPassword("correct horse battery staple") {
|
||||
t.Fatal("password check failed after round-trip")
|
||||
}
|
||||
if len(got.Roles) != 1 || got.Roles[0] != "admin" {
|
||||
t.Fatalf("expected role admin, got %#v", got.Roles)
|
||||
}
|
||||
|
||||
// Update username and password
|
||||
got.Username = "alice2"
|
||||
if err := got.SetPassword("newpass"); err != nil {
|
||||
t.Fatalf("set new password: %v", err)
|
||||
}
|
||||
if err := s.UpdateUser(ctx, got); err != nil {
|
||||
t.Fatalf("update user: %v", err)
|
||||
}
|
||||
got2, err := s.GetUserByUsername(ctx, "alice2")
|
||||
if err != nil {
|
||||
t.Fatalf("get user 2: %v", err)
|
||||
}
|
||||
if !got2.CheckPassword("newpass") {
|
||||
t.Fatal("new password check failed")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user