Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions flexml/_feature_engineer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,49 @@ def transform(self, X):
A DataFrame with the specified columns dropped
"""
return X.drop(columns=self.drop_columns, axis=1, errors='ignore')


class CategoricalTypeConverter(BaseEstimator, TransformerMixin):
"""
A transformer to convert categorical columns to 'category' dtype.
Used for tree-based models that support native categorical features.
Supports ordered categories via ordinal_encode_map.
"""
def __init__(self, categorical_columns: Optional[List[str]] = None, ordinal_encode_map: Optional[Dict[str, List]] = None):
# Keep original values for sklearn clone compatibility
self.categorical_columns = categorical_columns
self.ordinal_encode_map = ordinal_encode_map

def fit(self, X, y=None):
return self

def transform(self, X):
"""
Converts specified categorical columns to 'category' dtype.
For columns in ordinal_encode_map, creates ordered categorical with specified order.

Returns
-------
pd.DataFrame
A DataFrame with categorical columns converted to 'category' dtype
"""
X = X.copy()
categorical_cols = self.categorical_columns or []
ordinal_map = self.ordinal_encode_map or {}

for col in categorical_cols:
if col in X.columns:
if col in ordinal_map:
# Handle unseen categories by mapping them to NaN
categories = [str(c) for c in ordinal_map[col]]
col_values = X[col].astype(str)
known_mask = col_values.isin(categories)
col_values = col_values.where(known_mask, other=np.nan)
X[col] = pd.Categorical(col_values, categories=categories, ordered=True)
else:
# Regular unordered categorical
X[col] = X[col].astype('category')
return X


class ColumnImputer(BaseEstimator, TransformerMixin):
Expand All @@ -47,6 +90,8 @@ def fit(self, X, y=None):
return self

def transform(self, X) -> pd.DataFrame:
X = X.copy() # Avoid modifying original data

# Categorical columns are converted to string
categorical_cols = X.select_dtypes(exclude=['number']).columns
X[categorical_cols] = X[categorical_cols].astype(str)
Expand Down Expand Up @@ -104,9 +149,17 @@ def __init__(
self.ordinal_encoders = {}

def fit(self, X, y=None):
# Categorical columns are converted to string
X = X.copy() # Avoid modifying original data

# First, convert all non-numeric columns to string (original behavior)
categorical_cols = X.select_dtypes(exclude=['number']).columns
X[categorical_cols] = X[categorical_cols].astype(str)

# Also ensure columns in encoding_method_mapper are string
# (handles case where column is numeric but needs encoding)
for col in self.encoding_method_mapper.keys():
if col in X.columns and col not in categorical_cols:
X[col] = X[col].astype(str)

for col, method in self.encoding_method_mapper.items():
if method == "label_encoder":
Expand All @@ -133,9 +186,16 @@ def fit(self, X, y=None):
return self

def transform(self, X) -> pd.DataFrame:
# Categorical columns are converted to string
X = X.copy() # Avoid modifying original data

# First, convert all non-numeric columns to string (original behavior)
categorical_cols = X.select_dtypes(exclude=['number']).columns
X[categorical_cols] = X[categorical_cols].astype(str)

# Also ensure columns in encoding_method_mapper are string
for col in self.encoding_method_mapper.keys():
if col in X.columns and col not in categorical_cols:
X[col] = X[col].astype(str)

for col, method in self.encoding_method_mapper.items():
if method == "label_encoder":
Expand Down
2 changes: 0 additions & 2 deletions flexml/_model_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,6 @@ def _setup_tuning(

* 'tuned_model_evaluation_metric': The evaluation metric that is used to evaluate the tuned model
"""
model_params = None

if isinstance(model, Pipeline):
model = model.named_steps['model']

Expand Down
1 change: 1 addition & 0 deletions flexml/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
)

