68 lines
1.2 KiB
Go
68 lines
1.2 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
//go:embed migrations/*.sql
|
|
var migrationsFS embed.FS
|
|
|
|
type DB struct {
|
|
conn *sql.DB
|
|
}
|
|
|
|
func Open(dataDir string) (*DB, error) {
|
|
dbPath := filepath.Join(dataDir, "sshkeeper.db")
|
|
|
|
conn, err := sql.Open("sqlite", dbPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open database: %w", err)
|
|
}
|
|
|
|
if err := conn.Ping(); err != nil {
|
|
return nil, fmt.Errorf("ping database: %w", err)
|
|
}
|
|
|
|
db := &DB{conn: conn}
|
|
|
|
if err := db.migrate(); err != nil {
|
|
return nil, fmt.Errorf("migrate: %w", err)
|
|
}
|
|
|
|
os.Chmod(dbPath, 0600)
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func (db *DB) Close() error {
|
|
return db.conn.Close()
|
|
}
|
|
|
|
func (db *DB) migrate() error {
|
|
entries, err := migrationsFS.ReadDir("migrations")
|
|
if err != nil {
|
|
return fmt.Errorf("read migrations dir: %w", err)
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
content, err := migrationsFS.ReadFile("migrations/" + entry.Name())
|
|
if err != nil {
|
|
return fmt.Errorf("read migration %s: %w", entry.Name(), err)
|
|
}
|
|
if _, err := db.conn.Exec(string(content)); err != nil {
|
|
return fmt.Errorf("exec migration %s: %w", entry.Name(), err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|