Improve testability and remove boxing
Some checks failed
Core Repos/ImageApi/pipeline/pr-master There was a failure building this commit

Leverage generics to remove the extra heap allocation for the response
handlers using Dao's. Also moved some of the environment variables to
app state to allow for easier testing.
This commit is contained in:
Cameron Cordes
2022-03-16 20:51:37 -04:00
parent e02165082a
commit 4d9b7c91a1
4 changed files with 172 additions and 124 deletions

View File

@@ -1,4 +1,4 @@
use actix_web::{post, Responder}; use actix_web::Responder;
use actix_web::{ use actix_web::{
web::{self, Json}, web::{self, Json},
HttpResponse, HttpResponse,
@@ -12,10 +12,10 @@ use crate::{
database::UserDao, database::UserDao,
}; };
#[post("/register")] #[allow(dead_code)]
async fn register( async fn register<D: UserDao>(
user: Json<CreateAccountRequest>, user: Json<CreateAccountRequest>,
user_dao: web::Data<Box<dyn UserDao>>, user_dao: web::Data<D>,
) -> impl Responder { ) -> impl Responder {
if !user.username.is_empty() && user.password.len() > 5 && user.password == user.confirmation { if !user.username.is_empty() && user.password.len() > 5 && user.password == user.confirmation {
if user_dao.user_exists(&user.username) { if user_dao.user_exists(&user.username) {
@@ -30,10 +30,7 @@ async fn register(
} }
} }
pub async fn login( pub async fn login<D: UserDao>(creds: Json<LoginRequest>, user_dao: web::Data<D>) -> HttpResponse {
creds: Json<LoginRequest>,
user_dao: web::Data<Box<dyn UserDao>>,
) -> HttpResponse {
debug!("Logging in: {}", creds.username); debug!("Logging in: {}", creds.username);
if let Some(user) = user_dao.get_user(&creds.username, &creds.password) { if let Some(user) = user_dao.get_user(&creds.username, &creds.password) {
@@ -74,7 +71,7 @@ mod tests {
password: "pass".to_string(), password: "pass".to_string(),
}); });
let response = login(j, web::Data::new(Box::new(dao))).await; let response = login::<TestUserDao>(j, web::Data::new(dao)).await;
assert_eq!(response.status(), 200); assert_eq!(response.status(), 200);
} }
@@ -89,7 +86,7 @@ mod tests {
password: "password".to_string(), password: "password".to_string(),
}); });
let response = login(j, web::Data::new(Box::new(dao))).await; let response = login::<TestUserDao>(j, web::Data::new(dao)).await;
assert_eq!(response.status(), 200); assert_eq!(response.status(), 200);
let response_text: String = response.read_to_str(); let response_text: String = response.read_to_str();
@@ -107,7 +104,7 @@ mod tests {
password: "password".to_string(), password: "password".to_string(),
}); });
let response = login(j, web::Data::new(Box::new(dao))).await; let response = login::<TestUserDao>(j, web::Data::new(dao)).await;
assert_eq!(response.status(), 404); assert_eq!(response.status(), 404);
} }

View File

