diff --git a/src/muse/sectors/subsector.py b/src/muse/sectors/subsector.py index 3b2ddee5..c411fbe7 100644 --- a/src/muse/sectors/subsector.py +++ b/src/muse/sectors/subsector.py @@ -175,9 +175,7 @@ def factory( if hasattr(settings, "commodities"): commodities = settings.commodities else: - commodities = aggregate_enduses( - [agent.assets for agent in agents], technologies - ) + commodities = aggregate_enduses(technologies) # len(commodities) == 0 may happen only if # we run only one region or all regions have no outputs @@ -208,21 +206,10 @@ def factory( ) -def aggregate_enduses( - assets: Sequence[xr.Dataset | xr.DataArray], technologies: xr.Dataset -) -> Sequence[str]: - """Aggregate enduse commodities for input assets. - - This function is meant as a helper to figure out the commodities attached to a group - of agents. - """ +def aggregate_enduses(technologies: xr.Dataset) -> Sequence[str]: + """Aggregate enduse commodities for a set of technologies.""" from muse.commodities import is_enduse - techs = set.union(*(set(data.technology.values) for data in assets)) - outputs = technologies.fixed_outputs.sel( - commodity=is_enduse(technologies.comm_usage), technology=list(techs) - ) - - return outputs.commodity.sel( - commodity=outputs.any([u for u in outputs.dims if u != "commodity"]) - ).values.tolist() + return technologies.sel( + commodity=is_enduse(technologies.comm_usage) + ).commodity.values.tolist() diff --git a/tests/test_subsector.py b/tests/test_subsector.py index 23406139..efcbc1fb 100644 --- a/tests/test_subsector.py +++ b/tests/test_subsector.py @@ -36,9 +36,7 @@ def test_subsector_investing_aggregation(): agents = list(examples.sector(sname, model).agents) sector = next(sector for sector in mca.sectors if sector.name == sname) technologies = sector.technologies - commodities = aggregate_enduses( - (agent.assets for agent in agents), technologies - ) + commodities = aggregate_enduses(technologies) market = mca.market.sel( commodity=technologies.commodity, region=technologies.region ).interp(year=[2020, 2025]) @@ -89,7 +87,7 @@ def test_subsector_noninvesting_aggregation(market, model, technologies, tmp_pat param["decision"]["parameters"] = ("ALCOE", False, 1) param.pop("quantity") agents = [create_agent(technologies=technologies, **param) for param in params] - commodities = aggregate_enduses((agent.assets for agent in agents), technologies) + commodities = aggregate_enduses(technologies) subsector = Subsector( agents,