From cbe1f971a5273dbd96c521a622cc086af5958a49 Mon Sep 17 00:00:00 2001 From: nkanaev Date: Sat, 25 Apr 2026 21:19:18 +0100 Subject: [PATCH] refactor: use sql named arg --- src/storage/feed.go | 47 ++++++++++++------ src/storage/folder.go | 22 +++++---- src/storage/http.go | 16 +++---- src/storage/item.go | 100 ++++++++++++++++++++++----------------- src/storage/item_test.go | 15 +++--- src/storage/migration.go | 5 +- src/storage/settings.go | 10 ++-- 7 files changed, 129 insertions(+), 86 deletions(-) diff --git a/src/storage/feed.go b/src/storage/feed.go index b829f03..5588bae 100644 --- a/src/storage/feed.go +++ b/src/storage/feed.go @@ -22,11 +22,14 @@ func (s *Storage) CreateFeed(title, description, link, feedLink string, folderId } row := s.db.QueryRow(` insert into feeds (title, description, link, feed_link, folder_id) - values (?, ?, ?, ?, ?) - on conflict (feed_link) do update set folder_id = ? + values (:title, :description, :link, :feed_link, :folder_id) + on conflict (feed_link) do update set folder_id = :folder_id returning id`, - title, description, link, feedLink, folderId, - folderId, + sql.Named("title", title), + sql.Named("description", description), + sql.Named("link", link), + sql.Named("feed_link", feedLink), + sql.Named("folder_id", folderId), ) var id int64 @@ -46,7 +49,7 @@ func (s *Storage) CreateFeed(title, description, link, feedLink string, folderId } func (s *Storage) DeleteFeed(feedId int64) bool { - result, err := s.db.Exec(`delete from feeds where id = ?`, feedId) + result, err := s.db.Exec(`delete from feeds where id = :id`, sql.Named("id", feedId)) if err != nil { log.Print(err) return false @@ -62,22 +65,34 @@ func (s *Storage) DeleteFeed(feedId int64) bool { } func (s *Storage) RenameFeed(feedId int64, newTitle string) bool { - _, err := s.db.Exec(`update feeds set title = ? where id = ?`, newTitle, feedId) + _, err := s.db.Exec(`update feeds set title = :title where id = :id`, + sql.Named("title", newTitle), + sql.Named("id", feedId), + ) return err == nil } func (s *Storage) UpdateFeedFolder(feedId int64, newFolderId *int64) bool { - _, err := s.db.Exec(`update feeds set folder_id = ? where id = ?`, newFolderId, feedId) + _, err := s.db.Exec(`update feeds set folder_id = :folder_id where id = :id`, + sql.Named("folder_id", newFolderId), + sql.Named("id", feedId), + ) return err == nil } func (s *Storage) UpdateFeedLink(feedId int64, newLink string) bool { - _, err := s.db.Exec(`update feeds set feed_link = ? where id = ?`, newLink, feedId) + _, err := s.db.Exec(`update feeds set feed_link = :feed_link where id = :id`, + sql.Named("feed_link", newLink), + sql.Named("id", feedId), + ) return err == nil } func (s *Storage) UpdateFeedIcon(feedId int64, icon *[]byte) bool { - _, err := s.db.Exec(`update feeds set icon = ? where id = ?`, icon, feedId) + _, err := s.db.Exec(`update feeds set icon = :icon where id = :id`, + sql.Named("icon", icon), + sql.Named("id", feedId), + ) return err == nil } @@ -149,8 +164,8 @@ func (s *Storage) GetFeed(id int64) *Feed { select id, folder_id, title, link, feed_link, icon, ifnull(icon, '') != '' as has_icon - from feeds where id = ? - `, id).Scan( + from feeds where id = :id + `, sql.Named("id", id)).Scan( &f.Id, &f.FolderId, &f.Title, &f.Link, &f.FeedLink, &f.Icon, &f.HasIcon, ) @@ -172,9 +187,10 @@ func (s *Storage) ResetFeedErrors() { func (s *Storage) SetFeedError(feedID int64, lastError error) { _, err := s.db.Exec(` insert into feed_errors (feed_id, error) - values (?, ?) + values (:feed_id, :error) on conflict (feed_id) do update set error = excluded.error`, - feedID, lastError.Error(), + sql.Named("feed_id", feedID), + sql.Named("error", lastError.Error()), ) if err != nil { log.Print(err) @@ -204,9 +220,10 @@ func (s *Storage) GetFeedErrors() map[int64]string { func (s *Storage) SetFeedSize(feedId int64, size int) { _, err := s.db.Exec(` insert into feed_sizes (feed_id, size) - values (?, ?) + values (:feed_id, :size) on conflict (feed_id) do update set size = excluded.size`, - feedId, size, + sql.Named("feed_id", feedId), + sql.Named("size", size), ) if err != nil { log.Print(err) diff --git a/src/storage/folder.go b/src/storage/folder.go index 27c33d5..40717c7 100644 --- a/src/storage/folder.go +++ b/src/storage/folder.go @@ -1,6 +1,7 @@ package storage import ( + "database/sql" "log" ) @@ -13,12 +14,11 @@ type Folder struct { func (s *Storage) CreateFolder(title string) *Folder { expanded := true row := s.db.QueryRow(` - insert into folders (title, is_expanded) values (?, ?) - on conflict (title) do update set title = ? + insert into folders (title, is_expanded) values (:title, :is_expanded) + on conflict (title) do update set title = :title returning id`, - title, expanded, - // provide title again so that we can extract row id - title, + sql.Named("title", title), + sql.Named("is_expanded", expanded), ) var id int64 err := row.Scan(&id) @@ -31,7 +31,7 @@ func (s *Storage) CreateFolder(title string) *Folder { } func (s *Storage) DeleteFolder(folderId int64) bool { - _, err := s.db.Exec(`delete from folders where id = ?`, folderId) + _, err := s.db.Exec(`delete from folders where id = :id`, sql.Named("id", folderId)) if err != nil { log.Print(err) } @@ -39,12 +39,18 @@ func (s *Storage) DeleteFolder(folderId int64) bool { } func (s *Storage) RenameFolder(folderId int64, newTitle string) bool { - _, err := s.db.Exec(`update folders set title = ? where id = ?`, newTitle, folderId) + _, err := s.db.Exec(`update folders set title = :title where id = :id`, + sql.Named("title", newTitle), + sql.Named("id", folderId), + ) return err == nil } func (s *Storage) ToggleFolderExpanded(folderId int64, isExpanded bool) bool { - _, err := s.db.Exec(`update folders set is_expanded = ? where id = ?`, isExpanded, folderId) + _, err := s.db.Exec(`update folders set is_expanded = :is_expanded where id = :id`, + sql.Named("is_expanded", isExpanded), + sql.Named("id", folderId), + ) return err == nil } diff --git a/src/storage/http.go b/src/storage/http.go index dd87ca2..6997b49 100644 --- a/src/storage/http.go +++ b/src/storage/http.go @@ -1,6 +1,7 @@ package storage import ( + "database/sql" "log" "time" ) @@ -40,8 +41,8 @@ func (s *Storage) ListHTTPStates() map[int64]HTTPState { func (s *Storage) GetHTTPState(feedID int64) *HTTPState { row := s.db.QueryRow(` select feed_id, last_refreshed, last_modified, etag - from http_states where feed_id = ? - `, feedID) + from http_states where feed_id = :feed_id + `, sql.Named("feed_id", feedID)) if row == nil { return nil @@ -60,12 +61,11 @@ func (s *Storage) GetHTTPState(feedID int64) *HTTPState { func (s *Storage) SetHTTPState(feedID int64, lastModified, etag string) { _, err := s.db.Exec(` insert into http_states (feed_id, last_modified, etag, last_refreshed) - values (?, ?, ?, datetime()) - on conflict (feed_id) do update set last_modified = ?, etag = ?, last_refreshed = datetime()`, - // insert - feedID, lastModified, etag, - // upsert - lastModified, etag, + values (:feed_id, :last_modified, :etag, datetime()) + on conflict (feed_id) do update set last_modified = :last_modified, etag = :etag, last_refreshed = datetime()`, + sql.Named("feed_id", feedID), + sql.Named("last_modified", lastModified), + sql.Named("etag", etag), ) if err != nil { log.Print(err) diff --git a/src/storage/item.go b/src/storage/item.go index 367d80b..ae34d38 100644 --- a/src/storage/item.go +++ b/src/storage/item.go @@ -1,6 +1,7 @@ package storage import ( + "database/sql" "database/sql/driver" "encoding/json" "fmt" @@ -137,14 +138,20 @@ func (s *Storage) CreateItems(items []Item) bool { date_arrived, status ) values ( - ?, ?, ?, ?, strftime('%Y-%m-%d %H:%M:%f', ?), - ?, ?, - ?, ? + :guid, :feed_id, :title, :link, strftime('%Y-%m-%d %H:%M:%f', :date), + :content, :media_links, + :date_arrived, :status ) on conflict (feed_id, guid) do nothing`, - item.GUID, item.FeedId, item.Title, item.Link, item.Date, - item.Content, item.MediaLinks, - now, UNREAD, + sql.Named("guid", item.GUID), + sql.Named("feed_id", item.FeedId), + sql.Named("title", item.Title), + sql.Named("link", item.Link), + sql.Named("date", item.Date), + sql.Named("content", item.Content), + sql.Named("media_links", item.MediaLinks), + sql.Named("date_arrived", now), + sql.Named("status", UNREAD), ) if err != nil { log.Print(err) @@ -166,16 +173,16 @@ func listQueryPredicate(filter ItemFilter, newestFirst bool) (string, []interfac cond := make([]string, 0) args := make([]interface{}, 0) if filter.FolderID != nil { - cond = append(cond, "i.feed_id in (select id from feeds where folder_id = ?)") - args = append(args, *filter.FolderID) + cond = append(cond, "i.feed_id in (select id from feeds where folder_id = :folder_id)") + args = append(args, sql.Named("folder_id", *filter.FolderID)) } if filter.FeedID != nil { - cond = append(cond, "i.feed_id = ?") - args = append(args, *filter.FeedID) + cond = append(cond, "i.feed_id = :feed_id") + args = append(args, sql.Named("feed_id", *filter.FeedID)) } if filter.Status != nil { - cond = append(cond, "i.status = ?") - args = append(args, *filter.Status) + cond = append(cond, "i.status = :status") + args = append(args, sql.Named("status", *filter.Status)) } if filter.Search != nil { words := strings.Fields(*filter.Search) @@ -184,38 +191,37 @@ func listQueryPredicate(filter ItemFilter, newestFirst bool) (string, []interfac terms[idx] = word + "*" } - cond = append(cond, "i.search_rowid in (select rowid from search where search match ?)") - args = append(args, strings.Join(terms, " ")) + cond = append(cond, "i.search_rowid in (select rowid from search where search match :search)") + args = append(args, sql.Named("search", strings.Join(terms, " "))) } if filter.After != nil { compare := ">" if newestFirst { compare = "<" } - cond = append(cond, fmt.Sprintf("(i.date, i.id) %s (select date, id from items where id = ?)", compare)) - args = append(args, *filter.After) + cond = append(cond, fmt.Sprintf("(i.date, i.id) %s (select date, id from items where id = :after_id)", compare)) + args = append(args, sql.Named("after_id", *filter.After)) } if filter.IDs != nil && len(*filter.IDs) > 0 { qmarks := make([]string, len(*filter.IDs)) - idargs := make([]interface{}, len(*filter.IDs)) for i, id := range *filter.IDs { - qmarks[i] = "?" - idargs[i] = id + name := fmt.Sprintf("id%d", i) + qmarks[i] = ":" + name + args = append(args, sql.Named(name, id)) } cond = append(cond, "i.id in ("+strings.Join(qmarks, ",")+")") - args = append(args, idargs...) } if filter.SinceID != nil { - cond = append(cond, "i.id > ?") - args = append(args, filter.SinceID) + cond = append(cond, "i.id > :since_id") + args = append(args, sql.Named("since_id", filter.SinceID)) } if filter.MaxID != nil { - cond = append(cond, "i.id < ?") - args = append(args, filter.MaxID) + cond = append(cond, "i.id < :max_id") + args = append(args, sql.Named("max_id", filter.MaxID)) } if filter.Before != nil { - cond = append(cond, "i.date < ?") - args = append(args, filter.Before) + cond = append(cond, "i.date < :before") + args = append(args, sql.Named("before", filter.Before)) } predicate := "1" @@ -299,8 +305,8 @@ func (s *Storage) GetItem(id int64) *Item { i.id, i.guid, i.feed_id, i.title, i.link, i.content, i.date, i.status, i.media_links from items i - where i.id = ? - `, id).Scan( + where i.id = :id + `, sql.Named("id", id)).Scan( &i.Id, &i.GUID, &i.FeedId, &i.Title, &i.Link, &i.Content, &i.Date, &i.Status, &i.MediaLinks, ) @@ -312,7 +318,10 @@ func (s *Storage) GetItem(id int64) *Item { } func (s *Storage) UpdateItemStatus(item_id int64, status ItemStatus) bool { - _, err := s.db.Exec(`update items set status = ? where id = ?`, status, item_id) + _, err := s.db.Exec(`update items set status = :status where id = :id`, + sql.Named("status", status), + sql.Named("id", item_id), + ) return err == nil } @@ -381,8 +390,9 @@ func (s *Storage) SyncSearch() { for _, item := range items { result, err := s.db.Exec(` - insert into search (title, description, content) values (?, "", ?)`, - item.Title, htmlutil.ExtractText(item.Content), + insert into search (title, description, content) values (:title, "", :content)`, + sql.Named("title", item.Title), + sql.Named("content", htmlutil.ExtractText(item.Content)), ) if err != nil { log.Print(err) @@ -391,8 +401,9 @@ func (s *Storage) SyncSearch() { if numrows, err := result.RowsAffected(); err == nil && numrows == 1 { if rowId, err := result.LastInsertId(); err == nil { s.db.Exec( - `update items set search_rowid = ? where id = ?`, - rowId, item.Id, + `update items set search_rowid = :search_rowid where id = :id`, + sql.Named("search_rowid", rowId), + sql.Named("id", item.Id), ) } } @@ -416,13 +427,16 @@ func (s *Storage) DeleteOldItems() { rows, err := s.db.Query(` select i.feed_id, - max(coalesce(s.size, 0), ?) as max_items, + max(coalesce(s.size, 0), :keep_size) as max_items, count(*) as num_items from items i left outer join feed_sizes s on s.feed_id = i.feed_id - where status != ? + where status != :starred_status group by i.feed_id - `, itemsKeepSize, STARRED) + `, + sql.Named("keep_size", itemsKeepSize), + sql.Named("starred_status", STARRED), + ) if err != nil { log.Print(err) return @@ -441,15 +455,15 @@ func (s *Storage) DeleteOldItems() { where id in ( select i.id from items i - where i.feed_id = ? and status != ? + where i.feed_id = :feed_id and status != :starred_status order by date desc - limit -1 offset ? - ) and date_arrived < ? + limit -1 offset :limit + ) and date_arrived < :date_limit `, - feedId, - STARRED, - limit, - time.Now().UTC().Add(-time.Hour*time.Duration(24*itemsKeepDays)), + sql.Named("feed_id", feedId), + sql.Named("starred_status", STARRED), + sql.Named("limit", limit), + sql.Named("date_limit", time.Now().UTC().Add(-time.Hour*time.Duration(24*itemsKeepDays))), ) if err != nil { log.Print(err) diff --git a/src/storage/item_test.go b/src/storage/item_test.go index 88fc63d..a11cf83 100644 --- a/src/storage/item_test.go +++ b/src/storage/item_test.go @@ -1,6 +1,7 @@ package storage import ( + "database/sql" "log" "reflect" "strconv" @@ -59,8 +60,8 @@ func testItemsSetup(db *Storage) testItemScope { {GUID: "item012", FeedId: feed01.Id, Title: "title012", Date: now.Add(time.Hour * 24 * 9)}, // read {GUID: "item013", FeedId: feed01.Id, Title: "title013", Date: now.Add(time.Hour * 24 * 10)}, // starred }) - db.db.Exec(`update items set status = ? where guid in ("item112", "item122", "item211", "item012")`, READ) - db.db.Exec(`update items set status = ? where guid in ("item113", "item212", "item013")`, STARRED) + db.db.Exec(`update items set status = :status where guid in ("item112", "item122", "item211", "item012")`, sql.Named("status", READ)) + db.db.Exec(`update items set status = :status where guid in ("item113", "item212", "item013")`, sql.Named("status", STARRED)) return testItemScope{ feed11: feed11, @@ -79,8 +80,8 @@ func getItem(db *Storage, guid string) *Item { i.id, i.guid, i.feed_id, i.title, i.link, i.content, i.date, i.status, i.media_links from items i - where i.guid = ? - `, guid).Scan( + where i.guid = :guid + `, sql.Named("guid", guid)).Scan( &i.Id, &i.GUID, &i.FeedId, &i.Title, &i.Link, &i.Content, &i.Date, &i.Status, &i.MediaLinks, ) @@ -295,7 +296,7 @@ func TestDeleteOldItems(t *testing.T) { db.SetFeedSize(feed.Id, itemsKeepSize) var feedSize int err := db.db.QueryRow( - `select size from feed_sizes where feed_id = ?`, feed.Id, + `select size from feed_sizes where feed_id = :feed_id`, sql.Named("feed_id", feed.Id), ).Scan(&feedSize) if err != nil { t.Fatal(err) @@ -310,9 +311,9 @@ func TestDeleteOldItems(t *testing.T) { // expire only the first 3 articles _, err = db.db.Exec( - `update items set date_arrived = ? + `update items set date_arrived = :date_arrived where id in (select id from items limit 3)`, - now.Add(-time.Hour*time.Duration(itemsKeepDays*24)), + sql.Named("date_arrived", now.Add(-time.Hour*time.Duration(itemsKeepDays*24))), ) if err != nil { t.Fatal(err) diff --git a/src/storage/migration.go b/src/storage/migration.go index abe13b1..a0dbdcd 100644 --- a/src/storage/migration.go +++ b/src/storage/migration.go @@ -290,7 +290,10 @@ func m08_normalize_datetime(tx *sql.Tx) error { if err != nil { return err } - _, err = tx.Exec(`update items set date_arrived = ? where id = ?;`, dateArrived.UTC(), id) + _, err = tx.Exec(`update items set date_arrived = :date_arrived where id = :id;`, + sql.Named("date_arrived", dateArrived.UTC()), + sql.Named("id", id), + ) if err != nil { return err } diff --git a/src/storage/settings.go b/src/storage/settings.go index d9ece3d..14eacc5 100644 --- a/src/storage/settings.go +++ b/src/storage/settings.go @@ -1,6 +1,7 @@ package storage import ( + "database/sql" "encoding/json" "log" ) @@ -20,7 +21,7 @@ func settingsDefaults() map[string]interface{} { } func (s *Storage) GetSettingsValue(key string) interface{} { - row := s.db.QueryRow(`select val from settings where key=?`, key) + row := s.db.QueryRow(`select val from settings where key=:key`, sql.Named("key", key)) if row == nil { return settingsDefaults()[key] } @@ -81,9 +82,10 @@ func (s *Storage) UpdateSettings(kv map[string]interface{}) bool { return false } _, err = s.db.Exec(` - insert into settings (key, val) values (?, ?) - on conflict (key) do update set val=?`, - key, valEncoded, valEncoded, + insert into settings (key, val) values (:key, :val) + on conflict (key) do update set val=:val`, + sql.Named("key", key), + sql.Named("val", valEncoded), ) if err != nil { log.Print(err)