From b4028fa08f965770affb2e9665c09594e0a4a1a0 Mon Sep 17 00:00:00 2001 From: Greg Shuflin Date: Tue, 4 Feb 2025 13:19:47 -0800 Subject: [PATCH] Separate session store module --- src/main.rs | 4 +++- src/session_store.rs | 44 ++++++++++++++++++++++++++++++++++++++++++++ src/user.rs | 44 +------------------------------------------- 3 files changed, 48 insertions(+), 44 deletions(-) create mode 100644 src/session_store.rs diff --git a/src/main.rs b/src/main.rs index 4d0c60f..1142c04 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ mod demo; mod feed_utils; mod feeds; mod poll; +mod session_store; mod user; use rocket::fairing::{self, AdHoc}; @@ -15,6 +16,7 @@ use rocket::response::Redirect; use rocket::{Build, Rocket, State}; use rocket_db_pools::{sqlx, Connection, Database}; use rocket_dyn_templates::{context, Template}; +use session_store::SessionStore; use user::AuthenticatedUser; /// RSS Reader application @@ -136,7 +138,7 @@ fn rocket() -> _ { .attach(Template::fairing()) .attach(Db::init()) .manage(args.demo) - .manage(user::SessionStore::new()) + .manage(SessionStore::new()) .attach(AdHoc::try_on_ignite("DB Setup", move |rocket| async move { setup_database(args.demo, rocket).await })) diff --git a/src/session_store.rs b/src/session_store.rs new file mode 100644 index 0000000..5f2ae99 --- /dev/null +++ b/src/session_store.rs @@ -0,0 +1,44 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::RwLock; +use uuid::Uuid; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; + +pub struct SessionStore(RwLock>>); + +impl SessionStore { + pub fn new() -> Self { + SessionStore(RwLock::new(HashMap::new())) + } + + pub fn generate_secret() -> String { + let mut bytes = [0u8; 32]; + getrandom::getrandom(&mut bytes).expect("Failed to generate random bytes"); + BASE64.encode(bytes) + } + + pub fn store(&self, user_id: Uuid, secret: String) { + let mut store = self.0.write().unwrap(); + store + .entry(user_id) + .or_insert_with(HashSet::new) + .insert(secret); + } + + pub fn verify(&self, user_id: Uuid, secret: &str) -> bool { + let store = self.0.read().unwrap(); + store + .get(&user_id) + .map_or(false, |secrets| secrets.contains(secret)) + } + + pub fn remove(&self, user_id: Uuid, secret: &str) { + let mut store = self.0.write().unwrap(); + if let Some(secrets) = store.get_mut(&user_id) { + secrets.remove(secret); + // Clean up the user entry if no sessions remain + if secrets.is_empty() { + store.remove(&user_id); + } + } + } +} \ No newline at end of file diff --git a/src/user.rs b/src/user.rs index ef21d37..990a4ec 100644 --- a/src/user.rs +++ b/src/user.rs @@ -1,56 +1,14 @@ use time::Duration; -use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use rocket::http::{Cookie, CookieJar, Status}; use rocket::serde::{json::Json, Deserialize, Serialize}; use rocket::State; use rocket_db_pools::Connection; use rocket_dyn_templates::{context, Template}; -use std::collections::{HashMap, HashSet}; -use std::sync::RwLock; use uuid::Uuid; use crate::Db; - -pub struct SessionStore(RwLock>>); - -impl SessionStore { - pub fn new() -> Self { - SessionStore(RwLock::new(HashMap::new())) - } - - fn generate_secret() -> String { - let mut bytes = [0u8; 32]; - getrandom::getrandom(&mut bytes).expect("Failed to generate random bytes"); - BASE64.encode(bytes) - } - - fn store(&self, user_id: Uuid, secret: String) { - let mut store = self.0.write().unwrap(); - store - .entry(user_id) - .or_insert_with(HashSet::new) - .insert(secret); - } - - fn verify(&self, user_id: Uuid, secret: &str) -> bool { - let store = self.0.read().unwrap(); - store - .get(&user_id) - .map_or(false, |secrets| secrets.contains(secret)) - } - - fn remove(&self, user_id: Uuid, secret: &str) { - let mut store = self.0.write().unwrap(); - if let Some(secrets) = store.get_mut(&user_id) { - secrets.remove(secret); - // Clean up the user entry if no sessions remain - if secrets.is_empty() { - store.remove(&user_id); - } - } - } -} +use crate::session_store::SessionStore; #[derive(Debug, Serialize)] #[serde(crate = "rocket::serde")]