diff --git a/src/serde_bencode.rs b/src/serde_bencode.rs index 3f64e03..65a93a9 100644 --- a/src/serde_bencode.rs +++ b/src/serde_bencode.rs @@ -2,7 +2,7 @@ use std::fmt; use std::fmt::Display; use std::str::Utf8Error; use serde::{de, Deserialize, ser}; -use serde::de::Visitor; +use serde::de::{DeserializeSeed, SeqAccess, Visitor}; pub type Result = std::result::Result; @@ -10,12 +10,18 @@ pub type Result = std::result::Result; pub enum Error { Message(String), + WontImplement, Eof, Syntax, InvalidUtf8, + ExpectedBytes, + ExpectedBytesSep, + ExpectedNumbers, ExpectedInteger, ExpectedIntegerEnd, + ExpectedList, ExpectedListEnd, + ExpectedDict, ExpectedDictEnd, TrailingCharacters, } @@ -42,12 +48,18 @@ impl Display for Error { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { match self { Error::Message(msg) => formatter.write_str(msg), + Error::WontImplement => formatter.write_str("there is no reasonable way to deserialize to this type"), Error::Eof => formatter.write_str("unexpected end of input"), Error::Syntax => formatter.write_str("syntax error"), Error::InvalidUtf8 => formatter.write_str("could not decoded as UTF-8"), - Error::ExpectedInteger => formatter.write_str("expected integer"), + Error::ExpectedBytes => formatter.write_str("expected byte string start char: any number"), + Error::ExpectedBytesSep => formatter.write_str("expected byte separator char: ':'"), + Error::ExpectedInteger => formatter.write_str("expected integer start char 'i'"), + Error::ExpectedNumbers => formatter.write_str("expected numbers"), Error::ExpectedIntegerEnd => formatter.write_str("expected integer end char 'e'"), + Error::ExpectedList => formatter.write_str("expected list start char 'l'"), Error::ExpectedListEnd => formatter.write_str("expected list end char 'e'"), + Error::ExpectedDict => formatter.write_str("expected dict start char 'd'"), Error::ExpectedDictEnd => formatter.write_str("expected dict end char 'e'"), Error::TrailingCharacters => formatter.write_str("trailing characters") } @@ -99,21 +111,21 @@ impl<'de> Deserializer<'de> { T: std::str::FromStr, { if self.next_byte()? != 'i' as u8 { - return Err(Error::Syntax) + return Err(Error::ExpectedInteger) } let end_pos = self.input.iter().position(|&x| x == 'e' as u8) .ok_or_else(|| Error::ExpectedIntegerEnd)?; let int_str = std::str::from_utf8(&self.input[..end_pos])?; - let int = int_str.parse::().map_err(|_| Error::ExpectedInteger)?; + let int = int_str.parse::().map_err(|_| Error::ExpectedNumbers)?; self.input = &self.input[end_pos + 1..]; Ok(int) } fn parse_byte_string(&mut self) -> Result<&'de [u8]> { let delim_pos = self.input.iter().position(|&x| x == ':' as u8) - .ok_or_else(|| Error::Syntax)?; + .ok_or_else(|| Error::ExpectedBytesSep)?; let int_str = std::str::from_utf8(&self.input[..delim_pos])?; - let str_len = int_str.parse().map_err(|_| Error::ExpectedInteger)?; + let str_len = int_str.parse().map_err(|_| Error::ExpectedNumbers)?; self.input = &self.input[delim_pos + 1..]; let res = &self.input[..str_len]; self.input = &self.input[str_len..]; @@ -124,18 +136,19 @@ impl<'de> Deserializer<'de> { impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { type Error = Error; - fn deserialize_any(self, visitor: V) -> Result + fn deserialize_any(self, _: V) -> Result where V: Visitor<'de> { - todo!() + println!("i don't get this API!"); + Err(Error::WontImplement) } fn deserialize_bool(self, _: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_i8(self, visitor: V) -> Result @@ -177,42 +190,42 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_u16(self, _: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_u32(self, _: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_u64(self, _: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_f32(self, _: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_f64(self, _: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_char(self, visitor: V) -> std::result::Result @@ -268,49 +281,58 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_unit(self, _: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_unit_struct(self, _: &'static str, _: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } - fn deserialize_newtype_struct(self, name: &'static str, visitor: V) -> Result + fn deserialize_newtype_struct(self, _name: &'static str, visitor: V) -> Result where V: Visitor<'de> { - todo!() + visitor.visit_newtype_struct(self) } fn deserialize_seq(self, visitor: V) -> Result where V: Visitor<'de> { - todo!() + if self.next_byte()? as char != 'l'{ + Err(Error::ExpectedList) + } else { + let value = visitor.visit_seq(Access::new(self))?; + if self.next_byte()? as char != 'e' { + Err(Error::ExpectedListEnd) + } else { + Ok(value) + } + } } fn deserialize_tuple(self, _: usize, _: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_tuple_struct(self, _: &'static str, _: usize, _: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_map(self, visitor: V) -> Result @@ -331,7 +353,7 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de> { - unimplemented!() + Err(Error::WontImplement) } fn deserialize_identifier(self, visitor: V) -> std::result::Result @@ -349,6 +371,35 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { } } +struct Access<'a, 'de: 'a> { + de: &'a mut Deserializer<'de>, +} + +impl <'a, 'de> Access<'a, 'de> { + fn new(de: &'a mut Deserializer<'de>) -> Self { + Access { + de, + } + } +} + +impl<'a, 'de> SeqAccess<'de> for Access<'a, 'de> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> std::result::Result, Self::Error> + where + T: DeserializeSeed<'de> + { + match self.de.peek_byte()? as char { + 'd' | 'i' | 'l' | '0'..='9' => { + seed.deserialize(&mut *self.de).map(Some) + }, + 'e' => Ok(None), + _ => Err(Error::Syntax) + } + } +} + #[cfg(test)] mod test { use crate::serde_bencode::{Error, from_bytes}; @@ -371,4 +422,10 @@ mod test { assert_eq!(from_bytes(&"1:a".as_bytes()), Ok('a')); assert_eq!(from_bytes(&"4:💩".as_bytes()), Ok('💩')); } + + #[test] + fn test_list() { + assert_eq!(from_bytes(&"li42ei13ee".as_bytes()), Ok(vec![42, 13])); + assert_eq!(from_bytes(&"l3:foo4:barie".as_bytes()), Ok(vec!["foo", "bari"])); + } }