feat: add encrypted local secret store
This commit is contained in:
parent
fe91784a8e
commit
7a7b3c7a3e
|
|
@ -0,0 +1,188 @@
|
|||
// Package secrets provides encrypted local storage for secret values.
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
keySize = 32
|
||||
nonceSize = 12
|
||||
)
|
||||
|
||||
// Store encrypts secret records before writing them to disk.
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
root string
|
||||
key []byte
|
||||
}
|
||||
|
||||
type encryptedRecord struct {
|
||||
Version int `json:"version"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type plaintextRecord struct {
|
||||
ID string `json:"id"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// NewStore creates an encrypted secret store rooted at root.
|
||||
func NewStore(root string, key []byte) (*Store, error) {
|
||||
if root == "" {
|
||||
return nil, fmt.Errorf("secret store root is empty")
|
||||
}
|
||||
if len(key) != keySize {
|
||||
return nil, fmt.Errorf("secret store key must be %d bytes", keySize)
|
||||
}
|
||||
|
||||
copiedKey := make([]byte, keySize)
|
||||
copy(copiedKey, key)
|
||||
return &Store{
|
||||
root: root,
|
||||
key: copiedKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Write encrypts and stores a secret value by ID.
|
||||
func (s *Store) Write(id, value string) error {
|
||||
if err := validateID(id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
plaintext, err := json.Marshal(plaintextRecord{ID: id, Value: value})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal secret: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, nonceSize)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
|
||||
aead, err := s.aead()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
record := encryptedRecord{
|
||||
Version: 1,
|
||||
Nonce: nonce,
|
||||
Ciphertext: aead.Seal(nil, nonce, plaintext, []byte(id)),
|
||||
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
data, err := json.MarshalIndent(record, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal encrypted secret: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return atomicWrite0600(s.pathForID(id), data)
|
||||
}
|
||||
|
||||
// Read decrypts and returns a secret value by ID.
|
||||
func (s *Store) Read(id string) (string, error) {
|
||||
if err := validateID(id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
data, err := os.ReadFile(s.pathForID(id))
|
||||
s.mu.RUnlock()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read secret %q: %w", id, err)
|
||||
}
|
||||
|
||||
var record encryptedRecord
|
||||
if err := json.Unmarshal(data, &record); err != nil {
|
||||
return "", fmt.Errorf("decode encrypted secret %q: %w", id, err)
|
||||
}
|
||||
if record.Version != 1 {
|
||||
return "", fmt.Errorf("unsupported secret version %d", record.Version)
|
||||
}
|
||||
|
||||
aead, err := s.aead()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
plaintext, err := aead.Open(nil, record.Nonce, record.Ciphertext, []byte(id))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt secret %q: %w", id, err)
|
||||
}
|
||||
|
||||
var decoded plaintextRecord
|
||||
if err := json.Unmarshal(plaintext, &decoded); err != nil {
|
||||
return "", fmt.Errorf("decode secret %q: %w", id, err)
|
||||
}
|
||||
if decoded.ID != id {
|
||||
return "", fmt.Errorf("secret %q contains mismatched id", id)
|
||||
}
|
||||
return decoded.Value, nil
|
||||
}
|
||||
|
||||
func (s *Store) aead() (cipher.AEAD, error) {
|
||||
block, err := aes.NewCipher(s.key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create secret cipher: %w", err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create secret gcm: %w", err)
|
||||
}
|
||||
return aead, nil
|
||||
}
|
||||
|
||||
func (s *Store) pathForID(id string) string {
|
||||
sum := sha256.Sum256([]byte(id))
|
||||
return filepath.Join(s.root, hex.EncodeToString(sum[:])+".json")
|
||||
}
|
||||
|
||||
func validateID(id string) error {
|
||||
if id == "" {
|
||||
return fmt.Errorf("secret id is empty")
|
||||
}
|
||||
if len(id) > 256 {
|
||||
return fmt.Errorf("secret id is too long")
|
||||
}
|
||||
if id == "." || id == ".." {
|
||||
return fmt.Errorf("secret id %q is a path traversal reference", id)
|
||||
}
|
||||
if strings.ContainsAny(id, `/\`) {
|
||||
return fmt.Errorf("secret id %q contains path separators", id)
|
||||
}
|
||||
if filepath.Clean(id) != id {
|
||||
return fmt.Errorf("secret id %q contains path traversal", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func atomicWrite0600(path string, data []byte) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0o700); err != nil {
|
||||
return fmt.Errorf("create secret store dir: %w", err)
|
||||
}
|
||||
|
||||
tmpFile := filepath.Join(dir, fmt.Sprintf(".tmp.%d", time.Now().UnixNano()))
|
||||
if err := os.WriteFile(tmpFile, data, 0o600); err != nil {
|
||||
return fmt.Errorf("write secret temp file: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmpFile, path); err != nil {
|
||||
os.Remove(tmpFile)
|
||||
return fmt.Errorf("commit secret file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
package secrets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testKey(seed byte) []byte {
|
||||
key := make([]byte, 32)
|
||||
for i := range key {
|
||||
key[i] = seed
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func readAllSecretStoreBytes(t *testing.T, root string) []byte {
|
||||
t.Helper()
|
||||
|
||||
var data []byte
|
||||
if err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
fileData, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = append(data, fileData...)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("read secret store bytes: %v", err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestStoreRoundTripsSecretWithoutPlaintextOnDisk(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
store, err := NewStore(root, testKey(0x11))
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore: %v", err)
|
||||
}
|
||||
|
||||
const id = "server.password"
|
||||
const value = "s3cr3t-value"
|
||||
if err := store.Write(id, value); err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
|
||||
got, err := store.Read(id)
|
||||
if err != nil {
|
||||
t.Fatalf("Read: %v", err)
|
||||
}
|
||||
if got != value {
|
||||
t.Fatalf("Read = %q, want %q", got, value)
|
||||
}
|
||||
|
||||
raw := readAllSecretStoreBytes(t, root)
|
||||
if bytes.Contains(raw, []byte(value)) {
|
||||
t.Fatal("secret value is stored as plaintext")
|
||||
}
|
||||
if bytes.Contains(raw, []byte(id)) {
|
||||
t.Fatal("secret id is stored as plaintext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRejectsWrongKey(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
store, err := NewStore(root, testKey(0x11))
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore: %v", err)
|
||||
}
|
||||
if err := store.Write("server.password", "s3cr3t-value"); err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
|
||||
wrongKeyStore, err := NewStore(root, testKey(0x22))
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore wrong key: %v", err)
|
||||
}
|
||||
if _, err := wrongKeyStore.Read("server.password"); err == nil {
|
||||
t.Fatal("Read with wrong key succeeded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRejectsUnsafeIDs(t *testing.T) {
|
||||
store, err := NewStore(t.TempDir(), testKey(0x11))
|
||||
if err != nil {
|
||||
t.Fatalf("NewStore: %v", err)
|
||||
}
|
||||
|
||||
for _, id := range []string{"", ".", "..", "../secret", `folder\secret`, strings.Repeat("a", 257)} {
|
||||
t.Run(id, func(t *testing.T) {
|
||||
if err := store.Write(id, "value"); err == nil {
|
||||
t.Fatalf("Write(%q): expected error", id)
|
||||
}
|
||||
if _, err := store.Read(id); err == nil {
|
||||
t.Fatalf("Read(%q): expected error", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue