-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict_page.py
More file actions
146 lines (119 loc) · 4.73 KB
/
predict_page.py
File metadata and controls
146 lines (119 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import pickle
from pathlib import Path
import pandas as pd
import streamlit as st
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline as SklearnPipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.ensemble import RandomForestRegressor
from processing import load_survey_data, normalize_inputs
class PipelineCompat(SklearnPipeline):
"""
Compatible wrapper to load older pickles whose Pipeline was stored
in a different module and passes positional args for memory/verbose.
"""
def __init__(self, *args, **kwargs):
if args and len(args) > 1:
steps = args[0]
memory = args[1] if len(args) > 1 else None
verbose = args[2] if len(args) > 2 else False
super().__init__(steps=steps, memory=memory, verbose=verbose, **kwargs)
else:
super().__init__(*args, **kwargs)
def __setstate__(self, state):
# Handle legacy tuples/lists: (params, steps, memory, verbose, *rest)
if isinstance(state, (list, tuple)):
if len(state) >= 4:
params, steps, memory, verbose = state[:4]
# Rebuild the pipeline with the provided steps
self.__init__(steps=steps, memory=memory, verbose=verbose)
if isinstance(params, dict):
self.set_params(**params)
return
try:
state = dict(state)
except Exception:
self.__dict__["_raw_state"] = state
return
# Fallback to default behavior for normal dict state
return super().__setstate__(state)
class _SafeUnpickler(pickle.Unpickler):
"""Redirect old pickled references to the current helpers."""
def find_class(self, module, name):
if module == "__main__" and name == "normalize_inputs":
return normalize_inputs
if name == "Pipeline":
return PipelineCompat
try:
return super().find_class(module, name)
except ModuleNotFoundError as exc:
# Fallback if the module name was stripped during pickling
if exc.name == "Pipeline":
return PipelineCompat
raise
def _load_pickled_model(path: Path):
with path.open("rb") as file:
return _SafeUnpickler(file).load()["pipeline"]
def _train_fresh_model():
"""Retrain a lightweight model if the stored pickle cannot be loaded."""
df = load_survey_data()
X = df[["Country", "EdLevel", "YearsCode"]]
y = df["Salary"]
categorical_features = ["Country", "EdLevel"]
numeric_features = ["YearsCode"]
preprocess = ColumnTransformer(
transformers=[
("cat", OneHotEncoder(handle_unknown="ignore"), categorical_features),
("num", SklearnPipeline([("imputer", SimpleImputer(strategy="median"))]), numeric_features),
]
)
model = RandomForestRegressor(n_estimators=150, random_state=42, n_jobs=-1)
pipe = SklearnPipeline([("preprocess", preprocess), ("model", model)])
pipe.fit(X, y)
return pipe
@st.cache_resource
def load_model():
path = Path("saved_steps.pkl")
if path.exists():
try:
return _load_pickled_model(path)
except Exception as err:
# Drop the broken pickle so we don't retry a bad state on next run.
try:
path.unlink()
except Exception:
pass
model = _train_fresh_model()
try:
with path.open("wb") as f:
pickle.dump({"pipeline": model}, f, protocol=pickle.HIGHEST_PROTOCOL)
except Exception:
# If we cannot write, still return the trained model without failing the UI.
pass
return model
@st.cache_data
def dropdown_options():
df = load_survey_data()
countries = sorted(df["Country"].unique())
education = (
"Less than a Bachelors",
"Bachelor's degree",
"Master's degree",
"Doctoral degree",
)
return countries, education
def show_predict_page():
st.title("Software Developer Salary Prediction 2025")
st.write("Share a few details to get a quick salary estimate.")
countries, education = dropdown_options()
country = st.selectbox("Country", countries)
education_level = st.selectbox("Education Level", education)
experience = st.slider("Years of Experience", 0, 50, 3)
if st.button("Calculate salary"):
user_input = pd.DataFrame(
{"Country": [country], "EdLevel": [education_level], "YearsCode": [experience]}
)
features = normalize_inputs(user_input)
salary = float(load_model().predict(features)[0])
st.success(f"Estimated salary: ${salary:,.0f} USD")