From 24d2123fc2f1d360b491e209400282a347c016da Mon Sep 17 00:00:00 2001 From: Cameron Date: Sun, 18 May 2025 19:57:16 -0400 Subject: [PATCH] Fix recursive-any tag counting This is bad security wise so it'll need another pass. --- Cargo.lock | 23 ++---------------- Cargo.toml | 4 +-- src/tags.rs | 70 ++++++++++++++++++++++++++++++++--------------------- 3 files changed, 45 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9f8bb85..71d6d3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -908,9 +908,9 @@ dependencies = [ [[package]] name = "diesel" -version = "2.2.5" +version = "2.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbf9649c05e0a9dbd6d0b0b8301db5182b972d0fd02f0a7c6736cf632d7c0fd5" +checksum = "ff3e1edb1f37b4953dd5176916347289ed43d7119cc2e6c7c3f7849ff44ea506" dependencies = [ "diesel_derives", "libsqlite3-sys", @@ -1624,7 +1624,6 @@ dependencies = [ "opentelemetry", "opentelemetry-appender-log", "opentelemetry-otlp", - "opentelemetry-resource-detectors", "opentelemetry-stdout", "opentelemetry_sdk", "path-absolutize", @@ -1633,7 +1632,6 @@ dependencies = [ "rayon", "serde", "serde_json", - "tempfile", "tokio", "walkdir", ] @@ -2176,23 +2174,6 @@ dependencies = [ "tonic", ] -[[package]] -name = "opentelemetry-resource-detectors" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0cd3cf373f6f7f3a8f25a189acf1300c8b87e85f7959b45ba83c01e305f5cc3" -dependencies = [ - "opentelemetry", - "opentelemetry-semantic-conventions", - "opentelemetry_sdk", -] - -[[package]] -name = "opentelemetry-semantic-conventions" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fb3a2f78c2d55362cd6c313b8abedfbc0142ab3c2676822068fd2ab7d51f9b7" - [[package]] name = "opentelemetry-stdout" version = "0.28.0" diff --git a/Cargo.toml b/Cargo.toml index f4a0ecd..7c98e1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ futures = "0.3.5" jsonwebtoken = "9.3.0" serde = "1" serde_json = "1" -diesel = { version = "2.2.5", features = ["sqlite"] } +diesel = { version = "2.2.10", features = ["sqlite"] } diesel_migrations = "2.2.0" chrono = "0.4" dotenv = "0.15" @@ -37,10 +37,8 @@ prometheus = "0.13" lazy_static = "1.5" anyhow = "1.0" rand = "0.8.5" -tempfile = "3.14.0" opentelemetry = { version = "0.28.0", features = ["default", "metrics", "tracing"] } opentelemetry_sdk = { version = "0.28.0", features = ["default", "rt-tokio-current-thread", "tracing", "metrics"] } opentelemetry-otlp = { version = "0.28.0", features = ["default", "metrics", "tracing", "grpc-tonic"] } opentelemetry-stdout = "0.28.0" opentelemetry-appender-log = "0.28.0" -opentelemetry-resource-detectors = "0.7.0" diff --git a/src/tags.rs b/src/tags.rs index 0339dfd..79632f4 100644 --- a/src/tags.rs +++ b/src/tags.rs @@ -6,6 +6,7 @@ use anyhow::Context; use chrono::Utc; use diesel::dsl::count_star; use diesel::prelude::*; +use diesel::sql_types::*; use log::{debug, info, trace}; use schema::{tagged_photo, tags}; use serde::{Deserialize, Serialize}; @@ -390,30 +391,42 @@ impl TagDao for SqliteTagDao { ) -> anyhow::Result> { use diesel::dsl::*; - let exclude_subquery = tagged_photo::table - .filter(tagged_photo::tag_id.eq_any(exclude_tag_ids.clone())) - .select(tagged_photo::photo_name) - .into_boxed(); + let tag_ids_str = tag_ids + .iter() + .map(|id| id.to_string()) + .collect::>() + .join(","); - tagged_photo::table - .filter(tagged_photo::tag_id.eq_any(tag_ids.clone())) - .filter(tagged_photo::photo_name.ne_all(exclude_subquery)) - .group_by(tagged_photo::photo_name) - .select(( - tagged_photo::photo_name, - count_distinct(tagged_photo::tag_id), - )) - .get_results::<(String, i64)>(&mut self.connection) - .map(|results| { - results - .into_iter() - .map(|(file_name, tag_count)| FileWithTagCount { - file_name, - tag_count, - }) - .collect() - }) - .with_context(|| format!("Unable to get Tagged photos with ids: {:?}", tag_ids)) + let exclude_tag_ids_str = exclude_tag_ids + .iter() + .map(|id| id.to_string()) + .collect::>() + .join(","); + + let query = sql_query(format!( + r#" +WITH filtered_photos AS ( + SELECT DISTINCT photo_name + FROM tagged_photo tp + WHERE tp.tag_id IN ({}) + AND tp.photo_name NOT IN ( + SELECT photo_name + FROM tagged_photo + WHERE tag_id IN ({}) + ) + ) + SELECT + fp.photo_name as file_name, + COUNT(DISTINCT tp2.tag_id) as tag_count + FROM filtered_photos fp + JOIN tagged_photo tp2 ON fp.photo_name = tp2.photo_name + GROUP BY fp.photo_name"#, + tag_ids_str, exclude_tag_ids_str + )); + + // Execute the query: + let results = query.load::(&mut self.connection)?; + Ok(results) } } @@ -629,9 +642,10 @@ mod tests { ); } } - -#[derive(Debug, Clone)] -pub struct FileWithTagCount { - pub file_name: String, - pub tag_count: i64, +#[derive(QueryableByName, Debug, Clone)] +pub(crate) struct FileWithTagCount { + #[diesel(sql_type = Text)] + pub(crate) file_name: String, + #[diesel(sql_type = BigInt)] + pub(crate) tag_count: i64, }