diff --git a/Cargo.toml b/Cargo.toml index 549c4ed..c8fbe43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ sha1 = "0.10.6" reqwest = { version = "0.12.5", features = ["blocking"] } rand = "0.8.5" urlencoding = "2.1.3" -serde = "1.0.204" +serde = { version = "1.0.204", features = ["derive"] } [dev-dependencies] hex-literal = "0.4.1" diff --git a/src/main.rs b/src/main.rs index e3ca349..a7910e7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,6 @@ +// who do you think you are? my mom?! +#![allow(uncommon_codepoints)] + use std::{env, fs, path}; use std::collections::HashMap; diff --git a/src/serde_bencode.rs b/src/serde_bencode.rs index 65a93a9..0894717 100644 --- a/src/serde_bencode.rs +++ b/src/serde_bencode.rs @@ -1,8 +1,9 @@ use std::fmt; use std::fmt::Display; use std::str::Utf8Error; + use serde::{de, Deserialize, ser}; -use serde::de::{DeserializeSeed, SeqAccess, Visitor}; +use serde::de::{DeserializeSeed, MapAccess, SeqAccess, Visitor}; pub type Result = std::result::Result; @@ -39,7 +40,7 @@ impl de::Error for Error { } impl From for Error { - fn from(value: Utf8Error) -> Self { + fn from(_: Utf8Error) -> Self { Error::InvalidUtf8 } } @@ -70,15 +71,19 @@ impl std::error::Error for Error {} pub struct Deserializer<'de> { input: &'de [u8], + 𓁺: Option, } impl<'de> Deserializer<'de> { pub fn from_bytes(input: &'de [u8]) -> Self { - Deserializer { input } + Deserializer { + input, + 𓁺: None, + } } } -pub fn from_bytes<'a, T>(b: &'a &[u8]) -> Result +pub fn from_bytes<'a, T>(b: &'a [u8]) -> Result where T: Deserialize<'a>, { @@ -91,6 +96,13 @@ where } } +pub fn from_str<'a, T>(s: &'a str) -> Result +where + T: Deserialize<'a>, +{ + from_bytes(s.as_bytes()) +} + impl<'de> Deserializer<'de> { fn peek_byte(&mut self) -> Result { if let Some(&result) = self.input.get(0) { @@ -144,7 +156,7 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { Err(Error::WontImplement) } - fn deserialize_bool(self, _: V) -> std::result::Result + fn deserialize_bool(self, _: V) -> Result where V: Visitor<'de> { @@ -186,49 +198,53 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { visitor.visit_i128(self.parse_int()?) } - fn deserialize_u8(self, _: V) -> std::result::Result + fn deserialize_u8(self, visitor: V) -> Result + where + V: Visitor<'de> + { + if let Some(ch) = self.𓁺 { + visitor.visit_u8(ch) + } else { + Err(Error::WontImplement) + } + } + + fn deserialize_u16(self, _: V) -> Result where V: Visitor<'de> { Err(Error::WontImplement) } - fn deserialize_u16(self, _: V) -> std::result::Result + fn deserialize_u32(self, _: V) -> Result where V: Visitor<'de> { Err(Error::WontImplement) } - fn deserialize_u32(self, _: V) -> std::result::Result + fn deserialize_u64(self, _: V) -> Result where V: Visitor<'de> { Err(Error::WontImplement) } - fn deserialize_u64(self, _: V) -> std::result::Result + fn deserialize_f32(self, _: V) -> Result where V: Visitor<'de> { Err(Error::WontImplement) } - fn deserialize_f32(self, _: V) -> std::result::Result + fn deserialize_f64(self, _: V) -> Result where V: Visitor<'de> { Err(Error::WontImplement) } - fn deserialize_f64(self, _: V) -> std::result::Result - where - V: Visitor<'de> - { - Err(Error::WontImplement) - } - - fn deserialize_char(self, visitor: V) -> std::result::Result + fn deserialize_char(self, visitor: V) -> Result where V: Visitor<'de> { @@ -270,28 +286,28 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { visitor.visit_borrowed_bytes(self.parse_byte_string()?) } - fn deserialize_byte_buf(self, visitor: V) -> std::result::Result + fn deserialize_byte_buf(self, visitor: V) -> Result where V: Visitor<'de> { visitor.visit_byte_buf(Vec::from(self.parse_byte_string()?)) } - fn deserialize_option(self, _: V) -> std::result::Result + fn deserialize_option(self, _: V) -> Result where V: Visitor<'de> { Err(Error::WontImplement) } - fn deserialize_unit(self, _: V) -> std::result::Result + fn deserialize_unit(self, _: V) -> Result where V: Visitor<'de> { Err(Error::WontImplement) } - fn deserialize_unit_struct(self, _: &'static str, _: V) -> std::result::Result + fn deserialize_unit_struct(self, _: &'static str, _: V) -> Result where V: Visitor<'de> { @@ -312,7 +328,7 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { if self.next_byte()? as char != 'l'{ Err(Error::ExpectedList) } else { - let value = visitor.visit_seq(Access::new(self))?; + let value = visitor.visit_seq(RegularAccess::new(self))?; if self.next_byte()? as char != 'e' { Err(Error::ExpectedListEnd) } else { @@ -321,14 +337,22 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } } - fn deserialize_tuple(self, _: usize, _: V) -> std::result::Result + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result where V: Visitor<'de> { - Err(Error::WontImplement) + match self.peek_byte()? as char{ + 'l' => self.deserialize_seq(visitor), + '0'..='9' => { + let str = self.parse_byte_string()?; + let ba = ByteAccess::new(self, str); + visitor.visit_seq(ba) + }, + _ => Err(Error::Syntax), + } } - fn deserialize_tuple_struct(self, _: &'static str, _: usize, _: V) -> std::result::Result + fn deserialize_tuple_struct(self, _: &'static str, _: usize, _: V) -> Result where V: Visitor<'de> { @@ -339,54 +363,101 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de> { - todo!() + if self.next_byte()? as char != 'd'{ + Err(Error::ExpectedDict) + } else { + let value = visitor.visit_map(RegularAccess::new(self))?; + if self.next_byte()? as char != 'e' { + Err(Error::ExpectedDictEnd) + } else { + Ok(value) + } + } } - fn deserialize_struct(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> Result + fn deserialize_struct(self, _name: &'static str, _fields: &'static [&'static str], visitor: V) -> Result where V: Visitor<'de> { - todo!() + self.deserialize_map(visitor) } - fn deserialize_enum(self, _: &'static str, _: &'static [&'static str], _: V) -> std::result::Result + fn deserialize_enum(self, _: &'static str, _: &'static [&'static str], _: V) -> Result where V: Visitor<'de> { Err(Error::WontImplement) } - fn deserialize_identifier(self, visitor: V) -> std::result::Result + fn deserialize_identifier(self, visitor: V) -> Result where V: Visitor<'de> { - todo!() + self.deserialize_bytes(visitor) } - fn deserialize_ignored_any(self, visitor: V) -> std::result::Result + fn deserialize_ignored_any(self, _: V) -> Result where V: Visitor<'de> { - todo!() + println!("i really don't get this API!"); + Err(Error::WontImplement) } } -struct Access<'a, 'de: 'a> { +struct ByteAccess<'a, 'de: 'a> { + de: &'a mut Deserializer<'de>, + bytes: &'a [u8], +} + +impl <'a, 'de> ByteAccess<'a, 'de> { + fn new(de: &'a mut Deserializer<'de>, bytes: &'a [u8]) -> Self { + ByteAccess { + de, + bytes, + } + } +} + +impl<'a, 'de> SeqAccess<'de> for ByteAccess<'a, 'de> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'de> + { + + if let Some(&ch) = self.bytes.get(0) { + self.de.𓁺 = Some(ch); + self.bytes = &self.bytes[1..]; + seed.deserialize(&mut *self.de).map(Some) + } else { + self.de.𓁺 = None; + Ok(None) + } + } + + fn size_hint(&self) -> Option { + Some(self.bytes.len()) + } +} + +struct RegularAccess<'a, 'de: 'a> { de: &'a mut Deserializer<'de>, } -impl <'a, 'de> Access<'a, 'de> { +impl <'a, 'de> RegularAccess<'a, 'de> { fn new(de: &'a mut Deserializer<'de>) -> Self { - Access { + RegularAccess { de, } } } -impl<'a, 'de> SeqAccess<'de> for Access<'a, 'de> { +impl<'a, 'de> SeqAccess<'de> for RegularAccess<'a, 'de> { type Error = Error; - fn next_element_seed(&mut self, seed: T) -> std::result::Result, Self::Error> + fn next_element_seed(&mut self, seed: T) -> Result> where T: DeserializeSeed<'de> { @@ -400,32 +471,104 @@ impl<'a, 'de> SeqAccess<'de> for Access<'a, 'de> { } } +impl<'a, 'de> MapAccess<'de> for RegularAccess<'a, 'de> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'de> + { + match self.de.peek_byte()? as char { + '0'..='9' => { + seed.deserialize(&mut *self.de).map(Some) + }, + 'e' => Ok(None), + _ => Err(Error::ExpectedBytes) + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de> + { + match self.de.peek_byte()? as char { + 'd' | 'i' | 'l' | '0'..='9' => { + seed.deserialize(&mut *self.de) + }, + _ => Err(Error::Syntax) + } + } +} + #[cfg(test)] mod test { - use crate::serde_bencode::{Error, from_bytes}; + use std::collections::HashMap; + + use serde::Deserialize; + + use crate::serde_bencode::{Error, from_bytes, from_str}; #[test] fn test_int() { - assert_eq!(from_bytes(&"i42e".as_bytes()), Ok(42)); - assert_eq!(from_bytes::(&"i42".as_bytes()), Err(Error::ExpectedIntegerEnd)); - assert_eq!(from_bytes::(&"42e".as_bytes()), Err(Error::Syntax)); + assert_eq!(from_str("i42e"), Ok(42)); + assert_eq!(from_str::("i42"), Err(Error::ExpectedIntegerEnd)); + assert_eq!(from_str::("42e"), Err(Error::ExpectedInteger)); // 2**16 + 1 = 65537 - assert_eq!(from_bytes(&"i65537e".as_bytes()), Ok(65537)); + assert_eq!(from_str("i65537e"), Ok(65537)); } #[test] fn test_str() { - assert_eq!(from_bytes(&"3:foo".as_bytes()), Ok("foo")); - assert_eq!(from_bytes(&"3:bar".as_bytes()), Ok("bar")); - assert_eq!(from_bytes(&"3:bar".as_bytes()), Ok("bar".to_string())); - assert_eq!(from_bytes(&"3:foo".as_bytes()), Ok("foo".as_bytes())); - assert_eq!(from_bytes(&"1:a".as_bytes()), Ok('a')); - assert_eq!(from_bytes(&"4:💩".as_bytes()), Ok('💩')); + assert_eq!(from_str("3:foo"), Ok("foo")); + assert_eq!(from_str("3:bar"), Ok("bar")); + assert_eq!(from_str("3:bar"), Ok("bar".to_string())); + assert_eq!(from_str("3:foo"), Ok("foo")); + assert_eq!(from_str("1:a"), Ok('a')); + assert_eq!(from_str("4:💩"), Ok('💩')); } #[test] - fn test_list() { - assert_eq!(from_bytes(&"li42ei13ee".as_bytes()), Ok(vec![42, 13])); - assert_eq!(from_bytes(&"l3:foo4:barie".as_bytes()), Ok(vec!["foo", "bari"])); + fn test_seq() { + assert_eq!(from_str("li42ei13ee"), Ok(vec![42, 13])); + assert_eq!(from_str("l3:foo4:barie"), Ok(vec!["foo", "bari"])); + assert_eq!(from_str("lli42ei13eelee"), Ok(vec![vec![42, 13], vec![]])); + } + + #[test] + fn test_map() { + assert_eq!(from_str("d3:fooi42ee"), Ok(HashMap::from([ + ("foo", 42) + ]))); + assert_eq!(from_str("d3:foo3:bare"), Ok(HashMap::from([ + ("foo", "bar") + ]))); + } + + #[test] + fn test_struct() { + + #[derive(Deserialize, Debug)] + struct A { + a: i64, + } + let de: A = from_str("d1:ai42ee").unwrap(); + assert_eq!(de.a, 42); + + #[derive(Deserialize, Debug)] + struct B { + l: Vec, + d: HashMap + } + let mut buf = "d1:lli42ei13ei23ee1:dd3:foo3:".as_bytes().to_vec(); + buf.extend(vec![0x1, 0x2, 0x3]); + buf.extend("3:bar3:".as_bytes()); + buf.extend(vec![0x19, 0x31, 0x17]); + buf.extend("ee".as_bytes()); + let de: B = from_bytes(&buf[..]).unwrap(); + assert_eq!(de.l, vec![42, 13, 23]); + assert_eq!(de.d, HashMap::from([ + ("foo".to_string(), [0x1, 0x2, 0x3]), + ("bar".to_string(), [0x19, 0x31, 0x17]), + ])); } }