major features update
This commit is contained in:
parent
519fb49901
commit
4e2fab67c5
48 changed files with 3925 additions and 208 deletions
279
src/main.rs
279
src/main.rs
|
|
@ -1,46 +1,189 @@
|
|||
use axum::{Router, extract::State, response::IntoResponse, routing::get};
|
||||
use maud::html;
|
||||
use axum::Router;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::get;
|
||||
use axum_login::AuthUser;
|
||||
use axum_login::{AuthManagerLayerBuilder, login_required};
|
||||
use clap::{Parser, Subcommand, arg, command};
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::signal;
|
||||
use tokio::task::AbortHandle;
|
||||
use tower_http::services::ServeDir;
|
||||
use tower_sessions::{ExpiredDeletion, Expiry, SessionManagerLayer, cookie::Key};
|
||||
use tower_sessions_sqlx_store::SqliteStore;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
mod contact;
|
||||
use contact::Contact;
|
||||
mod models;
|
||||
use models::contact::ContactTrie;
|
||||
use models::user::{Backend, User};
|
||||
|
||||
mod db;
|
||||
use db::{Database, DbId};
|
||||
|
||||
mod web;
|
||||
use web::{auth, contact, home, ics, journal, settings};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppStateEntry {
|
||||
database: Arc<Database>,
|
||||
contact_search: Arc<RwLock<ContactTrie>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
contacts: Vec<Contact>,
|
||||
map: Arc<RwLock<HashMap<DbId, AppStateEntry>>>,
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
async fn contacts(
|
||||
// access the state via the `State` extractor
|
||||
// extracting a state of the wrong type results in a compile error
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
html! {
|
||||
ul {
|
||||
@for contact in &state.contacts {
|
||||
li { (&contact.name) }
|
||||
}
|
||||
struct NameReference {
|
||||
name: String,
|
||||
contact_id: DbId,
|
||||
}
|
||||
impl AppState {
|
||||
pub fn new() -> Self {
|
||||
AppState {
|
||||
map: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
pub async fn init(&mut self, user: &User) -> Result<Option<AppStateEntry>, AppError> {
|
||||
let database = Database::for_user(&user).await?;
|
||||
let mut trie = radix_trie::Trie::new();
|
||||
let rows = sqlx::query_as!(
|
||||
NameReference,
|
||||
"select name, contact_id from (
|
||||
select contact_id, name, count(name) as ct from names group by name
|
||||
) where ct = 1;",
|
||||
)
|
||||
.fetch_all(&database.pool)
|
||||
.await?;
|
||||
|
||||
for row in rows {
|
||||
trie.insert(row.name, DbId::try_from(row.contact_id)?);
|
||||
}
|
||||
|
||||
let mut map = self.map.write().expect("rwlock poisoned");
|
||||
Ok(map.insert(
|
||||
user.id(),
|
||||
crate::AppStateEntry {
|
||||
database: Arc::new(database),
|
||||
contact_search: Arc::new(RwLock::new(trie)),
|
||||
},
|
||||
))
|
||||
}
|
||||
pub fn remove(&mut self, user: &impl AuthUser<Id = DbId>) {
|
||||
let mut map = self.map.write().expect("rwlock poisoned");
|
||||
map.remove(&user.id());
|
||||
}
|
||||
pub fn db(&self, user: &impl AuthUser<Id = DbId>) -> Arc<Database> {
|
||||
let map = self.map.read().expect("rwlock poisoned");
|
||||
|
||||
map.get(&user.id()).unwrap().database.clone()
|
||||
}
|
||||
pub fn contact_search(&self, user: &impl AuthUser<Id = DbId>) -> Arc<RwLock<ContactTrie>> {
|
||||
let map = self.map.read().expect("rwlock poisoned");
|
||||
map.get(&user.id()).unwrap().contact_search.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
let state = AppState {
|
||||
contacts: vec![
|
||||
Contact {
|
||||
name: "Foo Bar".to_string(),
|
||||
},
|
||||
Contact {
|
||||
name: "Baz Qux".to_string(),
|
||||
},
|
||||
],
|
||||
pub struct AppError(anyhow::Error);
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
(
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Something went wrong: {}", self.0),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
impl<E> From<E> for AppError
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
fn from(err: E) -> Self {
|
||||
Self(err.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Commands {
|
||||
/// run mascarpone server (default)
|
||||
Serve {
|
||||
/// port to bind
|
||||
#[arg(short, long, default_value_t = 3000)]
|
||||
port: u32,
|
||||
},
|
||||
|
||||
SetPassword {
|
||||
/// username to create or set password
|
||||
username: String,
|
||||
},
|
||||
}
|
||||
|
||||
async fn serve(port: &u32) -> Result<(), anyhow::Error> {
|
||||
let users_db = {
|
||||
let db_options = SqliteConnectOptions::from_str("users.db")?
|
||||
.create_if_missing(true)
|
||||
.to_owned();
|
||||
|
||||
let db = SqlitePoolOptions::new().connect_with(db_options).await?;
|
||||
sqlx::migrate!("./migrations/users.db").run(&db).await?;
|
||||
db
|
||||
};
|
||||
|
||||
let state = AppState::new();
|
||||
let session_store = SqliteStore::new(users_db.clone());
|
||||
session_store.migrate().await?;
|
||||
|
||||
let deletion_task = tokio::task::spawn(
|
||||
session_store
|
||||
.clone()
|
||||
.continuously_delete_expired(tokio::time::Duration::from_secs(600)),
|
||||
);
|
||||
|
||||
// Generate a cryptographic key to sign the session cookie.
|
||||
let key = Key::generate();
|
||||
|
||||
let session_layer = SessionManagerLayer::new(session_store)
|
||||
.with_secure(false)
|
||||
.with_expiry(Expiry::OnInactivity(time::Duration::days(10)))
|
||||
.with_signed(key);
|
||||
|
||||
let backend = Backend::new(users_db.clone());
|
||||
let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build();
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
||||
format!(
|
||||
"{}=debug,tower_http=debug,axum=trace",
|
||||
env!("CARGO_CRATE_NAME")
|
||||
)
|
||||
.into()
|
||||
}),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer().without_time())
|
||||
.init();
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, World!" }))
|
||||
.route("/contacts", get(contacts))
|
||||
.route("/", get(home::get::home))
|
||||
.merge(contact::router())
|
||||
.merge(journal::router())
|
||||
.merge(settings::router())
|
||||
.route_layer(login_required!(Backend, login_url = "/login"))
|
||||
.merge(auth::router())
|
||||
.merge(ics::router())
|
||||
.nest_service("/static", ServeDir::new("./static"))
|
||||
.layer(auth_layer)
|
||||
.with_state(state);
|
||||
|
||||
let mut listenfd = listenfd::ListenFd::from_env();
|
||||
|
|
@ -49,9 +192,83 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
listener.set_nonblocking(true)?;
|
||||
TcpListener::from_std(listener)
|
||||
}
|
||||
None => TcpListener::bind("0.0.0.0:3000").await,
|
||||
None => TcpListener::bind(format!("0.0.0.0:{}", port)).await,
|
||||
}?;
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
tracing::debug!("Starting axum on 0.0.0.0:3000...");
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal(deletion_task.abort_handle()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
deletion_task.await??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
match &cli.command {
|
||||
Some(Commands::SetPassword { username }) => {
|
||||
let users_db = {
|
||||
let db_options = SqliteConnectOptions::from_str("users.db")?
|
||||
.create_if_missing(true)
|
||||
.to_owned();
|
||||
|
||||
let db = SqlitePoolOptions::new().connect_with(db_options).await?;
|
||||
sqlx::migrate!("./migrations/users.db").run(&db).await?;
|
||||
db
|
||||
};
|
||||
|
||||
let password =
|
||||
rpassword::prompt_password(format!("New password for {}: ", username)).unwrap();
|
||||
|
||||
let update = sqlx::query(
|
||||
"insert into users (username, password) values ($1, $2) on conflict do update set password=excluded.password",
|
||||
)
|
||||
.bind(username)
|
||||
.bind(password_auth::generate_hash(password))
|
||||
.execute(&users_db)
|
||||
.await?;
|
||||
|
||||
if update.rows_affected() > 0 {
|
||||
println!("Updated password for {}.", username);
|
||||
} else {
|
||||
println!("No update was made; probably something went wrong.");
|
||||
}
|
||||
}
|
||||
Some(Commands::Serve { port }) => {
|
||||
serve(port).await?;
|
||||
}
|
||||
None => {
|
||||
serve(&3000).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shutdown_signal(deletion_task_abort_handle: AbortHandle) {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => { deletion_task_abort_handle.abort() },
|
||||
_ = terminate => { deletion_task_abort_handle.abort() },
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue