From 860e7a97fb5f2c61e4e2c33edd44cf4a2792e28b Mon Sep 17 00:00:00 2001 From: Cameron Date: Sun, 24 Nov 2024 09:49:03 -0500 Subject: [PATCH] Use TagDao for improved filtering --- src/files.rs | 68 ++++++++++++++++++++-------------------------------- src/main.rs | 21 ++++++++-------- src/tags.rs | 16 ++++++++++--- 3 files changed, 50 insertions(+), 55 deletions(-) diff --git a/src/files.rs b/src/files.rs index 7416bc9..9b8f7a8 100644 --- a/src/files.rs +++ b/src/files.rs @@ -24,7 +24,7 @@ use crate::tags::TagDao; use crate::video::StreamActor; use path_absolutize::*; use rand::prelude::SliceRandom; -use rand::{thread_rng}; +use rand::thread_rng; use serde::Deserialize; pub async fn list_photos( @@ -51,54 +51,38 @@ pub async fn list_photos( .filter_map(|t| t.parse().ok()) .collect::>(); - let exclude_tag_ids = req.exclude_tag_ids.clone() + let exclude_tag_ids = req + .exclude_tag_ids + .clone() .unwrap_or(String::new()) .split(',') .filter_map(|t| t.parse().ok()) .collect::>(); - return dao - .get_files_with_any_tag_ids(tag_ids.clone(), exclude_tag_ids.clone()) - .context(format!("Failed to get files with tag_ids: {:?}", tag_ids)) - .map(|tagged_files| { - return if let Some(sort_type) = req.sort { - debug!("Sorting files: {:?}", sort_type); - sort(tagged_files, sort_type) - } else { - tagged_files - }; - }) - .map(|tagged_files| match filter_mode { - FilterMode::Any => tagged_files - .iter() - .filter(|file| file.starts_with(search_path)) - .cloned() - .collect(), - FilterMode::All => tagged_files - .iter() - .filter(|&file_path| { - if !file_path.starts_with(search_path) { - return false; - } + match filter_mode { + FilterMode::Any => dao.get_files_with_any_tag_ids(tag_ids.clone(), exclude_tag_ids), + FilterMode::All => dao.get_files_with_all_tag_ids(tag_ids.clone(), exclude_tag_ids), + } + .context(format!( + "Failed to get files with tag_ids: {:?} with filter_mode: {:?}", + tag_ids, filter_mode + )) + .map(|files| { + files + .into_iter() + .filter(|file_path| file_path.starts_with(search_path)) + .collect() + }) + .map(|tagged_files| { + trace!("Found tagged files: {:?}", tagged_files); - let file_tags = dao.get_tags_for_path(file_path).unwrap_or_default(); - tag_ids - .iter() - .all(|id| file_tags.iter().any(|tag| &tag.id == id)) - }) - .cloned() - .collect::>(), + HttpResponse::Ok().json(PhotosResponse { + photos: tagged_files, + dirs: vec![], }) - .map(|tagged_files| { - trace!("Found tagged files: {:?}", tagged_files); - - HttpResponse::Ok().json(PhotosResponse { - photos: tagged_files, - dirs: vec![], - }) - }) - .into_http_internal_err() - .unwrap_or_else(|e| e.error_response()); + }) + .into_http_internal_err() + .unwrap_or_else(|e| e.error_response()); } } diff --git a/src/main.rs b/src/main.rs index 9f9c9a3..6f8d7de 100644 --- a/src/main.rs +++ b/src/main.rs @@ -282,7 +282,7 @@ async fn favorites( .expect("Unable to get FavoritesDao") .get_favorites(claims.sub.parse::().unwrap()) }) - .await + .await { Ok(Ok(favorites)) => { let favorites = favorites @@ -317,7 +317,7 @@ async fn put_add_favorite( .expect("Unable to get FavoritesDao") .add_favorite(user_id, &path) }) - .await + .await { Ok(Err(e)) if e.kind == DbErrorKind::AlreadyExists => { debug!("Favorite: {} exists for user: {}", &body.path, user_id); @@ -356,8 +356,8 @@ async fn delete_favorite( .expect("Unable to get favorites dao") .remove_favorite(user_id, path); }) - .await - .unwrap(); + .await + .unwrap(); info!( "Removing favorite \"{}\" for userid: {}", @@ -391,7 +391,7 @@ fn create_thumbnails() { .parent() .unwrap_or_else(|| panic!("Thumbnail {:?} has no parent?", thumb_path)), ) - .expect("Error creating directory"); + .expect("Error creating directory"); debug!("Generating video thumbnail: {:?}", thumb_path); generate_video_thumbnail(entry.path(), &thumb_path); @@ -526,10 +526,10 @@ fn main() -> std::io::Result<()> { .app_data::>>(Data::new(Mutex::new(tag_dao))) .wrap(prometheus.clone()) }) - .bind(dotenv::var("BIND_URL").unwrap())? - .bind("localhost:8088")? - .run() - .await + .bind(dotenv::var("BIND_URL").unwrap())? + .bind("localhost:8088")? + .run() + .await }) } @@ -548,7 +548,8 @@ fn watch_files() { let base_str = dotenv::var("BASE_PATH").unwrap(); let base_path = Path::new(&base_str); - watcher.watch(base_path, RecursiveMode::Recursive) + watcher + .watch(base_path, RecursiveMode::Recursive) .context(format!("Unable to watch BASE_PATH: '{}'", base_str)) .unwrap(); diff --git a/src/tags.rs b/src/tags.rs index a1a22ee..a8fd969 100644 --- a/src/tags.rs +++ b/src/tags.rs @@ -198,7 +198,11 @@ pub trait TagDao { fn create_tag(&mut self, name: &str) -> anyhow::Result; fn remove_tag(&mut self, tag_name: &str, path: &str) -> anyhow::Result>; fn tag_file(&mut self, path: &str, tag_id: i32) -> anyhow::Result; - fn get_files_with_all_tag_ids(&mut self, tag_ids: Vec) -> anyhow::Result>; + fn get_files_with_all_tag_ids( + &mut self, + tag_ids: Vec, + exclude_tag_ids: Vec, + ) -> anyhow::Result>; fn get_files_with_any_tag_ids( &mut self, tag_ids: Vec, @@ -345,11 +349,16 @@ impl TagDao for SqliteTagDao { }) } - fn get_files_with_all_tag_ids(&mut self, tag_ids: Vec) -> anyhow::Result> { + fn get_files_with_all_tag_ids( + &mut self, + tag_ids: Vec, + exclude_tag_ids: Vec, + ) -> anyhow::Result> { use diesel::dsl::*; tagged_photo::table .filter(tagged_photo::tag_id.eq_any(tag_ids.clone())) + .filter(tagged_photo::tag_id.ne_all(exclude_tag_ids)) .group_by(tagged_photo::photo_name) .select((tagged_photo::photo_name, count(tagged_photo::tag_id))) .having(count_distinct(tagged_photo::tag_id).eq(tag_ids.len() as i64)) @@ -496,7 +505,8 @@ mod tests { fn get_files_with_all_tag_ids( &mut self, - _tag_ids: Vec, + tag_ids: Vec, + exclude_tag_ids: Vec, ) -> anyhow::Result> { todo!() }