diff --git a/Cargo.toml b/Cargo.toml index fc57eec..b319423 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,9 @@ sha2 = "0.10" zip = "0.5" thiserror = "1" rgb = "0.8" +rayon = "1.5" +nom = "7.1" +smol_str = { version = "*", features = ["serde"] } futures = { version = "0.3", optional = true } reqwest = { version = "0.11", optional = true, features = ["blocking"]} diff --git a/src/gtfs.rs b/src/gtfs.rs index 3dea6ba..79a8976 100644 --- a/src/gtfs.rs +++ b/src/gtfs.rs @@ -5,6 +5,7 @@ use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; use std::sync::Arc; +type Map = HashMap; /// Data structure with all the GTFS objects /// /// This structure is easier to use than the [RawGtfs] structure as some relationships are parsed to be easier to use. @@ -24,21 +25,21 @@ pub struct Gtfs { /// Time needed to read and parse the archive in milliseconds pub read_duration: i64, /// All Calendar by `service_id` - pub calendar: HashMap, + pub calendar: Map, /// All calendar dates grouped by service_id - pub calendar_dates: HashMap>, + pub calendar_dates: Map>, /// All stop by `stop_id`. Stops are in an [Arc] because they are also referenced by each [StopTime] - pub stops: HashMap>, + pub stops: Map>, /// All routes by `route_id` - pub routes: HashMap, + pub routes: Map, /// All trips by `trip_id` - pub trips: HashMap, + pub trips: Map, /// All agencies. They can not be read by `agency_id`, as it is not a required field pub agencies: Vec, /// All shapes by shape_id - pub shapes: HashMap>, + pub shapes: Map>, /// All fare attributes by `fare_id` - pub fare_attributes: HashMap, + pub fare_attributes: Map, /// All feed information. There is no identifier pub feed_info: Vec, } @@ -221,24 +222,24 @@ impl Gtfs { } } -fn to_map(elements: impl IntoIterator) -> HashMap { +fn to_map(elements: impl IntoIterator) -> Map { elements .into_iter() - .map(|e| (e.id().to_owned(), e)) + .map(|e| (smol_str::SmolStr::new(e.id()), e)) .collect() } -fn to_stop_map(stops: Vec) -> HashMap> { +fn to_stop_map(stops: Vec) -> Map> { stops .into_iter() .map(|s| (s.id.clone(), Arc::new(s))) .collect() } -fn to_shape_map(shapes: Vec) -> HashMap> { - let mut res = HashMap::default(); +fn to_shape_map(shapes: Vec) -> Map> { + let mut res = Map::default(); for s in shapes { - let shape = res.entry(s.id.to_owned()).or_insert_with(Vec::new); + let shape = res.entry(s.id.clone()).or_insert_with(Vec::new); shape.push(s); } // we sort the shape by it's pt_sequence @@ -249,10 +250,10 @@ fn to_shape_map(shapes: Vec) -> HashMap> { res } -fn to_calendar_dates(cd: Vec) -> HashMap> { - let mut res = HashMap::default(); +fn to_calendar_dates(cd: Vec) -> Map> { + let mut res = Map::default(); for c in cd { - let cal = res.entry(c.service_id.to_owned()).or_insert_with(Vec::new); + let cal = res.entry(c.service_id.clone()).or_insert_with(Vec::new); cal.push(c); } res @@ -262,8 +263,8 @@ fn create_trips( raw_trips: Vec, raw_stop_times: Vec, raw_frequencies: Vec, - stops: &HashMap>, -) -> Result, Error> { + stops: &Map>, +) -> Result, Error> { let mut trips = to_map(raw_trips.into_iter().map(|rt| Trip { id: rt.id, service_id: rt.service_id, @@ -279,23 +280,23 @@ fn create_trips( frequencies: vec![], })); for s in raw_stop_times { - let trip = &mut trips - .get_mut(&s.trip_id) + let trip = trips + .get_mut(s.trip_id.as_str()) .ok_or_else(|| Error::ReferenceError(s.trip_id.to_string()))?; let stop = stops - .get(&s.stop_id) + .get(s.stop_id.as_str()) .ok_or_else(|| Error::ReferenceError(s.stop_id.to_string()))?; trip.stop_times.push(StopTime::from(&s, Arc::clone(stop))); } - for trip in &mut trips.values_mut() { + for trip in trips.values_mut() { trip.stop_times .sort_by(|a, b| a.stop_sequence.cmp(&b.stop_sequence)); } for f in raw_frequencies { let trip = &mut trips - .get_mut(&f.trip_id) + .get_mut(f.trip_id.as_str()) .ok_or_else(|| Error::ReferenceError(f.trip_id.to_string()))?; trip.frequencies.push(Frequency::from(&f)); } diff --git a/src/gtfs_reader.rs b/src/gtfs_reader.rs index e745669..d3bf14e 100644 --- a/src/gtfs_reader.rs +++ b/src/gtfs_reader.rs @@ -3,6 +3,7 @@ use serde::Deserialize; use sha2::{Digest, Sha256}; use crate::{Error, Gtfs, RawGtfs}; +use rayon::prelude::*; use std::collections::HashMap; use std::convert::TryFrom; use std::fs::File; @@ -268,7 +269,7 @@ impl RawGtfsReader { fn read_objs(mut reader: T, file_name: &str) -> Result, Error> where - for<'de> O: Deserialize<'de>, + for<'de> O: Deserialize<'de> + Send, T: std::io::Read, { let mut bom = [0; 3]; @@ -299,30 +300,33 @@ where })? .clone(); - let mut res = Vec::new(); - for rec in reader.records() { - let r = rec.map_err(|e| Error::CSVError { - file_name: file_name.to_owned(), - source: e, - line_in_error: None, - })?; - let o = r.deserialize(Some(&headers)).map_err(|e| Error::CSVError { - file_name: file_name.to_owned(), - source: e, - line_in_error: Some(crate::error::LineError { - headers: headers.into_iter().map(|s| s.to_owned()).collect(), - values: r.into_iter().map(|s| s.to_owned()).collect(), - }), - })?; - res.push(o); - } - - Ok(res) + let v = reader + .records() + .map(|rec| { + rec.map_err(|e| Error::CSVError { + file_name: file_name.to_owned(), + source: e, + line_in_error: None, + }) + }) + .collect::, Error>>()?; + v.par_iter() + .map(|r| { + r.deserialize(Some(&headers)).map_err(|e| Error::CSVError { + file_name: file_name.to_owned(), + source: e, + line_in_error: Some(crate::error::LineError { + headers: headers.into_iter().map(|s| s.to_owned()).collect(), + values: r.into_iter().map(|s| s.to_owned()).collect(), + }), + }) + }) + .collect() } fn read_objs_from_path(path: std::path::PathBuf) -> Result, Error> where - for<'de> O: Deserialize<'de>, + for<'de> O: Deserialize<'de> + Send, { let file_name = path .file_name() @@ -346,7 +350,7 @@ fn read_objs_from_optional_path( file_name: &str, ) -> Option, Error>> where - for<'de> O: Deserialize<'de>, + for<'de> O: Deserialize<'de> + Send, { File::open(dir_path.join(file_name)) .ok() @@ -359,7 +363,7 @@ fn read_file( file_name: &str, ) -> Result, Error> where - for<'de> O: Deserialize<'de>, + for<'de> O: Deserialize<'de> + Send, T: std::io::Read + std::io::Seek, { read_optional_file(file_mapping, archive, file_name) @@ -372,7 +376,7 @@ fn read_optional_file( file_name: &str, ) -> Option, Error>> where - for<'de> O: Deserialize<'de>, + for<'de> O: Deserialize<'de> + Send, T: std::io::Read + std::io::Seek, { file_mapping.get(&file_name).map(|i| { diff --git a/src/objects.rs b/src/objects.rs index 7cd37a4..58a75df 100644 --- a/src/objects.rs +++ b/src/objects.rs @@ -3,6 +3,7 @@ use crate::serde_helpers::*; use chrono::{Datelike, NaiveDate, Weekday}; use rgb::RGB8; +use smol_str::SmolStr; use std::fmt; use std::sync::Arc; @@ -25,7 +26,7 @@ pub trait Type { pub struct Calendar { /// Unique technical identifier (not for the traveller) of this calendar #[serde(rename = "service_id")] - pub id: String, + pub id: SmolStr, /// Does the service run on mondays #[serde( deserialize_with = "deserialize_bool", @@ -119,7 +120,7 @@ impl Calendar { #[derive(Debug, Deserialize, Serialize)] pub struct CalendarDate { /// Identifier of the service that is modified at this date - pub service_id: String, + pub service_id: SmolStr, #[serde( deserialize_with = "deserialize_date", serialize_with = "serialize_date" @@ -135,7 +136,7 @@ pub struct CalendarDate { pub struct Stop { /// Unique technical identifier (not for the traveller) of the stop #[serde(rename = "stop_id")] - pub id: String, + pub id: SmolStr, /// Short text or a number that identifies the location for riders #[serde(rename = "stop_code")] pub code: Option, @@ -197,7 +198,7 @@ impl fmt::Display for Stop { #[derive(Debug, Serialize, Deserialize, Default)] pub struct RawStopTime { /// [Trip] to which this stop time belongs to - pub trip_id: String, + pub trip_id: smol_str::SmolStr, /// Arrival time of the stop time. /// It's an option since the intermediate stops can have have no arrival /// and this arrival needs to be interpolated @@ -215,7 +216,7 @@ pub struct RawStopTime { )] pub departure_time: Option, /// Identifier of the [Stop] where the vehicle stops - pub stop_id: String, + pub stop_id: smol_str::SmolStr, /// Order of stops for a particular trip. The values must increase along the trip but do not need to be consecutive pub stop_sequence: u16, /// Text that appears on signage identifying the trip's destination to riders @@ -294,7 +295,7 @@ impl StopTime { pub struct Route { /// Unique technical (not for the traveller) identifier for the route #[serde(rename = "route_id")] - pub id: String, + pub id: SmolStr, /// Short name of a route. This will often be a short, abstract identifier like "32", "100X", or "Green" that riders use to identify a route, but which doesn't give any indication of what places the route serves #[serde(rename = "route_short_name")] pub short_name: String, @@ -363,13 +364,13 @@ impl fmt::Display for Route { pub struct RawTrip { /// Unique technical (not for the traveller) identifier for the Trip #[serde(rename = "trip_id")] - pub id: String, + pub id: SmolStr, /// References the [Calendar] on which this trip runs - pub service_id: String, + pub service_id: SmolStr, /// References along which [Route] this trip runs - pub route_id: String, + pub route_id: SmolStr, /// Shape of the trip - pub shape_id: Option, + pub shape_id: Option, /// Text that appears on signage identifying the trip's destination to riders pub trip_headsign: Option, /// Public facing text used to identify the trip to riders, for instance, to identify train numbers for commuter rail trips @@ -412,15 +413,15 @@ impl fmt::Display for RawTrip { #[derive(Debug, Default)] pub struct Trip { /// Unique technical identifier (not for the traveller) for the Trip - pub id: String, + pub id: SmolStr, /// References the [Calendar] on which this trip runs - pub service_id: String, + pub service_id: SmolStr, /// References along which [Route] this trip runs - pub route_id: String, + pub route_id: SmolStr, /// All the [StopTime] that define the trip pub stop_times: Vec, /// Text that appears on signage identifying the trip's destination to riders - pub shape_id: Option, + pub shape_id: Option, /// Text that appears on signage identifying the trip's destination to riders pub trip_headsign: Option, /// Public facing text used to identify the trip to riders, for instance, to identify train numbers for commuter rail trips @@ -514,7 +515,7 @@ impl fmt::Display for Agency { pub struct Shape { /// Unique technical (not for the traveller) identifier for the Shape #[serde(rename = "shape_id")] - pub id: String, + pub id: SmolStr, #[serde(rename = "shape_pt_lat", default)] /// Latitude of a shape point pub latitude: f64, @@ -546,7 +547,7 @@ impl Id for Shape { pub struct FareAttribute { /// Unique technical (not for the traveller) identifier for the FareAttribute #[serde(rename = "fare_id")] - pub id: String, + pub id: SmolStr, /// Fare price, in the unit specified by [FareAttribute::currency] pub price: String, /// Currency used to pay the fare. @@ -578,7 +579,7 @@ impl Type for FareAttribute { #[derive(Debug, Serialize, Deserialize, Default)] pub struct RawFrequency { /// References the [Trip] that uses frequency - pub trip_id: String, + pub trip_id: SmolStr, /// Time at which the first vehicle departs from the first stop of the trip #[serde( deserialize_with = "deserialize_time", diff --git a/src/serde_helpers.rs b/src/serde_helpers.rs index d1720e6..fb143b3 100644 --- a/src/serde_helpers.rs +++ b/src/serde_helpers.rs @@ -1,4 +1,8 @@ use chrono::NaiveDate; +use nom::{ + character::complete::{char, digit1}, + sequence::tuple, +}; use rgb::RGB8; use serde::de::{self, Deserialize, Deserializer}; use serde::ser::Serializer; @@ -41,20 +45,19 @@ where } } -pub fn parse_time_impl(v: Vec<&str>) -> Result { - let hours: u32 = v[0].parse()?; - let minutes: u32 = v[1].parse()?; - let seconds: u32 = v[2].parse()?; - Ok(hours * 3600 + minutes * 60 + seconds) +fn parse_time_impl(h: &str, m: &str, s: &str) -> Result { + Ok(h.parse::()? * 3600 + m.parse::()? * 60 + s.parse::()?) } -pub fn parse_time(s: &str) -> Result { - let v: Vec<&str> = s.trim_start().split(':').collect(); - if v.len() != 3 { - Err(crate::Error::InvalidTime(s.to_owned())) - } else { - parse_time_impl(v).map_err(|_| crate::Error::InvalidTime(s.to_owned())) - } +fn parse_time(s: &str) -> Result { + // Parsing the times in stop_times.txt is a significant bottleneck + // Using nom to parse the times result in an improvement of about 3% in performance + let mut parser = tuple::<&str, _, (_, _), _>((digit1, char(':'), digit1, char(':'), digit1)); + parser(s) + .map_err(|_| crate::Error::InvalidTime(s.to_owned())) + .and_then(|(_, (h, _, m, _, s))| { + parse_time_impl(h, m, s).map_err(|_| crate::Error::InvalidTime(s.to_owned())) + }) } pub fn deserialize_time<'de, D>(deserializer: D) -> Result @@ -107,7 +110,6 @@ where D: Deserializer<'de>, { String::deserialize(de).and_then(|s| { - let s = s.trim(); if s.is_empty() { Ok(None) } else { @@ -134,11 +136,10 @@ where D: Deserializer<'de>, { String::deserialize(de).and_then(|s| { - let s = s.trim(); if s.is_empty() { Ok(None) } else { - parse_color(s).map(Some).map_err(de::Error::custom) + parse_color(&s).map(Some).map_err(de::Error::custom) } }) }