From 40a298f3cc01f4f5fb3b357dd593d26ac717ca5b Mon Sep 17 00:00:00 2001 From: Faerbit Date: Sun, 11 Aug 2024 23:55:06 +0200 Subject: [PATCH] [torrent] Moved torrent structures to own module and added correctness test --- src/main.rs | 97 +---------------------------------------- src/torrent.rs | 114 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 95 deletions(-) create mode 100644 src/torrent.rs diff --git a/src/main.rs b/src/main.rs index cf17f66..348eb58 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,110 +1,17 @@ // who do you think you are? my mom?! #![allow(uncommon_codepoints)] -use std::{env, fs, path}; +use std::{env, fs}; -use crate::bencode::de::from_bytes; -use crate::bencode::ser::to_bytes; use anyhow::{anyhow, Result}; use rand::prelude::*; use reqwest::blocking::Client; use reqwest::Url; -use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; -use serde::ser::SerializeSeq; -use serde_bytes::ByteBuf; -use sha1::{Digest, Sha1}; +use torrent::Torrent; mod bencode; mod torrent; -#[derive(Serialize, Deserialize, Debug)] -struct FileInfo { - length: i64, - #[serde(deserialize_with = "deserialize_path", serialize_with = "serialize_path")] - path: String, -} - -fn deserialize_path<'de, D>(deserializer: D) -> Result -where D: Deserializer<'de> { - let parts: Vec = Vec::deserialize(deserializer)?; - for p in parts.iter() { - let res = p.find(path::MAIN_SEPARATOR_STR); - if let Some(_) = res { - return Err(de::Error::custom(format!("Unable to deal with {} (platform path separator in path parts", path::MAIN_SEPARATOR_STR))) - } - } - Ok(parts.join(path::MAIN_SEPARATOR_STR)) -} - -fn serialize_path(v: &String, serializer: S) -> Result -where S: Serializer { - let parts: Vec<&str> = v.split(path::MAIN_SEPARATOR_STR).collect(); - let mut seq = serializer.serialize_seq(Some(parts.len()))?; - for p in parts { - seq.serialize_element(p)?; - } - seq.end() -} - -#[derive(Serialize, Deserialize, Debug)] -struct TorrentInfo { - files: Option>, - length: Option, - name: String, - #[serde(rename(serialize = "piece length", deserialize = "piece length"))] - piece_length: i64, - #[serde(deserialize_with = "deserialize_pieces", serialize_with = "serialize_pieces")] - pieces: Vec<[u8; 20]>, -} - -fn deserialize_pieces<'de, D>(deserializer: D) -> Result, D::Error> -where D: Deserializer<'de> { - let all_pieces = ByteBuf::deserialize(deserializer)?.into_vec(); - if all_pieces.len() % 20 != 0 { - return Err(de::Error::custom("Pieces string length not a multiple of 20")) - } - - // TODO maybe handle the error, even though it should be checked above - let pieces = all_pieces.chunks(20).map(|x| x.try_into().unwrap()).collect(); - Ok(pieces) -} - -fn serialize_pieces(v: &Vec<[u8; 20]>, serializer: S) -> Result -where S: Serializer -{ - let mut buf = Vec::new(); - for i in v { - buf.extend(i) - } - serializer.serialize_bytes(&buf[..]) -} - -#[derive(Serialize, Deserialize, Debug)] -struct Torrent { - announce: String, - info: TorrentInfo, - #[serde(skip)] - info_hash: [u8; 20], -} - -impl Torrent { - fn from(input: &[u8]) -> Result { - let mut torrent: Self = from_bytes(input)?; - torrent.compute_info_hash()?; - Ok(torrent) - } - - fn compute_info_hash(&mut self) -> Result<()> { - let mut hasher = Sha1::new(); - - let info_str = to_bytes(&self.info)?; - hasher.update(info_str); - - self.info_hash = hasher.finalize().into(); - Ok(()) - } -} - fn main() -> Result<()> { let args: Vec = env::args().collect(); if args.len() < 2 { diff --git a/src/torrent.rs b/src/torrent.rs new file mode 100644 index 0000000..f1084d3 --- /dev/null +++ b/src/torrent.rs @@ -0,0 +1,114 @@ +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use std::path; +use serde_bytes::ByteBuf; +use sha1::{Digest, Sha1}; +use serde::ser::SerializeSeq; +use crate::bencode::de::from_bytes; +use crate::bencode::ser::to_bytes; + +#[derive(Serialize, Deserialize, Debug)] +pub struct FileInfo { + pub length: i64, + #[serde(deserialize_with = "deserialize_path", serialize_with = "serialize_path")] + pub path: String, +} + +fn deserialize_path<'de, D>(deserializer: D) -> anyhow::Result +where D: Deserializer<'de> { + let parts: Vec = Vec::deserialize(deserializer)?; + for p in parts.iter() { + let res = p.find(path::MAIN_SEPARATOR_STR); + if let Some(_) = res { + return Err(de::Error::custom(format!("Unable to deal with {} (platform path separator in path parts", path::MAIN_SEPARATOR_STR))) + } + } + Ok(parts.join(path::MAIN_SEPARATOR_STR)) +} + +fn serialize_path(v: &String, serializer: S) -> anyhow::Result +where S: Serializer { + let parts: Vec<&str> = v.split(path::MAIN_SEPARATOR_STR).collect(); + let mut seq = serializer.serialize_seq(Some(parts.len()))?; + for p in parts { + seq.serialize_element(p)?; + } + seq.end() +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct TorrentInfo { + pub files: Option>, + pub length: Option, + pub name: String, + #[serde(rename(serialize = "piece length", deserialize = "piece length"))] + pub piece_length: i64, + #[serde(deserialize_with = "deserialize_pieces", serialize_with = "serialize_pieces")] + pub pieces: Vec<[u8; 20]>, +} + +fn deserialize_pieces<'de, D>(deserializer: D) -> anyhow::Result, D::Error> +where D: Deserializer<'de> { + let all_pieces = ByteBuf::deserialize(deserializer)?.into_vec(); + if all_pieces.len() % 20 != 0 { + return Err(de::Error::custom("Pieces string length not a multiple of 20")) + } + + // TODO maybe handle the error, even though it should be checked above + let pieces = all_pieces.chunks(20).map(|x| x.try_into().unwrap()).collect(); + Ok(pieces) +} + +fn serialize_pieces(v: &Vec<[u8; 20]>, serializer: S) -> anyhow::Result +where S: Serializer +{ + let mut buf = Vec::new(); + for i in v { + buf.extend(i) + } + serializer.serialize_bytes(&buf[..]) +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Torrent { + pub announce: String, + pub info: TorrentInfo, + #[serde(skip)] + pub info_hash: [u8; 20], +} + +impl Torrent { + pub fn from(input: &[u8]) -> anyhow::Result { + let mut torrent: Self = from_bytes(input)?; + torrent.compute_info_hash()?; + Ok(torrent) + } + + fn compute_info_hash(&mut self) -> anyhow::Result<()> { + let mut hasher = Sha1::new(); + + let info_str = to_bytes(&self.info)?; + hasher.update(info_str); + + self.info_hash = hasher.finalize().into(); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{env, fs}; + use std::path::Path; + use crate::torrent::Torrent; + + #[test] + fn check_correct_info_hash() { + let manifest_path = env::var_os("CARGO_MANIFEST_DIR").unwrap(); + let path = Path::new(&manifest_path).join("test/Fedora-Workstation-Live-x86_64-40.torrent"); + let bytes = fs::read(&path).unwrap(); + + let torrent = Torrent::from(&bytes).unwrap(); + + let info_hash_str = torrent.info_hash.iter().map(|x| {format!("{x:02x}")}).collect::>().join(""); + assert_eq!(info_hash_str, "1021075bad21641897c85f1a4369569d93315f63"); + } +}