You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
208 lines
4.9 KiB
208 lines
4.9 KiB
package auth
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
var ErrInvalidCredentials = errors.New("invalid credentials")
|
|
var ErrSessionNotFound = errors.New("session not found")
|
|
|
|
type Store struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
type User struct {
|
|
ID int64
|
|
Username string
|
|
IsAdmin bool
|
|
}
|
|
|
|
type Session struct {
|
|
Token string
|
|
UserID int64
|
|
ExpiresAt time.Time
|
|
User *User
|
|
}
|
|
|
|
func NewStore(database *sql.DB) *Store {
|
|
return &Store{db: database}
|
|
}
|
|
|
|
func (s *Store) EnsureAdmin(username, password string) error {
|
|
username = strings.TrimSpace(username)
|
|
if username == "" || password == "" {
|
|
return fmt.Errorf("admin username and password are required")
|
|
}
|
|
|
|
var count int
|
|
if err := s.db.QueryRow(`SELECT COUNT(*) FROM users`).Scan(&count); err != nil {
|
|
return fmt.Errorf("count users: %w", err)
|
|
}
|
|
if count > 0 {
|
|
return nil
|
|
}
|
|
|
|
return s.CreateUser(username, password, true)
|
|
}
|
|
|
|
func (s *Store) CreateUser(username, password string, isAdmin bool) error {
|
|
username = strings.TrimSpace(username)
|
|
if username == "" {
|
|
return fmt.Errorf("username is required")
|
|
}
|
|
if password == "" {
|
|
return fmt.Errorf("password is required")
|
|
}
|
|
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return fmt.Errorf("hash password: %w", err)
|
|
}
|
|
|
|
adminValue := 0
|
|
if isAdmin {
|
|
adminValue = 1
|
|
}
|
|
|
|
if _, err := s.db.Exec(
|
|
`INSERT INTO users (username, password_hash, is_admin) VALUES (?, ?, ?)`,
|
|
username,
|
|
string(hash),
|
|
adminValue,
|
|
); err != nil {
|
|
return fmt.Errorf("create user: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) DeleteUser(id int64) error {
|
|
if _, err := s.db.Exec(`DELETE FROM users WHERE id = ?`, id); err != nil {
|
|
return fmt.Errorf("delete user: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) ListUsers() ([]User, error) {
|
|
rows, err := s.db.Query(`SELECT id, username, is_admin FROM users ORDER BY username`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list users: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var users []User
|
|
for rows.Next() {
|
|
var user User
|
|
var isAdmin int
|
|
if err := rows.Scan(&user.ID, &user.Username, &isAdmin); err != nil {
|
|
return nil, fmt.Errorf("scan user: %w", err)
|
|
}
|
|
user.IsAdmin = isAdmin != 0
|
|
users = append(users, user)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate users: %w", err)
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
func (s *Store) Authenticate(username, password string) (*User, error) {
|
|
var user User
|
|
var hash string
|
|
var isAdmin int
|
|
err := s.db.QueryRow(
|
|
`SELECT id, username, password_hash, is_admin FROM users WHERE username = ?`,
|
|
strings.TrimSpace(username),
|
|
).Scan(&user.ID, &user.Username, &hash, &isAdmin)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrInvalidCredentials
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load user: %w", err)
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
|
|
return nil, ErrInvalidCredentials
|
|
}
|
|
|
|
user.IsAdmin = isAdmin != 0
|
|
return &user, nil
|
|
}
|
|
|
|
func (s *Store) CreateSession(userID int64, ttl time.Duration) (string, error) {
|
|
tokenBytes := make([]byte, 32)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
return "", fmt.Errorf("generate session token: %w", err)
|
|
}
|
|
token := base64.RawURLEncoding.EncodeToString(tokenBytes)
|
|
expiresAt := time.Now().UTC().Add(ttl)
|
|
|
|
if _, err := s.db.Exec(
|
|
`INSERT INTO sessions (token, user_id, expires_at) VALUES (?, ?, ?)`,
|
|
token,
|
|
userID,
|
|
expiresAt.Format(time.RFC3339Nano),
|
|
); err != nil {
|
|
return "", fmt.Errorf("create session: %w", err)
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func (s *Store) GetSession(token string) (*Session, error) {
|
|
var session Session
|
|
var expiresAt string
|
|
var user User
|
|
var isAdmin int
|
|
err := s.db.QueryRow(
|
|
`SELECT s.token, s.user_id, s.expires_at, u.username, u.is_admin
|
|
FROM sessions s
|
|
JOIN users u ON u.id = s.user_id
|
|
WHERE s.token = ?`,
|
|
token,
|
|
).Scan(&session.Token, &session.UserID, &expiresAt, &user.Username, &isAdmin)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrSessionNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load session: %w", err)
|
|
}
|
|
|
|
parsedExpiresAt, err := time.Parse(time.RFC3339Nano, expiresAt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse session expiry: %w", err)
|
|
}
|
|
if !parsedExpiresAt.After(time.Now().UTC()) {
|
|
_ = s.DeleteSession(token)
|
|
return nil, ErrSessionNotFound
|
|
}
|
|
|
|
user.ID = session.UserID
|
|
user.IsAdmin = isAdmin != 0
|
|
session.ExpiresAt = parsedExpiresAt
|
|
session.User = &user
|
|
return &session, nil
|
|
}
|
|
|
|
func (s *Store) DeleteSession(token string) error {
|
|
if _, err := s.db.Exec(`DELETE FROM sessions WHERE token = ?`, token); err != nil {
|
|
return fmt.Errorf("delete session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) PurgeExpiredSessions() error {
|
|
if _, err := s.db.Exec(`DELETE FROM sessions WHERE expires_at <= ?`, time.Now().UTC().Format(time.RFC3339Nano)); err != nil {
|
|
return fmt.Errorf("purge expired sessions: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|