diff --git a/src/auth.rs b/src/auth.rs index b6e7d81..4c019b7 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,19 +1,23 @@ -use actix_web::web::{HttpResponse, Json}; +use actix_web::web::{self, HttpResponse, Json}; use actix_web::{post, Responder}; use chrono::{Duration, Utc}; use jsonwebtoken::{encode, EncodingKey, Header}; use log::{debug, error}; -use crate::data::LoginRequest; -use crate::data::{secret_key, Claims, CreateAccountRequest, Token}; -use crate::database::{create_user, get_user, user_exists}; +use crate::{ + data::{secret_key, Claims, CreateAccountRequest, LoginRequest, Token}, + database::UserDao, +}; #[post("/register")] -async fn register(user: Json) -> impl Responder { +async fn register( + user: Json, + user_dao: web::Data>, +) -> impl Responder { if !user.username.is_empty() && user.password.len() > 5 && user.password == user.confirmation { - if user_exists(&user.username) { + if user_dao.user_exists(&user.username) { HttpResponse::BadRequest() - } else if let Some(_user) = create_user(&user.username, &user.password) { + } else if let Some(_user) = user_dao.create_user(&user.username, &user.password) { HttpResponse::Ok() } else { HttpResponse::InternalServerError() @@ -23,10 +27,12 @@ async fn register(user: Json) -> impl Responder { } } -#[post("/login")] -async fn login(creds: Json) -> impl Responder { +pub async fn login( + creds: Json, + user_dao: web::Data>, +) -> HttpResponse { debug!("Logging in: {}", creds.username); - if let Some(user) = get_user(&creds.username, &creds.password) { + if let Some(user) = user_dao.get_user(&creds.username, &creds.password) { let claims = Claims { sub: user.id.to_string(), exp: (Utc::now() + Duration::days(5)).timestamp(), @@ -43,3 +49,55 @@ async fn login(creds: Json) -> impl Responder { HttpResponse::NotFound().finish() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::database::testhelpers::{BodyReader, TestUserDao}; + + #[actix_rt::test] + async fn test_login_reports_200_when_user_exists() { + let dao = TestUserDao::new(); + dao.create_user("user", "pass"); + + let j = Json(LoginRequest { + username: "user".to_string(), + password: "pass".to_string(), + }); + + let response = login(j, web::Data::new(Box::new(dao))).await; + + assert_eq!(response.status(), 200); + } + + #[actix_rt::test] + async fn test_login_returns_token_on_success() { + let dao = TestUserDao::new(); + dao.create_user("user", "password"); + + let j = Json(LoginRequest { + username: "user".to_string(), + password: "password".to_string(), + }); + + let response = login(j, web::Data::new(Box::new(dao))).await; + + assert_eq!(response.status(), 200); + assert!(response.body().read_to_str().contains("\"token\"")); + } + + #[actix_rt::test] + async fn test_login_reports_400_when_user_does_not_exist() { + let dao = TestUserDao::new(); + dao.create_user("user", "password"); + + let j = Json(LoginRequest { + username: "doesnotexist".to_string(), + password: "password".to_string(), + }); + + let response = login(j, web::Data::new(Box::new(dao))).await; + + assert_eq!(response.status(), 404); + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 22a9dc1..96f5915 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,72 +1,88 @@ use bcrypt::{hash, verify, DEFAULT_COST}; use diesel::prelude::*; use diesel::sqlite::SqliteConnection; -use dotenv::dotenv; use crate::database::models::{Favorite, InsertFavorite, InsertUser, User}; mod models; mod schema; -fn connect() -> SqliteConnection { - dotenv().ok(); - - let db_url = dotenv::var("DATABASE_URL").expect("DATABASE_URL must be set"); - SqliteConnection::establish(&db_url).expect("Error connecting to DB") +pub trait UserDao { + fn create_user(&self, user: &str, password: &str) -> Option; + fn get_user(&self, user: &str, password: &str) -> Option; + fn user_exists(&self, user: &str) -> bool; } -// TODO: Should probably use Result here -pub fn create_user(user: &str, pass: &str) -> Option { - use schema::users::dsl::*; +pub struct SqliteUserDao { + connection: SqliteConnection, +} - let hashed = hash(pass, DEFAULT_COST); - if let Ok(hash) = hashed { - let connection = connect(); - diesel::insert_into(users) - .values(InsertUser { - username: user, - password: &hash, - }) - .execute(&connection) - .unwrap(); +impl SqliteUserDao { + pub fn new() -> Self { + Self { + connection: connect(), + } + } +} + +impl UserDao for SqliteUserDao { + // TODO: Should probably use Result here + fn create_user(&self, user: &str, pass: &str) -> std::option::Option { + use schema::users::dsl::*; + + let hashed = hash(pass, DEFAULT_COST); + if let Ok(hash) = hashed { + diesel::insert_into(users) + .values(InsertUser { + username: user, + password: &hash, + }) + .execute(&self.connection) + .unwrap(); + + match users + .filter(username.eq(username)) + .load::(&self.connection) + .unwrap() + .first() + { + Some(u) => Some(u.clone()), + None => None, + } + } else { + None + } + } + + fn get_user(&self, user: &str, pass: &str) -> Option { + use schema::users::dsl::*; match users .filter(username.eq(user)) - .load::(&connection) - .unwrap() + .load::(&self.connection) + .unwrap_or_default() .first() { - Some(u) => Some(u.clone()), - None => None, + Some(u) if verify(pass, &u.password).unwrap_or(false) => Some(u.clone()), + _ => None, } - } else { - None + } + + fn user_exists(&self, user: &str) -> bool { + use schema::users::dsl::*; + + users + .filter(username.eq(user)) + .load::(&self.connection) + .unwrap_or_default() + .first() + .is_some() } } -pub fn get_user(user: &str, pass: &str) -> Option { - use schema::users::dsl::*; - - match users - .filter(username.eq(user)) - .load::(&connect()) - .unwrap_or_default() - .first() - { - Some(u) if verify(pass, &u.password).unwrap_or(false) => Some(u.clone()), - _ => None, - } -} - -pub fn user_exists(name: &str) -> bool { - use schema::users::dsl::*; - - users - .filter(username.eq(name)) - .load::(&connect()) - .unwrap_or_default() - .first() - .is_some() +fn connect() -> SqliteConnection { + let db_url = dotenv::var("DATABASE_URL").expect("DATABASE_URL must be set"); + SqliteConnection::establish(&db_url).expect("Error connecting to DB") } pub fn add_favorite(user_id: i32, favorite_path: String) { @@ -90,3 +106,81 @@ pub fn get_favorites(user_id: i32) -> Vec { .load::(&connect()) .unwrap_or_default() } + +#[cfg(test)] +pub mod testhelpers { + use actix_web::dev::{Body, ResponseBody}; + + use super::{models::User, UserDao}; + use std::cell::RefCell; + use std::option::Option; + + pub struct TestUserDao { + pub user_map: RefCell>, + } + + impl TestUserDao { + pub fn new() -> Self { + Self { + user_map: RefCell::new(Vec::new()), + } + } + } + + impl UserDao for TestUserDao { + fn create_user(&self, username: &str, password: &str) -> Option { + let u = User { + id: (self.user_map.borrow().len() + 1) as i32, + username: username.to_string(), + password: password.to_string(), + }; + + self.user_map.borrow_mut().push(u.clone()); + + Some(u) + } + + fn get_user(&self, user: &str, pass: &str) -> Option { + match self + .user_map + .borrow() + .iter() + .filter(|u| u.username == user && u.password == pass) + .collect::>() + .first() + { + Some(u) => { + let copy = (*u).clone(); + Some(copy) + } + None => None, + } + } + + fn user_exists(&self, user: &str) -> bool { + self.user_map + .borrow() + .iter() + .filter(|u| u.username == user) + .collect::>() + .first() + .is_some() + } + } + + pub trait BodyReader { + fn read_to_str(&self) -> &str; + } + + impl BodyReader for ResponseBody { + fn read_to_str(&self) -> &str { + match self { + ResponseBody::Body(body) => match body { + Body::Bytes(b) => std::str::from_utf8(&b).unwrap(), + _ => panic!("Unknown response body content"), + }, + _ => panic!("Unknown response body"), + } + } + } +} diff --git a/src/main.rs b/src/main.rs index 3549230..fb7e310 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ extern crate diesel; extern crate rayon; use crate::auth::login; +use database::{SqliteUserDao, UserDao}; use futures::stream::StreamExt; use std::fs::File; use std::io::prelude::*; @@ -328,8 +329,9 @@ fn main() -> std::io::Result<()> { }); HttpServer::new(move || { + let user_dao = SqliteUserDao::new(); App::new() - .service(login) + .service(web::resource("/login").route(web::post().to(login))) .service(list_photos) .service(get_image) .service(upload_image) @@ -339,6 +341,7 @@ fn main() -> std::io::Result<()> { .service(favorites) .service(post_add_favorite) .app_data(app_data.clone()) + .data::>(Box::new(user_dao)) }) .bind(dotenv::var("BIND_URL").unwrap())? .bind("localhost:8088")?