package web
import (
"fmt"
"html/template"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"testing"
"envault/vault"
)
func newTestServer(t *testing.T) *Server {
t.Helper()
v, err := vault.Open(filepath.Join(t.TempDir(), "missing-vault.json"), "test-password")
if err != nil {
t.Fatalf("vault.Open failed: %v", err)
}
funcs := template.FuncMap{
"dict": func(pairs ...any) (map[string]any, error) {
if len(pairs)%2 != 0 {
return nil, fmt.Errorf("dict requires an even number of arguments")
}
m := make(map[string]any, len(pairs)/2)
for i := 0; i < len(pairs); i += 2 {
k, ok := pairs[i].(string)
if !ok {
return nil, fmt.Errorf("dict keys must be strings")
}
m[k] = pairs[i+1]
}
return m, nil
},
}
tmpl, err := template.New("").Funcs(funcs).ParseFS(templateFiles, "templates/index.html")
if err != nil {
t.Fatalf("ParseFS failed: %v", err)
}
return &Server{
Vault: v,
VaultPath: filepath.Join(t.TempDir(), "vault.json"),
Password: "test-password",
tmpl: tmpl,
}
}
func TestHandleCreateProjectReturnsConflictForDuplicate(t *testing.T) {
srv := newTestServer(t)
if err := srv.Vault.CreateProject("demo"); err != nil {
t.Fatalf("CreateProject setup failed: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/projects", nil)
req.Form = map[string][]string{"name": {"demo"}}
rec := httptest.NewRecorder()
srv.handleCreateProject(rec, req)
if rec.Code != http.StatusConflict {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusConflict)
}
}
func TestHandleDeleteSecretReturnsNotFoundForMissingKey(t *testing.T) {
srv := newTestServer(t)
if err := srv.Vault.CreateProject("demo"); err != nil {
t.Fatalf("CreateProject setup failed: %v", err)
}
req := httptest.NewRequest(http.MethodDelete, "/projects/demo/secrets/MISSING", nil)
req.SetPathValue("project", "demo")
req.SetPathValue("key", "MISSING")
rec := httptest.NewRecorder()
srv.handleDeleteSecret(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
func TestHandleDeleteProjectReturnsNotFoundForMissingProject(t *testing.T) {
srv := newTestServer(t)
req := httptest.NewRequest(http.MethodDelete, "/projects/missing", nil)
req.SetPathValue("project", "missing")
rec := httptest.NewRecorder()
srv.handleDeleteProject(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
// --- CSRF middleware ---
func TestCSRFMiddleware_RejectsWrongOrigin(t *testing.T) {
srv := newTestServer(t)
srv.Port = 9871
called := false
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
handler := srv.csrfMiddleware(inner)
for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodDelete} {
called = false
req := httptest.NewRequest(method, "/", nil)
req.Header.Set("Origin", "http://evil.example.com")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("%s with wrong origin: got %d, want 403", method, rec.Code)
}
if called {
t.Errorf("%s with wrong origin should not reach inner handler", method)
}
}
}
func TestCSRFMiddleware_AllowsCorrectOrigin(t *testing.T) {
srv := newTestServer(t)
srv.Port = 9872
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := srv.csrfMiddleware(inner)
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Origin", "http://localhost:9872")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("correct origin: got %d, want 200", rec.Code)
}
}
func TestCSRFMiddleware_AllowsNoOrigin(t *testing.T) {
srv := newTestServer(t)
srv.Port = 9873
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := srv.csrfMiddleware(inner)
// No Origin header (e.g. curl, direct form submit) — should be allowed
req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("no origin: got %d, want 200", rec.Code)
}
}
func TestCSRFMiddleware_GETNotChecked(t *testing.T) {
srv := newTestServer(t)
srv.Port = 9874
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := srv.csrfMiddleware(inner)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Origin", "http://evil.example.com")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("GET with wrong origin should not be blocked: got %d", rec.Code)
}
}
// --- handleSetSecret ---
func TestHandleSetSecretRejectsBadKey(t *testing.T) {
srv := newTestServer(t)
if err := srv.Vault.CreateProject("demo"); err != nil {
t.Fatal(err)
}
for _, badKey := range []string{"", "KEY WITH SPACE", "KEY=BAD"} {
req := httptest.NewRequest(http.MethodPost, "/projects/demo/secrets", nil)
req.SetPathValue("project", "demo")
req.Form = map[string][]string{"key": {badKey}, "value": {"v"}}
rec := httptest.NewRecorder()
srv.handleSetSecret(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("key %q: got %d, want 400", badKey, rec.Code)
}
}
}
// --- handleRevealSecret ---
func TestHandleRevealSecretReturnsValueForExistingKey(t *testing.T) {
srv := newTestServer(t)
srv.Vault.CreateProject("demo")
srv.Vault.Set("demo", "MY_KEY", "supersecret")
req := httptest.NewRequest(http.MethodGet, "/projects/demo/secrets/MY_KEY/reveal", nil)
req.SetPathValue("project", "demo")
req.SetPathValue("key", "MY_KEY")
rec := httptest.NewRecorder()
srv.handleRevealSecret(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want 200", rec.Code)
}
if !strings.Contains(rec.Body.String(), "supersecret") {
t.Error("response should contain the secret value")
}
}
func TestHandleRevealSecretReturnsNotFoundForMissingKey(t *testing.T) {
srv := newTestServer(t)
srv.Vault.CreateProject("demo")
req := httptest.NewRequest(http.MethodGet, "/projects/demo/secrets/MISSING/reveal", nil)
req.SetPathValue("project", "demo")
req.SetPathValue("key", "MISSING")
rec := httptest.NewRecorder()
srv.handleRevealSecret(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("status = %d, want 404", rec.Code)
}
}
// --- handleUpdateSecret ---
func TestHandleUpdateSecretRejectsBadProjectOrKey(t *testing.T) {
srv := newTestServer(t)
cases := []struct {
project string
key string
}{
{"bad project", "KEY"},
{"demo", "bad key"},
{"", "KEY"},
}
for _, tc := range cases {
req := httptest.NewRequest(http.MethodPut, "/", nil)
req.SetPathValue("project", tc.project)
req.SetPathValue("key", tc.key)
req.Form = map[string][]string{"value": {"v"}}
rec := httptest.NewRecorder()
srv.handleUpdateSecret(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("project=%q key=%q: got %d, want 400", tc.project, tc.key, rec.Code)
}
}
}