diff --git a/src/bot.rs b/src/bot.rs index 3280e14..f697558 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -3,6 +3,7 @@ pub mod bot { use std::ops::Deref; + use std::str; use log::{debug, info}; use matrix_sdk::{ @@ -29,7 +30,7 @@ pub mod bot { joined_room: Joined, swear_list: Vec, database_handle: sled::Db, - spam_info: SpamInfo, + spam_db_handle: sled::Db, members_list: Vec, } @@ -109,21 +110,25 @@ pub mod bot { let utc: DateTime = Utc::now(); let timestamp: i64 = utc.timestamp(); let bytes = convert_to_bytes_sled(timestamp, default_reputation); + let spam_path = "spam_tracking_database"; + let spam_db_handle = sled::open(spam_path).unwrap(); for member in members_list.clone() { if member != creds.user_id { - dbg!(db.insert(member.as_str(), &bytes).unwrap()); + { + dbg!(db.insert(member.as_str(), &bytes).unwrap()); + } + { + dbg!(spam_db_handle.insert(member.as_str(), "[]".as_bytes()).unwrap()); + } } } - - let spam_info = SpamInfo::new(); - let bot = Bot { client, info: creds, joined_room, swear_list: swear_list.clone(), database_handle: db, - spam_info, + spam_db_handle, members_list: members_list.clone(), }; @@ -291,51 +296,46 @@ pub mod bot { .unwrap(); let author_name = author.user_id().as_str().to_string(); let curr_utc = Utc::now().timestamp(); - if !self.spam_info.author_msg_times.contains_key(&author_name) - || self.spam_info.author_msg_times.is_empty() - { - self.spam_info - .author_msg_times - .insert(author_name.clone(), vec![curr_utc]); - } else if let Some(msg_times) = self.spam_info.author_msg_times.get_mut(&author_name) { - msg_times.push(curr_utc) - } - - let expire_time: i64 = curr_utc - self.spam_info.detection_window; - + let expire_time: i64 = curr_utc - 5; let mut expired_msgs: Vec = vec![]; - if let Some(msg_time) = self.spam_info.author_msg_times.get(&author_name) { - for msg in msg_time { - if msg < &expire_time { - expired_msgs.push(*msg) + + let spam_data = self.spam_db_handle.get(&author_name); + match spam_data { + Ok(_) => { + if spam_data.clone().unwrap().is_some() { + let mut data_vec = convert_vec_to_str(str::from_utf8(&spam_data.unwrap().unwrap()[..]).unwrap().as_ref()); + if !data_vec.is_empty() { + for time in &data_vec { + if time < &expire_time { + expired_msgs.push(*time) + } + } + + for msg in expired_msgs { + let _ = &data_vec.retain(|value| *value != msg); + } + } + + data_vec.push(curr_utc); + + if data_vec.len() > 5 && author_name != self.info.user_id { + self.delete_message_from_room(&event.event_id, "Spamming") + .await; + self.update_reputation_for_member(&author, -1) + .await + .unwrap(); + } + + dbg!(self.spam_db_handle.insert(&author_name, format!("{:?}", data_vec).as_str().as_bytes()).unwrap()); } + else { + dbg!(self.spam_db_handle.insert(&author_name, format!("{:?}", vec![curr_utc]).as_str().as_bytes()).unwrap()); + } + }, + Err(_) => { + dbg!(self.spam_db_handle.insert(&author_name, "[]".as_bytes()).unwrap()); } } - - for msg in expired_msgs { - self.spam_info - .author_msg_times - .get_mut(&author_name) - .unwrap() - .retain(|value| *value != msg); - } - - if i64::try_from( - self.spam_info - .author_msg_times - .get(&author_name) - .expect("Cannot find author") - .len(), - ) - .unwrap() - > self.spam_info.max_msg_per_window - { - self.delete_message_from_room(&event.event_id, "Spamming") - .await; - self.update_reputation_for_member(&author, -1) - .await - .unwrap(); - } } async fn detect_command(&self, event: &OriginalSyncRoomMessageEvent, message: &str) { diff --git a/src/utils.rs b/src/utils.rs index 79378a7..28c6f7a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -7,30 +7,11 @@ pub mod utils { }; use matrix_sdk::ruma::OwnedEventId; use sled::IVec; - use std::str::Chars; use std::{ - collections::HashMap, fs::File, io::{copy, BufRead, BufReader, Result}, }; - #[derive(Clone)] - pub struct SpamInfo { - pub detection_window: i64, - pub max_msg_per_window: i64, - pub author_msg_times: HashMap>, - } - - impl SpamInfo { - pub fn new() -> Self { - SpamInfo { - detection_window: 5, - max_msg_per_window: 5, - author_msg_times: HashMap::new(), - } - } - } - pub fn get_message_event_text(event: &OriginalSyncRoomMessageEvent) -> Option { if let MessageType::Text(TextMessageEventContent { body, .. }) = &event.content.msgtype { Some(body.to_owned()) @@ -93,7 +74,7 @@ pub mod utils { *bytes } - pub fn convert_str_to_vec(input: &str) -> Vec { + pub fn convert_vec_to_str(input: &str) -> Vec { // Implementation 1: // input // .split_at(input.len() - 1) @@ -111,10 +92,15 @@ pub mod utils { // .collect::>(); // // Improved implementation - input[1..input.len() - 1] - .split(',') - .map(|n| n.trim().parse().unwrap()) - .collect() + if !input.is_empty() && input != "[]" { + input[1..input.len() - 1] + .split(',') + .map(|n| n.trim().parse().unwrap()) + .collect() + } + else { + vec![] + } } pub fn detect_caps(message: &str) -> bool {