diff --git a/src/muse/agents/agent.py b/src/muse/agents/agent.py index eddc6600..15801bd9 100644 --- a/src/muse/agents/agent.py +++ b/src/muse/agents/agent.py @@ -315,7 +315,9 @@ def next( # Calculate the search space search_space = ( - self.search_rules(self, demand, technologies, market).fillna(0).astype(int) + self.search_rules(self, demand, technologies=technologies, market=market) + .fillna(0) + .astype(int) ) # Skip forward if the search space is empty diff --git a/src/muse/examples.py b/src/muse/examples.py index 3aa3e287..5b141482 100644 --- a/src/muse/examples.py +++ b/src/muse/examples.py @@ -420,7 +420,7 @@ def _trade_search_space(sector: str, model: str = "default") -> xr.DataArray: a.uuid: cast(Agent, a).search_rules( agent=a, demand=market.consumption.isel(year=0, drop=True), - technologies=loaded_sector.technologies, + technologies=loaded_sector.technologies.isel(year=0, drop=True), market=market, ) for a in loaded_sector.agents diff --git a/src/muse/filters.py b/src/muse/filters.py index 42473108..1b9d351a 100644 --- a/src/muse/filters.py +++ b/src/muse/filters.py @@ -20,8 +20,10 @@ def search_space_filter( agent: Agent, search_space: xr.DataArray, - technologies: xr.Dataset, - market: xr.Dataset + *, + additional_argument1, + additional_argument2, + **kwargs, ) -> xr.DataArray: pass @@ -29,10 +31,8 @@ def search_space_filter( agent: the agent relevant to the search space. The filters may need to query the agent for parameters, e.g. the current year, the tolerance, etc. search_space: the current search space. - technologies: A data set characterising the technologies from which the - agent can draw assets. - market: Market variables, such as prices or current capacity and retirement - profile. + any additional arguments must follow the `*` argument, and must be passed as keyword + arguments. Returns: A new search space with the same data-type as the input search-space, but @@ -53,7 +53,6 @@ def search_space_initializer( agent: Agent, demand: xr.DataArray, technologies: xr.Dataset, - market: xr.Dataset ) -> xr.DataArray: pass @@ -64,8 +63,6 @@ def search_space_initializer( assets). technologies: A data set characterising the technologies from which the agent can draw assets. - market: Market variables, such as prices or current capacity and retirement - profile. Returns: An initial search space @@ -103,14 +100,14 @@ def search_space_initializer( from muse.agents import Agent from muse.registration import registrator -SSF_SIGNATURE = Callable[[Agent, xr.DataArray, xr.Dataset, xr.Dataset], xr.DataArray] +SSF_SIGNATURE = Callable[[Agent, xr.DataArray], xr.DataArray] """ Search space filter signature """ SEARCH_SPACE_FILTERS: MutableMapping[str, SSF_SIGNATURE] = {} """Filters for selecting technology search spaces.""" -SSI_SIGNATURE = Callable[[Agent, xr.DataArray, xr.Dataset, xr.Dataset], xr.DataArray] +SSI_SIGNATURE = Callable[[Agent, xr.DataArray], xr.DataArray] """ Search space initializer signature """ SEARCH_SPACE_INITIALIZERS: MutableMapping[str, SSI_SIGNATURE] = {} @@ -131,10 +128,12 @@ def register_filter(function: SSF_SIGNATURE) -> Callable: from functools import wraps @wraps(function) - def decorated( - agent: Agent, search_space: xr.DataArray, *args, **kwargs - ) -> xr.DataArray: - result = function(agent, search_space, *args, **kwargs) # type: ignore + def decorated(agent: Agent, search_space: xr.DataArray, **kwargs) -> xr.DataArray: + # Check inputs + if "technologies" in kwargs: + assert "year" not in kwargs["technologies"].dims + + result = function(agent, search_space, **kwargs) # type: ignore if isinstance(result, xr.DataArray): result.name = search_space.name return result @@ -150,8 +149,12 @@ def register_initializer(function: SSI_SIGNATURE) -> Callable: from functools import wraps @wraps(function) - def decorated(agent: Agent, *args, **kwargs) -> xr.DataArray: - result = function(agent, *args, **kwargs) # type: ignore + def decorated(agent: Agent, demand: xr.DataArray, **kwargs) -> xr.DataArray: + # Check inputs + if "technologies" in kwargs: + assert "year" not in kwargs["technologies"].dims + + result = function(agent, demand, **kwargs) # type: ignore if isinstance(result, xr.DataArray): result.name = "search_space" return result @@ -221,11 +224,11 @@ def factory( ), ] - def filters(agent: Agent, demand: xr.DataArray, *args, **kwargs) -> xr.DataArray: + def filters(agent: Agent, demand: xr.DataArray, **kwargs) -> xr.DataArray: """Applies a series of filter to determine the search space.""" - result = functions[0](agent, demand, *args, **kwargs) + result = functions[0](agent, demand, **kwargs) for function in functions[1:]: - result = function(agent, result, *args, **kwargs) + result = function(agent, result, **kwargs) return result return filters @@ -235,8 +238,8 @@ def filters(agent: Agent, demand: xr.DataArray, *args, **kwargs) -> xr.DataArray def same_enduse( agent: Agent, search_space: xr.DataArray, + *, technologies: xr.Dataset, - *args, **kwargs, ) -> xr.DataArray: """Only allow for technologies with at least the same end-use.""" @@ -244,7 +247,6 @@ def same_enduse( tech_enduses = agent.filter_input( technologies.fixed_outputs, - year=agent.year, commodity=is_enduse(technologies.comm_usage), ) tech_enduses = (tech_enduses > 0).astype(int).rename(technology="replacement") @@ -253,14 +255,14 @@ def same_enduse( @register_filter(name="all") -def identity(agent: Agent, search_space: xr.DataArray, *args, **kwargs) -> xr.DataArray: +def identity(agent: Agent, search_space: xr.DataArray, **kwargs) -> xr.DataArray: """Returns search space as given.""" return search_space @register_filter(name="similar") def similar_technology( - agent: Agent, search_space: xr.DataArray, technologies: xr.Dataset, *args, **kwargs + agent: Agent, search_space: xr.DataArray, *, technologies: xr.Dataset, **kwargs ): """Filters technologies with the same type.""" tech_type = agent.filter_input(technologies.tech_type) @@ -271,7 +273,7 @@ def similar_technology( @register_filter(name="fueltype") def same_fuels( - agent: Agent, search_space: xr.DataArray, technologies: xr.Dataset, *args, **kwargs + agent: Agent, search_space: xr.DataArray, *, technologies: xr.Dataset, **kwargs ): """Filters technologies with the same fuel type.""" fuel = agent.filter_input(technologies.fuel) @@ -284,8 +286,9 @@ def same_fuels( def currently_existing_tech( agent: Agent, search_space: xr.DataArray, - technologies: xr.Dataset, + *, market: xr.Dataset, + **kwargs, ) -> xr.DataArray: """Only consider technologies that currently exist in the market. @@ -308,8 +311,9 @@ def currently_existing_tech( def currently_referenced_tech( agent: Agent, search_space: xr.DataArray, - technologies: xr.Dataset, + *, market: xr.Dataset, + **kwargs, ) -> xr.DataArray: """Only consider technologies that are currently referenced in the market. @@ -327,9 +331,8 @@ def currently_referenced_tech( def maturity( agent: Agent, search_space: xr.DataArray, - technologies: xr.Dataset, + *, market: xr.Dataset, - enduse_label: str = "service", **kwargs, ) -> xr.DataArray: """Only allows technologies that have achieve a given market share. @@ -359,14 +362,13 @@ def maturity( def spend_limit( agent: Agent, search_space: xr.DataArray, + *, technologies: xr.Dataset, - market: xr.Dataset, - enduse_label: str = "service", **kwargs, ) -> xr.DataArray: """Only allows technologies with a unit capital cost lower than the spend limit.""" limit = agent.spend_limit - unit_capex = agent.filter_input(technologies.cap_par, year=agent.year) + unit_capex = agent.filter_input(technologies.cap_par) condition = (unit_capex <= limit).rename("spend_limit") techs = ( condition.technology.where(condition, drop=True).drop_vars("technology").values @@ -387,8 +389,6 @@ def spend_limit( def compress( agent: Agent, search_space: xr.DataArray, - technologies: xr.Dataset, - market: xr.Dataset, **kwargs, ) -> xr.DataArray: """Compress search space to include only potential technologies. @@ -411,8 +411,6 @@ def compress( def reduce_asset( agent: Agent, search_space: xr.DataArray, - technologies: xr.Dataset, - market: xr.Dataset, **kwargs, ) -> xr.DataArray: """Reduce over assets.""" @@ -423,8 +421,6 @@ def reduce_asset( def with_asset_technology( agent: Agent, search_space: xr.DataArray, - technologies: xr.Dataset, - market: xr.Dataset, **kwargs, ) -> xr.DataArray: """Search space *also* contains its asset technology for each asset.""" @@ -433,7 +429,7 @@ def with_asset_technology( @register_initializer(name="from_techs") def initialize_from_technologies( - agent: Agent, demand: xr.DataArray, technologies: xr.Dataset, *args, **kwargs + agent: Agent, demand: xr.DataArray, *, technologies: xr.Dataset, **kwargs ): """Initialize a search space from existing technologies.""" coords = ( @@ -452,8 +448,8 @@ def initialize_from_technologies( def initialize_from_assets( agent: Agent, demand: xr.DataArray, + *, technologies: xr.Dataset, - *args, coords: Sequence[str] = ("region", "technology"), **kwargs, ): diff --git a/tests/test_filters.py b/tests/test_filters.py index 43c5a01c..40edc423 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -18,6 +18,12 @@ def search_space(retro_agent, technologies): ) +@fixture +def technologies(technologies): + # Filters must take technology data for a single year + return technologies.sel(year=2010) + + @mark.usefixtures("save_registries") def test_filter_registering(): from muse.filters import SEARCH_SPACE_FILTERS @@ -40,7 +46,7 @@ def b_filter(retro_agent, search_space: xr.DataArray): @mark.usefixtures("save_registries") def test_filtering(): @register_initializer - def start(*args, **kwargs): + def start(retro_agent, demand, **kwargs): return list(range(5)) @register_filter @@ -53,24 +59,24 @@ def first(retro_agent, search_space, switch=True, data=None): def second(retro_agent, search_space, switch=True, data=None): return [u for u in search_space if u in data] - sp = start(None, None, None) + sp = start(None, None) assert factory(["start", "first"])(None, sp) == sp[2:] - assert factory(["start", "first"])(None, sp, False) == sp[:2] + assert factory(["start", "first"])(None, sp, switch=False) == sp[:2] assert factory(["start", "second"])(None, sp, data=(1, 3, 5)) == [1, 3] assert factory(["start", "first", "second"])(None, sp, data=(1, 3, 5)) == [3] - assert factory(["start", "first", "second"])(None, sp, False, (1, 3, 5)) == [1] + assert factory(["start", "first", "second"])( + None, sp, switch=False, data=(1, 3, 5) + ) == [1] def test_same_enduse(retro_agent, technologies, search_space): from muse.commodities import is_enduse from muse.filters import same_enduse - result = same_enduse(retro_agent, search_space, technologies) + result = same_enduse(retro_agent, search_space, technologies=technologies) enduses = is_enduse(technologies.comm_usage) - finputs = technologies.sel( - region=retro_agent.region, year=retro_agent.year, commodity=enduses - ) + finputs = technologies.sel(region=retro_agent.region, commodity=enduses) finputs = finputs.fixed_outputs > 0 expected = search_space.copy() @@ -91,7 +97,7 @@ def test_same_enduse(retro_agent, technologies, search_space): def test_similar_tech(retro_agent, search_space, technologies): from muse.filters import similar_technology - actual = similar_technology(retro_agent, search_space, technologies) + actual = similar_technology(retro_agent, search_space, technologies=technologies) assert sorted(actual.dims) == sorted(search_space.dims) tech_type = technologies.tech_type @@ -104,7 +110,7 @@ def test_similar_tech(retro_agent, search_space, technologies): def test_similar_fuels(retro_agent, search_space, technologies): from muse.filters import same_fuels - actual = same_fuels(retro_agent, search_space, technologies) + actual = same_fuels(retro_agent, search_space, technologies=technologies) assert sorted(actual.dims) == sorted(search_space.dims) fuel_type = technologies.fuel @@ -119,14 +125,14 @@ def test_currently_existing(retro_agent, search_space, technologies, agent_marke agent_market.capacity[:] = 0 actual = currently_existing_tech( - retro_agent, search_space, technologies, agent_market + retro_agent, search_space, technologies=technologies, market=agent_market ) assert sorted(actual.dims) == sorted(search_space.dims) assert not actual.any() agent_market.capacity[:] = 1 actual = currently_existing_tech( - retro_agent, search_space, technologies, agent_market + retro_agent, search_space, technologies=technologies, market=agent_market ) assert sorted(actual.dims) == sorted(search_space.dims) in_market = search_space.replacement.isin(agent_market.technology) @@ -141,7 +147,7 @@ def test_currently_existing(retro_agent, search_space, technologies, agent_marke agent_market.capacity[:] = 0 agent_market.capacity.loc[{"technology": agent_market.technology.isin(techs)}] = 1 actual = currently_existing_tech( - retro_agent, search_space, technologies, agent_market + retro_agent, search_space, technologies=technologies, market=agent_market ) assert sorted(actual.dims) == sorted(search_space.dims) assert not actual.sel(replacement=~in_market).any() @@ -188,7 +194,7 @@ def test_init_from_tech(demand_share, technologies, agent_market): agent = namedtuple("DummyAgent", ["tolerance"])(tolerance=1e-8) - space = initialize_from_technologies(agent, demand_share, technologies) + space = initialize_from_technologies(agent, demand_share, technologies=technologies) assert set(space.dims) == {"asset", "replacement"} assert (space.asset.values == demand_share.asset.values).all() assert (space.replacement.values == technologies.technology.values).all()