596 lines
13 KiB
Go
596 lines
13 KiB
Go
package vault
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/argon2"
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
)
|
|
|
|
const (
|
|
currentVersion = 1
|
|
saltLen = 32
|
|
nonceLen = 24
|
|
keyLen = 32
|
|
verifierID = "__sshkeeper_vault_verifier__"
|
|
verifierPlaintext = "sshkeeper-vault-verifier-v1"
|
|
)
|
|
|
|
type KDFMeta struct {
|
|
Name string `json:"name"`
|
|
MemoryKiB int `json:"memory_kib"`
|
|
Iterations int `json:"iterations"`
|
|
Parallelism int `json:"parallelism"`
|
|
Salt string `json:"salt"`
|
|
}
|
|
|
|
type Record struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
Nonce string `json:"nonce"`
|
|
Ciphertext string `json:"ciphertext"`
|
|
}
|
|
|
|
type VaultFile struct {
|
|
Version int `json:"version"`
|
|
KDF KDFMeta `json:"kdf"`
|
|
Verifier *Record `json:"verifier,omitempty"`
|
|
Records []Record `json:"records"`
|
|
}
|
|
|
|
type Vault struct {
|
|
mu sync.Mutex
|
|
path string
|
|
masterKey []byte
|
|
records map[string]secretRecord
|
|
modified bool
|
|
}
|
|
|
|
type secretRecord struct {
|
|
secretType string
|
|
plaintext []byte
|
|
}
|
|
|
|
type SecretMeta struct {
|
|
ID string
|
|
Alias string
|
|
Type string
|
|
}
|
|
|
|
func New(path string) *Vault {
|
|
return &Vault{
|
|
path: path,
|
|
records: make(map[string]secretRecord),
|
|
}
|
|
}
|
|
|
|
// Exists checks if vault file exists and has content
|
|
func Exists(path string) bool {
|
|
info, err := os.Stat(path)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return info.Size() > 0
|
|
}
|
|
|
|
// Create initializes a new vault with a master password
|
|
func Create(path string, masterPassword string) error {
|
|
salt := make([]byte, saltLen)
|
|
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
|
return fmt.Errorf("generate salt: %w", err)
|
|
}
|
|
|
|
kdf := KDFMeta{
|
|
Name: "argon2id",
|
|
MemoryKiB: 65536,
|
|
Iterations: 3,
|
|
Parallelism: 1,
|
|
Salt: base64.StdEncoding.EncodeToString(salt),
|
|
}
|
|
|
|
fmt.Print("Deriving key...")
|
|
|
|
key := argon2.IDKey([]byte(masterPassword), salt, uint32(kdf.Iterations), uint32(kdf.MemoryKiB), uint8(kdf.Parallelism), keyLen)
|
|
|
|
verifier, err := newVerifierRecord(key)
|
|
if err != nil {
|
|
return fmt.Errorf("create verifier: %w", err)
|
|
}
|
|
|
|
vf := VaultFile{
|
|
Version: currentVersion,
|
|
KDF: kdf,
|
|
Verifier: &verifier,
|
|
Records: []Record{},
|
|
}
|
|
|
|
data, err := json.Marshal(vf)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal vault: %w", err)
|
|
}
|
|
|
|
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return fmt.Errorf("create vault file: %w", err)
|
|
}
|
|
defer f.Close()
|
|
|
|
if _, err := f.Write(data); err != nil {
|
|
return fmt.Errorf("write vault: %w", err)
|
|
}
|
|
|
|
// Clear key from memory
|
|
for i := range key {
|
|
key[i] = 0
|
|
}
|
|
|
|
fmt.Println(" done.")
|
|
return nil
|
|
}
|
|
|
|
// Unlock decrypts the vault with master password
|
|
func (v *Vault) Unlock(masterPassword string) error {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
|
|
fmt.Print("Unlocking vault...")
|
|
|
|
data, err := os.ReadFile(v.path)
|
|
if err != nil {
|
|
return fmt.Errorf("read vault file: %w", err)
|
|
}
|
|
|
|
var vf VaultFile
|
|
if err := json.Unmarshal(data, &vf); err != nil {
|
|
return fmt.Errorf("parse vault: %w", err)
|
|
}
|
|
|
|
if vf.Version != currentVersion {
|
|
return fmt.Errorf("unsupported vault version: %d", vf.Version)
|
|
}
|
|
|
|
salt, err := base64.StdEncoding.DecodeString(vf.KDF.Salt)
|
|
if err != nil {
|
|
return fmt.Errorf("decode salt: %w", err)
|
|
}
|
|
|
|
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
|
|
|
|
if vf.Verifier != nil {
|
|
if err := verifyRecord(key, *vf.Verifier); err != nil {
|
|
return fmt.Errorf("invalid master password")
|
|
}
|
|
} else if len(vf.Records) > 0 {
|
|
if _, err := decryptRecord(key, vf.Records[0]); err != nil {
|
|
return fmt.Errorf("invalid master password")
|
|
}
|
|
} else {
|
|
return fmt.Errorf("vault cannot verify master password; recreate empty vault")
|
|
}
|
|
|
|
v.masterKey = key
|
|
v.records = make(map[string]secretRecord)
|
|
|
|
for _, rec := range vf.Records {
|
|
plaintext, err := decryptRecord(key, rec)
|
|
if err != nil {
|
|
return fmt.Errorf("decrypt record %s: %w", rec.ID, err)
|
|
}
|
|
v.records[rec.ID] = secretRecord{
|
|
secretType: inferSecretType(rec.ID, rec.Type),
|
|
plaintext: plaintext,
|
|
}
|
|
}
|
|
|
|
fmt.Println(" done.")
|
|
return nil
|
|
}
|
|
|
|
// Lock clears the master key and records from memory
|
|
func (v *Vault) Lock() {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
|
|
if v.masterKey != nil {
|
|
for i := range v.masterKey {
|
|
v.masterKey[i] = 0
|
|
}
|
|
}
|
|
v.masterKey = nil
|
|
v.records = make(map[string]secretRecord)
|
|
}
|
|
|
|
// IsUnlocked returns whether the vault is currently unlocked
|
|
func (v *Vault) IsUnlocked() bool {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
return v.masterKey != nil
|
|
}
|
|
|
|
// Put stores a secret in memory (not persisted until Save)
|
|
func (v *Vault) Put(id string, secretType string, plaintext []byte) error {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
|
|
if v.masterKey == nil {
|
|
return fmt.Errorf("vault is locked")
|
|
}
|
|
|
|
data := make([]byte, len(plaintext))
|
|
copy(data, plaintext)
|
|
v.records[id] = secretRecord{secretType: secretType, plaintext: data}
|
|
v.modified = true
|
|
return nil
|
|
}
|
|
|
|
// Get retrieves a secret
|
|
func (v *Vault) Get(id string) ([]byte, error) {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
|
|
if v.masterKey == nil {
|
|
return nil, fmt.Errorf("vault is locked")
|
|
}
|
|
|
|
record, ok := v.records[id]
|
|
if !ok {
|
|
return nil, fmt.Errorf("secret not found: %s", id)
|
|
}
|
|
|
|
result := make([]byte, len(record.plaintext))
|
|
copy(result, record.plaintext)
|
|
return result, nil
|
|
}
|
|
|
|
func (v *Vault) HasSecret(id string) bool {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
_, ok := v.records[id]
|
|
return ok
|
|
}
|
|
|
|
func (v *Vault) ListSecrets() ([]SecretMeta, error) {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
|
|
if v.masterKey == nil {
|
|
return nil, fmt.Errorf("vault is locked")
|
|
}
|
|
|
|
metas := make([]SecretMeta, 0, len(v.records))
|
|
for id, record := range v.records {
|
|
alias, secretType, ok := parseServerSecretID(id)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if record.secretType != "" {
|
|
secretType = record.secretType
|
|
}
|
|
metas = append(metas, SecretMeta{
|
|
ID: id,
|
|
Alias: alias,
|
|
Type: secretType,
|
|
})
|
|
}
|
|
sort.Slice(metas, func(i, j int) bool {
|
|
if metas[i].Alias == metas[j].Alias {
|
|
return metas[i].Type < metas[j].Type
|
|
}
|
|
return metas[i].Alias < metas[j].Alias
|
|
})
|
|
return metas, nil
|
|
}
|
|
|
|
// Delete removes a secret
|
|
func (v *Vault) Delete(id string) {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
delete(v.records, id)
|
|
v.modified = true
|
|
}
|
|
|
|
// Save persists encrypted vault to disk
|
|
func (v *Vault) Save() error {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
|
|
if v.masterKey == nil {
|
|
return fmt.Errorf("vault is locked")
|
|
}
|
|
|
|
salt, err := base64.StdEncoding.DecodeString(v.getSalt())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
kdf := KDFMeta{
|
|
Name: "argon2id",
|
|
MemoryKiB: 65536,
|
|
Iterations: 3,
|
|
Parallelism: 1,
|
|
Salt: base64.StdEncoding.EncodeToString(salt),
|
|
}
|
|
|
|
fmt.Print("Deriving key...")
|
|
|
|
var records []Record
|
|
for id, record := range v.records {
|
|
rec, err := encryptRecordWithType(v.masterKey, id, record.secretType, record.plaintext)
|
|
if err != nil {
|
|
return fmt.Errorf("encrypt record %s: %w", id, err)
|
|
}
|
|
records = append(records, rec)
|
|
}
|
|
|
|
verifier, err := newVerifierRecord(v.masterKey)
|
|
if err != nil {
|
|
return fmt.Errorf("create verifier: %w", err)
|
|
}
|
|
|
|
vf := VaultFile{
|
|
Version: currentVersion,
|
|
KDF: kdf,
|
|
Verifier: &verifier,
|
|
Records: records,
|
|
}
|
|
|
|
data, err := json.Marshal(vf)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal vault: %w", err)
|
|
}
|
|
|
|
tmpPath := v.path + ".tmp"
|
|
f, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return fmt.Errorf("create temp vault: %w", err)
|
|
}
|
|
|
|
if _, err := f.Write(data); err != nil {
|
|
f.Close()
|
|
os.Remove(tmpPath)
|
|
return fmt.Errorf("write vault: %w", err)
|
|
}
|
|
f.Close()
|
|
|
|
if err := os.Rename(tmpPath, v.path); err != nil {
|
|
os.Remove(tmpPath)
|
|
return fmt.Errorf("rename vault: %w", err)
|
|
}
|
|
|
|
v.modified = false
|
|
return nil
|
|
}
|
|
|
|
// ChangePassword re-encrypts the vault with a new master password
|
|
func (v *Vault) ChangePassword(newPassword string) error {
|
|
v.mu.Lock()
|
|
defer v.mu.Unlock()
|
|
|
|
if v.masterKey == nil {
|
|
return fmt.Errorf("vault is locked")
|
|
}
|
|
|
|
salt := make([]byte, saltLen)
|
|
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
|
return fmt.Errorf("generate salt: %w", err)
|
|
}
|
|
|
|
newKey := argon2.IDKey([]byte(newPassword), salt, 3, 65536, 1, keyLen)
|
|
|
|
kdf := KDFMeta{
|
|
Name: "argon2id",
|
|
MemoryKiB: 65536,
|
|
Iterations: 3,
|
|
Parallelism: 1,
|
|
Salt: base64.StdEncoding.EncodeToString(salt),
|
|
}
|
|
|
|
fmt.Print("Deriving key...")
|
|
|
|
var records []Record
|
|
for id, record := range v.records {
|
|
rec, err := encryptRecordWithType(newKey, id, record.secretType, record.plaintext)
|
|
if err != nil {
|
|
return fmt.Errorf("encrypt record: %w", err)
|
|
}
|
|
records = append(records, rec)
|
|
}
|
|
|
|
verifier, err := newVerifierRecord(newKey)
|
|
if err != nil {
|
|
return fmt.Errorf("create verifier: %w", err)
|
|
}
|
|
|
|
vf := VaultFile{
|
|
Version: currentVersion,
|
|
KDF: kdf,
|
|
Verifier: &verifier,
|
|
Records: records,
|
|
}
|
|
|
|
data, err := json.Marshal(vf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tmpPath := v.path + ".tmp"
|
|
f, err := os.OpenFile(tmpPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := f.Write(data); err != nil {
|
|
f.Close()
|
|
os.Remove(tmpPath)
|
|
return err
|
|
}
|
|
f.Close()
|
|
|
|
if err := os.Rename(tmpPath, v.path); err != nil {
|
|
os.Remove(tmpPath)
|
|
return err
|
|
}
|
|
|
|
// Swap key
|
|
for i := range v.masterKey {
|
|
v.masterKey[i] = 0
|
|
}
|
|
v.masterKey = newKey
|
|
|
|
return nil
|
|
}
|
|
|
|
// Helper to get salt from existing vault
|
|
func (v *Vault) getSalt() string {
|
|
data, err := os.ReadFile(v.path)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
var vf VaultFile
|
|
if err := json.Unmarshal(data, &vf); err != nil {
|
|
return ""
|
|
}
|
|
return vf.KDF.Salt
|
|
}
|
|
|
|
func encryptRecord(key []byte, id string, plaintext []byte) (Record, error) {
|
|
return encryptRecordWithType(key, id, "", plaintext)
|
|
}
|
|
|
|
func encryptRecordWithType(key []byte, id string, secretType string, plaintext []byte) (Record, error) {
|
|
aead, err := chacha20poly1305.NewX(key)
|
|
if err != nil {
|
|
return Record{}, err
|
|
}
|
|
|
|
nonce := make([]byte, nonceLen)
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return Record{}, err
|
|
}
|
|
|
|
ciphertext := aead.Seal(nil, nonce, plaintext, []byte(id))
|
|
|
|
return Record{
|
|
ID: id,
|
|
Type: secretType,
|
|
Nonce: base64.StdEncoding.EncodeToString(nonce),
|
|
Ciphertext: base64.StdEncoding.EncodeToString(ciphertext),
|
|
}, nil
|
|
}
|
|
|
|
func inferSecretType(id string, recordType string) string {
|
|
if recordType != "" {
|
|
return recordType
|
|
}
|
|
_, secretType, ok := parseServerSecretID(id)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
return secretType
|
|
}
|
|
|
|
func parseServerSecretID(id string) (string, string, bool) {
|
|
parts := strings.Split(id, ":")
|
|
if len(parts) != 3 || parts[0] != "server" || parts[1] == "" || parts[2] == "" {
|
|
return "", "", false
|
|
}
|
|
return parts[1], parts[2], true
|
|
}
|
|
|
|
func decryptRecord(key []byte, rec Record) ([]byte, error) {
|
|
aead, err := chacha20poly1305.NewX(key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
nonce, err := base64.StdEncoding.DecodeString(rec.Nonce)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode nonce: %w", err)
|
|
}
|
|
|
|
ciphertext, err := base64.StdEncoding.DecodeString(rec.Ciphertext)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode ciphertext: %w", err)
|
|
}
|
|
|
|
plaintext, err := aead.Open(nil, nonce, ciphertext, []byte(rec.ID))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decrypt failed: %w", err)
|
|
}
|
|
|
|
return plaintext, nil
|
|
}
|
|
|
|
func newVerifierRecord(key []byte) (Record, error) {
|
|
rec, err := encryptRecord(key, verifierID, []byte(verifierPlaintext))
|
|
if err != nil {
|
|
return Record{}, err
|
|
}
|
|
rec.Type = "verifier"
|
|
return rec, nil
|
|
}
|
|
|
|
func verifyRecord(key []byte, rec Record) error {
|
|
plaintext, err := decryptRecord(key, rec)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !SecureCompare(string(plaintext), verifierPlaintext) {
|
|
return fmt.Errorf("invalid verifier")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// VerifyPassword checks if a master password is correct without unlocking
|
|
func VerifyPassword(path string, masterPassword string) (bool, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
var vf VaultFile
|
|
if err := json.Unmarshal(data, &vf); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
salt, err := base64.StdEncoding.DecodeString(vf.KDF.Salt)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
key := argon2.IDKey([]byte(masterPassword), salt, uint32(vf.KDF.Iterations), uint32(vf.KDF.MemoryKiB), uint8(vf.KDF.Parallelism), keyLen)
|
|
defer func() {
|
|
for i := range key {
|
|
key[i] = 0
|
|
}
|
|
}()
|
|
|
|
if vf.Verifier != nil {
|
|
return verifyRecord(key, *vf.Verifier) == nil, nil
|
|
}
|
|
|
|
if len(vf.Records) == 0 {
|
|
return false, nil
|
|
}
|
|
_, err = decryptRecord(key, vf.Records[0])
|
|
return err == nil, nil
|
|
}
|
|
|
|
// Constant-time comparison to prevent timing attacks
|
|
func SecureCompare(a, b string) bool {
|
|
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
|
}
|
|
|
|
// Ensure time import is used
|
|
var _ time.Duration
|