from flexml.config.supervised_config import (
NATIVE_CATEGORICAL_MODELS,
EVALUATION_METRICS,
TUNING_METRIC_TRANSFORMATIONS,
CROSS_VALIDATION_METHODS,
Expand Down
8 changes: 4 additions & 4 deletions flexml/config/ml_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_ml_models(
KNN_REGRESSION = KNeighborsRegressor(n_jobs=n_jobs)
BAYESIAN_RIDGE_REGRESSION = BayesianRidge()
ADA_BOOST_REGRESSION = AdaBoostRegressor(random_state=random_state)
HIST_GRADIENT_BOOSTING_REGRESSION = HistGradientBoostingRegressor(random_state=random_state)
HIST_GRADIENT_BOOSTING_REGRESSION = HistGradientBoostingRegressor(random_state=random_state, categorical_features="from_dtype")
GRADIENT_BOOSTING_REGRESSION = GradientBoostingRegressor(random_state=random_state)
RANDOM_FOREST_REGRESSION = RandomForestRegressor(random_state=random_state, n_jobs=n_jobs)
EXTRA_TREES_REGRESSION = ExtraTreesRegressor(random_state=random_state, n_jobs=n_jobs)
Expand Down Expand Up @@ -308,8 +308,8 @@ def get_ml_models(

# Quick Classification Models
LOGISTIC_REGRESSION = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
XGBOOST_CLASSIFIER = XGBClassifier(objective=xgb_objective, random_state=random_state, n_jobs=n_jobs)
LIGHTGBM_CLASSIFIER = LGBMClassifier(verbose=-1, random_state=random_state, n_jobs=n_jobs)
XGBOOST_CLASSIFIER = XGBClassifier(enable_categorical=True, objective=xgb_objective, random_state=random_state, n_jobs=n_jobs)
LIGHTGBM_CLASSIFIER = LGBMClassifier(enable_categorical=True, verbose=-1, random_state=random_state, n_jobs=n_jobs)
CATBOOST_CLASSIFIER = CatBoostClassifier(allow_writing_files=False, silent=True, random_seed=random_state, thread_count=n_jobs)
DECISION_TREE_CLASSIFIER = DecisionTreeClassifier(random_state=random_state)
RANDOM_FOREST_CLASSIFIER = RandomForestClassifier(random_state=random_state, n_jobs=n_jobs)
Expand All @@ -318,7 +318,7 @@ def get_ml_models(

# Wide Classification Models
ADA_BOOST_CLASSIFIER = AdaBoostClassifier(random_state=random_state)
HIST_GRADIENT_BOOSTING_CLASSIFIER = HistGradientBoostingClassifier(random_state=random_state)
HIST_GRADIENT_BOOSTING_CLASSIFIER = HistGradientBoostingClassifier(random_state=random_state, categorical_features="from_dtype")
GRADIENT_BOOSTING_CLASSIFIER = GradientBoostingClassifier(random_state=random_state)
EXTRA_TREES_CLASSIFIER = ExtraTreesClassifier(random_state=random_state, n_jobs=n_jobs)
QDA_CLASSIFIER = QuadraticDiscriminantAnalysis()
Expand Down
8 changes: 8 additions & 0 deletions flexml/config/supervised_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# Models that support native categorical features
NATIVE_CATEGORICAL_MODELS = {
'CatBoostRegressor', 'CatBoostClassifier',
'LGBMRegressor', 'LGBMClassifier',
'XGBRegressor', 'XGBClassifier',
'HistGradientBoostingRegressor', 'HistGradientBoostingClassifier'
}

# Regression & Classification Evaluation Metrics
EVALUATION_METRICS = {
"Regression": {"DEFAULT": "R2",
Expand Down
15 changes: 6 additions & 9 deletions flexml/helpers/plot_model_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,15 +383,8 @@ def plot_shap(
or an error message if an error occurs during the process.
"""
try:
# Check if model is a tree-based model
model_type = str(type(model))

tree_based_models = [
"RandomForest", "GradientBoosting", "AdaBoost",
"HistGradientBoosting", "DecisionTree", "ExtraTrees",
"XGB", "CatBoost", "LGBM"
]
is_tree_based = any(model_name in model_type for model_name in tree_based_models)
# Check if the model is tree-based
is_tree_based = hasattr(model, 'feature_importances_')

if is_tree_based:
explainer = shap.TreeExplainer(model)
Expand All @@ -410,6 +403,10 @@ def plot_shap(
if shap_type == 'shap_summary':
shap.summary_plot(shap_values, X_test)
elif shap_type == 'shap_violin':
# While shap summary is okay with categorical columns, violin plot is not
cat_cols = X_test.select_dtypes(include=['category']).columns
for col in cat_cols:
X_test[col] = X_test[col].cat.codes
shap.plots.violin(shap_values, X_test)
else:
return f"Invalid shap_type: {shap_type}"
Expand Down
Loading