#![allow(dead_code)] use rand::Rng; use std::fmt::Write; #[derive(PartialEq, PartialOrd, Debug)] pub struct Bytes(pub Vec); impl Bytes { pub fn empty() -> Bytes { Bytes(vec![]) } pub fn from_utf8(s: &str) -> Bytes { Bytes(s.as_bytes().iter().map(|c| c.clone()).collect()) } #[allow(dead_code)] pub fn to_utf8(&self) -> String { let Bytes(v) = self; String::from(std::str::from_utf8(&v).unwrap()) } pub fn len(&self) -> usize { self.0.len() } pub fn get_block(&self, block_index: usize, block_size: usize) -> Bytes { Bytes(self.0[(block_index * block_size)..(block_index + 1) * block_size].to_vec()) } pub fn random(length: usize) -> Bytes { Bytes( (0..length) .map(|_| rand::thread_rng().gen_range(0..255)) .collect(), ) } pub fn random_range(lower: usize, upper: usize) -> Bytes { let length: usize = rand::thread_rng().gen_range(lower..upper); Bytes::random(length) } pub fn from_hex(s: &str) -> Bytes { if s.len() % 2 != 0 { panic!("Input string has uneven number of characters"); } let bytes_result: Result, std::num::ParseIntError> = (0..s.len()) .step_by(2) .map(|i| u8::from_str_radix(&s[i..i + 2], 16)) .collect(); match bytes_result { Ok(b) => Bytes(b), Err(_) => panic!("Could not convert all digit pairs to hex."), } } pub fn to_hex(&self) -> String { let Bytes(v) = self; let mut r = String::new(); for e in v.iter() { write!(r, "{:02x}", e).unwrap(); } r } pub fn is_ascii(&self) -> bool { let Bytes(v) = self; for &c in v.iter() { if c < 32 || c > 127 { return false; } } true } pub fn ascii_score(&self) -> u32 { let Bytes(v) = self; let mut r = 0; for &c in v.iter() { match c { 32 => r += 2, 33..=64 => r += 1, 65..=90 => r += 3, 91..=96 => r += 1, 97..=122 => r += 3, 123..=127 => r += 1, _ => (), } } r } pub fn guess_key(&self) -> u8 { // Assuming all bytes have been xor-encrypted by the same u8 key, find that key // by trying all u8 values and compute an ascii score for the resulting "plaintext". // The u8 key for the "plaintext" with the highest score is returned as the guessed // key. let mut h: Vec<(Bytes, u8)> = (0..=255).map(|i| (Bytes::xor_byte(self, i), i)).collect(); h.sort_by(|a, b| a.0.ascii_score().partial_cmp(&b.0.ascii_score()).unwrap()); h.last().unwrap().1 } pub fn pad_pkcs7(&mut self, block_size: usize) -> () { let Bytes(v) = self; let padding_value = (block_size - v.len() % block_size) as u8; for _ in 0..padding_value { v.push(padding_value); } } pub fn has_valid_pkcs7(&self, block_size: usize) -> bool { if self.len() > 0 && self.len() % block_size != 0 { return false; } let last_block_index = self.len() / block_size - 1; let last_block = self.get_block(last_block_index, block_size).0; let pad_byte = last_block[block_size - 1]; if pad_byte < 1 || pad_byte > 16 { return false; } for i in 0..(pad_byte as usize) { let byte = last_block[block_size - 1 - i]; if byte != pad_byte { return false; } } return true; } pub fn remove_pkcs7(&mut self, block_size: usize) -> () { if !self.has_valid_pkcs7(block_size) { return; } let Bytes(v) = self; let pad_byte_count = v[v.len() - 1]; for _ in 0..(pad_byte_count as usize) { v.pop(); } } pub fn flip_bit(&mut self, byte_index: usize, bit_index: usize) -> () { let Bytes(v) = self; let flip_mask: u8 = 0b1 << bit_index; v[byte_index] ^= flip_mask; } pub fn xor(Bytes(a): &Bytes, Bytes(b): &Bytes) -> Bytes { Bytes(crate::utils::xor(a, b)) } pub fn xor_byte(Bytes(a): &Bytes, byte: u8) -> Bytes { Bytes(a.iter().map(|e| e ^ byte).collect()) } pub fn xor_cycle(Bytes(msg): &Bytes, Bytes(key): &Bytes) -> Bytes { Bytes( Iterator::zip(msg.iter(), 0..msg.len()) .map(|z| *(z.0) ^ key[z.1 % key.len()]) .collect(), ) } pub fn has_duplicated_cycle(&self, block_size: usize) -> bool { let Bytes(v) = self; let chunks: Vec<&[u8]> = v.chunks(block_size).collect(); // we could use a hashmap to get O(n) instead of O(n^2) for i in 0..chunks.len() { for j in (i + 1)..chunks.len() { if chunks[i] == chunks[j] { return true; } } } false } #[allow(dead_code)] pub fn hemming(Bytes(a): &Bytes, Bytes(b): &Bytes) -> u32 { let v: Vec = Iterator::zip(a.iter(), b.iter()) .map(|z| (*(z.0) ^ *(z.1)).count_ones()) .collect(); v.iter().sum() } }