sqlite: panic -> log.fatal refactor
@@ -218,11 +218,14 @@ // username fetches a client's username based
// on the sessionToken that user has set. username // will return "" if there is no sessionToken. func (s *Site) username(r *http.Request) string { - sessionToken, err := r.Cookie("session_token") + cookie, err := r.Cookie("session_token") + if err == http.ErrNoCookie { + return "" + } if err != nil { - return "" + log.Println(err) } - username := s.db.GetUsernameBySessionToken(sessionToken.Value) + username := s.db.GetUsernameBySessionToken(cookie.Value) return username }@@ -237,13 +240,13 @@ // login compares the sqlite password field against the user supplied password and
// sets a session token against the supplied writer. func (s *Site) login(w http.ResponseWriter, username string, password string) error { if username == "" { - return fmt.Errorf("username cannot be nil") + return fmt.Errorf("username cannot be empty") } if password == "" { - return fmt.Errorf("password cannot be nil") + return fmt.Errorf("password cannot be empty") } if !s.db.UserExists(username) { - return fmt.Errorf("user does not exist") + return fmt.Errorf("user '%s' does not exist", username) } storedPassword := s.db.GetPassword(username) err := bcrypt.CompareHashAndPassword([]byte(storedPassword), []byte(password))@@ -251,7 +254,11 @@ if err != nil {
return fmt.Errorf("invalid password") } sessionToken := lib.GenerateSessionToken() - s.db.SetSessionToken(username, sessionToken) + err = s.db.SetSessionToken(username, sessionToken) + if err != nil { + log.Println(err) + } + http.SetCookie(w, &http.Cookie{ Name: "session_token", // 18 years
@@ -28,7 +28,7 @@ session_token TEXT UNIQUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )`) if err != nil { - panic(err) + log.Fatal(err) } // feed _, err = db.Exec(`CREATE TABLE IF NOT EXISTS feed (@@ -38,7 +38,7 @@ fetch_error TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )`) if err != nil { - panic(err) + log.Fatal(err) } // subscribe _, err = db.Exec(`CREATE TABLE IF NOT EXISTS subscribe (@@ -48,13 +48,11 @@ feed_id INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )`) if err != nil { - panic(err) + log.Fatal(err) } return &DB{sql: db} } - -// TODO: think more about errors func (db *DB) GetUsernameBySessionToken(token string) string { var username string@@ -63,7 +61,7 @@ if err == sql.ErrNoRows {
return "" } if err != nil { - panic(err) + log.Fatal(err) } return username }@@ -75,16 +73,14 @@ if err == sql.ErrNoRows {
return "" } if err != nil { - panic(err) + log.Fatal(err) } return password } -func (db *DB) SetSessionToken(username string, token string) { +func (db *DB) SetSessionToken(username string, token string) error { _, err := db.sql.Exec("UPDATE user SET session_token=? WHERE username=?", token, username) - if err != nil { - panic(err) - } + return err } func (db *DB) AddUser(username string, passwordHash string) error {@@ -100,19 +96,19 @@ err := db.sql.QueryRow("SELECT id FROM subscribe WHERE user_id=? AND feed_id=?", uid, fid).Scan(&id)
if err == sql.ErrNoRows { _, err := db.sql.Exec("INSERT INTO subscribe (user_id, feed_id) VALUES (?, ?)", uid, fid) if err != nil { - panic(err) + log.Fatal(err) } return } if err != nil { - panic(err) + log.Fatal(err) } } func (db *DB) UnsubscribeAll(username string) { _, err := db.sql.Exec("DELETE FROM subscribe WHERE user_id=?", db.GetUserID(username)) if err != nil { - panic(err) + log.Fatal(err) } }@@ -123,7 +119,7 @@ if err == sql.ErrNoRows {
return false } if err != nil { - panic(err) + log.Fatal(err) } return true }@@ -132,7 +128,7 @@ func (db *DB) GetAllFeedURLs() []string {
// TODO: BAD SELECT STATEMENT!! SORRY :( --wesley rows, err := db.sql.Query("SELECT url FROM feed") if err != nil { - panic(err) + log.Fatal(err) } defer rows.Close()@@ -141,7 +137,7 @@ for rows.Next() {
var url string err = rows.Scan(&url) if err != nil { - panic(err) + log.Fatal(err) } urls = append(urls, url) }@@ -163,7 +159,7 @@ if err == sql.ErrNoRows {
return []string{} } if err != nil { - panic(err) + log.Fatal(err) } defer rows.Close()@@ -172,7 +168,7 @@ for rows.Next() {
var url string err = rows.Scan(&url) if err != nil { - panic(err) + log.Fatal(err) } urls = append(urls, url) }@@ -183,7 +179,7 @@ func (db *DB) GetUserID(username string) int {
var uid int err := db.sql.QueryRow("SELECT id FROM user WHERE username=?", username).Scan(&uid) if err != nil { - panic(err) + log.Fatal(err) } return uid }@@ -192,7 +188,7 @@ func (db *DB) GetFeedID(feedURL string) int {
var fid int err := db.sql.QueryRow("SELECT id FROM feed WHERE url=?", feedURL).Scan(&fid) if err != nil { - panic(err) + log.Fatal(err) } return fid }@@ -203,7 +199,7 @@ func (db *DB) WriteFeed(url string) {
_, err := db.sql.Exec(`INSERT INTO feed(url) VALUES(?) ON CONFLICT(url) DO NOTHING`, url) if err != nil { - panic(err) + log.Fatal(err) } }