diff --git a/internal/core/secrets/store.go b/internal/core/secrets/store.go new file mode 100644 index 0000000..e895922 --- /dev/null +++ b/internal/core/secrets/store.go @@ -0,0 +1,188 @@ +// Package secrets provides encrypted local storage for secret values. +package secrets + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +const ( + keySize = 32 + nonceSize = 12 +) + +// Store encrypts secret records before writing them to disk. +type Store struct { + mu sync.RWMutex + root string + key []byte +} + +type encryptedRecord struct { + Version int `json:"version"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` + UpdatedAt string `json:"updatedAt"` +} + +type plaintextRecord struct { + ID string `json:"id"` + Value string `json:"value"` +} + +// NewStore creates an encrypted secret store rooted at root. +func NewStore(root string, key []byte) (*Store, error) { + if root == "" { + return nil, fmt.Errorf("secret store root is empty") + } + if len(key) != keySize { + return nil, fmt.Errorf("secret store key must be %d bytes", keySize) + } + + copiedKey := make([]byte, keySize) + copy(copiedKey, key) + return &Store{ + root: root, + key: copiedKey, + }, nil +} + +// Write encrypts and stores a secret value by ID. +func (s *Store) Write(id, value string) error { + if err := validateID(id); err != nil { + return err + } + + plaintext, err := json.Marshal(plaintextRecord{ID: id, Value: value}) + if err != nil { + return fmt.Errorf("marshal secret: %w", err) + } + + nonce := make([]byte, nonceSize) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return fmt.Errorf("generate nonce: %w", err) + } + + aead, err := s.aead() + if err != nil { + return err + } + record := encryptedRecord{ + Version: 1, + Nonce: nonce, + Ciphertext: aead.Seal(nil, nonce, plaintext, []byte(id)), + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + } + data, err := json.MarshalIndent(record, "", " ") + if err != nil { + return fmt.Errorf("marshal encrypted secret: %w", err) + } + + s.mu.Lock() + defer s.mu.Unlock() + return atomicWrite0600(s.pathForID(id), data) +} + +// Read decrypts and returns a secret value by ID. +func (s *Store) Read(id string) (string, error) { + if err := validateID(id); err != nil { + return "", err + } + + s.mu.RLock() + data, err := os.ReadFile(s.pathForID(id)) + s.mu.RUnlock() + if err != nil { + return "", fmt.Errorf("read secret %q: %w", id, err) + } + + var record encryptedRecord + if err := json.Unmarshal(data, &record); err != nil { + return "", fmt.Errorf("decode encrypted secret %q: %w", id, err) + } + if record.Version != 1 { + return "", fmt.Errorf("unsupported secret version %d", record.Version) + } + + aead, err := s.aead() + if err != nil { + return "", err + } + plaintext, err := aead.Open(nil, record.Nonce, record.Ciphertext, []byte(id)) + if err != nil { + return "", fmt.Errorf("decrypt secret %q: %w", id, err) + } + + var decoded plaintextRecord + if err := json.Unmarshal(plaintext, &decoded); err != nil { + return "", fmt.Errorf("decode secret %q: %w", id, err) + } + if decoded.ID != id { + return "", fmt.Errorf("secret %q contains mismatched id", id) + } + return decoded.Value, nil +} + +func (s *Store) aead() (cipher.AEAD, error) { + block, err := aes.NewCipher(s.key) + if err != nil { + return nil, fmt.Errorf("create secret cipher: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("create secret gcm: %w", err) + } + return aead, nil +} + +func (s *Store) pathForID(id string) string { + sum := sha256.Sum256([]byte(id)) + return filepath.Join(s.root, hex.EncodeToString(sum[:])+".json") +} + +func validateID(id string) error { + if id == "" { + return fmt.Errorf("secret id is empty") + } + if len(id) > 256 { + return fmt.Errorf("secret id is too long") + } + if id == "." || id == ".." { + return fmt.Errorf("secret id %q is a path traversal reference", id) + } + if strings.ContainsAny(id, `/\`) { + return fmt.Errorf("secret id %q contains path separators", id) + } + if filepath.Clean(id) != id { + return fmt.Errorf("secret id %q contains path traversal", id) + } + return nil +} + +func atomicWrite0600(path string, data []byte) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0o700); err != nil { + return fmt.Errorf("create secret store dir: %w", err) + } + + tmpFile := filepath.Join(dir, fmt.Sprintf(".tmp.%d", time.Now().UnixNano())) + if err := os.WriteFile(tmpFile, data, 0o600); err != nil { + return fmt.Errorf("write secret temp file: %w", err) + } + if err := os.Rename(tmpFile, path); err != nil { + os.Remove(tmpFile) + return fmt.Errorf("commit secret file: %w", err) + } + return nil +} diff --git a/internal/core/secrets/store_test.go b/internal/core/secrets/store_test.go new file mode 100644 index 0000000..ab62ca6 --- /dev/null +++ b/internal/core/secrets/store_test.go @@ -0,0 +1,107 @@ +package secrets + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" +) + +func testKey(seed byte) []byte { + key := make([]byte, 32) + for i := range key { + key[i] = seed + } + return key +} + +func readAllSecretStoreBytes(t *testing.T, root string) []byte { + t.Helper() + + var data []byte + if err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + fileData, err := os.ReadFile(path) + if err != nil { + return err + } + data = append(data, fileData...) + return nil + }); err != nil { + t.Fatalf("read secret store bytes: %v", err) + } + return data +} + +func TestStoreRoundTripsSecretWithoutPlaintextOnDisk(t *testing.T) { + root := t.TempDir() + store, err := NewStore(root, testKey(0x11)) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + const id = "server.password" + const value = "s3cr3t-value" + if err := store.Write(id, value); err != nil { + t.Fatalf("Write: %v", err) + } + + got, err := store.Read(id) + if err != nil { + t.Fatalf("Read: %v", err) + } + if got != value { + t.Fatalf("Read = %q, want %q", got, value) + } + + raw := readAllSecretStoreBytes(t, root) + if bytes.Contains(raw, []byte(value)) { + t.Fatal("secret value is stored as plaintext") + } + if bytes.Contains(raw, []byte(id)) { + t.Fatal("secret id is stored as plaintext") + } +} + +func TestStoreRejectsWrongKey(t *testing.T) { + root := t.TempDir() + store, err := NewStore(root, testKey(0x11)) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + if err := store.Write("server.password", "s3cr3t-value"); err != nil { + t.Fatalf("Write: %v", err) + } + + wrongKeyStore, err := NewStore(root, testKey(0x22)) + if err != nil { + t.Fatalf("NewStore wrong key: %v", err) + } + if _, err := wrongKeyStore.Read("server.password"); err == nil { + t.Fatal("Read with wrong key succeeded") + } +} + +func TestStoreRejectsUnsafeIDs(t *testing.T) { + store, err := NewStore(t.TempDir(), testKey(0x11)) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + + for _, id := range []string{"", ".", "..", "../secret", `folder\secret`, strings.Repeat("a", 257)} { + t.Run(id, func(t *testing.T) { + if err := store.Write(id, "value"); err == nil { + t.Fatalf("Write(%q): expected error", id) + } + if _, err := store.Read(id); err == nil { + t.Fatalf("Read(%q): expected error", id) + } + }) + } +}