From 4dbedb2f99fa6a2c350e4c2237f4fa8d3bec3dcb Mon Sep 17 00:00:00 2001 From: nkanaev Date: Sun, 14 Jun 2026 14:56:25 +0100 Subject: [PATCH] wrap tests in dbtest --- src/storage/tests/feed_test.go | 130 +++--- src/storage/tests/feedstate_test.go | 185 ++++---- src/storage/tests/folder_test.go | 120 +++--- src/storage/tests/item_test.go | 627 ++++++++++++++-------------- src/storage/tests/settings_test.go | 228 +++++----- src/storage/tests/storage_test.go | 4 +- 6 files changed, 652 insertions(+), 642 deletions(-) diff --git a/src/storage/tests/feed_test.go b/src/storage/tests/feed_test.go index b027546..1fe41c0 100644 --- a/src/storage/tests/feed_test.go +++ b/src/storage/tests/feed_test.go @@ -4,89 +4,95 @@ import ( "reflect" "testing" + "github.com/nkanaev/yarr/src/storage" "github.com/nkanaev/yarr/src/storage/model" ) func TestCreateFeed(t *testing.T) { - db := testDB() - feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example.com/feed.xml"}) - if feed1 == nil || feed1.Id == 0 { - t.Fatal("expected feed") - } - feed2 := db.GetFeed(feed1.Id) - if feed2 == nil || !reflect.DeepEqual(feed1, feed2) { - t.Fatal("invalid feed") - } + dbtest(t, func(t *testing.T, db storage.Storage) { + feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example.com/feed.xml"}) + if feed1 == nil || feed1.Id == 0 { + t.Fatal("expected feed") + } + feed2 := db.GetFeed(feed1.Id) + if feed2 == nil || !reflect.DeepEqual(feed1, feed2) { + t.Fatal("invalid feed") + } + }) } func TestCreateFeedSameLink(t *testing.T) { - db := testDB() - feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", FeedLink: "http://example1.com/feed.xml"}) - if feed1 == nil || feed1.Id == 0 { - t.Fatal("expected feed") - } + dbtest(t, func(t *testing.T, db storage.Storage) { + feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", FeedLink: "http://example1.com/feed.xml"}) + if feed1 == nil || feed1.Id == 0 { + t.Fatal("expected feed") + } - for range 10 { - db.CreateFeed(model.CreateFeedParams{Title: "title", FeedLink: "http://example2.com/feed.xml"}) - } + for range 10 { + db.CreateFeed(model.CreateFeedParams{Title: "title", FeedLink: "http://example2.com/feed.xml"}) + } - feed2 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example1.com/feed.xml"}) - if feed1.Id != feed2.Id { - t.Fatalf("expected the same feed.\nwant: %#v\nhave: %#v", feed1, feed2) - } + feed2 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example1.com/feed.xml"}) + if feed1.Id != feed2.Id { + t.Fatalf("expected the same feed.\nwant: %#v\nhave: %#v", feed1, feed2) + } + }) } func TestReadFeed(t *testing.T) { - db := testDB() - if db.GetFeed(100500) != nil { - t.Fatal("cannot get nonexistent feed") - } + dbtest(t, func(t *testing.T, db storage.Storage) { + if db.GetFeed(100500) != nil { + t.Fatal("cannot get nonexistent feed") + } - feed1 := db.CreateFeed(model.CreateFeedParams{Title: "feed 1", Link: "http://example1.com", FeedLink: "http://example1.com/feed.xml"}) - feed2 := db.CreateFeed(model.CreateFeedParams{Title: "feed 2", Link: "http://example2.com", FeedLink: "http://example2.com/feed.xml"}) - feeds := db.ListFeeds() - if !reflect.DeepEqual(feeds, []model.Feed{*feed1, *feed2}) { - t.Fatalf("invalid feed list: %#v", feeds) - } + feed1 := db.CreateFeed(model.CreateFeedParams{Title: "feed 1", Link: "http://example1.com", FeedLink: "http://example1.com/feed.xml"}) + feed2 := db.CreateFeed(model.CreateFeedParams{Title: "feed 2", Link: "http://example2.com", FeedLink: "http://example2.com/feed.xml"}) + feeds := db.ListFeeds() + if !reflect.DeepEqual(feeds, []model.Feed{*feed1, *feed2}) { + t.Fatalf("invalid feed list: %#v", feeds) + } + }) } func TestUpdateFeed(t *testing.T) { - db := testDB() - feed1 := db.CreateFeed(model.CreateFeedParams{Title: "feed 1", Link: "http://example1.com", FeedLink: "http://example1.com/feed.xml"}) - folder := db.CreateFolder("test") - icon := []byte("icon") + dbtest(t, func(t *testing.T, db storage.Storage) { + feed1 := db.CreateFeed(model.CreateFeedParams{Title: "feed 1", Link: "http://example1.com", FeedLink: "http://example1.com/feed.xml"}) + folder := db.CreateFolder("test") + icon := []byte("icon") - title := "newtitle" - db.UpdateFeed(feed1.Id, model.UpdateFeedParams{ - Title: &title, - FolderID: model.SetNullable(&folder.Id), - Icon: model.SetNullable(&icon), + title := "newtitle" + db.UpdateFeed(feed1.Id, model.UpdateFeedParams{ + Title: &title, + FolderID: model.SetNullable(&folder.Id), + Icon: model.SetNullable(&icon), + }) + + feed2 := db.GetFeed(feed1.Id) + if feed2.Title != "newtitle" { + t.Error("invalid title") + } + if feed2.FolderId == nil || *feed2.FolderId != folder.Id { + t.Error("invalid folder") + } + if !feed2.HasIcon || string(*feed2.Icon) != "icon" { + t.Error("invalid icon") + } }) - - feed2 := db.GetFeed(feed1.Id) - if feed2.Title != "newtitle" { - t.Error("invalid title") - } - if feed2.FolderId == nil || *feed2.FolderId != folder.Id { - t.Error("invalid folder") - } - if !feed2.HasIcon || string(*feed2.Icon) != "icon" { - t.Error("invalid icon") - } } func TestDeleteFeed(t *testing.T) { - db := testDB() - feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example.com/feed.xml"}) + dbtest(t, func(t *testing.T, db storage.Storage) { + feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example.com/feed.xml"}) - if db.DeleteFeed(100500) { - t.Error("cannot delete what does not exist") - } + if db.DeleteFeed(100500) { + t.Error("cannot delete what does not exist") + } - if !db.DeleteFeed(feed1.Id) { - t.Fatal("did not delete existing feed") - } - if db.GetFeed(feed1.Id) != nil { - t.Fatal("feed still exists") - } + if !db.DeleteFeed(feed1.Id) { + t.Fatal("did not delete existing feed") + } + if db.GetFeed(feed1.Id) != nil { + t.Fatal("feed still exists") + } + }) } diff --git a/src/storage/tests/feedstate_test.go b/src/storage/tests/feedstate_test.go index d3c3d8c..23ffc06 100644 --- a/src/storage/tests/feedstate_test.go +++ b/src/storage/tests/feedstate_test.go @@ -4,126 +4,123 @@ import ( "testing" "time" + "github.com/nkanaev/yarr/src/storage" "github.com/nkanaev/yarr/src/storage/model" ) func TestUpdateFeedState_Full(t *testing.T) { - s := testDB() - defer s.Close() + dbtest(t, func(t *testing.T, s storage.Storage) { + f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"}) - f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"}) + now := time.Now().UTC().Truncate(time.Second) + errMsg := "error" + lmod := "today" + etag := "v1" - now := time.Now().UTC().Truncate(time.Second) - errMsg := "error" - lmod := "today" - etag := "v1" + ok, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{ + LastRefreshed: &now, + LastError: &errMsg, + HTTPLastModified: &lmod, + HTTPEtag: &etag, + }) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("expected true") + } - ok, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{ - LastRefreshed: &now, - LastError: &errMsg, - HTTPLastModified: &lmod, - HTTPEtag: &etag, + state, err := s.GetFeedState(f.Id) + if err != nil { + t.Fatal(err) + } + if state == nil { + t.Fatal("expected state, got nil") + } + if !state.LastRefreshed.Equal(now) { + t.Errorf("expected %v, got %v", now, state.LastRefreshed) + } + if state.LastError != errMsg { + t.Errorf("expected %s, got %v", errMsg, state.LastError) + } + if state.HTTPLastModified != lmod { + t.Errorf("expected %s, got %s", lmod, state.HTTPLastModified) + } + if state.HTTPEtag != etag { + t.Errorf("expected %s, got %s", etag, state.HTTPEtag) + } }) - if err != nil { - t.Fatal(err) - } - if !ok { - t.Error("expected true") - } - - state, err := s.GetFeedState(f.Id) - if err != nil { - t.Fatal(err) - } - if state == nil { - t.Fatal("expected state, got nil") - } - if !state.LastRefreshed.Equal(now) { - t.Errorf("expected %v, got %v", now, state.LastRefreshed) - } - if state.LastError != errMsg { - t.Errorf("expected %s, got %v", errMsg, state.LastError) - } - if state.HTTPLastModified != lmod { - t.Errorf("expected %s, got %s", lmod, state.HTTPLastModified) - } - if state.HTTPEtag != etag { - t.Errorf("expected %s, got %s", etag, state.HTTPEtag) - } } func TestUpdateFeedState_Partial(t *testing.T) { - s := testDB() - defer s.Close() + dbtest(t, func(t *testing.T, s storage.Storage) { + f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"}) + etag := "v1" + s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{HTTPEtag: &etag}) - f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"}) - etag := "v1" - s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{HTTPEtag: &etag}) + newErr := "new error" + _, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{ + LastError: &newErr, + }) + if err != nil { + t.Fatal(err) + } - newErr := "new error" - _, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{ - LastError: &newErr, + state, err := s.GetFeedState(f.Id) + if err != nil { + t.Fatal(err) + } + if state.LastError != newErr { + t.Errorf("expected %s, got %v", newErr, state.LastError) + } + if state.HTTPEtag != etag { + t.Errorf("etag should be unchanged, got %s", state.HTTPEtag) + } }) - if err != nil { - t.Fatal(err) - } - - state, err := s.GetFeedState(f.Id) - if err != nil { - t.Fatal(err) - } - if state.LastError != newErr { - t.Errorf("expected %s, got %v", newErr, state.LastError) - } - if state.HTTPEtag != etag { - t.Errorf("etag should be unchanged, got %s", state.HTTPEtag) - } } func TestUpdateFeedState_ClearError(t *testing.T) { - s := testDB() - defer s.Close() + dbtest(t, func(t *testing.T, s storage.Storage) { + f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"}) + errMsg := "error" + s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{LastError: &errMsg}) - f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"}) - errMsg := "error" - s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{LastError: &errMsg}) + empty := "" + _, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{ + LastError: &empty, + }) + if err != nil { + t.Fatal(err) + } - empty := "" - _, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{ - LastError: &empty, + state, err := s.GetFeedState(f.Id) + if err != nil { + t.Fatal(err) + } + if state.LastError != "" { + t.Errorf("expected empty error string, got %v", state.LastError) + } }) - if err != nil { - t.Fatal(err) - } - - state, err := s.GetFeedState(f.Id) - if err != nil { - t.Fatal(err) - } - if state.LastError != "" { - t.Errorf("expected empty error string, got %v", state.LastError) - } } func TestListFeedStates(t *testing.T) { - s := testDB() - defer s.Close() + dbtest(t, func(t *testing.T, s storage.Storage) { + f1 := s.CreateFeed(model.CreateFeedParams{Title: "F1", FeedLink: "L1"}) + f2 := s.CreateFeed(model.CreateFeedParams{Title: "F2", FeedLink: "L2"}) - f1 := s.CreateFeed(model.CreateFeedParams{Title: "F1", FeedLink: "L1"}) - f2 := s.CreateFeed(model.CreateFeedParams{Title: "F2", FeedLink: "L2"}) + errMsg := "fail" + s.UpdateFeedState(f1.Id, model.UpdateFeedStateParams{LastError: &errMsg}) + s.UpdateFeedState(f2.Id, model.UpdateFeedStateParams{HTTPEtag: ptr("e")}) - errMsg := "fail" - s.UpdateFeedState(f1.Id, model.UpdateFeedStateParams{LastError: &errMsg}) - s.UpdateFeedState(f2.Id, model.UpdateFeedStateParams{HTTPEtag: ptr("e")}) + states, err := s.ListFeedStates() + if err != nil { + t.Fatal(err) + } - states, err := s.ListFeedStates() - if err != nil { - t.Fatal(err) - } - - if len(states) != 2 { - t.Errorf("expected 2 states, got %d", len(states)) - } + if len(states) != 2 { + t.Errorf("expected 2 states, got %d", len(states)) + } + }) } func ptr[T any](v T) *T { diff --git a/src/storage/tests/folder_test.go b/src/storage/tests/folder_test.go index 4b406e3..b475ef4 100644 --- a/src/storage/tests/folder_test.go +++ b/src/storage/tests/folder_test.go @@ -3,78 +3,80 @@ package tests import ( "testing" + "github.com/nkanaev/yarr/src/storage" "github.com/nkanaev/yarr/src/storage/model" ) func TestUpdateFolder(t *testing.T) { - db := testDB() - folder := db.CreateFolder("old title") - if folder.IsExpanded != true { - t.Fatal("expected folder to be expanded by default") - } + dbtest(t, func(t *testing.T, db storage.Storage) { + folder := db.CreateFolder("old title") + if folder.IsExpanded != true { + t.Fatal("expected folder to be expanded by default") + } - t.Run("rename only", func(t *testing.T) { - newTitle := "new title" - ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{ - Title: &newTitle, + t.Run("rename only", func(t *testing.T) { + newTitle := "new title" + ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{ + Title: &newTitle, + }) + if !ok || err != nil { + t.Fatalf("UpdateFolder failed: %v", err) + } + + folders := db.ListFolders() + if len(folders) != 1 || folders[0].Title != "new title" { + t.Errorf("expected title to be updated, got %s", folders[0].Title) + } + if folders[0].IsExpanded != true { + t.Error("expected expansion state to remain unchanged") + } }) - if !ok || err != nil { - t.Fatalf("UpdateFolder failed: %v", err) - } - folders := db.ListFolders() - if len(folders) != 1 || folders[0].Title != "new title" { - t.Errorf("expected title to be updated, got %s", folders[0].Title) - } - if folders[0].IsExpanded != true { - t.Error("expected expansion state to remain unchanged") - } - }) + t.Run("toggle expanded only", func(t *testing.T) { + isExpanded := false + ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{ + IsExpanded: &isExpanded, + }) + if !ok || err != nil { + t.Fatalf("UpdateFolder failed: %v", err) + } - t.Run("toggle expanded only", func(t *testing.T) { - isExpanded := false - ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{ - IsExpanded: &isExpanded, + folders := db.ListFolders() + if len(folders) != 1 || folders[0].IsExpanded != false { + t.Errorf("expected is_expanded to be false, got %v", folders[0].IsExpanded) + } + if folders[0].Title != "new title" { + t.Error("expected title to remain unchanged") + } }) - if !ok || err != nil { - t.Fatalf("UpdateFolder failed: %v", err) - } - folders := db.ListFolders() - if len(folders) != 1 || folders[0].IsExpanded != false { - t.Errorf("expected is_expanded to be false, got %v", folders[0].IsExpanded) - } - if folders[0].Title != "new title" { - t.Error("expected title to remain unchanged") - } - }) + t.Run("update both", func(t *testing.T) { + bothTitle := "both" + isExpanded := true + ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{ + Title: &bothTitle, + IsExpanded: &isExpanded, + }) + if !ok || err != nil { + t.Fatalf("UpdateFolder failed: %v", err) + } - t.Run("update both", func(t *testing.T) { - bothTitle := "both" - isExpanded := true - ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{ - Title: &bothTitle, - IsExpanded: &isExpanded, + folders := db.ListFolders() + if len(folders) != 1 || folders[0].Title != "both" || folders[0].IsExpanded != true { + t.Errorf("expected both to be updated, got title=%s expanded=%v", folders[0].Title, folders[0].IsExpanded) + } }) - if !ok || err != nil { - t.Fatalf("UpdateFolder failed: %v", err) - } - folders := db.ListFolders() - if len(folders) != 1 || folders[0].Title != "both" || folders[0].IsExpanded != true { - t.Errorf("expected both to be updated, got title=%s expanded=%v", folders[0].Title, folders[0].IsExpanded) - } - }) + t.Run("update none", func(t *testing.T) { + ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{}) + if !ok || err != nil { + t.Fatalf("UpdateFolder failed: %v", err) + } - t.Run("update none", func(t *testing.T) { - ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{}) - if !ok || err != nil { - t.Fatalf("UpdateFolder failed: %v", err) - } - - folders := db.ListFolders() - if len(folders) != 1 || folders[0].Title != "both" || folders[0].IsExpanded != true { - t.Errorf("expected no changes, got title=%s expanded=%v", folders[0].Title, folders[0].IsExpanded) - } + folders := db.ListFolders() + if len(folders) != 1 || folders[0].Title != "both" || folders[0].IsExpanded != true { + t.Errorf("expected no changes, got title=%s expanded=%v", folders[0].Title, folders[0].IsExpanded) + } + }) }) } diff --git a/src/storage/tests/item_test.go b/src/storage/tests/item_test.go index f7ed569..f598d36 100644 --- a/src/storage/tests/item_test.go +++ b/src/storage/tests/item_test.go @@ -9,6 +9,7 @@ import ( "testing/synctest" "time" + "github.com/nkanaev/yarr/src/storage" "github.com/nkanaev/yarr/src/storage/model" ) @@ -144,374 +145,378 @@ func getItemGuids(items []model.Item) []string { } func TestListItems(t *testing.T) { - db := testDB() - scope := testItemsSetup(db) + dbtest(t, func(t *testing.T, db storage.Storage) { + scope := testItemsSetup(db) - // filter by folder_id + // filter by folder_id - have := getItemGuids(db.ListItems(model.ItemFilter{FolderID: &scope.folder1.Id}, 10, false, false)) - want := []string{"item111", "item112", "item113", "item121", "item122"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + have := getItemGuids(db.ListItems(model.ItemFilter{FolderID: &scope.folder1.Id}, 10, false, false)) + want := []string{"item111", "item112", "item113", "item121", "item122"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } - have = getItemGuids(db.ListItems(model.ItemFilter{FolderID: &scope.folder2.Id}, 10, false, false)) - want = []string{"item211", "item212"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + have = getItemGuids(db.ListItems(model.ItemFilter{FolderID: &scope.folder2.Id}, 10, false, false)) + want = []string{"item211", "item212"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } - // filter by feed_id + // filter by feed_id - have = getItemGuids(db.ListItems(model.ItemFilter{FeedID: &scope.feed11.Id}, 10, false, false)) - want = []string{"item111", "item112", "item113"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + have = getItemGuids(db.ListItems(model.ItemFilter{FeedID: &scope.feed11.Id}, 10, false, false)) + want = []string{"item111", "item112", "item113"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } - have = getItemGuids(db.ListItems(model.ItemFilter{FeedID: &scope.feed01.Id}, 10, false, false)) - want = []string{"item011", "item012", "item013"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + have = getItemGuids(db.ListItems(model.ItemFilter{FeedID: &scope.feed01.Id}, 10, false, false)) + want = []string{"item011", "item012", "item013"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } - // filter by status + // filter by status - var starred model.ItemStatus = model.STARRED - have = getItemGuids(db.ListItems(model.ItemFilter{Status: &starred}, 10, false, false)) - want = []string{"item113", "item212", "item013"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + var starred model.ItemStatus = model.STARRED + have = getItemGuids(db.ListItems(model.ItemFilter{Status: &starred}, 10, false, false)) + want = []string{"item113", "item212", "item013"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } - var unread model.ItemStatus = model.UNREAD - have = getItemGuids(db.ListItems(model.ItemFilter{Status: &unread}, 10, false, false)) - want = []string{"item111", "item121", "item011"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + var unread model.ItemStatus = model.UNREAD + have = getItemGuids(db.ListItems(model.ItemFilter{Status: &unread}, 10, false, false)) + want = []string{"item111", "item121", "item011"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } - // limit + // limit - have = getItemGuids(db.ListItems(model.ItemFilter{}, 2, false, false)) - want = []string{"item111", "item112"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + have = getItemGuids(db.ListItems(model.ItemFilter{}, 2, false, false)) + want = []string{"item111", "item112"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } - // filter by search - search1 := "title111" - have = getItemGuids(db.ListItems(model.ItemFilter{Search: &search1}, 4, true, false)) - want = []string{"item111"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + // filter by search + search1 := "title111" + have = getItemGuids(db.ListItems(model.ItemFilter{Search: &search1}, 4, true, false)) + want = []string{"item111"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } - // sort by date - have = getItemGuids(db.ListItems(model.ItemFilter{}, 4, true, false)) - want = []string{"item013", "item012", "item011", "item212"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + // sort by date + have = getItemGuids(db.ListItems(model.ItemFilter{}, 4, true, false)) + want = []string{"item013", "item012", "item011", "item212"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } + }) } func TestListItemsPaginated(t *testing.T) { - db := testDB() - testItemsSetup(db) + dbtest(t, func(t *testing.T, db storage.Storage) { + testItemsSetup(db) - item012 := getItem(db, "item012") - item121 := getItem(db, "item121") + item012 := getItem(db, "item012") + item121 := getItem(db, "item121") - // all, newest first - have := getItemGuids(db.ListItems(model.ItemFilter{After: &item012.Id}, 3, true, false)) - want := []string{"item011", "item212", "item211"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + // all, newest first + have := getItemGuids(db.ListItems(model.ItemFilter{After: &item012.Id}, 3, true, false)) + want := []string{"item011", "item212", "item211"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } - // unread, newest first - unread := model.UNREAD - have = getItemGuids( - db.ListItems(model.ItemFilter{After: &item012.Id, Status: &unread}, 3, true, false), - ) - want = []string{"item011", "item121", "item111"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + // unread, newest first + unread := model.UNREAD + have = getItemGuids( + db.ListItems(model.ItemFilter{After: &item012.Id, Status: &unread}, 3, true, false), + ) + want = []string{"item011", "item121", "item111"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } - // starred, oldest first - starred := model.STARRED - have = getItemGuids( - db.ListItems(model.ItemFilter{After: &item121.Id, Status: &starred}, 3, false, false), - ) - want = []string{"item212", "item013"} - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + // starred, oldest first + starred := model.STARRED + have = getItemGuids( + db.ListItems(model.ItemFilter{After: &item121.Id, Status: &starred}, 3, false, false), + ) + want = []string{"item212", "item013"} + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } + }) } func TestMarkItemsRead(t *testing.T) { // NOTE: starred items must not be marked as read var read model.ItemStatus = model.READ - db1 := testDB() - testItemsSetup(db1) - db1.MarkItemsRead(model.MarkFilter{}) - have := getItemGuids(db1.ListItems(model.ItemFilter{Status: &read}, 10, false, false)) - want := []string{ - "item111", "item112", "item121", "item122", - "item211", "item011", "item012", - } - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + dbtest(t, func(t *testing.T, db1 storage.Storage) { + testItemsSetup(db1) + db1.MarkItemsRead(model.MarkFilter{}) + have := getItemGuids(db1.ListItems(model.ItemFilter{Status: &read}, 10, false, false)) + want := []string{ + "item111", "item112", "item121", "item122", + "item211", "item011", "item012", + } + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } + }) - db2 := testDB() - scope2 := testItemsSetup(db2) - db2.MarkItemsRead(model.MarkFilter{FolderID: &scope2.folder1.Id}) - have = getItemGuids(db2.ListItems(model.ItemFilter{Status: &read}, 10, false, false)) - want = []string{ - "item111", "item112", "item121", "item122", - "item211", "item012", - } - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + dbtest(t, func(t *testing.T, db2 storage.Storage) { + scope2 := testItemsSetup(db2) + db2.MarkItemsRead(model.MarkFilter{FolderID: &scope2.folder1.Id}) + have = getItemGuids(db2.ListItems(model.ItemFilter{Status: &read}, 10, false, false)) + want = []string{ + "item111", "item112", "item121", "item122", + "item211", "item012", + } + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } + }) - db3 := testDB() - scope3 := testItemsSetup(db3) - db3.MarkItemsRead(model.MarkFilter{FeedID: &scope3.feed11.Id}) - have = getItemGuids(db3.ListItems(model.ItemFilter{Status: &read}, 10, false, false)) - want = []string{ - "item111", "item112", "item122", - "item211", "item012", - } - if !reflect.DeepEqual(have, want) { - t.Logf("want: %#v", want) - t.Logf("have: %#v", have) - t.Fail() - } + dbtest(t, func(t *testing.T, db3 storage.Storage) { + scope3 := testItemsSetup(db3) + db3.MarkItemsRead(model.MarkFilter{FeedID: &scope3.feed11.Id}) + have = getItemGuids(db3.ListItems(model.ItemFilter{Status: &read}, 10, false, false)) + want = []string{ + "item111", "item112", "item122", + "item211", "item012", + } + if !reflect.DeepEqual(have, want) { + t.Logf("want: %#v", want) + t.Logf("have: %#v", have) + t.Fail() + } + }) } func TestDeleteOldItems(t *testing.T) { now := time.Now().UTC() starred := model.STARRED + dbtest(t, func(t *testing.T, db storage.Storage) { - t.Run("keeps at least 50 items", func(t *testing.T) { - db := testDB() - feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) - items := make([]model.Item, 100) - for i := range 100 { - items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Hour * 24)} - } - db.CreateItems(items) + t.Run("keeps at least 50 items", func(t *testing.T) { + feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) + items := make([]model.Item, 100) + for i := range 100 { + items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Hour * 24)} + } + db.CreateItems(items) - // // Set 1 recent (latest), 100 old (100 days ago) - db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now)) - db.db.Exec(`update items set last_arrived = :la where guid != "99"`, sql.Named("la", now.Add(-time.Hour*24*100))) + // // Set 1 recent (latest), 100 old (100 days ago) + db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now)) + db.db.Exec(`update items set last_arrived = :la where guid != "99"`, sql.Named("la", now.Add(-time.Hour*24*100))) - db.DeleteOldItems() - var have int - db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have) - if have != 50 { - t.Errorf("expected 50 items, have %d", have) - } - }) + db.DeleteOldItems() + var have int + db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have) + if have != 50 { + t.Errorf("expected 50 items, have %d", have) + } + }) - t.Run("keeps all less than 90 days old", func(t *testing.T) { - db := testDB() - feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) - items := make([]model.Item, 100) - for i := 0; i < 100; i++ { - items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Second)} - } - db.CreateItems(items) + t.Run("keeps all less than 90 days old", func(t *testing.T) { + feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) + items := make([]model.Item, 100) + for i := 0; i < 100; i++ { + items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Second)} + } + db.CreateItems(items) - // Latest item at "now" - // All others at 80 days ago (keep) - db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now)) - db.db.Exec(`update items set last_arrived = :la where guid != "99"`, sql.Named("la", now.Add(-time.Hour*24*80))) + // Latest item at "now" + // All others at 80 days ago (keep) + db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now)) + db.db.Exec(`update items set last_arrived = :la where guid != "99"`, sql.Named("la", now.Add(-time.Hour*24*80))) - db.DeleteOldItems() - var have int - db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have) - if have != 100 { - t.Errorf("expected 100 items, have %d", have) - } - }) + db.DeleteOldItems() + var have int + db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have) + if have != 100 { + t.Errorf("expected 100 items, have %d", have) + } + }) - t.Run("keeps starred", func(t *testing.T) { - db := testDB() - feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) - items := make([]model.Item, 100) - for i := 0; i < 100; i++ { - items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Second)} - } - db.CreateItems(items) + t.Run("keeps starred", func(t *testing.T) { + feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) + items := make([]model.Item, 100) + for i := 0; i < 100; i++ { + items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Second)} + } + db.CreateItems(items) - // Set all to 100 days ago, except one recent - db.db.Exec(`update items set last_arrived = :la`, sql.Named("la", now.Add(-time.Hour*24*100))) - db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now)) - // Star 10 old items that would otherwise be deleted (rn > 50 and old) - db.db.Exec(`update items set status = :s where cast(guid as integer) < 10`, sql.Named("s", starred)) + // Set all to 100 days ago, except one recent + db.db.Exec(`update items set last_arrived = :la`, sql.Named("la", now.Add(-time.Hour*24*100))) + db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now)) + // Star 10 old items that would otherwise be deleted (rn > 50 and old) + db.db.Exec(`update items set status = :s where cast(guid as integer) < 10`, sql.Named("s", starred)) - db.DeleteOldItems() - var have int - db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have) - // 50 (limit) + 10 (starred) = 60 items should remain. - if have != 60 { - t.Errorf("expected 60 items, have %d", have) - } + db.DeleteOldItems() + var have int + db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have) + // 50 (limit) + 10 (starred) = 60 items should remain. + if have != 60 { + t.Errorf("expected 60 items, have %d", have) + } + }) }) } - - func TestCreateItemsLastArrived(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - db := testDB() - defer db.db.Close() - feed := db.CreateFeed(model.CreateFeedParams{Title: "test feed", FeedLink: "http://example.com/feed"}) + dbtest(t, func(t *testing.T, db storage.Storage) { + synctest.Test(t, func(t *testing.T) { + db := testDB() + defer db.db.Close() + feed := db.CreateFeed(model.CreateFeedParams{Title: "test feed", FeedLink: "http://example.com/feed"}) - item := model.Item{ - GUID: "item1", - FeedId: feed.Id, - Title: "Title 1", - Date: time.Now(), - } + item := model.Item{ + GUID: "item1", + FeedId: feed.Id, + Title: "Title 1", + Date: time.Now(), + } - // 1. Initial creation - db.CreateItems([]model.Item{item}) + // 1. Initial creation + db.CreateItems([]model.Item{item}) - var lastArrived1 time.Time - err := db.db.QueryRow("select last_arrived from items where guid = ?", item.GUID).Scan(&lastArrived1) - if err != nil { - t.Fatal(err) - } + var lastArrived1 time.Time + err := db.db.QueryRow("select last_arrived from items where guid = ?", item.GUID).Scan(&lastArrived1) + if err != nil { + t.Fatal(err) + } - time.Sleep(time.Second * 10) + time.Sleep(time.Second * 10) - // 2. Update on conflict - db.CreateItems([]model.Item{item}) + // 2. Update on conflict + db.CreateItems([]model.Item{item}) - var lastArrived2 time.Time - err = db.db.QueryRow("select last_arrived from items where guid = ?", item.GUID).Scan(&lastArrived2) - if err != nil { - t.Fatal(err) - } + var lastArrived2 time.Time + err = db.db.QueryRow("select last_arrived from items where guid = ?", item.GUID).Scan(&lastArrived2) + if err != nil { + t.Fatal(err) + } - if !lastArrived2.After(lastArrived1) { - t.Errorf("expected last_arrived to be updated. old: %v, new: %v", lastArrived1, lastArrived2) - } + if !lastArrived2.After(lastArrived1) { + t.Errorf("expected last_arrived to be updated. old: %v, new: %v", lastArrived1, lastArrived2) + } + }) }) } func TestSearch(t *testing.T) { - db := testDB() - defer db.Close() - feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) + dbtest(t, func(t *testing.T, db storage.Storage) { + feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) - db.CreateItems([]model.Item{ - { - GUID: "i1", - FeedId: feed.Id, - Title: "Hello World", - Content: "This is a test of the emergency broadcast system.", - }, - { - GUID: "i2", - FeedId: feed.Id, - Title: "FTS5 Unicode", - Content: "Unicode support with characters like: Привет, 世界, 🚀", - }, - { - GUID: "i3", - FeedId: feed.Id, - Title: "Hidden Tag", - Content: `
Don't find me by my class name
`, - }, + db.CreateItems([]model.Item{ + { + GUID: "i1", + FeedId: feed.Id, + Title: "Hello World", + Content: "This is a test of the emergency broadcast system.", + }, + { + GUID: "i2", + FeedId: feed.Id, + Title: "FTS5 Unicode", + Content: "Unicode support with characters like: Привет, 世界, 🚀", + }, + { + GUID: "i3", + FeedId: feed.Id, + Title: "Hidden Tag", + Content: `
Don't find me by my class name
`, + }, + }) + + // 1. Basic search + s1 := "emergency" + have := getItemGuids(db.ListItems(model.ItemFilter{Search: &s1}, 10, true, false)) + if !reflect.DeepEqual(have, []string{"i1"}) { + t.Errorf("basic search failed: expected [i1], got %v", have) + } + + // 2. HTML stripping: Should find text, but NOT the tags + s2 := "test" + have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s2}, 10, true, false)) + if !reflect.DeepEqual(have, []string{"i1"}) { + t.Errorf("html text search failed: expected [i1], got %v", have) + } + + s3 := "secret-class" + have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s3}, 10, true, false)) + if len(have) > 0 { + t.Errorf("html tag search should have failed but found: %v", have) + } + + // 3. Multi-word (AND) + s4 := "broadcast system" + have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s4}, 10, true, false)) + if !reflect.DeepEqual(have, []string{"i1"}) { + t.Errorf("multi-word search failed: expected [i1], got %v", have) + } + + // 4. Unicode + s5 := "Привет" + have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s5}, 10, true, false)) + if !reflect.DeepEqual(have, []string{"i2"}) { + t.Errorf("unicode search failed: expected [i2], got %v", have) + } + + s6 := "世界" + have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s6}, 10, true, false)) + if !reflect.DeepEqual(have, []string{"i2"}) { + t.Errorf("unicode search (CJK) failed: expected [i2], got %v", have) + } + + // 5. Trigger: Update + db.db.Exec("update items set title = 'Updated Title' where guid = 'i1'") + s7 := "Updated" + have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s7}, 10, true, false)) + if !reflect.DeepEqual(have, []string{"i1"}) { + t.Errorf("update trigger failed: expected [i1], got %v", have) + } + + // 6. Trigger: Delete + db.db.Exec("delete from items where guid = 'i1'") + have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s7}, 10, true, false)) + if len(have) > 0 { + t.Errorf("delete trigger failed: found deleted item: %v", have) + } }) - - // 1. Basic search - s1 := "emergency" - have := getItemGuids(db.ListItems(model.ItemFilter{Search: &s1}, 10, true, false)) - if !reflect.DeepEqual(have, []string{"i1"}) { - t.Errorf("basic search failed: expected [i1], got %v", have) - } - - // 2. HTML stripping: Should find text, but NOT the tags - s2 := "test" - have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s2}, 10, true, false)) - if !reflect.DeepEqual(have, []string{"i1"}) { - t.Errorf("html text search failed: expected [i1], got %v", have) - } - - s3 := "secret-class" - have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s3}, 10, true, false)) - if len(have) > 0 { - t.Errorf("html tag search should have failed but found: %v", have) - } - - // 3. Multi-word (AND) - s4 := "broadcast system" - have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s4}, 10, true, false)) - if !reflect.DeepEqual(have, []string{"i1"}) { - t.Errorf("multi-word search failed: expected [i1], got %v", have) - } - - // 4. Unicode - s5 := "Привет" - have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s5}, 10, true, false)) - if !reflect.DeepEqual(have, []string{"i2"}) { - t.Errorf("unicode search failed: expected [i2], got %v", have) - } - - s6 := "世界" - have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s6}, 10, true, false)) - if !reflect.DeepEqual(have, []string{"i2"}) { - t.Errorf("unicode search (CJK) failed: expected [i2], got %v", have) - } - - // 5. Trigger: Update - db.db.Exec("update items set title = 'Updated Title' where guid = 'i1'") - s7 := "Updated" - have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s7}, 10, true, false)) - if !reflect.DeepEqual(have, []string{"i1"}) { - t.Errorf("update trigger failed: expected [i1], got %v", have) - } - - // 6. Trigger: Delete - db.db.Exec("delete from items where guid = 'i1'") - have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s7}, 10, true, false)) - if len(have) > 0 { - t.Errorf("delete trigger failed: found deleted item: %v", have) - } } diff --git a/src/storage/tests/settings_test.go b/src/storage/tests/settings_test.go index 57aecb7..adb33d2 100644 --- a/src/storage/tests/settings_test.go +++ b/src/storage/tests/settings_test.go @@ -5,148 +5,148 @@ import ( "strings" "testing" + "github.com/nkanaev/yarr/src/storage" "github.com/nkanaev/yarr/src/storage/model" ) func TestSettingsDefaults(t *testing.T) { - s := testDB() - defer s.Close() + dbtest(t, func(t *testing.T, s storage.Storage) { + settings := s.GetSettings() + defaults := settingsDefaults() - settings := s.GetSettings() - defaults := settingsDefaults() - - if !reflect.DeepEqual(settings, defaults) { - t.Errorf("expected defaults %+v, got %+v", defaults, settings) - } + if !reflect.DeepEqual(settings, defaults) { + t.Errorf("expected defaults %+v, got %+v", defaults, settings) + } + }) } func TestUpdateSettings(t *testing.T) { - s := testDB() - defer s.Close() + dbtest(t, func(t *testing.T, s storage.Storage) { - params := model.UpdateSettingsParams{ - ThemeName: ptr("night"), - FeedListWidth: ptr(400), - RefreshRate: ptr(int64(15)), - } + params := model.UpdateSettingsParams{ + ThemeName: ptr("night"), + FeedListWidth: ptr(400), + RefreshRate: ptr(int64(15)), + } - if ok := s.UpdateSettings(params); !ok { - t.Fatal("UpdateSettings failed") - } + if ok := s.UpdateSettings(params); !ok { + t.Fatal("UpdateSettings failed") + } - settings := s.GetSettings() + settings := s.GetSettings() - if settings.ThemeName != "night" { - t.Errorf("expected theme_name night, got %s", settings.ThemeName) - } - if settings.FeedListWidth != 400 { - t.Errorf("expected feed_list_width 400, got %d", settings.FeedListWidth) - } - if settings.RefreshRate != 15 { - t.Errorf("expected refresh_rate 15, got %d", settings.RefreshRate) - } + if settings.ThemeName != "night" { + t.Errorf("expected theme_name night, got %s", settings.ThemeName) + } + if settings.FeedListWidth != 400 { + t.Errorf("expected feed_list_width 400, got %d", settings.FeedListWidth) + } + if settings.RefreshRate != 15 { + t.Errorf("expected refresh_rate 15, got %d", settings.RefreshRate) + } + }) } func TestGetSettings(t *testing.T) { - s := testDB() - defer s.Close() + dbtest(t, func(t *testing.T, s storage.Storage) { - s.UpdateSettings(model.UpdateSettingsParams{Language: ptr("fr")}) + s.UpdateSettings(model.UpdateSettingsParams{Language: ptr("fr")}) - settings := s.GetSettings() - if settings.Language != "fr" { - t.Errorf("expected fr, got %v", settings.Language) - } - if settings.ThemeName != "light" { - t.Errorf("expected light, got %v", settings.ThemeName) - } + settings := s.GetSettings() + if settings.Language != "fr" { + t.Errorf("expected fr, got %v", settings.Language) + } + if settings.ThemeName != "light" { + t.Errorf("expected light, got %v", settings.ThemeName) + } + }) } func TestSettingsExhaustive(t *testing.T) { - s := testDB() - defer s.Close() + dbtest(t, func(t *testing.T, s storage.Storage) { - settingsType := reflect.TypeOf(model.Settings{}) - paramsType := reflect.TypeOf(model.UpdateSettingsParams{}) - - settings := s.GetSettings() - m := settings.Map() + settingsType := reflect.TypeOf(model.Settings{}) + paramsType := reflect.TypeOf(model.UpdateSettingsParams{}) - for i := 0; i < settingsType.NumField(); i++ { - field := settingsType.Field(i) - jsonTag := field.Tag.Get("json") - if jsonTag == "" { - t.Errorf("Field %s missing json tag", field.Name) - continue - } - // json tags might have options like "name,omitempty", take only the first part - jsonKey := strings.Split(jsonTag, ",")[0] + settings := s.GetSettings() + m := settings.Map() - // 1. Check Map() - if _, ok := m[jsonKey]; !ok { - t.Errorf("Key %q (from field %s) missing from Settings.Map()", jsonKey, field.Name) - } + for i := 0; i < settingsType.NumField(); i++ { + field := settingsType.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag == "" { + t.Errorf("Field %s missing json tag", field.Name) + continue + } + // json tags might have options like "name,omitempty", take only the first part + jsonKey := strings.Split(jsonTag, ",")[0] - // 2. Check UpdateSettingsParams - foundInParams := false - for j := 0; j < paramsType.NumField(); j++ { - pField := paramsType.Field(j) - pJsonTag := strings.Split(pField.Tag.Get("json"), ",")[0] - if pJsonTag == jsonKey { - foundInParams = true - // Also check it's a pointer - if pField.Type.Kind() != reflect.Ptr { - t.Errorf("Field %s in UpdateSettingsParams should be a pointer", pField.Name) + // 1. Check Map() + if _, ok := m[jsonKey]; !ok { + t.Errorf("Key %q (from field %s) missing from Settings.Map()", jsonKey, field.Name) + } + + // 2. Check UpdateSettingsParams + foundInParams := false + for j := 0; j < paramsType.NumField(); j++ { + pField := paramsType.Field(j) + pJsonTag := strings.Split(pField.Tag.Get("json"), ",")[0] + if pJsonTag == jsonKey { + foundInParams = true + // Also check it's a pointer + if pField.Type.Kind() != reflect.Ptr { + t.Errorf("Field %s in UpdateSettingsParams should be a pointer", pField.Name) + } + break } - break } - } - if !foundInParams { - t.Errorf("Key %q (from field %s) missing from UpdateSettingsParams", jsonKey, field.Name) - } + if !foundInParams { + t.Errorf("Key %q (from field %s) missing from UpdateSettingsParams", jsonKey, field.Name) + } - // 3. Test round-trip update - // We'll create a new UpdateSettingsParams and set ONLY this field - paramsValue := reflect.New(paramsType).Elem() - for j := 0; j < paramsType.NumField(); j++ { - pField := paramsType.Field(j) - pJsonTag := strings.Split(pField.Tag.Get("json"), ",")[0] - if pJsonTag == jsonKey { - // Create a new value of the underlying type - val := reflect.New(field.Type).Elem() - switch field.Type.Kind() { - case reflect.String: - val.SetString("test_" + jsonKey) - case reflect.Int, reflect.Int64: - val.SetInt(42) - case reflect.Bool: - val.SetBool(false) + // 3. Test round-trip update + // We'll create a new UpdateSettingsParams and set ONLY this field + paramsValue := reflect.New(paramsType).Elem() + for j := 0; j < paramsType.NumField(); j++ { + pField := paramsType.Field(j) + pJsonTag := strings.Split(pField.Tag.Get("json"), ",")[0] + if pJsonTag == jsonKey { + // Create a new value of the underlying type + val := reflect.New(field.Type).Elem() + switch field.Type.Kind() { + case reflect.String: + val.SetString("test_" + jsonKey) + case reflect.Int, reflect.Int64: + val.SetInt(42) + case reflect.Bool: + val.SetBool(false) + } + paramsValue.Field(j).Set(val.Addr()) + break + } + } + + if ok := s.UpdateSettings(paramsValue.Interface().(model.UpdateSettingsParams)); !ok { + t.Errorf("UpdateSettings failed for %q", jsonKey) + } + + updated := s.GetSettings() + updatedValue := reflect.ValueOf(updated).Field(i) + + switch field.Type.Kind() { + case reflect.String: + if updatedValue.String() != "test_"+jsonKey { + t.Errorf("Round-trip failed for %q: expected %q, got %q (check UpdateSettings/GetSettings switch)", jsonKey, "test_"+jsonKey, updatedValue.String()) + } + case reflect.Int, reflect.Int64: + if updatedValue.Int() != 42 { + t.Errorf("Round-trip failed for %q: expected 42, got %d (check UpdateSettings/GetSettings switch)", jsonKey, updatedValue.Int()) + } + case reflect.Bool: + if updatedValue.Bool() != false { + t.Errorf("Round-trip failed for %q: expected false, got %v (check UpdateSettings/GetSettings switch)", jsonKey, updatedValue.Bool()) } - paramsValue.Field(j).Set(val.Addr()) - break } } - - if ok := s.UpdateSettings(paramsValue.Interface().(model.UpdateSettingsParams)); !ok { - t.Errorf("UpdateSettings failed for %q", jsonKey) - } - - updated := s.GetSettings() - updatedValue := reflect.ValueOf(updated).Field(i) - - switch field.Type.Kind() { - case reflect.String: - if updatedValue.String() != "test_"+jsonKey { - t.Errorf("Round-trip failed for %q: expected %q, got %q (check UpdateSettings/GetSettings switch)", jsonKey, "test_"+jsonKey, updatedValue.String()) - } - case reflect.Int, reflect.Int64: - if updatedValue.Int() != 42 { - t.Errorf("Round-trip failed for %q: expected 42, got %d (check UpdateSettings/GetSettings switch)", jsonKey, updatedValue.Int()) - } - case reflect.Bool: - if updatedValue.Bool() != false { - t.Errorf("Round-trip failed for %q: expected false, got %v (check UpdateSettings/GetSettings switch)", jsonKey, updatedValue.Bool()) - } - } - } + }) } diff --git a/src/storage/tests/storage_test.go b/src/storage/tests/storage_test.go index a6a85bb..0de1bb4 100644 --- a/src/storage/tests/storage_test.go +++ b/src/storage/tests/storage_test.go @@ -10,8 +10,8 @@ import ( ) func dbtest(t *testing.T, testcase func(t *testing.T, db storage.Storage)) { - testurls := map[string]string { - "sqlite": ":memory:", + testurls := map[string]string{ + "sqlite": ":memory:", "postgres": "postgres://postgres:postgres@localhost:5432/yarr_test", } for testname, url := range testurls {