diff --git a/src/server/auth/middleware.go b/src/server/auth/middleware.go index 4450c91..4dd12a7 100644 --- a/src/server/auth/middleware.go +++ b/src/server/auth/middleware.go @@ -47,12 +47,12 @@ func (m *Middleware) Handler(c *router.Context) { c.HTML(http.StatusOK, assets.Template("login.html"), map[string]any{ "username": username, "error": "Invalid username/password", - "settings": m.DB.GetSettings(), + "settings": m.DB.GetSettings().Map(), }) return } } c.HTML(http.StatusOK, assets.Template("login.html"), map[string]any{ - "settings": m.DB.GetSettings(), + "settings": m.DB.GetSettings().Map(), }) } diff --git a/src/server/routes.go b/src/server/routes.go index 2e74503..f5f8620 100644 --- a/src/server/routes.go +++ b/src/server/routes.go @@ -65,7 +65,7 @@ func (s *Server) handler() http.Handler { func (s *Server) handleIndex(c *router.Context) { c.HTML(http.StatusOK, assets.Template("index.html"), map[string]any{ - "settings": s.db.GetSettings(), + "settings": s.db.GetSettings().Map(), "authenticated": s.Username != "" && s.Password != "", }) } @@ -423,14 +423,14 @@ func (s *Server) handleSettings(c *router.Context) { if c.Req.Method == "GET" { c.JSON(http.StatusOK, s.db.GetSettings()) } else if c.Req.Method == "PUT" { - settings := make(map[string]any) - if err := json.NewDecoder(c.Req.Body).Decode(&settings); err != nil { + var params storage.UpdateSettingsParams + if err := json.NewDecoder(c.Req.Body).Decode(¶ms); err != nil { c.Out.WriteHeader(http.StatusBadRequest) return } - if s.db.UpdateSettings(settings) { - if _, ok := settings["refresh_rate"]; ok { - s.worker.SetRefreshRate(s.db.GetSettingsValueInt64("refresh_rate")) + if s.db.UpdateSettings(params) { + if params.RefreshRate != nil { + s.worker.SetRefreshRate(s.db.GetSettings().RefreshRate) } c.Out.WriteHeader(http.StatusOK) } else { diff --git a/src/server/server.go b/src/server/server.go index f3e571e..e6a8ed4 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -48,7 +48,7 @@ func (h *Server) GetAddr() string { } func (s *Server) Start() { - refreshRate := s.db.GetSettingsValueInt64("refresh_rate") + refreshRate := s.db.GetSettings().RefreshRate s.worker.FindFavicons() s.worker.StartFeedCleaner() s.worker.SetRefreshRate(refreshRate) diff --git a/src/storage/settings.go b/src/storage/settings.go index ab41a9b..0686710 100644 --- a/src/storage/settings.go +++ b/src/storage/settings.go @@ -6,92 +6,166 @@ import ( "log" ) -func settingsDefaults() map[string]any { +type Settings struct { + Filter string `json:"filter"` + Feed string `json:"feed"` + FeedListWidth int `json:"feed_list_width"` + ItemListWidth int `json:"item_list_width"` + SortNewestFirst bool `json:"sort_newest_first"` + ThemeName string `json:"theme_name"` + ThemeFont string `json:"theme_font"` + ThemeSize int `json:"theme_size"` + RefreshRate int64 `json:"refresh_rate"` + Language string `json:"language"` +} + +func (s Settings) Map() map[string]any { return map[string]any{ - "filter": "", - "feed": "", - "feed_list_width": 300, - "item_list_width": 300, - "sort_newest_first": true, - "theme_name": "light", - "theme_font": "", - "theme_size": 1, - "refresh_rate": 0, - "language": "en", + "filter": s.Filter, + "feed": s.Feed, + "feed_list_width": s.FeedListWidth, + "item_list_width": s.ItemListWidth, + "sort_newest_first": s.SortNewestFirst, + "theme_name": s.ThemeName, + "theme_font": s.ThemeFont, + "theme_size": s.ThemeSize, + "refresh_rate": s.RefreshRate, + "language": s.Language, } } -func (s *Storage) GetSettingsValue(key string) any { - row := s.db.QueryRow(`select val from settings where key=:key`, sql.Named("key", key)) - if row == nil { - return settingsDefaults()[key] +func settingsDefaults() Settings { + return Settings{ + Filter: "", + Feed: "", + FeedListWidth: 300, + ItemListWidth: 300, + SortNewestFirst: true, + ThemeName: "light", + ThemeFont: "", + ThemeSize: 1, + RefreshRate: 0, + Language: "en", } - var val []byte - row.Scan(&val) - if len(val) == 0 { - return nil - } - var valDecoded any - if err := json.Unmarshal([]byte(val), &valDecoded); err != nil { - log.Print(err) - return nil - } - return valDecoded } -func (s *Storage) GetSettingsValueInt64(key string) int64 { - val := s.GetSettingsValue(key) - if val != nil { - if fval, ok := val.(float64); ok { - return int64(fval) - } - } - return 0 -} - -func (s *Storage) GetSettings() map[string]any { +func (s *Storage) GetSettings() Settings { result := settingsDefaults() rows, err := s.db.Query(`select key, val from settings;`) if err != nil { log.Print(err) return result } + defer rows.Close() + for rows.Next() { var key string var val []byte - var valDecoded any - rows.Scan(&key, &val) - if err = json.Unmarshal([]byte(val), &valDecoded); err != nil { - log.Print(err) - continue + + switch key { + case "filter": + json.Unmarshal(val, &result.Filter) + case "feed": + json.Unmarshal(val, &result.Feed) + case "feed_list_width": + json.Unmarshal(val, &result.FeedListWidth) + case "item_list_width": + json.Unmarshal(val, &result.ItemListWidth) + case "sort_newest_first": + json.Unmarshal(val, &result.SortNewestFirst) + case "theme_name": + json.Unmarshal(val, &result.ThemeName) + case "theme_font": + json.Unmarshal(val, &result.ThemeFont) + case "theme_size": + json.Unmarshal(val, &result.ThemeSize) + case "refresh_rate": + json.Unmarshal(val, &result.RefreshRate) + case "language": + json.Unmarshal(val, &result.Language) } - result[key] = valDecoded } return result } -func (s *Storage) UpdateSettings(kv map[string]any) bool { - defaults := settingsDefaults() - for key, val := range kv { - if defaults[key] == nil { - continue - } +type UpdateSettingsParams struct { + Filter *string `json:"filter"` + Feed *string `json:"feed"` + FeedListWidth *int `json:"feed_list_width"` + ItemListWidth *int `json:"item_list_width"` + SortNewestFirst *bool `json:"sort_newest_first"` + ThemeName *string `json:"theme_name"` + ThemeFont *string `json:"theme_font"` + ThemeSize *int `json:"theme_size"` + RefreshRate *int64 `json:"refresh_rate"` + Language *string `json:"language"` +} + +func (s *Storage) UpdateSettings(params UpdateSettingsParams) bool { + tx, err := s.db.Begin() + if err != nil { + log.Print(err) + return false + } + defer tx.Rollback() + + update := func(key string, val any) error { valEncoded, err := json.Marshal(val) if err != nil { - log.Print(err) - return false + return err } - _, err = s.db.Exec(` + _, err = tx.Exec(` insert into settings (key, val) values (:key, :val) on conflict (key) do update set val=:val`, sql.Named("key", key), sql.Named("val", valEncoded), ) + return err + } + + var errs []error + if params.Filter != nil { + errs = append(errs, update("filter", *params.Filter)) + } + if params.Feed != nil { + errs = append(errs, update("feed", *params.Feed)) + } + if params.FeedListWidth != nil { + errs = append(errs, update("feed_list_width", *params.FeedListWidth)) + } + if params.ItemListWidth != nil { + errs = append(errs, update("item_list_width", *params.ItemListWidth)) + } + if params.SortNewestFirst != nil { + errs = append(errs, update("sort_newest_first", *params.SortNewestFirst)) + } + if params.ThemeName != nil { + errs = append(errs, update("theme_name", *params.ThemeName)) + } + if params.ThemeFont != nil { + errs = append(errs, update("theme_font", *params.ThemeFont)) + } + if params.ThemeSize != nil { + errs = append(errs, update("theme_size", *params.ThemeSize)) + } + if params.RefreshRate != nil { + errs = append(errs, update("refresh_rate", *params.RefreshRate)) + } + if params.Language != nil { + errs = append(errs, update("language", *params.Language)) + } + + for _, err := range errs { if err != nil { log.Print(err) return false } } + + if err := tx.Commit(); err != nil { + log.Print(err) + return false + } return true } diff --git a/src/storage/settings_test.go b/src/storage/settings_test.go new file mode 100644 index 0000000..9a79234 --- /dev/null +++ b/src/storage/settings_test.go @@ -0,0 +1,150 @@ +package storage + +import ( + "reflect" + "strings" + "testing" +) + +func TestSettingsDefaults(t *testing.T) { + s := testDB() + defer s.Close() + + settings := s.GetSettings() + defaults := settingsDefaults() + + if !reflect.DeepEqual(settings, defaults) { + t.Errorf("expected defaults %+v, got %+v", defaults, settings) + } +} + +func TestUpdateSettings(t *testing.T) { + s := testDB() + defer s.Close() + + params := UpdateSettingsParams{ + ThemeName: ptr("night"), + FeedListWidth: ptr(400), + RefreshRate: ptr(int64(15)), + } + + if ok := s.UpdateSettings(params); !ok { + t.Fatal("UpdateSettings failed") + } + + settings := s.GetSettings() + + if settings.ThemeName != "night" { + t.Errorf("expected theme_name night, got %s", settings.ThemeName) + } + if settings.FeedListWidth != 400 { + t.Errorf("expected feed_list_width 400, got %d", settings.FeedListWidth) + } + if settings.RefreshRate != 15 { + t.Errorf("expected refresh_rate 15, got %d", settings.RefreshRate) + } +} + +func TestGetSettings(t *testing.T) { + s := testDB() + defer s.Close() + + s.UpdateSettings(UpdateSettingsParams{Language: ptr("fr")}) + + settings := s.GetSettings() + if settings.Language != "fr" { + t.Errorf("expected fr, got %v", settings.Language) + } + if settings.ThemeName != "light" { + t.Errorf("expected light, got %v", settings.ThemeName) + } +} + +func TestSettingsExhaustive(t *testing.T) { + s := testDB() + defer s.Close() + + settingsType := reflect.TypeOf(Settings{}) + paramsType := reflect.TypeOf(UpdateSettingsParams{}) + + settings := s.GetSettings() + m := settings.Map() + + for i := 0; i < settingsType.NumField(); i++ { + field := settingsType.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag == "" { + t.Errorf("Field %s missing json tag", field.Name) + continue + } + // json tags might have options like "name,omitempty", take only the first part + jsonKey := strings.Split(jsonTag, ",")[0] + + // 1. Check Map() + if _, ok := m[jsonKey]; !ok { + t.Errorf("Key %q (from field %s) missing from Settings.Map()", jsonKey, field.Name) + } + + // 2. Check UpdateSettingsParams + foundInParams := false + for j := 0; j < paramsType.NumField(); j++ { + pField := paramsType.Field(j) + pJsonTag := strings.Split(pField.Tag.Get("json"), ",")[0] + if pJsonTag == jsonKey { + foundInParams = true + // Also check it's a pointer + if pField.Type.Kind() != reflect.Ptr { + t.Errorf("Field %s in UpdateSettingsParams should be a pointer", pField.Name) + } + break + } + } + if !foundInParams { + t.Errorf("Key %q (from field %s) missing from UpdateSettingsParams", jsonKey, field.Name) + } + + // 3. Test round-trip update + // We'll create a new UpdateSettingsParams and set ONLY this field + paramsValue := reflect.New(paramsType).Elem() + for j := 0; j < paramsType.NumField(); j++ { + pField := paramsType.Field(j) + pJsonTag := strings.Split(pField.Tag.Get("json"), ",")[0] + if pJsonTag == jsonKey { + // Create a new value of the underlying type + val := reflect.New(field.Type).Elem() + switch field.Type.Kind() { + case reflect.String: + val.SetString("test_" + jsonKey) + case reflect.Int, reflect.Int64: + val.SetInt(42) + case reflect.Bool: + val.SetBool(false) + } + paramsValue.Field(j).Set(val.Addr()) + break + } + } + + if ok := s.UpdateSettings(paramsValue.Interface().(UpdateSettingsParams)); !ok { + t.Errorf("UpdateSettings failed for %q", jsonKey) + } + + updated := s.GetSettings() + updatedValue := reflect.ValueOf(updated).Field(i) + + switch field.Type.Kind() { + case reflect.String: + if updatedValue.String() != "test_"+jsonKey { + t.Errorf("Round-trip failed for %q: expected %q, got %q (check UpdateSettings/GetSettings switch)", jsonKey, "test_"+jsonKey, updatedValue.String()) + } + case reflect.Int, reflect.Int64: + if updatedValue.Int() != 42 { + t.Errorf("Round-trip failed for %q: expected 42, got %d (check UpdateSettings/GetSettings switch)", jsonKey, updatedValue.Int()) + } + case reflect.Bool: + if updatedValue.Bool() != false { + t.Errorf("Round-trip failed for %q: expected false, got %v (check UpdateSettings/GetSettings switch)", jsonKey, updatedValue.Bool()) + } + } + } +}