From 51f51f768bc330ef2a4e15f68d44d71ceda7dcd5 Mon Sep 17 00:00:00 2001 From: otsmr Date: Mon, 13 Apr 2026 16:38:59 +0200 Subject: [PATCH] working tests --- rust/Cargo.lock | 127 ++--- rust/Cargo.toml | 4 +- rust/src/sss/sss.rs | 0 rust/src/user_discovery/mod.rs | 455 +++++++++++++++--- .../user_discovery/stores/in_memory_store.rs | 101 ++-- rust/src/user_discovery/traits.rs | 74 ++- rust/src/user_discovery/types.proto | 9 +- 7 files changed, 576 insertions(+), 194 deletions(-) delete mode 100644 rust/src/sss/sss.rs diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 12b0104..07f9c79 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -437,6 +437,19 @@ dependencies = [ "regex", ] +[[package]] +name = "env_logger" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" +dependencies = [ + "humantime", + "is-terminal", + "log", + "regex", + "termcolor", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -786,6 +799,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "hex" version = "0.4.3" @@ -819,6 +838,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "humantime" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" + [[package]] name = "iana-time-zone" version = "0.1.65" @@ -964,6 +989,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi 0.5.2", + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "itertools" version = "0.14.0" @@ -1107,15 +1143,6 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" -[[package]] -name = "nu-ansi-term" -version = "0.50.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" -dependencies = [ - "windows-sys 0.61.2", -] - [[package]] name = "num-bigint-dig" version = "0.8.6" @@ -1168,7 +1195,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.3", "libc", ] @@ -1335,6 +1362,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "pretty_env_logger" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "865724d4dbe39d9f3dd3b52b88d859d66bcb2d6a0acfd5ea68a65fb66d4bdc1c" +dependencies = [ + "env_logger", + "log", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -1541,6 +1578,7 @@ dependencies = [ "blahaj", "flutter_rust_bridge", "postcard", + "pretty_env_logger", "prost", "prost-build", "rand 0.10.1", @@ -1549,7 +1587,6 @@ dependencies = [ "thiserror", "tokio", "tracing", - "tracing-subscriber", ] [[package]] @@ -1675,15 +1712,6 @@ dependencies = [ "digest", ] -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - [[package]] name = "shlex" version = "1.3.0" @@ -2011,6 +2039,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "thiserror" version = "2.0.18" @@ -2031,15 +2068,6 @@ dependencies = [ "syn", ] -[[package]] -name = "thread_local" -version = "1.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" -dependencies = [ - "cfg-if", -] - [[package]] name = "threadpool" version = "1.8.1" @@ -2143,32 +2171,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" -dependencies = [ - "nu-ansi-term", - "sharded-slab", - "smallvec", - "thread_local", - "tracing-core", - "tracing-log", ] [[package]] @@ -2228,12 +2230,6 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" -[[package]] -name = "valuable" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" - [[package]] name = "vcpkg" version = "0.2.15" @@ -2396,6 +2392,15 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "windows-core" version = "0.62.2" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 1ee6689..eaf86a7 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -23,10 +23,12 @@ tracing = "0.1.44" serde = "1.0.228" prost = "0.14.1" rand = "0.10.1" -tracing-subscriber = "0.3.23" blahaj = "0.6.0" postcard = { version = "1.1.3", features = ["alloc"] } +[dev-dependencies] +pretty_env_logger = "0.5.0" + [build-dependencies] prost-build = "0.14.1" diff --git a/rust/src/sss/sss.rs b/rust/src/sss/sss.rs deleted file mode 100644 index e69de29..0000000 diff --git a/rust/src/user_discovery/mod.rs b/rust/src/user_discovery/mod.rs index 6701a67..55ee1b4 100644 --- a/rust/src/user_discovery/mod.rs +++ b/rust/src/user_discovery/mod.rs @@ -2,17 +2,23 @@ mod error; pub mod stores; mod traits; +use std::collections::HashMap; + use blahaj::{Share, Sharks}; use postcard::{from_bytes, to_allocvec}; use prost::Message; use serde::{Deserialize, Serialize}; use crate::user_discovery::error::{Result, UserDiscoveryError}; -use crate::user_discovery::traits::UserDiscoveryUtils; +use crate::user_discovery::traits::{AnnouncedUser, OtherPromotion, UserDiscoveryUtils}; use crate::user_discovery::user_discovery_message::{UserDiscoveryAnnouncement, UserDiscoveryPromotion}; use crate::user_discovery::user_discovery_message::user_discovery_promotion::AnnouncementShareDecrypted; use crate::user_discovery::user_discovery_message::user_discovery_promotion::announcement_share_decrypted::SignedData; pub use traits::UserDiscoveryStore; +#[cfg(test)] +static TRANSMITTED_NETWORK_BYTES: std::sync::OnceLock> = + std::sync::OnceLock::new(); + pub type UserID = i64; include!(concat!(env!("OUT_DIR"), "/_.rs")); @@ -31,6 +37,8 @@ struct UserDiscoveryConfig { public_id: u64, /// Verification shares verification_shares: Vec>, + // The users' id: + user_id: UserID, } impl Default for UserDiscoveryConfig { @@ -42,6 +50,7 @@ impl Default for UserDiscoveryConfig { promotion_version: 0, verification_shares: vec![], public_id: 0, + user_id: 0, } } } @@ -74,6 +83,7 @@ impl UserDiscovery UserDiscoveryConfig { threshold, + user_id, ..Default::default() }, }; @@ -86,7 +96,7 @@ impl UserDiscovery UserDiscovery Result)>>> { + self.store.get_all_announced_users() + } + /// Returns all new user discovery messages for the provided contact pub fn get_new_messages( &self, @@ -206,9 +223,17 @@ impl UserDiscovery = std::sync::OnceLock::new(); Ok(messages) } @@ -223,8 +248,9 @@ impl UserDiscovery {stored_version:?}"); Ok(received_version.announcement > stored_version.announcement - || received_version.promotion < received_version.promotion) + || received_version.promotion > stored_version.promotion) } pub(crate) fn get_contact_version(&self, contact_id: UserID) -> Result>> { @@ -246,10 +272,11 @@ impl UserDiscovery UserDiscovery UserDiscovery { - return Err(UserDiscoveryError::ShamirsSecret(err.to_string())); - } + Err(err) => Err(UserDiscoveryError::ShamirsSecret(err.to_string())), } } fn handle_user_discovery_promotion( &self, - contact_id: UserID, - uda: UserDiscoveryPromotion, - ) { + from_contact_id: UserID, + udp: UserDiscoveryPromotion, + ) -> Result<()> { + tracing::debug!("Received a new UDP with public_id = {}.", &udp.public_id); - // contact_id - // uda.promotion_id - // uda.public_id - // uda.threshold - // uda.announcement_share - // uda.verification_state + self.store.store_other_promotion(OtherPromotion { + from_contact_id, + promotion_id: udp.promotion_id, + threshold: udp.threshold as u8, + public_id: udp.public_id, + announcement_share: udp.announcement_share, + public_key_verified_timestamp: udp.public_key_verified_timestamp, + })?; - // store this into the received_promotion_table - // check if the threshold is reached - // in case thr threshold is reached -> CALL STORE -> NEW USER - // otherwise do nothing + if let Some(contact) = self.store.get_announced_user_by_public_id(udp.public_id)? { + tracing::debug!( + "NEW PROMOTION 2: {} knows {}", + from_contact_id, + contact.user_id + ); + // The user is already known, just propagate the relation ship + self.store.push_new_user_relation( + from_contact_id, + contact, + udp.public_key_verified_timestamp, + )?; + return Ok(()); + } + let promotions = self.store.get_other_promotions_by_public_id(udp.public_id); + + if promotions.len() < udp.threshold as usize { + tracing::debug!( + "Not enough shares ({} < {}) to decrypt announcement. Waiting for next share.", + promotions.len(), + udp.threshold + ); + return Ok(()); + } + + tracing::debug!("Enough shares decrypting announcement."); + + let shares: Vec<_> = promotions + .iter() + .map(|x| x.announcement_share.to_owned()) + .filter_map(|x| Share::try_from(x.as_slice()).ok()) + .collect(); + + match Sharks(udp.threshold as u8).recover(&shares) { + Ok(secret) => { + tracing::debug!("Could decrypt announcement."); + let asd = AnnouncementShareDecrypted::decode(secret.as_slice())?; + if let Some(signed_data) = asd.signed_data { + if udp.public_id != signed_data.public_id { + tracing::error!( + "Mismatch of the announced public id and the signed public id " + ); + return Ok(()); + } + + if !self.utils.verify_signature( + &signed_data.encode_to_vec(), + &signed_data.public_key, + &asd.signature, + )? { + return Err(UserDiscoveryError::MaliciousAnnouncementData(format!( + "signature is invalid", + ))); + } + + tracing::debug!("Announcement valid."); + + let announced_user = AnnouncedUser { + user_id: signed_data.user_id, + public_key: signed_data.public_key, + public_id: udp.public_id, + }; + + let user_id = self.get_config()?.user_id; + for promotion in promotions { + // Do not store the announcement of the users itself. + // Or in case the promotion promotes myself + if promotion.from_contact_id == announced_user.user_id + || announced_user.user_id == user_id + { + continue; + } + tracing::debug!( + "NEW PROMOTION: {} knows {}", + promotion.from_contact_id, + announced_user.user_id + ); + self.store.push_new_user_relation( + promotion.from_contact_id, + announced_user.clone(), + promotion.public_key_verified_timestamp, + )?; + } + } + Ok(()) + } + Err(err) => Err(UserDiscoveryError::ShamirsSecret(err.to_string())), + } } - } #[cfg(test)] mod tests { + use std::collections::{HashMap, HashSet}; + use std::vec; + use crate::user_discovery::stores::InMemoryStore; use crate::user_discovery::traits::tests::TestingUtils; - use crate::user_discovery::{UserDiscovery, UserDiscoveryVersion, UserID}; + use crate::user_discovery::{ + UserDiscovery, UserDiscoveryVersion, UserID, TRANSMITTED_NETWORK_BYTES, + }; use prost::Message; fn get_version_bytes(announcement: u32, promotion: u32) -> Vec { @@ -380,16 +526,11 @@ mod tests { .encode_to_vec() } - fn zero() -> Vec { - get_version_bytes(0, 0) - } - - fn get_ud(user_id: UserID, friends: &[UserID]) -> UserDiscovery { + fn get_ud(user_id: usize) -> UserDiscovery { let store = InMemoryStore::default(); - store.set_friends(friends.iter().copied().collect()); let ud = UserDiscovery::new(store.to_owned(), TestingUtils::default()).unwrap(); - ud.initialize_or_update(3, user_id, vec![user_id as u8; 32]) + ud.initialize_or_update(2, user_id as UserID, vec![user_id as u8; 32]) .unwrap(); let version = ud.get_current_version().unwrap(); @@ -398,74 +539,236 @@ mod tests { ud } - fn init() { - tracing_subscriber::fmt() - .with_max_level(tracing::Level::DEBUG) - .init(); + fn assert_new_messages( + from: (usize, &UserDiscovery), + to: (usize, &UserDiscovery), + has_new_messages: bool, + ) { + // From sends a message with his current version to To + let to_received_version = &from.1.get_current_version().unwrap(); + assert_eq!( + to.1.should_request_new_messages(from.0 as UserID, to_received_version) + .unwrap(), + has_new_messages + ); } fn request_and_handle_messages( - from: (UserID, &UserDiscovery), - to: (UserID, &UserDiscovery), + from: (usize, &UserDiscovery), + to: (usize, &UserDiscovery), messages_count: usize, ) { // From sends a message with his current version to To let to_received_version = &from.1.get_current_version().unwrap(); assert_eq!( - to.1.should_request_new_messages(from.0, to_received_version) + to.1.should_request_new_messages(from.0 as UserID, to_received_version) .unwrap(), true ); // As To has a older version stored he sends a request to From: Give me all messages since version. let from_request_version_from_to = - to.1.get_contact_version(from.0).unwrap().unwrap_or(zero()); + to.1.get_contact_version(from.0 as UserID) + .unwrap() + .unwrap_or(get_version_bytes(0, 0)); let new_messages = from .1 - .get_new_messages(to.0, &from_request_version_from_to) + .get_new_messages(to.0 as UserID, &from_request_version_from_to) .unwrap(); - assert_eq!(new_messages.len(), messages_count); + assert!(new_messages.len() <= messages_count); - to.1.handle_user_discovery_messages(from.0, new_messages) + to.1.handle_user_discovery_messages(from.0 as UserID, new_messages) .unwrap(); assert_eq!( - to.1.should_request_new_messages(from.0, &from.1.get_current_version().unwrap()) - .unwrap(), + to.1.should_request_new_messages( + from.0 as UserID, + &from.1.get_current_version().unwrap() + ) + .unwrap(), false ); } + const ALICE: usize = 0; + const BOB: usize = 1; + const CHARLIE: usize = 2; + const DAVID: usize = 3; + const FRANK: usize = 4; + const TEST_USER_COUNT: usize = 5; + struct TestUsers { + names: [&'static str; TEST_USER_COUNT], + friends: [Vec; TEST_USER_COUNT], + uds: Vec>, + } + + impl TestUsers { + fn get() -> Self { + let names = ["ALICE", "BOB", "CHARLIE", "DAVID", "FRANK"]; + let mut uds = vec![]; + for index in 0..names.len() { + uds.push(get_ud(index)); + } + let friends = [ + vec![BOB, CHARLIE], + vec![ALICE, CHARLIE, DAVID], + vec![ALICE, BOB, DAVID, FRANK], + vec![BOB, CHARLIE], + vec![CHARLIE], + ]; + Self { + names, + uds, + friends, + } + } + } + #[tokio::test] async fn test_initialize_user_discovery() { - init(); + pretty_env_logger::init(); + let counter = TRANSMITTED_NETWORK_BYTES.get_or_init(|| std::sync::Mutex::new(0)); - const ALICE: UserID = 0; - const BOB: UserID = 1; - const CHARLIE: UserID = 2; - const DAVID: UserID = 3; - const FRANK: UserID = 4; + let users = TestUsers::get(); - let alice_ud = get_ud(ALICE, &[BOB, CHARLIE]); - let bob_ud = get_ud(BOB, &[ALICE, CHARLIE, DAVID]); - let charlie_ud = get_ud(CHARLIE, &[ALICE, BOB, DAVID, FRANK]); - let david_ud = get_ud(DAVID, &[BOB, CHARLIE]); - let frank_ud = get_ud(FRANK, &[CHARLIE]); + fn to_all_friends(from: usize, message_count: usize, users: &TestUsers) { + for friend in &users.friends[from] { + tracing::debug!("From {} to {}", users.names[from], users.names[*friend]); - { - // Alice send UDA to Bob and Charlie - request_and_handle_messages((ALICE, &alice_ud), (BOB, &bob_ud), 1); - request_and_handle_messages((ALICE, &alice_ud), (CHARLIE, &charlie_ud), 1); + if message_count == 0 { + assert_new_messages( + (from, &users.uds[from]), + (*friend, &users.uds[*friend]), + false, + ); + } else { + request_and_handle_messages( + (from, &users.uds[from]), + (*friend, &users.uds[*friend]), + message_count, + ); + } + } } - { - // This now contains Bob's own announcement in addition this now also contains the promotion from Alice - request_and_handle_messages((BOB, &bob_ud), (DAVID, &david_ud), 2); - request_and_handle_messages((BOB, &bob_ud), (CHARLIE, &charlie_ud), 2); + let message_flows = [ + // ALICE: own announcement sending to BOB and CHARLIE + (ALICE, 1), + // BOB: own announcement + promotion for ALICE + (BOB, 2), + // BOBs version should not have any new messages for his friends + (BOB, 0), + // ALICE: promotion for BOB + (ALICE, 1), + // CHARLIE: own announcement + promotion for ALICE, BOB + (CHARLIE, 3), + // DAVID: own announcement + promotion for BOB, CHARLIE + (DAVID, 3), + // BOB: promotion for CHARLIE, DAVID + (BOB, 2), + // CHARLIE: promotion for DAVID + (CHARLIE, 1), + // FRANK: own announcement + promotion for CHARLIE + (FRANK, 2), + // CHARLIE: promotion for FRANK + (CHARLIE, 1), + // ALICE: promotion for CHARLIE + (ALICE, 1), + ]; + + for (i, (from, count)) in message_flows.into_iter().enumerate() { + tracing::debug!("MESSAGE FLOW: {i}"); + to_all_friends(from, count, &users); } - todo!(); + tracing::debug!("Now all users should have the newest version."); + + for from in 0..TEST_USER_COUNT { + for to in &users.friends[from] { + tracing::debug!( + "Does {} has open messages for {}?", + &users.names[from], + &users.names[*to] + ); + assert_new_messages((from, &users.uds[from]), (*to, &users.uds[*to]), false); + } + } + + tracing::debug!("Test if all exchanges where successful."); + + let announced_users_expected = [ + // ALICE should now know that BOB and CHARLIE, BOB and DAVID and CHARLIE and DAVID are friends. + // Alice should also have one protected share from Frank. + ( + ALICE, + vec![ + (BOB, vec![CHARLIE]), // ALICE knows Bob and that CHARLIE is connected with BOB + (CHARLIE, vec![BOB]), // ALICE knows CHARLIE and that BOB is connected with CHARLIE + (DAVID, vec![BOB, CHARLIE]), // ALICE knows DAVID and that BOB and CHARLIE are connected with DAVID + ], + ), + ( + BOB, + vec![ + (ALICE, vec![CHARLIE]), + (CHARLIE, vec![ALICE, DAVID]), + (DAVID, vec![CHARLIE]), + ], + ), + ( + CHARLIE, + vec![ + (ALICE, vec![BOB]), + (BOB, vec![ALICE, DAVID]), + (DAVID, vec![BOB]), + (FRANK, vec![]), + ], + ), + ( + DAVID, + vec![ + (ALICE, vec![BOB, CHARLIE]), + (BOB, vec![CHARLIE]), + (CHARLIE, vec![BOB]), + ], + ), + (FRANK, vec![(CHARLIE, vec![])]), + ]; + + for (user, announcements) in announced_users_expected { + let announced_users2 = users.uds[user].get_all_announced_users().unwrap(); + let mut announced_users = HashMap::new(); + for a in announced_users2 { + announced_users.insert(a.0.user_id, a.1.iter().map(|x| x.0).collect::>()); + } + tracing::debug!("{} knows now: {}", users.names[user], announced_users.len()); + assert_eq!(announced_users.len(), announcements.len()); + for (contact_id, announced_users_expected) in announcements { + let announced_users = announced_users.get(&(contact_id as i64)).unwrap(); + tracing::debug!( + "{} knows now that {} has the following friends: {}", + users.names[user], + users.names[contact_id], + announced_users + .iter() + .map(|x| users.names[*x as usize]) + .collect::>() + .join(", ") + ); + let announced_users: HashSet = announced_users.iter().cloned().collect(); + let announced_users_expected: HashSet = announced_users_expected + .iter() + .cloned() + .map(|x| x as i64) + .collect(); + assert_eq!(announced_users, announced_users_expected); + } + } + + let count = TRANSMITTED_NETWORK_BYTES.get().unwrap().lock().unwrap(); + + tracing::info!("Transmitted a total of {} bytes.", *count); } } diff --git a/rust/src/user_discovery/stores/in_memory_store.rs b/rust/src/user_discovery/stores/in_memory_store.rs index 2b7ec34..4cf8d31 100644 --- a/rust/src/user_discovery/stores/in_memory_store.rs +++ b/rust/src/user_discovery/stores/in_memory_store.rs @@ -1,7 +1,7 @@ -use crate::user_discovery::error::UserDiscoveryError; -use crate::user_discovery::traits::UserDiscoveryStore; +use crate::user_discovery::error::{Result, UserDiscoveryError}; +use crate::user_discovery::traits::{AnnouncedUser, OtherPromotion, UserDiscoveryStore}; use crate::user_discovery::UserID; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::{Arc, Mutex}; #[derive(Default)] @@ -10,8 +10,9 @@ pub(crate) struct Storage { unused_shares: Vec>, used_shares: HashMap>, contact_versions: HashMap>, - friends: HashSet, - promotions: Vec>, + other_promotions: Vec, + announced_users: HashMap)>>, + own_promotions: Vec<(UserID, Vec)>, } #[derive(Default, Clone)] @@ -23,33 +24,27 @@ impl InMemoryStore { fn storage(&self) -> std::sync::MutexGuard<'_, Storage> { self.storage.lock().unwrap() } - pub fn set_friends(&self, friends: HashSet) { - self.storage().friends = friends; - } } impl UserDiscoveryStore for InMemoryStore { - fn get_config(&self) -> crate::user_discovery::error::Result> { + fn get_config(&self) -> Result> { if let Some(storage) = self.storage().config.clone() { return Ok(storage); } Err(UserDiscoveryError::NotInitialized) } - fn update_config(&self, update: Vec) -> crate::user_discovery::error::Result<()> { + fn update_config(&self, update: Vec) -> Result<()> { self.storage().config = Some(update); Ok(()) } - fn set_shares(&self, shares: Vec>) -> crate::user_discovery::error::Result<()> { + fn set_shares(&self, shares: Vec>) -> Result<()> { self.storage().unused_shares = shares; Ok(()) } - fn get_share_for_contact( - &self, - contact_id: UserID, - ) -> crate::user_discovery::error::Result> { + fn get_share_for_contact(&self, contact_id: UserID) -> Result> { let mut storage = self.storage(); if let Some(share) = storage.used_shares.get(&contact_id) { return Ok(share.to_vec()); @@ -61,42 +56,82 @@ impl UserDiscoveryStore for InMemoryStore { Err(UserDiscoveryError::NoSharesLeft) } - fn get_contact_version( - &self, - contact_id: UserID, - ) -> crate::user_discovery::error::Result>> { + fn get_contact_version(&self, contact_id: UserID) -> Result>> { Ok(self.storage().contact_versions.get(&contact_id).cloned()) } - fn set_contact_version( - &self, - contact_id: UserID, - update: Vec, - ) -> crate::user_discovery::error::Result<()> { + fn set_contact_version(&self, contact_id: UserID, update: Vec) -> Result<()> { self.storage().contact_versions.insert(contact_id, update); Ok(()) } - fn push_promotion( + fn push_own_promotion( &self, + contact_id: UserID, version: u32, promotion: Vec, - ) -> crate::user_discovery::error::Result<()> { + ) -> Result<()> { let mut storage = self.storage(); // println!("{} != {}", version, storage.promotions.len()); - if version as usize != storage.promotions.len() + 1 { + if version as usize != storage.own_promotions.len() + 1 { return Err(UserDiscoveryError::PushedInvalidVersion); } - storage.promotions.push(promotion); + storage.own_promotions.push((contact_id, promotion)); Ok(()) } - fn get_promotions_after_version( - &self, - version: u32, - ) -> crate::user_discovery::error::Result>> { + fn get_own_promotions_after_version(&self, version: u32) -> Result>> { let storage = self.storage(); - let elements = storage.promotions[(version as usize)..].to_vec(); + let elements = storage.own_promotions[(version as usize)..] + .into_iter() + .map(|(_, promotion)| promotion.to_owned()) + .collect(); Ok(elements) } + + fn store_other_promotion(&self, promotion: OtherPromotion) -> Result<()> { + self.storage().other_promotions.push(promotion); + Ok(()) + } + + fn get_other_promotions_by_public_id(&self, public_id: u64) -> Vec { + self.storage() + .other_promotions + .iter() + .filter(|other| other.public_id == public_id) + .map(OtherPromotion::to_owned) + .collect() + } + + fn get_announced_user_by_public_id(&self, public_id: u64) -> Result> { + Ok(self + .storage() + .announced_users + .iter() + .find(|(u, _)| u.public_id == public_id) + .map(|u| u.0.to_owned())) + } + + fn get_all_announced_users( + &self, + ) -> Result)>>> { + Ok(self.storage().announced_users.clone()) + } + + fn push_new_user_relation( + &self, + from_contact_id: UserID, + announced_user: AnnouncedUser, + public_key_verified_timestamp: Option, + ) -> Result<()> { + let mut storage = self.storage(); + let entry = storage + .announced_users + .entry(announced_user.clone()) + .or_insert(vec![]); + if announced_user.user_id != from_contact_id { + entry.push((from_contact_id, public_key_verified_timestamp)); + } + Ok(()) + } } diff --git a/rust/src/user_discovery/traits.rs b/rust/src/user_discovery/traits.rs index 67dca8c..8646d1f 100644 --- a/rust/src/user_discovery/traits.rs +++ b/rust/src/user_discovery/traits.rs @@ -1,6 +1,25 @@ +use std::collections::HashMap; + use crate::user_discovery::error::Result; use crate::user_discovery::UserID; +#[derive(Clone)] +pub struct OtherPromotion { + pub promotion_id: u32, + pub public_id: u64, + pub from_contact_id: UserID, + pub threshold: u8, + pub announcement_share: Vec, + pub public_key_verified_timestamp: Option, +} + +#[derive(Clone, Hash, PartialEq, Eq)] +pub struct AnnouncedUser { + pub user_id: UserID, + pub public_key: Vec, + pub public_id: u64, +} + pub trait UserDiscoveryStore { fn get_config(&self) -> Result>; fn update_config(&self, update: Vec) -> Result<()>; @@ -8,22 +27,38 @@ pub trait UserDiscoveryStore { fn get_share_for_contact(&self, contact_id: UserID) -> Result>; - fn push_promotion(&self, version: u32, promotion: Vec) -> Result<()>; - fn get_promotions_after_version(&self, version: u32) -> Result>>; + fn push_own_promotion( + &self, + contact_id: UserID, + version: u32, + promotion: Vec, + ) -> Result<()>; + + fn get_own_promotions_after_version(&self, version: u32) -> Result>>; + + fn store_other_promotion(&self, promotion: OtherPromotion) -> Result<()>; + fn get_other_promotions_by_public_id(&self, public_id: u64) -> Vec; + + fn get_announced_user_by_public_id(&self, public_id: u64) -> Result>; + + fn push_new_user_relation( + &self, + from_contact_id: UserID, + announced_user: AnnouncedUser, + public_key_verified_timestamp: Option, + ) -> Result<()>; + + fn get_all_announced_users(&self) + -> Result)>>>; fn get_contact_version(&self, contact_id: UserID) -> Result>>; fn set_contact_version(&self, contact_id: UserID, update: Vec) -> Result<()>; } pub trait UserDiscoveryUtils { - fn sign_data(&self, input_data: Vec) -> Result>; - fn verify_pubkey_and_signature_from( - &self, - from_contact_id: UserID, - data: Vec, - pubkey: Vec, - signature: Vec, - ) -> Result; + fn sign_data(&self, input_data: &[u8]) -> Result>; + fn verify_signature(&self, input_data: &[u8], pubkey: &[u8], signature: &[u8]) -> Result; + fn verify_stored_pubkey(&self, from_contact_id: UserID, pubkey: &[u8]) -> Result; } #[cfg(test)] @@ -33,16 +68,23 @@ pub(crate) mod tests { #[derive(Default)] pub(crate) struct TestingUtils {} impl UserDiscoveryUtils for TestingUtils { - fn sign_data(&self, _input_data: Vec) -> crate::user_discovery::error::Result> { + fn sign_data(&self, _input_data: &[u8]) -> crate::user_discovery::error::Result> { Ok(vec![0; 64]) } - fn verify_pubkey_and_signature_from( + fn verify_signature( &self, - from_contact_id: crate::user_discovery::UserID, - data: Vec, - pubkey: Vec, - signature: Vec, + _data: &[u8], + _pubkey: &[u8], + _signature: &[u8], + ) -> crate::user_discovery::error::Result { + Ok(true) + } + + fn verify_stored_pubkey( + &self, + _from_contact_id: crate::user_discovery::UserID, + _pubkey: &[u8], ) -> crate::user_discovery::error::Result { Ok(true) } diff --git a/rust/src/user_discovery/types.proto b/rust/src/user_discovery/types.proto index f719ece..9237586 100644 --- a/rust/src/user_discovery/types.proto +++ b/rust/src/user_discovery/types.proto @@ -17,7 +17,7 @@ message UserDiscoveryMessage { uint64 public_id = 1; uint32 threshold = 2; bytes announcement_share = 4; - repeated bytes verification_shares = 5; + repeated bytes verification_shares = 6; } message UserDiscoveryPromotion { @@ -26,7 +26,7 @@ message UserDiscoveryMessage { uint32 threshold = 3; bytes announcement_share = 5; - optional VerificationState verification_state = 6; + optional int64 public_key_verified_timestamp = 6; message AnnouncementShareDecrypted { message SignedData { @@ -38,11 +38,6 @@ message UserDiscoveryMessage { bytes signature = 2; } - message VerificationState { - int64 timestamp = 1; - bytes signature = 2; - } - } message UserDiscoveryRecall {