diff --git a/docs/source/Inference.md b/docs/source/Inference.md index 02f35685d..8a163b9e1 100644 --- a/docs/source/Inference.md +++ b/docs/source/Inference.md @@ -67,7 +67,8 @@ python3 run_pretrained_openfold.py \ --pdb70_database_path $BASE_DATA_DIR/pdb70 \ --uniclust30_database_path $BASE_DATA_DIR/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \ --bfd_database_path $BASE_DATA_DIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \ - --model_device "cuda:0" + --model_device "cuda:0" \ + --cyclic_offset FASTA-tag1 FASTA-tag2 ... FASTA-tagN ``` **Required arguments:** @@ -138,7 +139,11 @@ Some commonly used command line flags are here. A full list of flags can be view - `--data_random_seed`: Specifies a random seed to use. - `--save_outputs`: Saves a copy of all outputs from the model, e.g. the output of the msa track, ptm heads. - `--experiment_config_json`: Specify configuration settings using a json file. For example, passing a json with `{globals.relax.max_iterations = 10}` specifies 10 as the maximum number of relaxation iterations. See for [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L283) the full dictionary of configuration settings. Any parameters that are not manually set in these configuration settings will refer to the defaults specified by your `config_preset`. - +- `--cyclic_offset`: Specifies a list FASTA tags for cyclic peptides. E.g. `--cyclic_offset FASTA-tag1 FASTA-tag2 ... FASTA-tagN`. When the list is not empty OpenFold will apply a cyclic end-to-end offset on the sequence instead of the deafult linear offset. +The result is that the sequence is treated as a cyclic species instead of a linear one. +It is recommended to use the unrelaxed output with this option as we have noticed worse +cyclization performance with the relaxed output. Refer to the [AfCycDesign preprint paper](https://www.biorxiv.org/content/10.1101/2023.02.25.529956v1.full) for original +implementation and explanation on the usage of a cyclic offset. ### Advanced Options for Increasing Efficiency diff --git a/docs/source/original_readme.md b/docs/source/original_readme.md index a070a1d52..e38d59da5 100644 --- a/docs/source/original_readme.md +++ b/docs/source/original_readme.md @@ -162,7 +162,8 @@ python3 run_pretrained_openfold.py \ --config_preset "model_1_ptm" \ --model_device "cuda:0" \ --output_dir ./ \ - --openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt + --openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt \ + --cyclic_offset FASTA-tag1 FASTA-tag2 ... FASTA-tagN ``` where `data` is the same directory as in the previous step. If `jackhmmer`, @@ -182,6 +183,12 @@ OpenFold was trained under a newer training schedule than the one from which the `*_ptm` checkpoints must be run with `*_ptm` config presets and that `_no_templ_` checkpoints are only compatible with template-less presets (`model_3` and above). +`--cyclic_offset` accepts a list of sequence FASTA tags. When the list is not empty OpenFold will +apply a cyclic end-to-end offset on the sequence instead of the deafult linear offset. +The result is that the sequence is treated as a cyclic species instead of a linear one. +Note: cyclization bond is not reliably retained relaxation step, recommend using unrelaxed structure output. Refer to the [AfCycDesign preprint paper](https://www.biorxiv.org/content/10.1101/2023.02.25.529956v1.full) for original +implementation and explanation on the usage of a cyclic offset. + Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement) is enabled by default in inference mode. To disable it, set `globals.chunk_size` to `None` in the config. If a value is specified, OpenFold will attempt to @@ -263,7 +270,8 @@ python3 run_pretrained_openfold.py \ --kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \ --config_preset "model_1_multimer_v3" \ --model_device "cuda:0" \ - --output_dir ./ + --output_dir ./ \ + --cyclic_offset FASTA-tag1 FASTA-tag2 ... FASTA-tagN ``` As with monomer inference, if you've already computed alignments for the query, you can use diff --git a/openfold/config.py b/openfold/config.py index 7bf30e391..90c0aeda7 100644 --- a/openfold/config.py +++ b/openfold/config.py @@ -290,6 +290,7 @@ def model_config( "common": { "feat": { "aatype": [NUM_RES], + "cyclic_mask": [NUM_RES], "all_atom_mask": [NUM_RES, None], "all_atom_positions": [NUM_RES, None, None], "alt_chi_angles": [NUM_RES, None], @@ -383,6 +384,7 @@ def model_config( "between_segment_residues", "deletion_matrix", "no_recycling_iters", + 'cyclic_mask' ], "use_templates": templates_enabled, "use_template_torsion_angles": embed_template_torsion_angles, @@ -748,6 +750,7 @@ def model_config( "common": { "feat": { "aatype": [NUM_RES], + "cyclic_mask": [NUM_RES], "all_atom_mask": [NUM_RES, None], "all_atom_positions": [NUM_RES, None, None], # "all_chains_entity_ids": [], # TODO: Resolve missing features, remove processed msa feats @@ -808,6 +811,7 @@ def model_config( "asym_id", "entity_id", "sym_id", + "cyclic_mask" ] }, "supervised": { diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py index adde0b73b..849c28c67 100644 --- a/openfold/data/data_pipeline.py +++ b/openfold/data/data_pipeline.py @@ -20,7 +20,7 @@ import dataclasses from multiprocessing import cpu_count import tempfile -from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union +from typing import List, Mapping, Optional, Sequence, Any, MutableMapping, Union import numpy as np import torch from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer @@ -854,6 +854,7 @@ def process_fasta( alignment_dir: str, alignment_index: Optional[Any] = None, seqemb_mode: bool = False, + cyclic_offset: Optional[List[str]] = [] ) -> FeatureDict: """Assembles features for a single sequence in a FASTA file""" with open(fasta_path) as f: @@ -885,6 +886,9 @@ def process_fasta( num_res=num_res, ) + n_residue_index = sequence_features['residue_index'].shape[0] + sequence_features['cyclic_mask'] = (np.ones(n_residue_index)*(input_description in cyclic_offset)).astype(np.bool_) + sequence_embedding_features = {} # If using seqemb mode, generate a dummy MSA features using just the sequence if seqemb_mode: @@ -1228,7 +1232,8 @@ def read_msa(start, size): def process_fasta(self, fasta_path: str, alignment_dir: str, - alignment_index: Optional[Any] = None + alignment_index: Optional[Any] = None, + cyclic_offset: Optional[List[str]] = None ) -> FeatureDict: """Creates features.""" with open(fasta_path) as f: @@ -1266,6 +1271,8 @@ def process_fasta(self, chain_features, chain_id=desc ) + + chain_features['cyclic_mask'] = (np.ones(chain_features['seq_length'])*(desc in cyclic_offset)).astype(np.bool_) all_chain_features[desc] = chain_features sequence_features[seq] = chain_features diff --git a/openfold/data/data_transforms.py b/openfold/data/data_transforms.py index bd306a11c..fc162fc04 100755 --- a/openfold/data/data_transforms.py +++ b/openfold/data/data_transforms.py @@ -144,6 +144,7 @@ def squeeze_features(protein): "between_segment_residues", "residue_index", "template_all_atom_mask", + 'cyclic_mask' ]: if k in protein: final_dim = protein[k].shape[-1] diff --git a/openfold/data/feature_processing_multimer.py b/openfold/data/feature_processing_multimer.py index 518babf26..993088609 100644 --- a/openfold/data/feature_processing_multimer.py +++ b/openfold/data/feature_processing_multimer.py @@ -31,7 +31,7 @@ 'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments', 'num_templates', 'queue_size', 'residue_index', 'resolution', 'seq_length', 'seq_mask', 'sym_id', 'template_aatype', - 'template_all_atom_mask', 'template_all_atom_positions' + 'template_all_atom_mask', 'template_all_atom_positions', 'cyclic_mask' }) MAX_TEMPLATES = 4 diff --git a/openfold/data/msa_pairing.py b/openfold/data/msa_pairing.py index 9b82acaf0..de1e09e5a 100644 --- a/openfold/data/msa_pairing.py +++ b/openfold/data/msa_pairing.py @@ -47,7 +47,7 @@ 'sym_id', 'entity_mask', 'deletion_mean', 'prediction_atom_mask', 'literature_positions', 'atom_indices_to_group_indices', - 'rigid_group_default_frame') + 'rigid_group_default_frame', 'cyclic_mask') TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions', 'template_all_atom_mask') CHAIN_FEATURES = ('num_alignments', 'seq_length') diff --git a/openfold/model/embedders.py b/openfold/model/embedders.py index 69cf4951c..ee74b60e1 100644 --- a/openfold/model/embedders.py +++ b/openfold/model/embedders.py @@ -82,7 +82,30 @@ def __init__( self.no_bins = 2 * relpos_k + 1 self.linear_relpos = Linear(self.no_bins, c_z) - def relpos(self, ri: torch.Tensor): + def cyclic_offset(self, residue_index: torch.Tensor) -> torch.Tensor: + """Calculate the cyclic offset for the given residue index. + + Parameters + ---------- + residue_index : torch.Tensor + The residue index tensor. + + Returns + ------- + torch.Tensor + The cyclic offset tensor. + """ + peptide_length = residue_index.shape[0] + cyclic_offset_array = torch.zeros((peptide_length, peptide_length)) + cyc_row = torch.arange(0, -peptide_length, -1) + pc = int(torch.round(torch.tensor(peptide_length / 2))) # Get centre + cyc_row[pc + 1 :] = torch.arange(len(cyc_row[pc + 1 :]), 0, -1) + for i in range(len(cyclic_offset_array)): + cyclic_offset_array[i] = torch.roll(cyc_row, i) + return cyclic_offset_array + + + def relpos(self, ri: torch.Tensor, cyclic_mask: Optional[torch.Tensor] = None): """ Computes relative positional encodings @@ -93,6 +116,9 @@ def relpos(self, ri: torch.Tensor): "residue_index" features of shape [*, N] """ d = ri[..., None] - ri[..., None, :] + if cyclic_mask is not None and sum(cyclic_mask)!=0: + d = self.cyclic_offset(ri).type(torch.long).to(d.device) + boundaries = torch.arange( start=-self.relpos_k, end=self.relpos_k + 1, device=d.device ) @@ -110,6 +136,7 @@ def forward( ri: torch.Tensor, msa: torch.Tensor, inplace_safe: bool = False, + cyclic_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -132,7 +159,7 @@ def forward( tf_emb_j = self.linear_tf_z_j(tf) # [*, N_res, N_res, c_z] - pair_emb = self.relpos(ri.type(tf_emb_i.dtype)) + pair_emb = self.relpos(ri.type(tf_emb_i.dtype), cyclic_mask=cyclic_mask) pair_emb = add(pair_emb, tf_emb_i[..., None, :], inplace=inplace_safe @@ -211,12 +238,44 @@ def __init__( else: self.no_bins = 2 * max_relative_idx + 1 self.linear_relpos = Linear(self.no_bins, c_z) - + + def cyclic_offset(self, residue_index: torch.Tensor) -> torch.Tensor: + """Calculate the cyclic offset for the given residue index. + + Parameters + ---------- + residue_index : torch.Tensor + The residue index tensor. + + Returns + ------- + torch.Tensor + The cyclic offset tensor. + """ + peptide_length = residue_index.shape[0] + cyclic_offset_array = torch.zeros((peptide_length, peptide_length)) + cyc_row = torch.arange(0, -peptide_length, -1) + pc = int(torch.round(torch.tensor(peptide_length / 2))) # Get centre + cyc_row[pc + 1 :] = torch.arange(len(cyc_row[pc + 1 :]), 0, -1) + for i in range(len(cyclic_offset_array)): + cyclic_offset_array[i] = torch.roll(cyc_row, i) + return cyclic_offset_array + def relpos(self, batch): pos = batch["residue_index"] asym_id = batch["asym_id"] asym_id_same = (asym_id[..., None] == asym_id[..., None, :]) offset = pos[..., None] - pos[..., None, :] + + if sum(batch['cyclic_mask'])!=0: + cyclic_entities = torch.unique(batch['entity_id'][batch['cyclic_mask']]) + for cyclic_entity in cyclic_entities: + entity_mask = batch['entity_id'] == cyclic_entity + entity_idx = torch.where(batch['entity_id']==cyclic_entity)[0] + cyclic_pos = pos[entity_mask] + cyclic_offset = self.cyclic_offset(cyclic_pos).type(torch.long) + offset[entity_idx,entity_idx.view(-1,1)] = cyclic_offset.to(offset.device) + clipped_offset = torch.clamp( offset + self.max_relative_idx, 0, 2 * self.max_relative_idx diff --git a/openfold/model/model.py b/openfold/model/model.py index 52cfda45e..81745faf5 100644 --- a/openfold/model/model.py +++ b/openfold/model/model.py @@ -255,6 +255,7 @@ def iteration(self, feats, prevs, _recycle=True): feats["residue_index"], feats["msa_feat"], inplace_safe=inplace_safe, + cyclic_mask = feats["cyclic_mask"] ) # Unpack the recycling embeddings. Removing them from the list allows diff --git a/run_pretrained_openfold.py b/run_pretrained_openfold.py index 3cd7c25c4..1a27118e7 100644 --- a/run_pretrained_openfold.py +++ b/run_pretrained_openfold.py @@ -139,7 +139,7 @@ def generate_feature_dict( '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)]) ) feature_dict = data_processor.process_fasta( - fasta_path=tmp_fasta_path, alignment_dir=alignment_dir, + fasta_path=tmp_fasta_path, alignment_dir=alignment_dir, cyclic_offset=args.cyclic_offset ) elif len(seqs) == 1: tag = tags[0] @@ -151,7 +151,7 @@ def generate_feature_dict( feature_dict = data_processor.process_fasta( fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir, - seqemb_mode=args.use_single_seq_mode, + seqemb_mode=args.use_single_seq_mode, cyclic_offset=args.cyclic_offset ) else: with open(tmp_fasta_path, "w") as fp: @@ -175,7 +175,6 @@ def list_files_with_extensions(dir, extensions): def main(args): # Create the output directory os.makedirs(args.output_dir, exist_ok=True) - if args.config_preset.startswith("seq"): args.use_single_seq_mode = True @@ -475,6 +474,11 @@ def main(args): "--use_deepspeed_evoformer_attention", action="store_true", default=False, help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.", ) + parser.add_argument( + '--cyclic-offset', metavar='N', type=str, nargs='*', default=[], + help="Space-separated list of sequence tags to apply cyclic offset to" + ) + add_data_args(parser) args = parser.parse_args()