From 20705704ae6b2fdd5e84c27c8f59fdb686a7ba80 Mon Sep 17 00:00:00 2001 From: Logan Ward Date: Thu, 7 Aug 2025 10:23:16 -0400 Subject: [PATCH 1/3] Update to use the new HF docstring tools --- safe/trainer/model.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/safe/trainer/model.py b/safe/trainer/model.py index 9d84d6d..dfec626 100644 --- a/safe/trainer/model.py +++ b/safe/trainer/model.py @@ -5,12 +5,9 @@ from torch.nn import CrossEntropyLoss, MSELoss from transformers import GPT2DoubleHeadsModel, PretrainedConfig from transformers.activations import get_activation +from transformers.utils import auto_docstring from transformers.models.gpt2.modeling_gpt2 import ( - _CONFIG_FOR_DOC, - GPT2_INPUTS_DOCSTRING, GPT2DoubleHeadsModelOutput, - add_start_docstrings_to_model_forward, - replace_return_docstrings, ) @@ -114,8 +111,7 @@ def __init__(self, config): del self.multiple_choice_head self.multiple_choice_head = PropertyHead(config) - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + @auto_docstring() def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -149,8 +145,6 @@ def forward( mc_labels (`torch.LongTensor` of shape `(batch_size, n_tasks)`, *optional*): Labels for computing the supervized loss for regularization. inputs: List of inputs, put here because the trainer removes information not in signature - Returns: - output (GPT2DoubleHeadsModelOutput): output of the model """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( From 3a2303d37db016a6aa5f92633a917b4110c7e2d8 Mon Sep 17 00:00:00 2001 From: Logan Ward Date: Thu, 7 Aug 2025 11:42:53 -0400 Subject: [PATCH 2/3] Pin the maximum version of dependencies --- env.yml | 2 +- pyproject.toml | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/env.yml b/env.yml index fd3ab49..e5a9aec 100644 --- a/env.yml +++ b/env.yml @@ -37,7 +37,7 @@ dependencies: - mkdocs <1.6.0 - mkdocs-material >=7.1.1 - mkdocs-material-extensions - - mkdocstrings + - mkdocstrings < 0.28.0 - mkdocstrings-python - mkdocs-jupyter - markdown-include diff --git a/pyproject.toml b/pyproject.toml index 4395da7..ba369b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,19 @@ dependencies = [ "rdkit" ] +[project.optional-dependencies] +docs = [ + "mkdocsi <= 1.6.0", + "mkdocs-material>=7.1.1", + "mkdocs-material-extensions", + "mkdocstrings < 0.28.0", + "mkdocstrings-python", + "mkdocs-jupyter", + "markdown-include", + "mdx_truly_sane_lists", + "mike >=1.0.0", +] + [project.urls] "Source Code" = "https://github.com/datamol-io/safe" "Bug Tracker" = "https://github.com/datamol-io/safe/issues" @@ -91,10 +104,10 @@ lint.select = [ "F", # see: https://pypi.org/project/pyflakes ] lint.extend-select = [ - "C4", # see: https://pypi.org/project/flake8-comprehensions + "C4", # see: https://pypi.org/project/flake8-comprehensions "SIM", # see: https://pypi.org/project/flake8-simplify "RET", # see: https://pypi.org/project/flake8-return - "PT", # see: https://pypi.org/project/flake8-pytest-style + "PT", # see: https://pypi.org/project/flake8-pytest-style ] lint.ignore = [ "E731", # Do not assign a lambda expression, use a def From ae2165d1f14871fee62da1eaeb581b8d75f2de1a Mon Sep 17 00:00:00 2001 From: Logan Ward Date: Mon, 18 Aug 2025 07:21:19 -0400 Subject: [PATCH 3/3] Code style fix in viz --- safe/viz.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/safe/viz.py b/safe/viz.py index 8021d73..ea5a949 100644 --- a/safe/viz.py +++ b/safe/viz.py @@ -79,8 +79,8 @@ def to_image( bond_matches = list(itertools.chain(*bond_matches)) atom_indices.extend(atom_matches) bond_indices.extend(bond_matches) - atom_colors.update({x: current_colors[i] for x in atom_matches}) - bond_colors.update({x: current_colors[i] for x in bond_matches}) + atom_colors.update(dict.fromkeys(atom_matches, current_colors[i])) + bond_colors.update(dict.fromkeys(bond_matches, current_colors[i])) return dm.viz.to_image( mol,