wrap tests in dbtest

This commit is contained in:
nkanaev
2026-06-14 14:56:25 +01:00
parent 32cfc3bc1a
commit 4dbedb2f99
6 changed files with 652 additions and 642 deletions

View File

@@ -4,89 +4,95 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/nkanaev/yarr/src/storage"
"github.com/nkanaev/yarr/src/storage/model" "github.com/nkanaev/yarr/src/storage/model"
) )
func TestCreateFeed(t *testing.T) { func TestCreateFeed(t *testing.T) {
db := testDB() dbtest(t, func(t *testing.T, db storage.Storage) {
feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example.com/feed.xml"}) feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example.com/feed.xml"})
if feed1 == nil || feed1.Id == 0 { if feed1 == nil || feed1.Id == 0 {
t.Fatal("expected feed") t.Fatal("expected feed")
} }
feed2 := db.GetFeed(feed1.Id) feed2 := db.GetFeed(feed1.Id)
if feed2 == nil || !reflect.DeepEqual(feed1, feed2) { if feed2 == nil || !reflect.DeepEqual(feed1, feed2) {
t.Fatal("invalid feed") t.Fatal("invalid feed")
} }
})
} }
func TestCreateFeedSameLink(t *testing.T) { func TestCreateFeedSameLink(t *testing.T) {
db := testDB() dbtest(t, func(t *testing.T, db storage.Storage) {
feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", FeedLink: "http://example1.com/feed.xml"}) feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", FeedLink: "http://example1.com/feed.xml"})
if feed1 == nil || feed1.Id == 0 { if feed1 == nil || feed1.Id == 0 {
t.Fatal("expected feed") t.Fatal("expected feed")
} }
for range 10 { for range 10 {
db.CreateFeed(model.CreateFeedParams{Title: "title", FeedLink: "http://example2.com/feed.xml"}) db.CreateFeed(model.CreateFeedParams{Title: "title", FeedLink: "http://example2.com/feed.xml"})
} }
feed2 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example1.com/feed.xml"}) feed2 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example1.com/feed.xml"})
if feed1.Id != feed2.Id { if feed1.Id != feed2.Id {
t.Fatalf("expected the same feed.\nwant: %#v\nhave: %#v", feed1, feed2) t.Fatalf("expected the same feed.\nwant: %#v\nhave: %#v", feed1, feed2)
} }
})
} }
func TestReadFeed(t *testing.T) { func TestReadFeed(t *testing.T) {
db := testDB() dbtest(t, func(t *testing.T, db storage.Storage) {
if db.GetFeed(100500) != nil { if db.GetFeed(100500) != nil {
t.Fatal("cannot get nonexistent feed") t.Fatal("cannot get nonexistent feed")
} }
feed1 := db.CreateFeed(model.CreateFeedParams{Title: "feed 1", Link: "http://example1.com", FeedLink: "http://example1.com/feed.xml"}) feed1 := db.CreateFeed(model.CreateFeedParams{Title: "feed 1", Link: "http://example1.com", FeedLink: "http://example1.com/feed.xml"})
feed2 := db.CreateFeed(model.CreateFeedParams{Title: "feed 2", Link: "http://example2.com", FeedLink: "http://example2.com/feed.xml"}) feed2 := db.CreateFeed(model.CreateFeedParams{Title: "feed 2", Link: "http://example2.com", FeedLink: "http://example2.com/feed.xml"})
feeds := db.ListFeeds() feeds := db.ListFeeds()
if !reflect.DeepEqual(feeds, []model.Feed{*feed1, *feed2}) { if !reflect.DeepEqual(feeds, []model.Feed{*feed1, *feed2}) {
t.Fatalf("invalid feed list: %#v", feeds) t.Fatalf("invalid feed list: %#v", feeds)
} }
})
} }
func TestUpdateFeed(t *testing.T) { func TestUpdateFeed(t *testing.T) {
db := testDB() dbtest(t, func(t *testing.T, db storage.Storage) {
feed1 := db.CreateFeed(model.CreateFeedParams{Title: "feed 1", Link: "http://example1.com", FeedLink: "http://example1.com/feed.xml"}) feed1 := db.CreateFeed(model.CreateFeedParams{Title: "feed 1", Link: "http://example1.com", FeedLink: "http://example1.com/feed.xml"})
folder := db.CreateFolder("test") folder := db.CreateFolder("test")
icon := []byte("icon") icon := []byte("icon")
title := "newtitle" title := "newtitle"
db.UpdateFeed(feed1.Id, model.UpdateFeedParams{ db.UpdateFeed(feed1.Id, model.UpdateFeedParams{
Title: &title, Title: &title,
FolderID: model.SetNullable(&folder.Id), FolderID: model.SetNullable(&folder.Id),
Icon: model.SetNullable(&icon), Icon: model.SetNullable(&icon),
})
feed2 := db.GetFeed(feed1.Id)
if feed2.Title != "newtitle" {
t.Error("invalid title")
}
if feed2.FolderId == nil || *feed2.FolderId != folder.Id {
t.Error("invalid folder")
}
if !feed2.HasIcon || string(*feed2.Icon) != "icon" {
t.Error("invalid icon")
}
}) })
feed2 := db.GetFeed(feed1.Id)
if feed2.Title != "newtitle" {
t.Error("invalid title")
}
if feed2.FolderId == nil || *feed2.FolderId != folder.Id {
t.Error("invalid folder")
}
if !feed2.HasIcon || string(*feed2.Icon) != "icon" {
t.Error("invalid icon")
}
} }
func TestDeleteFeed(t *testing.T) { func TestDeleteFeed(t *testing.T) {
db := testDB() dbtest(t, func(t *testing.T, db storage.Storage) {
feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example.com/feed.xml"}) feed1 := db.CreateFeed(model.CreateFeedParams{Title: "title", Link: "http://example.com", FeedLink: "http://example.com/feed.xml"})
if db.DeleteFeed(100500) { if db.DeleteFeed(100500) {
t.Error("cannot delete what does not exist") t.Error("cannot delete what does not exist")
} }
if !db.DeleteFeed(feed1.Id) { if !db.DeleteFeed(feed1.Id) {
t.Fatal("did not delete existing feed") t.Fatal("did not delete existing feed")
} }
if db.GetFeed(feed1.Id) != nil { if db.GetFeed(feed1.Id) != nil {
t.Fatal("feed still exists") t.Fatal("feed still exists")
} }
})
} }

