[serde_bencode] Implemented sequences

This commit is contained in:
Fabian 2024-08-02 22:47:50 +02:00
parent 4caab65bc1
commit 674f4a8af4

View File

@ -2,7 +2,7 @@ use std::fmt;
use std::fmt::Display; use std::fmt::Display;
use std::str::Utf8Error; use std::str::Utf8Error;
use serde::{de, Deserialize, ser}; use serde::{de, Deserialize, ser};
use serde::de::Visitor; use serde::de::{DeserializeSeed, SeqAccess, Visitor};
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@ -10,12 +10,18 @@ pub type Result<T> = std::result::Result<T, Error>;
pub enum Error { pub enum Error {
Message(String), Message(String),
WontImplement,
Eof, Eof,
Syntax, Syntax,
InvalidUtf8, InvalidUtf8,
ExpectedBytes,
ExpectedBytesSep,
ExpectedNumbers,
ExpectedInteger, ExpectedInteger,
ExpectedIntegerEnd, ExpectedIntegerEnd,
ExpectedList,
ExpectedListEnd, ExpectedListEnd,
ExpectedDict,
ExpectedDictEnd, ExpectedDictEnd,
TrailingCharacters, TrailingCharacters,
} }
@ -42,12 +48,18 @@ impl Display for Error {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
Error::Message(msg) => formatter.write_str(msg), Error::Message(msg) => formatter.write_str(msg),
Error::WontImplement => formatter.write_str("there is no reasonable way to deserialize to this type"),
Error::Eof => formatter.write_str("unexpected end of input"), Error::Eof => formatter.write_str("unexpected end of input"),
Error::Syntax => formatter.write_str("syntax error"), Error::Syntax => formatter.write_str("syntax error"),
Error::InvalidUtf8 => formatter.write_str("could not decoded as UTF-8"), Error::InvalidUtf8 => formatter.write_str("could not decoded as UTF-8"),
Error::ExpectedInteger => formatter.write_str("expected integer"), Error::ExpectedBytes => formatter.write_str("expected byte string start char: any number"),
Error::ExpectedBytesSep => formatter.write_str("expected byte separator char: ':'"),
Error::ExpectedInteger => formatter.write_str("expected integer start char 'i'"),
Error::ExpectedNumbers => formatter.write_str("expected numbers"),
Error::ExpectedIntegerEnd => formatter.write_str("expected integer end char 'e'"), Error::ExpectedIntegerEnd => formatter.write_str("expected integer end char 'e'"),
Error::ExpectedList => formatter.write_str("expected list start char 'l'"),
Error::ExpectedListEnd => formatter.write_str("expected list end char 'e'"), Error::ExpectedListEnd => formatter.write_str("expected list end char 'e'"),
Error::ExpectedDict => formatter.write_str("expected dict start char 'd'"),
Error::ExpectedDictEnd => formatter.write_str("expected dict end char 'e'"), Error::ExpectedDictEnd => formatter.write_str("expected dict end char 'e'"),
Error::TrailingCharacters => formatter.write_str("trailing characters") Error::TrailingCharacters => formatter.write_str("trailing characters")
} }
@ -99,21 +111,21 @@ impl<'de> Deserializer<'de> {
T: std::str::FromStr, T: std::str::FromStr,
{ {
if self.next_byte()? != 'i' as u8 { if self.next_byte()? != 'i' as u8 {
return Err(Error::Syntax) return Err(Error::ExpectedInteger)
} }
let end_pos = self.input.iter().position(|&x| x == 'e' as u8) let end_pos = self.input.iter().position(|&x| x == 'e' as u8)
.ok_or_else(|| Error::ExpectedIntegerEnd)?; .ok_or_else(|| Error::ExpectedIntegerEnd)?;
let int_str = std::str::from_utf8(&self.input[..end_pos])?; let int_str = std::str::from_utf8(&self.input[..end_pos])?;
let int = int_str.parse::<T>().map_err(|_| Error::ExpectedInteger)?; let int = int_str.parse::<T>().map_err(|_| Error::ExpectedNumbers)?;
self.input = &self.input[end_pos + 1..]; self.input = &self.input[end_pos + 1..];
Ok(int) Ok(int)
} }
fn parse_byte_string(&mut self) -> Result<&'de [u8]> { fn parse_byte_string(&mut self) -> Result<&'de [u8]> {
let delim_pos = self.input.iter().position(|&x| x == ':' as u8) let delim_pos = self.input.iter().position(|&x| x == ':' as u8)
.ok_or_else(|| Error::Syntax)?; .ok_or_else(|| Error::ExpectedBytesSep)?;
let int_str = std::str::from_utf8(&self.input[..delim_pos])?; let int_str = std::str::from_utf8(&self.input[..delim_pos])?;
let str_len = int_str.parse().map_err(|_| Error::ExpectedInteger)?; let str_len = int_str.parse().map_err(|_| Error::ExpectedNumbers)?;
self.input = &self.input[delim_pos + 1..]; self.input = &self.input[delim_pos + 1..];
let res = &self.input[..str_len]; let res = &self.input[..str_len];
self.input = &self.input[str_len..]; self.input = &self.input[str_len..];
@ -124,18 +136,19 @@ impl<'de> Deserializer<'de> {
impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
type Error = Error; type Error = Error;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value> fn deserialize_any<V>(self, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() println!("i don't get this API!");
Err(Error::WontImplement)
} }
fn deserialize_bool<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_bool<V>(self, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value> fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
@ -177,42 +190,42 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_u16<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_u16<V>(self, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_u32<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_u32<V>(self, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_u64<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_u64<V>(self, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_f32<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_f32<V>(self, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_f64<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_f64<V>(self, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_char<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_char<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
@ -268,49 +281,58 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_unit<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_unit<V>(self, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_unit_struct<V>(self, _: &'static str, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_unit_struct<V>(self, _: &'static str, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value> fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() visitor.visit_newtype_struct(self)
} }
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value> fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() if self.next_byte()? as char != 'l'{
Err(Error::ExpectedList)
} else {
let value = visitor.visit_seq(Access::new(self))?;
if self.next_byte()? as char != 'e' {
Err(Error::ExpectedListEnd)
} else {
Ok(value)
}
}
} }
fn deserialize_tuple<V>(self, _: usize, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_tuple<V>(self, _: usize, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_tuple_struct<V>(self, _: &'static str, _: usize, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_tuple_struct<V>(self, _: &'static str, _: usize, _: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value> fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
@ -331,7 +353,7 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
where where
V: Visitor<'de> V: Visitor<'de>
{ {
unimplemented!() Err(Error::WontImplement)
} }
fn deserialize_identifier<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_identifier<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
@ -349,6 +371,35 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
} }
} }
struct Access<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>,
}
impl <'a, 'de> Access<'a, 'de> {
fn new(de: &'a mut Deserializer<'de>) -> Self {
Access {
de,
}
}
}
impl<'a, 'de> SeqAccess<'de> for Access<'a, 'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> std::result::Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>
{
match self.de.peek_byte()? as char {
'd' | 'i' | 'l' | '0'..='9' => {
seed.deserialize(&mut *self.de).map(Some)
},
'e' => Ok(None),
_ => Err(Error::Syntax)
}
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::serde_bencode::{Error, from_bytes}; use crate::serde_bencode::{Error, from_bytes};
@ -371,4 +422,10 @@ mod test {
assert_eq!(from_bytes(&"1:a".as_bytes()), Ok('a')); assert_eq!(from_bytes(&"1:a".as_bytes()), Ok('a'));
assert_eq!(from_bytes(&"4:💩".as_bytes()), Ok('💩')); assert_eq!(from_bytes(&"4:💩".as_bytes()), Ok('💩'));
} }
#[test]
fn test_list() {
assert_eq!(from_bytes(&"li42ei13ee".as_bytes()), Ok(vec![42, 13]));
assert_eq!(from_bytes(&"l3:foo4:barie".as_bytes()), Ok(vec!["foo", "bari"]));
}
} }