sshkeeper/internal/vault/vault.go

596 lines
13 KiB
Go

package vault
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"os"
"sort"
"strings"
"sync"
"time"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/chacha20poly1305"
)
const (
currentVersion = 1
saltLen = 32
nonceLen = 24
keyLen = 32
verifierID = "__sshkeeper_vault_verifier__"
verifierPlaintext = "sshkeeper-vault-verifier-v1"
)
type KDFMeta struct {
Name string `json:"name"`
MemoryKiB int `json:"memory_kib"`
Iterations int `json:"iterations"`
Parallelism int `json:"parallelism"`
Salt string `json:"salt"`
}
type Record struct {
ID string `json:"id"`
Type string `json:"type"`
Nonce string `json:"nonce"`
Ciphertext string `json:"ciphertext"`
}
type VaultFile struct {
Version int `json:"version"`
KDF KDFMeta `json:"kdf"`
Verifier *Record `json:"verifier,omitempty"`
Records []Record `json:"records"`
}
type Vault struct {
mu sync.Mutex
path string
masterKey []byte
records map[string]secretRecord
modified bool
}
type secretRecord struct {
secretType string
plaintext []byte
}
type SecretMeta struct {
ID string
Alias string
Type string
}
func New(path string) *Vault {
return &Vault{
path: path,
records: make(map[string]secretRecord),
}
}
// Exists checks if vault file exists and has content
func Exists(path string) bool {
info, err := os.Stat(path)
if err != nil {
return false
}
return info.Size() > 0
}
// Create initializes a new vault with a master password
func Create(path string, masterPassword string) error {
salt := make([]byte, saltLen)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return fmt.Errorf("generate salt: %w", err)
}
kdf := KDFMeta{
Name: "argon2id",
MemoryKiB: 65536,
Iterations: 3,
Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt),
}
fmt.Print("Deriving key...")
key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB), uint8(kdf.Parallelism), keyLen)
verifier, err := newVerifierRecord(key)
if err != nil {
return fmt.Errorf("create verifier: %w", err)
}
vf := VaultFile{
Version: currentVersion,
KDF: kdf,
Verifier: &verifier,
Records: []Record{},
}
data, err := json.Marshal(vf)
if err != nil {
return fmt.Errorf("marshal vault: %w", err)
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("create vault file: %w", err)
}
defer f.Close()
if _, err := f.Write(data); err != nil {
return fmt.Errorf("write vault: %w", err)
}
// Clear key from memory
for i := range key {
key[i] = 0
}
fmt.Println(" done.")
return nil
}
// Unlock decrypts the vault with master password
func (v *Vault) Unlock(masterPassword string) error {
v.mu.Lock()
defer v.mu.Unlock()
fmt.Print("Unlocking vault...")
data, err := os.ReadFile(v.path)
if err != nil {
return fmt.Errorf("read vault file: %w", err)
}
var vf VaultFile
if err := json.Unmarshal(data, &vf); err != nil {
return fmt.Errorf("parse vault: %w", err)
}
if vf.Version != currentVersion {
return fmt.Errorf("unsupported vault version: %d", vf.Version)
}
salt, err := base64.StdEncoding.DecodeString(vf.KDF.Salt)
if err != nil {
return fmt.Errorf("decode salt: %w", err)
}
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
if vf.Verifier != nil {
if err := verifyRecord(key, *vf.Verifier); err != nil {
return fmt.Errorf("invalid master password")
}
} else if len(vf.Records) > 0 {
if _, err := decryptRecord(key, vf.Records[0]); err != nil {
return fmt.Errorf("invalid master password")
}
} else {
return fmt.Errorf("vault cannot verify master password; recreate empty vault")
}
v.masterKey = key
v.records = make(map[string]secretRecord)
for _, rec := range vf.Records {
plaintext, err := decryptRecord(key, rec)
if err != nil {
return fmt.Errorf("decrypt record %s: %w", rec.ID, err)
}
v.records[rec.ID] = secretRecord{
secretType: inferSecretType(rec.ID, rec.Type),
plaintext: plaintext,
}
}
fmt.Println(" done.")
return nil
}
// Lock clears the master key and records from memory
func (v *Vault) Lock() {
v.mu.Lock()
defer v.mu.Unlock()
if v.masterKey != nil {
for i := range v.masterKey {
v.masterKey[i] = 0
}
}
v.masterKey = nil
v.records = make(map[string]secretRecord)
}
// IsUnlocked returns whether the vault is currently unlocked
func (v *Vault) IsUnlocked() bool {
v.mu.Lock()
defer v.mu.Unlock()
return v.masterKey != nil
}
// Put stores a secret in memory (not persisted until Save)
func (v *Vault) Put(id string, secretType string, plaintext []byte) error {
v.mu.Lock()
defer v.mu.Unlock()
if v.masterKey == nil {
return fmt.Errorf("vault is locked")
}
data := make([]byte, len(plaintext))
copy(data, plaintext)
v.records[id] = secretRecord{secretType: secretType, plaintext: data}
v.modified = true
return nil
}
// Get retrieves a secret
func (v *Vault) Get(id string) ([]byte, error) {
v.mu.Lock()
defer v.mu.Unlock()
if v.masterKey == nil {
return nil, fmt.Errorf("vault is locked")
}
record, ok := v.records[id]
if !ok {
return nil, fmt.Errorf("secret not found: %s", id)
}
result := make([]byte, len(record.plaintext))
copy(result, record.plaintext)
return result, nil
}
func (v *Vault) HasSecret(id string) bool {
v.mu.Lock()
defer v.mu.Unlock()
_, ok := v.records[id]
return ok
}
func (v *Vault) ListSecrets() ([]SecretMeta, error) {
v.mu.Lock()
defer v.mu.Unlock()
if v.masterKey == nil {
return nil, fmt.Errorf("vault is locked")
}
metas := make([]SecretMeta, 0, len(v.records))
for id, record := range v.records {
alias, secretType, ok := parseServerSecretID(id)
if !ok {
continue
}
if record.secretType != "" {
secretType = record.secretType
}
metas = append(metas, SecretMeta{
ID: id,
Alias: alias,
Type: secretType,
})
}
sort.Slice(metas, func(i, j int) bool {
if metas[i].Alias == metas[j].Alias {
return metas[i].Type < metas[j].Type
}
return metas[i].Alias < metas[j].Alias
})
return metas, nil
}
// Delete removes a secret
func (v *Vault) Delete(id string) {
v.mu.Lock()
defer v.mu.Unlock()
delete(v.records, id)
v.modified = true
}
// Save persists encrypted vault to disk
func (v *Vault) Save() error {
v.mu.Lock()
defer v.mu.Unlock()
if v.masterKey == nil {
return fmt.Errorf("vault is locked")
}
salt, err := base64.StdEncoding.DecodeString(v.getSalt())
if err != nil {
return err
}
kdf := KDFMeta{
Name: "argon2id",
MemoryKiB: 65536,
Iterations: 3,
Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt),
}
fmt.Print("Deriving key...")
var records []Record
for id, record := range v.records {
rec, err := encryptRecordWithType(v.masterKey, id, record.secretType, record.plaintext)
if err != nil {
return fmt.Errorf("encrypt record %s: %w", id, err)
}
records = append(records, rec)
}
verifier, err := newVerifierRecord(v.masterKey)
if err != nil {
return fmt.Errorf("create verifier: %w", err)
}
vf := VaultFile{
Version: currentVersion,
KDF: kdf,
Verifier: &verifier,
Records: records,
}
data, err := json.Marshal(vf)
if err != nil {
return fmt.Errorf("marshal vault: %w", err)
}
tmpPath := v.path + ".tmp"
f, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("create temp vault: %w", err)
}
if _, err := f.Write(data); err != nil {
f.Close()
os.Remove(tmpPath)
return fmt.Errorf("write vault: %w", err)
}
f.Close()
if err := os.Rename(tmpPath, v.path); err != nil {
os.Remove(tmpPath)
return fmt.Errorf("rename vault: %w", err)
}
v.modified = false
return nil
}
// ChangePassword re-encrypts the vault with a new master password
func (v *Vault) ChangePassword(newPassword string) error {
v.mu.Lock()
defer v.mu.Unlock()
if v.masterKey == nil {
return fmt.Errorf("vault is locked")
}
salt := make([]byte, saltLen)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return fmt.Errorf("generate salt: %w", err)
}
newKey := argon2.IDKey([]byte(newPassword), salt, 3, 65536, 1, keyLen)
kdf := KDFMeta{
Name: "argon2id",
MemoryKiB: 65536,
Iterations: 3,
Parallelism: 1,
Salt: base64.StdEncoding.EncodeToString(salt),
}
fmt.Print("Deriving key...")
var records []Record
for id, record := range v.records {
rec, err := encryptRecordWithType(newKey, id, record.secretType, record.plaintext)
if err != nil {
return fmt.Errorf("encrypt record: %w", err)
}
records = append(records, rec)
}
verifier, err := newVerifierRecord(newKey)
if err != nil {
return fmt.Errorf("create verifier: %w", err)
}
vf := VaultFile{
Version: currentVersion,
KDF: kdf,
Verifier: &verifier,
Records: records,
}
data, err := json.Marshal(vf)
if err != nil {
return err
}
tmpPath := v.path + ".tmp"
f, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return err
}
if _, err := f.Write(data); err != nil {
f.Close()
os.Remove(tmpPath)
return err
}
f.Close()
if err := os.Rename(tmpPath, v.path); err != nil {
os.Remove(tmpPath)
return err
}
// Swap key
for i := range v.masterKey {
v.masterKey[i] = 0
}
v.masterKey = newKey
return nil
}
// Helper to get salt from existing vault
func (v *Vault) getSalt() string {
data, err := os.ReadFile(v.path)
if err != nil {
return ""
}
var vf VaultFile
if err := json.Unmarshal(data, &vf); err != nil {
return ""
}
return vf.KDF.Salt
}
func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
return encryptRecordWithType(key, id, "", plaintext)
}
func encryptRecordWithType(key []byte, id string, secretType string, plaintext []byte) (Record, error) {
aead, err := chacha20poly1305.NewX(key)
if err != nil {
return Record{}, err
}
nonce := make([]byte, nonceLen)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return Record{}, err
}
ciphertext := aead.Seal(nil, nonce, plaintext, []byte(id))
return Record{
ID: id,
Type: secretType,
Nonce: base64.StdEncoding.EncodeToString(nonce),
Ciphertext: base64.StdEncoding.EncodeToString(ciphertext),
}, nil
}
func inferSecretType(id string, recordType string) string {
if recordType != "" {
return recordType
}
_, secretType, ok := parseServerSecretID(id)
if !ok {
return ""
}
return secretType
}
func parseServerSecretID(id string) (string, string, bool) {
parts := strings.Split(id, ":")
if len(parts) != 3 || parts[0] != "server" || parts[1] == "" || parts[2] == "" {
return "", "", false
}
return parts[1], parts[2], true
}
func decryptRecord(key []byte, rec Record) ([]byte, error) {
aead, err := chacha20poly1305.NewX(key)
if err != nil {
return nil, err
}
nonce, err := base64.StdEncoding.DecodeString(rec.Nonce)
if err != nil {
return nil, fmt.Errorf("decode nonce: %w", err)
}
ciphertext, err := base64.StdEncoding.DecodeString(rec.Ciphertext)
if err != nil {
return nil, fmt.Errorf("decode ciphertext: %w", err)
}
plaintext, err := aead.Open(nil, nonce, ciphertext, []byte(rec.ID))
if err != nil {
return nil, fmt.Errorf("decrypt failed: %w", err)
}
return plaintext, nil
}
func newVerifierRecord(key []byte) (Record, error) {
rec, err := encryptRecord(key, verifierID, []byte(verifierPlaintext))
if err != nil {
return Record{}, err
}
rec.Type = "verifier"
return rec, nil
}
func verifyRecord(key []byte, rec Record) error {
plaintext, err := decryptRecord(key, rec)
if err != nil {
return err
}
if !SecureCompare(string(plaintext), verifierPlaintext) {
return fmt.Errorf("invalid verifier")
}
return nil
}
// VerifyPassword checks if a master password is correct without unlocking
func VerifyPassword(path string, masterPassword string) (bool, error) {
data, err := os.ReadFile(path)
if err != nil {
return false, err
}
var vf VaultFile
if err := json.Unmarshal(data, &vf); err != nil {
return false, err
}
salt, err := base64.StdEncoding.DecodeString(vf.KDF.Salt)
if err != nil {
return false, err
}
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
defer func() {
for i := range key {
key[i] = 0
}
}()
if vf.Verifier != nil {
return verifyRecord(key, *vf.Verifier) == nil, nil
}
if len(vf.Records) == 0 {
return false, nil
}
_, err = decryptRecord(key, vf.Records[0])
return err == nil, nil
}
// Constant-time comparison to prevent timing attacks
func SecureCompare(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
// Ensure time import is used
var _ time.Duration