View File

@@ -4,126 +4,123 @@ import (
"testing" "testing"
"time" "time"
"github.com/nkanaev/yarr/src/storage"
"github.com/nkanaev/yarr/src/storage/model" "github.com/nkanaev/yarr/src/storage/model"
) )
func TestUpdateFeedState_Full(t *testing.T) { func TestUpdateFeedState_Full(t *testing.T) {
s := testDB() dbtest(t, func(t *testing.T, s storage.Storage) {
defer s.Close() f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"})
f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"}) now := time.Now().UTC().Truncate(time.Second)
errMsg := "error"
lmod := "today"
etag := "v1"
now := time.Now().UTC().Truncate(time.Second) ok, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{
errMsg := "error" LastRefreshed: &now,
lmod := "today" LastError: &errMsg,
etag := "v1" HTTPLastModified: &lmod,
HTTPEtag: &etag,
})
if err != nil {
t.Fatal(err)
}
if !ok {
t.Error("expected true")
}
ok, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{ state, err := s.GetFeedState(f.Id)
LastRefreshed: &now, if err != nil {
LastError: &errMsg, t.Fatal(err)
HTTPLastModified: &lmod, }
HTTPEtag: &etag, if state == nil {
t.Fatal("expected state, got nil")
}
if !state.LastRefreshed.Equal(now) {
t.Errorf("expected %v, got %v", now, state.LastRefreshed)
}
if state.LastError != errMsg {
t.Errorf("expected %s, got %v", errMsg, state.LastError)
}
if state.HTTPLastModified != lmod {
t.Errorf("expected %s, got %s", lmod, state.HTTPLastModified)
}
if state.HTTPEtag != etag {
t.Errorf("expected %s, got %s", etag, state.HTTPEtag)
}
}) })
if err != nil {
t.Fatal(err)
}
if !ok {
t.Error("expected true")
}
state, err := s.GetFeedState(f.Id)
if err != nil {
t.Fatal(err)
}
if state == nil {
t.Fatal("expected state, got nil")
}
if !state.LastRefreshed.Equal(now) {
t.Errorf("expected %v, got %v", now, state.LastRefreshed)
}
if state.LastError != errMsg {
t.Errorf("expected %s, got %v", errMsg, state.LastError)
}
if state.HTTPLastModified != lmod {
t.Errorf("expected %s, got %s", lmod, state.HTTPLastModified)
}
if state.HTTPEtag != etag {
t.Errorf("expected %s, got %s", etag, state.HTTPEtag)
}
} }
func TestUpdateFeedState_Partial(t *testing.T) { func TestUpdateFeedState_Partial(t *testing.T) {
s := testDB() dbtest(t, func(t *testing.T, s storage.Storage) {
defer s.Close() f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"})
etag := "v1"
s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{HTTPEtag: &etag})
f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"}) newErr := "new error"
etag := "v1" _, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{
s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{HTTPEtag: &etag}) LastError: &newErr,
})
if err != nil {
t.Fatal(err)
}
newErr := "new error" state, err := s.GetFeedState(f.Id)
_, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{ if err != nil {
LastError: &newErr, t.Fatal(err)
}
if state.LastError != newErr {
t.Errorf("expected %s, got %v", newErr, state.LastError)
}
if state.HTTPEtag != etag {
t.Errorf("etag should be unchanged, got %s", state.HTTPEtag)
}
}) })
if err != nil {
t.Fatal(err)
}
state, err := s.GetFeedState(f.Id)
if err != nil {
t.Fatal(err)
}
if state.LastError != newErr {
t.Errorf("expected %s, got %v", newErr, state.LastError)
}
if state.HTTPEtag != etag {
t.Errorf("etag should be unchanged, got %s", state.HTTPEtag)
}
} }
func TestUpdateFeedState_ClearError(t *testing.T) { func TestUpdateFeedState_ClearError(t *testing.T) {
s := testDB() dbtest(t, func(t *testing.T, s storage.Storage) {
defer s.Close() f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"})
errMsg := "error"
s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{LastError: &errMsg})
f := s.CreateFeed(model.CreateFeedParams{Title: "Test", FeedLink: "http://example.com"}) empty := ""
errMsg := "error" _, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{
s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{LastError: &errMsg}) LastError: &empty,
})
if err != nil {
t.Fatal(err)
}
empty := "" state, err := s.GetFeedState(f.Id)
_, err := s.UpdateFeedState(f.Id, model.UpdateFeedStateParams{ if err != nil {
LastError: &empty, t.Fatal(err)
}
if state.LastError != "" {
t.Errorf("expected empty error string, got %v", state.LastError)
}
}) })
if err != nil {
t.Fatal(err)
}
state, err := s.GetFeedState(f.Id)
if err != nil {
t.Fatal(err)
}
if state.LastError != "" {
t.Errorf("expected empty error string, got %v", state.LastError)
}
} }
func TestListFeedStates(t *testing.T) { func TestListFeedStates(t *testing.T) {
s := testDB() dbtest(t, func(t *testing.T, s storage.Storage) {
defer s.Close() f1 := s.CreateFeed(model.CreateFeedParams{Title: "F1", FeedLink: "L1"})
f2 := s.CreateFeed(model.CreateFeedParams{Title: "F2", FeedLink: "L2"})
f1 := s.CreateFeed(model.CreateFeedParams{Title: "F1", FeedLink: "L1"}) errMsg := "fail"
f2 := s.CreateFeed(model.CreateFeedParams{Title: "F2", FeedLink: "L2"}) s.UpdateFeedState(f1.Id, model.UpdateFeedStateParams{LastError: &errMsg})
s.UpdateFeedState(f2.Id, model.UpdateFeedStateParams{HTTPEtag: ptr("e")})
errMsg := "fail" states, err := s.ListFeedStates()
s.UpdateFeedState(f1.Id, model.UpdateFeedStateParams{LastError: &errMsg}) if err != nil {
s.UpdateFeedState(f2.Id, model.UpdateFeedStateParams{HTTPEtag: ptr("e")}) t.Fatal(err)
}
states, err := s.ListFeedStates() if len(states) != 2 {
if err != nil { t.Errorf("expected 2 states, got %d", len(states))
t.Fatal(err) }
} })
if len(states) != 2 {
t.Errorf("expected 2 states, got %d", len(states))
}
} }
func ptr[T any](v T) *T { func ptr[T any](v T) *T {

View File

@@ -3,78 +3,80 @@ package tests
import ( import (
"testing" "testing"
"github.com/nkanaev/yarr/src/storage"
"github.com/nkanaev/yarr/src/storage/model" "github.com/nkanaev/yarr/src/storage/model"
) )
func TestUpdateFolder(t *testing.T) { func TestUpdateFolder(t *testing.T) {
db := testDB() dbtest(t, func(t *testing.T, db storage.Storage) {
folder := db.CreateFolder("old title") folder := db.CreateFolder("old title")
if folder.IsExpanded != true { if folder.IsExpanded != true {
t.Fatal("expected folder to be expanded by default") t.Fatal("expected folder to be expanded by default")
} }
t.Run("rename only", func(t *testing.T) { t.Run("rename only", func(t *testing.T) {
newTitle := "new title" newTitle := "new title"
ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{ ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{
Title: &newTitle, Title: &newTitle,
})
if !ok || err != nil {
t.Fatalf("UpdateFolder failed: %v", err)
}
folders := db.ListFolders()
if len(folders) != 1 || folders[0].Title != "new title" {
t.Errorf("expected title to be updated, got %s", folders[0].Title)
}
if folders[0].IsExpanded != true {
t.Error("expected expansion state to remain unchanged")
}
}) })
if !ok || err != nil {
t.Fatalf("UpdateFolder failed: %v", err)
}
folders := db.ListFolders() t.Run("toggle expanded only", func(t *testing.T) {
if len(folders) != 1 || folders[0].Title != "new title" { isExpanded := false
t.Errorf("expected title to be updated, got %s", folders[0].Title) ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{
} IsExpanded: &isExpanded,
if folders[0].IsExpanded != true { })
t.Error("expected expansion state to remain unchanged") if !ok || err != nil {
} t.Fatalf("UpdateFolder failed: %v", err)
}) }
t.Run("toggle expanded only", func(t *testing.T) { folders := db.ListFolders()
isExpanded := false if len(folders) != 1 || folders[0].IsExpanded != false {
ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{ t.Errorf("expected is_expanded to be false, got %v", folders[0].IsExpanded)
IsExpanded: &isExpanded, }
if folders[0].Title != "new title" {
t.Error("expected title to remain unchanged")
}
}) })
if !ok || err != nil {
t.Fatalf("UpdateFolder failed: %v", err)
}
folders := db.ListFolders() t.Run("update both", func(t *testing.T) {
if len(folders) != 1 || folders[0].IsExpanded != false { bothTitle := "both"
t.Errorf("expected is_expanded to be false, got %v", folders[0].IsExpanded) isExpanded := true
} ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{
if folders[0].Title != "new title" { Title: &bothTitle,
t.Error("expected title to remain unchanged") IsExpanded: &isExpanded,
} })
}) if !ok || err != nil {
t.Fatalf("UpdateFolder failed: %v", err)
}
t.Run("update both", func(t *testing.T) { folders := db.ListFolders()
bothTitle := "both" if len(folders) != 1 || folders[0].Title != "both" || folders[0].IsExpanded != true {
isExpanded := true t.Errorf("expected both to be updated, got title=%s expanded=%v", folders[0].Title, folders[0].IsExpanded)
ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{ }
Title: &bothTitle,
IsExpanded: &isExpanded,
}) })
if !ok || err != nil {
t.Fatalf("UpdateFolder failed: %v", err)
}
folders := db.ListFolders() t.Run("update none", func(t *testing.T) {
if len(folders) != 1 || folders[0].Title != "both" || folders[0].IsExpanded != true { ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{})
t.Errorf("expected both to be updated, got title=%s expanded=%v", folders[0].Title, folders[0].IsExpanded) if !ok || err != nil {
} t.Fatalf("UpdateFolder failed: %v", err)
}) }
t.Run("update none", func(t *testing.T) { folders := db.ListFolders()
ok, err := db.UpdateFolder(folder.Id, model.UpdateFolderParams{}) if len(folders) != 1 || folders[0].Title != "both" || folders[0].IsExpanded != true {
if !ok || err != nil { t.Errorf("expected no changes, got title=%s expanded=%v", folders[0].Title, folders[0].IsExpanded)
t.Fatalf("UpdateFolder failed: %v", err) }
} })
folders := db.ListFolders()
if len(folders) != 1 || folders[0].Title != "both" || folders[0].IsExpanded != true {
t.Errorf("expected no changes, got title=%s expanded=%v", folders[0].Title, folders[0].IsExpanded)
}
}) })
} }

