From e9cd8caaed76b33debf616c1a613471c01f9c9b6 Mon Sep 17 00:00:00 2001
From: strawberry <strawberry@pupbrain.dev>
Date: Sat, 25 Nov 2023 02:11:41 -0500
Subject: [PATCH] add feature flagged support for migrating from base64 file
 name keys to sha256 ones

core implementation and tests from https://gitlab.com/famedly/conduit/-/merge_requests/467
feature flag, base64 encode update, and tweaks were me

Signed-off-by: strawberry <strawberry@pupbrain.dev>
---
 Cargo.lock                   |   5 +-
 Cargo.toml                   |   2 +
 src/api/ruma_wrapper/axum.rs |   5 +-
 src/database/mod.rs          |  24 +++++++
 src/service/globals/mod.rs   |  23 ++++++
 src/service/media/mod.rs     | 133 +++++++++++++++++++++++++++++++++--
 6 files changed, 183 insertions(+), 9 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 25b6d41e..265e7bca 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -468,6 +468,7 @@ dependencies = [
  "serde_json",
  "serde_yaml",
  "sha-1",
+ "sha2",
  "thiserror",
  "thread_local",
  "threadpool",
@@ -2601,9 +2602,9 @@ dependencies = [
 
 [[package]]
 name = "sha2"
-version = "0.10.7"
+version = "0.10.8"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8"
+checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
 dependencies = [
  "cfg-if",
  "cpufeatures",
diff --git a/Cargo.toml b/Cargo.toml
index f09d13dd..0a6c15fd 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -96,6 +96,7 @@ thread_local = "1.1.7"
 # used for TURN server authentication
 hmac = "0.12.1"
 sha-1 = "0.10.1"
+sha2 = { version = "0.10.8" }
 # used for conduit's CLI and admin room command parsing
 clap = { version = "4.4.8", default-features = false, features = ["std", "derive", "help", "usage", "error-context"] }
 futures-util = { version = "0.3.29", default-features = false }
@@ -127,6 +128,7 @@ systemd = ["sd-notify"]
 zstd_compression = ["tower-http/compression-zstd"]
 #brotli_compression = ["tower-http/compression-br"]
 #compression = ["tower-http/compression-full"]
+sha256_media = []
 
 [[bin]]
 name = "conduit"
diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs
index 42eb6aa0..9cf1fe16 100644
--- a/src/api/ruma_wrapper/axum.rs
+++ b/src/api/ruma_wrapper/axum.rs
@@ -220,7 +220,10 @@ where
                         let keys_result = services()
                             .rooms
                             .event_handler
-                            .fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.to_owned()])
+                            .fetch_signing_keys_for_server(
+                                &x_matrix.origin,
+                                vec![x_matrix.key.to_owned()],
+                            )
                             .await;
 
                         let keys = match keys_result {
diff --git a/src/database/mod.rs b/src/database/mod.rs
index e18842a9..08d8e666 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -942,6 +942,30 @@ impl KeyValueDatabase {
                 warn!("Migration: 12 -> 13 finished");
             }
 
+            if services().globals.database_version()? < 14 && cfg!(feature = "sha256_media") {
+                warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names");
+                // Move old media files to new names
+                for (key, _) in db.mediaid_file.iter() {
+                    // we know that this method is deprecated, but we need to use it to migrate the old files
+                    // to the new location
+                    //
+                    // TODO: remove this once we're sure that all users have migrated
+                    #[allow(deprecated)]
+                    let old_path = services().globals.get_media_file(&key);
+                    let path = services().globals.get_media_file_new(&key);
+                    // move the file to the new location
+                    if old_path.exists() {
+                        tokio::fs::rename(&old_path, &path).await?;
+                    }
+                }
+
+                services().globals.bump_database_version(13)?;
+
+                warn!("Migration: 13 -> 14 finished");
+            } else {
+                warn!("Skipping migration from version 13 to 14 for converting legacy base64 key file names to sha256 hashes of the base64 keys");
+            }
+
             assert_eq!(
                 services().globals.database_version().unwrap(),
                 latest_database_version
diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs
index 330633c8..5eb2fa12 100644
--- a/src/service/globals/mod.rs
+++ b/src/service/globals/mod.rs
@@ -5,6 +5,8 @@ use ruma::{
     OwnedServerSigningKeyId, OwnedUserId,
 };
 
+use sha2::Digest;
+
 use crate::api::server_server::FedDest;
 
 use crate::{services, Config, Error, Result};
@@ -447,6 +449,27 @@ impl Service {
         r
     }
 
+    /// new SHA256 file name media function, requires "sha256_media" feature flag enabled and database migrated
+    /// uses SHA256 hash of the base64 key as the file name
+    pub fn get_media_file_new(&self, key: &[u8]) -> PathBuf {
+        if services().globals.database_version().unwrap() < 14 {
+            error!("Using SHA256 key file names requires database to be migrated.")
+        }
+        let mut r = PathBuf::new();
+        r.push(self.config.database_path.clone());
+        r.push("media");
+        // Using the hash of the base64 key as the filename
+        // This is to prevent the total length of the path from exceeding the maximum length in most filesystems
+        r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key)));
+        r
+    }
+
+    /// old base64 file name media function
+    /// This is the old version of `get_media_file` that uses the full base64 key as the filename.
+    ///
+    /// This is deprecated and will be removed in a future release.
+    /// Please use `get_media_file_new` instead.
+    #[deprecated(note = "Use get_media_file_new instead")]
     pub fn get_media_file(&self, key: &[u8]) -> PathBuf {
         let mut r = PathBuf::new();
         r.push(self.config.database_path.clone());
diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs
index fc8fa569..179182cb 100644
--- a/src/service/media/mod.rs
+++ b/src/service/media/mod.rs
@@ -3,6 +3,8 @@ use std::io::Cursor;
 
 pub use data::Data;
 
+use base64::{engine::general_purpose, Engine as _};
+
 use crate::{services, Result};
 use image::imageops::FilterType;
 
@@ -35,7 +37,12 @@ impl Service {
             .db
             .create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
 
-        let path = services().globals.get_media_file(&key);
+        let path: std::path::PathBuf;
+        if cfg!(feature = "sha256_media") {
+            path = services().globals.get_media_file_new(&key);
+        } else {
+            path = services().globals.get_media_file(&key);
+        }
         let mut f = File::create(path).await?;
         f.write_all(file).await?;
         Ok(())
@@ -56,7 +63,12 @@ impl Service {
             self.db
                 .create_file_metadata(mxc, width, height, content_disposition, content_type)?;
 
-        let path = services().globals.get_media_file(&key);
+        let path: std::path::PathBuf;
+        if cfg!(feature = "sha256_media") {
+            path = services().globals.get_media_file_new(&key);
+        } else {
+            path = services().globals.get_media_file(&key);
+        }
         let mut f = File::create(path).await?;
         f.write_all(file).await?;
 
@@ -68,7 +80,12 @@ impl Service {
         if let Ok((content_disposition, content_type, key)) =
             self.db.search_file_metadata(mxc, 0, 0)
         {
-            let path = services().globals.get_media_file(&key);
+            let path: std::path::PathBuf;
+            if cfg!(feature = "sha256_media") {
+                path = services().globals.get_media_file_new(&key);
+            } else {
+                path = services().globals.get_media_file(&key);
+            }
             let mut file = Vec::new();
             BufReader::new(File::open(path).await?)
                 .read_to_end(&mut file)
@@ -121,7 +138,12 @@ impl Service {
             self.db.search_file_metadata(mxc.clone(), width, height)
         {
             // Using saved thumbnail
-            let path = services().globals.get_media_file(&key);
+            let path: std::path::PathBuf;
+            if cfg!(feature = "sha256_media") {
+                path = services().globals.get_media_file_new(&key);
+            } else {
+                path = services().globals.get_media_file(&key);
+            }
             let mut file = Vec::new();
             File::open(path).await?.read_to_end(&mut file).await?;
 
@@ -134,7 +156,12 @@ impl Service {
             self.db.search_file_metadata(mxc.clone(), 0, 0)
         {
             // Generate a thumbnail
-            let path = services().globals.get_media_file(&key);
+            let path: std::path::PathBuf;
+            if cfg!(feature = "sha256_media") {
+                path = services().globals.get_media_file_new(&key);
+            } else {
+                path = services().globals.get_media_file(&key);
+            }
             let mut file = Vec::new();
             File::open(path).await?.read_to_end(&mut file).await?;
 
@@ -204,7 +231,12 @@ impl Service {
                     content_type.as_deref(),
                 )?;
 
-                let path = services().globals.get_media_file(&thumbnail_key);
+                let path: std::path::PathBuf;
+                if cfg!(feature = "sha256_media") {
+                    path = services().globals.get_media_file_new(&key);
+                } else {
+                    path = services().globals.get_media_file(&key);
+                }
                 let mut f = File::create(path).await?;
                 f.write_all(&thumbnail_bytes).await?;
 
@@ -226,3 +258,92 @@ impl Service {
         }
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use std::path::PathBuf;
+
+    use sha2::Digest;
+
+    use super::*;
+
+    struct MockedKVDatabase;
+
+    impl Data for MockedKVDatabase {
+        fn create_file_metadata(
+            &self,
+            mxc: String,
+            width: u32,
+            height: u32,
+            content_disposition: Option<&str>,
+            content_type: Option<&str>,
+        ) -> Result<Vec<u8>> {
+            // copied from src/database/key_value/media.rs
+            let mut key = mxc.as_bytes().to_vec();
+            key.push(0xff);
+            key.extend_from_slice(&width.to_be_bytes());
+            key.extend_from_slice(&height.to_be_bytes());
+            key.push(0xff);
+            key.extend_from_slice(
+                content_disposition
+                    .as_ref()
+                    .map(|f| f.as_bytes())
+                    .unwrap_or_default(),
+            );
+            key.push(0xff);
+            key.extend_from_slice(
+                content_type
+                    .as_ref()
+                    .map(|c| c.as_bytes())
+                    .unwrap_or_default(),
+            );
+
+            Ok(key)
+        }
+
+        fn search_file_metadata(
+            &self,
+            _mxc: String,
+            _width: u32,
+            _height: u32,
+        ) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
+            todo!()
+        }
+    }
+
+    #[tokio::test]
+    async fn long_file_names_works() {
+        static DB: MockedKVDatabase = MockedKVDatabase;
+        let media = Service { db: &DB };
+
+        let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned();
+        let width = 100;
+        let height = 100;
+        let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special characters like äöüß and even emoji like 🦀.png\"";
+        let content_type = "image/png";
+        let key = media
+            .db
+            .create_file_metadata(
+                mxc,
+                width,
+                height,
+                Some(content_disposition),
+                Some(content_type),
+            )
+            .unwrap();
+        let mut r = PathBuf::new();
+        r.push("/tmp");
+        r.push("media");
+        // r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD));
+        // use the sha256 hash of the key as the file name instead of the key itself
+        // this is because the base64 encoded key can be longer than 255 characters.
+        r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(&key)));
+        // Check that the file path is not longer than 255 characters
+        // (255 is the maximum length of a file path on most file systems)
+        assert!(
+            r.to_str().unwrap().len() <= 255,
+            "File path is too long: {}",
+            r.to_str().unwrap().len()
+        );
+    }
+}