From 74759747642c4b683a42e2bd414de984c36d369c Mon Sep 17 00:00:00 2001 From: Faerbit Date: Wed, 31 Jul 2024 20:22:37 +0200 Subject: [PATCH] Reimplementing bencoding with Serde. Scaffolding setup. WIP --- Cargo.lock | 1 + Cargo.toml | 1 + src/main.rs | 5 +- src/serde_bencode.rs | 315 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 320 insertions(+), 2 deletions(-) create mode 100644 src/serde_bencode.rs diff --git a/Cargo.lock b/Cargo.lock index 2834eb7..a2ea1b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1069,6 +1069,7 @@ dependencies = [ "hex-literal", "rand", "reqwest", + "serde", "sha1", "urlencoding", ] diff --git a/Cargo.toml b/Cargo.toml index e6caad0..549c4ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ sha1 = "0.10.6" reqwest = { version = "0.12.5", features = ["blocking"] } rand = "0.8.5" urlencoding = "2.1.3" +serde = "1.0.204" [dev-dependencies] hex-literal = "0.4.1" diff --git a/src/main.rs b/src/main.rs index 9c9f44e..e3ca349 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,7 @@ use reqwest::Url; use crate::bencode::{Bencode, ByteString}; mod bencode; +mod serde_bencode; #[derive(Debug)] struct FileInfo { @@ -217,8 +218,8 @@ fn main() -> Result<()> { let resp = client.get(url).send()?; let status = resp.status(); - let body = resp.text()?; - println!("Response: {} {}", status, body); + // let body = resp.text()?; + println!("Response: {status}"); Ok(()) } diff --git a/src/serde_bencode.rs b/src/serde_bencode.rs new file mode 100644 index 0000000..4122f1c --- /dev/null +++ b/src/serde_bencode.rs @@ -0,0 +1,315 @@ +use std::fmt; +use std::fmt::Display; +use std::ops::{AddAssign, MulAssign}; + +use serde::{de, Deserialize, ser}; +use serde::de::Visitor; + +pub type Result = std::result::Result; + +#[derive(Debug, PartialEq)] +pub enum Error { + Message(String), + + Eof, + Syntax, + ExpectedInteger, + ExpectedListEnd, + ExpectedDictEnd, + TrailingCharacters, +} + +impl ser::Error for Error { + fn custom(msg: T) -> Self { + Error::Message(msg.to_string()) + } +} + +impl de::Error for Error { + fn custom(msg: T) -> Self { + Error::Message(msg.to_string()) + } +} + +impl Display for Error { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Message(msg) => formatter.write_str(msg), + Error::Eof => formatter.write_str("unexpected end of input"), + Error::Syntax => formatter.write_str("syntax error"), + Error::ExpectedInteger => formatter.write_str("expected integer"), + Error::ExpectedListEnd => formatter.write_str("expected list end char 'e'"), + Error::ExpectedDictEnd => formatter.write_str("expected dict end char 'e'"), + Error::TrailingCharacters => formatter.write_str("trailing characters") + } + } +} + +impl std::error::Error for Error {} + +pub struct Deserializer<'de> { + input: &'de [u8], +} + +impl<'de> Deserializer<'de> { + pub fn from_bytes(input: &'de [u8]) -> Self { + Deserializer { input } + } +} + +pub fn from_bytes<'a, T>(b: &'a &[u8]) -> Result +where + T: Deserialize<'a>, +{ + let mut deserializer = Deserializer::from_bytes(b); + let t = T::deserialize(&mut deserializer)?; + if deserializer.input.is_empty() { + Ok(t) + } else { + Err(Error::TrailingCharacters) + } +} + +impl<'de> Deserializer<'de> { + fn peek_byte(&mut self) -> Result { + if let Some(&result) = self.input.get(0) { + Ok(result) + } else { + Err(Error::Eof) + } + } + + fn parse_int(&mut self) -> Result + where + //T: AddAssign + MulAssign + From, + T: From, + { + Ok(42.into()) + } +} + +impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_bool(self, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_i8(self, visitor: V) -> Result + where + V: Visitor<'de> + { + visitor.visit_i8(self.parse_int()?) + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: Visitor<'de> + { + visitor.visit_i16(self.parse_int()?) + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: Visitor<'de> + { + visitor.visit_i32(self.parse_int()?) + } + + fn deserialize_i64(self, visitor: V) -> Result + where + V: Visitor<'de> + { + visitor.visit_i64(self.parse_int()?) + } + + fn deserialize_i128(self, visitor: V) -> Result + where + V: Visitor<'de> + { + visitor.visit_i128(self.parse_int()?) + } + + fn deserialize_u8(self, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_u16(self, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_u32(self, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_u64(self, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_f32(self, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_f64(self, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_char(self, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_byte_buf(self, visitor: V) -> std::result::Result + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_option(self, visitor: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_unit(self, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_unit_struct(self, _: &'static str, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_tuple(self, _: usize, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_tuple_struct(self, _: &'static str, _: usize, _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_struct(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> Result + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_enum(self, _: &'static str, _: &'static [&'static str], _: V) -> std::result::Result + where + V: Visitor<'de> + { + unimplemented!() + } + + fn deserialize_identifier(self, visitor: V) -> std::result::Result + where + V: Visitor<'de> + { + todo!() + } + + fn deserialize_ignored_any(self, visitor: V) -> std::result::Result + where + V: Visitor<'de> + { + todo!() + } +} + +#[cfg(test)] +mod test { + use crate::serde_bencode::from_bytes; + + #[test] + fn test_int() { + assert_eq!(from_bytes(&"i42e".as_bytes()), Ok(42)); + // 2**16 + 1 = 65537 + assert_eq!(from_bytes(&"i65537e".as_bytes()), Ok(65537)); + } +}