From 4d9b7c91a11338f0de78a83d33f00732f86fc6c6 Mon Sep 17 00:00:00 2001 From: Cameron Cordes Date: Wed, 16 Mar 2022 20:51:37 -0400 Subject: [PATCH] Improve testability and remove boxing Leverage generics to remove the extra heap allocation for the response handlers using Dao's. Also moved some of the environment variables to app state to allow for easier testing. --- src/auth.rs | 19 +++-- src/files.rs | 87 ++++++++++++++--------- src/main.rs | 169 ++++++++++++++++++++++++++++----------------- src/testhelpers.rs | 21 ++---- 4 files changed, 172 insertions(+), 124 deletions(-) diff --git a/src/auth.rs b/src/auth.rs index 58dc00e..f938ffb 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,4 +1,4 @@ -use actix_web::{post, Responder}; +use actix_web::Responder; use actix_web::{ web::{self, Json}, HttpResponse, @@ -12,10 +12,10 @@ use crate::{ database::UserDao, }; -#[post("/register")] -async fn register( +#[allow(dead_code)] +async fn register( user: Json, - user_dao: web::Data>, + user_dao: web::Data, ) -> impl Responder { if !user.username.is_empty() && user.password.len() > 5 && user.password == user.confirmation { if user_dao.user_exists(&user.username) { @@ -30,10 +30,7 @@ async fn register( } } -pub async fn login( - creds: Json, - user_dao: web::Data>, -) -> HttpResponse { +pub async fn login(creds: Json, user_dao: web::Data) -> HttpResponse { debug!("Logging in: {}", creds.username); if let Some(user) = user_dao.get_user(&creds.username, &creds.password) { @@ -74,7 +71,7 @@ mod tests { password: "pass".to_string(), }); - let response = login(j, web::Data::new(Box::new(dao))).await; + let response = login::(j, web::Data::new(dao)).await; assert_eq!(response.status(), 200); } @@ -89,7 +86,7 @@ mod tests { password: "password".to_string(), }); - let response = login(j, web::Data::new(Box::new(dao))).await; + let response = login::(j, web::Data::new(dao)).await; assert_eq!(response.status(), 200); let response_text: String = response.read_to_str(); @@ -107,7 +104,7 @@ mod tests { password: "password".to_string(), }); - let response = login(j, web::Data::new(Box::new(dao))).await; + let response = login::(j, web::Data::new(dao)).await; assert_eq!(response.status(), 404); } diff --git a/src/files.rs b/src/files.rs index 9ccb6ec..f5df626 100644 --- a/src/files.rs +++ b/src/files.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use std::fs::read_dir; use std::io; use std::path::{Path, PathBuf}; @@ -5,17 +6,25 @@ use std::path::{Path, PathBuf}; use ::anyhow; use anyhow::{anyhow, Context}; -use actix_web::{web::Query, HttpResponse}; +use actix_web::{ + web::{self, Query}, + HttpResponse, +}; use log::{debug, error}; use crate::data::{Claims, PhotosResponse, ThumbnailRequest}; +use crate::AppState; use path_absolutize::*; -pub async fn list_photos(_: Claims, req: Query) -> HttpResponse { +pub async fn list_photos( + _: Claims, + req: Query, + app_state: web::Data, +) -> HttpResponse { let path = &req.path; - if let Some(path) = is_valid_path(path) { + if let Some(path) = is_valid_full_path(&PathBuf::from(&app_state.base_path), path) { debug!("Valid path: {:?}", path); let files = list_files(&path).unwrap_or_default(); @@ -31,9 +40,7 @@ pub async fn list_photos(_: Claims, req: Query) -> HttpRespons ) }) .map(|path: &PathBuf| { - let relative = path - .strip_prefix(dotenv::var("BASE_PATH").unwrap()) - .unwrap(); + let relative = path.strip_prefix(&app_state.base_path).unwrap(); relative.to_path_buf() }) .map(|f| f.to_str().unwrap().to_string()) @@ -43,9 +50,7 @@ pub async fn list_photos(_: Claims, req: Query) -> HttpRespons .iter() .filter(|&f| f.metadata().map_or(false, |md| md.is_dir())) .map(|path: &PathBuf| { - let relative = path - .strip_prefix(dotenv::var("BASE_PATH").unwrap()) - .unwrap(); + let relative = path.strip_prefix(&app_state.base_path).unwrap(); relative.to_path_buf() }) .map(|f| f.to_str().unwrap().to_string()) @@ -82,18 +87,13 @@ pub fn is_image_or_video(path: &Path) -> bool { || extension == "nef" } -pub fn is_valid_path(path: &str) -> Option { - let base = PathBuf::from(dotenv::var("BASE_PATH").unwrap()); - - is_valid_full_path(&base, path) -} - -fn is_valid_full_path(base: &Path, path: &str) -> Option { +pub fn is_valid_full_path + Debug>(base: &P, path: &str) -> Option { debug!("Base: {:?}. Path: {}", base, path); let path = PathBuf::from(path); let mut path = if path.is_relative() { - let mut full_path = PathBuf::from(base); + let mut full_path = PathBuf::new(); + full_path.push(base); full_path.push(&path); full_path } else { @@ -109,7 +109,10 @@ fn is_valid_full_path(base: &Path, path: &str) -> Option { } } -fn is_path_above_base_dir(base: &Path, full_path: &mut PathBuf) -> anyhow::Result { +fn is_path_above_base_dir + Debug>( + base: P, + full_path: &mut PathBuf, +) -> anyhow::Result { full_path .absolutize() .with_context(|| format!("Unable to resolve absolute path: {:?}", full_path)) @@ -135,15 +138,21 @@ mod tests { use super::*; mod api { - use actix_web::{web::Query, HttpResponse}; + use super::*; + use actix::Actor; + use actix_web::{ + web::{self, Query}, + HttpResponse, + }; - use super::list_photos; use crate::{ data::{Claims, PhotosResponse, ThumbnailRequest}, testhelpers::BodyReader, + video::StreamActor, + AppState, }; - use std::fs; + use std::{fs, sync::Arc}; fn setup() { let _ = env_logger::builder().is_test(true).try_init(); @@ -160,7 +169,6 @@ mod tests { let request: Query = Query::from_query("path=").unwrap(); - std::env::set_var("BASE_PATH", "/tmp"); let mut temp_photo = std::env::temp_dir(); let mut tmp = temp_photo.clone(); @@ -169,9 +177,17 @@ mod tests { temp_photo.push("photo.jpg"); - fs::File::create(temp_photo).unwrap(); + fs::File::create(temp_photo.clone()).unwrap(); - let response: HttpResponse = list_photos(claims, request).await; + let response: HttpResponse = list_photos( + claims, + request, + web::Data::new(AppState::new( + Arc::new(StreamActor {}.start()), + String::from("/tmp"), + )), + ) + .await; let status = response.status(); let body: PhotosResponse = serde_json::from_str(&response.read_to_str()).unwrap(); @@ -200,7 +216,15 @@ mod tests { let request: Query = Query::from_query("path=..").unwrap(); - let response = list_photos(claims, request).await; + let response = list_photos( + claims, + request, + web::Data::new(AppState::new( + Arc::new(StreamActor {}.start()), + String::from("/tmp"), + )), + ) + .await; assert_eq!(response.status(), 400); } @@ -208,12 +232,13 @@ mod tests { #[test] fn directory_traversal_test() { - assert_eq!(None, is_valid_path("../")); - assert_eq!(None, is_valid_path("..")); - assert_eq!(None, is_valid_path("fake/../../../")); - assert_eq!(None, is_valid_path("../../../etc/passwd")); - assert_eq!(None, is_valid_path("..//etc/passwd")); - assert_eq!(None, is_valid_path("../../etc/passwd")); + let base = env::temp_dir(); + assert_eq!(None, is_valid_full_path(&base, "../")); + assert_eq!(None, is_valid_full_path(&base, "..")); + assert_eq!(None, is_valid_full_path(&base, "fake/../../../")); + assert_eq!(None, is_valid_full_path(&base, "../../../etc/passwd")); + assert_eq!(None, is_valid_full_path(&base, "..//etc/passwd")); + assert_eq!(None, is_valid_full_path(&base, "../../etc/passwd")); } #[test] diff --git a/src/main.rs b/src/main.rs index 12e1b08..0a62a28 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,7 +32,7 @@ use log::{debug, error, info}; use crate::auth::login; use crate::data::*; use crate::database::*; -use crate::files::{is_image_or_video, is_valid_path}; +use crate::files::{is_image_or_video, is_valid_full_path}; use crate::video::*; mod auth; @@ -62,13 +62,15 @@ async fn get_image( _claims: Claims, request: HttpRequest, req: web::Query, + app_state: web::Data, ) -> impl Responder { - if let Some(path) = is_valid_path(&req.path) { + if let Some(path) = is_valid_full_path(&app_state.base_path, &req.path) { if req.size.is_some() { - let thumbs = dotenv::var("THUMBNAILS").unwrap(); let relative_path = path - .strip_prefix(dotenv::var("BASE_PATH").unwrap()) - .expect("Error stripping prefix"); + .strip_prefix(&app_state.base_path) + .expect("Error stripping base path prefix from thumbnail"); + + let thumbs = &app_state.thumbnail_path; let thumb_path = Path::new(&thumbs).join(relative_path); debug!("{:?}", thumb_path); @@ -89,8 +91,12 @@ async fn get_image( } #[get("/image/metadata")] -async fn get_file_metadata(_: Claims, path: web::Query) -> impl Responder { - match is_valid_path(&path.path) +async fn get_file_metadata( + _: Claims, + path: web::Query, + app_state: web::Data, +) -> impl Responder { + match is_valid_full_path(&app_state.base_path, &path.path) .ok_or_else(|| ErrorKind::InvalidData.into()) .and_then(File::open) .and_then(|file| file.metadata()) @@ -107,7 +113,11 @@ async fn get_file_metadata(_: Claims, path: web::Query) -> imp } #[post("/image")] -async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder { +async fn upload_image( + _: Claims, + mut payload: mp::Multipart, + app_state: web::Data, +) -> impl Responder { let mut file_content: BytesMut = BytesMut::new(); let mut file_name: Option = None; let mut file_path: Option = None; @@ -131,10 +141,12 @@ async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder { } } - let path = file_path.unwrap_or_else(|| dotenv::var("BASE_PATH").unwrap()); + let path = file_path.unwrap_or_else(|| app_state.base_path.clone()); if !file_content.is_empty() { let full_path = PathBuf::from(&path).join(file_name.unwrap()); - if let Some(full_path) = is_valid_path(full_path.to_str().unwrap_or("")) { + if let Some(full_path) = + is_valid_full_path(&app_state.base_path, full_path.to_str().unwrap_or("")) + { if !full_path.is_file() && is_image_or_video(&full_path) { let mut file = File::create(full_path).unwrap(); file.write_all(&file_content).unwrap(); @@ -155,7 +167,7 @@ async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder { #[post("/video/generate")] async fn generate_video( _claims: Claims, - data: web::Data, + app_state: web::Data, body: web::Json, ) -> impl Responder { let filename = PathBuf::from(&body.path); @@ -163,9 +175,10 @@ async fn generate_video( if let Some(name) = filename.file_stem() { let filename = name.to_str().expect("Filename should convert to string"); let playlist = format!("tmp/{}.m3u8", filename); - if let Some(path) = is_valid_path(&body.path) { + if let Some(path) = is_valid_full_path(&app_state.base_path, &body.path) { if let Ok(child) = create_playlist(path.to_str().unwrap(), &playlist).await { - data.stream_manager + app_state + .stream_manager .do_send(ProcessMessage(playlist.clone(), child)); } } else { @@ -184,12 +197,13 @@ async fn stream_video( request: HttpRequest, _: Claims, path: web::Query, + app_state: web::Data, ) -> impl Responder { let playlist = &path.path; debug!("Playlist: {}", playlist); // Extract video playlist dir to dotenv - if !playlist.starts_with("tmp") && is_valid_path(playlist) != None { + if !playlist.starts_with("tmp") && is_valid_full_path(&app_state.base_path, playlist) != None { HttpResponse::BadRequest().finish() } else if let Ok(file) = NamedFile::open(playlist) { file.into_response(&request) @@ -406,7 +420,56 @@ fn main() -> std::io::Result<()> { env_logger::init(); create_thumbnails(); + watch_files(); + let system = actix::System::new(); + system.block_on(async { + let app_data = web::Data::new(AppState::default()); + + let labels = HashMap::new(); + let prometheus = PrometheusMetricsBuilder::new("api") + .const_labels(labels) + .build() + .expect("Unable to build prometheus metrics middleware"); + + prometheus + .registry + .register(Box::new(IMAGE_GAUGE.clone())) + .unwrap(); + prometheus + .registry + .register(Box::new(VIDEO_GAUGE.clone())) + .unwrap(); + + HttpServer::new(move || { + let user_dao = SqliteUserDao::new(); + let favorites_dao = SqliteFavoriteDao::new(); + App::new() + .wrap(middleware::Logger::default()) + .service(web::resource("/login").route(web::post().to(login::))) + .service(web::resource("/photos").route(web::get().to(files::list_photos))) + .service(get_image) + .service(upload_image) + .service(generate_video) + .service(stream_video) + .service(get_video_part) + .service(favorites) + .service(put_add_favorite) + .service(delete_favorite) + .service(get_file_metadata) + .app_data(app_data.clone()) + .app_data::>(Data::new(user_dao)) + .app_data::>>(Data::new(Box::new(favorites_dao))) + .wrap(prometheus.clone()) + }) + .bind(dotenv::var("BIND_URL").unwrap())? + .bind("localhost:8088")? + .run() + .await + }) +} + +fn watch_files() { std::thread::spawn(|| { let (wtx, wrx) = channel(); let mut watcher = watcher(wtx, std::time::Duration::from_secs(10)).unwrap(); @@ -449,58 +512,34 @@ fn main() -> std::io::Result<()> { } } }); - - let system = actix::System::new(); - system.block_on(async { - let act = StreamActor {}.start(); - - let app_data = web::Data::new(AppState { - stream_manager: Arc::new(act), - }); - - let labels = HashMap::new(); - let prometheus = PrometheusMetricsBuilder::new("api") - .const_labels(labels) - .build() - .expect("Unable to build prometheus metrics middleware"); - - prometheus - .registry - .register(Box::new(IMAGE_GAUGE.clone())) - .unwrap(); - prometheus - .registry - .register(Box::new(VIDEO_GAUGE.clone())) - .unwrap(); - - HttpServer::new(move || { - let user_dao = SqliteUserDao::new(); - let favorites_dao = SqliteFavoriteDao::new(); - App::new() - .wrap(middleware::Logger::default()) - .service(web::resource("/login").route(web::post().to(login))) - .service(web::resource("/photos").route(web::get().to(files::list_photos))) - .service(get_image) - .service(upload_image) - .service(generate_video) - .service(stream_video) - .service(get_video_part) - .service(favorites) - .service(put_add_favorite) - .service(delete_favorite) - .service(get_file_metadata) - .app_data(app_data.clone()) - .app_data::>>(Data::new(Box::new(user_dao))) - .app_data::>>(Data::new(Box::new(favorites_dao))) - .wrap(prometheus.clone()) - }) - .bind(dotenv::var("BIND_URL").unwrap())? - .bind("localhost:8088")? - .run() - .await - }) } -struct AppState { +pub struct AppState { stream_manager: Arc>, + base_path: String, + thumbnail_path: String, +} + +impl AppState { + fn new( + stream_manager: Arc>, + base_path: String, + thumbnail_path: String, + ) -> Self { + Self { + stream_manager, + base_path, + thumbnail_path, + } + } +} + +impl Default for AppState { + fn default() -> Self { + Self::new( + Arc::new(StreamActor {}.start()), + env::var("BASE_PATH").expect("BASE_PATH was not set in the env"), + env::var("THUMBNAILS").expect("THUMBNAILS was not set in the env"), + ) + } } diff --git a/src/testhelpers.rs b/src/testhelpers.rs index fc34ecd..f51e066 100644 --- a/src/testhelpers.rs +++ b/src/testhelpers.rs @@ -1,6 +1,7 @@ -use actix_web::body::MessageBody; -use actix_web::{body::BoxBody, HttpResponse}; -use serde::de::DeserializeOwned; +use actix_web::{ + body::{BoxBody, MessageBody}, + HttpResponse, +}; use crate::database::{models::User, UserDao}; use std::cell::RefCell; @@ -65,17 +66,3 @@ impl BodyReader for HttpResponse { std::str::from_utf8(&body).unwrap().to_string() } } - -pub trait TypedBodyReader -where - T: DeserializeOwned, -{ - fn read_body(self) -> T; -} - -impl TypedBodyReader for HttpResponse { - fn read_body(self) -> T { - let body = self.read_to_str(); - serde_json::from_value(serde_json::Value::String(body.clone())).unwrap() - } -}