Improve testability and remove boxing
Some checks failed
Core Repos/ImageApi/pipeline/pr-master There was a failure building this commit

Leverage generics to remove the extra heap allocation for the response
handlers using Dao's. Also moved some of the environment variables to
app state to allow for easier testing.
This commit is contained in:
Cameron Cordes
2022-03-16 20:51:37 -04:00
parent e02165082a
commit 4d9b7c91a1
4 changed files with 172 additions and 124 deletions

View File

@@ -1,4 +1,4 @@
use actix_web::{post, Responder};
use actix_web::Responder;
use actix_web::{
web::{self, Json},
HttpResponse,
@@ -12,10 +12,10 @@ use crate::{
database::UserDao,
};
#[post("/register")]
async fn register(
#[allow(dead_code)]
async fn register<D: UserDao>(
user: Json<CreateAccountRequest>,
user_dao: web::Data<Box<dyn UserDao>>,
user_dao: web::Data<D>,
) -> impl Responder {
if !user.username.is_empty() && user.password.len() > 5 && user.password == user.confirmation {
if user_dao.user_exists(&user.username) {
@@ -30,10 +30,7 @@ async fn register(
}
}
pub async fn login(
creds: Json<LoginRequest>,
user_dao: web::Data<Box<dyn UserDao>>,
) -> HttpResponse {
pub async fn login<D: UserDao>(creds: Json<LoginRequest>, user_dao: web::Data<D>) -> HttpResponse {
debug!("Logging in: {}", creds.username);
if let Some(user) = user_dao.get_user(&creds.username, &creds.password) {
@@ -74,7 +71,7 @@ mod tests {
password: "pass".to_string(),
});
let response = login(j, web::Data::new(Box::new(dao))).await;
let response = login::<TestUserDao>(j, web::Data::new(dao)).await;
assert_eq!(response.status(), 200);
}
@@ -89,7 +86,7 @@ mod tests {
password: "password".to_string(),
});
let response = login(j, web::Data::new(Box::new(dao))).await;
let response = login::<TestUserDao>(j, web::Data::new(dao)).await;
assert_eq!(response.status(), 200);
let response_text: String = response.read_to_str();
@@ -107,7 +104,7 @@ mod tests {
password: "password".to_string(),
});
let response = login(j, web::Data::new(Box::new(dao))).await;
let response = login::<TestUserDao>(j, web::Data::new(dao)).await;
assert_eq!(response.status(), 404);
}

View File

@@ -1,3 +1,4 @@
use std::fmt::Debug;
use std::fs::read_dir;
use std::io;
use std::path::{Path, PathBuf};
@@ -5,17 +6,25 @@ use std::path::{Path, PathBuf};
use ::anyhow;
use anyhow::{anyhow, Context};
use actix_web::{web::Query, HttpResponse};
use actix_web::{
web::{self, Query},
HttpResponse,
};
use log::{debug, error};
use crate::data::{Claims, PhotosResponse, ThumbnailRequest};
use crate::AppState;
use path_absolutize::*;
pub async fn list_photos(_: Claims, req: Query<ThumbnailRequest>) -> HttpResponse {
pub async fn list_photos(
_: Claims,
req: Query<ThumbnailRequest>,
app_state: web::Data<AppState>,
) -> HttpResponse {
let path = &req.path;
if let Some(path) = is_valid_path(path) {
if let Some(path) = is_valid_full_path(&PathBuf::from(&app_state.base_path), path) {
debug!("Valid path: {:?}", path);
let files = list_files(&path).unwrap_or_default();
@@ -31,9 +40,7 @@ pub async fn list_photos(_: Claims, req: Query<ThumbnailRequest>) -> HttpRespons
)
})
.map(|path: &PathBuf| {
let relative = path
.strip_prefix(dotenv::var("BASE_PATH").unwrap())
.unwrap();
let relative = path.strip_prefix(&app_state.base_path).unwrap();
relative.to_path_buf()
})
.map(|f| f.to_str().unwrap().to_string())
@@ -43,9 +50,7 @@ pub async fn list_photos(_: Claims, req: Query<ThumbnailRequest>) -> HttpRespons
.iter()
.filter(|&f| f.metadata().map_or(false, |md| md.is_dir()))
.map(|path: &PathBuf| {
let relative = path
.strip_prefix(dotenv::var("BASE_PATH").unwrap())
.unwrap();
let relative = path.strip_prefix(&app_state.base_path).unwrap();
relative.to_path_buf()
})
.map(|f| f.to_str().unwrap().to_string())
@@ -82,18 +87,13 @@ pub fn is_image_or_video(path: &Path) -> bool {
|| extension == "nef"
}
pub fn is_valid_path(path: &str) -> Option<PathBuf> {
let base = PathBuf::from(dotenv::var("BASE_PATH").unwrap());
is_valid_full_path(&base, path)
}
fn is_valid_full_path(base: &Path, path: &str) -> Option<PathBuf> {
pub fn is_valid_full_path<P: AsRef<Path> + Debug>(base: &P, path: &str) -> Option<PathBuf> {
debug!("Base: {:?}. Path: {}", base, path);
let path = PathBuf::from(path);
let mut path = if path.is_relative() {
let mut full_path = PathBuf::from(base);
let mut full_path = PathBuf::new();
full_path.push(base);
full_path.push(&path);
full_path
} else {
@@ -109,7 +109,10 @@ fn is_valid_full_path(base: &Path, path: &str) -> Option<PathBuf> {
}
}
fn is_path_above_base_dir(base: &Path, full_path: &mut PathBuf) -> anyhow::Result<PathBuf> {
fn is_path_above_base_dir<P: AsRef<Path> + Debug>(
base: P,
full_path: &mut PathBuf,
) -> anyhow::Result<PathBuf> {
full_path
.absolutize()
.with_context(|| format!("Unable to resolve absolute path: {:?}", full_path))
@@ -135,15 +138,21 @@ mod tests {
use super::*;
mod api {
use actix_web::{web::Query, HttpResponse};
use super::*;
use actix::Actor;
use actix_web::{
web::{self, Query},
HttpResponse,
};
use super::list_photos;
use crate::{
data::{Claims, PhotosResponse, ThumbnailRequest},
testhelpers::BodyReader,
video::StreamActor,
AppState,
};
use std::fs;
use std::{fs, sync::Arc};
fn setup() {
let _ = env_logger::builder().is_test(true).try_init();
@@ -160,7 +169,6 @@ mod tests {
let request: Query<ThumbnailRequest> = Query::from_query("path=").unwrap();
std::env::set_var("BASE_PATH", "/tmp");
let mut temp_photo = std::env::temp_dir();
let mut tmp = temp_photo.clone();
@@ -169,9 +177,17 @@ mod tests {
temp_photo.push("photo.jpg");
fs::File::create(temp_photo).unwrap();
fs::File::create(temp_photo.clone()).unwrap();
let response: HttpResponse = list_photos(claims, request).await;
let response: HttpResponse = list_photos(
claims,
request,
web::Data::new(AppState::new(
Arc::new(StreamActor {}.start()),
String::from("/tmp"),
)),
)
.await;
let status = response.status();
let body: PhotosResponse = serde_json::from_str(&response.read_to_str()).unwrap();
@@ -200,7 +216,15 @@ mod tests {
let request: Query<ThumbnailRequest> = Query::from_query("path=..").unwrap();
let response = list_photos(claims, request).await;
let response = list_photos(
claims,
request,
web::Data::new(AppState::new(
Arc::new(StreamActor {}.start()),
String::from("/tmp"),
)),
)
.await;
assert_eq!(response.status(), 400);
}
@@ -208,12 +232,13 @@ mod tests {
#[test]
fn directory_traversal_test() {
assert_eq!(None, is_valid_path("../"));
assert_eq!(None, is_valid_path(".."));
assert_eq!(None, is_valid_path("fake/../../../"));
assert_eq!(None, is_valid_path("../../../etc/passwd"));
assert_eq!(None, is_valid_path("..//etc/passwd"));
assert_eq!(None, is_valid_path("../../etc/passwd"));
let base = env::temp_dir();
assert_eq!(None, is_valid_full_path(&base, "../"));
assert_eq!(None, is_valid_full_path(&base, ".."));
assert_eq!(None, is_valid_full_path(&base, "fake/../../../"));
assert_eq!(None, is_valid_full_path(&base, "../../../etc/passwd"));
assert_eq!(None, is_valid_full_path(&base, "..//etc/passwd"));
assert_eq!(None, is_valid_full_path(&base, "../../etc/passwd"));
}
#[test]

View File

@@ -32,7 +32,7 @@ use log::{debug, error, info};
use crate::auth::login;
use crate::data::*;
use crate::database::*;
use crate::files::{is_image_or_video, is_valid_path};
use crate::files::{is_image_or_video, is_valid_full_path};
use crate::video::*;
mod auth;
@@ -62,13 +62,15 @@ async fn get_image(
_claims: Claims,
request: HttpRequest,
req: web::Query<ThumbnailRequest>,
app_state: web::Data<AppState>,
) -> impl Responder {
if let Some(path) = is_valid_path(&req.path) {
if let Some(path) = is_valid_full_path(&app_state.base_path, &req.path) {
if req.size.is_some() {
let thumbs = dotenv::var("THUMBNAILS").unwrap();
let relative_path = path
.strip_prefix(dotenv::var("BASE_PATH").unwrap())
.expect("Error stripping prefix");
.strip_prefix(&app_state.base_path)
.expect("Error stripping base path prefix from thumbnail");
let thumbs = &app_state.thumbnail_path;
let thumb_path = Path::new(&thumbs).join(relative_path);
debug!("{:?}", thumb_path);
@@ -89,8 +91,12 @@ async fn get_image(
}
#[get("/image/metadata")]
async fn get_file_metadata(_: Claims, path: web::Query<ThumbnailRequest>) -> impl Responder {
match is_valid_path(&path.path)
async fn get_file_metadata(
_: Claims,
path: web::Query<ThumbnailRequest>,
app_state: web::Data<AppState>,
) -> impl Responder {
match is_valid_full_path(&app_state.base_path, &path.path)
.ok_or_else(|| ErrorKind::InvalidData.into())
.and_then(File::open)
.and_then(|file| file.metadata())
@@ -107,7 +113,11 @@ async fn get_file_metadata(_: Claims, path: web::Query<ThumbnailRequest>) -> imp
}
#[post("/image")]
async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder {
async fn upload_image(
_: Claims,
mut payload: mp::Multipart,
app_state: web::Data<AppState>,
) -> impl Responder {
let mut file_content: BytesMut = BytesMut::new();
let mut file_name: Option<String> = None;
let mut file_path: Option<String> = None;
@@ -131,10 +141,12 @@ async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder {
}
}
let path = file_path.unwrap_or_else(|| dotenv::var("BASE_PATH").unwrap());
let path = file_path.unwrap_or_else(|| app_state.base_path.clone());
if !file_content.is_empty() {
let full_path = PathBuf::from(&path).join(file_name.unwrap());
if let Some(full_path) = is_valid_path(full_path.to_str().unwrap_or("")) {
if let Some(full_path) =
is_valid_full_path(&app_state.base_path, full_path.to_str().unwrap_or(""))
{
if !full_path.is_file() && is_image_or_video(&full_path) {
let mut file = File::create(full_path).unwrap();
file.write_all(&file_content).unwrap();
@@ -155,7 +167,7 @@ async fn upload_image(_: Claims, mut payload: mp::Multipart) -> impl Responder {
#[post("/video/generate")]
async fn generate_video(
_claims: Claims,
data: web::Data<AppState>,
app_state: web::Data<AppState>,
body: web::Json<ThumbnailRequest>,
) -> impl Responder {
let filename = PathBuf::from(&body.path);
@@ -163,9 +175,10 @@ async fn generate_video(
if let Some(name) = filename.file_stem() {
let filename = name.to_str().expect("Filename should convert to string");
let playlist = format!("tmp/{}.m3u8", filename);
if let Some(path) = is_valid_path(&body.path) {
if let Some(path) = is_valid_full_path(&app_state.base_path, &body.path) {
if let Ok(child) = create_playlist(path.to_str().unwrap(), &playlist).await {
data.stream_manager
app_state
.stream_manager
.do_send(ProcessMessage(playlist.clone(), child));
}
} else {
@@ -184,12 +197,13 @@ async fn stream_video(
request: HttpRequest,
_: Claims,
path: web::Query<ThumbnailRequest>,
app_state: web::Data<AppState>,
) -> impl Responder {
let playlist = &path.path;
debug!("Playlist: {}", playlist);
// Extract video playlist dir to dotenv
if !playlist.starts_with("tmp") && is_valid_path(playlist) != None {
if !playlist.starts_with("tmp") && is_valid_full_path(&app_state.base_path, playlist) != None {
HttpResponse::BadRequest().finish()
} else if let Ok(file) = NamedFile::open(playlist) {
file.into_response(&request)
@@ -406,7 +420,56 @@ fn main() -> std::io::Result<()> {
env_logger::init();
create_thumbnails();
watch_files();
let system = actix::System::new();
system.block_on(async {
let app_data = web::Data::new(AppState::default());
let labels = HashMap::new();
let prometheus = PrometheusMetricsBuilder::new("api")
.const_labels(labels)
.build()
.expect("Unable to build prometheus metrics middleware");
prometheus
.registry
.register(Box::new(IMAGE_GAUGE.clone()))
.unwrap();
prometheus
.registry
.register(Box::new(VIDEO_GAUGE.clone()))
.unwrap();
HttpServer::new(move || {
let user_dao = SqliteUserDao::new();
let favorites_dao = SqliteFavoriteDao::new();
App::new()
.wrap(middleware::Logger::default())
.service(web::resource("/login").route(web::post().to(login::<SqliteUserDao>)))
.service(web::resource("/photos").route(web::get().to(files::list_photos)))
.service(get_image)
.service(upload_image)
.service(generate_video)
.service(stream_video)
.service(get_video_part)
.service(favorites)
.service(put_add_favorite)
.service(delete_favorite)
.service(get_file_metadata)
.app_data(app_data.clone())
.app_data::<Data<SqliteUserDao>>(Data::new(user_dao))
.app_data::<Data<Box<dyn FavoriteDao>>>(Data::new(Box::new(favorites_dao)))
.wrap(prometheus.clone())
})
.bind(dotenv::var("BIND_URL").unwrap())?
.bind("localhost:8088")?
.run()
.await
})
}
fn watch_files() {
std::thread::spawn(|| {
let (wtx, wrx) = channel();
let mut watcher = watcher(wtx, std::time::Duration::from_secs(10)).unwrap();
@@ -449,58 +512,34 @@ fn main() -> std::io::Result<()> {
}
}
});
let system = actix::System::new();
system.block_on(async {
let act = StreamActor {}.start();
let app_data = web::Data::new(AppState {
stream_manager: Arc::new(act),
});
let labels = HashMap::new();
let prometheus = PrometheusMetricsBuilder::new("api")
.const_labels(labels)
.build()
.expect("Unable to build prometheus metrics middleware");
prometheus
.registry
.register(Box::new(IMAGE_GAUGE.clone()))
.unwrap();
prometheus
.registry
.register(Box::new(VIDEO_GAUGE.clone()))
.unwrap();
HttpServer::new(move || {
let user_dao = SqliteUserDao::new();
let favorites_dao = SqliteFavoriteDao::new();
App::new()
.wrap(middleware::Logger::default())
.service(web::resource("/login").route(web::post().to(login)))
.service(web::resource("/photos").route(web::get().to(files::list_photos)))
.service(get_image)
.service(upload_image)
.service(generate_video)
.service(stream_video)
.service(get_video_part)
.service(favorites)
.service(put_add_favorite)
.service(delete_favorite)
.service(get_file_metadata)
.app_data(app_data.clone())
.app_data::<Data<Box<dyn UserDao>>>(Data::new(Box::new(user_dao)))
.app_data::<Data<Box<dyn FavoriteDao>>>(Data::new(Box::new(favorites_dao)))
.wrap(prometheus.clone())
})
.bind(dotenv::var("BIND_URL").unwrap())?
.bind("localhost:8088")?
.run()
.await
})
}
struct AppState {
pub struct AppState {
stream_manager: Arc<Addr<StreamActor>>,
base_path: String,
thumbnail_path: String,
}
impl AppState {
fn new(
stream_manager: Arc<Addr<StreamActor>>,
base_path: String,
thumbnail_path: String,
) -> Self {
Self {
stream_manager,
base_path,
thumbnail_path,
}
}
}
impl Default for AppState {
fn default() -> Self {
Self::new(
Arc::new(StreamActor {}.start()),
env::var("BASE_PATH").expect("BASE_PATH was not set in the env"),
env::var("THUMBNAILS").expect("THUMBNAILS was not set in the env"),
)
}
}

View File

@@ -1,6 +1,7 @@
use actix_web::body::MessageBody;
use actix_web::{body::BoxBody, HttpResponse};
use serde::de::DeserializeOwned;
use actix_web::{
body::{BoxBody, MessageBody},
HttpResponse,
};
use crate::database::{models::User, UserDao};
use std::cell::RefCell;
@@ -65,17 +66,3 @@ impl BodyReader for HttpResponse<BoxBody> {
std::str::from_utf8(&body).unwrap().to_string()
}
}
pub trait TypedBodyReader<T>
where
T: DeserializeOwned,
{
fn read_body(self) -> T;
}
impl<T: DeserializeOwned> TypedBodyReader<T> for HttpResponse<BoxBody> {
fn read_body(self) -> T {
let body = self.read_to_str();
serde_json::from_value(serde_json::Value::String(body.clone())).unwrap()
}
}