View File

@@ -9,6 +9,7 @@ import (
"testing/synctest" "testing/synctest"
"time" "time"
"github.com/nkanaev/yarr/src/storage"
"github.com/nkanaev/yarr/src/storage/model" "github.com/nkanaev/yarr/src/storage/model"
) )
@@ -144,374 +145,378 @@ func getItemGuids(items []model.Item) []string {
} }
func TestListItems(t *testing.T) { func TestListItems(t *testing.T) {
db := testDB() dbtest(t, func(t *testing.T, db storage.Storage) {
scope := testItemsSetup(db) scope := testItemsSetup(db)
// filter by folder_id // filter by folder_id
have := getItemGuids(db.ListItems(model.ItemFilter{FolderID: &scope.folder1.Id}, 10, false, false)) have := getItemGuids(db.ListItems(model.ItemFilter{FolderID: &scope.folder1.Id}, 10, false, false))
want := []string{"item111", "item112", "item113", "item121", "item122"} want := []string{"item111", "item112", "item113", "item121", "item122"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
have = getItemGuids(db.ListItems(model.ItemFilter{FolderID: &scope.folder2.Id}, 10, false, false)) have = getItemGuids(db.ListItems(model.ItemFilter{FolderID: &scope.folder2.Id}, 10, false, false))
want = []string{"item211", "item212"} want = []string{"item211", "item212"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
// filter by feed_id // filter by feed_id
have = getItemGuids(db.ListItems(model.ItemFilter{FeedID: &scope.feed11.Id}, 10, false, false)) have = getItemGuids(db.ListItems(model.ItemFilter{FeedID: &scope.feed11.Id}, 10, false, false))
want = []string{"item111", "item112", "item113"} want = []string{"item111", "item112", "item113"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
have = getItemGuids(db.ListItems(model.ItemFilter{FeedID: &scope.feed01.Id}, 10, false, false)) have = getItemGuids(db.ListItems(model.ItemFilter{FeedID: &scope.feed01.Id}, 10, false, false))
want = []string{"item011", "item012", "item013"} want = []string{"item011", "item012", "item013"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
// filter by status // filter by status
var starred model.ItemStatus = model.STARRED var starred model.ItemStatus = model.STARRED
have = getItemGuids(db.ListItems(model.ItemFilter{Status: &starred}, 10, false, false)) have = getItemGuids(db.ListItems(model.ItemFilter{Status: &starred}, 10, false, false))
want = []string{"item113", "item212", "item013"} want = []string{"item113", "item212", "item013"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
var unread model.ItemStatus = model.UNREAD var unread model.ItemStatus = model.UNREAD
have = getItemGuids(db.ListItems(model.ItemFilter{Status: &unread}, 10, false, false)) have = getItemGuids(db.ListItems(model.ItemFilter{Status: &unread}, 10, false, false))
want = []string{"item111", "item121", "item011"} want = []string{"item111", "item121", "item011"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
// limit // limit
have = getItemGuids(db.ListItems(model.ItemFilter{}, 2, false, false)) have = getItemGuids(db.ListItems(model.ItemFilter{}, 2, false, false))
want = []string{"item111", "item112"} want = []string{"item111", "item112"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
// filter by search // filter by search
search1 := "title111" search1 := "title111"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &search1}, 4, true, false)) have = getItemGuids(db.ListItems(model.ItemFilter{Search: &search1}, 4, true, false))
want = []string{"item111"} want = []string{"item111"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
// sort by date // sort by date
have = getItemGuids(db.ListItems(model.ItemFilter{}, 4, true, false)) have = getItemGuids(db.ListItems(model.ItemFilter{}, 4, true, false))
want = []string{"item013", "item012", "item011", "item212"} want = []string{"item013", "item012", "item011", "item212"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
})
} }
func TestListItemsPaginated(t *testing.T) { func TestListItemsPaginated(t *testing.T) {
db := testDB() dbtest(t, func(t *testing.T, db storage.Storage) {
testItemsSetup(db) testItemsSetup(db)
item012 := getItem(db, "item012") item012 := getItem(db, "item012")
item121 := getItem(db, "item121") item121 := getItem(db, "item121")
// all, newest first // all, newest first
have := getItemGuids(db.ListItems(model.ItemFilter{After: &item012.Id}, 3, true, false)) have := getItemGuids(db.ListItems(model.ItemFilter{After: &item012.Id}, 3, true, false))
want := []string{"item011", "item212", "item211"} want := []string{"item011", "item212", "item211"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
// unread, newest first // unread, newest first
unread := model.UNREAD unread := model.UNREAD
have = getItemGuids( have = getItemGuids(
db.ListItems(model.ItemFilter{After: &item012.Id, Status: &unread}, 3, true, false), db.ListItems(model.ItemFilter{After: &item012.Id, Status: &unread}, 3, true, false),
) )
want = []string{"item011", "item121", "item111"} want = []string{"item011", "item121", "item111"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
// starred, oldest first // starred, oldest first
starred := model.STARRED starred := model.STARRED
have = getItemGuids( have = getItemGuids(
db.ListItems(model.ItemFilter{After: &item121.Id, Status: &starred}, 3, false, false), db.ListItems(model.ItemFilter{After: &item121.Id, Status: &starred}, 3, false, false),
) )
want = []string{"item212", "item013"} want = []string{"item212", "item013"}
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
})
} }
func TestMarkItemsRead(t *testing.T) { func TestMarkItemsRead(t *testing.T) {
// NOTE: starred items must not be marked as read // NOTE: starred items must not be marked as read
var read model.ItemStatus = model.READ var read model.ItemStatus = model.READ
db1 := testDB() dbtest(t, func(t *testing.T, db1 storage.Storage) {
testItemsSetup(db1) testItemsSetup(db1)
db1.MarkItemsRead(model.MarkFilter{}) db1.MarkItemsRead(model.MarkFilter{})
have := getItemGuids(db1.ListItems(model.ItemFilter{Status: &read}, 10, false, false)) have := getItemGuids(db1.ListItems(model.ItemFilter{Status: &read}, 10, false, false))
want := []string{ want := []string{
"item111", "item112", "item121", "item122", "item111", "item112", "item121", "item122",
"item211", "item011", "item012", "item211", "item011", "item012",
} }
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
})
db2 := testDB() dbtest(t, func(t *testing.T, db2 storage.Storage) {
scope2 := testItemsSetup(db2) scope2 := testItemsSetup(db2)
db2.MarkItemsRead(model.MarkFilter{FolderID: &scope2.folder1.Id}) db2.MarkItemsRead(model.MarkFilter{FolderID: &scope2.folder1.Id})
have = getItemGuids(db2.ListItems(model.ItemFilter{Status: &read}, 10, false, false)) have = getItemGuids(db2.ListItems(model.ItemFilter{Status: &read}, 10, false, false))
want = []string{ want = []string{
"item111", "item112", "item121", "item122", "item111", "item112", "item121", "item122",
"item211", "item012", "item211", "item012",
} }
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
})
db3 := testDB() dbtest(t, func(t *testing.T, db3 storage.Storage) {
scope3 := testItemsSetup(db3) scope3 := testItemsSetup(db3)
db3.MarkItemsRead(model.MarkFilter{FeedID: &scope3.feed11.Id}) db3.MarkItemsRead(model.MarkFilter{FeedID: &scope3.feed11.Id})
have = getItemGuids(db3.ListItems(model.ItemFilter{Status: &read}, 10, false, false)) have = getItemGuids(db3.ListItems(model.ItemFilter{Status: &read}, 10, false, false))
want = []string{ want = []string{
"item111", "item112", "item122", "item111", "item112", "item122",
"item211", "item012", "item211", "item012",
} }
if !reflect.DeepEqual(have, want) { if !reflect.DeepEqual(have, want) {
t.Logf("want: %#v", want) t.Logf("want: %#v", want)
t.Logf("have: %#v", have) t.Logf("have: %#v", have)
t.Fail() t.Fail()
} }
})
} }
func TestDeleteOldItems(t *testing.T) { func TestDeleteOldItems(t *testing.T) {
now := time.Now().UTC() now := time.Now().UTC()
starred := model.STARRED starred := model.STARRED
dbtest(t, func(t *testing.T, db storage.Storage) {
t.Run("keeps at least 50 items", func(t *testing.T) { t.Run("keeps at least 50 items", func(t *testing.T) {
db := testDB() feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"})
feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) items := make([]model.Item, 100)
items := make([]model.Item, 100) for i := range 100 {
for i := range 100 { items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Hour * 24)}
items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Hour * 24)} }
} db.CreateItems(items)
db.CreateItems(items)
// // Set 1 recent (latest), 100 old (100 days ago) // // Set 1 recent (latest), 100 old (100 days ago)
db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now)) db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now))
db.db.Exec(`update items set last_arrived = :la where guid != "99"`, sql.Named("la", now.Add(-time.Hour*24*100))) db.db.Exec(`update items set last_arrived = :la where guid != "99"`, sql.Named("la", now.Add(-time.Hour*24*100)))
db.DeleteOldItems() db.DeleteOldItems()
var have int var have int
db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have) db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have)
if have != 50 { if have != 50 {
t.Errorf("expected 50 items, have %d", have) t.Errorf("expected 50 items, have %d", have)
} }
}) })
t.Run("keeps all less than 90 days old", func(t *testing.T) { t.Run("keeps all less than 90 days old", func(t *testing.T) {
db := testDB() feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"})
feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) items := make([]model.Item, 100)
items := make([]model.Item, 100) for i := 0; i < 100; i++ {
for i := 0; i < 100; i++ { items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Second)}
items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Second)} }
} db.CreateItems(items)
db.CreateItems(items)
// Latest item at "now" // Latest item at "now"
// All others at 80 days ago (keep) // All others at 80 days ago (keep)
db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now)) db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now))
db.db.Exec(`update items set last_arrived = :la where guid != "99"`, sql.Named("la", now.Add(-time.Hour*24*80))) db.db.Exec(`update items set last_arrived = :la where guid != "99"`, sql.Named("la", now.Add(-time.Hour*24*80)))
db.DeleteOldItems() db.DeleteOldItems()
var have int var have int
db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have) db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have)
if have != 100 { if have != 100 {
t.Errorf("expected 100 items, have %d", have) t.Errorf("expected 100 items, have %d", have)
} }
}) })
t.Run("keeps starred", func(t *testing.T) { t.Run("keeps starred", func(t *testing.T) {
db := testDB() feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"})
feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"}) items := make([]model.Item, 100)
items := make([]model.Item, 100) for i := 0; i < 100; i++ {
for i := 0; i < 100; i++ { items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Second)}
items[i] = model.Item{GUID: strconv.Itoa(i), FeedId: feed.Id, Date: now.Add(time.Duration(i) * time.Second)} }
} db.CreateItems(items)
db.CreateItems(items)
// Set all to 100 days ago, except one recent // Set all to 100 days ago, except one recent
db.db.Exec(`update items set last_arrived = :la`, sql.Named("la", now.Add(-time.Hour*24*100))) db.db.Exec(`update items set last_arrived = :la`, sql.Named("la", now.Add(-time.Hour*24*100)))
db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now)) db.db.Exec(`update items set last_arrived = :la where guid = "99"`, sql.Named("la", now))
// Star 10 old items that would otherwise be deleted (rn > 50 and old) // Star 10 old items that would otherwise be deleted (rn > 50 and old)
db.db.Exec(`update items set status = :s where cast(guid as integer) < 10`, sql.Named("s", starred)) db.db.Exec(`update items set status = :s where cast(guid as integer) < 10`, sql.Named("s", starred))
db.DeleteOldItems() db.DeleteOldItems()
var have int var have int
db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have) db.db.QueryRow("select count(*) from items where feed_id = ?", feed.Id).Scan(&have)
// 50 (limit) + 10 (starred) = 60 items should remain. // 50 (limit) + 10 (starred) = 60 items should remain.
if have != 60 { if have != 60 {
t.Errorf("expected 60 items, have %d", have) t.Errorf("expected 60 items, have %d", have)
} }
})
}) })
} }
func TestCreateItemsLastArrived(t *testing.T) { func TestCreateItemsLastArrived(t *testing.T) {
synctest.Test(t, func(t *testing.T) { dbtest(t, func(t *testing.T, db storage.Storage) {
db := testDB() synctest.Test(t, func(t *testing.T) {
defer db.db.Close() db := testDB()
feed := db.CreateFeed(model.CreateFeedParams{Title: "test feed", FeedLink: "http://example.com/feed"}) defer db.db.Close()
feed := db.CreateFeed(model.CreateFeedParams{Title: "test feed", FeedLink: "http://example.com/feed"})
item := model.Item{ item := model.Item{
GUID: "item1", GUID: "item1",
FeedId: feed.Id, FeedId: feed.Id,
Title: "Title 1", Title: "Title 1",
Date: time.Now(), Date: time.Now(),
} }
// 1. Initial creation // 1. Initial creation
db.CreateItems([]model.Item{item}) db.CreateItems([]model.Item{item})
var lastArrived1 time.Time var lastArrived1 time.Time
err := db.db.QueryRow("select last_arrived from items where guid = ?", item.GUID).Scan(&lastArrived1) err := db.db.QueryRow("select last_arrived from items where guid = ?", item.GUID).Scan(&lastArrived1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
// 2. Update on conflict // 2. Update on conflict
db.CreateItems([]model.Item{item}) db.CreateItems([]model.Item{item})
var lastArrived2 time.Time var lastArrived2 time.Time
err = db.db.QueryRow("select last_arrived from items where guid = ?", item.GUID).Scan(&lastArrived2) err = db.db.QueryRow("select last_arrived from items where guid = ?", item.GUID).Scan(&lastArrived2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !lastArrived2.After(lastArrived1) { if !lastArrived2.After(lastArrived1) {
t.Errorf("expected last_arrived to be updated. old: %v, new: %v", lastArrived1, lastArrived2) t.Errorf("expected last_arrived to be updated. old: %v, new: %v", lastArrived1, lastArrived2)
} }
})
}) })
} }
func TestSearch(t *testing.T) { func TestSearch(t *testing.T) {
db := testDB() dbtest(t, func(t *testing.T, db storage.Storage) {
defer db.Close() feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"})
feed := db.CreateFeed(model.CreateFeedParams{Title: "f", FeedLink: "http://f.xml"})
db.CreateItems([]model.Item{ db.CreateItems([]model.Item{
{ {
GUID: "i1", GUID: "i1",
FeedId: feed.Id, FeedId: feed.Id,
Title: "Hello World", Title: "Hello World",
Content: "This is a <b>test</b> of the <i>emergency</i> broadcast system.", Content: "This is a <b>test</b> of the <i>emergency</i> broadcast system.",
}, },
{ {
GUID: "i2", GUID: "i2",
FeedId: feed.Id, FeedId: feed.Id,
Title: "FTS5 Unicode", Title: "FTS5 Unicode",
Content: "Unicode support with characters like: Привет, 世界, 🚀", Content: "Unicode support with characters like: Привет, 世界, 🚀",
}, },
{ {
GUID: "i3", GUID: "i3",
FeedId: feed.Id, FeedId: feed.Id,
Title: "Hidden Tag", Title: "Hidden Tag",
Content: `<div class="secret-class">Don't find me by my class name</div>`, Content: `<div class="secret-class">Don't find me by my class name</div>`,
}, },
})
// 1. Basic search
s1 := "emergency"
have := getItemGuids(db.ListItems(model.ItemFilter{Search: &s1}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i1"}) {
t.Errorf("basic search failed: expected [i1], got %v", have)
}
// 2. HTML stripping: Should find text, but NOT the tags
s2 := "test"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s2}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i1"}) {
t.Errorf("html text search failed: expected [i1], got %v", have)
}
s3 := "secret-class"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s3}, 10, true, false))
if len(have) > 0 {
t.Errorf("html tag search should have failed but found: %v", have)
}
// 3. Multi-word (AND)
s4 := "broadcast system"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s4}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i1"}) {
t.Errorf("multi-word search failed: expected [i1], got %v", have)
}
// 4. Unicode
s5 := "Привет"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s5}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i2"}) {
t.Errorf("unicode search failed: expected [i2], got %v", have)
}
s6 := "世界"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s6}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i2"}) {
t.Errorf("unicode search (CJK) failed: expected [i2], got %v", have)
}
// 5. Trigger: Update
db.db.Exec("update items set title = 'Updated Title' where guid = 'i1'")
s7 := "Updated"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s7}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i1"}) {
t.Errorf("update trigger failed: expected [i1], got %v", have)
}
// 6. Trigger: Delete
db.db.Exec("delete from items where guid = 'i1'")
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s7}, 10, true, false))
if len(have) > 0 {
t.Errorf("delete trigger failed: found deleted item: %v", have)
}
}) })
// 1. Basic search
s1 := "emergency"
have := getItemGuids(db.ListItems(model.ItemFilter{Search: &s1}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i1"}) {
t.Errorf("basic search failed: expected [i1], got %v", have)
}
// 2. HTML stripping: Should find text, but NOT the tags
s2 := "test"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s2}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i1"}) {
t.Errorf("html text search failed: expected [i1], got %v", have)
}
s3 := "secret-class"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s3}, 10, true, false))
if len(have) > 0 {
t.Errorf("html tag search should have failed but found: %v", have)
}
// 3. Multi-word (AND)
s4 := "broadcast system"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s4}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i1"}) {
t.Errorf("multi-word search failed: expected [i1], got %v", have)
}
// 4. Unicode
s5 := "Привет"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s5}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i2"}) {
t.Errorf("unicode search failed: expected [i2], got %v", have)
}
s6 := "世界"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s6}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i2"}) {
t.Errorf("unicode search (CJK) failed: expected [i2], got %v", have)
}
// 5. Trigger: Update
db.db.Exec("update items set title = 'Updated Title' where guid = 'i1'")
s7 := "Updated"
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s7}, 10, true, false))
if !reflect.DeepEqual(have, []string{"i1"}) {
t.Errorf("update trigger failed: expected [i1], got %v", have)
}
// 6. Trigger: Delete
db.db.Exec("delete from items where guid = 'i1'")
have = getItemGuids(db.ListItems(model.ItemFilter{Search: &s7}, 10, true, false))
if len(have) > 0 {
t.Errorf("delete trigger failed: found deleted item: %v", have)
}
} }

