127 lines
3.4 KiB
Rust
127 lines
3.4 KiB
Rust
|
|
use axum_login::{AuthUser, AuthnBackend, UserId};
|
||
|
|
use password_auth::verify_password;
|
||
|
|
use serde::{Deserialize, Serialize};
|
||
|
|
use sqlx::{FromRow, SqlitePool};
|
||
|
|
use tokio::task;
|
||
|
|
|
||
|
|
#[derive(Clone, Serialize, Deserialize, FromRow)]
|
||
|
|
pub struct User {
|
||
|
|
id: i64,
|
||
|
|
pub username: String,
|
||
|
|
password: String,
|
||
|
|
pub ephemeral: bool,
|
||
|
|
}
|
||
|
|
|
||
|
|
// Here we've implemented `Debug` manually to avoid accidentally logging the
|
||
|
|
// password hash.
|
||
|
|
impl std::fmt::Debug for User {
|
||
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
|
|
f.debug_struct("User")
|
||
|
|
.field("id", &self.id)
|
||
|
|
.field("username", &self.username)
|
||
|
|
.field("password", &"[redacted]")
|
||
|
|
.field("ephemeral", &self.ephemeral)
|
||
|
|
.finish()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl AuthUser for User {
|
||
|
|
type Id = i64;
|
||
|
|
|
||
|
|
fn id(&self) -> Self::Id {
|
||
|
|
self.id
|
||
|
|
}
|
||
|
|
|
||
|
|
fn session_auth_hash(&self) -> &[u8] {
|
||
|
|
// We use the password hash as the auth hash--what this means
|
||
|
|
// is when the user changes their password the auth session becomes invalid.
|
||
|
|
self.password.as_bytes()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone, Deserialize)]
|
||
|
|
pub struct Credentials {
|
||
|
|
pub username: String,
|
||
|
|
pub password: String,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, Clone)]
|
||
|
|
pub struct Backend {
|
||
|
|
db: SqlitePool,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Backend {
|
||
|
|
pub fn new(db: SqlitePool) -> Self {
|
||
|
|
Self { db }
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn set_password(&self, creds: Credentials) -> Result<(), anyhow::Error> {
|
||
|
|
if creds.username != "demo" {
|
||
|
|
sqlx::query("update users set password=$2 where username=$1")
|
||
|
|
.bind(creds.username)
|
||
|
|
.bind(password_auth::generate_hash(creds.password))
|
||
|
|
.execute(&self.db)
|
||
|
|
.await?;
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(())
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn find_user(
|
||
|
|
&self,
|
||
|
|
username: impl AsRef<str>,
|
||
|
|
) -> Result<Option<User>, anyhow::Error> {
|
||
|
|
let user = sqlx::query_as("select * from users where username = ?")
|
||
|
|
.bind(username.as_ref())
|
||
|
|
.fetch_optional(&self.db)
|
||
|
|
.await?;
|
||
|
|
|
||
|
|
Ok(user)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug, thiserror::Error)]
|
||
|
|
pub enum Error {
|
||
|
|
#[error(transparent)]
|
||
|
|
Sqlx(#[from] sqlx::Error),
|
||
|
|
|
||
|
|
#[error(transparent)]
|
||
|
|
TaskJoin(#[from] task::JoinError),
|
||
|
|
}
|
||
|
|
|
||
|
|
impl AuthnBackend for Backend {
|
||
|
|
type User = User;
|
||
|
|
type Credentials = Credentials;
|
||
|
|
type Error = Error;
|
||
|
|
|
||
|
|
async fn authenticate(
|
||
|
|
&self,
|
||
|
|
creds: Self::Credentials,
|
||
|
|
) -> Result<Option<Self::User>, Self::Error> {
|
||
|
|
let user: Option<Self::User> = sqlx::query_as("select * from users where username = $1")
|
||
|
|
.bind(creds.username)
|
||
|
|
.fetch_optional(&self.db)
|
||
|
|
.await?;
|
||
|
|
|
||
|
|
// Verifying the password is blocking and potentially slow, so we'll do so via
|
||
|
|
// `spawn_blocking`.
|
||
|
|
task::spawn_blocking(|| {
|
||
|
|
// We're using password-based authentication--this works by comparing our form
|
||
|
|
// input with an argon2 password hash.
|
||
|
|
Ok(user.filter(|user| verify_password(creds.password, &user.password).is_ok()))
|
||
|
|
})
|
||
|
|
.await?
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
|
||
|
|
let user = sqlx::query_as("select * from users where id = ?")
|
||
|
|
.bind(user_id)
|
||
|
|
.fetch_optional(&self.db)
|
||
|
|
.await?;
|
||
|
|
|
||
|
|
Ok(user)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub type AuthSession = axum_login::AuthSession<Backend>;
|