@@ -1,3 +1,4 @@
use std::fmt::Debug;
use std::fs::read_dir; use std::fs::read_dir;
use std::io; use std::io;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
@@ -5,17 +6,25 @@ use std::path::{Path, PathBuf};
use ::anyhow; use ::anyhow;
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use actix_web::{web::Query, HttpResponse}; use actix_web::{
web::{self, Query},
HttpResponse,
};
use log::{debug, error}; use log::{debug, error};
use crate::data::{Claims, PhotosResponse, ThumbnailRequest}; use crate::data::{Claims, PhotosResponse, ThumbnailRequest};
use crate::AppState;
use path_absolutize::*; use path_absolutize::*;
pub async fn list_photos(_: Claims, req: Query<ThumbnailRequest>) -> HttpResponse { pub async fn list_photos(
_: Claims,
req: Query<ThumbnailRequest>,
app_state: web::Data<AppState>,
) -> HttpResponse {
let path = &req.path; let path = &req.path;
if let Some(path) = is_valid_path(path) { if let Some(path) = is_valid_full_path(&PathBuf::from(&app_state.base_path), path) {
debug!("Valid path: {:?}", path); debug!("Valid path: {:?}", path);
let files = list_files(&path).unwrap_or_default(); let files = list_files(&path).unwrap_or_default();
@@ -31,9 +40,7 @@ pub async fn list_photos(_: Claims, req: Query<ThumbnailRequest>) -> HttpRespons
) )
}) })
.map(|path: &PathBuf| { .map(|path: &PathBuf| {
let relative = path let relative = path.strip_prefix(&app_state.base_path).unwrap();
.strip_prefix(dotenv::var("BASE_PATH").unwrap())
.unwrap();
relative.to_path_buf() relative.to_path_buf()
}) })
.map(|f| f.to_str().unwrap().to_string()) .map(|f| f.to_str().unwrap().to_string())
@@ -43,9 +50,7 @@ pub async fn list_photos(_: Claims, req: Query<ThumbnailRequest>) -> HttpRespons
.iter() .iter()
.filter(|&f| f.metadata().map_or(false, |md| md.is_dir())) .filter(|&f| f.metadata().map_or(false, |md| md.is_dir()))
.map(|path: &PathBuf| { .map(|path: &PathBuf| {
let relative = path let relative = path.strip_prefix(&app_state.base_path).unwrap();
.strip_prefix(dotenv::var("BASE_PATH").unwrap())
.unwrap();
relative.to_path_buf() relative.to_path_buf()
}) })
.map(|f| f.to_str().unwrap().to_string()) .map(|f| f.to_str().unwrap().to_string())
@@ -82,18 +87,13 @@ pub fn is_image_or_video(path: &Path) -> bool {
|| extension == "nef" || extension == "nef"
} }
pub fn is_valid_path(path: &str) -> Option<PathBuf> { pub fn is_valid_full_path<P: AsRef<Path> + Debug>(base: &P, path: &str) -> Option<PathBuf> {
let base = PathBuf::from(dotenv::var("BASE_PATH").unwrap());
is_valid_full_path(&base, path)
}
fn is_valid_full_path(base: &Path, path: &str) -> Option<PathBuf> {
debug!("Base: {:?}. Path: {}", base, path); debug!("Base: {:?}. Path: {}", base, path);
let path = PathBuf::from(path); let path = PathBuf::from(path);
let mut path = if path.is_relative() { let mut path = if path.is_relative() {
let mut full_path = PathBuf::from(base); let mut full_path = PathBuf::new();
full_path.push(base);
full_path.push(&path); full_path.push(&path);
full_path full_path
} else { } else {
@@ -109,7 +109,10 @@ fn is_valid_full_path(base: &Path, path: &str) -> Option<PathBuf> {
} }
} }
fn is_path_above_base_dir(base: &Path, full_path: &mut PathBuf) -> anyhow::Result<PathBuf> { fn is_path_above_base_dir<P: AsRef<Path> + Debug>(
base: P,
full_path: &mut PathBuf,
) -> anyhow::Result<PathBuf> {
full_path full_path
.absolutize() .absolutize()
.with_context(|| format!("Unable to resolve absolute path: {:?}", full_path)) .with_context(|| format!("Unable to resolve absolute path: {:?}", full_path))
@@ -135,15 +138,21 @@ mod tests {
use super::*; use super::*;
mod api { mod api {
use actix_web::{web::Query, HttpResponse}; use super::*;
use actix::Actor;
use actix_web::{
web::{self, Query},
HttpResponse,
};
use super::list_photos;
use crate::{ use crate::{
data::{Claims, PhotosResponse, ThumbnailRequest}, data::{Claims, PhotosResponse, ThumbnailRequest},
testhelpers::BodyReader, testhelpers::BodyReader,
video::StreamActor,
AppState,
}; };
use std::fs; use std::{fs, sync::Arc};
fn setup() { fn setup() {
let _ = env_logger::builder().is_test(true).try_init(); let _ = env_logger::builder().is_test(true).try_init();
@@ -160,7 +169,6 @@ mod tests {
let request: Query<ThumbnailRequest> = Query::from_query("path=").unwrap(); let request: Query<ThumbnailRequest> = Query::from_query("path=").unwrap();
std::env::set_var("BASE_PATH", "/tmp");
let mut temp_photo = std::env::temp_dir(); let mut temp_photo = std::env::temp_dir();
let mut tmp = temp_photo.clone(); let mut tmp = temp_photo.clone();
@@ -169,9 +177,17 @@ mod tests {
temp_photo.push("photo.jpg"); temp_photo.push("photo.jpg");
fs::File::create(temp_photo).unwrap(); fs::File::create(temp_photo.clone()).unwrap();
let response: HttpResponse = list_photos(claims, request).await; let response: HttpResponse = list_photos(
claims,
request,
web::Data::new(AppState::new(
Arc::new(StreamActor {}.start()),
String::from("/tmp"),
)),
)
.await;
let status = response.status(); let status = response.status();
let body: PhotosResponse = serde_json::from_str(&response.read_to_str()).unwrap(); let body: PhotosResponse = serde_json::from_str(&response.read_to_str()).unwrap();
@@ -200,7 +216,15 @@ mod tests {
let request: Query<ThumbnailRequest> = Query::from_query("path=..").unwrap(); let request: Query<ThumbnailRequest> = Query::from_query("path=..").unwrap();
let response = list_photos(claims, request).await; let response = list_photos(
claims,
request,
web::Data::new(AppState::new(
Arc::new(StreamActor {}.start()),
String::from("/tmp"),
)),
)
.await;
assert_eq!(response.status(), 400); assert_eq!(response.status(), 400);
} }
@@ -208,12 +232,13 @@ mod tests {
#[test] #[test]
fn directory_traversal_test() { fn directory_traversal_test() {
assert_eq!(None, is_valid_path("../")); let base = env::temp_dir();
assert_eq!(None, is_valid_path("..")); assert_eq!(None, is_valid_full_path(&base, "../"));
assert_eq!(None, is_valid_path("fake/../../../")); assert_eq!(None, is_valid_full_path(&base, ".."));
assert_eq!(None, is_valid_path("../../../etc/passwd")); assert_eq!(None, is_valid_full_path(&base, "fake/../../../"));
assert_eq!(None, is_valid_path("..//etc/passwd")); assert_eq!(None, is_valid_full_path(&base, "../../../etc/passwd"));
assert_eq!(None, is_valid_path("../../etc/passwd")); assert_eq!(None, is_valid_full_path(&base, "..//etc/passwd"));
assert_eq!(None, is_valid_full_path(&base, "../../etc/passwd"));
} }
#[test] #[test]

View File

@@ -32,7 +32,7 @@ use log::{debug, error, info};
use crate::auth::login; use crate::auth::login;
use crate::data::*; use crate::data::*;
use crate::database::*; use crate::database::*;
use crate::files::{is_image_or_video, is_valid_path}; use crate::files::{is_image_or_video, is_valid_full_path};
use crate::video::*; use crate::video::*;
mod auth; mod auth;
@@ -62,13 +62,15 @@ async fn get_image(
_claims: Claims, _claims: Claims,
request: HttpRequest, request: HttpRequest,
req: web::Query<ThumbnailRequest>, req: web::Query<ThumbnailRequest>,
app_state: web::Data<AppState>,
) -> impl Responder { ) -> impl Responder {
if let Some(path) = is_valid_path(&req.path) { if let Some(path) = is_valid_full_path(&app_state.base_path, &req.path) {
if req.size.is_some() { if req.size.is_some() {
let thumbs = dotenv::var("THUMBNAILS").unwrap();
let relative_path = path let relative_path = path
.strip_prefix(dotenv::var("BASE_PATH").unwrap()) .strip_prefix(&app_state.base_path)
.expect("Error stripping prefix"); .expect("Error stripping base path prefix from thumbnail");
let thumbs = &app_state.thumbnail_path;
let thumb_path = Path::new(&thumbs).join(relative_path); let thumb_path = Path::new(&thumbs).join(relative_path);
debug!("{:?}", thumb_path); debug!("{:?}", thumb_path);
@@ -89,8 +91,12 @@ async fn get_image(
} }
#[get("/image/metadata")] #[get("/image/metadata")]
async fn get_file_metadata(_: Claims, path: web::Query<ThumbnailRequest>) -> impl Responder { async fn get_file_metadata(
match is_valid_path(&path.path) _: Claims,
path: web::Query<ThumbnailRequest>,
app_state: web::Data<AppState>,
) -> impl Responder {
match is_valid_full_path(&app_state.base_path, &path.path)
.ok_or_else(|| ErrorKind::InvalidData.into()) .ok_or_else(|| ErrorKind::InvalidData.into())
.and_then(File::open) .and_then(File::open)
.and_then(|file| file.metadata()) .and_then(|file| file.metadata())
@@ -107,7 +113,11 @@ async fn get_file_metadata(_: Claims, path: web::Query<ThumbnailRequest>) -> imp
} }
#[post("/image")] #[post("/image")]
async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder { async fn upload_image(
_: Claims,
mut payload: mp::Multipart,
app_state: web::Data<AppState>,
) -> impl Responder {
let mut file_content: BytesMut = BytesMut::new(); let mut file_content: BytesMut = BytesMut::new();
let mut file_name: Option<String> = None; let mut file_name: Option<String> = None;
let mut file_path: Option<String> = None; let mut file_path: Option<String> = None;
@@ -131,10 +141,12 @@ async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder {
} }
} }
let path = file_path.unwrap_or_else(|| dotenv::var("BASE_PATH").unwrap()); let path = file_path.unwrap_or_else(|| app_state.base_path.clone());
if !file_content.is_empty() { if !file_content.is_empty() {
let full_path = PathBuf::from(&path).join(file_name.unwrap()); let full_path = PathBuf::from(&path).join(file_name.unwrap());
if let Some(full_path) = is_valid_path(full_path.to_str().unwrap_or("")) { if let Some(full_path) =
is_valid_full_path(&app_state.base_path, full_path.to_str().unwrap_or(""))
{
if !full_path.is_file() && is_image_or_video(&full_path) { if !full_path.is_file() && is_image_or_video(&full_path) {
let mut file = File::create(full_path).unwrap(); let mut file = File::create(full_path).unwrap();
file.write_all(&file_content).unwrap(); file.write_all(&file_content).unwrap();
@@ -155,7 +167,7 @@ async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder {
#[post("/video/generate")] #[post("/video/generate")]
async fn generate_video( async fn generate_video(
_claims: Claims, _claims: Claims,
data: web::Data<AppState>, app_state: web::Data<AppState>,
body: web::Json<ThumbnailRequest>, body: web::Json<ThumbnailRequest>,
) -> impl Responder { ) -> impl Responder {
let filename = PathBuf::from(&body.path); let filename = PathBuf::from(&body.path);
@@ -163,9 +175,10 @@ async fn generate_video(
if let Some(name) = filename.file_stem() { if let Some(name) = filename.file_stem() {
let filename = name.to_str().expect("Filename should convert to string"); let filename = name.to_str().expect("Filename should convert to string");
let playlist = format!("tmp/{}.m3u8", filename); let playlist = format!("tmp/{}.m3u8", filename);
if let Some(path) = is_valid_path(&body.path) { if let Some(path) = is_valid_full_path(&app_state.base_path, &body.path) {
if let Ok(child) = create_playlist(path.to_str().unwrap(), &playlist).await { if let Ok(child) = create_playlist(path.to_str().unwrap(), &playlist).await {
data.stream_manager app_state
.stream_manager
.do_send(ProcessMessage(playlist.clone(), child)); .do_send(ProcessMessage(playlist.clone(), child));
} }
} else { } else {
@@ -184,12 +197,13 @@ async fn stream_video(
request: HttpRequest, request: HttpRequest,
_: Claims, _: Claims,
path: web::Query<ThumbnailRequest>, path: web::Query<ThumbnailRequest>,
app_state: web::Data<AppState>,
) -> impl Responder { ) -> impl Responder {
let playlist = &path.path; let playlist = &path.path;
debug!("Playlist: {}", playlist); debug!("Playlist: {}", playlist);
// Extract video playlist dir to dotenv // Extract video playlist dir to dotenv
if !playlist.starts_with("tmp") && is_valid_path(playlist) != None { if !playlist.starts_with("tmp") && is_valid_full_path(&app_state.base_path, playlist) != None {
HttpResponse::BadRequest().finish() HttpResponse::BadRequest().finish()
} else if let Ok(file) = NamedFile::open(playlist) { } else if let Ok(file) = NamedFile::open(playlist) {
file.into_response(&request) file.into_response(&request)
@@ -406,7 +420,56 @@ fn main() -> std::io::Result<()> {
env_logger::init(); env_logger::init();
create_thumbnails(); create_thumbnails();
watch_files();
let system = actix::System::new();
system.block_on(async {
let app_data = web::Data::new(AppState::default());
let labels = HashMap::new();
let prometheus = PrometheusMetricsBuilder::new("api")
.const_labels(labels)
.build()
.expect("Unable to build prometheus metrics middleware");
prometheus
.registry
.register(Box::new(IMAGE_GAUGE.clone()))
.unwrap();
prometheus
.registry
.register(Box::new(VIDEO_GAUGE.clone()))
.unwrap();
HttpServer::new(move || {
let user_dao = SqliteUserDao::new();
let favorites_dao = SqliteFavoriteDao::new();
App::new()
.wrap(middleware::Logger::default())
.service(web::resource("/login").route(web::post().to(login::<SqliteUserDao>)))
.service(web::resource("/photos").route(web::get().to(files::list_photos)))
.service(get_image)
.service(upload_image)
.service(generate_video)
.service(stream_video)
.service(get_video_part)
.service(favorites)
.service(put_add_favorite)
.service(delete_favorite)
.service(get_file_metadata)
.app_data(app_data.clone())
.app_data::<Data<SqliteUserDao>>(Data::new(user_dao))
.app_data::<Data<Box<dyn FavoriteDao>>>(Data::new(Box::new(favorites_dao)))
.wrap(prometheus.clone())
})
.bind(dotenv::var("BIND_URL").unwrap())?
.bind("localhost:8088")?
.run()
.await
})
}
fn watch_files() {
std::thread::spawn(|| { std::thread::spawn(|| {
let (wtx, wrx) = channel(); let (wtx, wrx) = channel();
let mut watcher = watcher(wtx, std::time::Duration::from_secs(10)).unwrap(); let mut watcher = watcher(wtx, std::time::Duration::from_secs(10)).unwrap();
@@ -449,58 +512,34 @@ fn main() -> std::io::Result<()> {
} }
} }
}); });
let system = actix::System::new();
system.block_on(async {
let act = StreamActor {}.start();
let app_data = web::Data::new(AppState {
stream_manager: Arc::new(act),
});
let labels = HashMap::new();
let prometheus = PrometheusMetricsBuilder::new("api")
.const_labels(labels)
.build()
.expect("Unable to build prometheus metrics middleware");
prometheus
.registry
.register(Box::new(IMAGE_GAUGE.clone()))
.unwrap();
prometheus
.registry
.register(Box::new(VIDEO_GAUGE.clone()))
.unwrap();
HttpServer::new(move || {
let user_dao = SqliteUserDao::new();
let favorites_dao = SqliteFavoriteDao::new();
App::new()
.wrap(middleware::Logger::default())
.service(web::resource("/login").route(web::post().to(login)))
.service(web::resource("/photos").route(web::get().to(files::list_photos)))
.service(get_image)
.service(upload_image)
.service(generate_video)
.service(stream_video)
.service(get_video_part)
.service(favorites)
.service(put_add_favorite)
.service(delete_favorite)
.service(get_file_metadata)
.app_data(app_data.clone())
.app_data::<Data<Box<dyn UserDao>>>(Data::new(Box::new(user_dao)))
.app_data::<Data<Box<dyn FavoriteDao>>>(Data::new(Box::new(favorites_dao)))
.wrap(prometheus.clone())
})
.bind(dotenv::var("BIND_URL").unwrap())?
.bind("localhost:8088")?
.run()
.await
})
} }
struct AppState { pub struct AppState {
stream_manager: Arc<Addr<StreamActor>>, stream_manager: Arc<Addr<StreamActor>>,
base_path: String,
thumbnail_path: String,
}
impl AppState {
fn new(
stream_manager: Arc<Addr<StreamActor>>,
base_path: String,
thumbnail_path: String,
) -> Self {
Self {
stream_manager,
base_path,
thumbnail_path,
}
}
}
impl Default for AppState {
fn default() -> Self {
Self::new(
Arc::new(StreamActor {}.start()),
env::var("BASE_PATH").expect("BASE_PATH was not set in the env"),
env::var("THUMBNAILS").expect("THUMBNAILS was not set in the env"),
)
}
} }

View File

@@ -1,6 +1,7 @@
use actix_web::body::MessageBody; use actix_web::{
use actix_web::{body::BoxBody, HttpResponse}; body::{BoxBody, MessageBody},
use serde::de::DeserializeOwned; HttpResponse,
};
use crate::database::{models::User, UserDao}; use crate::database::{models::User, UserDao};
use std::cell::RefCell; use std::cell::RefCell;
@@ -65,17 +66,3 @@ impl BodyReader for HttpResponse<BoxBody> {
std::str::from_utf8(&body).unwrap().to_string() std::str::from_utf8(&body).unwrap().to_string()
} }
} }
pub trait TypedBodyReader<T>
where
T: DeserializeOwned,
{
fn read_body(self) -> T;
}
impl<T: DeserializeOwned> TypedBodyReader<T> for HttpResponse<BoxBody> {
fn read_body(self) -> T {
let body = self.read_to_str();
serde_json::from_value(serde_json::Value::String(body.clone())).unwrap()
}
}