diff --git a/src/agent.rs b/src/agent.rs index a69570ea1..b0de8a60b 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,7 +1,8 @@ //! Agents drive the economy of the MUSE 2.0 simulation, through relative investment in different //! assets. use crate::commodity::Commodity; -use crate::id::define_id_getter; +use crate::id::{define_id_getter, define_id_type}; +use crate::process::ProcessID; use crate::region::RegionSelection; use indexmap::IndexMap; use serde::Deserialize; @@ -9,14 +10,16 @@ use serde_string_enum::DeserializeLabeledStringEnum; use std::collections::HashSet; use std::rc::Rc; +define_id_type! {AgentID} + /// A map of [`Agent`]s, keyed by agent ID -pub type AgentMap = IndexMap, Agent>; +pub type AgentMap = IndexMap; /// An agent in the simulation #[derive(Debug, Clone, PartialEq)] pub struct Agent { /// A unique identifier for the agent. - pub id: Rc, + pub id: AgentID, /// A text description of the agent. pub description: String, /// The commodities that the agent is responsible for servicing. @@ -34,7 +37,7 @@ pub struct Agent { /// The agent's objectives. pub objectives: Vec, } -define_id_getter! {Agent} +define_id_getter! {Agent, AgentID} /// Which processes apply to this agent #[derive(Debug, Clone, PartialEq)] @@ -42,7 +45,7 @@ pub enum SearchSpace { /// All processes are considered AllProcesses, /// Only these specific processes are considered - Some(HashSet>), + Some(HashSet), } /// Search space for an agent @@ -74,7 +77,7 @@ pub enum DecisionRule { #[derive(Debug, Clone, Deserialize, PartialEq)] pub struct AgentObjective { /// Unique agent id identifying the agent this objective belongs to - pub agent_id: String, + pub agent_id: AgentID, /// The year the objective is relevant for pub year: u32, /// Acronym identifying the objective (e.g. LCOX) diff --git a/src/asset.rs b/src/asset.rs index 5f13dc831..bca91ad7d 100644 --- a/src/asset.rs +++ b/src/asset.rs @@ -1,6 +1,8 @@ //! Assets are instances of a process which are owned and invested in by agents. +use crate::agent::AgentID; use crate::commodity::Commodity; use crate::process::Process; +use crate::region::RegionID; use crate::time_slice::TimeSliceID; use std::collections::HashSet; use std::ops::RangeInclusive; @@ -21,11 +23,11 @@ pub struct Asset { /// A unique identifier for the asset pub id: AssetID, /// A unique identifier for the agent - pub agent_id: Rc, + pub agent_id: AgentID, /// The [`Process`] that this asset corresponds to pub process: Rc, /// The region in which the asset is located - pub region_id: Rc, + pub region_id: RegionID, /// Capacity of asset pub capacity: f64, /// The year the asset comes online @@ -38,9 +40,9 @@ impl Asset { /// The `id` field is initially set to [`AssetID::INVALID`], but is changed to a unique value /// when the asset is stored in an [`AssetPool`]. pub fn new( - agent_id: Rc, + agent_id: AgentID, process: Rc, - region_id: Rc, + region_id: RegionID, capacity: f64, commission_year: u32, ) -> Self { @@ -146,7 +148,7 @@ impl AssetPool { /// Iterate over active assets for a particular region pub fn iter_for_region<'a>( &'a self, - region_id: &'a Rc, + region_id: &'a RegionID, ) -> impl Iterator { self.iter().filter(|asset| asset.region_id == *region_id) } @@ -155,7 +157,7 @@ impl AssetPool { /// commodity pub fn iter_for_region_and_commodity<'a>( &'a self, - region_id: &'a Rc, + region_id: &'a RegionID, commodity: &'a Rc, ) -> impl Iterator { self.iter_for_region(region_id) diff --git a/src/commodity.rs b/src/commodity.rs index f1eb9d4a4..0fc15d1fa 100644 --- a/src/commodity.rs +++ b/src/commodity.rs @@ -1,5 +1,6 @@ #![allow(missing_docs)] -use crate::id::define_id_getter; +use crate::id::{define_id_getter, define_id_type}; +use crate::region::RegionID; use crate::time_slice::{TimeSliceID, TimeSliceLevel}; use indexmap::IndexMap; use serde::Deserialize; @@ -7,15 +8,17 @@ use serde_string_enum::DeserializeLabeledStringEnum; use std::collections::HashMap; use std::rc::Rc; +define_id_type! {CommodityID} + /// A map of [`Commodity`]s, keyed by commodity ID -pub type CommodityMap = IndexMap, Rc>; +pub type CommodityMap = IndexMap>; /// A commodity within the simulation. Represents a substance (e.g. CO2) or form of energy (e.g. /// electricity) that can be produced and/or consumed by technologies in the model. #[derive(PartialEq, Debug, Deserialize)] pub struct Commodity { /// Unique identifier for the commodity (e.g. "ELC") - pub id: Rc, + pub id: CommodityID, /// Text description of commodity (e.g. "electricity") pub description: String, #[serde(rename = "type")] // NB: we can't name a field type as it's a reserved keyword @@ -29,7 +32,7 @@ pub struct Commodity { #[serde(skip)] pub demand: DemandMap, } -define_id_getter! {Commodity} +define_id_getter! {Commodity, CommodityID} /// Type of balance for application of cost #[derive(PartialEq, Clone, Debug, DeserializeLabeledStringEnum)] @@ -54,7 +57,7 @@ pub struct CommodityCost { /// Used for looking up [`CommodityCost`]s in a [`CommodityCostMap`] #[derive(PartialEq, Eq, Hash, Debug, Clone)] struct CommodityCostKey { - region_id: Rc, + region_id: RegionID, year: u32, time_slice: TimeSliceID, } @@ -72,7 +75,7 @@ impl CommodityCostMap { /// Insert a [`CommodityCost`] into the map pub fn insert( &mut self, - region_id: Rc, + region_id: RegionID, year: u32, time_slice: TimeSliceID, value: CommodityCost, @@ -88,12 +91,12 @@ impl CommodityCostMap { /// Retrieve a [`CommodityCost`] from the map pub fn get( &self, - region_id: &Rc, + region_id: &RegionID, year: u32, time_slice: &TimeSliceID, ) -> Option<&CommodityCost> { let key = CommodityCostKey { - region_id: Rc::clone(region_id), + region_id: region_id.clone(), year, time_slice: time_slice.clone(), }; @@ -124,7 +127,7 @@ pub struct DemandMap(HashMap); /// The key for a [`DemandMap`] #[derive(PartialEq, Eq, Hash, Debug, Clone)] struct DemandMapKey { - region_id: Rc, + region_id: RegionID, year: u32, time_slice: TimeSliceID, } @@ -136,7 +139,7 @@ impl DemandMap { } /// Retrieve the demand for the specified region, year and time slice - pub fn get(&self, region_id: &Rc, year: u32, time_slice: &TimeSliceID) -> f64 { + pub fn get(&self, region_id: &RegionID, year: u32, time_slice: &TimeSliceID) -> f64 { self.0 .get(&DemandMapKey { region_id: region_id.clone(), @@ -148,7 +151,7 @@ impl DemandMap { } /// Insert a new demand entry for the specified region, year and time slice - pub fn insert(&mut self, region_id: Rc, year: u32, time_slice: TimeSliceID, demand: f64) { + pub fn insert(&mut self, region_id: RegionID, year: u32, time_slice: TimeSliceID, demand: f64) { self.0.insert( DemandMapKey { region_id, diff --git a/src/id.rs b/src/id.rs index 8a29c955b..f76d1612f 100644 --- a/src/id.rs +++ b/src/id.rs @@ -1,25 +1,80 @@ //! Code for handing IDs +use crate::region::RegionID; use anyhow::{Context, Result}; use std::collections::HashSet; -use std::rc::Rc; + +/// A trait alias for ID types +pub trait IDLike: + Eq + std::hash::Hash + std::borrow::Borrow + Clone + std::fmt::Display +{ +} +impl IDLike for T where + T: Eq + std::hash::Hash + std::borrow::Borrow + Clone + std::fmt::Display +{ +} + +macro_rules! define_id_type { + ($name:ident) => { + #[derive( + Clone, std::hash::Hash, PartialEq, Eq, serde::Deserialize, Debug, serde::Serialize, + )] + /// An ID type (e.g. `AgentID`, `CommodityID`, etc.) + pub struct $name(pub std::rc::Rc); + + impl std::borrow::Borrow for $name { + fn borrow(&self) -> &str { + &self.0 + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + + impl From<&str> for $name { + fn from(s: &str) -> Self { + $name(std::rc::Rc::from(s)) + } + } + + impl From for $name { + fn from(s: String) -> Self { + $name(std::rc::Rc::from(s)) + } + } + + impl $name { + /// Create a new ID from a string slice + pub fn new(id: &str) -> Self { + $name(std::rc::Rc::from(id)) + } + } + }; +} +pub(crate) use define_id_type; + +#[cfg(test)] +define_id_type!(GenericID); /// Indicates that the struct has an ID field -pub trait HasID { - /// Get a string representation of the struct's ID - fn get_id(&self) -> &str; +pub trait HasID { + /// Get the struct's ID + fn get_id(&self) -> &ID; } /// An object which is associated with a single region pub trait HasRegionID { /// Get the associated region ID - fn get_region_id(&self) -> &str; + fn get_region_id(&self) -> &RegionID; } /// Implement the `HasID` trait for the given type, assuming it has a field called `id` macro_rules! define_id_getter { - ($t:ty) => { - impl crate::id::HasID for $t { - fn get_id(&self) -> &str { + ($t:ty, $id_ty:ty) => { + impl crate::id::HasID<$id_ty> for $t { + fn get_id(&self) -> &$id_ty { &self.id } } @@ -31,7 +86,7 @@ pub(crate) use define_id_getter; macro_rules! define_region_id_getter { ($t:ty) => { impl crate::id::HasRegionID for $t { - fn get_region_id(&self) -> &str { + fn get_region_id(&self) -> &RegionID { &self.region_id } } @@ -40,24 +95,42 @@ macro_rules! define_region_id_getter { pub(crate) use define_region_id_getter; /// A data structure containing a set of IDs -pub trait IDCollection { - /// Get the ID after checking that it exists this collection. +pub trait IDCollection { + /// Get the ID from the collection by its string representation. + /// + /// # Arguments + /// + /// * `id` - The string representation of the ID + /// + /// # Returns + /// + /// A copy of the ID in `self`, or an error if not found. + fn get_id_by_str(&self, id: &str) -> Result; + + /// Check if the ID is in the collection, returning a copy of it if found. /// /// # Arguments /// - /// * `id` - The ID to look up + /// * `id` - The ID to check /// /// # Returns /// - /// A copy of the `Rc` in `self` or an error if not found. - fn get_id(&self, id: &str) -> Result>; + /// A copy of the ID in `self`, or an error if not found. + fn get_id(&self, id: &ID) -> Result; } -impl IDCollection for HashSet> { - fn get_id(&self, id: &str) -> Result> { - let id = self +impl IDCollection for HashSet { + fn get_id_by_str(&self, id: &str) -> Result { + let found = self .get(id) .with_context(|| format!("Unknown ID {id} found"))?; - Ok(Rc::clone(id)) + Ok(found.clone()) + } + + fn get_id(&self, id: &ID) -> Result { + let found = self + .get(id.borrow()) + .with_context(|| format!("Unknown ID {id} found"))?; + Ok(found.clone()) } } diff --git a/src/input.rs b/src/input.rs index ac302c5a6..1b08d337f 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,6 +1,6 @@ //! Common routines for handling input data. use crate::asset::AssetPool; -use crate::id::HasID; +use crate::id::{HasID, IDLike}; use crate::model::{Model, ModelFile}; use anyhow::{bail, ensure, Context, Result}; use float_cmp::approx_eq; @@ -9,8 +9,8 @@ use itertools::Itertools; use serde::de::{Deserialize, DeserializeOwned, Deserializer}; use std::collections::HashSet; use std::fs; +use std::hash::Hash; use std::path::Path; -use std::rc::Rc; mod agent; use agent::read_agents; @@ -101,18 +101,18 @@ pub fn input_err_msg>(file_path: P) -> String { /// /// As this function is only ever used for top-level CSV files (i.e. the ones which actually define /// the IDs for a given type), we use an ordered map to maintain the order in the input files. -fn read_csv_id_file(file_path: &Path) -> Result, T>> +fn read_csv_id_file(file_path: &Path) -> Result> where - T: HasID + DeserializeOwned, + T: HasID + DeserializeOwned, { - fn fill_and_validate_map(file_path: &Path) -> Result, T>> + fn fill_and_validate_map(file_path: &Path) -> Result> where - T: HasID + DeserializeOwned, + T: HasID + DeserializeOwned, { let mut map = IndexMap::new(); for record in read_csv::(file_path)? { - let id = record.get_id().into(); - let existing = map.insert(Rc::clone(&id), record).is_some(); + let id = record.get_id().clone(); + let existing = map.insert(id.clone(), record).is_some(); ensure!(!existing, "Duplicate ID found: {id}"); } ensure!(!map.is_empty(), "CSV file is empty"); @@ -186,6 +186,8 @@ pub fn load_model>(model_dir: P) -> Result<(Model, AssetPool)> { #[cfg(test)] mod tests { + use crate::id::GenericID; + use super::*; use serde::de::value::{Error as ValueError, F64Deserializer}; use serde::de::IntoDeserializer; @@ -197,12 +199,12 @@ mod tests { #[derive(Debug, PartialEq, Deserialize)] struct Record { - id: String, + id: GenericID, value: u32, } - impl HasID for Record { - fn get_id(&self) -> &str { + impl HasID for Record { + fn get_id(&self) -> &GenericID { &self.id } } @@ -225,11 +227,11 @@ mod tests { records, &[ Record { - id: "hello".to_string(), + id: "hello".into(), value: 1, }, Record { - id: "world".to_string(), + id: "world".into(), value: 2, } ] @@ -256,7 +258,7 @@ mod tests { assert_eq!( read_toml::(&file_path).unwrap(), Record { - id: "hello".to_string(), + id: "hello".into(), value: 1, } ); diff --git a/src/input/agent.rs b/src/input/agent.rs index 301168573..a5b7f95ce 100644 --- a/src/input/agent.rs +++ b/src/input/agent.rs @@ -1,14 +1,13 @@ //! Code for reading in agent-related data from CSV files. use super::*; -use crate::agent::{Agent, AgentMap, DecisionRule}; +use crate::agent::{Agent, AgentID, AgentMap, DecisionRule}; use crate::commodity::CommodityMap; use crate::process::ProcessMap; -use crate::region::RegionSelection; +use crate::region::{RegionID, RegionSelection}; use anyhow::{bail, ensure, Context, Result}; use serde::Deserialize; use std::collections::HashSet; use std::path::Path; -use std::rc::Rc; mod objective; use objective::read_agent_objectives; @@ -25,7 +24,7 @@ const AGENT_FILE_NAME: &str = "agents.csv"; #[derive(Debug, Deserialize, PartialEq, Clone)] struct AgentRaw { /// A unique identifier for the agent. - id: Rc, + id: String, /// A text description of the agent. description: String, /// The decision rule that the agent uses to decide investment. @@ -54,7 +53,7 @@ pub fn read_agents( model_dir: &Path, commodities: &CommodityMap, processes: &ProcessMap, - region_ids: &HashSet>, + region_ids: &HashSet, milestone_years: &[u32], ) -> Result { let process_ids = processes.keys().cloned().collect(); @@ -128,7 +127,7 @@ where }; let agent = Agent { - id: Rc::clone(&agent_raw.id), + id: AgentID(agent_raw.id.into()), description: agent_raw.description, commodities: Vec::new(), search_space: Vec::new(), @@ -140,7 +139,7 @@ where }; ensure!( - agents.insert(agent_raw.id, agent).is_none(), + agents.insert(agent.id.clone(), agent).is_none(), "Duplicate agent ID" ); } diff --git a/src/input/agent/commodity.rs b/src/input/agent/commodity.rs index 4575003ce..1e984e15c 100644 --- a/src/input/agent/commodity.rs +++ b/src/input/agent/commodity.rs @@ -1,7 +1,8 @@ //! Code for reading the agent commodities CSV file. use super::super::*; -use crate::agent::{AgentCommodity, AgentMap}; +use crate::agent::{AgentCommodity, AgentID, AgentMap}; use crate::commodity::{CommodityMap, CommodityType}; +use crate::region::RegionID; use anyhow::{ensure, Context, Result}; use serde::Deserialize; use std::collections::HashMap; @@ -63,9 +64,9 @@ pub fn read_agent_commodities( model_dir: &Path, agents: &AgentMap, commodities: &CommodityMap, - region_ids: &HashSet>, + region_ids: &HashSet, milestone_years: &[u32], -) -> Result, Vec>> { +) -> Result>> { let file_path = model_dir.join(AGENT_COMMODITIES_FILE_NAME); let agent_commodities_csv = read_csv(&file_path)?; read_agent_commodities_from_iter( @@ -82,9 +83,9 @@ fn read_agent_commodities_from_iter( iter: I, agents: &AgentMap, commodities: &CommodityMap, - region_ids: &HashSet>, + region_ids: &HashSet, milestone_years: &[u32], -) -> Result, Vec>> +) -> Result>> where I: Iterator, { @@ -99,7 +100,7 @@ where // Append to Vec with the corresponding key or create agent_commodities - .entry(Rc::clone(id)) + .entry(id.clone()) .or_insert_with(|| Vec::with_capacity(1)) .push(agent_commodity); } @@ -116,10 +117,10 @@ where } fn validate_agent_commodities( - agent_commodities: &HashMap, Vec>, + agent_commodities: &HashMap>, agents: &AgentMap, commodities: &CommodityMap, - region_ids: &HashSet>, + region_ids: &HashSet, milestone_years: &[u32], ) -> Result<()> { // CHECK 1: For each agent there must be at least one commodity for all years @@ -181,14 +182,14 @@ fn validate_agent_commodities( CommodityType::SupplyEqualsDemand | CommodityType::ServiceDemand ) }) - .map(|(id, _)| Rc::clone(id)); + .map(|(id, _)| id.clone()); // Check that summed_portions contains all SVD/SED commodities for all regions and milestone // years for commodity_id in svd_and_sed_commodities { for year in milestone_years { for region in region_ids { - let key = (&*commodity_id, *year, region); + let key = (&commodity_id, *year, region); ensure!( summed_portions.contains_key(&key), "Commodity {} in year {} and region {} is not covered", @@ -207,7 +208,7 @@ fn validate_agent_commodities( mod tests { use super::*; use crate::agent::{Agent, DecisionRule}; - use crate::commodity::{Commodity, CommodityCostMap, CommodityType, DemandMap}; + use crate::commodity::{Commodity, CommodityCostMap, CommodityID, CommodityType, DemandMap}; use crate::region::RegionSelection; use crate::time_slice::TimeSliceLevel; @@ -263,9 +264,9 @@ mod tests { #[test] fn test_validate_agent_commodities() { let agents = IndexMap::from([( - Rc::from("agent1"), + AgentID::new("agent1"), Agent { - id: Rc::from("agent1"), + id: "agent1".into(), description: "An agent".into(), commodities: Vec::new(), search_space: Vec::new(), @@ -277,7 +278,7 @@ mod tests { }, )]); let mut commodities = IndexMap::from([( - Rc::from("commodity1"), + CommodityID::new("commodity1"), Rc::new(Commodity { id: "commodity1".into(), description: "A commodity".into(), @@ -287,7 +288,7 @@ mod tests { demand: DemandMap::new(), }), )]); - let region_ids = HashSet::from([Rc::from("region1")]); + let region_ids = HashSet::from([RegionID::new("region1")]); let milestone_years = vec![2020]; // Valid case @@ -296,7 +297,7 @@ mod tests { commodity: Rc::clone(commodities.get("commodity1").unwrap()), commodity_portion: 1.0, }; - let agent_commodities = HashMap::from([(Rc::from("agent1"), vec![agent_commodity])]); + let agent_commodities = HashMap::from([(AgentID::new("agent1"), vec![agent_commodity])]); assert!(validate_agent_commodities( &agent_commodities, &agents, @@ -312,7 +313,8 @@ mod tests { commodity: Rc::clone(commodities.get("commodity1").unwrap()), commodity_portion: 0.5, }; - let agent_commodities_v2 = HashMap::from([(Rc::from("agent1"), vec![agent_commodity_v2])]); + let agent_commodities_v2 = + HashMap::from([(AgentID::new("agent1"), vec![agent_commodity_v2])]); assert!(validate_agent_commodities( &agent_commodities_v2, &agents, @@ -324,7 +326,7 @@ mod tests { // Invalid case: SED commodity without associated commodity portions commodities.insert( - Rc::from("commodity2"), + CommodityID::new("commodity2"), Rc::new(Commodity { id: "commodity2".into(), description: "Another commodity".into(), diff --git a/src/input/agent/objective.rs b/src/input/agent/objective.rs index 26c263381..94fe9e062 100644 --- a/src/input/agent/objective.rs +++ b/src/input/agent/objective.rs @@ -1,10 +1,9 @@ //! Code for reading the agent objectives CSV file. use super::super::*; -use crate::agent::{AgentMap, AgentObjective, DecisionRule}; +use crate::agent::{AgentID, AgentMap, AgentObjective, DecisionRule}; use anyhow::{ensure, Context, Result}; use std::collections::HashMap; use std::path::Path; -use std::rc::Rc; const AGENT_OBJECTIVES_FILE_NAME: &str = "agent_objectives.csv"; @@ -21,7 +20,7 @@ pub fn read_agent_objectives( model_dir: &Path, agents: &AgentMap, milestone_years: &[u32], -) -> Result, Vec>> { +) -> Result>> { let file_path = model_dir.join(AGENT_OBJECTIVES_FILE_NAME); let agent_objectives_csv = read_csv(&file_path)?; read_agent_objectives_from_iter(agent_objectives_csv, agents, milestone_years) @@ -32,14 +31,14 @@ fn read_agent_objectives_from_iter( iter: I, agents: &AgentMap, milestone_years: &[u32], -) -> Result, Vec>> +) -> Result>> where I: Iterator, { let mut objectives = HashMap::new(); for objective in iter { let (id, agent) = agents - .get_key_value(objective.agent_id.as_str()) + .get_key_value(&objective.agent_id) .context("Invalid agent ID")?; // Check that required parameters are present and others are absent @@ -54,7 +53,7 @@ where // Append to Vec with the corresponding key or create objectives - .entry(Rc::clone(id)) + .entry(id.clone()) .or_insert_with(|| Vec::with_capacity(1)) .push(objective); } @@ -125,7 +124,7 @@ fn check_objective_parameter( fn check_agent_objectives( objectives: &[&AgentObjective], decision_rule: &DecisionRule, - agent_id: &str, + agent_id: &AgentID, year: u32, ) -> Result<()> { let count = objectives.len(); @@ -284,15 +283,16 @@ mod tests { #[test] fn test_check_agent_objectives() { + let agent_id = AgentID::new("agent"); let objective1 = AgentObjective { - agent_id: "agent".into(), + agent_id: agent_id.clone(), year: 2020, objective_type: ObjectiveType::EquivalentAnnualCost, decision_weight: None, decision_lexico_order: Some(1), }; let objective2 = AgentObjective { - agent_id: "agent".into(), + agent_id: agent_id.clone(), year: 2020, objective_type: ObjectiveType::EquivalentAnnualCost, decision_weight: None, @@ -302,22 +302,23 @@ mod tests { // DecisionRule::Single let decision_rule = DecisionRule::Single; let objectives = [&objective1]; - assert!(check_agent_objectives(&objectives, &decision_rule, "agent", 2020).is_ok()); + + assert!(check_agent_objectives(&objectives, &decision_rule, &agent_id, 2020).is_ok()); let objectives = [&objective1, &objective2]; - assert!(check_agent_objectives(&objectives, &decision_rule, "agent", 2020).is_err()); + assert!(check_agent_objectives(&objectives, &decision_rule, &agent_id, 2020).is_err()); // DecisionRule::Weighted let decision_rule = DecisionRule::Weighted; let objectives = [&objective1, &objective2]; - assert!(check_agent_objectives(&objectives, &decision_rule, "agent", 2020).is_ok()); + assert!(check_agent_objectives(&objectives, &decision_rule, &agent_id, 2020).is_ok()); let objectives = [&objective1]; - assert!(check_agent_objectives(&objectives, &decision_rule, "agent", 2020).is_err()); + assert!(check_agent_objectives(&objectives, &decision_rule, &agent_id, 2020).is_err()); // DecisionRule::Lexicographical let decision_rule = DecisionRule::Lexicographical { tolerance: 1.0 }; let objectives = [&objective1, &objective2]; - assert!(check_agent_objectives(&objectives, &decision_rule, "agent", 2020).is_ok()); + assert!(check_agent_objectives(&objectives, &decision_rule, &agent_id, 2020).is_ok()); let objectives = [&objective1, &objective1]; - assert!(check_agent_objectives(&objectives, &decision_rule, "agent", 2020).is_err()); + assert!(check_agent_objectives(&objectives, &decision_rule, &agent_id, 2020).is_err()); } } diff --git a/src/input/agent/region.rs b/src/input/agent/region.rs index b75650ed3..bfb46d4c7 100644 --- a/src/input/agent/region.rs +++ b/src/input/agent/region.rs @@ -1,25 +1,25 @@ //! Code for loading the agent regions CSV file. use super::super::region::read_regions_for_entity; +use crate::agent::AgentID; use crate::id::{define_region_id_getter, HasID}; -use crate::region::RegionSelection; +use crate::region::{RegionID, RegionSelection}; use anyhow::Result; use serde::Deserialize; use std::collections::{HashMap, HashSet}; use std::path::Path; -use std::rc::Rc; const AGENT_REGIONS_FILE_NAME: &str = "agent_regions.csv"; #[derive(Debug, Deserialize, PartialEq)] struct AgentRegion { - agent_id: String, + agent_id: AgentID, /// The region to which an agent belongs. - region_id: String, + region_id: RegionID, } define_region_id_getter!(AgentRegion); -impl HasID for AgentRegion { - fn get_id(&self) -> &str { +impl HasID for AgentRegion { + fn get_id(&self) -> &AgentID { &self.agent_id } } @@ -37,9 +37,9 @@ impl HasID for AgentRegion { /// A map of [`RegionSelection`]s, with the agent ID as the key. pub fn read_agent_regions( model_dir: &Path, - agent_ids: &HashSet>, - region_ids: &HashSet>, -) -> Result, RegionSelection>> { + agent_ids: &HashSet, + region_ids: &HashSet, +) -> Result> { let file_path = model_dir.join(AGENT_REGIONS_FILE_NAME); - read_regions_for_entity::(&file_path, agent_ids, region_ids) + read_regions_for_entity::(&file_path, agent_ids, region_ids) } diff --git a/src/input/agent/search_space.rs b/src/input/agent/search_space.rs index f619112f2..0e2506ed2 100644 --- a/src/input/agent/search_space.rs +++ b/src/input/agent/search_space.rs @@ -1,8 +1,9 @@ //! Code for reading the agent search space CSV file. use super::super::*; -use crate::agent::{AgentMap, AgentSearchSpace, SearchSpace}; +use crate::agent::{AgentID, AgentMap, AgentSearchSpace, SearchSpace}; use crate::commodity::CommodityMap; use crate::id::IDCollection; +use crate::process::ProcessID; use anyhow::{Context, Result}; use serde::Deserialize; use std::collections::HashMap; @@ -27,7 +28,7 @@ struct AgentSearchSpaceRaw { impl AgentSearchSpaceRaw { fn to_agent_search_space( &self, - process_ids: &HashSet>, + process_ids: &HashSet, commodities: &CommodityMap, milestone_years: &[u32], ) -> Result { @@ -37,7 +38,7 @@ impl AgentSearchSpaceRaw { Some(processes) => { let mut set = HashSet::new(); for id in processes.split(';') { - set.insert(process_ids.get_id(id)?); + set.insert(process_ids.get_id_by_str(id)?); } SearchSpace::Some(set) } @@ -76,10 +77,10 @@ impl AgentSearchSpaceRaw { pub fn read_agent_search_space( model_dir: &Path, agents: &AgentMap, - process_ids: &HashSet>, + process_ids: &HashSet, commodities: &CommodityMap, milestone_years: &[u32], -) -> Result, Vec>> { +) -> Result>> { let file_path = model_dir.join(AGENT_SEARCH_SPACE_FILE_NAME); let iter = read_csv_optional::(&file_path)?; read_agent_search_space_from_iter(iter, agents, process_ids, commodities, milestone_years) @@ -89,10 +90,10 @@ pub fn read_agent_search_space( fn read_agent_search_space_from_iter( iter: I, agents: &AgentMap, - process_ids: &HashSet>, + process_ids: &HashSet, commodities: &CommodityMap, milestone_years: &[u32], -) -> Result, Vec>> +) -> Result>> where I: Iterator, { @@ -107,7 +108,7 @@ where // Append to Vec with the corresponding key or create search_spaces - .entry(Rc::clone(id)) + .entry(id.clone()) .or_insert_with(|| Vec::with_capacity(1)) .push(search_space); } diff --git a/src/input/asset.rs b/src/input/asset.rs index 03002f029..d0e1eec6f 100644 --- a/src/input/asset.rs +++ b/src/input/asset.rs @@ -1,8 +1,10 @@ //! Code for reading [Asset]s from a CSV file. use super::*; +use crate::agent::AgentID; use crate::asset::Asset; use crate::id::IDCollection; use crate::process::ProcessMap; +use crate::region::RegionID; use anyhow::{ensure, Context, Result}; use itertools::Itertools; use serde::Deserialize; @@ -35,9 +37,9 @@ struct AssetRaw { /// A `HashMap` containing assets grouped by agent ID. pub fn read_assets( model_dir: &Path, - agent_ids: &HashSet>, + agent_ids: &HashSet, processes: &ProcessMap, - region_ids: &HashSet>, + region_ids: &HashSet, ) -> Result> { let file_path = model_dir.join(ASSETS_FILE_NAME); let assets_csv = read_csv(&file_path)?; @@ -59,19 +61,19 @@ pub fn read_assets( /// A [`Vec`] of [`Asset`]s or an error. fn read_assets_from_iter( iter: I, - agent_ids: &HashSet>, + agent_ids: &HashSet, processes: &ProcessMap, - region_ids: &HashSet>, + region_ids: &HashSet, ) -> Result> where I: Iterator, { iter.map(|asset| -> Result<_> { - let agent_id = agent_ids.get_id(&asset.agent_id)?; + let agent_id = agent_ids.get_id_by_str(&asset.agent_id)?; let process = processes .get(asset.process_id.as_str()) .with_context(|| format!("Invalid process ID: {}", &asset.process_id))?; - let region_id = region_ids.get_id(&asset.region_id)?; + let region_id = region_ids.get_id_by_str(&asset.region_id)?; ensure!( process.regions.contains(®ion_id), "Region {} is not one of the regions in which process {} operates", @@ -117,7 +119,7 @@ mod tests { parameter: process_param.clone(), regions: RegionSelection::All, }); - let processes = [(Rc::clone(&process.id), Rc::clone(&process))] + let processes = [(process.id.clone(), Rc::clone(&process))] .into_iter() .collect(); let agent_ids = ["agent1".into()].into_iter().collect(); @@ -199,7 +201,7 @@ mod tests { capacity: 1.0, commission_year: 2010, }; - let processes = [(Rc::clone(&process.id), Rc::clone(&process))] + let processes = [(process.id.clone(), Rc::clone(&process))] .into_iter() .collect(); assert!( diff --git a/src/input/commodity.rs b/src/input/commodity.rs index c8636e795..18ab1d478 100644 --- a/src/input/commodity.rs +++ b/src/input/commodity.rs @@ -1,11 +1,11 @@ //! Code for reading in commodity-related data from CSV files. use super::*; -use crate::commodity::{Commodity, CommodityMap}; +use crate::commodity::{Commodity, CommodityID, CommodityMap}; +use crate::region::RegionID; use crate::time_slice::TimeSliceInfo; use anyhow::Result; use std::collections::HashSet; use std::path::Path; -use std::rc::Rc; mod cost; use cost::read_commodity_costs; @@ -29,11 +29,12 @@ const COMMODITY_FILE_NAME: &str = "commodities.csv"; /// A map containing commodities, grouped by commodity ID or an error. pub fn read_commodities( model_dir: &Path, - region_ids: &HashSet>, + region_ids: &HashSet, time_slice_info: &TimeSliceInfo, milestone_years: &[u32], ) -> Result { - let commodities = read_csv_id_file::(&model_dir.join(COMMODITY_FILE_NAME))?; + let commodities = + read_csv_id_file::(&model_dir.join(COMMODITY_FILE_NAME))?; let commodity_ids = commodities.keys().cloned().collect(); let mut costs = read_commodity_costs( model_dir, diff --git a/src/input/commodity/cost.rs b/src/input/commodity/cost.rs index afd80e705..8da53e984 100644 --- a/src/input/commodity/cost.rs +++ b/src/input/commodity/cost.rs @@ -1,13 +1,13 @@ //! Code for reading in the commodity cost CSV file. use super::super::*; -use crate::commodity::{BalanceType, CommodityCost, CommodityCostMap}; +use crate::commodity::{BalanceType, CommodityCost, CommodityCostMap, CommodityID}; use crate::id::IDCollection; +use crate::region::RegionID; use crate::time_slice::TimeSliceInfo; use anyhow::{ensure, Context, Result}; use serde::Deserialize; use std::collections::{HashMap, HashSet}; use std::path::Path; -use std::rc::Rc; const COMMODITY_COSTS_FILE_NAME: &str = "commodity_costs.csv"; @@ -43,11 +43,11 @@ struct CommodityCostRaw { /// A map containing commodity costs, grouped by commodity ID. pub fn read_commodity_costs( model_dir: &Path, - commodity_ids: &HashSet>, - region_ids: &HashSet>, + commodity_ids: &HashSet, + region_ids: &HashSet, time_slice_info: &TimeSliceInfo, milestone_years: &[u32], -) -> Result, CommodityCostMap>> { +) -> Result> { let file_path = model_dir.join(COMMODITY_COSTS_FILE_NAME); let commodity_costs_csv = read_csv::(&file_path)?; read_commodity_costs_iter( @@ -62,11 +62,11 @@ pub fn read_commodity_costs( fn read_commodity_costs_iter( iter: I, - commodity_ids: &HashSet>, - region_ids: &HashSet>, + commodity_ids: &HashSet, + region_ids: &HashSet, time_slice_info: &TimeSliceInfo, milestone_years: &[u32], -) -> Result, CommodityCostMap>> +) -> Result> where I: Iterator, { @@ -78,8 +78,8 @@ where let mut used_milestone_years = HashMap::new(); for cost in iter { - let commodity_id = commodity_ids.get_id(&cost.commodity_id)?; - let region_id = region_ids.get_id(&cost.region_id)?; + let commodity_id = commodity_ids.get_id_by_str(&cost.commodity_id)?; + let region_id = region_ids.get_id_by_str(&cost.region_id)?; let ts_selection = time_slice_info.get_selection(&cost.time_slice)?; ensure!( @@ -101,7 +101,7 @@ where }; ensure!( - map.insert(Rc::clone(®ion_id), cost.year, time_slice.clone(), value) + map.insert(region_id.clone(), cost.year, time_slice.clone(), value) .is_none(), "Commodity cost entry covered by more than one time slice \ (region: {}, year: {}, time slice: {})", diff --git a/src/input/commodity/demand.rs b/src/input/commodity/demand.rs index 46c42997a..2d801ec1c 100644 --- a/src/input/commodity/demand.rs +++ b/src/input/commodity/demand.rs @@ -2,14 +2,14 @@ //! slice. use super::super::*; use super::demand_slicing::{read_demand_slices, DemandSliceMap, DemandSliceMapKey}; -use crate::commodity::DemandMap; +use crate::commodity::{CommodityID, DemandMap}; use crate::id::IDCollection; +use crate::region::RegionID; use crate::time_slice::TimeSliceInfo; use anyhow::{ensure, Result}; use serde::Deserialize; use std::collections::{HashMap, HashSet}; use std::path::Path; -use std::rc::Rc; const DEMAND_FILE_NAME: &str = "demand.csv"; @@ -33,15 +33,15 @@ pub type AnnualDemandMap = HashMap; #[derive(PartialEq, Eq, Hash, Debug)] pub struct AnnualDemandMapKey { /// The commodity to which this demand applies - commodity_id: Rc, + commodity_id: CommodityID, /// The region to which this demand applies - region_id: Rc, + region_id: RegionID, /// The simulation year to which this demand applies year: u32, } /// A set of commodity + region pairs -pub type CommodityRegionPairs = HashSet<(Rc, Rc)>; +pub type CommodityRegionPairs = HashSet<(CommodityID, RegionID)>; /// Reads demand data from CSV files. /// @@ -58,11 +58,11 @@ pub type CommodityRegionPairs = HashSet<(Rc, Rc)>; /// This function returns [`DemandMap`]s grouped by commodity ID. pub fn read_demand( model_dir: &Path, - commodity_ids: &HashSet>, - region_ids: &HashSet>, + commodity_ids: &HashSet, + region_ids: &HashSet, time_slice_info: &TimeSliceInfo, milestone_years: &[u32], -) -> Result, DemandMap>> { +) -> Result> { let (demand, commodity_regions) = read_demand_file(model_dir, commodity_ids, region_ids, milestone_years)?; let slices = read_demand_slices( @@ -90,8 +90,8 @@ pub fn read_demand( /// Annual demand data, grouped by commodity, region and milestone year. fn read_demand_file( model_dir: &Path, - commodity_ids: &HashSet>, - region_ids: &HashSet>, + commodity_ids: &HashSet, + region_ids: &HashSet, milestone_years: &[u32], ) -> Result<(AnnualDemandMap, CommodityRegionPairs)> { let file_path = model_dir.join(DEMAND_FILE_NAME); @@ -114,8 +114,8 @@ fn read_demand_file( /// commodity + region pairs included in the file. fn read_demand_from_iter( iter: I, - commodity_ids: &HashSet>, - region_ids: &HashSet>, + commodity_ids: &HashSet, + region_ids: &HashSet, milestone_years: &[u32], ) -> Result<(AnnualDemandMap, CommodityRegionPairs)> where @@ -128,8 +128,8 @@ where let mut commodity_regions = HashSet::new(); for demand in iter { - let commodity_id = commodity_ids.get_id(&demand.commodity_id)?; - let region_id = region_ids.get_id(&demand.region_id)?; + let commodity_id = commodity_ids.get_id_by_str(&demand.commodity_id)?; + let region_id = region_ids.get_id_by_str(&demand.region_id)?; ensure!( milestone_years.binary_search(&demand.year).is_ok(), @@ -144,8 +144,8 @@ where ); let key = AnnualDemandMapKey { - commodity_id: Rc::clone(&commodity_id), - region_id: Rc::clone(®ion_id), + commodity_id: commodity_id.clone(), + region_id: region_id.clone(), year: demand.year, }; ensure!( @@ -164,8 +164,8 @@ where for (commodity_id, region_id) in commodity_regions.iter() { for year in milestone_years.iter().copied() { let key = AnnualDemandMapKey { - commodity_id: Rc::clone(commodity_id), - region_id: Rc::clone(region_id), + commodity_id: commodity_id.clone(), + region_id: region_id.clone(), year, }; ensure!( @@ -194,15 +194,15 @@ fn compute_demand_maps( demand: &AnnualDemandMap, slices: &DemandSliceMap, time_slice_info: &TimeSliceInfo, -) -> HashMap, DemandMap> { +) -> HashMap { let mut map = HashMap::new(); for (demand_key, annual_demand) in demand.iter() { let commodity_id = &demand_key.commodity_id; let region_id = &demand_key.region_id; for time_slice in time_slice_info.iter_ids() { let slice_key = DemandSliceMapKey { - commodity_id: Rc::clone(commodity_id), - region_id: Rc::clone(region_id), + commodity_id: commodity_id.clone(), + region_id: region_id.clone(), time_slice: time_slice.clone(), }; @@ -211,12 +211,12 @@ fn compute_demand_maps( // Get or create entry let map = map - .entry(Rc::clone(commodity_id)) + .entry(commodity_id.clone()) .or_insert_with(DemandMap::new); // Add a new demand entry map.insert( - Rc::clone(region_id), + region_id.clone(), demand_key.year, time_slice.clone(), annual_demand * demand_fraction, diff --git a/src/input/commodity/demand_slicing.rs b/src/input/commodity/demand_slicing.rs index 513154fab..fd0b2e6e4 100644 --- a/src/input/commodity/demand_slicing.rs +++ b/src/input/commodity/demand_slicing.rs @@ -1,14 +1,15 @@ //! Demand slicing determines how annual demand is distributed across the year. use super::super::*; use super::demand::*; +use crate::commodity::CommodityID; use crate::id::IDCollection; +use crate::region::RegionID; use crate::time_slice::{TimeSliceID, TimeSliceInfo}; use anyhow::{ensure, Context, Result}; use itertools::Itertools; use serde::Deserialize; use std::collections::{HashMap, HashSet}; use std::path::Path; -use std::rc::Rc; const DEMAND_SLICING_FILE_NAME: &str = "demand_slicing.csv"; @@ -28,9 +29,9 @@ pub type DemandSliceMap = HashMap; #[derive(PartialEq, Eq, Hash, Debug)] pub struct DemandSliceMapKey { /// The commodity to which this demand applies - pub commodity_id: Rc, + pub commodity_id: CommodityID, /// The region to which this demand applies - pub region_id: Rc, + pub region_id: RegionID, /// The time slice to which this demand applies pub time_slice: TimeSliceID, } @@ -46,8 +47,8 @@ pub struct DemandSliceMapKey { /// * `time_slice_info` - Information about seasons and times of day pub fn read_demand_slices( model_dir: &Path, - commodity_ids: &HashSet>, - region_ids: &HashSet>, + commodity_ids: &HashSet, + region_ids: &HashSet, commodity_regions: &CommodityRegionPairs, time_slice_info: &TimeSliceInfo, ) -> Result { @@ -66,8 +67,8 @@ pub fn read_demand_slices( /// Read demand slices from an iterator fn read_demand_slices_from_iter( iter: I, - commodity_ids: &HashSet>, - region_ids: &HashSet>, + commodity_ids: &HashSet, + region_ids: &HashSet, commodity_regions: &CommodityRegionPairs, time_slice_info: &TimeSliceInfo, ) -> Result @@ -77,10 +78,10 @@ where let mut demand_slices = DemandSliceMap::new(); for slice in iter { - let commodity_id = commodity_ids.get_id(&slice.commodity_id)?; - let region_id = region_ids.get_id(&slice.region_id)?; + let commodity_id = commodity_ids.get_id_by_str(&slice.commodity_id)?; + let region_id = region_ids.get_id_by_str(&slice.region_id)?; ensure!( - commodity_regions.contains(&(Rc::clone(&commodity_id), Rc::clone(®ion_id))), + commodity_regions.contains(&(commodity_id.clone(), region_id.clone())), "Demand slicing provided for commodity {commodity_id} in region {region_id} \ without a corresponding entry in demand CSV file" ); @@ -92,8 +93,8 @@ where for (ts, demand_fraction) in time_slice_info.calculate_share(&ts_selection, slice.fraction) { let key = DemandSliceMapKey { - commodity_id: Rc::clone(&commodity_id), - region_id: Rc::clone(®ion_id), + commodity_id: commodity_id.clone(), + region_id: region_id.clone(), time_slice: ts.clone(), }; @@ -128,8 +129,8 @@ fn validate_demand_slices( .iter_ids() .map(|time_slice| { let key = DemandSliceMapKey { - commodity_id: Rc::clone(commodity_id), - region_id: Rc::clone(region_id), + commodity_id: commodity_id.clone(), + region_id: region_id.clone(), time_slice: time_slice.clone(), }; diff --git a/src/input/process.rs b/src/input/process.rs index 589aab1e3..96d8b097f 100644 --- a/src/input/process.rs +++ b/src/input/process.rs @@ -1,8 +1,10 @@ //! Code for reading process-related information from CSV files. use super::*; -use crate::commodity::{Commodity, CommodityMap, CommodityType}; -use crate::process::{ActivityLimitsMap, Process, ProcessFlow, ProcessMap, ProcessParameter}; -use crate::region::RegionSelection; +use crate::commodity::{Commodity, CommodityID, CommodityMap, CommodityType}; +use crate::process::{ + ActivityLimitsMap, Process, ProcessFlow, ProcessID, ProcessMap, ProcessParameter, +}; +use crate::region::{RegionID, RegionSelection}; use crate::time_slice::TimeSliceInfo; use anyhow::{bail, ensure, Context, Result}; use serde::Deserialize; @@ -24,10 +26,10 @@ const PROCESSES_FILE_NAME: &str = "processes.csv"; #[derive(PartialEq, Debug, Deserialize)] struct ProcessDescription { - id: Rc, + id: ProcessID, description: String, } -define_id_getter! {ProcessDescription} +define_id_getter! {ProcessDescription, ProcessID} /// Read process information from the specified CSV files. /// @@ -45,12 +47,12 @@ define_id_getter! {ProcessDescription} pub fn read_processes( model_dir: &Path, commodities: &CommodityMap, - region_ids: &HashSet>, + region_ids: &HashSet, time_slice_info: &TimeSliceInfo, milestone_years: &[u32], ) -> Result { let file_path = model_dir.join(PROCESSES_FILE_NAME); - let descriptions = read_csv_id_file::(&file_path)?; + let descriptions = read_csv_id_file::(&file_path)?; let process_ids = HashSet::from_iter(descriptions.keys().cloned()); let availabilities = read_process_availabilities(model_dir, &process_ids, time_slice_info)?; @@ -80,23 +82,23 @@ pub fn read_processes( } struct ValidationParams<'a> { - flows: &'a HashMap, Vec>, - region_ids: &'a HashSet>, + flows: &'a HashMap>, + region_ids: &'a HashSet, milestone_years: &'a [u32], time_slice_info: &'a TimeSliceInfo, - parameters: &'a HashMap, ProcessParameter>, - availabilities: &'a HashMap, ActivityLimitsMap>, + parameters: &'a HashMap, + availabilities: &'a HashMap, } /// Perform consistency checks for commodity flows. fn validate_commodities( commodities: &CommodityMap, - flows: &HashMap, Vec>, - region_ids: &HashSet>, + flows: &HashMap>, + region_ids: &HashSet, milestone_years: &[u32], time_slice_info: &TimeSliceInfo, - parameters: &HashMap, ProcessParameter>, - availabilities: &HashMap, ActivityLimitsMap>, + parameters: &HashMap, + availabilities: &HashMap, ) -> anyhow::Result<()> { let params = ValidationParams { flows, @@ -121,9 +123,9 @@ fn validate_commodities( } fn validate_sed_commodity( - commodity_id: &Rc, + commodity_id: &CommodityID, commodity: &Rc, - flows: &HashMap, Vec>, + flows: &HashMap>, ) -> Result<()> { let mut has_producer = false; let mut has_consumer = false; @@ -149,7 +151,7 @@ fn validate_sed_commodity( } fn validate_svd_commodity( - commodity_id: &Rc, + commodity_id: &CommodityID, commodity: &Rc, params: &ValidationParams, ) -> Result<()> { @@ -203,10 +205,10 @@ fn validate_svd_commodity( fn create_process_map( descriptions: I, - mut availabilities: HashMap, ActivityLimitsMap>, - mut flows: HashMap, Vec>, - mut parameters: HashMap, ProcessParameter>, - mut regions: HashMap, RegionSelection>, + mut availabilities: HashMap, + mut flows: HashMap>, + mut parameters: HashMap, + mut regions: HashMap, ) -> Result where I: Iterator, @@ -228,7 +230,7 @@ where let regions = regions.remove(id).unwrap(); let process = Process { - id: Rc::clone(id), + id: id.clone(), description: description.description, activity_limits: availabilities, flows, @@ -253,22 +255,22 @@ mod tests { struct ProcessData { descriptions: Vec, - availabilities: HashMap, ActivityLimitsMap>, - flows: HashMap, Vec>, - parameters: HashMap, ProcessParameter>, - regions: HashMap, RegionSelection>, - region_ids: HashSet>, + availabilities: HashMap, + flows: HashMap>, + parameters: HashMap, + regions: HashMap, + region_ids: HashSet, } /// Returns example data (without errors) for processes fn get_process_data() -> ProcessData { let descriptions = vec![ ProcessDescription { - id: Rc::from("process1"), + id: "process1".into(), description: "Process 1".to_string(), }, ProcessDescription { - id: Rc::from("process2"), + id: "process2".into(), description: "Process 2".to_string(), }, ]; @@ -424,17 +426,14 @@ mod tests { }); let commodities: CommodityMap = [ - (Rc::clone(&commodity_sed.id), Rc::clone(&commodity_sed)), - ( - Rc::clone(&commodity_non_sed.id), - Rc::clone(&commodity_non_sed), - ), + (commodity_sed.id.clone(), Rc::clone(&commodity_sed)), + (commodity_non_sed.id.clone(), Rc::clone(&commodity_non_sed)), ] .into_iter() .collect(); // Create mock flows - let process_flows: HashMap, Vec> = [ + let process_flows: HashMap> = [ ( "process1".into(), vec![ @@ -484,7 +483,7 @@ mod tests { .is_ok()); // Modify flows to make the validation fail - let process_flows_invalid: HashMap, Vec> = [( + let process_flows_invalid: HashMap> = [( "process1".into(), vec![ProcessFlow { process_id: "process1".into(), diff --git a/src/input/process/availability.rs b/src/input/process/availability.rs index 87502181a..598367cfb 100644 --- a/src/input/process/availability.rs +++ b/src/input/process/availability.rs @@ -1,14 +1,13 @@ //! Code for reading process availabilities CSV file use super::super::*; use crate::id::IDCollection; -use crate::process::ActivityLimitsMap; +use crate::process::{ActivityLimitsMap, ProcessID}; use crate::time_slice::TimeSliceInfo; use anyhow::{Context, Result}; use serde::Deserialize; use serde_string_enum::DeserializeLabeledStringEnum; use std::collections::{HashMap, HashSet}; use std::path::Path; -use std::rc::Rc; const PROCESS_AVAILABILITIES_FILE_NAME: &str = "process_availabilities.csv"; @@ -49,9 +48,9 @@ enum LimitType { /// error. pub fn read_process_availabilities( model_dir: &Path, - process_ids: &HashSet>, + process_ids: &HashSet, time_slice_info: &TimeSliceInfo, -) -> Result, ActivityLimitsMap>> { +) -> Result> { let file_path = model_dir.join(PROCESS_AVAILABILITIES_FILE_NAME); let process_availabilities_csv = read_csv(&file_path)?; read_process_availabilities_from_iter(process_availabilities_csv, process_ids, time_slice_info) @@ -61,16 +60,16 @@ pub fn read_process_availabilities( /// Process raw process availabilities input data into [`ActivityLimitsMap`]s fn read_process_availabilities_from_iter( iter: I, - process_ids: &HashSet>, + process_ids: &HashSet, time_slice_info: &TimeSliceInfo, -) -> Result, ActivityLimitsMap>> +) -> Result> where I: Iterator, { let mut map = HashMap::new(); for record in iter { - let process_id = process_ids.get_id(&record.process_id)?; + let process_id = process_ids.get_id_by_str(&record.process_id)?; ensure!( record.value >= 0.0 && record.value <= 1.0, @@ -108,7 +107,7 @@ where /// Check that every capacity map has an entry for every time slice fn validate_capacity_maps( - map: &HashMap, ActivityLimitsMap>, + map: &HashMap, time_slice_info: &TimeSliceInfo, ) -> Result<()> { for (process_id, map) in map.iter() { diff --git a/src/input/process/flow.rs b/src/input/process/flow.rs index 92ec927e7..22a7c368e 100644 --- a/src/input/process/flow.rs +++ b/src/input/process/flow.rs @@ -1,8 +1,8 @@ //! Code for reading process flows file use super::super::*; -use crate::commodity::CommodityMap; +use crate::commodity::{CommodityID, CommodityMap}; use crate::id::IDCollection; -use crate::process::{FlowType, ProcessFlow}; +use crate::process::{FlowType, ProcessFlow, ProcessID}; use anyhow::{ensure, Context, Result}; use serde::Deserialize; use std::collections::{HashMap, HashSet}; @@ -25,9 +25,9 @@ struct ProcessFlowRaw { /// Read process flows from a CSV file pub fn read_process_flows( model_dir: &Path, - process_ids: &HashSet>, + process_ids: &HashSet, commodities: &CommodityMap, -) -> Result, Vec>> { +) -> Result>> { let file_path = model_dir.join(PROCESS_FLOWS_FILE_NAME); let process_flow_csv = read_csv(&file_path)?; read_process_flows_from_iter(process_flow_csv, process_ids, commodities) @@ -37,9 +37,9 @@ pub fn read_process_flows( /// Read 'ProcessFlowRaw' records from an iterator and convert them into 'ProcessFlow' records. fn read_process_flows_from_iter( iter: I, - process_ids: &HashSet>, + process_ids: &HashSet, commodities: &CommodityMap, -) -> Result, Vec>> +) -> Result>> where I: Iterator, { @@ -72,7 +72,7 @@ where } // Create ProcessFlow object - let process_id = process_ids.get_id(&flow.process_id)?; + let process_id = process_ids.get_id_by_str(&flow.process_id)?; let process_flow = ProcessFlow { process_id: flow.process_id, commodity: Rc::clone(commodity), @@ -102,14 +102,14 @@ where /// /// # Returns /// An `Ok(())` if the check is successful, or an error. -fn validate_flows(flows: &HashMap, Vec>) -> Result<()> { +fn validate_flows(flows: &HashMap>) -> Result<()> { for (process_id, flows) in flows.iter() { - let mut commodities: HashSet> = HashSet::new(); + let mut commodities: HashSet = HashSet::new(); for flow in flows.iter() { let commodity_id = &flow.commodity.id; ensure!( - commodities.insert(Rc::clone(commodity_id)), + commodities.insert(commodity_id.clone()), "Process {process_id} has multiple flows for commodity {commodity_id}", ); } @@ -126,7 +126,7 @@ fn validate_flows(flows: &HashMap, Vec>) -> Result<()> { /// /// # Returns /// An `Ok(())` if the check is successful, or an error. -fn validate_pac_flows(flows: &HashMap, Vec>) -> Result<()> { +fn validate_pac_flows(flows: &HashMap>) -> Result<()> { for (process_id, flows) in flows.iter() { let mut flow_sign: Option = None; // False for inputs, true for outputs @@ -173,7 +173,7 @@ mod tests { demand: DemandMap::new(), }; - (Rc::clone(&commodity.id), commodity.into()) + (commodity.id.clone(), commodity.into()) }) .collect(); @@ -260,7 +260,7 @@ mod tests { demand: DemandMap::new(), }; - (Rc::clone(&commodity.id), commodity.into()) + (commodity.id.clone(), commodity.into()) }) .collect(); @@ -341,7 +341,7 @@ mod tests { demand: DemandMap::new(), }; - (Rc::clone(&commodity.id), commodity.into()) + (commodity.id.clone(), commodity.into()) }) .collect(); @@ -385,7 +385,7 @@ mod tests { demand: DemandMap::new(), }; - (Rc::clone(&commodity.id), commodity.into()) + (commodity.id.clone(), commodity.into()) }) .collect(); @@ -466,7 +466,7 @@ mod tests { demand: DemandMap::new(), }; - (Rc::clone(&commodity.id), commodity.into()) + (commodity.id.clone(), commodity.into()) }) .collect(); diff --git a/src/input/process/parameter.rs b/src/input/process/parameter.rs index 37d4c03ae..3878f044e 100644 --- a/src/input/process/parameter.rs +++ b/src/input/process/parameter.rs @@ -1,14 +1,13 @@ //! Code for reading process parameters CSV file use super::super::*; use crate::id::IDCollection; -use crate::process::ProcessParameter; +use crate::process::{ProcessID, ProcessParameter}; use ::log::warn; use anyhow::{ensure, Context, Result}; use serde::Deserialize; use std::collections::{HashMap, HashSet}; use std::ops::RangeInclusive; use std::path::Path; -use std::rc::Rc; const PROCESS_PARAMETERS_FILE_NAME: &str = "process_parameters.csv"; @@ -106,9 +105,9 @@ impl ProcessParameterRaw { /// Read process parameters from the specified model directory pub fn read_process_parameters( model_dir: &Path, - process_ids: &HashSet>, + process_ids: &HashSet, year_range: &RangeInclusive, -) -> Result, ProcessParameter>> { +) -> Result> { let file_path = model_dir.join(PROCESS_PARAMETERS_FILE_NAME); let iter = read_csv::(&file_path)?; read_process_parameters_from_iter(iter, process_ids, year_range) @@ -117,18 +116,18 @@ pub fn read_process_parameters( fn read_process_parameters_from_iter( iter: I, - process_ids: &HashSet>, + process_ids: &HashSet, year_range: &RangeInclusive, -) -> Result, ProcessParameter>> +) -> Result> where I: Iterator, { let mut params = HashMap::new(); for param_raw in iter { - let id = process_ids.get_id(¶m_raw.process_id)?; + let id = process_ids.get_id_by_str(¶m_raw.process_id)?; let param = param_raw.into_parameter(year_range)?; ensure!( - params.insert(Rc::clone(&id), param).is_none(), + params.insert(id.clone(), param).is_none(), "More than one parameter provided for process {id}" ); } @@ -307,7 +306,7 @@ mod tests { }, ]; - let expected: HashMap, _> = [ + let expected: HashMap = [ ( "A".into(), ProcessParameter { diff --git a/src/input/process/region.rs b/src/input/process/region.rs index 9ce39ee73..ce2cc31b3 100644 --- a/src/input/process/region.rs +++ b/src/input/process/region.rs @@ -1,24 +1,24 @@ //! Code for reading the process region CSV file use super::super::region::read_regions_for_entity; use crate::id::{define_region_id_getter, HasID}; -use crate::region::RegionSelection; +use crate::process::ProcessID; +use crate::region::{RegionID, RegionSelection}; use anyhow::Result; use serde::Deserialize; use std::collections::{HashMap, HashSet}; use std::path::Path; -use std::rc::Rc; const PROCESS_REGIONS_FILE_NAME: &str = "process_regions.csv"; #[derive(PartialEq, Debug, Deserialize)] struct ProcessRegion { - process_id: String, - region_id: String, + process_id: ProcessID, + region_id: RegionID, } define_region_id_getter! {ProcessRegion} -impl HasID for ProcessRegion { - fn get_id(&self) -> &str { +impl HasID for ProcessRegion { + fn get_id(&self) -> &ProcessID { &self.process_id } } @@ -36,9 +36,9 @@ impl HasID for ProcessRegion { /// A map of [`RegionSelection`]s, with the process ID as the key. pub fn read_process_regions( model_dir: &Path, - process_ids: &HashSet>, - region_ids: &HashSet>, -) -> Result, RegionSelection>> { + process_ids: &HashSet, + region_ids: &HashSet, +) -> Result> { let file_path = model_dir.join(PROCESS_REGIONS_FILE_NAME); - read_regions_for_entity::(&file_path, process_ids, region_ids) + read_regions_for_entity::(&file_path, process_ids, region_ids) } diff --git a/src/input/region.rs b/src/input/region.rs index d697c32f3..2a322d939 100644 --- a/src/input/region.rs +++ b/src/input/region.rs @@ -1,12 +1,11 @@ //! Code for reading region-related information from CSV files. use super::*; -use crate::id::{HasID, HasRegionID, IDCollection}; -use crate::region::{RegionMap, RegionSelection}; +use crate::id::{HasID, HasRegionID, IDCollection, IDLike}; +use crate::region::{RegionID, RegionMap, RegionSelection}; use anyhow::{anyhow, ensure, Context, Result}; use serde::de::DeserializeOwned; use std::collections::{HashMap, HashSet}; use std::path::Path; -use std::rc::Rc; const REGIONS_FILE_NAME: &str = "regions.csv"; @@ -18,7 +17,7 @@ const REGIONS_FILE_NAME: &str = "regions.csv"; /// /// # Returns /// -/// A `HashMap, Region>` with the parsed regions data or an error. The keys are region IDs. +/// A `HashMap` with the parsed regions data or an error pub fn read_regions(model_dir: &Path) -> Result { read_csv_id_file(&model_dir.join(REGIONS_FILE_NAME)) } @@ -30,26 +29,26 @@ pub fn read_regions(model_dir: &Path) -> Result { /// `file_path` - Path to CSV file /// `entity_ids` - All possible valid IDs for the entity type /// `region_ids` - All possible valid region IDs -pub fn read_regions_for_entity( +pub fn read_regions_for_entity( file_path: &Path, - entity_ids: &HashSet>, - region_ids: &HashSet>, -) -> Result, RegionSelection>> + entity_ids: &HashSet, + region_ids: &HashSet, +) -> Result> where - T: HasID + HasRegionID + DeserializeOwned, + T: HasID + HasRegionID + DeserializeOwned, { read_regions_for_entity_from_iter(read_csv::(file_path)?, entity_ids, region_ids) .with_context(|| input_err_msg(file_path)) } -fn read_regions_for_entity_from_iter( +fn read_regions_for_entity_from_iter( entity_iter: I, - entity_ids: &HashSet>, - region_ids: &HashSet>, -) -> Result, RegionSelection>> + entity_ids: &HashSet, + region_ids: &HashSet, +) -> Result> where I: Iterator, - T: HasID + HasRegionID, + T: HasID + HasRegionID, { let mut entity_regions = HashMap::new(); for entity in entity_iter { @@ -70,15 +69,15 @@ where } /// Try to insert a region ID into the specified map -fn try_insert_region( - entity_id: Rc, - region_id: &str, - region_ids: &HashSet>, - entity_regions: &mut HashMap, RegionSelection>, +fn try_insert_region( + entity_id: ID, + region_id: &RegionID, + region_ids: &HashSet, + entity_regions: &mut HashMap, ) -> Result<()> { let entity_name = entity_id.clone(); - if region_id.eq_ignore_ascii_case("all") { + if region_id.0.eq_ignore_ascii_case("all") { // Valid for all regions return match entity_regions.insert(entity_id, RegionSelection::All) { None => Ok(()), @@ -119,7 +118,7 @@ fn try_insert_region( #[cfg(test)] mod tests { use super::*; - use crate::id::{define_id_getter, define_region_id_getter}; + use crate::id::{define_id_getter, define_region_id_getter, GenericID}; use crate::region::Region; use serde::Deserialize; use std::fs::File; @@ -179,8 +178,14 @@ AP,Asia Pacific" let region_ids = ["GBR".into(), "FRA".into()].into_iter().collect(); // Insert new - let mut entity_regions = HashMap::new(); - assert!(try_insert_region("key".into(), "GBR", ®ion_ids, &mut entity_regions).is_ok()); + let mut entity_regions: HashMap = HashMap::new(); + assert!(try_insert_region( + "key".into(), + &"GBR".into(), + ®ion_ids, + &mut entity_regions + ) + .is_ok()); let selected: HashSet<_> = ["GBR".into()].into_iter().collect(); assert_eq!( *entity_regions.get("key").unwrap(), @@ -188,16 +193,31 @@ AP,Asia Pacific" ); // Insert "all" - let mut entity_regions = HashMap::new(); - assert!(try_insert_region("key".into(), "all", ®ion_ids, &mut entity_regions).is_ok()); + let mut entity_regions: HashMap = HashMap::new(); + assert!(try_insert_region( + "key".into(), + &"all".into(), + ®ion_ids, + &mut entity_regions + ) + .is_ok()); assert_eq!(*entity_regions.get("key").unwrap(), RegionSelection::All); // Append to existing let selected: HashSet<_> = ["FRA".into()].into_iter().collect(); - let mut entity_regions = [("key".into(), RegionSelection::Some(selected.clone()))] - .into_iter() - .collect(); - assert!(try_insert_region("key".into(), "GBR", ®ion_ids, &mut entity_regions).is_ok()); + let mut entity_regions = [( + GenericID::new("key"), + RegionSelection::Some(selected.clone()), + )] + .into_iter() + .collect(); + assert!(try_insert_region( + "key".into(), + &"GBR".into(), + ®ion_ids, + &mut entity_regions + ) + .is_ok()); let selected: HashSet<_> = ["FRA".into(), "GBR".into()].into_iter().collect(); assert_eq!( *entity_regions.get("key").unwrap(), @@ -205,30 +225,53 @@ AP,Asia Pacific" ); // "All" already specified - let mut entity_regions = [("key".into(), RegionSelection::All)].into_iter().collect(); - assert!(try_insert_region("key".into(), "GBR", ®ion_ids, &mut entity_regions).is_err()); + let mut entity_regions = [(GenericID::new("key"), RegionSelection::All)] + .into_iter() + .collect(); + assert!(try_insert_region( + "key".into(), + &"GBR".into(), + ®ion_ids, + &mut entity_regions + ) + .is_err()); // "GBR" specified twice let selected: HashSet<_> = ["GBR".into()].into_iter().collect(); - let mut entity_regions = [("key".into(), RegionSelection::Some(selected))] + let mut entity_regions = [(GenericID::new("key"), RegionSelection::Some(selected))] .into_iter() .collect(); - assert!(try_insert_region("key".into(), "GBR", ®ion_ids, &mut entity_regions).is_err()); + assert!(try_insert_region( + "key".into(), + &"GBR".into(), + ®ion_ids, + &mut entity_regions + ) + .is_err()); // Try appending "all" to existing let selected: HashSet<_> = ["FRA".into()].into_iter().collect(); - let mut entity_regions = [("key".into(), RegionSelection::Some(selected.clone()))] - .into_iter() - .collect(); - assert!(try_insert_region("key".into(), "all", ®ion_ids, &mut entity_regions).is_err()); + let mut entity_regions = [( + GenericID::new("key"), + RegionSelection::Some(selected.clone()), + )] + .into_iter() + .collect(); + assert!(try_insert_region( + "key".into(), + &"all".into(), + ®ion_ids, + &mut entity_regions + ) + .is_err()); } #[derive(Deserialize, PartialEq)] struct Record { - id: String, - region_id: String, + id: GenericID, + region_id: RegionID, } - define_id_getter! {Record} + define_id_getter! {Record, GenericID} define_region_id_getter! {Record} #[test] diff --git a/src/model.rs b/src/model.rs index ab0c8ffcf..58849602b 100644 --- a/src/model.rs +++ b/src/model.rs @@ -4,12 +4,11 @@ use crate::agent::AgentMap; use crate::commodity::CommodityMap; use crate::input::{input_err_msg, read_toml}; use crate::process::ProcessMap; -use crate::region::RegionMap; +use crate::region::{RegionID, RegionMap}; use crate::time_slice::TimeSliceInfo; use anyhow::{ensure, Context, Result}; use serde::Deserialize; use std::path::Path; -use std::rc::Rc; const MODEL_FILE_NAME: &str = "model.toml"; @@ -85,7 +84,7 @@ impl Model { } /// Iterate over the model's regions (region IDs). - pub fn iter_regions(&self) -> impl Iterator> + '_ { + pub fn iter_regions(&self) -> impl Iterator + '_ { self.regions.keys() } } diff --git a/src/output.rs b/src/output.rs index c0c7460ea..fa3cc1359 100644 --- a/src/output.rs +++ b/src/output.rs @@ -1,5 +1,9 @@ //! The module responsible for writing output data to disk. +use crate::agent::AgentID; use crate::asset::{Asset, AssetID, AssetPool}; +use crate::commodity::CommodityID; +use crate::process::ProcessID; +use crate::region::RegionID; use crate::simulation::CommodityPrices; use crate::time_slice::TimeSliceID; use anyhow::{Context, Result}; @@ -8,7 +12,6 @@ use serde::{Deserialize, Serialize}; use std::fs; use std::fs::File; use std::path::{Path, PathBuf}; -use std::rc::Rc; /// The root folder in which model-specific output folders will be created const OUTPUT_DIRECTORY_ROOT: &str = "muse2_results"; @@ -52,9 +55,9 @@ pub fn create_output_directory(model_dir: &Path) -> Result { #[derive(Serialize, Deserialize, Debug, PartialEq)] struct AssetRow { milestone_year: u32, - process_id: Rc, - region_id: Rc, - agent_id: Rc, + process_id: ProcessID, + region_id: RegionID, + agent_id: AgentID, commission_year: u32, } @@ -62,9 +65,9 @@ impl AssetRow { fn new(milestone_year: u32, asset: &Asset) -> Self { Self { milestone_year, - process_id: Rc::clone(&asset.process.id), - region_id: Rc::clone(&asset.region_id), - agent_id: Rc::clone(&asset.agent_id), + process_id: asset.process.id.clone(), + region_id: asset.region_id.clone(), + agent_id: asset.agent_id.clone(), commission_year: asset.commission_year, } } @@ -75,7 +78,7 @@ impl AssetRow { /// This will be written along with an [`AssetRow`] containing asset-related info. #[derive(Serialize, Deserialize, Debug, PartialEq)] struct CommodityFlowRow { - commodity_id: Rc, + commodity_id: CommodityID, time_slice: String, flow: f64, } @@ -84,7 +87,7 @@ struct CommodityFlowRow { #[derive(Serialize, Deserialize, Debug, PartialEq)] struct CommodityPriceRow { milestone_year: u32, - commodity_id: Rc, + commodity_id: CommodityID, time_slice: String, price: f64, } @@ -132,13 +135,13 @@ impl DataWriter { flows: I, ) -> Result<()> where - I: Iterator, &'a TimeSliceID, f64)>, + I: Iterator, { for (asset_id, commodity_id, time_slice, flow) in flows { let asset = assets.get(asset_id).unwrap(); let asset_row = AssetRow::new(milestone_year, asset); let flow_row = CommodityFlowRow { - commodity_id: Rc::clone(commodity_id), + commodity_id: commodity_id.clone(), time_slice: time_slice.to_string(), flow, }; @@ -153,7 +156,7 @@ impl DataWriter { for (commodity_id, time_slice, price) in prices.iter() { let row = CommodityPriceRow { milestone_year, - commodity_id: Rc::clone(commodity_id), + commodity_id: commodity_id.clone(), time_slice: time_slice.to_string(), price, }; @@ -180,11 +183,12 @@ mod tests { use crate::region::RegionSelection; use crate::time_slice::TimeSliceID; use itertools::{assert_equal, Itertools}; + use std::rc::Rc; use std::{collections::HashMap, iter}; use tempfile::tempdir; fn get_asset() -> Asset { - let process_id = "process1".into(); + let process_id = ProcessID::new("process1"); let region_id = "GBR".into(); let agent_id = "agent1".into(); let commission_year = 2015; @@ -198,7 +202,7 @@ mod tests { capacity_to_activity: 3.0, }; let process = Rc::new(Process { - id: Rc::clone(&process_id), + id: process_id, description: "Description".into(), activity_limits: HashMap::new(), flows: vec![], diff --git a/src/process.rs b/src/process.rs index ac9e20567..9d45e7894 100644 --- a/src/process.rs +++ b/src/process.rs @@ -1,6 +1,7 @@ //! Processes are used for converting between different commodities. The data structures in this //! module are used to represent these conversions along with the associated costs. use crate::commodity::Commodity; +use crate::id::define_id_type; use crate::region::RegionSelection; use crate::time_slice::TimeSliceID; use indexmap::IndexMap; @@ -10,14 +11,16 @@ use std::collections::HashMap; use std::ops::RangeInclusive; use std::rc::Rc; +define_id_type! {ProcessID} + /// A map of [`Process`]es, keyed by process ID -pub type ProcessMap = IndexMap, Rc>; +pub type ProcessMap = IndexMap>; /// Represents a process within the simulation #[derive(PartialEq, Debug)] pub struct Process { /// A unique identifier for the process (e.g. GASDRV) - pub id: Rc, + pub id: ProcessID, /// A human-readable description for the process (e.g. dry gas extraction) pub description: String, /// The activity limits for each time slice (as a fraction of maximum) diff --git a/src/region.rs b/src/region.rs index a9d188cd2..ea7d9d7c2 100644 --- a/src/region.rs +++ b/src/region.rs @@ -1,24 +1,26 @@ //! Regions represent different geographical areas in which agents, processes, etc. are active. use crate::id::define_id_getter; +use crate::id::define_id_type; use indexmap::IndexMap; use itertools::Itertools; use serde::Deserialize; use std::collections::HashSet; use std::fmt::Display; -use std::rc::Rc; + +define_id_type! {RegionID} /// A map of [`Region`]s, keyed by region ID -pub type RegionMap = IndexMap, Region>; +pub type RegionMap = IndexMap; /// Represents a region with an ID and a longer description. #[derive(Debug, Deserialize, PartialEq)] pub struct Region { /// A unique identifier for a region (e.g. "GBR"). - pub id: Rc, + pub id: RegionID, /// A text description of the region (e.g. "United Kingdom"). pub description: String, } -define_id_getter! {Region} +define_id_getter! {Region, RegionID} /// Represents multiple regions #[derive(PartialEq, Debug, Clone, Default)] @@ -27,12 +29,12 @@ pub enum RegionSelection { #[default] All, /// Only some regions are covered - Some(HashSet>), + Some(HashSet), } impl RegionSelection { /// Returns true if the [`RegionSelection`] covers a given region - pub fn contains(&self, region_id: &str) -> bool { + pub fn contains(&self, region_id: &RegionID) -> bool { match self { Self::All => true, Self::Some(regions) => regions.contains(region_id), diff --git a/src/simulation/optimisation.rs b/src/simulation/optimisation.rs index 491f6d128..50e1fc1da 100644 --- a/src/simulation/optimisation.rs +++ b/src/simulation/optimisation.rs @@ -2,14 +2,13 @@ //! //! This is used to calculate commodity flows and prices. use crate::asset::{Asset, AssetID, AssetPool}; -use crate::commodity::BalanceType; +use crate::commodity::{BalanceType, CommodityID}; use crate::model::Model; use crate::process::ProcessFlow; use crate::time_slice::{TimeSliceID, TimeSliceInfo}; use anyhow::{anyhow, Result}; use highs::{HighsModelStatus, RowProblem as Problem, Sense}; use indexmap::IndexMap; -use std::rc::Rc; mod constraints; use constraints::{add_asset_constraints, CapacityConstraintKeys, CommodityBalanceConstraintKeys}; @@ -34,10 +33,15 @@ pub struct VariableMap(IndexMap); impl VariableMap { /// Get the [`Variable`] corresponding to the given parameters. - fn get(&self, asset_id: AssetID, commodity_id: &Rc, time_slice: &TimeSliceID) -> Variable { + fn get( + &self, + asset_id: AssetID, + commodity_id: &CommodityID, + time_slice: &TimeSliceID, + ) -> Variable { let key = VariableMapKey { asset_id, - commodity_id: Rc::clone(commodity_id), + commodity_id: commodity_id.clone(), time_slice: time_slice.clone(), }; @@ -52,13 +56,13 @@ impl VariableMap { #[derive(Eq, PartialEq, Hash)] struct VariableMapKey { asset_id: AssetID, - commodity_id: Rc, + commodity_id: CommodityID, time_slice: TimeSliceID, } impl VariableMapKey { /// Create a new [`VariableMapKey`] - fn new(asset_id: AssetID, commodity_id: Rc, time_slice: TimeSliceID) -> Self { + fn new(asset_id: AssetID, commodity_id: CommodityID, time_slice: TimeSliceID) -> Self { Self { asset_id, commodity_id, @@ -87,7 +91,7 @@ impl Solution<'_> { /// An iterator of tuples containing an asset ID, commodity, time slice and flow. pub fn iter_commodity_flows_for_assets( &self, - ) -> impl Iterator, &TimeSliceID, f64)> { + ) -> impl Iterator { self.variables .0 .keys() @@ -98,7 +102,7 @@ impl Solution<'_> { /// Keys and dual values for commodity balance constraints. pub fn iter_commodity_balance_duals( &self, - ) -> impl Iterator, &TimeSliceID, f64)> { + ) -> impl Iterator { // Each commodity balance constraint applies to a particular time slice // selection (depending on time slice level). Where this covers multiple timeslices, // we return the same dual for each individual timeslice. @@ -223,11 +227,8 @@ fn add_variables( problem.add_column(coeff, 0.0..) }; - let key = VariableMapKey::new( - asset.id, - Rc::clone(&flow.commodity.id), - time_slice.clone(), - ); + let key = + VariableMapKey::new(asset.id, flow.commodity.id.clone(), time_slice.clone()); let existing = variables.0.insert(key, var).is_some(); assert!(!existing, "Duplicate entry for var"); diff --git a/src/simulation/optimisation/constraints.rs b/src/simulation/optimisation/constraints.rs index f77ea45c3..0abe29e7e 100644 --- a/src/simulation/optimisation/constraints.rs +++ b/src/simulation/optimisation/constraints.rs @@ -1,6 +1,6 @@ //! Code for adding constraints to the dispatch optimisation problem. use crate::asset::{AssetID, AssetPool}; -use crate::commodity::CommodityType; +use crate::commodity::{CommodityID, CommodityType}; use crate::model::Model; use crate::time_slice::{TimeSliceID, TimeSliceInfo, TimeSliceSelection}; use highs::RowProblem as Problem; @@ -9,7 +9,7 @@ use std::rc::Rc; use super::VariableMap; /// Indicates the commodity ID and time slice selection covered by each commodity balance constraint -pub type CommodityBalanceConstraintKeys = Vec<(Rc, TimeSliceSelection)>; +pub type CommodityBalanceConstraintKeys = Vec<(CommodityID, TimeSliceSelection)>; /// Indicates the asset ID and time slice covered by each capacity constraint pub type CapacityConstraintKeys = Vec<(AssetID, TimeSliceID)>; @@ -135,7 +135,7 @@ fn add_commodity_balance_constraints( problem.add_row(rhs..=rhs, terms.drain(0..)); // Keep track of the order in which constraints were added - keys.push((Rc::clone(&commodity.id), ts_selection)); + keys.push((commodity.id.clone(), ts_selection)); } } } diff --git a/src/simulation/prices.rs b/src/simulation/prices.rs index 3b2dad265..31e0f49b2 100644 --- a/src/simulation/prices.rs +++ b/src/simulation/prices.rs @@ -1,15 +1,15 @@ //! Code for updating the simulation state. use super::optimisation::Solution; use crate::asset::AssetPool; +use crate::commodity::CommodityID; use crate::model::Model; use crate::time_slice::{TimeSliceID, TimeSliceInfo}; use indexmap::IndexMap; use log::warn; use std::collections::{HashMap, HashSet}; -use std::rc::Rc; /// A combination of commodity ID and time slice -type CommodityPriceKey = (Rc, TimeSliceID); +type CommodityPriceKey = (CommodityID, TimeSliceID); /// A map relating commodity ID + time slice to current price (endogenous) #[derive(Default)] @@ -46,7 +46,11 @@ impl CommodityPrices { /// # Returns /// /// The set of commodities for which prices were added. - fn add_from_solution(&mut self, solution: &Solution, assets: &AssetPool) -> HashSet> { + fn add_from_solution( + &mut self, + solution: &Solution, + assets: &AssetPool, + ) -> HashSet { let mut commodities_updated = HashSet::new(); // Calculate highest capacity dual for each commodity/timeslice @@ -77,10 +81,10 @@ impl CommodityPrices { // Add the highest capacity dual for each commodity/timeslice to each commodity balance dual for (commodity_id, time_slice, dual) in solution.iter_commodity_balance_duals() { - let key = (Rc::clone(commodity_id), time_slice.clone()); + let key = (commodity_id.clone(), time_slice.clone()); let price = dual + highest_duals.get(&key).unwrap_or(&0.0); self.insert(commodity_id, time_slice, price); - commodities_updated.insert(Rc::clone(commodity_id)); + commodities_updated.insert(commodity_id.clone()); } commodities_updated @@ -94,7 +98,7 @@ impl CommodityPrices { /// * `time_slice_info` - Information about time slices fn add_remaining<'a, I>(&mut self, commodity_ids: I, time_slice_info: &TimeSliceInfo) where - I: Iterator>, + I: Iterator, { for commodity_id in commodity_ids { warn!("No prices calculated for commodity {commodity_id}; setting to NaN"); @@ -105,8 +109,8 @@ impl CommodityPrices { } /// Insert a price for the given commodity and time slice - pub fn insert(&mut self, commodity_id: &Rc, time_slice: &TimeSliceID, price: f64) { - let key = (Rc::clone(commodity_id), time_slice.clone()); + pub fn insert(&mut self, commodity_id: &CommodityID, time_slice: &TimeSliceID, price: f64) { + let key = (commodity_id.clone(), time_slice.clone()); self.0.insert(key, price); } @@ -115,7 +119,7 @@ impl CommodityPrices { /// # Returns /// /// An iterator of tuples containing commodity ID, time slice and price. - pub fn iter(&self) -> impl Iterator, &TimeSliceID, f64)> { + pub fn iter(&self) -> impl Iterator { self.0 .iter() .map(|((commodity_id, ts), price)| (commodity_id, ts, *price))