mascarpone/src/main.rs

275 lines
7.9 KiB
Rust
Raw Normal View History

2025-11-27 13:45:21 -06:00
use axum::Router;
use axum::response::{IntoResponse, Response};
use axum::routing::get;
use axum_login::AuthUser;
use axum_login::{AuthManagerLayerBuilder, login_required};
use clap::{Parser, Subcommand, arg, command};
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::{Arc, RwLock};
2025-10-14 08:26:18 -05:00
use tokio::net::TcpListener;
2025-11-27 13:45:21 -06:00
use tokio::signal;
use tokio::task::AbortHandle;
use tower_http::services::ServeDir;
use tower_sessions::{ExpiredDeletion, Expiry, SessionManagerLayer, cookie::Key};
use tower_sessions_sqlx_store::SqliteStore;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
2025-10-14 08:26:18 -05:00
2025-11-27 13:45:21 -06:00
mod models;
use models::contact::ContactTrie;
use models::user::{Backend, User};
mod db;
use db::{Database, DbId};
mod web;
use web::{auth, contact, home, ics, journal, settings};
#[derive(Clone)]
struct AppStateEntry {
database: Arc<Database>,
contact_search: Arc<RwLock<ContactTrie>>,
}
2025-10-14 08:26:18 -05:00
#[derive(Clone)]
struct AppState {
2025-11-27 13:45:21 -06:00
map: Arc<RwLock<HashMap<DbId, AppStateEntry>>>,
}
struct NameReference {
name: String,
contact_id: DbId,
}
impl AppState {
pub fn new() -> Self {
AppState {
map: Arc::new(RwLock::new(HashMap::new())),
2025-10-14 08:26:18 -05:00
}
}
2025-11-27 13:45:21 -06:00
pub async fn init(&mut self, user: &User) -> Result<Option<AppStateEntry>, AppError> {
let database = Database::for_user(&user).await?;
let mut trie = radix_trie::Trie::new();
let rows = sqlx::query_as!(
NameReference,
"select name, contact_id from (
select contact_id, name, count(name) as ct from names group by name
) where ct = 1;",
)
.fetch_all(&database.pool)
.await?;
2025-10-14 08:26:18 -05:00
2025-11-27 13:45:21 -06:00
for row in rows {
trie.insert(row.name, DbId::try_from(row.contact_id)?);
}
let mut map = self.map.write().expect("rwlock poisoned");
Ok(map.insert(
user.id(),
crate::AppStateEntry {
database: Arc::new(database),
contact_search: Arc::new(RwLock::new(trie)),
2025-10-14 08:26:18 -05:00
},
2025-11-27 13:45:21 -06:00
))
}
pub fn remove(&mut self, user: &impl AuthUser<Id = DbId>) {
let mut map = self.map.write().expect("rwlock poisoned");
map.remove(&user.id());
}
pub fn db(&self, user: &impl AuthUser<Id = DbId>) -> Arc<Database> {
let map = self.map.read().expect("rwlock poisoned");
map.get(&user.id()).unwrap().database.clone()
}
pub fn contact_search(&self, user: &impl AuthUser<Id = DbId>) -> Arc<RwLock<ContactTrie>> {
let map = self.map.read().expect("rwlock poisoned");
map.get(&user.id()).unwrap().contact_search.clone()
}
}
pub struct AppError(anyhow::Error);
impl IntoResponse for AppError {
fn into_response(self) -> Response {
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", self.0),
)
.into_response()
}
}
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Cli {
#[command(subcommand)]
command: Option<Commands>,
}
#[derive(Subcommand, Debug)]
enum Commands {
/// run mascarpone server (default)
Serve {
/// port to bind
#[arg(short, long, default_value_t = 3000)]
port: u32,
},
SetPassword {
/// username to create or set password
username: String,
},
}
async fn serve(port: &u32) -> Result<(), anyhow::Error> {
let users_db = {
let db_options = SqliteConnectOptions::from_str("users.db")?
.create_if_missing(true)
.to_owned();
let db = SqlitePoolOptions::new().connect_with(db_options).await?;
sqlx::migrate!("./migrations/users.db").run(&db).await?;
db
2025-10-14 08:26:18 -05:00
};
2025-11-27 13:45:21 -06:00
let state = AppState::new();
let session_store = SqliteStore::new(users_db.clone());
session_store.migrate().await?;
let deletion_task = tokio::task::spawn(
session_store
.clone()
.continuously_delete_expired(tokio::time::Duration::from_secs(600)),
);
// Generate a cryptographic key to sign the session cookie.
let key = Key::generate();
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_expiry(Expiry::OnInactivity(time::Duration::days(10)))
.with_signed(key);
let backend = Backend::new(users_db.clone());
let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build();
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
format!(
"{}=debug,tower_http=debug,axum=trace",
env!("CARGO_CRATE_NAME")
)
.into()
}),
)
.with(tracing_subscriber::fmt::layer().without_time())
.init();
2025-10-14 08:26:18 -05:00
let app = Router::new()
2025-11-27 13:45:21 -06:00
.route("/", get(home::get::home))
.merge(contact::router())
.merge(journal::router())
.merge(settings::router())
.route_layer(login_required!(Backend, login_url = "/login"))
.merge(auth::router())
.merge(ics::router())
.nest_service("/static", ServeDir::new("./static"))
.layer(auth_layer)
2025-10-14 08:26:18 -05:00
.with_state(state);
let mut listenfd = listenfd::ListenFd::from_env();
let listener = match listenfd.take_tcp_listener(0)? {
Some(listener) => {
listener.set_nonblocking(true)?;
TcpListener::from_std(listener)
}
2025-11-27 13:45:21 -06:00
None => TcpListener::bind(format!("0.0.0.0:{}", port)).await,
2025-10-14 08:26:18 -05:00
}?;
2025-11-27 13:45:21 -06:00
tracing::debug!("Starting axum on 0.0.0.0:3000...");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal(deletion_task.abort_handle()))
.await
.unwrap();
deletion_task.await??;
Ok(())
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let cli = Cli::parse();
match &cli.command {
Some(Commands::SetPassword { username }) => {
let users_db = {
let db_options = SqliteConnectOptions::from_str("users.db")?
.create_if_missing(true)
.to_owned();
let db = SqlitePoolOptions::new().connect_with(db_options).await?;
sqlx::migrate!("./migrations/users.db").run(&db).await?;
db
};
let password =
rpassword::prompt_password(format!("New password for {}: ", username)).unwrap();
let update = sqlx::query(
"insert into users (username, password) values ($1, $2) on conflict do update set password=excluded.password",
)
.bind(username)
.bind(password_auth::generate_hash(password))
.execute(&users_db)
.await?;
if update.rows_affected() > 0 {
println!("Updated password for {}.", username);
} else {
println!("No update was made; probably something went wrong.");
}
}
Some(Commands::Serve { port }) => {
serve(port).await?;
}
None => {
serve(&3000).await?;
}
}
2025-10-14 08:26:18 -05:00
Ok(())
}
2025-11-27 13:45:21 -06:00
async fn shutdown_signal(deletion_task_abort_handle: AbortHandle) {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => { deletion_task_abort_handle.abort() },
_ = terminate => { deletion_task_abort_handle.abort() },
}
}