diff --git a/segregation/inference/comparative.py b/segregation/inference/comparative.py index d3da57ad..3252ecba 100644 --- a/segregation/inference/comparative.py +++ b/segregation/inference/comparative.py @@ -338,9 +338,13 @@ def _prepare_random_label(seg_class_1, seg_class_2): seg_class_1.group_pop_var, seg_class_1.total_pop_var, "grouping_variable", + data_1.geometry.name if hasattr(data_1, "geometry") else None, ] ] - data_1.columns = ["group", "total", "grouping_variable"] + cols = ["group", "total", "grouping_variable"] + if hasattr(data_1, "geometry") : + cols += 'geometry', + data_1.columns = cols data_2.loc[:, (seg_class_2.group_pop_var, seg_class_2.total_pop_var)] = ( data_2.loc[:, (seg_class_2.group_pop_var, seg_class_2.total_pop_var)] @@ -352,9 +356,10 @@ def _prepare_random_label(seg_class_1, seg_class_2): seg_class_2.group_pop_var, seg_class_2.total_pop_var, "grouping_variable", + data_2.geometry.name if hasattr(data_2, "geometry") else None, ] ] - data_2.columns = ["group", "total", "grouping_variable"] + data_2.columns = cols stacked_data = pd.concat([data_1, data_2], axis=0) @@ -367,11 +372,9 @@ def _prepare_random_label(seg_class_1, seg_class_2): if seg_class_1.groups != seg_class_2.groups: raise ValueError("MultiGroup groups should be the same") - # geometry has been discarded, but the CRS can cause concatenation problems - data_1.crs = None - data_2.crs = None + stacked_data = pd.concat([data_1, data_2], ignore_index=True) - return stacked_data + return stacked_data.reset_index(drop=True) def _estimate_random_label_difference(data): @@ -403,9 +406,14 @@ def _estimate_random_label_difference(data): stacked_data["grouping_variable"] = grouping else: - stacked_data["grouping_variable"] = np.random.permutation( - stacked_data["grouping_variable"].values - ) + shuffled_indices = np.random.permutation(len(stacked_data)) + if groups: + stacked_data[groups] = stacked_data.iloc[shuffled_indices][groups].values + + else: + # these two cols need to be permuted together to maintain the relationship between group and total population counts + stacked_data[['group', 'total']] = stacked_data.iloc[shuffled_indices][['group','total']].values + stacked_data_1 = stacked_data[stacked_data["grouping_variable"] == "Group_1"] stacked_data_2 = stacked_data[stacked_data["grouping_variable"] == "Group_2"]