[serde_bencode] Impl byte string

This commit is contained in:
Fabian 2024-07-31 23:25:24 +02:00
parent 8fb6eda144
commit 4caab65bc1

View File

@ -1,6 +1,6 @@
use std::fmt; use std::fmt;
use std::fmt::Display; use std::fmt::Display;
use std::str::Utf8Error;
use serde::{de, Deserialize, ser}; use serde::{de, Deserialize, ser};
use serde::de::Visitor; use serde::de::Visitor;
@ -12,6 +12,7 @@ pub enum Error {
Eof, Eof,
Syntax, Syntax,
InvalidUtf8,
ExpectedInteger, ExpectedInteger,
ExpectedIntegerEnd, ExpectedIntegerEnd,
ExpectedListEnd, ExpectedListEnd,
@ -31,12 +32,19 @@ impl de::Error for Error {
} }
} }
impl From<Utf8Error> for Error {
fn from(value: Utf8Error) -> Self {
Error::InvalidUtf8
}
}
impl Display for Error { impl Display for Error {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
Error::Message(msg) => formatter.write_str(msg), Error::Message(msg) => formatter.write_str(msg),
Error::Eof => formatter.write_str("unexpected end of input"), Error::Eof => formatter.write_str("unexpected end of input"),
Error::Syntax => formatter.write_str("syntax error"), Error::Syntax => formatter.write_str("syntax error"),
Error::InvalidUtf8 => formatter.write_str("could not decoded as UTF-8"),
Error::ExpectedInteger => formatter.write_str("expected integer"), Error::ExpectedInteger => formatter.write_str("expected integer"),
Error::ExpectedIntegerEnd => formatter.write_str("expected integer end char 'e'"), Error::ExpectedIntegerEnd => formatter.write_str("expected integer end char 'e'"),
Error::ExpectedListEnd => formatter.write_str("expected list end char 'e'"), Error::ExpectedListEnd => formatter.write_str("expected list end char 'e'"),
@ -95,11 +103,22 @@ impl<'de> Deserializer<'de> {
} }
let end_pos = self.input.iter().position(|&x| x == 'e' as u8) let end_pos = self.input.iter().position(|&x| x == 'e' as u8)
.ok_or_else(|| Error::ExpectedIntegerEnd)?; .ok_or_else(|| Error::ExpectedIntegerEnd)?;
let int_str = std::str::from_utf8(&self.input[..end_pos]).map_err(|_| Error::Syntax)?; let int_str = std::str::from_utf8(&self.input[..end_pos])?;
let int = int_str.parse::<T>().map_err(|_| Error::ExpectedInteger)?; let int = int_str.parse::<T>().map_err(|_| Error::ExpectedInteger)?;
self.input = &self.input[end_pos + 1..]; self.input = &self.input[end_pos + 1..];
Ok(int) Ok(int)
} }
fn parse_byte_string(&mut self) -> Result<&'de [u8]> {
let delim_pos = self.input.iter().position(|&x| x == ':' as u8)
.ok_or_else(|| Error::Syntax)?;
let int_str = std::str::from_utf8(&self.input[..delim_pos])?;
let str_len = int_str.parse().map_err(|_| Error::ExpectedInteger)?;
self.input = &self.input[delim_pos + 1..];
let res = &self.input[..str_len];
self.input = &self.input[str_len..];
Ok(res)
}
} }
impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
@ -196,42 +215,56 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
unimplemented!() unimplemented!()
} }
fn deserialize_char<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_char<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() let chars = std::str::from_utf8(self.parse_byte_string()?)?.chars().collect::<Vec<_>>();
match chars.len() {
1 => visitor.visit_char(chars[0]),
0 => Err(Error::Eof),
_ => Err(Error::TrailingCharacters),
}
} }
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value> fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() visitor.visit_borrowed_str(
std::str::from_utf8(
self.parse_byte_string()?
)?
)
} }
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value> fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() visitor.visit_string(
std::str::from_utf8(
self.parse_byte_string()?
)?
.to_string()
)
} }
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value> fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() visitor.visit_borrowed_bytes(self.parse_byte_string()?)
} }
fn deserialize_byte_buf<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_byte_buf<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() visitor.visit_byte_buf(Vec::from(self.parse_byte_string()?))
} }
fn deserialize_option<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_option<V>(self, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
@ -328,4 +361,14 @@ mod test {
// 2**16 + 1 = 65537 // 2**16 + 1 = 65537
assert_eq!(from_bytes(&"i65537e".as_bytes()), Ok(65537)); assert_eq!(from_bytes(&"i65537e".as_bytes()), Ok(65537));
} }
#[test]
fn test_str() {
assert_eq!(from_bytes(&"3:foo".as_bytes()), Ok("foo"));
assert_eq!(from_bytes(&"3:bar".as_bytes()), Ok("bar"));
assert_eq!(from_bytes(&"3:bar".as_bytes()), Ok("bar".to_string()));
assert_eq!(from_bytes(&"3:foo".as_bytes()), Ok("foo".as_bytes()));
assert_eq!(from_bytes(&"1:a".as_bytes()), Ok('a'));
assert_eq!(from_bytes(&"4:💩".as_bytes()), Ok('💩'));
}
} }