safety: make subscription a transaction
Jes Olson j3s@c3f.net
Fri, 11 Apr 2025 16:07:40 -0500
2 files changed,
34 insertions(+),
12 deletions(-)
M
site.go
→
site.go
@@ -249,14 +249,14 @@ }
s.db.WriteFeed(u) } - // TODO: this is insane, make it a transaction - // so people don't lose feed subscriptions - // if vore restarts in the middle of this - // process. - s.db.UnsubscribeAll(s.username(r)) - for _, url := range validatedURLs { - s.db.Subscribe(s.username(r), url) + err := s.db.BatchSubscribe(s.username(r), validatedURLs) + if err != nil { + log.Println(err) + e := fmt.Sprintf("reaper: can't batchsubscribe user=%s err=%s", s.username(r), err) + s.renderErr(w, e, http.StatusInternalServerError) + return } + http.Redirect(w, r, "/settings", http.StatusSeeOther) }
M
sqlite/sqlite.go
→
sqlite/sqlite.go
@@ -124,9 +124,7 @@ _, err := db.sql.Exec("INSERT INTO user (username, password) VALUES (?, ?)", username, passwordHash)
return err } -func (db *DB) Subscribe(username string, feedURL string) { - uid := db.GetUserID(username) - fid := db.GetFeedID(feedURL) +func (db *DB) subscribe(uid int, fid int) { var id int err := db.sql.QueryRow("SELECT id FROM subscribe WHERE user_id=? AND feed_id=?", uid, fid).Scan(&id) if err == sql.ErrNoRows {@@ -141,8 +139,8 @@ log.Fatal(err)
} } -func (db *DB) UnsubscribeAll(username string) { - _, err := db.sql.Exec("DELETE FROM subscribe WHERE user_id=?", db.GetUserID(username)) +func (db *DB) unsubscribeAll(uid int) { + _, err := db.sql.Exec("DELETE FROM subscribe WHERE user_id=?", uid) if err != nil { log.Fatal(err) }@@ -324,3 +322,27 @@ log.Fatal(err)
} return fid, true } + +func (db *DB) BatchSubscribe(username string, feedURLs []string) error { + tx, err := db.sql.Begin() + if err != nil { + return err + } + + defer func() { + if err != nil { + tx.Rollback() + } + }() + + // first, unsub from everything + uid := db.GetUserID(username) + db.unsubscribeAll(uid) + + // Add new subscriptions + for _, url := range feedURLs { + db.subscribe(uid, db.GetFeedID(url)) + } + + return tx.Commit() +}