diff --git a/selection/learning/learners.py b/selection/learning/learners.py index 0941f9054..cc9963376 100644 --- a/selection/learning/learners.py +++ b/selection/learning/learners.py @@ -209,9 +209,20 @@ def learn(self, """ learning_selection, learning_T, random_algorithm = self.generate_data(B=B, - check_selection=check_selection) + check_selection=check_selection) + print('prob(select): ', np.mean(learning_selection, 0)) - conditional_laws = fit_probability(learning_T, learning_selection, **fit_args) + no_selection = np.mean(learning_selection, 0) == 1 + learned_laws = fit_probability(learning_T, + learning_selection[:, ~no_selection], + **fit_args) + conditional_laws = [] + for i in range(learning_selection.shape[1]): + if no_selection[i]: + print('just a constant', i, learning_selection.shape[1]) + conditional_laws.append(lambda t: np.ones(np.asarray(t).shape[0])) + else: + conditional_laws.append(learned_laws.pop(0)) return conditional_laws, (learning_T, learning_selection) class sparse_mixture_learner(mixture_learner): diff --git a/setup.py b/setup.py index c8fc1e0ec..c5658e110 100755 --- a/setup.py +++ b/setup.py @@ -108,6 +108,7 @@ def main(**extra_args): requires=info.REQUIRES, provides=info.PROVIDES, packages = ['selection', + 'selection.learning', 'selection.utils', 'selection.truncated', 'selection.truncated.tests',