diff --git a/src/storage/feed.go b/src/storage/feed.go index a415955..b77abc5 100644 --- a/src/storage/feed.go +++ b/src/storage/feed.go @@ -20,20 +20,21 @@ func (s *Storage) CreateFeed(title, description, link, feedLink string, folderId if title == "" { title = feedLink } - result, err := s.db.Exec(` + row := s.db.QueryRow(` insert into feeds (title, description, link, feed_link, folder_id) values (?, ?, ?, ?, ?) - on conflict (feed_link) do update set folder_id=?`, + on conflict (feed_link) do update set folder_id = ? + returning id`, title, description, link, feedLink, folderId, folderId, ) - if err != nil { - return nil - } - id, idErr := result.LastInsertId() - if idErr != nil { - return nil - } + + var id int64 + err := row.Scan(&id) + if err != nil { + log.Print(err) + return nil + } return &Feed{ Id: id, Title: title, diff --git a/src/storage/feed_test.go b/src/storage/feed_test.go index 149be2a..2492757 100644 --- a/src/storage/feed_test.go +++ b/src/storage/feed_test.go @@ -17,6 +17,23 @@ func TestCreateFeed(t *testing.T) { } } +func TestCreateFeedSameLink(t *testing.T) { + db := testDB() + feed1 := db.CreateFeed("title", "", "", "http://example1.com/feed.xml", nil) + if feed1 == nil || feed1.Id == 0 { + t.Fatal("expected feed") + } + + for i := 0; i < 10; i++ { + db.CreateFeed("title", "", "", "http://example2.com/feed.xml", nil) + } + + feed2 := db.CreateFeed("title", "", "http://example.com", "http://example1.com/feed.xml", nil) + if feed1.Id != feed2.Id { + t.Fatalf("expected the same feed.\nwant: %#v\nhave: %#v", feed1, feed2) + } +} + func TestReadFeed(t *testing.T) { db := testDB() if db.GetFeed(100500) != nil { diff --git a/src/storage/folder.go b/src/storage/folder.go index a533653..eb31ecb 100644 --- a/src/storage/folder.go +++ b/src/storage/folder.go @@ -1,7 +1,6 @@ package storage import ( - "fmt" "log" ) @@ -13,35 +12,21 @@ type Folder struct { func (s *Storage) CreateFolder(title string) *Folder { expanded := true - result, err := s.db.Exec(` + row := s.db.QueryRow(` insert into folders (title, is_expanded) values (?, ?) - on conflict (title) do nothing`, + on conflict (title) do update set title = ? + returning id`, title, expanded, + // provide title again so that we can extract row id + title, ) - if err != nil { - fmt.Println(err) - return nil - } + var id int64 + err := row.Scan(&id) - var id int64 - numrows, err := result.RowsAffected() if err != nil { log.Print(err) return nil } - if numrows == 1 { - id, err = result.LastInsertId() - if err != nil { - log.Print(err) - return nil - } - } else { - err = s.db.QueryRow(`select id, is_expanded from folders where title=?`, title).Scan(&id, &expanded) - if err != nil { - log.Print(err) - return nil - } - } return &Folder{Id: id, Title: title, IsExpanded: expanded} }