From 330f1c74494c5c96c184a5d15e4dfdbabaf23888 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tristram=20Gr=C3=A4bener?= Date: Wed, 29 Dec 2021 10:51:55 +0100 Subject: [PATCH 1/6] Make a loop more rusty --- src/gtfs_reader.rs | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/gtfs_reader.rs b/src/gtfs_reader.rs index e745669..34699cd 100644 --- a/src/gtfs_reader.rs +++ b/src/gtfs_reader.rs @@ -299,25 +299,26 @@ where })? .clone(); - let mut res = Vec::new(); - for rec in reader.records() { - let r = rec.map_err(|e| Error::CSVError { - file_name: file_name.to_owned(), - source: e, - line_in_error: None, - })?; - let o = r.deserialize(Some(&headers)).map_err(|e| Error::CSVError { - file_name: file_name.to_owned(), - source: e, - line_in_error: Some(crate::error::LineError { - headers: headers.into_iter().map(|s| s.to_owned()).collect(), - values: r.into_iter().map(|s| s.to_owned()).collect(), - }), - })?; - res.push(o); - } - - Ok(res) + reader + .records() + .map(|rec| { + rec.map_err(|e| Error::CSVError { + file_name: file_name.to_owned(), + source: e, + line_in_error: None, + }) + .and_then(|r| { + r.deserialize(Some(&headers)).map_err(|e| Error::CSVError { + file_name: file_name.to_owned(), + source: e, + line_in_error: Some(crate::error::LineError { + headers: headers.into_iter().map(|s| s.to_owned()).collect(), + values: r.into_iter().map(|s| s.to_owned()).collect(), + }), + }) + }) + }) + .collect() } fn read_objs_from_path(path: std::path::PathBuf) -> Result, Error> From 2c5f163348843d8b7472563640232e5fdea141b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tristram=20Gr=C3=A4bener?= Date: Wed, 29 Dec 2021 12:38:19 +0100 Subject: [PATCH 2/6] Try to use rayon to speed up the parsing --- Cargo.toml | 1 + src/gtfs_reader.rs | 33 ++++++++++++++++++--------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fc57eec..532db31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ sha2 = "0.10" zip = "0.5" thiserror = "1" rgb = "0.8" +rayon = "1.5" futures = { version = "0.3", optional = true } reqwest = { version = "0.11", optional = true, features = ["blocking"]} diff --git a/src/gtfs_reader.rs b/src/gtfs_reader.rs index 34699cd..d3bf14e 100644 --- a/src/gtfs_reader.rs +++ b/src/gtfs_reader.rs @@ -3,6 +3,7 @@ use serde::Deserialize; use sha2::{Digest, Sha256}; use crate::{Error, Gtfs, RawGtfs}; +use rayon::prelude::*; use std::collections::HashMap; use std::convert::TryFrom; use std::fs::File; @@ -268,7 +269,7 @@ impl RawGtfsReader { fn read_objs(mut reader: T, file_name: &str) -> Result, Error> where - for<'de> O: Deserialize<'de>, + for<'de> O: Deserialize<'de> + Send, T: std::io::Read, { let mut bom = [0; 3]; @@ -299,7 +300,7 @@ where })? .clone(); - reader + let v = reader .records() .map(|rec| { rec.map_err(|e| Error::CSVError { @@ -307,15 +308,17 @@ where source: e, line_in_error: None, }) - .and_then(|r| { - r.deserialize(Some(&headers)).map_err(|e| Error::CSVError { - file_name: file_name.to_owned(), - source: e, - line_in_error: Some(crate::error::LineError { - headers: headers.into_iter().map(|s| s.to_owned()).collect(), - values: r.into_iter().map(|s| s.to_owned()).collect(), - }), - }) + }) + .collect::, Error>>()?; + v.par_iter() + .map(|r| { + r.deserialize(Some(&headers)).map_err(|e| Error::CSVError { + file_name: file_name.to_owned(), + source: e, + line_in_error: Some(crate::error::LineError { + headers: headers.into_iter().map(|s| s.to_owned()).collect(), + values: r.into_iter().map(|s| s.to_owned()).collect(), + }), }) }) .collect() @@ -323,7 +326,7 @@ where fn read_objs_from_path(path: std::path::PathBuf) -> Result, Error> where - for<'de> O: Deserialize<'de>, + for<'de> O: Deserialize<'de> + Send, { let file_name = path .file_name() @@ -347,7 +350,7 @@ fn read_objs_from_optional_path( file_name: &str, ) -> Option, Error>> where - for<'de> O: Deserialize<'de>, + for<'de> O: Deserialize<'de> + Send, { File::open(dir_path.join(file_name)) .ok() @@ -360,7 +363,7 @@ fn read_file( file_name: &str, ) -> Result, Error> where - for<'de> O: Deserialize<'de>, + for<'de> O: Deserialize<'de> + Send, T: std::io::Read + std::io::Seek, { read_optional_file(file_mapping, archive, file_name) @@ -373,7 +376,7 @@ fn read_optional_file( file_name: &str, ) -> Option, Error>> where - for<'de> O: Deserialize<'de>, + for<'de> O: Deserialize<'de> + Send, T: std::io::Read + std::io::Seek, { file_mapping.get(&file_name).map(|i| { From ceeb9e2cdebb1904b8e387dfc81448fb67cc794c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tristram=20Gr=C3=A4bener?= Date: Thu, 30 Dec 2021 13:14:00 +0100 Subject: [PATCH 3/6] No need to trim in serde_helpers It is the CSV library that does the trimming --- src/serde_helpers.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/serde_helpers.rs b/src/serde_helpers.rs index d1720e6..ac34509 100644 --- a/src/serde_helpers.rs +++ b/src/serde_helpers.rs @@ -49,7 +49,7 @@ pub fn parse_time_impl(v: Vec<&str>) -> Result { } pub fn parse_time(s: &str) -> Result { - let v: Vec<&str> = s.trim_start().split(':').collect(); + let v: Vec<&str> = s.split(':').collect(); if v.len() != 3 { Err(crate::Error::InvalidTime(s.to_owned())) } else { @@ -107,7 +107,6 @@ where D: Deserializer<'de>, { String::deserialize(de).and_then(|s| { - let s = s.trim(); if s.is_empty() { Ok(None) } else { @@ -134,11 +133,10 @@ where D: Deserializer<'de>, { String::deserialize(de).and_then(|s| { - let s = s.trim(); if s.is_empty() { Ok(None) } else { - parse_color(s).map(Some).map_err(de::Error::custom) + parse_color(&s).map(Some).map_err(de::Error::custom) } }) } From 177b576c57ca9c4df7e67ef1af792d50e9241ad4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tristram=20Gr=C3=A4bener?= Date: Thu, 30 Dec 2021 16:37:01 +0100 Subject: [PATCH 4/6] Use nom to parse datetime --- Cargo.toml | 1 + src/serde_helpers.rs | 27 +++++++++++++++------------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 532db31..5b66276 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ zip = "0.5" thiserror = "1" rgb = "0.8" rayon = "1.5" +nom = "7.1" futures = { version = "0.3", optional = true } reqwest = { version = "0.11", optional = true, features = ["blocking"]} diff --git a/src/serde_helpers.rs b/src/serde_helpers.rs index ac34509..fb143b3 100644 --- a/src/serde_helpers.rs +++ b/src/serde_helpers.rs @@ -1,4 +1,8 @@ use chrono::NaiveDate; +use nom::{ + character::complete::{char, digit1}, + sequence::tuple, +}; use rgb::RGB8; use serde::de::{self, Deserialize, Deserializer}; use serde::ser::Serializer; @@ -41,20 +45,19 @@ where } } -pub fn parse_time_impl(v: Vec<&str>) -> Result { - let hours: u32 = v[0].parse()?; - let minutes: u32 = v[1].parse()?; - let seconds: u32 = v[2].parse()?; - Ok(hours * 3600 + minutes * 60 + seconds) +fn parse_time_impl(h: &str, m: &str, s: &str) -> Result { + Ok(h.parse::()? * 3600 + m.parse::()? * 60 + s.parse::()?) } -pub fn parse_time(s: &str) -> Result { - let v: Vec<&str> = s.split(':').collect(); - if v.len() != 3 { - Err(crate::Error::InvalidTime(s.to_owned())) - } else { - parse_time_impl(v).map_err(|_| crate::Error::InvalidTime(s.to_owned())) - } +fn parse_time(s: &str) -> Result { + // Parsing the times in stop_times.txt is a significant bottleneck + // Using nom to parse the times result in an improvement of about 3% in performance + let mut parser = tuple::<&str, _, (_, _), _>((digit1, char(':'), digit1, char(':'), digit1)); + parser(s) + .map_err(|_| crate::Error::InvalidTime(s.to_owned())) + .and_then(|(_, (h, _, m, _, s))| { + parse_time_impl(h, m, s).map_err(|_| crate::Error::InvalidTime(s.to_owned())) + }) } pub fn deserialize_time<'de, D>(deserializer: D) -> Result From bb33dccca587387d91dacf182c716cf2c1f6d76c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tristram=20Gr=C3=A4bener?= Date: Fri, 31 Dec 2021 10:24:43 +0100 Subject: [PATCH 5/6] Remove useless &mut --- src/gtfs.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gtfs.rs b/src/gtfs.rs index 3dea6ba..6555cef 100644 --- a/src/gtfs.rs +++ b/src/gtfs.rs @@ -279,7 +279,7 @@ fn create_trips( frequencies: vec![], })); for s in raw_stop_times { - let trip = &mut trips + let trip = trips .get_mut(&s.trip_id) .ok_or_else(|| Error::ReferenceError(s.trip_id.to_string()))?; let stop = stops @@ -288,7 +288,7 @@ fn create_trips( trip.stop_times.push(StopTime::from(&s, Arc::clone(stop))); } - for trip in &mut trips.values_mut() { + for trip in trips.values_mut() { trip.stop_times .sort_by(|a, b| a.stop_sequence.cmp(&b.stop_sequence)); } From b6bb77079ee7cd3e166d660022078c9142cf8152 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tristram=20Gr=C3=A4bener?= Date: Sat, 1 Jan 2022 21:56:21 +0100 Subject: [PATCH 6/6] Use SmolStr for id This allows a small performance improvement (~1%) and should require less memory as most Id are commonly less than 22 chars long --- Cargo.toml | 1 + src/gtfs.rs | 43 ++++++++++++++++++++++--------------------- src/objects.rs | 35 ++++++++++++++++++----------------- 3 files changed, 41 insertions(+), 38 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5b66276..b319423 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ thiserror = "1" rgb = "0.8" rayon = "1.5" nom = "7.1" +smol_str = { version = "*", features = ["serde"] } futures = { version = "0.3", optional = true } reqwest = { version = "0.11", optional = true, features = ["blocking"]} diff --git a/src/gtfs.rs b/src/gtfs.rs index 6555cef..79a8976 100644 --- a/src/gtfs.rs +++ b/src/gtfs.rs @@ -5,6 +5,7 @@ use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; use std::sync::Arc; +type Map = HashMap; /// Data structure with all the GTFS objects /// /// This structure is easier to use than the [RawGtfs] structure as some relationships are parsed to be easier to use. @@ -24,21 +25,21 @@ pub struct Gtfs { /// Time needed to read and parse the archive in milliseconds pub read_duration: i64, /// All Calendar by `service_id` - pub calendar: HashMap, + pub calendar: Map, /// All calendar dates grouped by service_id - pub calendar_dates: HashMap>, + pub calendar_dates: Map>, /// All stop by `stop_id`. Stops are in an [Arc] because they are also referenced by each [StopTime] - pub stops: HashMap>, + pub stops: Map>, /// All routes by `route_id` - pub routes: HashMap, + pub routes: Map, /// All trips by `trip_id` - pub trips: HashMap, + pub trips: Map, /// All agencies. They can not be read by `agency_id`, as it is not a required field pub agencies: Vec, /// All shapes by shape_id - pub shapes: HashMap>, + pub shapes: Map>, /// All fare attributes by `fare_id` - pub fare_attributes: HashMap, + pub fare_attributes: Map, /// All feed information. There is no identifier pub feed_info: Vec, } @@ -221,24 +222,24 @@ impl Gtfs { } } -fn to_map(elements: impl IntoIterator) -> HashMap { +fn to_map(elements: impl IntoIterator) -> Map { elements .into_iter() - .map(|e| (e.id().to_owned(), e)) + .map(|e| (smol_str::SmolStr::new(e.id()), e)) .collect() } -fn to_stop_map(stops: Vec) -> HashMap> { +fn to_stop_map(stops: Vec) -> Map> { stops .into_iter() .map(|s| (s.id.clone(), Arc::new(s))) .collect() } -fn to_shape_map(shapes: Vec) -> HashMap> { - let mut res = HashMap::default(); +fn to_shape_map(shapes: Vec) -> Map> { + let mut res = Map::default(); for s in shapes { - let shape = res.entry(s.id.to_owned()).or_insert_with(Vec::new); + let shape = res.entry(s.id.clone()).or_insert_with(Vec::new); shape.push(s); } // we sort the shape by it's pt_sequence @@ -249,10 +250,10 @@ fn to_shape_map(shapes: Vec) -> HashMap> { res } -fn to_calendar_dates(cd: Vec) -> HashMap> { - let mut res = HashMap::default(); +fn to_calendar_dates(cd: Vec) -> Map> { + let mut res = Map::default(); for c in cd { - let cal = res.entry(c.service_id.to_owned()).or_insert_with(Vec::new); + let cal = res.entry(c.service_id.clone()).or_insert_with(Vec::new); cal.push(c); } res @@ -262,8 +263,8 @@ fn create_trips( raw_trips: Vec, raw_stop_times: Vec, raw_frequencies: Vec, - stops: &HashMap>, -) -> Result, Error> { + stops: &Map>, +) -> Result, Error> { let mut trips = to_map(raw_trips.into_iter().map(|rt| Trip { id: rt.id, service_id: rt.service_id, @@ -280,10 +281,10 @@ fn create_trips( })); for s in raw_stop_times { let trip = trips - .get_mut(&s.trip_id) + .get_mut(s.trip_id.as_str()) .ok_or_else(|| Error::ReferenceError(s.trip_id.to_string()))?; let stop = stops - .get(&s.stop_id) + .get(s.stop_id.as_str()) .ok_or_else(|| Error::ReferenceError(s.stop_id.to_string()))?; trip.stop_times.push(StopTime::from(&s, Arc::clone(stop))); } @@ -295,7 +296,7 @@ fn create_trips( for f in raw_frequencies { let trip = &mut trips - .get_mut(&f.trip_id) + .get_mut(f.trip_id.as_str()) .ok_or_else(|| Error::ReferenceError(f.trip_id.to_string()))?; trip.frequencies.push(Frequency::from(&f)); } diff --git a/src/objects.rs b/src/objects.rs index 7cd37a4..58a75df 100644 --- a/src/objects.rs +++ b/src/objects.rs @@ -3,6 +3,7 @@ use crate::serde_helpers::*; use chrono::{Datelike, NaiveDate, Weekday}; use rgb::RGB8; +use smol_str::SmolStr; use std::fmt; use std::sync::Arc; @@ -25,7 +26,7 @@ pub trait Type { pub struct Calendar { /// Unique technical identifier (not for the traveller) of this calendar #[serde(rename = "service_id")] - pub id: String, + pub id: SmolStr, /// Does the service run on mondays #[serde( deserialize_with = "deserialize_bool", @@ -119,7 +120,7 @@ impl Calendar { #[derive(Debug, Deserialize, Serialize)] pub struct CalendarDate { /// Identifier of the service that is modified at this date - pub service_id: String, + pub service_id: SmolStr, #[serde( deserialize_with = "deserialize_date", serialize_with = "serialize_date" @@ -135,7 +136,7 @@ pub struct CalendarDate { pub struct Stop { /// Unique technical identifier (not for the traveller) of the stop #[serde(rename = "stop_id")] - pub id: String, + pub id: SmolStr, /// Short text or a number that identifies the location for riders #[serde(rename = "stop_code")] pub code: Option, @@ -197,7 +198,7 @@ impl fmt::Display for Stop { #[derive(Debug, Serialize, Deserialize, Default)] pub struct RawStopTime { /// [Trip] to which this stop time belongs to - pub trip_id: String, + pub trip_id: smol_str::SmolStr, /// Arrival time of the stop time. /// It's an option since the intermediate stops can have have no arrival /// and this arrival needs to be interpolated @@ -215,7 +216,7 @@ pub struct RawStopTime { )] pub departure_time: Option, /// Identifier of the [Stop] where the vehicle stops - pub stop_id: String, + pub stop_id: smol_str::SmolStr, /// Order of stops for a particular trip. The values must increase along the trip but do not need to be consecutive pub stop_sequence: u16, /// Text that appears on signage identifying the trip's destination to riders @@ -294,7 +295,7 @@ impl StopTime { pub struct Route { /// Unique technical (not for the traveller) identifier for the route #[serde(rename = "route_id")] - pub id: String, + pub id: SmolStr, /// Short name of a route. This will often be a short, abstract identifier like "32", "100X", or "Green" that riders use to identify a route, but which doesn't give any indication of what places the route serves #[serde(rename = "route_short_name")] pub short_name: String, @@ -363,13 +364,13 @@ impl fmt::Display for Route { pub struct RawTrip { /// Unique technical (not for the traveller) identifier for the Trip #[serde(rename = "trip_id")] - pub id: String, + pub id: SmolStr, /// References the [Calendar] on which this trip runs - pub service_id: String, + pub service_id: SmolStr, /// References along which [Route] this trip runs - pub route_id: String, + pub route_id: SmolStr, /// Shape of the trip - pub shape_id: Option, + pub shape_id: Option, /// Text that appears on signage identifying the trip's destination to riders pub trip_headsign: Option, /// Public facing text used to identify the trip to riders, for instance, to identify train numbers for commuter rail trips @@ -412,15 +413,15 @@ impl fmt::Display for RawTrip { #[derive(Debug, Default)] pub struct Trip { /// Unique technical identifier (not for the traveller) for the Trip - pub id: String, + pub id: SmolStr, /// References the [Calendar] on which this trip runs - pub service_id: String, + pub service_id: SmolStr, /// References along which [Route] this trip runs - pub route_id: String, + pub route_id: SmolStr, /// All the [StopTime] that define the trip pub stop_times: Vec, /// Text that appears on signage identifying the trip's destination to riders - pub shape_id: Option, + pub shape_id: Option, /// Text that appears on signage identifying the trip's destination to riders pub trip_headsign: Option, /// Public facing text used to identify the trip to riders, for instance, to identify train numbers for commuter rail trips @@ -514,7 +515,7 @@ impl fmt::Display for Agency { pub struct Shape { /// Unique technical (not for the traveller) identifier for the Shape #[serde(rename = "shape_id")] - pub id: String, + pub id: SmolStr, #[serde(rename = "shape_pt_lat", default)] /// Latitude of a shape point pub latitude: f64, @@ -546,7 +547,7 @@ impl Id for Shape { pub struct FareAttribute { /// Unique technical (not for the traveller) identifier for the FareAttribute #[serde(rename = "fare_id")] - pub id: String, + pub id: SmolStr, /// Fare price, in the unit specified by [FareAttribute::currency] pub price: String, /// Currency used to pay the fare. @@ -578,7 +579,7 @@ impl Type for FareAttribute { #[derive(Debug, Serialize, Deserialize, Default)] pub struct RawFrequency { /// References the [Trip] that uses frequency - pub trip_id: String, + pub trip_id: SmolStr, /// Time at which the first vehicle departs from the first stop of the trip #[serde( deserialize_with = "deserialize_time",