diff --git a/src/bencode.rs b/src/bencode.rs index 1673fce..42e8c14 100644 --- a/src/bencode.rs +++ b/src/bencode.rs @@ -1,2 +1,4 @@ pub mod custom; pub mod de; +pub mod ser; +pub mod error; diff --git a/src/bencode/de.rs b/src/bencode/de.rs index 21cf60e..ad2f1fe 100644 --- a/src/bencode/de.rs +++ b/src/bencode/de.rs @@ -1,73 +1,6 @@ -use std::fmt; -use std::fmt::Display; -use std::str::Utf8Error; - -use serde::{de, Deserialize, ser}; +use serde::{de, Deserialize}; use serde::de::{DeserializeSeed, MapAccess, SeqAccess, Visitor}; - -pub type Result = std::result::Result; - -#[derive(Debug, PartialEq)] -pub enum Error { - Message(String), - - WontImplement, - Eof, - Syntax, - InvalidUtf8, - ExpectedBytes, - ExpectedBytesSep, - ExpectedNumbers, - ExpectedInteger, - ExpectedIntegerEnd, - ExpectedList, - ExpectedListEnd, - ExpectedDict, - ExpectedDictEnd, - TrailingCharacters, -} - -impl ser::Error for Error { - fn custom(msg: T) -> Self { - Error::Message(msg.to_string()) - } -} - -impl de::Error for Error { - fn custom(msg: T) -> Self { - Error::Message(msg.to_string()) - } -} - -impl From for Error { - fn from(_: Utf8Error) -> Self { - Error::InvalidUtf8 - } -} - -impl Display for Error { - fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - match self { - Error::Message(msg) => formatter.write_str(msg), - Error::WontImplement => formatter.write_str("there is no reasonable way to deserialize to this type"), - Error::Eof => formatter.write_str("unexpected end of input"), - Error::Syntax => formatter.write_str("syntax error"), - Error::InvalidUtf8 => formatter.write_str("could not decoded as UTF-8"), - Error::ExpectedBytes => formatter.write_str("expected byte string start char: any number"), - Error::ExpectedBytesSep => formatter.write_str("expected byte separator char: ':'"), - Error::ExpectedInteger => formatter.write_str("expected integer start char 'i'"), - Error::ExpectedNumbers => formatter.write_str("expected numbers"), - Error::ExpectedIntegerEnd => formatter.write_str("expected integer end char 'e'"), - Error::ExpectedList => formatter.write_str("expected list start char 'l'"), - Error::ExpectedListEnd => formatter.write_str("expected list end char 'e'"), - Error::ExpectedDict => formatter.write_str("expected dict start char 'd'"), - Error::ExpectedDictEnd => formatter.write_str("expected dict end char 'e'"), - Error::TrailingCharacters => formatter.write_str("trailing characters") - } - } -} - -impl std::error::Error for Error {} +use crate::bencode::error::{Error, Result}; pub struct Deserializer<'de> { input: &'de [u8], @@ -505,8 +438,8 @@ mod test { use std::collections::HashMap; use serde::Deserialize; - - use crate::de::{Error, from_bytes, from_str}; + use crate::bencode::de::{from_bytes, from_str}; + use crate::bencode::error::Error; #[test] fn test_int() { diff --git a/src/bencode/error.rs b/src/bencode/error.rs new file mode 100644 index 0000000..3d1ceb6 --- /dev/null +++ b/src/bencode/error.rs @@ -0,0 +1,68 @@ +use std::fmt::Display; +use std::fmt; +use std::str::Utf8Error; +use serde::{de, ser}; + +#[derive(Debug, PartialEq)] +pub enum Error { + Message(String), + + WontImplement, + Eof, + Syntax, + InvalidUtf8, + ExpectedBytes, + ExpectedBytesSep, + ExpectedNumbers, + ExpectedInteger, + ExpectedIntegerEnd, + ExpectedList, + ExpectedListEnd, + ExpectedDict, + ExpectedDictEnd, + TrailingCharacters, +} + +impl ser::Error for Error { + fn custom(msg: T) -> Self { + Error::Message(msg.to_string()) + } +} + +impl de::Error for Error { + fn custom(msg: T) -> Self { + Error::Message(msg.to_string()) + } +} + +impl From for Error { + fn from(_: Utf8Error) -> Self { + Error::InvalidUtf8 + } +} + +impl Display for Error { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Message(msg) => formatter.write_str(msg), + Error::WontImplement => formatter.write_str("there is no reasonable way to (de)serialize this type"), + Error::Eof => formatter.write_str("unexpected end of input"), + Error::Syntax => formatter.write_str("syntax error"), + Error::InvalidUtf8 => formatter.write_str("could not decoded as UTF-8"), + Error::ExpectedBytes => formatter.write_str("expected byte string start char: any number"), + Error::ExpectedBytesSep => formatter.write_str("expected byte separator char: ':'"), + Error::ExpectedInteger => formatter.write_str("expected integer start char 'i'"), + Error::ExpectedNumbers => formatter.write_str("expected numbers"), + Error::ExpectedIntegerEnd => formatter.write_str("expected integer end char 'e'"), + Error::ExpectedList => formatter.write_str("expected list start char 'l'"), + Error::ExpectedListEnd => formatter.write_str("expected list end char 'e'"), + Error::ExpectedDict => formatter.write_str("expected dict start char 'd'"), + Error::ExpectedDictEnd => formatter.write_str("expected dict end char 'e'"), + Error::TrailingCharacters => formatter.write_str("trailing characters") + } + } +} + +impl std::error::Error for Error {} + +pub type Result = std::result::Result; diff --git a/src/bencode/ser.rs b/src/bencode/ser.rs new file mode 100644 index 0000000..fc4a8cc --- /dev/null +++ b/src/bencode/ser.rs @@ -0,0 +1,361 @@ +use std::collections::BTreeMap; + +use serde::{ser, Serialize}; + +use crate::bencode::error::{Error, Result}; + +pub struct Serializer { + output: Vec, +} + +pub fn to_bytes(value: &T) -> Result> +where + T: Serialize + ?Sized +{ + let mut serializer = Serializer{ + output: Vec::new(), + }; + value.serialize(&mut serializer)?; + Ok(serializer.output) +} + +pub fn to_string(value: &T) -> Result +where + T: Serialize +{ + let bytes = to_bytes(value)?; + Ok(std::str::from_utf8(&bytes[..])?.to_string()) +} + + + +impl<'a> ser::Serializer for &'a mut Serializer +{ + type Ok = (); + + type Error = Error; + + type SerializeSeq = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + type SerializeMap = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + + fn serialize_bool(self, v: bool) -> Result<()> { + Err(Error::WontImplement) + } + + fn serialize_i8(self, v: i8) -> Result<()> { + self.output.extend(format!("i{v}e").into_bytes()); + Ok(()) + } + + fn serialize_i16(self, v: i16) -> Result<()> { + self.output.extend(format!("i{v}e").into_bytes()); + Ok(()) + } + + fn serialize_i32(self, v: i32) -> Result<()> { + self.output.extend(format!("i{v}e").into_bytes()); + Ok(()) + } + + fn serialize_i64(self, v: i64) -> Result<()> { + self.output.extend(format!("i{v}e").into_bytes()); + Ok(()) + } + + fn serialize_i128(self, v: i128) -> Result<()> { + self.output.extend(format!("i{v}e").into_bytes()); + Ok(()) + } + + fn serialize_u8(self, v: u8) -> Result<()> { + todo!() + } + + fn serialize_u16(self, _v: u16) -> Result<()> { + Err(Error::WontImplement) + } + + fn serialize_u32(self, _v: u32) -> Result<()> { + Err(Error::WontImplement) + } + + fn serialize_u64(self, _v: u64) -> Result<()> { + Err(Error::WontImplement) + } + + fn serialize_u128(self, _v: u128) -> Result<()> { + Err(Error::WontImplement) + } + + fn serialize_f32(self, _v: f32) -> Result<()> { + Err(Error::WontImplement) + } + + fn serialize_f64(self, _v: f64) -> Result<()> { + Err(Error::WontImplement) + } + + fn serialize_char(self, v: char) -> Result<()> { + self.output.extend(format!("{}:", v.len_utf8()).into_bytes()); + let mut b = [0; 4]; + v.encode_utf8(&mut b); + self.output.extend(&b[..v.len_utf8()]); + Ok(()) + } + + fn serialize_str(self, v: &str) -> Result<()> { + self.output.extend(format!("{}:", v.len()).into_bytes()); + self.output.extend(v.as_bytes()); + Ok(()) + } + + fn serialize_bytes(self, v: &[u8]) -> Result<()> { + self.output.extend(format!("{}:", v.len()).into_bytes()); + self.output.extend(v); + Ok(()) + } + + fn serialize_none(self) -> Result<()> { + todo!() + } + + fn serialize_some(self, value: &T) -> Result<()> + where + T: ?Sized + Serialize + { + todo!() + } + + fn serialize_unit(self) -> Result<()> { + Err(Error::WontImplement) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + Err(Error::WontImplement) + } + + fn serialize_unit_variant(self, _name: &'static str, _variant_index: u32, _variant: &'static str) -> Result<()> { + Err(Error::WontImplement) + } + + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize + { + value.serialize(self) + } + + fn serialize_newtype_variant(self, name: &'static str, variant_index: u32, variant: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize + { + Err(Error::WontImplement) + } + + fn serialize_seq(self, _len: Option) -> Result { + self.output.push('l' as u8); + Ok(self) + } + + fn serialize_tuple(self, len: usize) -> Result { + println!("tuple"); + todo!() + } + + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + Err(Error::WontImplement) + } + + fn serialize_tuple_variant(self, _name: &'static str, _variant_index: u32, _variant: &'static str, _len: usize) -> Result { + Err(Error::WontImplement) + } + + fn serialize_map(self, _len: Option) -> Result { + self.output.push('d' as u8); + Ok(self) + } + + fn serialize_struct(self, name: &'static str, len: usize) -> Result { + todo!() + } + + fn serialize_struct_variant(self, _name: &'static str, _variant_index: u32, _variant: &'static str, _len: usize) -> Result { + Err(Error::WontImplement) + } +} + +impl <'a> ser::SerializeSeq for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize + { + value.serialize(&mut **self) + } + + fn end(self) -> Result<()> { + self.output.push('e' as u8); + Ok(()) + } +} + +impl <'a> ser::SerializeTuple for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize + { + todo!() + } + + fn end(self) -> Result<()> { + todo!() + } +} + +impl <'a> ser::SerializeTupleStruct for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize + { + Err(Error::WontImplement) + } + + fn end(self) -> Result<()> { + Err(Error::WontImplement) + } +} + +impl <'a> ser::SerializeTupleVariant for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize + { + Err(Error::WontImplement) + } + + fn end(self) -> Result<()> { + Err(Error::WontImplement) + } +} + +impl <'a> ser::SerializeMap for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_key(&mut self, key: &T) -> Result<()> + where + T: ?Sized + Serialize + { + key.serialize(&mut **self) + } + + fn serialize_value(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize + { + value.serialize(&mut **self) + } + + fn end(self) -> Result<()> { + // TODO serialize only here and sort first + self.output.push('e' as u8); + Ok(()) + } +} + +impl <'a> ser::SerializeStruct for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize + { + todo!() + } + + fn end(self) -> Result<()> { + self.output.push('e' as u8); + Ok(()) + } +} + +impl <'a> ser::SerializeStructVariant for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize + { + Err(Error::WontImplement) + } + + fn end(self) -> Result<()> { + Err(Error::WontImplement) + } +} + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use crate::bencode::ser::to_string; + + #[test] + fn test_int() { + assert_eq!(to_string(&42).unwrap(), "i42e"); + assert_eq!(to_string(&13).unwrap(), "i13e"); + } + + #[test] + fn test_char() { + assert_eq!(to_string(&'b').unwrap(), "1:b"); + assert_eq!(to_string(&'💩').unwrap(), "4:💩"); + } + + #[test] + fn test_str() { + assert_eq!(to_string(&"foo").unwrap(), "3:foo"); + assert_eq!(to_string(&"bar").unwrap(), "3:bar"); + assert_eq!(to_string(&"foo💩").unwrap(), "7:foo💩"); + } + + #[test] + fn test_seq() { + assert_eq!(to_string(&vec![42, 13, 7]).unwrap(), "li42ei13ei7ee"); + assert_eq!(to_string(&vec!["foo", "bar", "💩"]).unwrap(), "l3:foo3:bar4:💩e"); + } + + #[test] + fn test_map() { + assert_eq!(to_string(&HashMap::from([ + ("foo", 42), + ("bar", 13), + ])).unwrap(), + "d3:bari13e3:fooi42ee" + ); + assert_eq!(to_string(&HashMap::from([ + ("foo", "💩"), + ("bar", "🙈"), + ])).unwrap(), + "d3:bar4:🙈3:foo4:💩e" + ); + } +}