diff --git a/src/auth.rs b/src/auth.rs index 58dc00e..f938ffb 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,4 +1,4 @@ -use actix_web::{post, Responder}; +use actix_web::Responder; use actix_web::{ web::{self, Json}, HttpResponse, @@ -12,10 +12,10 @@ use crate::{ database::UserDao, }; -#[post("/register")] -async fn register( +#[allow(dead_code)] +async fn register( user: Json, - user_dao: web::Data>, + user_dao: web::Data, ) -> impl Responder { if !user.username.is_empty() && user.password.len() > 5 && user.password == user.confirmation { if user_dao.user_exists(&user.username) { @@ -30,10 +30,7 @@ async fn register( } } -pub async fn login( - creds: Json, - user_dao: web::Data>, -) -> HttpResponse { +pub async fn login(creds: Json, user_dao: web::Data) -> HttpResponse { debug!("Logging in: {}", creds.username); if let Some(user) = user_dao.get_user(&creds.username, &creds.password) { @@ -74,7 +71,7 @@ mod tests { password: "pass".to_string(), }); - let response = login(j, web::Data::new(Box::new(dao))).await; + let response = login::(j, web::Data::new(dao)).await; assert_eq!(response.status(), 200); } @@ -89,7 +86,7 @@ mod tests { password: "password".to_string(), }); - let response = login(j, web::Data::new(Box::new(dao))).await; + let response = login::(j, web::Data::new(dao)).await; assert_eq!(response.status(), 200); let response_text: String = response.read_to_str(); @@ -107,7 +104,7 @@ mod tests { password: "password".to_string(), }); - let response = login(j, web::Data::new(Box::new(dao))).await; + let response = login::(j, web::Data::new(dao)).await; assert_eq!(response.status(), 404); } diff --git a/src/files.rs b/src/files.rs index 9ccb6ec..f5df626 100644 --- a/src/files.rs +++ b/src/files.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use std::fs::read_dir; use std::io; use std::path::{Path, PathBuf}; @@ -5,17 +6,25 @@ use std::path::{Path, PathBuf}; use ::anyhow; use anyhow::{anyhow, Context}; -use actix_web::{web::Query, HttpResponse}; +use actix_web::{ + web::{self, Query}, + HttpResponse, +}; use log::{debug, error}; use crate::data::{Claims, PhotosResponse, ThumbnailRequest}; +use crate::AppState; use path_absolutize::*; -pub async fn list_photos(_: Claims, req: Query) -> HttpResponse { +pub async fn list_photos( + _: Claims, + req: Query, + app_state: web::Data, +) -> HttpResponse { let path = &req.path; - if let Some(path) = is_valid_path(path) { + if let Some(path) = is_valid_full_path(&PathBuf::from(&app_state.base_path), path) { debug!("Valid path: {:?}", path); let files = list_files(&path).unwrap_or_default(); @@ -31,9 +40,7 @@ pub async fn list_photos(_: Claims, req: Query) -> HttpRespons ) }) .map(|path: &PathBuf| { - let relative = path - .strip_prefix(dotenv::var("BASE_PATH").unwrap()) - .unwrap(); + let relative = path.strip_prefix(&app_state.base_path).unwrap(); relative.to_path_buf() }) .map(|f| f.to_str().unwrap().to_string()) @@ -43,9 +50,7 @@ pub async fn list_photos(_: Claims, req: Query) -> HttpRespons .iter() .filter(|&f| f.metadata().map_or(false, |md| md.is_dir())) .map(|path: &PathBuf| { - let relative = path - .strip_prefix(dotenv::var("BASE_PATH").unwrap()) - .unwrap(); + let relative = path.strip_prefix(&app_state.base_path).unwrap(); relative.to_path_buf() }) .map(|f| f.to_str().unwrap().to_string()) @@ -82,18 +87,13 @@ pub fn is_image_or_video(path: &Path) -> bool { || extension == "nef" } -pub fn is_valid_path(path: &str) -> Option { - let base = PathBuf::from(dotenv::var("BASE_PATH").unwrap()); - - is_valid_full_path(&base, path) -} - -fn is_valid_full_path(base: &Path, path: &str) -> Option { +pub fn is_valid_full_path + Debug>(base: &P, path: &str) -> Option { debug!("Base: {:?}. Path: {}", base, path); let path = PathBuf::from(path); let mut path = if path.is_relative() { - let mut full_path = PathBuf::from(base); + let mut full_path = PathBuf::new(); + full_path.push(base); full_path.push(&path); full_path } else { @@ -109,7 +109,10 @@ fn is_valid_full_path(base: &Path, path: &str) -> Option { } } -fn is_path_above_base_dir(base: &Path, full_path: &mut PathBuf) -> anyhow::Result { +fn is_path_above_base_dir + Debug>( + base: P, + full_path: &mut PathBuf, +) -> anyhow::Result { full_path .absolutize() .with_context(|| format!("Unable to resolve absolute path: {:?}", full_path)) @@ -135,15 +138,21 @@ mod tests { use super::*; mod api { - use actix_web::{web::Query, HttpResponse}; + use super::*; + use actix::Actor; + use actix_web::{ + web::{self, Query}, + HttpResponse, + }; - use super::list_photos; use crate::{ data::{Claims, PhotosResponse, ThumbnailRequest}, testhelpers::BodyReader, + video::StreamActor, + AppState, }; - use std::fs; + use std::{fs, sync::Arc}; fn setup() { let _ = env_logger::builder().is_test(true).try_init(); @@ -160,7 +169,6 @@ mod tests { let request: Query = Query::from_query("path=").unwrap(); - std::env::set_var("BASE_PATH", "/tmp"); let mut temp_photo = std::env::temp_dir(); let mut tmp = temp_photo.clone(); @@ -169,9 +177,17 @@ mod tests { temp_photo.push("photo.jpg"); - fs::File::create(temp_photo).unwrap(); + fs::File::create(temp_photo.clone()).unwrap(); - let response: HttpResponse = list_photos(claims, request).await; + let response: HttpResponse = list_photos( + claims, + request, + web::Data::new(AppState::new( + Arc::new(StreamActor {}.start()), + String::from("/tmp"), + )), + ) + .await; let status = response.status(); let body: PhotosResponse = serde_json::from_str(&response.read_to_str()).unwrap(); @@ -200,7 +216,15 @@ mod tests { let request: Query = Query::from_query("path=..").unwrap(); - let response = list_photos(claims, request).await; + let response = list_photos( + claims, + request, + web::Data::new(AppState::new( + Arc::new(StreamActor {}.start()), + String::from("/tmp"), + )), + ) + .await; assert_eq!(response.status(), 400); } @@ -208,12 +232,13 @@ mod tests { #[test] fn directory_traversal_test() { - assert_eq!(None, is_valid_path("../")); - assert_eq!(None, is_valid_path("..")); - assert_eq!(None, is_valid_path("fake/../../../")); - assert_eq!(None, is_valid_path("../../../etc/passwd")); - assert_eq!(None, is_valid_path("..//etc/passwd")); - assert_eq!(None, is_valid_path("../../etc/passwd")); + let base = env::temp_dir(); + assert_eq!(None, is_valid_full_path(&base, "../")); + assert_eq!(None, is_valid_full_path(&base, "..")); + assert_eq!(None, is_valid_full_path(&base, "fake/../../../")); + assert_eq!(None, is_valid_full_path(&base, "../../../etc/passwd")); + assert_eq!(None, is_valid_full_path(&base, "..//etc/passwd")); + assert_eq!(None, is_valid_full_path(&base, "../../etc/passwd")); } #[test] diff --git a/src/main.rs b/src/main.rs index 12e1b08..0a62a28 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,7 +32,7 @@ use log::{debug, error, info}; use crate::auth::login; use crate::data::*; use crate::database::*; -use crate::files::{is_image_or_video, is_valid_path}; +use crate::files::{is_image_or_video, is_valid_full_path}; use crate::video::*; mod auth; @@ -62,13 +62,15 @@ async fn get_image( _claims: Claims, request: HttpRequest, req: web::Query, + app_state: web::Data, ) -> impl Responder { - if let Some(path) = is_valid_path(&req.path) { + if let Some(path) = is_valid_full_path(&app_state.base_path, &req.path) { if req.size.is_some() { - let thumbs = dotenv::var("THUMBNAILS").unwrap(); let relative_path = path - .strip_prefix(dotenv::var("BASE_PATH").unwrap()) - .expect("Error stripping prefix"); + .strip_prefix(&app_state.base_path) + .expect("Error stripping base path prefix from thumbnail"); + + let thumbs = &app_state.thumbnail_path; let thumb_path = Path::new(&thumbs).join(relative_path); debug!("{:?}", thumb_path); @@ -89,8 +91,12 @@ async fn get_image( } #[get("/image/metadata")] -async fn get_file_metadata(_: Claims, path: web::Query) -> impl Responder { - match is_valid_path(&path.path) +async fn get_file_metadata( + _: Claims, + path: web::Query, + app_state: web::Data, +) -> impl Responder { + match is_valid_full_path(&app_state.base_path, &path.path) .ok_or_else(|| ErrorKind::InvalidData.into()) .and_then(File::open) .and_then(|file| file.metadata()) @@ -107,7 +113,11 @@ async fn get_file_metadata(_: Claims, path: web::Query) -> imp } #[post("/image")] -async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder { +async fn upload_image( + _: Claims, + mut payload: mp::Multipart, + app_state: web::Data, +) -> impl Responder { let mut file_content: BytesMut = BytesMut::new(); let mut file_name: Option = None; let mut file_path: Option = None; @@ -131,10 +141,12 @@ async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder { } } - let path = file_path.unwrap_or_else(|| dotenv::var("BASE_PATH").unwrap()); + let path = file_path.unwrap_or_else(|| app_state.base_path.clone()); if !file_content.is_empty() { let full_path = PathBuf::from(&path).join(file_name.unwrap()); - if let Some(full_path) = is_valid_path(full_path.to_str().unwrap_or("")) { + if let Some(full_path) = + is_valid_full_path(&app_state.base_path, full_path.to_str().unwrap_or("")) + { if !full_path.is_file() && is_image_or_video(&full_path) { let mut file = File::create(full_path).unwrap(); file.write_all(&file_content).unwrap(); @@ -155,7 +167,7 @@ async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder { #[post("/video/generate")] async fn generate_video( _claims: Claims, - data: web::Data, + app_state: web::Data, body: web::Json, ) -> impl Responder { let filename = PathBuf::from(&body.path); @@ -163,9 +175,10 @@ async fn generate_video( if let Some(name) = filename.file_stem() { let filename = name.to_str().expect("Filename should convert to string"); let playlist = format!("tmp/{}.m3u8", filename); - if let Some(path) = is_valid_path(&body.path) { + if let Some(path) = is_valid_full_path(&app_state.base_path, &body.path) { if let Ok(child) = create_playlist(path.to_str().unwrap(), &playlist).await { - data.stream_manager + app_state + .stream_manager .do_send(ProcessMessage(playlist.clone(), child)); } } else { @@ -184,12 +197,13 @@ async fn stream_video( request: HttpRequest, _: Claims, path: web::Query, + app_state: web::Data, ) -> impl Responder { let playlist = &path.path; debug!("Playlist: {}", playlist); // Extract video playlist dir to dotenv - if !playlist.starts_with("tmp") && is_valid_path(playlist) != None { + if !playlist.starts_with("tmp") && is_valid_full_path(&app_state.base_path, playlist) != None { HttpResponse::BadRequest().finish() } else if let Ok(file) = NamedFile::open(playlist) { file.into_response(&request) @@ -406,7 +420,56 @@ fn main() -> std::io::Result<()> { env_logger::init(); create_thumbnails(); + watch_files(); + let system = actix::System::new(); + system.block_on(async { + let app_data = web::Data::new(AppState::default()); + + let labels = HashMap::new(); + let prometheus = PrometheusMetricsBuilder::new("api") + .const_labels(labels) + .build() + .expect("Unable to build prometheus metrics middleware"); + + prometheus + .registry + .register(Box::new(IMAGE_GAUGE.clone())) + .unwrap(); + prometheus + .registry + .register(Box::new(VIDEO_GAUGE.clone())) + .unwrap(); + + HttpServer::new(move || { + let user_dao = SqliteUserDao::new(); + let favorites_dao = SqliteFavoriteDao::new(); + App::new() + .wrap(middleware::Logger::default()) + .service(web::resource("/login").route(web::post().to(login::))) + .service(web::resource("/photos").route(web::get().to(files::list_photos))) + .service(get_image) + .service(upload_image) + .service(generate_video) + .service(stream_video) + .service(get_video_part) + .service(favorites) + .service(put_add_favorite) + .service(delete_favorite) + .service(get_file_metadata) + .app_data(app_data.clone()) + .app_data::>(Data::new(user_dao)) + .app_data::>>(Data::new(Box::new(favorites_dao))) + .wrap(prometheus.clone()) + }) + .bind(dotenv::var("BIND_URL").unwrap())? + .bind("localhost:8088")? + .run() + .await + }) +} + +fn watch_files() { std::thread::spawn(|| { let (wtx, wrx) = channel(); let mut watcher = watcher(wtx, std::time::Duration::from_secs(10)).unwrap(); @@ -449,58 +512,34 @@ fn main() -> std::io::Result<()> { } } }); - - let system = actix::System::new(); - system.block_on(async { - let act = StreamActor {}.start(); - - let app_data = web::Data::new(AppState { - stream_manager: Arc::new(act), - }); - - let labels = HashMap::new(); - let prometheus = PrometheusMetricsBuilder::new("api") - .const_labels(labels) - .build() - .expect("Unable to build prometheus metrics middleware"); - - prometheus - .registry - .register(Box::new(IMAGE_GAUGE.clone())) - .unwrap(); - prometheus - .registry - .register(Box::new(VIDEO_GAUGE.clone())) - .unwrap(); - - HttpServer::new(move || { - let user_dao = SqliteUserDao::new(); - let favorites_dao = SqliteFavoriteDao::new(); - App::new() - .wrap(middleware::Logger::default()) - .service(web::resource("/login").route(web::post().to(login))) - .service(web::resource("/photos").route(web::get().to(files::list_photos))) - .service(get_image) - .service(upload_image) - .service(generate_video) - .service(stream_video) - .service(get_video_part) - .service(favorites) - .service(put_add_favorite) - .service(delete_favorite) - .service(get_file_metadata) - .app_data(app_data.clone()) - .app_data::>>(Data::new(Box::new(user_dao))) - .app_data::>>(Data::new(Box::new(favorites_dao))) - .wrap(prometheus.clone()) - }) - .bind(dotenv::var("BIND_URL").unwrap())? - .bind("localhost:8088")? - .run() - .await - }) } -struct AppState { +pub struct AppState { stream_manager: Arc>, + base_path: String, + thumbnail_path: String, +} + +impl AppState { + fn new( + stream_manager: Arc>, + base_path: String, + thumbnail_path: String, + ) -> Self { + Self { + stream_manager, + base_path, + thumbnail_path, + } + } +} + +impl Default for AppState { + fn default() -> Self { + Self::new( + Arc::new(StreamActor {}.start()), + env::var("BASE_PATH").expect("BASE_PATH was not set in the env"), + env::var("THUMBNAILS").expect("THUMBNAILS was not set in the env"), + ) + } } diff --git a/src/testhelpers.rs b/src/testhelpers.rs index fc34ecd..f51e066 100644 --- a/src/testhelpers.rs +++ b/src/testhelpers.rs @@ -1,6 +1,7 @@ -use actix_web::body::MessageBody; -use actix_web::{body::BoxBody, HttpResponse}; -use serde::de::DeserializeOwned; +use actix_web::{ + body::{BoxBody, MessageBody}, + HttpResponse, +}; use crate::database::{models::User, UserDao}; use std::cell::RefCell; @@ -65,17 +66,3 @@ impl BodyReader for HttpResponse { std::str::from_utf8(&body).unwrap().to_string() } } - -pub trait TypedBodyReader -where - T: DeserializeOwned, -{ - fn read_body(self) -> T; -} - -impl TypedBodyReader for HttpResponse { - fn read_body(self) -> T { - let body = self.read_to_str(); - serde_json::from_value(serde_json::Value::String(body.clone())).unwrap() - } -}