Skip to content

Commit 00b66cc

Browse files
committed
Update simple pytorch model to use BaseModel class, and to optionally subset to particular amenities.
1 parent bbe6caa commit 00b66cc

1 file changed

Lines changed: 52 additions & 42 deletions

File tree

exploratory/models/pytorch_simple.py

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,27 @@
1414
import plotnine as gg
1515
from pathlib import Path
1616

17+
from openpois.models.base_model import BaseModel, EventRate
18+
1719
# Globals
1820
DATA_VERSION = "20260129"
1921
MODEL_VERSION = "20260212"
2022
DATA_DIR = Path("~/data/openpois").expanduser() / DATA_VERSION
2123
MODEL_DIR = Path("~/data/openpois").expanduser() / MODEL_VERSION
2224
TAG_KEY = "name"
25+
GROUP_KEY = "leisure"
26+
GROUP_VALUES = ["park"]
2327

2428
# Load data
2529
observations_df = pd.read_csv(DATA_DIR / f"osm_observations_{TAG_KEY}.csv")
2630

2731
# Ensure model directory exists
2832
MODEL_DIR.mkdir(parents = True, exist_ok = True)
29-
33+
model_suffix = f"_simple_{TAG_KEY}"
34+
if GROUP_KEY is not None:
35+
model_suffix += f"_{GROUP_KEY}"
36+
if GROUP_VALUES is not None:
37+
model_suffix += f"_{'-'.join(GROUP_VALUES)}"
3038
# Device setup
3139
DTYPE = torch.float64
3240
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -36,6 +44,17 @@
3644

3745
## Input data preparation --------------------------------------------------------------->
3846

47+
# If a group key was set, subset to those observations
48+
if GROUP_KEY is not None:
49+
keep_ids = observations_df.dropna(subset = [GROUP_KEY]).id.unique().tolist()
50+
observations_df = observations_df.query('id in @keep_ids')
51+
# If a group values were set, subset to those observations
52+
if GROUP_VALUES is not None:
53+
keep_ids = observations_df.loc[
54+
observations_df[GROUP_KEY].isin(GROUP_VALUES), 'id'
55+
].unique().tolist()
56+
observations_df = observations_df.query('id in @keep_ids')
57+
3958
timestamp_cols = ['obs_timestamp', 'last_obs_timestamp', 'last_tag_timestamp']
4059
for timestamp_col in timestamp_cols:
4160
observations_df[timestamp_col] = pd.to_datetime(observations_df[timestamp_col])
@@ -52,51 +71,36 @@
5271
## Define model ------------------------------------------------------------------------->
5372

5473
# Only parameters need requires_grad=True; data tensors must not, or memory explodes
55-
X = torch.tensor(obs_sub[['tag_years']].values, dtype=DTYPE, device=DEVICE)
5674
y = torch.tensor(obs_sub['changed'].values, dtype=DTYPE, device=DEVICE)
57-
58-
# Estimand: lambda, the rate parameter that is always positive
59-
omega = torch.tensor(
75+
X = torch.zeros(obs_sub.shape[0], 1, dtype=DTYPE, device=DEVICE)
76+
t1 = torch.zeros(obs_sub.shape[0], 1, dtype=DTYPE, device=DEVICE)
77+
t2 = torch.tensor(obs_sub[['tag_years']].values, dtype=DTYPE, device=DEVICE)
78+
# Estimand: (log) lambda, log of the rate parameter
79+
starting_params = torch.tensor(
6080
np.array([0.0]),
61-
dtype=DTYPE,
62-
device=DEVICE,
63-
requires_grad=True,
81+
dtype = DTYPE,
82+
device = DEVICE,
83+
requires_grad = True,
6484
)
6585

66-
# Small epsilon to avoid log(0) and log(1-p) = -inf -> NaN
67-
def nll_torchmin(params, y, X, DELTA = 1e-6, EPSILON = 1e-7):
68-
log_lambda = params[0].clamp(-20.0, 20.0) # keep lambda in [2e-9, 5e8]
69-
lambda_ = torch.exp(log_lambda)
70-
# X is (n,1); ensure positive so p is in (0,1)
71-
x = X.clamp(min = DELTA)
72-
p = (
73-
(1.0 - torch.exp(-lambda_ * x))
74-
.squeeze(-1)
75-
.clamp(min = EPSILON, max = 1.0 - EPSILON)
76-
)
77-
ll = torch.sum(y * torch.log(p) + (1.0 - y) * torch.log(1.0 - p))
78-
return -ll
79-
80-
model_fit = torchmin.minimize(
81-
fun = lambda params: nll_torchmin(params = params, y = y, X = X),
82-
x0 = omega,
83-
method = 'l-bfgs',
84-
tol = 1e-5,
85-
disp = True,
86-
)
87-
88-
# Prepare model results
89-
hessian_ = torch.autograd.functional.hessian(
90-
lambda params: nll_torchmin(params, y, X),
91-
model_fit.x
86+
def simple_model_fun(params, covariates = None):
87+
return torch.exp(params)
88+
89+
simple_model = BaseModel(
90+
event_rate = EventRate(
91+
type = 'constant',
92+
fun = simple_model_fun,
93+
),
94+
params = starting_params,
95+
covariates = X,
96+
target = y,
97+
t1 = t1,
98+
t2 = t2,
99+
verbose = True
92100
)
93-
se_torch_ = torch.sqrt(torch.linalg.diagonal(torch.linalg.inv(hessian_)))
101+
simple_model.fit()
94102

95-
m1 = pd.DataFrame({
96-
'parameter': ['log_lambda'],
97-
'estimate': model_fit.x.data.cpu().numpy(),
98-
'std_err': se_torch_.data.cpu().numpy(),
99-
})
103+
m1 = simple_model.get_results().assign(parameter = 'log_lambda')
100104
m2 = (
101105
m1
102106
.copy()
@@ -108,8 +112,14 @@ def nll_torchmin(params, y, X, DELTA = 1e-6, EPSILON = 1e-7):
108112
)
109113
model_results = pd.concat([m1, m2])
110114

115+
predictions = simple_model.predict(
116+
t2 = torch.tensor(np.arange(11), dtype = DTYPE, device = DEVICE),
117+
covariates = None,
118+
).assign(units = 'years')
119+
predictions.to_csv(MODEL_DIR / f"predictions{model_suffix}.csv", index = False)
120+
111121

112122
## Run model and save results ----------------------------------------------------------->
113123

114-
model_results.to_csv(MODEL_DIR / f"fitted_params_{TAG_KEY}.csv", index = False)
115-
torch.save(model_fit, MODEL_DIR / f"fitted_params_{TAG_KEY}.pt")
124+
model_results.to_csv(MODEL_DIR / f"fitted_params{model_suffix}.csv", index = False)
125+
torch.save(simple_model, MODEL_DIR / f"fitted_params{model_suffix}.pt")

0 commit comments

Comments
 (0)