[serde_bencode] Implemented the rest of the fucking owl

This commit is contained in:
Fabian 2024-08-03 16:20:46 +02:00
parent 674f4a8af4
commit cf3fc81f06
3 changed files with 200 additions and 54 deletions

View File

@ -9,7 +9,7 @@ sha1 = "0.10.6"
reqwest = { version = "0.12.5", features = ["blocking"] } reqwest = { version = "0.12.5", features = ["blocking"] }
rand = "0.8.5" rand = "0.8.5"
urlencoding = "2.1.3" urlencoding = "2.1.3"
serde = "1.0.204" serde = { version = "1.0.204", features = ["derive"] }
[dev-dependencies] [dev-dependencies]
hex-literal = "0.4.1" hex-literal = "0.4.1"

View File

@ -1,3 +1,6 @@
// who do you think you are? my mom?!
#![allow(uncommon_codepoints)]
use std::{env, fs, path}; use std::{env, fs, path};
use std::collections::HashMap; use std::collections::HashMap;

View File

@ -1,8 +1,9 @@
use std::fmt; use std::fmt;
use std::fmt::Display; use std::fmt::Display;
use std::str::Utf8Error; use std::str::Utf8Error;
use serde::{de, Deserialize, ser}; use serde::{de, Deserialize, ser};
use serde::de::{DeserializeSeed, SeqAccess, Visitor}; use serde::de::{DeserializeSeed, MapAccess, SeqAccess, Visitor};
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@ -39,7 +40,7 @@ impl de::Error for Error {
} }
impl From<Utf8Error> for Error { impl From<Utf8Error> for Error {
fn from(value: Utf8Error) -> Self { fn from(_: Utf8Error) -> Self {
Error::InvalidUtf8 Error::InvalidUtf8
} }
} }
@ -70,15 +71,19 @@ impl std::error::Error for Error {}
pub struct Deserializer<'de> { pub struct Deserializer<'de> {
input: &'de [u8], input: &'de [u8],
𓁺: Option<u8>,
} }
impl<'de> Deserializer<'de> { impl<'de> Deserializer<'de> {
pub fn from_bytes(input: &'de [u8]) -> Self { pub fn from_bytes(input: &'de [u8]) -> Self {
Deserializer { input } Deserializer {
input,
𓁺: None,
}
} }
} }
pub fn from_bytes<'a, T>(b: &'a &[u8]) -> Result<T> pub fn from_bytes<'a, T>(b: &'a [u8]) -> Result<T>
where where
T: Deserialize<'a>, T: Deserialize<'a>,
{ {
@ -91,6 +96,13 @@ where
} }
} }
pub fn from_str<'a, T>(s: &'a str) -> Result<T>
where
T: Deserialize<'a>,
{
from_bytes(s.as_bytes())
}
impl<'de> Deserializer<'de> { impl<'de> Deserializer<'de> {
fn peek_byte(&mut self) -> Result<u8> { fn peek_byte(&mut self) -> Result<u8> {
if let Some(&result) = self.input.get(0) { if let Some(&result) = self.input.get(0) {
@ -144,7 +156,7 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
Err(Error::WontImplement) Err(Error::WontImplement)
} }
fn deserialize_bool<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_bool<V>(self, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
@ -186,49 +198,53 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
visitor.visit_i128(self.parse_int()?) visitor.visit_i128(self.parse_int()?)
} }
fn deserialize_u8<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>
{
if let Some(ch) = self.𓁺 {
visitor.visit_u8(ch)
} else {
Err(Error::WontImplement)
}
}
fn deserialize_u16<V>(self, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
Err(Error::WontImplement) Err(Error::WontImplement)
} }
fn deserialize_u16<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_u32<V>(self, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
Err(Error::WontImplement) Err(Error::WontImplement)
} }
fn deserialize_u32<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_u64<V>(self, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
Err(Error::WontImplement) Err(Error::WontImplement)
} }
fn deserialize_u64<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_f32<V>(self, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
Err(Error::WontImplement) Err(Error::WontImplement)
} }
fn deserialize_f32<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_f64<V>(self, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
Err(Error::WontImplement) Err(Error::WontImplement)
} }
fn deserialize_f64<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>
{
Err(Error::WontImplement)
}
fn deserialize_char<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
@ -270,28 +286,28 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
visitor.visit_borrowed_bytes(self.parse_byte_string()?) visitor.visit_borrowed_bytes(self.parse_byte_string()?)
} }
fn deserialize_byte_buf<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
visitor.visit_byte_buf(Vec::from(self.parse_byte_string()?)) visitor.visit_byte_buf(Vec::from(self.parse_byte_string()?))
} }
fn deserialize_option<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_option<V>(self, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
Err(Error::WontImplement) Err(Error::WontImplement)
} }
fn deserialize_unit<V>(self, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_unit<V>(self, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
Err(Error::WontImplement) Err(Error::WontImplement)
} }
fn deserialize_unit_struct<V>(self, _: &'static str, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_unit_struct<V>(self, _: &'static str, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
@ -312,7 +328,7 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
if self.next_byte()? as char != 'l'{ if self.next_byte()? as char != 'l'{
Err(Error::ExpectedList) Err(Error::ExpectedList)
} else { } else {
let value = visitor.visit_seq(Access::new(self))?; let value = visitor.visit_seq(RegularAccess::new(self))?;
if self.next_byte()? as char != 'e' { if self.next_byte()? as char != 'e' {
Err(Error::ExpectedListEnd) Err(Error::ExpectedListEnd)
} else { } else {
@ -321,14 +337,22 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
} }
} }
fn deserialize_tuple<V>(self, _: usize, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
Err(Error::WontImplement) match self.peek_byte()? as char{
'l' => self.deserialize_seq(visitor),
'0'..='9' => {
let str = self.parse_byte_string()?;
let ba = ByteAccess::new(self, str);
visitor.visit_seq(ba)
},
_ => Err(Error::Syntax),
}
} }
fn deserialize_tuple_struct<V>(self, _: &'static str, _: usize, _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_tuple_struct<V>(self, _: &'static str, _: usize, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
@ -339,54 +363,101 @@ impl <'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() if self.next_byte()? as char != 'd'{
Err(Error::ExpectedDict)
} else {
let value = visitor.visit_map(RegularAccess::new(self))?;
if self.next_byte()? as char != 'e' {
Err(Error::ExpectedDictEnd)
} else {
Ok(value)
}
}
} }
fn deserialize_struct<V>(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> Result<V::Value> fn deserialize_struct<V>(self, _name: &'static str, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() self.deserialize_map(visitor)
} }
fn deserialize_enum<V>(self, _: &'static str, _: &'static [&'static str], _: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_enum<V>(self, _: &'static str, _: &'static [&'static str], _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
Err(Error::WontImplement) Err(Error::WontImplement)
} }
fn deserialize_identifier<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() self.deserialize_bytes(visitor)
} }
fn deserialize_ignored_any<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error> fn deserialize_ignored_any<V>(self, _: V) -> Result<V::Value>
where where
V: Visitor<'de> V: Visitor<'de>
{ {
todo!() println!("i really don't get this API!");
Err(Error::WontImplement)
} }
} }
struct Access<'a, 'de: 'a> { struct ByteAccess<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>,
bytes: &'a [u8],
}
impl <'a, 'de> ByteAccess<'a, 'de> {
fn new(de: &'a mut Deserializer<'de>, bytes: &'a [u8]) -> Self {
ByteAccess {
de,
bytes,
}
}
}
impl<'a, 'de> SeqAccess<'de> for ByteAccess<'a, 'de> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: DeserializeSeed<'de>
{
if let Some(&ch) = self.bytes.get(0) {
self.de.𓁺 = Some(ch);
self.bytes = &self.bytes[1..];
seed.deserialize(&mut *self.de).map(Some)
} else {
self.de.𓁺 = None;
Ok(None)
}
}
fn size_hint(&self) -> Option<usize> {
Some(self.bytes.len())
}
}
struct RegularAccess<'a, 'de: 'a> {
de: &'a mut Deserializer<'de>, de: &'a mut Deserializer<'de>,
} }
impl <'a, 'de> Access<'a, 'de> { impl <'a, 'de> RegularAccess<'a, 'de> {
fn new(de: &'a mut Deserializer<'de>) -> Self { fn new(de: &'a mut Deserializer<'de>) -> Self {
Access { RegularAccess {
de, de,
} }
} }
} }
impl<'a, 'de> SeqAccess<'de> for Access<'a, 'de> { impl<'a, 'de> SeqAccess<'de> for RegularAccess<'a, 'de> {
type Error = Error; type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> std::result::Result<Option<T::Value>, Self::Error> fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where where
T: DeserializeSeed<'de> T: DeserializeSeed<'de>
{ {
@ -400,32 +471,104 @@ impl<'a, 'de> SeqAccess<'de> for Access<'a, 'de> {
} }
} }
impl<'a, 'de> MapAccess<'de> for RegularAccess<'a, 'de> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: DeserializeSeed<'de>
{
match self.de.peek_byte()? as char {
'0'..='9' => {
seed.deserialize(&mut *self.de).map(Some)
},
'e' => Ok(None),
_ => Err(Error::ExpectedBytes)
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: DeserializeSeed<'de>
{
match self.de.peek_byte()? as char {
'd' | 'i' | 'l' | '0'..='9' => {
seed.deserialize(&mut *self.de)
},
_ => Err(Error::Syntax)
}
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::serde_bencode::{Error, from_bytes}; use std::collections::HashMap;
use serde::Deserialize;
use crate::serde_bencode::{Error, from_bytes, from_str};
#[test] #[test]
fn test_int() { fn test_int() {
assert_eq!(from_bytes(&"i42e".as_bytes()), Ok(42)); assert_eq!(from_str("i42e"), Ok(42));
assert_eq!(from_bytes::<i32>(&"i42".as_bytes()), Err(Error::ExpectedIntegerEnd)); assert_eq!(from_str::<i32>("i42"), Err(Error::ExpectedIntegerEnd));
assert_eq!(from_bytes::<i32>(&"42e".as_bytes()), Err(Error::Syntax)); assert_eq!(from_str::<i32>("42e"), Err(Error::ExpectedInteger));
// 2**16 + 1 = 65537 // 2**16 + 1 = 65537
assert_eq!(from_bytes(&"i65537e".as_bytes()), Ok(65537)); assert_eq!(from_str("i65537e"), Ok(65537));
} }
#[test] #[test]
fn test_str() { fn test_str() {
assert_eq!(from_bytes(&"3:foo".as_bytes()), Ok("foo")); assert_eq!(from_str("3:foo"), Ok("foo"));
assert_eq!(from_bytes(&"3:bar".as_bytes()), Ok("bar")); assert_eq!(from_str("3:bar"), Ok("bar"));
assert_eq!(from_bytes(&"3:bar".as_bytes()), Ok("bar".to_string())); assert_eq!(from_str("3:bar"), Ok("bar".to_string()));
assert_eq!(from_bytes(&"3:foo".as_bytes()), Ok("foo".as_bytes())); assert_eq!(from_str("3:foo"), Ok("foo"));
assert_eq!(from_bytes(&"1:a".as_bytes()), Ok('a')); assert_eq!(from_str("1:a"), Ok('a'));
assert_eq!(from_bytes(&"4:💩".as_bytes()), Ok('💩')); assert_eq!(from_str("4:💩"), Ok('💩'));
} }
#[test] #[test]
fn test_list() { fn test_seq() {
assert_eq!(from_bytes(&"li42ei13ee".as_bytes()), Ok(vec![42, 13])); assert_eq!(from_str("li42ei13ee"), Ok(vec![42, 13]));
assert_eq!(from_bytes(&"l3:foo4:barie".as_bytes()), Ok(vec!["foo", "bari"])); assert_eq!(from_str("l3:foo4:barie"), Ok(vec!["foo", "bari"]));
assert_eq!(from_str("lli42ei13eelee"), Ok(vec![vec![42, 13], vec![]]));
}
#[test]
fn test_map() {
assert_eq!(from_str("d3:fooi42ee"), Ok(HashMap::from([
("foo", 42)
])));
assert_eq!(from_str("d3:foo3:bare"), Ok(HashMap::from([
("foo", "bar")
])));
}
#[test]
fn test_struct() {
#[derive(Deserialize, Debug)]
struct A {
a: i64,
}
let de: A = from_str("d1:ai42ee").unwrap();
assert_eq!(de.a, 42);
#[derive(Deserialize, Debug)]
struct B {
l: Vec<i64>,
d: HashMap<String, [u8; 3]>
}
let mut buf = "d1:lli42ei13ei23ee1:dd3:foo3:".as_bytes().to_vec();
buf.extend(vec![0x1, 0x2, 0x3]);
buf.extend("3:bar3:".as_bytes());
buf.extend(vec![0x19, 0x31, 0x17]);
buf.extend("ee".as_bytes());
let de: B = from_bytes(&buf[..]).unwrap();
assert_eq!(de.l, vec![42, 13, 23]);
assert_eq!(de.d, HashMap::from([
("foo".to_string(), [0x1, 0x2, 0x3]),
("bar".to_string(), [0x19, 0x31, 0x17]),
]));
} }
} }