package auth import ( "crypto/rand" "database/sql" "encoding/base64" "errors" "fmt" "strings" "time" "golang.org/x/crypto/bcrypt" ) var ErrInvalidCredentials = errors.New("invalid credentials") var ErrSessionNotFound = errors.New("session not found") type Store struct { db *sql.DB } type User struct { ID int64 Username string IsAdmin bool } type Session struct { Token string UserID int64 ExpiresAt time.Time User *User } func NewStore(database *sql.DB) *Store { return &Store{db: database} } func (s *Store) EnsureAdmin(username, password string) error { username = strings.TrimSpace(username) if username == "" || password == "" { return fmt.Errorf("admin username and password are required") } var count int if err := s.db.QueryRow(`SELECT COUNT(*) FROM users`).Scan(&count); err != nil { return fmt.Errorf("count users: %w", err) } if count > 0 { return nil } return s.CreateUser(username, password, true) } func (s *Store) CreateUser(username, password string, isAdmin bool) error { username = strings.TrimSpace(username) if username == "" { return fmt.Errorf("username is required") } if password == "" { return fmt.Errorf("password is required") } hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("hash password: %w", err) } adminValue := 0 if isAdmin { adminValue = 1 } if _, err := s.db.Exec( `INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`, username, string(hash), adminValue, ); err != nil { return fmt.Errorf("create user: %w", err) } return nil } func (s *Store) DeleteUser(id int64) error { if _, err := s.db.Exec(`DELETE FROM users WHERE id = ?`, id); err != nil { return fmt.Errorf("delete user: %w", err) } return nil } func (s *Store) ListUsers() ([]User, error) { rows, err := s.db.Query(`SELECT id, username, is_admin FROM users ORDER BY username`) if err != nil { return nil, fmt.Errorf("list users: %w", err) } defer rows.Close() var users []User for rows.Next() { var user User var isAdmin int if err := rows.Scan(&user.ID, &user.Username, &isAdmin); err != nil { return nil, fmt.Errorf("scan user: %w", err) } user.IsAdmin = isAdmin != 0 users = append(users, user) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate users: %w", err) } return users, nil } func (s *Store) Authenticate(username, password string) (*User, error) { var user User var hash string var isAdmin int err := s.db.QueryRow( `SELECT id, username, password_hash, is_admin FROM users WHERE username = ?`, strings.TrimSpace(username), ).Scan(&user.ID, &user.Username, &hash, &isAdmin) if errors.Is(err, sql.ErrNoRows) { return nil, ErrInvalidCredentials } if err != nil { return nil, fmt.Errorf("load user: %w", err) } if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil { return nil, ErrInvalidCredentials } user.IsAdmin = isAdmin != 0 return &user, nil } func (s *Store) CreateSession(userID int64, ttl time.Duration) (string, error) { tokenBytes := make([]byte, 32) if _, err := rand.Read(tokenBytes); err != nil { return "", fmt.Errorf("generate session token: %w", err) } token := base64.RawURLEncoding.EncodeToString(tokenBytes) expiresAt := time.Now().UTC().Add(ttl) if _, err := s.db.Exec( `INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)`, token, userID, expiresAt.Format(time.RFC3339Nano), ); err != nil { return "", fmt.Errorf("create session: %w", err) } return token, nil } func (s *Store) GetSession(token string) (*Session, error) { var session Session var expiresAt string var user User var isAdmin int err := s.db.QueryRow( `SELECT s.token, s.user_id, s.expires_at, u.username, u.is_admin FROM sessions s JOIN users u ON u.id = s.user_id WHERE s.token = ?`, token, ).Scan(&session.Token, &session.UserID, &expiresAt, &user.Username, &isAdmin) if errors.Is(err, sql.ErrNoRows) { return nil, ErrSessionNotFound } if err != nil { return nil, fmt.Errorf("load session: %w", err) } parsedExpiresAt, err := time.Parse(time.RFC3339Nano, expiresAt) if err != nil { return nil, fmt.Errorf("parse session expiry: %w", err) } if !parsedExpiresAt.After(time.Now().UTC()) { _ = s.DeleteSession(token) return nil, ErrSessionNotFound } user.ID = session.UserID user.IsAdmin = isAdmin != 0 session.ExpiresAt = parsedExpiresAt session.User = &user return &session, nil } func (s *Store) DeleteSession(token string) error { if _, err := s.db.Exec(`DELETE FROM sessions WHERE token = ?`, token); err != nil { return fmt.Errorf("delete session: %w", err) } return nil } func (s *Store) PurgeExpiredSessions() error { if _, err := s.db.Exec(`DELETE FROM sessions WHERE expires_at <= ?`, time.Now().UTC().Format(time.RFC3339Nano)); err != nil { return fmt.Errorf("purge expired sessions: %w", err) } return nil }