View File

@@ -5,148 +5,148 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/nkanaev/yarr/src/storage"
"github.com/nkanaev/yarr/src/storage/model" "github.com/nkanaev/yarr/src/storage/model"
) )
func TestSettingsDefaults(t *testing.T) { func TestSettingsDefaults(t *testing.T) {
s := testDB() dbtest(t, func(t *testing.T, s storage.Storage) {
defer s.Close() settings := s.GetSettings()
defaults := settingsDefaults()
settings := s.GetSettings() if !reflect.DeepEqual(settings, defaults) {
defaults := settingsDefaults() t.Errorf("expected defaults %+v, got %+v", defaults, settings)
}
if !reflect.DeepEqual(settings, defaults) { })
t.Errorf("expected defaults %+v, got %+v", defaults, settings)
}
} }
func TestUpdateSettings(t *testing.T) { func TestUpdateSettings(t *testing.T) {
s := testDB() dbtest(t, func(t *testing.T, s storage.Storage) {
defer s.Close()
params := model.UpdateSettingsParams{ params := model.UpdateSettingsParams{
ThemeName: ptr("night"), ThemeName: ptr("night"),
FeedListWidth: ptr(400), FeedListWidth: ptr(400),
RefreshRate: ptr(int64(15)), RefreshRate: ptr(int64(15)),
} }
if ok := s.UpdateSettings(params); !ok { if ok := s.UpdateSettings(params); !ok {
t.Fatal("UpdateSettings failed") t.Fatal("UpdateSettings failed")
} }
settings := s.GetSettings() settings := s.GetSettings()
if settings.ThemeName != "night" { if settings.ThemeName != "night" {
t.Errorf("expected theme_name night, got %s", settings.ThemeName) t.Errorf("expected theme_name night, got %s", settings.ThemeName)
} }
if settings.FeedListWidth != 400 { if settings.FeedListWidth != 400 {
t.Errorf("expected feed_list_width 400, got %d", settings.FeedListWidth) t.Errorf("expected feed_list_width 400, got %d", settings.FeedListWidth)
} }
if settings.RefreshRate != 15 { if settings.RefreshRate != 15 {
t.Errorf("expected refresh_rate 15, got %d", settings.RefreshRate) t.Errorf("expected refresh_rate 15, got %d", settings.RefreshRate)
} }
})
} }
func TestGetSettings(t *testing.T) { func TestGetSettings(t *testing.T) {
s := testDB() dbtest(t, func(t *testing.T, s storage.Storage) {
defer s.Close()
s.UpdateSettings(model.UpdateSettingsParams{Language: ptr("fr")}) s.UpdateSettings(model.UpdateSettingsParams{Language: ptr("fr")})
settings := s.GetSettings() settings := s.GetSettings()
if settings.Language != "fr" { if settings.Language != "fr" {
t.Errorf("expected fr, got %v", settings.Language) t.Errorf("expected fr, got %v", settings.Language)
} }
if settings.ThemeName != "light" { if settings.ThemeName != "light" {
t.Errorf("expected light, got %v", settings.ThemeName) t.Errorf("expected light, got %v", settings.ThemeName)
} }
})
} }
func TestSettingsExhaustive(t *testing.T) { func TestSettingsExhaustive(t *testing.T) {
s := testDB() dbtest(t, func(t *testing.T, s storage.Storage) {
defer s.Close()
settingsType := reflect.TypeOf(model.Settings{}) settingsType := reflect.TypeOf(model.Settings{})
paramsType := reflect.TypeOf(model.UpdateSettingsParams{}) paramsType := reflect.TypeOf(model.UpdateSettingsParams{})
settings := s.GetSettings() settings := s.GetSettings()
m := settings.Map() m := settings.Map()
for i := 0; i < settingsType.NumField(); i++ { for i := 0; i < settingsType.NumField(); i++ {
field := settingsType.Field(i) field := settingsType.Field(i)
jsonTag := field.Tag.Get("json") jsonTag := field.Tag.Get("json")
if jsonTag == "" { if jsonTag == "" {
t.Errorf("Field %s missing json tag", field.Name) t.Errorf("Field %s missing json tag", field.Name)
continue continue
} }
// json tags might have options like "name,omitempty", take only the first part // json tags might have options like "name,omitempty", take only the first part
jsonKey := strings.Split(jsonTag, ",")[0] jsonKey := strings.Split(jsonTag, ",")[0]
// 1. Check Map() // 1. Check Map()
if _, ok := m[jsonKey]; !ok { if _, ok := m[jsonKey]; !ok {
t.Errorf("Key %q (from field %s) missing from Settings.Map()", jsonKey, field.Name) t.Errorf("Key %q (from field %s) missing from Settings.Map()", jsonKey, field.Name)
} }
// 2. Check UpdateSettingsParams // 2. Check UpdateSettingsParams
foundInParams := false foundInParams := false
for j := 0; j < paramsType.NumField(); j++ { for j := 0; j < paramsType.NumField(); j++ {
pField := paramsType.Field(j) pField := paramsType.Field(j)
pJsonTag := strings.Split(pField.Tag.Get("json"), ",")[0] pJsonTag := strings.Split(pField.Tag.Get("json"), ",")[0]
if pJsonTag == jsonKey { if pJsonTag == jsonKey {
foundInParams = true foundInParams = true
// Also check it's a pointer // Also check it's a pointer
if pField.Type.Kind() != reflect.Ptr { if pField.Type.Kind() != reflect.Ptr {
t.Errorf("Field %s in UpdateSettingsParams should be a pointer", pField.Name) t.Errorf("Field %s in UpdateSettingsParams should be a pointer", pField.Name)
}
break
} }
break
} }
} if !foundInParams {
if !foundInParams { t.Errorf("Key %q (from field %s) missing from UpdateSettingsParams", jsonKey, field.Name)
t.Errorf("Key %q (from field %s) missing from UpdateSettingsParams", jsonKey, field.Name) }
}
// 3. Test round-trip update // 3. Test round-trip update
// We'll create a new UpdateSettingsParams and set ONLY this field // We'll create a new UpdateSettingsParams and set ONLY this field
paramsValue := reflect.New(paramsType).Elem() paramsValue := reflect.New(paramsType).Elem()
for j := 0; j < paramsType.NumField(); j++ { for j := 0; j < paramsType.NumField(); j++ {
pField := paramsType.Field(j) pField := paramsType.Field(j)
pJsonTag := strings.Split(pField.Tag.Get("json"), ",")[0] pJsonTag := strings.Split(pField.Tag.Get("json"), ",")[0]
if pJsonTag == jsonKey { if pJsonTag == jsonKey {
// Create a new value of the underlying type // Create a new value of the underlying type
val := reflect.New(field.Type).Elem() val := reflect.New(field.Type).Elem()
switch field.Type.Kind() { switch field.Type.Kind() {
case reflect.String: case reflect.String:
val.SetString("test_" + jsonKey) val.SetString("test_" + jsonKey)
case reflect.Int, reflect.Int64: case reflect.Int, reflect.Int64:
val.SetInt(42) val.SetInt(42)
case reflect.Bool: case reflect.Bool:
val.SetBool(false) val.SetBool(false)
}
paramsValue.Field(j).Set(val.Addr())
break
}
}
if ok := s.UpdateSettings(paramsValue.Interface().(model.UpdateSettingsParams)); !ok {
t.Errorf("UpdateSettings failed for %q", jsonKey)
}
updated := s.GetSettings()
updatedValue := reflect.ValueOf(updated).Field(i)
switch field.Type.Kind() {
case reflect.String:
if updatedValue.String() != "test_"+jsonKey {
t.Errorf("Round-trip failed for %q: expected %q, got %q (check UpdateSettings/GetSettings switch)", jsonKey, "test_"+jsonKey, updatedValue.String())
}
case reflect.Int, reflect.Int64:
if updatedValue.Int() != 42 {
t.Errorf("Round-trip failed for %q: expected 42, got %d (check UpdateSettings/GetSettings switch)", jsonKey, updatedValue.Int())
}
case reflect.Bool:
if updatedValue.Bool() != false {
t.Errorf("Round-trip failed for %q: expected false, got %v (check UpdateSettings/GetSettings switch)", jsonKey, updatedValue.Bool())
} }
paramsValue.Field(j).Set(val.Addr())
break
} }
} }
})
if ok := s.UpdateSettings(paramsValue.Interface().(model.UpdateSettingsParams)); !ok {
t.Errorf("UpdateSettings failed for %q", jsonKey)
}
updated := s.GetSettings()
updatedValue := reflect.ValueOf(updated).Field(i)
switch field.Type.Kind() {
case reflect.String:
if updatedValue.String() != "test_"+jsonKey {
t.Errorf("Round-trip failed for %q: expected %q, got %q (check UpdateSettings/GetSettings switch)", jsonKey, "test_"+jsonKey, updatedValue.String())
}
case reflect.Int, reflect.Int64:
if updatedValue.Int() != 42 {
t.Errorf("Round-trip failed for %q: expected 42, got %d (check UpdateSettings/GetSettings switch)", jsonKey, updatedValue.Int())
}
case reflect.Bool:
if updatedValue.Bool() != false {
t.Errorf("Round-trip failed for %q: expected false, got %v (check UpdateSettings/GetSettings switch)", jsonKey, updatedValue.Bool())
}
}
}
} }

View File

@@ -10,8 +10,8 @@ import (
) )
func dbtest(t *testing.T, testcase func(t *testing.T, db storage.Storage)) { func dbtest(t *testing.T, testcase func(t *testing.T, db storage.Storage)) {
testurls := map[string]string { testurls := map[string]string{
"sqlite": ":memory:", "sqlite": ":memory:",
"postgres": "postgres://postgres:postgres@localhost:5432/yarr_test", "postgres": "postgres://postgres:postgres@localhost:5432/yarr_test",
} }
for testname, url := range testurls { for testname, url := range testurls {