diff --git a/src/serde_bencode.rs b/src/serde_bencode.rs index c17eac7..3f64e03 100644 --- a/src/serde_bencode.rs +++ b/src/serde_bencode.rs @@ -1,6 +1,6 @@ use std::fmt; use std::fmt::Display; - +use std::str::Utf8Error; use serde::{de, Deserialize, ser}; use serde::de::Visitor; @@ -12,6 +12,7 @@ pub enum Error { Eof, Syntax, + InvalidUtf8, ExpectedInteger, ExpectedIntegerEnd, ExpectedListEnd, @@ -31,12 +32,19 @@ impl de::Error for Error { } } +impl From for Error { + fn from(value: Utf8Error) -> Self { + Error::InvalidUtf8 + } +} + impl Display for Error { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { match self { Error::Message(msg) => formatter.write_str(msg), Error::Eof => formatter.write_str("unexpected end of input"), Error::Syntax => formatter.write_str("syntax error"), + Error::InvalidUtf8 => formatter.write_str("could not decoded as UTF-8"), Error::ExpectedInteger => formatter.write_str("expected integer"), Error::ExpectedIntegerEnd => formatter.write_str("expected integer end char 'e'"), Error::ExpectedListEnd => formatter.write_str("expected list end char 'e'"), @@ -95,11 +103,22 @@ impl<'de> Deserializer<'de> { } let end_pos = self.input.iter().position(|&x| x == 'e' as u8) .ok_or_else(|| Error::ExpectedIntegerEnd)?; - let int_str = std::str::from_utf8(&self.input[..end_pos]).map_err(|_| Error::Syntax)?; + let int_str = std::str::from_utf8(&self.input[..end_pos])?; let int = int_str.parse::().map_err(|_| Error::ExpectedInteger)?; self.input = &self.input[end_pos + 1..]; Ok(int) } + + fn parse_byte_string(&mut self) -> Result<&'de [u8]> { + let delim_pos = self.input.iter().position(|&x| x == ':' as u8) + .ok_or_else(|| Error::Syntax)?; + let int_str = std::str::from_utf8(&self.input[..delim_pos])?; + let str_len = int_str.parse().map_err(|_| Error::ExpectedInteger)?; + self.input = &self.input[delim_pos + 1..]; + let res = &self.input[..str_len]; + self.input = &self.input[str_len..]; + Ok(res) + } } impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { @@ -196,42 +215,56 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { unimplemented!() } - fn deserialize_char(self, _: V) -> std::result::Result + fn deserialize_char(self, visitor: V) -> std::result::Result where V: Visitor<'de> { - unimplemented!() + let chars = std::str::from_utf8(self.parse_byte_string()?)?.chars().collect::>(); + match chars.len() { + 1 => visitor.visit_char(chars[0]), + 0 => Err(Error::Eof), + _ => Err(Error::TrailingCharacters), + } } fn deserialize_str(self, visitor: V) -> Result where V: Visitor<'de> { - todo!() + visitor.visit_borrowed_str( + std::str::from_utf8( + self.parse_byte_string()? + )? + ) } fn deserialize_string(self, visitor: V) -> Result where V: Visitor<'de> { - todo!() + visitor.visit_string( + std::str::from_utf8( + self.parse_byte_string()? + )? + .to_string() + ) } fn deserialize_bytes(self, visitor: V) -> Result where V: Visitor<'de> { - todo!() + visitor.visit_borrowed_bytes(self.parse_byte_string()?) } fn deserialize_byte_buf(self, visitor: V) -> std::result::Result where V: Visitor<'de> { - todo!() + visitor.visit_byte_buf(Vec::from(self.parse_byte_string()?)) } - fn deserialize_option(self, visitor: V) -> std::result::Result + fn deserialize_option(self, _: V) -> std::result::Result where V: Visitor<'de> { @@ -328,4 +361,14 @@ mod test { // 2**16 + 1 = 65537 assert_eq!(from_bytes(&"i65537e".as_bytes()), Ok(65537)); } + + #[test] + fn test_str() { + assert_eq!(from_bytes(&"3:foo".as_bytes()), Ok("foo")); + assert_eq!(from_bytes(&"3:bar".as_bytes()), Ok("bar")); + assert_eq!(from_bytes(&"3:bar".as_bytes()), Ok("bar".to_string())); + assert_eq!(from_bytes(&"3:foo".as_bytes()), Ok("foo".as_bytes())); + assert_eq!(from_bytes(&"1:a".as_bytes()), Ok('a')); + assert_eq!(from_bytes(&"4:💩".as_bytes()), Ok('💩')); + } }