use axum::Router; use axum::response::{IntoResponse, Response}; use axum::routing::get; use axum_login::AuthUser; use axum_login::{AuthManagerLayerBuilder, login_required}; use clap::{Parser, Subcommand, arg, command}; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use std::collections::HashMap; use std::str::FromStr; use std::sync::{Arc, RwLock}; use tokio::net::TcpListener; use tokio::signal; use tokio::task::AbortHandle; use tower_http::services::ServeDir; use tower_sessions::{ExpiredDeletion, Expiry, SessionManagerLayer, cookie::Key}; use tower_sessions_sqlx_store::SqliteStore; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod models; use models::contact::MentionTrie; use models::user::{Backend, User}; mod db; use db::{Database, DbId}; mod web; use web::{auth, contact, group, home, ics, journal, settings}; #[derive(Clone)] struct AppStateEntry { database: Arc, contact_search: Arc>, } #[derive(Clone)] struct AppState { map: Arc>>, } struct NameReference { name: String, contact_id: DbId, } impl AppState { pub fn new() -> Self { AppState { map: Arc::new(RwLock::new(HashMap::new())), } } pub async fn init(&mut self, user: &User) -> Result, AppError> { let database = Database::for_user(&user).await?; let mut trie = radix_trie::Trie::new(); let mentionable_names = sqlx::query_as!( NameReference, "select name, contact_id from ( select contact_id, name, count(name) as ct from names group by name ) where ct = 1;", ) .fetch_all(&database.pool) .await?; for row in mentionable_names { trie.insert( row.name, format!("/contact/{}", DbId::try_from(row.contact_id)?), ); } let groups: Vec<(String, String)> = sqlx::query_as("select distinct name, slug from groups") .fetch_all(&database.pool) .await?; for (group, slug) in groups { // TODO urlencode trie.insert(group, format!("/group/{}", slug)); } let mut map = self.map.write().expect("rwlock poisoned"); Ok(map.insert( user.id(), crate::AppStateEntry { database: Arc::new(database), contact_search: Arc::new(RwLock::new(trie)), }, )) } pub fn remove(&mut self, user: &impl AuthUser) { let mut map = self.map.write().expect("rwlock poisoned"); map.remove(&user.id()); } pub fn db(&self, user: &impl AuthUser) -> Arc { let map = self.map.read().expect("rwlock poisoned"); map.get(&user.id()).unwrap().database.clone() } pub fn contact_search(&self, user: &impl AuthUser) -> Arc> { let map = self.map.read().expect("rwlock poisoned"); map.get(&user.id()).unwrap().contact_search.clone() } } pub struct AppError(anyhow::Error); impl IntoResponse for AppError { fn into_response(self) -> Response { ( axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("Something went wrong: {}", self.0), ) .into_response() } } impl From for AppError where E: Into, { fn from(err: E) -> Self { Self(err.into()) } } #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Cli { #[command(subcommand)] command: Option, } #[derive(Subcommand, Debug)] enum Commands { /// run mascarpone server (default) Serve { /// port to bind #[arg(short, long, default_value_t = 3000)] port: u32, }, SetPassword { /// username to create or set password username: String, }, } async fn serve(port: &u32) -> Result<(), anyhow::Error> { let users_db = { let db_options = SqliteConnectOptions::from_str("users.db")? .create_if_missing(true) .to_owned(); let db = SqlitePoolOptions::new().connect_with(db_options).await?; sqlx::migrate!("./migrations/users.db").run(&db).await?; db }; let state = AppState::new(); let session_store = SqliteStore::new(users_db.clone()); session_store.migrate().await?; let deletion_task = tokio::task::spawn( session_store .clone() .continuously_delete_expired(tokio::time::Duration::from_secs(600)), ); // Generate a cryptographic key to sign the session cookie. let key = Key::generate(); let session_layer = SessionManagerLayer::new(session_store) .with_secure(false) .with_expiry(Expiry::OnInactivity(time::Duration::days(10))) .with_signed(key); let backend = Backend::new(users_db.clone()); let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build(); tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { format!( "{}=debug,tower_http=debug,axum=trace,sqlx=debug", env!("CARGO_CRATE_NAME") ) .into() }), ) .with(tracing_subscriber::fmt::layer().without_time()) .init(); let app = Router::new() .route("/", get(home::get::home)) .merge(contact::router()) .merge(group::router()) .merge(journal::router()) .merge(settings::router()) .route_layer(login_required!(Backend, login_url = "/login")) .merge(auth::router()) .merge(ics::router()) .nest_service("/static", ServeDir::new("./hashed_static")) .layer(auth_layer) .with_state(state); let listener = TcpListener::bind(format!("0.0.0.0:{}", port)).await?; tracing::debug!("Starting axum on 0.0.0.0:{}...", port); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal(deletion_task.abort_handle())) .await .unwrap(); deletion_task.await??; Ok(()) } #[tokio::main] async fn main() -> Result<(), anyhow::Error> { let cli = Cli::parse(); match &cli.command { Some(Commands::SetPassword { username }) => { let users_db = { let db_options = SqliteConnectOptions::from_str("users.db")? .create_if_missing(true) .to_owned(); let db = SqlitePoolOptions::new().connect_with(db_options).await?; sqlx::migrate!("./migrations/users.db").run(&db).await?; db }; let password = rpassword::prompt_password(format!("New password for {}: ", username)).unwrap(); let update = sqlx::query( "insert into users (username, password) values ($1, $2) on conflict do update set password=excluded.password", ) .bind(username) .bind(password_auth::generate_hash(password)) .execute(&users_db) .await?; if update.rows_affected() > 0 { println!("Updated password for {}.", username); } else { println!("No update was made; probably something went wrong."); } } Some(Commands::Serve { port }) => { serve(port).await?; } None => { serve(&3000).await?; } } Ok(()) } async fn shutdown_signal(deletion_task_abort_handle: AbortHandle) { let ctrl_c = async { signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { signal::unix::signal(signal::unix::SignalKind::terminate()) .expect("failed to install signal handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => { deletion_task_abort_handle.abort() }, _ = terminate => { deletion_task_abort_handle.abort() }, } }