Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions segregation/inference/comparative.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,13 @@ def _prepare_random_label(seg_class_1, seg_class_2):
seg_class_1.group_pop_var,
seg_class_1.total_pop_var,
"grouping_variable",
data_1.geometry.name if hasattr(data_1, "geometry") else None,
]
]
data_1.columns = ["group", "total", "grouping_variable"]
cols = ["group", "total", "grouping_variable"]
if hasattr(data_1, "geometry") :
cols += 'geometry',
data_1.columns = cols

data_2.loc[:, (seg_class_2.group_pop_var, seg_class_2.total_pop_var)] = (
data_2.loc[:, (seg_class_2.group_pop_var, seg_class_2.total_pop_var)]
Expand All @@ -352,9 +356,10 @@ def _prepare_random_label(seg_class_1, seg_class_2):
seg_class_2.group_pop_var,
seg_class_2.total_pop_var,
"grouping_variable",
data_2.geometry.name if hasattr(data_2, "geometry") else None,
]
]
data_2.columns = ["group", "total", "grouping_variable"]
data_2.columns = cols

stacked_data = pd.concat([data_1, data_2], axis=0)

Expand All @@ -367,11 +372,9 @@ def _prepare_random_label(seg_class_1, seg_class_2):

if seg_class_1.groups != seg_class_2.groups:
raise ValueError("MultiGroup groups should be the same")
# geometry has been discarded, but the CRS can cause concatenation problems
data_1.crs = None
data_2.crs = None

stacked_data = pd.concat([data_1, data_2], ignore_index=True)
return stacked_data
return stacked_data.reset_index(drop=True)


def _estimate_random_label_difference(data):
Expand Down Expand Up @@ -403,9 +406,14 @@ def _estimate_random_label_difference(data):
stacked_data["grouping_variable"] = grouping

else:
stacked_data["grouping_variable"] = np.random.permutation(
stacked_data["grouping_variable"].values
)
shuffled_indices = np.random.permutation(len(stacked_data))
if groups:
stacked_data[groups] = stacked_data.iloc[shuffled_indices][groups].values

else:
# these two cols need to be permuted together to maintain the relationship between group and total population counts
stacked_data[['group', 'total']] = stacked_data.iloc[shuffled_indices][['group','total']].values


stacked_data_1 = stacked_data[stacked_data["grouping_variable"] == "Group_1"]
stacked_data_2 = stacked_data[stacked_data["grouping_variable"] == "Group_2"]
Expand Down
Loading