diff --git a/alphafold_pytorch_jit/net.py b/alphafold_pytorch_jit/net.py index 2581c7bfc089fc725c6d3c010366efc5d7f8e511..f195a2fa94a1dd6a87465da850647ed6b180ca67 100644 --- a/alphafold_pytorch_jit/net.py +++ b/alphafold_pytorch_jit/net.py @@ -102,7 +102,8 @@ class RunModel(object): ### create compatible structure module # time cost is low at structure-module # no need to cvt it to PyTorch version - _, struct_apply = get_pure_fn(StructureModule, sc, gc) + # _, struct_apply = get_pure_fn(StructureModule, sc, gc) + struct_apply = None ### create AlphaFold instance #evo_init_dims = { # 'target_feat':batch['target_feat'].shape[-1], diff --git a/alphafold_pytorch_jit/structure_module/__init__.py b/alphafold_pytorch_jit/structure_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/alphafold_pytorch_jit/structure_module/common/__init__.py b/alphafold_pytorch_jit/structure_module/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/alphafold_pytorch_jit/structure_module/common/protein.py b/alphafold_pytorch_jit/structure_module/common/protein.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7fd4c7cea7a2b8a1fa3749e675afe79f7acefc --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/common/protein.py @@ -0,0 +1,358 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protein data type.""" +import dataclasses +import io +from typing import Any, Mapping, Optional +import re + +from . import residue_constants +from Bio.PDB import PDBParser +import numpy as np + + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. +PICO_TO_ANGSTROM = 0.01 + +PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) +assert(PDB_MAX_CHAINS == 62) + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # 0-indexed number corresponding to the chain in the protein that this + # residue belongs to + chain_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + def __post_init__(self): + if(len(np.unique(self.chain_index)) > PDB_MAX_CHAINS): + raise ValueError( + f"Cannot build an instance with more than {PDB_MAX_CHAINS} " + "chains because these cannot be written to PDB format" + ) + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + Args: + pdb_str: The contents of the pdb file + chain_id: If chain_id is specified (e.g. A), then only that chain is + parsed. Else, all chains are parsed. + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure("none", pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f"Only single model PDBs are supported. Found {len(models)} models." + ) + model = models[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if(chain_id is not None and chain.id != chain_id): + continue + + + for res in chain: + if res.id[2] != " ": + raise ValueError( + f"PDB contains an insertion code at chain {chain.id} and residue " + f"index {res.id[1]}. These are not supported." + ) + res_shortname = residue_constants.restype_3to1.get(res.resname, "X") + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num + ) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1.0 + res_b_factors[ + residue_constants.atom_order[atom.name] + ] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + # Chain IDs are usually characters so map these to ints + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=chain_index, + b_factors=np.array(b_factors), + ) + + +def from_proteinnet_string(proteinnet_str: str) -> Protein: + tag_re = r'(\[[A-Z]+\]\n)' + tags = [ + tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0 + ] + groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]]) + + atoms = ['N', 'CA', 'C'] + aatype = None + atom_positions = None + atom_mask = None + for g in groups: + if("[PRIMARY]" == g[0]): + seq = g[1][0].strip() + for i in range(len(seq)): + if(seq[i] not in residue_constants.restypes): + seq[i] = 'X' + aatype = np.array([ + residue_constants.restype_order.get( + res_symbol, residue_constants.restype_num + ) for res_symbol in seq + ]) + elif("[TERTIARY]" == g[0]): + tertiary = [] + for axis in range(3): + tertiary.append(list(map(float, g[1][axis].split()))) + tertiary_np = np.array(tertiary) + atom_positions = np.zeros( + (len(tertiary[0])//3, residue_constants.atom_type_num, 3) + ).astype(np.float32) + for i, atom in enumerate(atoms): + atom_positions[:, residue_constants.atom_order[atom], :] = ( + np.transpose(tertiary_np[:, i::3]) + ) + atom_positions *= PICO_TO_ANGSTROM + elif("[MASK]" == g[0]): + mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip()))) + atom_mask = np.zeros( + (len(mask), residue_constants.atom_type_num,) + ).astype(np.float32) + for i, atom in enumerate(atoms): + atom_mask[:, residue_constants.atom_order[atom]] = 1 + atom_mask *= mask[..., None] + + return Protein( + atom_positions=atom_positions, + atom_mask=atom_mask, + aatype=aatype, + residue_index=np.arange(len(aatype)), + b_factors=None, + ) + + +def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: + chain_end = 'TER' + return( + f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' + f'{chain_name:>1}{residue_index:>4}' + ) + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + Args: + prot: The protein to convert to PDB. + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ["X"] + res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK") + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + chain_index = prot.chain_index.astype(np.int32) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError("Invalid aatypes.") + + # Construct a mapping from chain integer indices to chain ID strings. + chain_ids = {} + for i in np.unique(chain_index): # np.unique gives sorted output. + if i >= PDB_MAX_CHAINS: + raise ValueError( + f"The PDB format supports at most {PDB_MAX_CHAINS} chains." + ) + chain_ids[i] = PDB_CHAIN_IDS[i] + + pdb_lines.append("MODEL 1") + atom_index = 1 + last_chain_index = chain_index[0] + # Add all atom sites. + for i in range(aatype.shape[0]): + # Close the previous chain if in a multichain PDB. + if last_chain_index != chain_index[i]: + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(aatype[i - 1]), + chain_ids[chain_index[i - 1]], + residue_index[i - 1] + ) + ) + last_chain_index = chain_index[i] + atom_index += 1 # Atom index increases at the TER symbol. + + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip( + atom_types, atom_positions[i], atom_mask[i], b_factors[i] + ): + if mask < 0.5: + continue + + record_type = "ATOM" + name = atom_name if len(atom_name) == 4 else f" {atom_name}" + alt_loc = "" + insertion_code = "" + occupancy = 1.00 + element = atom_name[ + 0 + ] # Protein supports only C, N, O, S, this works. + charge = "" + # PDB is a columnar format, every space matters here! + atom_line = ( + f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" + f"{res_name_3:>3} {chain_ids[chain_index[i]]:>1}" + f"{residue_index[i]:>4}{insertion_code:>1} " + f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" + f"{occupancy:>6.2f}{b_factor:>6.2f} " + f"{element:>2}{charge:>2}" + ) + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the final chain. + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(aatype[-1]), + chain_ids[chain_index[-1]], + residue_index[-1] + ) + ) + + pdb_lines.append("ENDMDL") + pdb_lines.append("END") + + # Pad all lines to 80 characters + pdb_lines = [line.ljust(80) for line in pdb_lines] + return '\n'.join(pdb_lines) + '\n' # Add terminating newline. + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction( + features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None, + remove_leading_feature_dimension: bool = False, +) -> Protein: + """Assembles a protein from a prediction. + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + remove_leading_feature_dimension: Whether to remove the leading dimension + of the `features` values + Returns: + A protein instance. + """ + def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: + return arr[0] if remove_leading_feature_dimension else arr + + if 'asym_id' in features: + chain_index = _maybe_remove_leading_dim(features["asym_id"]) + else: + chain_index = np.zeros_like( + _maybe_remove_leading_dim(features["aatype"]) + ) + + if b_factors is None: + b_factors = np.zeros_like(result["final_atom_mask"]) + + return Protein( + aatype=_maybe_remove_leading_dim(features["aatype"]), + atom_positions=result["final_atom_positions"], + atom_mask=result["final_atom_mask"], + residue_index=_maybe_remove_leading_dim(features["residue_index"]) + 1, + chain_index=chain_index, + b_factors=b_factors, + ) \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/common/residue_constants.py b/alphafold_pytorch_jit/structure_module/common/residue_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..c5b969d99af24a8c20282fe16703c6bbbbabf046 --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/common/residue_constants.py @@ -0,0 +1,1492 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import os +import urllib.request +import collections +import functools +from typing import Mapping, List, Tuple + +import numpy as np +import tree + +# Internal import (35fd). + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + "ALA": [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + "ARG": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "NE"], + ["CG", "CD", "NE", "CZ"], + ], + "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "CYS": [["N", "CA", "CB", "SG"]], + "GLN": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLU": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "OE1"], + ], + "GLY": [], + "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]], + "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]], + "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "LYS": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "CD"], + ["CB", "CG", "CD", "CE"], + ["CG", "CD", "CE", "NZ"], + ], + "MET": [ + ["N", "CA", "CB", "CG"], + ["CA", "CB", "CG", "SD"], + ["CB", "CG", "SD", "CE"], + ], + "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]], + "SER": [["N", "CA", "CB", "OG"]], + "THR": [["N", "CA", "CB", "OG1"]], + "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "VAL": [["N", "CA", "CB", "CG1"]], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + "ALA": [ + ["N", 0, (-0.525, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.529, -0.774, -1.205)], + ["O", 3, (0.627, 1.062, 0.000)], + ], + "ARG": [ + ["N", 0, (-0.524, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.524, -0.778, -1.209)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG", 4, (0.616, 1.390, -0.000)], + ["CD", 5, (0.564, 1.414, 0.000)], + ["NE", 6, (0.539, 1.357, -0.000)], + ["NH1", 7, (0.206, 2.301, 0.000)], + ["NH2", 7, (2.078, 0.978, -0.000)], + ["CZ", 7, (0.758, 1.093, -0.000)], + ], + "ASN": [ + ["N", 0, (-0.536, 1.357, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.531, -0.787, -1.200)], + ["O", 3, (0.625, 1.062, 0.000)], + ["CG", 4, (0.584, 1.399, 0.000)], + ["ND2", 5, (0.593, -1.188, 0.001)], + ["OD1", 5, (0.633, 1.059, 0.000)], + ], + "ASP": [ + ["N", 0, (-0.525, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, 0.000, -0.000)], + ["CB", 0, (-0.526, -0.778, -1.208)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.593, 1.398, -0.000)], + ["OD1", 5, (0.610, 1.091, 0.000)], + ["OD2", 5, (0.592, -1.101, -0.003)], + ], + "CYS": [ + ["N", 0, (-0.522, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, 0.000, 0.000)], + ["CB", 0, (-0.519, -0.773, -1.212)], + ["O", 3, (0.625, 1.062, -0.000)], + ["SG", 4, (0.728, 1.653, 0.000)], + ], + "GLN": [ + ["N", 0, (-0.526, 1.361, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, 0.000)], + ["CB", 0, (-0.525, -0.779, -1.207)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.615, 1.393, 0.000)], + ["CD", 5, (0.587, 1.399, -0.000)], + ["NE2", 6, (0.593, -1.189, -0.001)], + ["OE1", 6, (0.634, 1.060, 0.000)], + ], + "GLU": [ + ["N", 0, (-0.528, 1.361, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, -0.000, -0.000)], + ["CB", 0, (-0.526, -0.781, -1.207)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG", 4, (0.615, 1.392, 0.000)], + ["CD", 5, (0.600, 1.397, 0.000)], + ["OE1", 6, (0.607, 1.095, -0.000)], + ["OE2", 6, (0.589, -1.104, -0.001)], + ], + "GLY": [ + ["N", 0, (-0.572, 1.337, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.517, -0.000, -0.000)], + ["O", 3, (0.626, 1.062, -0.000)], + ], + "HIS": [ + ["N", 0, (-0.527, 1.360, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, 0.000, 0.000)], + ["CB", 0, (-0.525, -0.778, -1.208)], + ["O", 3, (0.625, 1.063, 0.000)], + ["CG", 4, (0.600, 1.370, -0.000)], + ["CD2", 5, (0.889, -1.021, 0.003)], + ["ND1", 5, (0.744, 1.160, -0.000)], + ["CE1", 5, (2.030, 0.851, 0.002)], + ["NE2", 5, (2.145, -0.466, 0.004)], + ], + "ILE": [ + ["N", 0, (-0.493, 1.373, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, -0.000)], + ["CB", 0, (-0.536, -0.793, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG1", 4, (0.534, 1.437, -0.000)], + ["CG2", 4, (0.540, -0.785, -1.199)], + ["CD1", 5, (0.619, 1.391, 0.000)], + ], + "LEU": [ + ["N", 0, (-0.520, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.522, -0.773, -1.214)], + ["O", 3, (0.625, 1.063, -0.000)], + ["CG", 4, (0.678, 1.371, 0.000)], + ["CD1", 5, (0.530, 1.430, -0.000)], + ["CD2", 5, (0.535, -0.774, 1.200)], + ], + "LYS": [ + ["N", 0, (-0.526, 1.362, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, 0.000)], + ["CB", 0, (-0.524, -0.778, -1.208)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.619, 1.390, 0.000)], + ["CD", 5, (0.559, 1.417, 0.000)], + ["CE", 6, (0.560, 1.416, 0.000)], + ["NZ", 7, (0.554, 1.387, 0.000)], + ], + "MET": [ + ["N", 0, (-0.521, 1.364, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, 0.000, 0.000)], + ["CB", 0, (-0.523, -0.776, -1.210)], + ["O", 3, (0.625, 1.062, -0.000)], + ["CG", 4, (0.613, 1.391, -0.000)], + ["SD", 5, (0.703, 1.695, 0.000)], + ["CE", 6, (0.320, 1.786, -0.000)], + ], + "PHE": [ + ["N", 0, (-0.518, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, 0.000, -0.000)], + ["CB", 0, (-0.525, -0.776, -1.212)], + ["O", 3, (0.626, 1.062, -0.000)], + ["CG", 4, (0.607, 1.377, 0.000)], + ["CD1", 5, (0.709, 1.195, -0.000)], + ["CD2", 5, (0.706, -1.196, 0.000)], + ["CE1", 5, (2.102, 1.198, -0.000)], + ["CE2", 5, (2.098, -1.201, -0.000)], + ["CZ", 5, (2.794, -0.003, -0.001)], + ], + "PRO": [ + ["N", 0, (-0.566, 1.351, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, 0.000)], + ["CB", 0, (-0.546, -0.611, -1.293)], + ["O", 3, (0.621, 1.066, 0.000)], + ["CG", 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + "SER": [ + ["N", 0, (-0.529, 1.360, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, -0.000)], + ["CB", 0, (-0.518, -0.777, -1.211)], + ["O", 3, (0.626, 1.062, -0.000)], + ["OG", 4, (0.503, 1.325, 0.000)], + ], + "THR": [ + ["N", 0, (-0.517, 1.364, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.526, 0.000, -0.000)], + ["CB", 0, (-0.516, -0.793, -1.215)], + ["O", 3, (0.626, 1.062, 0.000)], + ["CG2", 4, (0.550, -0.718, -1.228)], + ["OG1", 4, (0.472, 1.353, 0.000)], + ], + "TRP": [ + ["N", 0, (-0.521, 1.363, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.525, -0.000, 0.000)], + ["CB", 0, (-0.523, -0.776, -1.212)], + ["O", 3, (0.627, 1.062, 0.000)], + ["CG", 4, (0.609, 1.370, -0.000)], + ["CD1", 5, (0.824, 1.091, 0.000)], + ["CD2", 5, (0.854, -1.148, -0.005)], + ["CE2", 5, (2.186, -0.678, -0.007)], + ["CE3", 5, (0.622, -2.530, -0.007)], + ["NE1", 5, (2.140, 0.690, -0.004)], + ["CH2", 5, (3.028, -2.890, -0.013)], + ["CZ2", 5, (3.283, -1.543, -0.011)], + ["CZ3", 5, (1.715, -3.389, -0.011)], + ], + "TYR": [ + ["N", 0, (-0.522, 1.362, 0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.524, -0.000, -0.000)], + ["CB", 0, (-0.522, -0.776, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG", 4, (0.607, 1.382, -0.000)], + ["CD1", 5, (0.716, 1.195, -0.000)], + ["CD2", 5, (0.713, -1.194, -0.001)], + ["CE1", 5, (2.107, 1.200, -0.002)], + ["CE2", 5, (2.104, -1.201, -0.003)], + ["OH", 5, (4.168, -0.002, -0.005)], + ["CZ", 5, (2.791, -0.001, -0.003)], + ], + "VAL": [ + ["N", 0, (-0.494, 1.373, -0.000)], + ["CA", 0, (0.000, 0.000, 0.000)], + ["C", 0, (1.527, -0.000, -0.000)], + ["CB", 0, (-0.533, -0.795, -1.213)], + ["O", 3, (0.627, 1.062, -0.000)], + ["CG1", 4, (0.540, 1.429, -0.000)], + ["CG2", 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + "ALA": ["C", "CA", "CB", "N", "O"], + "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"], + "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], + "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"], + "CYS": ["C", "CA", "CB", "N", "O", "SG"], + "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"], + "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"], + "GLY": ["C", "CA", "N", "O"], + "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"], + "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"], + "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"], + "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"], + "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"], + "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"], + "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"], + "SER": ["C", "CA", "CB", "N", "O", "OG"], + "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"], + "TRP": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + "N", + "NE1", + "O", + ], + "TYR": [ + "C", + "CA", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "N", + "O", + "OH", + ], + "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"], +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +# TODO: ^ interpret this +residue_atom_renaming_swaps = { + "ASP": {"OD1": "OD2"}, + "GLU": {"OE1": "OE2"}, + "PHE": {"CD1": "CD2", "CE1": "CE2"}, + "TYR": {"CD1": "CD2", "CE1": "CE2"}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + "C": 1.7, + "N": 1.55, + "O": 1.52, + "S": 1.8, +} + +Bond = collections.namedtuple( + "Bond", ["atom1_name", "atom2_name", "length", "stddev"] +) +BondAngle = collections.namedtuple( + "BondAngle", + ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"], +) + +def get_cache_path(): + cache_path = os.path.join(os.path.expanduser("~"), '.fastfold') + if not os.path.exists(cache_path): + os.makedirs(cache_path, exist_ok=True) + return cache_path + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[ + Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]], +]: + """Load stereo_chemical_props.txt into a nice structure. + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + Returns: + residue_bonds: dict that maps resname --> list of Bond tuples + residue_virtual_bonds: dict that maps resname --> list of Bond tuples + residue_bond_angles: dict that maps resname --> list of BondAngle tuples + """ + stereo_chemical_props_path = os.path.join(get_cache_path(), 'stereo_chemical_props.txt') + if not os.path.exists(stereo_chemical_props_path): + url = "https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt" + urllib.request.urlretrieve(url, stereo_chemical_props_path) + with open(stereo_chemical_props_path, 'rt') as f: + stereo_chemical_props = f.read() + + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split("-") + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev)) + ) + residue_bonds["UNK"] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split("-") + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle( + atom1, + atom2, + atom3, + float(angle_degree) / 180.0 * np.pi, + float(stddev_degree) / 180.0 * np.pi, + ) + ) + residue_bond_angles["UNK"] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return "-".join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt( + bond1.length ** 2 + + bond2.length ** 2 + - 2 * bond1.length * bond2.length * np.cos(gamma) + ) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = ( + 2 * bond1.length * bond2.length * np.sin(gamma) + ) * dl_outer + dl_db1 = ( + 2 * bond1.length - 2 * bond2.length * np.cos(gamma) + ) * dl_outer + dl_db2 = ( + 2 * bond2.length - 2 * bond1.length * np.cos(gamma) + ) * dl_outer + stddev = np.sqrt( + (dl_dgamma * ba.stddev) ** 2 + + (dl_db1 * bond1.stddev) ** 2 + + (dl_db2 * bond2.stddev) ** 2 + ) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev) + ) + + return (residue_bonds, residue_virtual_bonds, residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + "N", + "CA", + "C", + "CB", + "O", + "CG", + "CG1", + "CG2", + "OG", + "OG1", + "SG", + "CD", + "CD1", + "CD2", + "ND1", + "ND2", + "OD1", + "OD2", + "SD", + "CE", + "CE1", + "CE2", + "CE3", + "NE", + "NE1", + "NE2", + "OE1", + "OE2", + "CH2", + "NH1", + "NH2", + "OH", + "CZ", + "CZ2", + "CZ3", + "NZ", + "OXT", +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], + "ARG": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "NE", + "CZ", + "NH1", + "NH2", + "", + "", + "", + ], + "ASN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "OD1", + "ND2", + "", + "", + "", + "", + "", + "", + ], + "ASP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "OD1", + "OD2", + "", + "", + "", + "", + "", + "", + ], + "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], + "GLN": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "NE2", + "", + "", + "", + "", + "", + ], + "GLU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "OE1", + "OE2", + "", + "", + "", + "", + "", + ], + "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], + "HIS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "ND1", + "CD2", + "CE1", + "NE2", + "", + "", + "", + "", + ], + "ILE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG1", + "CG2", + "CD1", + "", + "", + "", + "", + "", + "", + ], + "LEU": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "", + "", + "", + "", + "", + "", + ], + "LYS": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD", + "CE", + "NZ", + "", + "", + "", + "", + "", + ], + "MET": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "SD", + "CE", + "", + "", + "", + "", + "", + "", + ], + "PHE": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "", + "", + "", + ], + "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], + "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], + "THR": [ + "N", + "CA", + "C", + "O", + "CB", + "OG1", + "CG2", + "", + "", + "", + "", + "", + "", + "", + ], + "TRP": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "NE1", + "CE2", + "CE3", + "CZ2", + "CZ3", + "CH2", + ], + "TYR": [ + "N", + "CA", + "C", + "O", + "CB", + "CG", + "CD1", + "CD2", + "CE1", + "CE2", + "CZ", + "OH", + "", + "", + ], + "VAL": [ + "N", + "CA", + "C", + "O", + "CB", + "CG1", + "CG2", + "", + "", + "", + "", + "", + "", + "", + ], + "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ["X"] +restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)} + + +def sequence_to_onehot( + sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False +) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError( + "The mapping must have values from 0 to num_unique_aas-1 " + "without any gaps. Got: %s" % sorted(mapping.values()) + ) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping["X"]) + else: + raise ValueError( + f"Invalid character in the sequence: {aa_type}" + ) + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = "UNK" + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + "A": 0, + "B": 2, + "C": 1, + "D": 2, + "E": 3, + "F": 4, + "G": 5, + "H": 6, + "I": 7, + "J": 20, + "K": 8, + "L": 9, + "M": 10, + "N": 11, + "O": 20, + "P": 12, + "Q": 13, + "R": 14, + "S": 15, + "T": 16, + "U": 1, + "V": 17, + "W": 18, + "X": 20, + "Y": 19, + "Z": 3, + "-": 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: "A", + 1: "C", # Also U. + 2: "D", # Also B. + 3: "E", # Also Z. + 4: "F", + 5: "G", + 6: "H", + 7: "I", + 8: "K", + 9: "L", + 10: "M", + 11: "N", + 12: "P", + 13: "Q", + 14: "R", + 15: "S", + 16: "T", + 17: "V", + 18: "W", + 19: "Y", + 20: "X", # Includes J and O. + 21: "-", +} + +restypes_with_x_and_gap = restypes + ["X", "-"] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap)) +) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1] * (4 - len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree.map_structure( + lambda atom_name: atom_order[atom_name], chi_angles_atom_indices +) +chi_angles_atom_indices = np.array( + [ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices + ] +) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack( + [ex_normalized, ey_normalized, eznorm, translation] + ).transpose() + m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int32) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int32) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[ + resname + ]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[ + restype, atomtype, : + ] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[ + restype, atom14idx, : + ] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = { + name: np.array(pos) + for name, _, pos in rigid_group_atom_positions[resname] + } + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["N"] - atom_positions["CA"], + ey=np.array([1.0, 0.0, 0.0]), + translation=atom_positions["N"], + ) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["C"] - atom_positions["CA"], + ey=atom_positions["CA"] - atom_positions["N"], + translation=atom_positions["C"], + ) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [ + atom_positions[name] for name in base_atom_names + ] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2], + ) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1.0, 0.0, 0.0]), + translation=axis_end_atom_position, + ) + restype_rigid_group_default_frame[ + restype, 4 + chi_idx, :, : + ] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds( + overlap_tolerance=1.5, bond_length_tolerance_factor=15 +): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[ + restype, atom1_idx, atom2_idx + ] = lower + restype_atom14_bond_lower_bound[ + restype, atom2_idx, atom1_idx + ] = lower + restype_atom14_bond_upper_bound[ + restype, atom1_idx, atom2_idx + ] = upper + restype_atom14_bond_upper_bound[ + restype, atom2_idx, atom1_idx + ] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[ + restype, atom1_idx, atom2_idx + ] = lower + restype_atom14_bond_lower_bound[ + restype, atom2_idx, atom1_idx + ] = lower + restype_atom14_bond_upper_bound[ + restype, atom1_idx, atom2_idx + ] = upper + restype_atom14_bond_upper_bound[ + restype, atom2_idx, atom1_idx + ] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return { + "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14) + "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14) + "stddev": restype_atom14_bond_stddev, # shape (21,14,14) + } + + +restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32) +restype_atom14_ambiguous_atoms_swap_idx = np.tile( + np.arange(14, dtype=np.int32), (21, 1) +) + + +def _make_atom14_ambiguity_feats(): + for res, pairs in residue_atom_renaming_swaps.items(): + res_idx = restype_order[restype_3to1[res]] + for atom1, atom2 in pairs.items(): + atom1_idx = restype_name_to_atom14_names[res].index(atom1) + atom2_idx = restype_name_to_atom14_names[res].index(atom2) + restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1 + restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1 + restype_atom14_ambiguous_atoms_swap_idx[ + res_idx, atom1_idx + ] = atom2_idx + restype_atom14_ambiguous_atoms_swap_idx[ + res_idx, atom2_idx + ] = atom1_idx + + +_make_atom14_ambiguity_feats() + + +def aatype_to_str_sequence(aatype): + return ''.join([ + restypes_with_x[aatype[i]] + for i in range(len(aatype)) + ]) + + +### ALPHAFOLD MULTIMER STUFF ### +def _make_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in restypes: + residue_name = restype_1to3[residue_name] + residue_chi_angles = chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return np.array(chi_atom_indices) + + +def _make_renaming_matrices(): + """Matrices to map atoms to symmetry partners in ambiguous case.""" + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative groundtruth coordinates where the naming is swapped + restype_3 = [ + restype_1to3[res] for res in restypes + ] + restype_3 += ['UNK'] + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = restype_name_to_atom14_names[ + resname].index(source_atom_swap) + target_index = restype_name_to_atom14_names[ + resname].index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) + return renaming_matrices + + +def _make_restype_atom37_mask(): + """Mask of which atoms are present for which residue type in atom37.""" + # create the corresponding mask + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + return restype_atom37_mask + + +def _make_restype_atom14_mask(): + """Mask of which atoms are present for which residue type in atom14.""" + restype_atom14_mask = [] + + for rt in restypes: + atom_names = restype_name_to_atom14_names[ + restype_1to3[rt]] + restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + restype_atom14_mask.append([0.] * 14) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + return restype_atom14_mask + + +def _make_restype_atom37_to_atom14(): + """Map from atom37 to atom14 per residue type.""" + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + for rt in restypes: + atom_names = restype_name_to_atom14_names[ + restype_1to3[rt]] + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in atom_types + ]) + + restype_atom37_to_atom14.append([0] * 37) + restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) + return restype_atom37_to_atom14 + + +def _make_restype_atom14_to_atom37(): + """Map from atom14 to atom37 per residue type.""" + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + for rt in restypes: + atom_names = restype_name_to_atom14_names[ + restype_1to3[rt]] + restype_atom14_to_atom37.append([ + (atom_order[name] if name else 0) + for name in atom_names + ]) + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) + return restype_atom14_to_atom37 + + +def _make_restype_atom14_is_ambiguous(): + """Mask which atoms are ambiguous in atom14.""" + # create an ambiguous atoms mask. shape: (21, 14) + restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) + for resname, swap in residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = restype_order[ + restype_3to1[resname]] + atom_idx1 = restype_name_to_atom14_names[resname].index( + atom_name1) + atom_idx2 = restype_name_to_atom14_names[resname].index( + atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + return restype_atom14_is_ambiguous + + +def _make_restype_rigidgroup_base_atom37_idx(): + """Create Map from rigidgroups to atom37 indices.""" + # Create an array with the atom names. + # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) + base_atom_names = np.full([21, 8, 3], '', dtype=object) + + # 0: backbone frame + base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + + # 3: 'psi-group' + base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + # 4,5,6,7: 'chi1,2,3,4-group' + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for chi_idx in range(4): + if chi_angles_mask[restype][chi_idx]: + atom_names = chi_angles_atoms[resname][chi_idx] + base_atom_names[restype, chi_idx + 4, :] = atom_names[1:] + + # Translate atom names into atom37 indices. + lookuptable = atom_order.copy() + lookuptable[''] = 0 + restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])( + base_atom_names) + return restype_rigidgroup_base_atom37_idx + + +CHI_ATOM_INDICES = _make_chi_atom_indices() +RENAMING_MATRICES = _make_renaming_matrices() +RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37() +RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14() +RESTYPE_ATOM37_MASK = _make_restype_atom37_mask() +RESTYPE_ATOM14_MASK = _make_restype_atom14_mask() +RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous() +RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx() + +# Create mask for existing rigid groups. +RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32) +RESTYPE_RIGIDGROUP_MASK[:, 0] = 1 +RESTYPE_RIGIDGROUP_MASK[:, 3] = 1 +RESTYPE_RIGIDGROUP_MASK[:20, 4:] = chi_angles_mask \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/primitives.py b/alphafold_pytorch_jit/structure_module/primitives.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf6e47cf67595a6cb60078d308bfade06e3f329 --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/primitives.py @@ -0,0 +1,554 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import math +from typing import Optional, Callable, List, Tuple, Sequence +import numpy as np + +import torch +import torch.nn as nn +from scipy.stats import truncnorm + +from .utils.checkpointing import get_checkpoint_fn +from .utils.tensor_utils import ( + permute_final_dims, + flatten_final_dims, + _chunk_slice, +) + +def _prod(nums): + out = 1 + for n in nums: + out = out * n + return out + + +def _calculate_fan(linear_weight_shape, fan="fan_in"): + fan_out, fan_in = linear_weight_shape + + if fan == "fan_in": + f = fan_in + elif fan == "fan_out": + f = fan_out + elif fan == "fan_avg": + f = (fan_in + fan_out) / 2 + else: + raise ValueError("Invalid fan option") + + return f + + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + f = _calculate_fan(shape, fan) + scale = scale / max(1, f) + a = -2 + b = 2 + std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) + size = _prod(shape) + samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) + samples = np.reshape(samples, shape) + with torch.no_grad(): + weights.copy_(torch.tensor(samples, device=weights.device)) + + +def lecun_normal_init_(weights): + trunc_normal_init_(weights, scale=1.0) + + +def he_normal_init_(weights): + trunc_normal_init_(weights, scale=2.0) + + +def glorot_uniform_init_(weights): + nn.init.xavier_uniform_(weights, gain=1) + + +def final_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def gating_init_(weights): + with torch.no_grad(): + weights.fill_(0.0) + + +def normal_init_(weights): + torch.nn.init.kaiming_normal_(weights, nonlinearity="linear") + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization + "relu": He initialization w/ truncated normal distribution + "glorot": Fan-average Glorot uniform initialization + "gating": Weights=0, Bias=1 + "normal": Normal initialization with std=1/sqrt(fan_in) + "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. + Overrides init if not None. + """ + super(Linear, self).__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init_fn is not None: + init_fn(self.weight, self.bias) + else: + if init == "default": + lecun_normal_init_(self.weight) + elif init == "relu": + he_normal_init_(self.weight) + elif init == "glorot": + glorot_uniform_init_(self.weight) + elif init == "gating": + gating_init_(self.weight) + if bias: + with torch.no_grad(): + self.bias.fill_(1.0) + elif init == "normal": + normal_init_(self.weight) + elif init == "final": + final_init_(self.weight) + else: + raise ValueError("Invalid init string.") + + +class LayerNorm(nn.Module): + + def __init__(self, c_in, eps=1e-5): + super(LayerNorm, self).__init__() + + self.c_in = (c_in,) + self.eps = eps + + self.weight = nn.Parameter(torch.ones(c_in)) + self.bias = nn.Parameter(torch.zeros(c_in)) + + def forward(self, x): + out = nn.functional.layer_norm( + x, + self.c_in, + self.weight, + self.bias, + self.eps, + ) + + return out + + +@torch.jit.ignore +def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + s = torch.nn.functional.softmax(t, dim=dim) + + return s + + +#@torch.jit.script +def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + biases: List[torch.Tensor]) -> torch.Tensor: + # [*, H, Q, C_hidden] + query = permute_final_dims(query, (1, 0, 2)) + + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 2, 0)) + + # [*, H, V, C_hidden] + value = permute_final_dims(value, (1, 0, 2)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + + for b in biases: + a += b + + a = softmax(a, -1) + + # [*, H, Q, C_hidden] + a = a.to(dtype=value.dtype) + a = torch.matmul(a, value) + + # [*, Q, H, C_hidden] + a = a.transpose(-2, -3) + + return a + + +@torch.jit.ignore +def _attention_chunked_trainable( + query, + key, + value, + biases, + chunk_size, + chunk_dim, + checkpoint, +): + if (checkpoint and len(biases) > 2): + raise ValueError("Checkpointed version permits only permits two bias terms") + + def _checkpointable_attention(q, k, v, b1, b2): + bs = [b for b in [b1, b2] if b is not None] + return _attention(q, k, v, bs) + + o_chunks = [] + checkpoint_fn = get_checkpoint_fn() + count = query.shape[chunk_dim] + for start in range(0, count, chunk_size): + end = start + chunk_size + idx = [slice(None)] * len(query.shape) + idx[chunk_dim] = slice(start, end) + idx_tup = tuple(idx) + q_chunk = query[idx_tup] + k_chunk = key[idx_tup] + v_chunk = value[idx_tup] + + def _slice_bias(b): + idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)) + return b[tuple(idx)] + + if (checkpoint): + bias_1_chunk, bias_2_chunk = [ + _slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2] + ] + + o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk, + bias_1_chunk, bias_2_chunk) + else: + bias_chunks = [_slice_bias(b) for b in biases] + + o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks) + + o_chunks.append(o_chunk) + + o = torch.cat(o_chunks, dim=chunk_dim) + return o + + +class Attention(nn.Module): + """ + Standard multi-head attention using AlphaFold's default layer + initialization. Allows multiple bias vectors. + """ + + def __init__( + self, + c_q: int, + c_k: int, + c_v: int, + c_hidden: int, + no_heads: int, + gating: bool = True, + ): + """ + Args: + c_q: + Input dimension of query data + c_k: + Input dimension of key data + c_v: + Input dimension of value data + c_hidden: + Per-head hidden dimension + no_heads: + Number of attention heads + gating: + Whether the output should be gated using query data + """ + super(Attention, self).__init__() + + self.c_q = c_q + self.c_k = c_k + self.c_v = c_v + self.c_hidden = c_hidden + self.no_heads = no_heads + self.gating = gating + + # DISCREPANCY: c_hidden is not the per-head channel dimension, as + # stated in the supplement, but the overall channel dimension. + + self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final") + + self.linear_g = None + if self.gating: + self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating") + + self.sigmoid = nn.Sigmoid() + + def _prep_qkv(self, q_x: torch.Tensor, + kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, Q/K/V, H * C_hidden] + q = self.linear_q(q_x) + k = self.linear_k(kv_x) + v = self.linear_v(kv_x) + + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + q /= math.sqrt(self.c_hidden) + + return q, k, v + + def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: + if (self.linear_g is not None): + g = self.sigmoid(self.linear_g(q_x)) + + # [*, Q, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + o = o * g + + # [*, Q, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + def forward( + self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + biases: Optional[List[torch.Tensor]] = None, + use_lma: bool = False, + q_chunk_size: Optional[int] = None, + kv_chunk_size: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List of biases that broadcast to [*, H, Q, K] + use_lma: + Whether to use low-memory attention + q_chunk_size: + Query chunk size (for LMA) + kv_chunk_size: + Key/Value chunk size (for LMA) + Returns + [*, Q, C_q] attention update + """ + if (biases is None): + biases = [] + if (use_lma and (q_chunk_size is None or kv_chunk_size is None)): + raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must " + "be provided") + + q, k, v = self._prep_qkv(q_x, kv_x) + + if (use_lma): + biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases] + + o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size) + else: + o = _attention(q, k, v, biases) + + o = self._wrap_up(o, q_x) + + return o + + +class GlobalAttention(nn.Module): + + def __init__(self, c_in, c_hidden, no_heads, inf, eps): + super(GlobalAttention, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.inf = inf + self.eps = eps + + self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot") + + self.linear_k = Linear( + c_in, + c_hidden, + bias=False, + init="glorot", + ) + self.linear_v = Linear( + c_in, + c_hidden, + bias=False, + init="glorot", + ) + self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating") + self.linear_o = Linear(c_hidden * no_heads, c_in, init="final") + + self.sigmoid = nn.Sigmoid() + + def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # [*, N_res, C_in] + q = torch.sum(m * mask.unsqueeze(-1), + dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps) + + # [*, N_res, H * C_hidden] + q = self.linear_q(q) + q *= (self.c_hidden**(-0.5)) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, N_seq, C_hidden] + k = self.linear_k(m) + v = self.linear_v(m) + + # [*, N_res, H, N_seq] + a = torch.matmul( + q, + k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq] + ) + bias = (self.inf * (mask - 1))[..., :, None, :] + + a += bias + a = softmax(a) + + # [*, N_res, H, C_hidden] + a = a.to(dtype=v.dtype) + o = torch.matmul( + a, + v, + ) + + # [*, N_res, N_seq, C_hidden] + g = self.sigmoid(self.linear_g(m)) + + # [*, N_res, N_seq, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, N_seq, H, C_hidden] + o = o.unsqueeze(-3) * g + + # [*, N_res, N_seq, H * C_hidden] + o = o.reshape(o.shape[:-2] + (-1,)) + + # [*, N_res, N_seq, C_in] + m = self.linear_o(o) + + return m + + +def _lma( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + biases: List[torch.Tensor], + q_chunk_size: int, + kv_chunk_size: int, +): + no_q, no_kv = q.shape[-3], k.shape[-3] + + # [*, Q, H, C_hidden] + o = q.new_zeros(q.shape) + for q_s in range(0, no_q, q_chunk_size): + q_chunk = q[..., q_s:q_s + q_chunk_size, :, :] + large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases] + + maxes = [] + weights = [] + values = [] + for kv_s in range(0, no_kv, kv_chunk_size): + k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :] + v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :] + small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks] + + a = torch.einsum( + "...qhd,...khd->...hqk", + q_chunk, + k_chunk, + ) + + for b in small_bias_chunks: + a += b + + a = a.transpose(-2, -3) + + max_a = torch.max(a, dim=-1, keepdim=True)[0] + exp_a = torch.exp(a - max_a) + exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a) + + maxes.append(max_a.detach().squeeze(-1)) + weights.append(torch.sum(exp_a, dim=-1)) + values.append(exp_v) + + chunk_max = torch.stack(maxes, dim=-3) + chunk_weights = torch.stack(weights, dim=-3) + chunk_values = torch.stack(values, dim=-4) + + global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0] + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= max_diffs.unsqueeze(-1) + chunk_weights *= max_diffs + + all_values = torch.sum(chunk_values, dim=-4) + all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4) + + q_chunk_out = all_values / all_weights + + o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out + + return o \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/structure_module.py b/alphafold_pytorch_jit/structure_module/structure_module.py new file mode 100644 index 0000000000000000000000000000000000000000..50952a13ee949f04d346b4d6dc3ed8a40f095654 --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/structure_module.py @@ -0,0 +1,961 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch.nn as nn +from typing import Any, Dict, Optional, Tuple, Union + +from .primitives import Linear, LayerNorm, ipa_point_weights_init_ +from .common.residue_constants import ( + restype_rigid_group_default_frame, + restype_atom14_to_rigid_group, + restype_atom14_mask, + restype_atom14_rigid_group_positions, +) +from .utils.geometry.quat_rigid import QuatRigid +from .utils.geometry.rigid_matrix_vector import Rigid3Array +from .utils.geometry.vector import Vec3Array +from .utils.feats import ( + frames_and_literature_positions_to_atom14_pos, + torsion_angles_to_frames, +) +from .utils.rigid_utils import Rotation, Rigid +from .utils.tensor_utils import ( + dict_multimap, + permute_final_dims, + flatten_final_dims, +) + +class AngleResnetBlock(nn.Module): + def __init__(self, c_hidden): + """ + Args: + c_hidden: + Hidden channel dimension + """ + super(AngleResnetBlock, self).__init__() + + self.c_hidden = c_hidden + + self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu") + self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final") + + self.relu = nn.ReLU() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + + s_initial = a + + a = self.relu(a) + a = self.linear_1(a) + a = self.relu(a) + a = self.linear_2(a) + + return a + s_initial + + +class AngleResnet(nn.Module): + """ + Implements Algorithm 20, lines 11-14 + """ + + def __init__( + self, c_in: int, c_hidden: int, no_blocks: int, no_angles: int, epsilon: float + ): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Hidden channel dimension + no_blocks: + Number of resnet blocks + no_angles: + Number of torsion angles to generate + epsilon: + Small constant for normalization + """ + super(AngleResnet, self).__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_blocks = no_blocks + self.no_angles = no_angles + self.eps = epsilon + + self.linear_in = Linear(self.c_in, self.c_hidden) + self.linear_initial = Linear(self.c_in, self.c_hidden) + + self.layers = nn.ModuleList() + for _ in range(self.no_blocks): + layer = AngleResnetBlock(c_hidden=self.c_hidden) + self.layers.append(layer) + + self.linear_out = Linear(self.c_hidden, self.no_angles * 2) + + self.relu = nn.ReLU() + + def forward( + self, s: torch.Tensor, s_initial: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + s: + [*, C_hidden] single embedding + s_initial: + [*, C_hidden] single embedding as of the start of the + StructureModule + Returns: + [*, no_angles, 2] predicted angles + """ + # NOTE: The ReLU's applied to the inputs are absent from the supplement + # pseudocode but present in the source. For maximal compatibility with + # the pretrained weights, I'm going with the source. + + # [*, C_hidden] + s_initial = self.relu(s_initial) + s_initial = self.linear_initial(s_initial) + s = self.relu(s) + s = self.linear_in(s) + s = s + s_initial + + for l in self.layers: + s = l(s) + + s = self.relu(s) + + # [*, no_angles * 2] + s = self.linear_out(s) + + # [*, no_angles, 2] + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s**2, dim=-1, keepdim=True), + min=self.eps, + ) + ) + s = s / norm_denom + + return unnormalized_s, s + + +class PointProjection(nn.Module): + def __init__( + self, + c_hidden: int, + num_points: int, + no_heads: int, + return_local_points: bool = False, + ): + super().__init__() + self.return_local_points = return_local_points + self.no_heads = no_heads + + self.linear = Linear(c_hidden, no_heads * 3 * num_points) + + def forward( + self, + activations: torch.Tensor, + rigids: Rigid3Array, + ) -> Union[Vec3Array, Tuple[Vec3Array, Vec3Array]]: + # TODO: Needs to run in high precision during training + points_local = self.linear(activations) + points_local = points_local.reshape( + *points_local.shape[:-1], + self.no_heads, + -1, + ) + points_local = torch.split(points_local, points_local.shape[-1] // 3, dim=-1) + points_local = Vec3Array(*points_local) + points_global = rigids[..., None, None].apply_to_point(points_local) + + if self.return_local_points: + return points_global, points_local + + return points_global + + +class InvariantPointAttention(nn.Module): + """ + Implements Algorithm 22. + """ + + def __init__( + self, + c_s: int, + c_z: int, + c_hidden: int, + no_heads: int, + no_qk_points: int, + no_v_points: int, + inf: float = 1e5, + eps: float = 1e-8, + is_multimer: bool = False, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_hidden: + Hidden channel dimension + no_heads: + Number of attention heads + no_qk_points: + Number of query/key points to generate + no_v_points: + Number of value points to generate + """ + super(InvariantPointAttention, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_hidden = c_hidden + self.no_heads = no_heads + self.no_qk_points = no_qk_points + self.no_v_points = no_v_points + self.inf = inf + self.eps = eps + self.is_multimer = is_multimer + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + if not self.is_multimer: + hc = self.c_hidden * self.no_heads + self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer)) + self.linear_kv = Linear(self.c_s, 2 * hc) + + hpq = self.no_heads * self.no_qk_points * 3 + self.linear_q_points = Linear(self.c_s, hpq) + + hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 + self.linear_kv_points = Linear(self.c_s, hpkv) + + # hpv = self.no_heads * self.no_v_points * 3 + + else: + hc = self.c_hidden * self.no_heads + self.linear_q = Linear(self.c_s, hc, bias=(not is_multimer)) + self.linear_q_points = PointProjection( + self.c_s, self.no_qk_points, self.no_heads + ) + + self.linear_k = Linear(self.c_s, hc, bias=False) + self.linear_v = Linear(self.c_s, hc, bias=False) + self.linear_k_points = PointProjection( + self.c_s, + self.no_qk_points, + self.no_heads, + ) + + self.linear_v_points = PointProjection( + self.c_s, + self.no_v_points, + self.no_heads, + ) + self.linear_b = Linear(self.c_z, self.no_heads) + + self.head_weights = nn.Parameter(torch.zeros((no_heads))) + ipa_point_weights_init_(self.head_weights) + + concat_out_dim = self.no_heads * ( + self.c_z + self.c_hidden + self.no_v_points * 4 + ) + self.linear_out = Linear(concat_out_dim, self.c_s, init="final") + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: torch.Tensor, + r: Union[Rigid, Rigid3Array], + mask: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + ####################################### + # Generate scalar and point activations + ####################################### + + # The following two blocks are equivalent + # They're separated only to preserve compatibility with old AF weights + if self.is_multimer: + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, P_qk] + q_pts = self.linear_q_points(s, r) + # [*, N_res, H * C_hidden] + k = self.linear_k(s) + v = self.linear_v(s) + + # [*, N_res, H, C_hidden] + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, P_qk, 3] + k_pts = self.linear_k_points(s, r) + + # [*, N_res, H, P_v, 3] + v_pts = self.linear_v_points(s, r) + else: + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + kv = self.linear_kv(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.c_hidden, dim=-1) + + # [*, N_res, H * P_q * 3] + q_pts = self.linear_q_points(s) + + # This is kind of clunky, but it's how the original does it + # [*, N_res, H * P_q, 3] + q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) + q_pts = torch.stack(q_pts, dim=-1) + q_pts = r[..., None].apply(q_pts) + + # [*, N_res, H, P_q, 3] + q_pts = q_pts.view(q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)) + + # [*, N_res, H * (P_q + P_v) * 3] + kv_pts = self.linear_kv_points(s) + + # [*, N_res, H * (P_q + P_v), 3] + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = r[..., None].apply(kv_pts) + + # [*, N_res, H, (P_q + P_v), 3] + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split( + kv_pts, [self.no_qk_points, self.no_v_points], dim=-2 + ) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z) + + # [*, H, N_res, N_res] + a = torch.matmul( + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + a *= math.sqrt(1.0 / (3 * self.c_hidden)) + a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)) + + if self.is_multimer: + # [*, N_res, N_res, H, P_q, 3] + pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :] + # [*, N_res, N_res, H, P_q] + pt_att = sum([c**2 for c in pt_att]) + else: + # [*, N_res, N_res, H, P_q, 3] + ###################################### + q_pts_t0 = q_pts.unsqueeze(-4) + q_shape = q_pts_t0.shape + q_pts_t0 = q_pts_t0.reshape([q_shape[0], q_shape[1], -1]) + k_pts_t0 = k_pts.unsqueeze(-5) + k_shape = k_pts_t0.shape + k_pts_t0 = k_pts_t0.reshape([k_shape[0], k_shape[1], -1]) + q_k = q_pts_t0 - k_pts_t0 + q_k = q_k ** 2 + q_k_shape = q_k.shape + pt_att = q_k.reshape(q_k_shape[:2] + q_shape[-3:]) + ##################################### + pt_att = pt_att.permute(0, 4, 1, 2, 3) + pt_att = torch.sum(pt_att, 1) + + head_weights = self.softplus(self.head_weights).view( + *((1,) * len(pt_att.shape[:-2]) + (-1, 1)) + ) + head_weights = head_weights * math.sqrt( + 1.0 / (3 * (self.no_qk_points * 9.0 / 2)) + ) + ############################## + pt_att_t0 = pt_att.permute(0, 3, 1, 2) + head_weights_t0 = head_weights.permute(0, 3, 1, 2) + pt_att_o = pt_att_t0 * head_weights_t0 + pt_att = pt_att_o.permute(0, 2,3, 1) + ############################## + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # As DeepMind explains, this manual matmul ensures that the operation + # happens in float32. + if self.is_multimer: + # [*, N_res, H, P_v] + o_pt = v_pts * permute_final_dims(a, (1, 2, 0)).unsqueeze(-1) + o_pt = o_pt.sum(dim=-3) + + # [*, N_res, H, P_v] + o_pt = r[..., None, None].apply_inverse_to_point(o_pt) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,)) + + # [*, N_res, H * P_v] + o_pt_norm = o_pt.norm(self.eps) + else: + # [*, H, 3, N_res, P_v] + ################################### + a1 = a[..., None, :, :, None] + a1 = a1.permute(0, 1, 2, 4, 3) + b = permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] + b = b.permute(0, 1, 2, 4, 3) + c = a1 * b + o_pt = torch.sum(c, -1) + ################################### + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_norm = flatten_final_dims( + torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.eps), 2 + ) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + + # [*, N_res, H, C_z] + o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype)) + + # [*, N_res, H * C_z] + o_pair = flatten_final_dims(o_pair, 2) + + # [*, N_res, C_s] + if self.is_multimer: + s = self.linear_out( + torch.cat((o, *o_pt, o_pt_norm, o_pair), dim=-1).to(dtype=z.dtype) + ) + else: + s = self.linear_out( + torch.cat( + (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1 + ).to(dtype=z.dtype) + ) + + return s + + +class BackboneUpdate(nn.Module): + """ + Implements part of Algorithm 23. + """ + + def __init__(self, c_s: int): + """ + Args: + c_s: + Single representation channel dimension + """ + super(BackboneUpdate, self).__init__() + + self.c_s = c_s + + self.linear = Linear(self.c_s, 6, init="final") + + def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + [*, N_res, C_s] single representation + Returns: + [*, N_res, 6] update vector + """ + # [*, 6] + update = self.linear(s) + + return update + + +class StructureModuleTransitionLayer(nn.Module): + def __init__(self, c: int): + super(StructureModuleTransitionLayer, self).__init__() + + self.c = c + + self.linear_1 = Linear(self.c, self.c, init="relu") + self.linear_2 = Linear(self.c, self.c, init="relu") + self.linear_3 = Linear(self.c, self.c, init="final") + + self.relu = nn.ReLU() + + def forward(self, s: torch.Tensor): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + s = s + s_initial + + return s + + +class StructureModuleTransition(nn.Module): + def __init__(self, c: int, num_layers: int, dropout_rate: float): + super(StructureModuleTransition, self).__init__() + + self.c = c + self.num_layers = num_layers + self.dropout_rate = dropout_rate + + self.layers = nn.ModuleList() + for _ in range(self.num_layers): + l = StructureModuleTransitionLayer(self.c) + self.layers.append(l) + + self.dropout = nn.Dropout(self.dropout_rate) + self.layer_norm = LayerNorm(self.c) + + def forward(self, s: torch.Tensor) -> torch.Tensor: + for l in self.layers: + s = l(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class StructureModule(nn.Module): + def __init__( + self, + c_s: int, + c_z: int, + c_ipa: int, + c_resnet: int, + no_heads_ipa: int, + no_qk_points: int, + no_v_points: int, + dropout_rate: float, + no_blocks: int, + no_transition_layers: int, + no_resnet_blocks: int, + no_angles: int, + trans_scale_factor: float, + epsilon: float, + inf: float, + is_multimer: bool = False, + **kwargs, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_ipa: + IPA hidden channel dimension + c_resnet: + Angle resnet (Alg. 23 lines 11-14) hidden channel dimension + no_heads_ipa: + Number of IPA heads + no_qk_points: + Number of query/key points to generate during IPA + no_v_points: + Number of value points to generate during IPA + dropout_rate: + Dropout rate used throughout the layer + no_blocks: + Number of structure module blocks + no_transition_layers: + Number of layers in the single representation transition + (Alg. 23 lines 8-9) + no_resnet_blocks: + Number of blocks in the angle resnet + no_angles: + Number of angles to generate in the angle resnet + trans_scale_factor: + Scale of single representation transition hidden dimension + epsilon: + Small number used in angle resnet normalization + inf: + Large number used for attention masking + is_multimer: + whether running under multimer mode + """ + super(StructureModule, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_ipa = c_ipa + self.c_resnet = c_resnet + self.no_heads_ipa = no_heads_ipa + self.no_qk_points = no_qk_points + self.no_v_points = no_v_points + self.dropout_rate = dropout_rate + self.no_blocks = no_blocks + self.no_transition_layers = no_transition_layers + self.no_resnet_blocks = no_resnet_blocks + self.no_angles = no_angles + self.trans_scale_factor = trans_scale_factor + self.epsilon = epsilon + self.inf = inf + self.is_multimer = is_multimer + + # To be lazily initialized later + self.default_frames = None + self.group_idx = None + self.atom_mask = None + self.lit_positions = None + + self.layer_norm_s = LayerNorm(self.c_s) + self.layer_norm_z = LayerNorm(self.c_z) + + self.linear_in = Linear(self.c_s, self.c_s) + + self.ipa = InvariantPointAttention( + self.c_s, + self.c_z, + self.c_ipa, + self.no_heads_ipa, + self.no_qk_points, + self.no_v_points, + inf=self.inf, + eps=self.epsilon, + is_multimer=self.is_multimer, + ) + + self.ipa_dropout = nn.Dropout(self.dropout_rate) + self.layer_norm_ipa = LayerNorm(self.c_s) + + self.transition = StructureModuleTransition( + self.c_s, + self.no_transition_layers, + self.dropout_rate, + ) + + if is_multimer: + self.bb_update = QuatRigid(self.c_s, full_quat=False) + else: + self.bb_update = BackboneUpdate(self.c_s) + + self.angle_resnet = AngleResnet( + self.c_s, + self.c_resnet, + self.no_resnet_blocks, + self.no_angles, + self.epsilon, + ) + + def _forward_monomer( + self, + s: torch.Tensor, + z: torch.Tensor, + aatype: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Dict[str, Any]: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + aatype: + [*, N_res] amino acid indices + mask: + Optional [*, N_res] sequence mask + Returns: + A dictionary of outputs + """ + if mask is None: + # [*, N] + mask = s.new_ones(s.shape[:-1]) + + # [*, N, C_s] + s = self.layer_norm_s(s) + + # [*, N, N, C_z] + z = self.layer_norm_z(z) + + # [*, N, C_s] + s_initial = s + s = self.linear_in(s) + + # [*, N] + rigids = Rigid.identity( + s.shape[:-1], + s.dtype, + s.device, + self.training, + fmt="quat", + ) + outputs = [] + for i in range(self.no_blocks): + # [*, N, C_s] + s = s + self.ipa(s, z, rigids, mask) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # [*, N] + rigids = rigids.compose_q_update_vec(self.bb_update(s)) + + # To hew as closely as possible to AlphaFold, we convert our + # quaternion-based transformations to rotation-matrix ones + # here + backb_to_global = Rigid( + Rotation( + rot_mats=rigids.get_rots().get_rot_mats(), + quats=None + ), + rigids.get_trans(), + ) + + backb_to_global = backb_to_global.scale_translation( + self.trans_scale_factor + ) + + # [*, N, 7, 2] + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + all_frames_to_global = self.torsion_angles_to_frames( + backb_to_global, + angles, + aatype, + ) + + pred_xyz = self.frames_and_literature_positions_to_atom14_pos( + all_frames_to_global, + aatype, + ) + + scaled_rigids = rigids.scale_translation(self.trans_scale_factor) + + preds = { + "frames": scaled_rigids.to_tensor_7(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "unnormalized_angles": unnormalized_angles, + "angles": angles, + "positions": pred_xyz, + } + + outputs.append(preds) + + if i < (self.no_blocks - 1): + rigids = rigids.stop_rot_gradient() + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + + return outputs + + def _forward_multimer( + self, + s: torch.Tensor, + z: torch.Tensor, + aatype: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Dict[str, Any]: + if mask is None: + # [*, N] + mask = s.new_ones(s.shape[:-1]) + + # [*, N, C_s] + s = self.layer_norm_s(s) + + # [*, N, N, C_z] + z = self.layer_norm_z(z) + + # [*, N, C_s] + s_initial = s + s = self.linear_in(s) + + # [*, N] + rigids = Rigid3Array.identity( + s.shape[:-1], + s.device, + ) + outputs = [] + for i in range(self.no_blocks): + # [*, N, C_s] + s = s + self.ipa(s, z, rigids, mask) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # [*, N] + rigids = rigids @ self.bb_update(s) + + # [*, N, 7, 2] + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + all_frames_to_global = self.torsion_angles_to_frames( + rigids.scale_translation(self.trans_scale_factor), + angles, + aatype, + ) + + pred_xyz = self.frames_and_literature_positions_to_atom14_pos( + all_frames_to_global, + aatype, + ) + + preds = { + "frames": rigids.scale_translation(self.trans_scale_factor).to_tensor(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "unnormalized_angles": unnormalized_angles, + "angles": angles, + "positions": pred_xyz.to_tensor(), + } + + outputs.append(preds) + + if i < (self.no_blocks - 1): + rigids = rigids.stop_rot_gradient() + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + + return outputs + + def forward( + self, + s: torch.Tensor, + z: torch.Tensor, + aatype: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ): + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + aatype: + [*, N_res] amino acid indices + mask: + Optional [*, N_res] sequence mask + Returns: + A dictionary of outputs + """ + if self.is_multimer: + outputs = self._forward_multimer(s, z, aatype, mask) + else: + outputs = self._forward_monomer(s, z, aatype, mask) + + return outputs + + def _init_residue_constants(self, float_dtype: torch.dtype, device: torch.device): + if self.default_frames is None: + self.default_frames = torch.tensor( + restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + if self.group_idx is None: + self.group_idx = torch.tensor( + restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ) + if self.atom_mask is None: + self.atom_mask = torch.tensor( + restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + if self.lit_positions is None: + self.lit_positions = torch.tensor( + restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + + def torsion_angles_to_frames( + self, r: Union[Rigid, Rigid3Array], alpha: torch.Tensor, f + ): + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(alpha.dtype, alpha.device) + # Separated purely to make testing less annoying + return torsion_angles_to_frames(r, alpha, f, self.default_frames) + + def frames_and_literature_positions_to_atom14_pos( + self, r: Union[Rigid, Rigid3Array], f # [*, N, 8] # [*, N] + ): + # Lazily initialize the residue constants on the correct device + if type(r) == Rigid: + self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) + elif type(r) == Rigid3Array: + self._init_residue_constants(r.dtype, r.device) + else: + raise ValueError("Unknown rigid type") + return frames_and_literature_positions_to_atom14_pos( + r, + f, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/utils/__init__.py b/alphafold_pytorch_jit/structure_module/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/alphafold_pytorch_jit/structure_module/utils/checkpointing.py b/alphafold_pytorch_jit/structure_module/utils/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..8a42c9127052eeec2ecfaa5542ff202dcfbfa75f --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/checkpointing.py @@ -0,0 +1,84 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.utils.checkpoint +from typing import Any, Tuple, List, Callable, Optional + + +BLOCK_ARG = Any +BLOCK_ARGS = List[BLOCK_ARG] + + +def get_checkpoint_fn(): + checkpoint = torch.utils.checkpoint.checkpoint + + return checkpoint + + +@torch.jit.ignore +def checkpoint_blocks( + blocks: List[Callable], + args: BLOCK_ARGS, + blocks_per_ckpt: Optional[int], +) -> BLOCK_ARGS: + """ + Chunk a list of blocks and run each chunk with activation + checkpointing. We define a "block" as a callable whose only inputs are + the outputs of the previous block. + + Implements Subsection 1.11.8 + + Args: + blocks: + List of blocks + args: + Tuple of arguments for the first block. + blocks_per_ckpt: + Size of each chunk. A higher value corresponds to fewer + checkpoints, and trades memory for speed. If None, no checkpointing + is performed. + Returns: + The output of the final block + """ + def wrap(a): + return (a,) if type(a) is not tuple else a + + def exec(b, a): + for block in b: + a = wrap(block(*a)) + return a + + def chunker(s, e): + def exec_sliced(*a): + return exec(blocks[s:e], a) + + return exec_sliced + + # Avoids mishaps when the blocks take just one argument + args = wrap(args) + + if blocks_per_ckpt is None: + return exec(blocks, args) + elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): + raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") + + checkpoint = get_checkpoint_fn() + + for s in range(0, len(blocks), blocks_per_ckpt): + e = s + blocks_per_ckpt + args = checkpoint(chunker(s, e), *args) + args = wrap(args) + + return args \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/utils/feats.py b/alphafold_pytorch_jit/structure_module/utils/feats.py new file mode 100644 index 0000000000000000000000000000000000000000..12e27b60adcbf4d9d3c3782178f7488938b7db73 --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/feats.py @@ -0,0 +1,319 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn +from typing import Any, Dict, Optional, Tuple, Union + +from ..common import protein +from ..common import residue_constants as rc +from .geometry.rigid_matrix_vector import Rigid3Array +from .geometry.rotation_matrix import Rot3Array +from .rigid_utils import Rotation, Rigid +from .tensor_utils import ( + batched_gather, + one_hot, + tree_map, + tensor_tree_map, +) + +def dgram_from_positions( + pos: torch.Tensor, + min_bin: float = 3.25, + max_bin: float = 50.75, + no_bins: float = 39, + inf: float = 1e8, +) -> torch.Tensor: + dgram = torch.sum( + (pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True + ) + lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower).type(dgram.dtype) * (dgram < upper)).type(dgram.dtype) + + return dgram + +def pseudo_beta_fn( + aatype, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + is_gly = aatype == rc.restype_order["G"] + ca_idx = rc.atom_order["CA"] + cb_idx = rc.atom_order["CB"] + pseudo_beta = torch.where( + is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_masks is not None: + pseudo_beta_mask = torch.where( + is_gly, + all_atom_masks[..., ca_idx], + all_atom_masks[..., cb_idx], + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta, None + + +def atom14_to_atom37(atom14, batch: Dict[str, Any]): + atom37_data = batched_gather( + atom14, + batch["residx_atom37_to_atom14"], + dim=-2, + no_batch_dims=len(atom14.shape[:-2]), + ) + + atom37_data = atom37_data * batch["atom37_atom_exists"][..., None] + + return atom37_data + + +def build_template_angle_feat(template_feats: Dict[str, Any]) -> torch.Tensor: + template_aatype = template_feats["template_aatype"] + torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"] + alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"] + torsion_angles_mask = template_feats["template_torsion_angles_mask"] + template_angle_feat = torch.cat( + [ + nn.functional.one_hot(template_aatype, 22).to(torch.float32), + torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14), + alt_torsion_angles_sin_cos.reshape( + *alt_torsion_angles_sin_cos.shape[:-2], 14 + ), + torsion_angles_mask, + ], + dim=-1, + ) + + return template_angle_feat + + +def build_template_pair_feat( + batch: Dict[str, Any], + min_bin: float, + max_bin: float, + no_bins: int, + use_unit_vector: bool = False, + eps: float = 1e-20, + inf: float = 1e8, + chunk=None +): + if chunk and 1 <= chunk <= 4: + for k, v in batch.items(): + batch[k] = v.cpu() + + template_mask = batch["template_pseudo_beta_mask"] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + # Compute distogram (this seems to differ slightly from Alg. 5) + tpb = batch["template_pseudo_beta"] + dgram = dgram_from_positions(tpb, min_bin, max_bin, no_bins, inf) + + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot = nn.functional.one_hot( + batch["template_aatype"], + rc.restype_num + 2, + ) + + n_res = batch["template_aatype"].shape[-1] + to_concat.append( + aatype_one_hot[..., None, :, :].expand( + *aatype_one_hot.shape[:-2], n_res, -1, -1 + ).to(dgram.dtype) + ) + to_concat.append( + aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1).to(dgram.dtype) + ) + + n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] + rigids = Rigid.make_transform_from_reference( + n_xyz=batch["template_all_atom_positions"][..., n, :], + ca_xyz=batch["template_all_atom_positions"][..., ca, :], + c_xyz=batch["template_all_atom_positions"][..., c, :], + eps=eps, + ) + points = rigids.get_trans()[..., None, :, :] + rigid_vec = rigids[..., None].invert_apply(points) + del rigids, points + + inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) + + t_aa_masks = batch["template_all_atom_mask"] + template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] + del t_aa_masks, n, ca, c + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + inv_distance_scalar = inv_distance_scalar * template_mask_2d + unit_vector = rigid_vec * inv_distance_scalar[..., None] + + if not use_unit_vector: + unit_vector = unit_vector * 0.0 + + to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) + to_concat.append(template_mask_2d[..., None]) + del unit_vector, rigid_vec, inv_distance_scalar + + act = torch.cat(to_concat, dim=-1) + act = act * template_mask_2d[..., None] + + return act + + +def build_extra_msa_feat(batch: Dict[str, Any]) -> torch.Tensor: + msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23) + msa_feat = [ + msa_1hot.to(torch.float32), + batch["extra_has_deletion"].unsqueeze(-1), + batch["extra_deletion_value"].unsqueeze(-1), + ] + return torch.cat(msa_feat, dim=-1) + + +def torsion_angles_to_frames( + r: Union[Rigid3Array, Rigid], + alpha: torch.Tensor, + aatype: torch.Tensor, + rrgdf: torch.Tensor, +) -> Union[Rigid, Rigid3Array]: + # [*, N, 8, 4, 4] + default_4x4 = rrgdf[aatype, ...] + + # [*, N, 8] transformations, i.e. + # One [*, N, 8, 3, 3] rotation matrix and + # One [*, N, 8, 3] translation matrix + default_r = r.from_tensor_4x4(default_4x4) + + bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + # [*, N, 8, 2] + alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) + + # [*, N, 8, 3, 3] + # Produces rotation matrices of the form: + # [ + # [1, 0 , 0 ], + # [0, a_2,-a_1], + # [0, a_1, a_2] + # ] + # This follows the original code rather than the supplement, which uses + # different indices. + if type(r) == Rigid3Array: + all_rots = alpha.new_zeros(default_r.shape + (3, 3)) + elif type(r) == Rigid: + all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) + else: + raise TypeError(f"Wrong type of Rigid: {type(r)}") + + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + if type(r) == Rigid3Array: + all_rots = Rot3Array.from_array(all_rots) + all_frames = default_r.compose_rotation(all_rots) + elif type(r) == Rigid: + all_rots = Rigid(Rotation(rot_mats=all_rots), None) + all_frames = default_r.compose(all_rots) + else: + raise TypeError(f"Wrong type of Rigid: {type(r)}") + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + if type(all_frames) == Rigid3Array: + all_frames_to_bb = Rigid3Array.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + elif type(all_frames) == Rigid: + all_frames_to_bb = Rigid.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = r[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + r: Union[Rigid3Array, Rigid], + aatype: torch.Tensor, + default_frames: torch.Tensor, + group_idx: torch.Tensor, + atom_mask: torch.Tensor, + lit_positions: torch.Tensor, +) -> torch.Tensor: + # [*, N, 14, 4, 4] + default_4x4 = default_frames[aatype, ...] + + # [*, N, 14] + group_mask = group_idx[aatype, ...] + + # [*, N, 14, 8] + if type(r) == Rigid3Array: + group_mask = nn.functional.one_hot( + group_mask.long(), + num_classes=default_frames.shape[-3], + ) + elif type(r) == Rigid: + group_mask = nn.functional.one_hot( + group_mask.long(), + num_classes=default_frames.shape[-3], + ) + else: + raise TypeError(f"Wrong type of Rigid: {type(r)}") + + # [*, N, 14, 8] + t_atoms_to_global = r[..., None, :] * group_mask + + # [*, N, 14] + t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) + + # [*, N, 14, 1] + if type(r) == Rigid: + atom_mask = atom_mask[aatype, ...].unsqueeze(-1) + elif type(r) == Rigid3Array: + atom_mask = atom_mask[aatype, ...] + + # [*, N, 14, 3] + lit_positions = lit_positions[aatype, ...] + pred_positions = t_atoms_to_global.apply(lit_positions) + pred_positions = pred_positions * atom_mask + + return pred_positions \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/utils/geometry/__init__.py b/alphafold_pytorch_jit/structure_module/utils/geometry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/alphafold_pytorch_jit/structure_module/utils/geometry/quat_rigid.py b/alphafold_pytorch_jit/structure_module/utils/geometry/quat_rigid.py new file mode 100644 index 0000000000000000000000000000000000000000..23d1b58aabba42d7d1d50ce0bb69dd6f3199416f --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/geometry/quat_rigid.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn + +from ...primitives import Linear +from .rigid_matrix_vector import Rigid3Array +from .rotation_matrix import Rot3Array +from .vector import Vec3Array + + +class QuatRigid(nn.Module): + def __init__(self, c_hidden, full_quat): + super().__init__() + self.full_quat = full_quat + if self.full_quat: + rigid_dim = 7 + else: + rigid_dim = 6 + + self.linear = Linear(c_hidden, rigid_dim) + + def forward(self, activations: torch.Tensor) -> Rigid3Array: + # NOTE: During training, this needs to be run in higher precision + rigid_flat = self.linear(activations.to(torch.float32)) + + rigid_flat = torch.unbind(rigid_flat, dim=-1) + if(self.full_quat): + qw, qx, qy, qz = rigid_flat[:4] + translation = rigid_flat[4:] + else: + qx, qy, qz = rigid_flat[:3] + qw = torch.ones_like(qx) + translation = rigid_flat[3:] + + rotation = Rot3Array.from_quaternion( + qw, qx, qy, qz, normalize=True, + ) + translation = Vec3Array(*translation) + return Rigid3Array(rotation, translation) \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/utils/geometry/rigid_matrix_vector.py b/alphafold_pytorch_jit/structure_module/utils/geometry/rigid_matrix_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..40deb2753d10a2dc80a2d737c968e90e8887ce45 --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/geometry/rigid_matrix_vector.py @@ -0,0 +1,175 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rigid3Array Transformations represented by a Matrix and a Vector.""" + +from __future__ import annotations +import dataclasses +from typing import Union, List + +import torch + +from . import rotation_matrix +from . import vector + + +Float = Union[float, torch.Tensor] + + +@dataclasses.dataclass(frozen=True) +class Rigid3Array: + """Rigid Transformation, i.e. element of special euclidean group.""" + + rotation: rotation_matrix.Rot3Array + translation: vector.Vec3Array + + def __matmul__(self, other: Rigid3Array) -> Rigid3Array: + new_rotation = self.rotation @ other.rotation # __matmul__ + new_translation = self.apply_to_point(other.translation) + return Rigid3Array(new_rotation, new_translation) + + def __getitem__(self, index) -> Rigid3Array: + return Rigid3Array( + self.rotation[index], + self.translation[index], + ) + + def __mul__(self, other: torch.Tensor) -> Rigid3Array: + return Rigid3Array( + self.rotation * other, + self.translation * other, + ) + + def map_tensor_fn(self, fn) -> Rigid3Array: + return Rigid3Array( + self.rotation.map_tensor_fn(fn), + self.translation.map_tensor_fn(fn), + ) + + def inverse(self) -> Rigid3Array: + """Return Rigid3Array corresponding to inverse transform.""" + inv_rotation = self.rotation.inverse() + inv_translation = inv_rotation.apply_to_point(-self.translation) + return Rigid3Array(inv_rotation, inv_translation) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply Rigid3Array transform to point.""" + return self.rotation.apply_to_point(point) + self.translation + + def apply(self, point: torch.Tensor) -> vector.Vec3Array: + return self.apply_to_point(vector.Vec3Array.from_array(point)) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply inverse Rigid3Array transform to point.""" + new_point = point - self.translation + return self.rotation.apply_inverse_to_point(new_point) + + def compose_rotation(self, other_rotation): + rot = self.rotation @ other_rotation + return Rigid3Array(rot, self.translation.clone()) + + def compose(self, other_rigid): + return self @ other_rigid + + def unsqueeze(self, dim: int): + return Rigid3Array( + self.rotation.unsqueeze(dim), + self.translation.unsqueeze(dim), + ) + + @property + def shape(self) -> torch.Size: + return self.rotation.xx.shape + + @property + def dtype(self) -> torch.dtype: + return self.rotation.xx.dtype + + @property + def device(self) -> torch.device: + return self.rotation.xx.device + + @classmethod + def identity(cls, shape, device) -> Rigid3Array: + """Return identity Rigid3Array of given shape.""" + return cls( + rotation_matrix.Rot3Array.identity(shape, device), + vector.Vec3Array.zeros(shape, device) + ) + + @classmethod + def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array: + return cls( + rotation_matrix.Rot3Array.cat( + [r.rotation for r in rigids], dim=dim + ), + vector.Vec3Array.cat( + [r.translation for r in rigids], dim=dim + ), + ) + + def scale_translation(self, factor: Float) -> Rigid3Array: + """Scale translation in Rigid3Array by 'factor'.""" + return Rigid3Array(self.rotation, self.translation * factor) + + def to_tensor(self) -> torch.Tensor: + rot_array = self.rotation.to_tensor() + vec_array = self.translation.to_tensor() + array = torch.zeros( + rot_array.shape[:-2] + (4, 4), + device=rot_array.device, + dtype=rot_array.dtype + ) + array[..., :3, :3] = rot_array + array[..., :3, 3] = vec_array + array[..., 3, 3] = 1. + return array + + def to_tensor_4x4(self) -> torch.Tensor: + return self.to_tensor() + + def reshape(self, new_shape) -> Rigid3Array: + rots = self.rotation.reshape(new_shape) + trans = self.translation.reshape(new_shape) + return Rigid3Aray(rots, trans) + + def stop_rot_gradient(self) -> Rigid3Array: + return Rigid3Array( + self.rotation.stop_gradient(), + self.translation, + ) + + @classmethod + def from_array(cls, array): + rot = rotation_matrix.Rot3Array.from_array( + array[..., :3, :3], + ) + vec = vector.Vec3Array.from_array(array[..., :3, 3]) + return cls(rot, vec) + + @classmethod + def from_tensor_4x4(cls, array): + return cls.from_array(array) + + @classmethod + def from_array4x4(cls, array: torch.tensor) -> Rigid3Array: + """Construct Rigid3Array from homogeneous 4x4 array.""" + rotation = rotation_matrix.Rot3Array( + array[..., 0, 0], array[..., 0, 1], array[..., 0, 2], + array[..., 1, 0], array[..., 1, 1], array[..., 1, 2], + array[..., 2, 0], array[..., 2, 1], array[..., 2, 2] + ) + translation = vector.Vec3Array( + array[..., 0, 3], array[..., 1, 3], array[..., 2, 3] + ) + return cls(rotation, translation) diff --git a/alphafold_pytorch_jit/structure_module/utils/geometry/rotation_matrix.py b/alphafold_pytorch_jit/structure_module/utils/geometry/rotation_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..835e460a6e79e9f53c340f3617b4ebf7d7a8277d --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/geometry/rotation_matrix.py @@ -0,0 +1,208 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rot3Array Matrix Class.""" + +from __future__ import annotations +import dataclasses + +import torch +import numpy as np + +from . import utils +from . import vector +from ..tensor_utils import tensor_tree_map + + +COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] + +@dataclasses.dataclass(frozen=True) +class Rot3Array: + """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" + xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) + xy: torch.Tensor + xz: torch.Tensor + yx: torch.Tensor + yy: torch.Tensor + yz: torch.Tensor + zx: torch.Tensor + zy: torch.Tensor + zz: torch.Tensor + + __array_ufunc__ = None + + def __getitem__(self, index): + field_names = utils.get_field_names(Rot3Array) + return Rot3Array( + **{ + name: getattr(self, name)[index] + for name in field_names + } + ) + + def __mul__(self, other: torch.Tensor): + field_names = utils.get_field_names(Rot3Array) + return Rot3Array( + **{ + name: getattr(self, name) * other + for name in field_names + } + ) + + def __matmul__(self, other: Rot3Array) -> Rot3Array: + """Composes two Rot3Arrays.""" + c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) + c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) + c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) + return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + def map_tensor_fn(self, fn) -> Rot3Array: + field_names = utils.get_field_names(Rot3Array) + return Rot3Array( + **{ + name: fn(getattr(self, name)) + for name in field_names + } + ) + + def inverse(self) -> Rot3Array: + """Returns inverse of Rot3Array.""" + return Rot3Array( + self.xx, self.yx, self.zx, + self.xy, self.yy, self.zy, + self.xz, self.yz, self.zz + ) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies Rot3Array to point.""" + return vector.Vec3Array( + self.xx * point.x + self.xy * point.y + self.xz * point.z, + self.yx * point.x + self.yy * point.y + self.yz * point.z, + self.zx * point.x + self.zy * point.y + self.zz * point.z + ) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies inverse Rot3Array to point.""" + return self.inverse().apply_to_point(point) + + + def unsqueeze(self, dim: int): + return Rot3Array( + *tensor_tree_map( + lambda t: t.unsqueeze(dim), + [getattr(self, c) for c in COMPONENTS] + ) + ) + + def stop_gradient(self) -> Rot3Array: + return Rot3Array( + *[getattr(self, c).detach() for c in COMPONENTS] + ) + + @classmethod + def identity(cls, shape, device) -> Rot3Array: + """Returns identity of given shape.""" + ones = torch.ones(shape, dtype=torch.float32, device=device) + zeros = torch.zeros(shape, dtype=torch.float32, device=device) + return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) + + @classmethod + def from_two_vectors( + cls, e0: vector.Vec3Array, + e1: vector.Vec3Array + ) -> Rot3Array: + """Construct Rot3Array from two Vectors. + + Rot3Array is constructed such that in the corresponding frame 'e0' lies on + the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. + + Args: + e0: Vector + e1: Vector + Returns: + Rot3Array + """ + # Normalize the unit vector for the x-axis, e0. + e0 = e0.normalized() + # make e1 perpendicular to e0. + c = e1.dot(e0) + e1 = (e1 - c * e0).normalized() + # Compute e2 as cross product of e0 and e1. + e2 = e0.cross(e1) + return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + + @classmethod + def from_array(cls, array: torch.Tensor) -> Rot3Array: + """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" + rows = torch.unbind(array, dim=-2) + rc = [torch.unbind(e, dim=-1) for e in rows] + return cls(*[e for row in rc for e in row]) + + def to_tensor(self) -> torch.Tensor: + """Convert Rot3Array to array of shape [..., 3, 3].""" + return torch.stack( + [ + torch.stack([self.xx, self.xy, self.xz], dim=-1), + torch.stack([self.yx, self.yy, self.yz], dim=-1), + torch.stack([self.zx, self.zy, self.zz], dim=-1) + ], + dim=-2 + ) + + @classmethod + def from_quaternion(cls, + w: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + normalize: bool = True, + eps: float = 1e-6 + ) -> Rot3Array: + """Construct Rot3Array from components of quaternion.""" + if normalize: + inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2) + w *= inv_norm + x *= inv_norm + y *= inv_norm + z *= inv_norm + xx = 1 - 2 * (y ** 2 + z ** 2) + xy = 2 * (x * y - w * z) + xz = 2 * (x * z + w * y) + yx = 2 * (x * y + w * z) + yy = 1 - 2 * (x ** 2 + z ** 2) + yz = 2 * (y * z - w * x) + zx = 2 * (x * z - w * y) + zy = 2 * (y * z + w * x) + zz = 1 - 2 * (x ** 2 + y ** 2) + return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) + + def reshape(self, new_shape): + field_names = utils.get_field_names(Rot3Array) + reshape_fn = lambda t: t.reshape(new_shape) + return Rot3Array( + **{ + name: reshape_fn(getattr(self, name)) + for name in field_names + } + ) + + @classmethod + def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array: + field_names = utils.get_field_names(Rot3Array) + cat_fn = lambda l: torch.cat(l, dim=dim) + return cls( + **{ + name: cat_fn([getattr(r, name) for r in rots]) + for name in field_names + } + ) diff --git a/alphafold_pytorch_jit/structure_module/utils/geometry/utils.py b/alphafold_pytorch_jit/structure_module/utils/geometry/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..30da0a23e5ea5278cab38b344f758843327709eb --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/geometry/utils.py @@ -0,0 +1,22 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for geometry library.""" + +import dataclasses + + +def get_field_names(cls): + fields = dataclasses.fields(cls) + field_names = [f.name for f in fields] + return field_names \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/utils/geometry/vector.py b/alphafold_pytorch_jit/structure_module/utils/geometry/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..17c375a9b70372e97ac6c2c478f89564bb06e56f --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/geometry/vector.py @@ -0,0 +1,263 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Vec3Array Class.""" + +from __future__ import annotations +import dataclasses +from typing import Union, List + +import torch + +from . import utils + +Float = Union[float, torch.Tensor] + +@dataclasses.dataclass(frozen=True) +class Vec3Array: + x: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) + y: torch.Tensor + z: torch.Tensor + + def __post_init__(self): + if hasattr(self.x, 'dtype'): + assert self.x.dtype == self.y.dtype + assert self.x.dtype == self.z.dtype + assert all([x == y for x, y in zip(self.x.shape, self.y.shape)]) + assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) + + def __add__(self, other: Vec3Array) -> Vec3Array: + return Vec3Array( + self.x + other.x, + self.y + other.y, + self.z + other.z, + ) + + def __sub__(self, other: Vec3Array) -> Vec3Array: + return Vec3Array( + self.x - other.x, + self.y - other.y, + self.z - other.z, + ) + + def __mul__(self, other: Float) -> Vec3Array: + return Vec3Array( + self.x * other, + self.y * other, + self.z * other, + ) + + def __rmul__(self, other: Float) -> Vec3Array: + return self * other + + def __truediv__(self, other: Float) -> Vec3Array: + return Vec3Array( + self.x / other, + self.y / other, + self.z / other, + ) + + def __neg__(self) -> Vec3Array: + return self * -1 + + def __pos__(self) -> Vec3Array: + return self * 1 + + def __getitem__(self, index) -> Vec3Array: + return Vec3Array( + self.x[index], + self.y[index], + self.z[index], + ) + + def __iter__(self): + return iter((self.x, self.y, self.z)) + + @property + def shape(self): + return self.x.shape + + def map_tensor_fn(self, fn) -> Vec3Array: + return Vec3Array( + fn(self.x), + fn(self.y), + fn(self.z), + ) + + def cross(self, other: Vec3Array) -> Vec3Array: + """Compute cross product between 'self' and 'other'.""" + new_x = self.y * other.z - self.z * other.y + new_y = self.z * other.x - self.x * other.z + new_z = self.x * other.y - self.y * other.x + return Vec3Array(new_x, new_y, new_z) + + def dot(self, other: Vec3Array) -> Float: + """Compute dot product between 'self' and 'other'.""" + return self.x * other.x + self.y * other.y + self.z * other.z + + def norm(self, epsilon: float = 1e-6) -> Float: + """Compute Norm of Vec3Array, clipped to epsilon.""" + # To avoid NaN on the backward pass, we must use maximum before the sqrt + norm2 = self.dot(self) + if epsilon: + norm2 = torch.clamp(norm2, min=epsilon**2) + return torch.sqrt(norm2) + + def norm2(self): + return self.dot(self) + + def normalized(self, epsilon: float = 1e-6) -> Vec3Array: + """Return unit vector with optional clipping.""" + return self / self.norm(epsilon) + + def clone(self) -> Vec3Array: + return Vec3Array( + self.x.clone(), + self.y.clone(), + self.z.clone(), + ) + + def reshape(self, new_shape) -> Vec3Array: + x = self.x.reshape(new_shape) + y = self.y.reshape(new_shape) + z = self.z.reshape(new_shape) + + return Vec3Array(x, y, z) + + def sum(self, dim: int) -> Vec3Array: + return Vec3Array( + torch.sum(self.x, dim=dim), + torch.sum(self.y, dim=dim), + torch.sum(self.z, dim=dim), + ) + + def unsqueeze(self, dim: int): + return Vec3Array( + self.x.unsqueeze(dim), + self.y.unsqueeze(dim), + self.z.unsqueeze(dim), + ) + + @classmethod + def zeros(cls, shape, device="cpu"): + """Return Vec3Array corresponding to zeros of given shape.""" + return cls( + torch.zeros(shape, dtype=torch.float32, device=device), + torch.zeros(shape, dtype=torch.float32, device=device), + torch.zeros(shape, dtype=torch.float32, device=device) + ) + + def to_tensor(self) -> torch.Tensor: + return torch.stack([self.x, self.y, self.z], dim=-1) + + @classmethod + def from_array(cls, tensor): + return cls(*torch.unbind(tensor, dim=-1)) + + @classmethod + def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array: + return cls( + torch.cat([v.x for v in vecs], dim=dim), + torch.cat([v.y for v in vecs], dim=dim), + torch.cat([v.z for v in vecs], dim=dim), + ) + + +def square_euclidean_distance( + vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6 +) -> Float: + """Computes square of euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute distance to + vec2: Vec3Array to compute distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of square euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + difference = vec1 - vec2 + distance = difference.dot(difference) + if epsilon: + distance = torch.maximum(distance, epsilon) + return distance + + +def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.dot(vector2) + + +def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.cross(vector2) + + +def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: + return vector.norm(epsilon) + + +def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: + return vector.normalized(epsilon) + + +def euclidean_distance( + vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6 +) -> Float: + """Computes euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute euclidean distance to + vec2: Vec3Array to compute euclidean distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) + distance = torch.sqrt(distance_sq) + return distance + + +def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, + d: Vec3Array) -> Float: + """Computes torsion angle for a quadruple of points. + + For points (a, b, c, d), this is the angle between the planes defined by + points (a, b, c) and (b, c, d). It is also known as the dihedral angle. + + Arguments: + a: A Vec3Array of coordinates. + b: A Vec3Array of coordinates. + c: A Vec3Array of coordinates. + d: A Vec3Array of coordinates. + + Returns: + A tensor of angles in radians: [-pi, pi]. + """ + v1 = a - b + v2 = b - c + v3 = d - c + + c1 = v1.cross(v2) + c2 = v3.cross(v2) + c3 = c2.cross(c1) + + v2_mag = v2.norm() + return torch.atan2(c3.dot(v2), v2_mag * c1.dot(c2)) \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/utils/import_weights.py b/alphafold_pytorch_jit/structure_module/utils/import_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..62ba2b883986e1ae855c4c3f244f020a9ea6f037 --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/import_weights.py @@ -0,0 +1,272 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from dataclasses import dataclass +from functools import partial +import numpy as np +import torch +from typing import Union, List + +#_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/" +_NPZ_KEY_PREFIX = "" + +# With Param, a poor man's enum with attributes (Rust-style) +class ParamType(Enum): + LinearWeight = partial( # hack: partial prevents fns from becoming methods + lambda w: w.transpose(-1, -2) + ) + LinearWeightMHA = partial( + lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2) + ) + LinearMHAOutputWeight = partial( + lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) + ) + LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1)) + LinearWeightOPM = partial( + lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2) + ) + LinearWeightMultimer = partial( + lambda w: w.unsqueeze(-1) + if len(w.shape) == 1 + else w.reshape(w.shape[0], -1).transpose(-1, -2) + ) + LinearBiasMultimer = partial(lambda w: w.reshape(-1)) + Other = partial(lambda w: w) + + def __init__(self, fn): + self.transformation = fn + + +@dataclass +class Param: + param: Union[torch.Tensor, List[torch.Tensor]] + param_type: ParamType = ParamType.Other + stacked: bool = False + + +def _process_translations_dict(d, top_layer=True): + flat = {} + for k, v in d.items(): + if type(v) == dict: + prefix = _NPZ_KEY_PREFIX if top_layer else "" + sub_flat = { + (prefix + "/".join([k, k_prime])): v_prime + for k_prime, v_prime in _process_translations_dict( + v, top_layer=False + ).items() + } + flat.update(sub_flat) + else: + k = "/" + k if not top_layer else k + flat[k] = v + + return flat + + +def stacked(param_dict_list, out=None): + """ + Args: + param_dict_list: + A list of (nested) Param dicts to stack. The structure of + each dict must be the identical (down to the ParamTypes of + "parallel" Params). There must be at least one dict + in the list. + """ + if out is None: + out = {} + template = param_dict_list[0] + for k, _ in template.items(): + v = [d[k] for d in param_dict_list] + if type(v[0]) is dict: + out[k] = {} + stacked(v, out=out[k]) + elif type(v[0]) is Param: + stacked_param = Param( + param=[param.param for param in v], + param_type=v[0].param_type, + stacked=True, + ) + + out[k] = stacked_param + + return out + + +def assign(translation_dict, orig_weights): + for k, param in translation_dict.items(): + with torch.no_grad(): + weights = torch.as_tensor(orig_weights[k]) + ref, param_type = param.param, param.param_type + if param.stacked: + weights = torch.unbind(weights, 0) + else: + weights = [weights] + ref = [ref] + + try: + weights = list(map(param_type.transformation, weights)) + for p, w in zip(ref, weights): + p.copy_(w) + except: + print(k) + print(ref[0].shape) + print(weights[0].shape) + raise + +def get_translation_dict(model, version): + is_multimer = "multimer" in version + ####################### + # Some templates + ####################### + LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight)) + LinearBias = lambda l: (Param(l)) + LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA)) + LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA)) + LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM)) + LinearWeightMultimer = lambda l: ( + Param(l, param_type=ParamType.LinearWeightMultimer) + ) + LinearBiasMultimer = lambda l: (Param(l, param_type=ParamType.LinearBiasMultimer)) + + LinearParams = lambda l: { + "weights": LinearWeight(l.weight), + "bias": LinearBias(l.bias), + } + + LinearParamsMultimer = lambda l: { + "weights": LinearWeightMultimer(l.weight), + "bias": LinearBiasMultimer(l.bias), + } + + LayerNormParams = lambda l: { + "scale": Param(l.weight), + "offset": Param(l.bias), + } + + IPAParams = lambda ipa: { + "q_scalar": LinearParams(ipa.linear_q), + "kv_scalar": LinearParams(ipa.linear_kv), + "q_point_local": LinearParams(ipa.linear_q_points), + # New style IPA param + # "q_point_local": LinearParams(ipa.linear_q_points.linear), + "kv_point_local": LinearParams(ipa.linear_kv_points), + # New style IPA param + # "kv_point_local": LinearParams(ipa.linear_kv_points.linear), + "trainable_point_weights": Param( + param=ipa.head_weights, param_type=ParamType.Other + ), + "attention_2d": LinearParams(ipa.linear_b), + "output_projection": LinearParams(ipa.linear_out), + } + + PointProjectionParams = lambda pp: { + "point_projection": LinearParamsMultimer( + pp.linear, + ), + } + + IPAParamsMultimer = lambda ipa: { + "q_scalar_projection": { + "weights": LinearWeightMultimer( + ipa.linear_q.weight, + ), + }, + "k_scalar_projection": { + "weights": LinearWeightMultimer( + ipa.linear_k.weight, + ), + }, + "v_scalar_projection": { + "weights": LinearWeightMultimer( + ipa.linear_v.weight, + ), + }, + "q_point_projection": PointProjectionParams(ipa.linear_q_points), + "k_point_projection": PointProjectionParams(ipa.linear_k_points), + "v_point_projection": PointProjectionParams(ipa.linear_v_points), + "trainable_point_weights": Param( + param=ipa.head_weights, param_type=ParamType.Other + ), + "attention_2d": LinearParams(ipa.linear_b), + "output_projection": LinearParams(ipa.linear_out), + } + + def FoldIterationParams(sm): + d = { + "invariant_point_attention": IPAParamsMultimer(sm.ipa) + if is_multimer + else IPAParams(sm.ipa), + "attention_layer_norm": LayerNormParams(sm.layer_norm_ipa), + "transition": LinearParams(sm.transition.layers[0].linear_1), + "transition_1": LinearParams(sm.transition.layers[0].linear_2), + "transition_2": LinearParams(sm.transition.layers[0].linear_3), + "transition_layer_norm": LayerNormParams(sm.transition.layer_norm), + "affine_update": LinearParams(sm.bb_update.linear), + "rigid_sidechain": { + "input_projection": LinearParams(sm.angle_resnet.linear_in), + "input_projection_1": LinearParams(sm.angle_resnet.linear_initial), + "resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1), + "resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2), + "resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1), + "resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2), + "unnormalized_angles": LinearParams(sm.angle_resnet.linear_out), + }, + } + + if is_multimer: + d.pop("affine_update") + d["quat_rigid"] = {"rigid": LinearParams(sm.bb_update.linear)} + + return d + + translations = { + "structure_module": { + "single_layer_norm": LayerNormParams( + model.layer_norm_s + ), + "initial_projection": LinearParams(model.linear_in), + "pair_layer_norm": LayerNormParams(model.layer_norm_z), + "fold_iteration": FoldIterationParams(model), + }, + } + + return translations + + +def import_jax_weights_(model, Struct_Params, version="model_1"): + data = {} + for k,v in Struct_Params.items(): + for name,value in v.items(): + data[f'{k}//{name}'] = value + + translations = get_translation_dict(model, version) + + # Flatten keys and insert missing key prefixes + flat = _process_translations_dict(translations) + + # Sanity check + keys = list(data.keys()) + flat_keys = list(flat.keys()) + incorrect = [k for k in flat_keys if k not in keys] + missing = [k for k in keys if k not in flat_keys] + # print(f"Incorrect: {incorrect}") + # print(f"Missing: {missing}") + + assert len(incorrect) == 0 + # assert(sorted(list(flat.keys())) == sorted(list(data.keys()))) + + # Set weights + assign(flat, data) diff --git a/alphafold_pytorch_jit/structure_module/utils/rigid_utils.py b/alphafold_pytorch_jit/structure_module/utils/rigid_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..947799fa89346c1f0a425967ac109ccbb3572b94 --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/rigid_utils.py @@ -0,0 +1,1391 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Tuple, Any, Sequence, Callable, Optional + +import numpy as np +import torch + +def rot_matmul( + a: torch.Tensor, + b: torch.Tensor +) -> torch.Tensor: + """ + Performs matrix multiplication of two rotation matrix tensors. Written + out by hand to avoid AMP downcasting. + + Args: + a: [*, 3, 3] left multiplicand + b: [*, 3, 3] right multiplicand + Returns: + The product ab + """ + + + row_1 = torch.stack( + [ + a[..., 0, 0] * b[..., 0, 0] + + a[..., 0, 1] * b[..., 1, 0] + + a[..., 0, 2] * b[..., 2, 0], + a[..., 0, 0] * b[..., 0, 1] + + a[..., 0, 1] * b[..., 1, 1] + + a[..., 0, 2] * b[..., 2, 1], + a[..., 0, 0] * b[..., 0, 2] + + a[..., 0, 1] * b[..., 1, 2] + + a[..., 0, 2] * b[..., 2, 2], + ], + dim=-1, + ) + row_2 = torch.stack( + [ + a[..., 1, 0] * b[..., 0, 0] + + a[..., 1, 1] * b[..., 1, 0] + + a[..., 1, 2] * b[..., 2, 0], + a[..., 1, 0] * b[..., 0, 1] + + a[..., 1, 1] * b[..., 1, 1] + + a[..., 1, 2] * b[..., 2, 1], + a[..., 1, 0] * b[..., 0, 2] + + a[..., 1, 1] * b[..., 1, 2] + + a[..., 1, 2] * b[..., 2, 2], + ], + dim=-1, + ) + row_3 = torch.stack( + [ + a[..., 2, 0] * b[..., 0, 0] + + a[..., 2, 1] * b[..., 1, 0] + + a[..., 2, 2] * b[..., 2, 0], + a[..., 2, 0] * b[..., 0, 1] + + a[..., 2, 1] * b[..., 1, 1] + + a[..., 2, 2] * b[..., 2, 1], + a[..., 2, 0] * b[..., 0, 2] + + a[..., 2, 1] * b[..., 1, 2] + + a[..., 2, 2] * b[..., 2, 2], + ], + dim=-1, + ) + + return torch.stack([row_1, row_2, row_3], dim=-2) + + +def rot_vec_mul( + r: torch.Tensor, + t: torch.Tensor +) -> torch.Tensor: + """ + Applies a rotation to a vector. Written out by hand to avoid transfer + to avoid AMP downcasting. + + Args: + r: [*, 3, 3] rotation matrices + t: [*, 3] coordinate tensors + Returns: + [*, 3] rotated coordinates + """ + + x = t[..., 0] + y = t[..., 1] + z = t[..., 2] + return torch.stack( + [ + r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z, + r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z, + r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z, + ], + dim=-1, + ) + + +def identity_rot_mats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + rots = torch.eye( + 3, dtype=dtype, device=device, requires_grad=requires_grad + ) + rots = rots.view(*((1,) * len(batch_dims)), 3, 3) + rots = rots.expand(*batch_dims, -1, -1) + + return rots + + +def identity_trans( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + trans = torch.zeros( + (*batch_dims, 3), + dtype=dtype, + device=device, + requires_grad=requires_grad + ) + return trans + + +def identity_quats( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, +) -> torch.Tensor: + quat = torch.zeros( + (*batch_dims, 4), + dtype=dtype, + device=device, + requires_grad=requires_grad + ) + + with torch.no_grad(): + quat[..., 0] = 1 + + return quat + + +_quat_elements = ["a", "b", "c", "d"] +_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements] +_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} + + +def _to_mat(pairs): + mat = np.zeros((4, 4)) + for pair in pairs: + key, value = pair + ind = _qtr_ind_dict[key] + mat[ind // 4][ind % 4] = value + + return mat + + +_QTR_MAT = np.zeros((4, 4, 3, 3)) +_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)]) +_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)]) +_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)]) +_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)]) +_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)]) +_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)]) +_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)]) +_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)]) +_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)]) + + +def quat_to_rot(quat: torch.Tensor) -> torch.Tensor: + """ + Converts a quaternion to a rotation matrix. + + Args: + quat: [*, 4] quaternions + Returns: + [*, 3, 3] rotation matrices + """ + # [*, 4, 4] + quat = quat[..., None] * quat[..., None, :] + + # [4, 4, 3, 3] + mat = quat.new_tensor(_QTR_MAT, requires_grad=False) + + # [*, 4, 4, 3, 3] + shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape) + quat = quat[..., None, None] * shaped_qtr_mat + + # [*, 3, 3] + return torch.sum(quat, dim=(-3, -4)) + + +def rot_to_quat( + rot: torch.Tensor, +): + if(rot.shape[-2:] != (3, 3)): + raise ValueError("Input rotation is incorrectly shaped") + + rot = [[rot[..., i, j] for j in range(3)] for i in range(3)] + [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot + + k = [ + [ xx + yy + zz, zy - yz, xz - zx, yx - xy,], + [ zy - yz, xx - yy - zz, xy + yx, xz + zx,], + [ xz - zx, xy + yx, yy - xx - zz, yz + zy,], + [ yx - xy, xz + zx, yz + zy, zz - xx - yy,] + ] + + k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2) + + _, vectors = torch.linalg.eigh(k) + return vectors[..., -1] + + +_QUAT_MULTIPLY = np.zeros((4, 4, 4)) +_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0], + [ 0,-1, 0, 0], + [ 0, 0,-1, 0], + [ 0, 0, 0,-1]] + +_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0], + [ 1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 0,-1, 0]] + +_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0], + [ 0, 0, 0,-1], + [ 1, 0, 0, 0], + [ 0, 1, 0, 0]] + +_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1], + [ 0, 0, 1, 0], + [ 0,-1, 0, 0], + [ 1, 0, 0, 0]] + +_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] + + +def quat_multiply(quat1, quat2): + """Multiply a quaternion by another quaternion.""" + mat = quat1.new_tensor(_QUAT_MULTIPLY) + reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * + quat1[..., :, None, None] * + quat2[..., None, :, None], + dim=(-3, -2) + ) + + +def quat_multiply_by_vec(quat, vec): + """Multiply a quaternion by a pure-vector quaternion.""" + mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC) + reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * + quat[..., :, None, None] * + vec[..., None, :, None], + dim=(-3, -2) + ) + + +def invert_rot_mat(rot_mat: torch.Tensor): + return rot_mat.transpose(-1, -2) + + +def invert_quat(quat: torch.Tensor): + quat_prime = quat.clone() + quat_prime[..., 1:] *= -1 + inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True) + return inv + + +class Rotation: + """ + A 3D rotation. Depending on how the object is initialized, the + rotation is represented by either a rotation matrix or a + quaternion, though both formats are made available by helper functions. + To simplify gradient computation, the underlying format of the + rotation cannot be changed in-place. Like Rigid, the class is designed + to mimic the behavior of a torch Tensor, almost as if each Rotation + object were a tensor of rotations, in one format or another. + """ + def __init__(self, + rot_mats: Optional[torch.Tensor] = None, + quats: Optional[torch.Tensor] = None, + normalize_quats: bool = True, + ): + """ + Args: + rot_mats: + A [*, 3, 3] rotation matrix tensor. Mutually exclusive with + quats + quats: + A [*, 4] quaternion. Mutually exclusive with rot_mats. If + normalize_quats is not True, must be a unit quaternion + normalize_quats: + If quats is specified, whether to normalize quats + """ + if((rot_mats is None and quats is None) or + (rot_mats is not None and quats is not None)): + raise ValueError("Exactly one input argument must be specified") + + if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or + (quats is not None and quats.shape[-1] != 4)): + raise ValueError( + "Incorrectly shaped rotation matrix or quaternion" + ) + + # Force full-precision + if(quats is not None): + quats = quats.to(dtype=torch.float32) + if(rot_mats is not None): + rot_mats = rot_mats.to(dtype=torch.float32) + + if(quats is not None and normalize_quats): + quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) + + self._rot_mats = rot_mats + self._quats = quats + + @staticmethod + def identity( + shape, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ) -> Rotation: + """ + Returns an identity Rotation. + + Args: + shape: + The "shape" of the resulting Rotation object. See documentation + for the shape property + dtype: + The torch dtype for the rotation + device: + The torch device for the new rotation + requires_grad: + Whether the underlying tensors in the new rotation object + should require gradient computation + fmt: + One of "quat" or "rot_mat". Determines the underlying format + of the new object's rotation + Returns: + A new identity rotation + """ + if(fmt == "rot_mat"): + rot_mats = identity_rot_mats( + shape, dtype, device, requires_grad, + ) + return Rotation(rot_mats=rot_mats, quats=None) + elif(fmt == "quat"): + quats = identity_quats(shape, dtype, device, requires_grad) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError(f"Invalid format: f{fmt}") + + # Magic methods + + def __getitem__(self, index: Any) -> Rotation: + """ + Allows torch-style indexing over the virtual shape of the rotation + object. See documentation for the shape property. + + Args: + index: + A torch index. E.g. (1, 3, 2), or (slice(None,)) + Returns: + The indexed rotation + """ + if type(index) != tuple: + index = (index,) + + if(self._rot_mats is not None): + rot_mats = self._rot_mats[index + (slice(None), slice(None))] + return Rotation(rot_mats=rot_mats) + elif(self._quats is not None): + quats = self._quats[index + (slice(None),)] + return Rotation(quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __mul__(self, + right: torch.Tensor, + ) -> Rotation: + """ + Pointwise left multiplication of the rotation with a tensor. Can be + used to e.g. mask the Rotation. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not(isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + if(self._rot_mats is not None): + rot_mats = self._rot_mats * right[..., None, None] + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = self._quats * right[..., None] + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def __rmul__(self, + left: torch.Tensor, + ) -> Rotation: + """ + Reverse pointwise multiplication of the rotation with a tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + # Properties + + @property + def shape(self) -> torch.Size: + """ + Returns the virtual shape of the rotation object. This shape is + defined as the batch dimensions of the underlying rotation matrix + or quaternion. If the Rotation was initialized with a [10, 3, 3] + rotation matrix tensor, for example, the resulting shape would be + [10]. + + Returns: + The virtual shape of the rotation object + """ + s = None + if(self._quats is not None): + s = self._quats.shape[:-1] + else: + s = self._rot_mats.shape[:-2] + + return s + + @property + def dtype(self) -> torch.dtype: + """ + Returns the dtype of the underlying rotation. + + Returns: + The dtype of the underlying rotation + """ + if(self._rot_mats is not None): + return self._rot_mats.dtype + elif(self._quats is not None): + return self._quats.dtype + else: + raise ValueError("Both rotations are None") + + @property + def device(self) -> torch.device: + """ + The device of the underlying rotation + + Returns: + The device of the underlying rotation + """ + if(self._rot_mats is not None): + return self._rot_mats.device + elif(self._quats is not None): + return self._quats.device + else: + raise ValueError("Both rotations are None") + + @property + def requires_grad(self) -> bool: + """ + Returns the requires_grad property of the underlying rotation + + Returns: + The requires_grad property of the underlying tensor + """ + if(self._rot_mats is not None): + return self._rot_mats.requires_grad + elif(self._quats is not None): + return self._quats.requires_grad + else: + raise ValueError("Both rotations are None") + + def get_rot_mats(self) -> torch.Tensor: + """ + Returns the underlying rotation as a rotation matrix tensor. + + Returns: + The rotation as a rotation matrix tensor + """ + rot_mats = self._rot_mats + if(rot_mats is None): + if(self._quats is None): + raise ValueError("Both rotations are None") + else: + rot_mats = quat_to_rot(self._quats) + + return rot_mats + + def get_quats(self) -> torch.Tensor: + """ + Returns the underlying rotation as a quaternion tensor. + + Depending on whether the Rotation was initialized with a + quaternion, this function may call torch.linalg.eigh. + + Returns: + The rotation as a quaternion tensor. + """ + quats = self._quats + if(quats is None): + if(self._rot_mats is None): + raise ValueError("Both rotations are None") + else: + quats = rot_to_quat(self._rot_mats) + + return quats + + def get_cur_rot(self) -> torch.Tensor: + """ + Return the underlying rotation in its current form + + Returns: + The stored rotation + """ + if(self._rot_mats is not None): + return self._rot_mats + elif(self._quats is not None): + return self._quats + else: + raise ValueError("Both rotations are None") + + # Rotation functions + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + normalize_quats: bool = True + ) -> Rotation: + """ + Returns a new quaternion Rotation after updating the current + object's underlying rotation with a quaternion update, formatted + as a [*, 3] tensor whose final three columns represent x, y, z such + that (1, x, y, z) is the desired (not necessarily unit) quaternion + update. + + Args: + q_update_vec: + A [*, 3] quaternion update tensor + normalize_quats: + Whether to normalize the output quaternion + Returns: + An updated Rotation + """ + quats = self.get_quats() + new_quats = quats + quat_multiply_by_vec(quats, q_update_vec) + return Rotation( + rot_mats=None, + quats=new_quats, + normalize_quats=normalize_quats, + ) + + def compose_r(self, r: Rotation) -> Rotation: + """ + Compose the rotation matrices of the current Rotation object with + those of another. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + r1 = self.get_rot_mats() + r2 = r.get_rot_mats() + new_rot_mats = rot_matmul(r1, r2) + return Rotation(rot_mats=new_rot_mats, quats=None) + + def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation: + """ + Compose the quaternions of the current Rotation object with those + of another. + + Depending on whether either Rotation was initialized with + quaternions, this function may call torch.linalg.eigh. + + Args: + r: + An update rotation object + Returns: + An updated rotation object + """ + q1 = self.get_quats() + q2 = r.get_quats() + new_quats = quat_multiply(q1, q2) + return Rotation( + rot_mats=None, quats=new_quats, normalize_quats=normalize_quats + ) + + def apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + Apply the current Rotation as a rotation matrix to a set of 3D + coordinates. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] rotated points + """ + rot_mats = self.get_rot_mats() + return rot_vec_mul(rot_mats, pts) + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + """ + The inverse of the apply() method. + + Args: + pts: + A [*, 3] set of points + Returns: + [*, 3] inverse-rotated points + """ + rot_mats = self.get_rot_mats() + inv_rot_mats = invert_rot_mat(rot_mats) + return rot_vec_mul(inv_rot_mats, pts) + + def invert(self) -> Rotation: + """ + Returns the inverse of the current Rotation. + + Returns: + The inverse of the current Rotation + """ + if(self._rot_mats is not None): + return Rotation( + rot_mats=invert_rot_mat(self._rot_mats), + quats=None + ) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=invert_quat(self._quats), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + # "Tensor" stuff + + def unsqueeze(self, + dim: int, + ) -> Rigid: + """ + Analogous to torch.unsqueeze. The dimension is relative to the + shape of the Rotation object. + + Args: + dim: A positive or negative dimension index. + Returns: + The unsqueezed Rotation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + + if(self._rot_mats is not None): + rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2) + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + @staticmethod + def cat( + rs: Sequence[Rotation], + dim: int, + ) -> Rigid: + """ + Concatenates rotations along one of the batch dimensions. Analogous + to torch.cat(). + + Note that the output of this operation is always a rotation matrix, + regardless of the format of input rotations. + + Args: + rs: + A list of rotation objects + dim: + The dimension along which the rotations should be + concatenated + Returns: + A concatenated Rotation object in rotation matrix format + """ + rot_mats = [r.get_rot_mats() for r in rs] + rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) + + return Rotation(rot_mats=rot_mats, quats=None) + + def map_tensor_fn(self, + fn: Callable[torch.Tensor, torch.Tensor] + ) -> Rotation: + """ + Apply a Tensor -> Tensor function to underlying rotation tensors, + mapping over the rotation dimension(s). Can be used e.g. to sum out + a one-hot batch dimension. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rotation + Returns: + The transformed Rotation object + """ + if(self._rot_mats is not None): + rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,)) + rot_mats = torch.stack( + list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1 + ) + rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3)) + return Rotation(rot_mats=rot_mats, quats=None) + elif(self._quats is not None): + quats = torch.stack( + list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1 + ) + return Rotation(rot_mats=None, quats=quats, normalize_quats=False) + else: + raise ValueError("Both rotations are None") + + def cuda(self) -> Rotation: + """ + Analogous to the cuda() method of torch Tensors + + Returns: + A copy of the Rotation in CUDA memory + """ + if(self._rot_mats is not None): + return Rotation(rot_mats=self._rot_mats.cuda(), quats=None) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.cuda(), + normalize_quats=False + ) + else: + raise ValueError("Both rotations are None") + + def to(self, + device: Optional[torch.device], + dtype: Optional[torch.dtype] + ) -> Rotation: + """ + Analogous to the to() method of torch Tensors + + Args: + device: + A torch device + dtype: + A torch dtype + Returns: + A copy of the Rotation using the new device and dtype + """ + if(self._rot_mats is not None): + return Rotation( + rot_mats=self._rot_mats.to(device=device, dtype=dtype), + quats=None, + ) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.to(device=device, dtype=dtype), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + def detach(self) -> Rotation: + """ + Returns a copy of the Rotation whose underlying Tensor has been + detached from its torch graph. + + Returns: + A copy of the Rotation whose underlying Tensor has been detached + from its torch graph + """ + if(self._rot_mats is not None): + return Rotation(rot_mats=self._rot_mats.detach(), quats=None) + elif(self._quats is not None): + return Rotation( + rot_mats=None, + quats=self._quats.detach(), + normalize_quats=False, + ) + else: + raise ValueError("Both rotations are None") + + +class Rigid: + """ + A class representing a rigid transformation. Little more than a wrapper + around two objects: a Rotation object and a [*, 3] translation + Designed to behave approximately like a single torch tensor with the + shape of the shared batch dimensions of its component parts. + """ + def __init__(self, + rots: Optional[Rotation], + trans: Optional[torch.Tensor], + ): + """ + Args: + rots: A [*, 3, 3] rotation tensor + trans: A corresponding [*, 3] translation tensor + """ + # (we need device, dtype, etc. from at least one input) + + batch_dims, dtype, device, requires_grad = None, None, None, None + if(trans is not None): + batch_dims = trans.shape[:-1] + dtype = trans.dtype + device = trans.device + requires_grad = trans.requires_grad + elif(rots is not None): + batch_dims = rots.shape + dtype = rots.dtype + device = rots.device + requires_grad = rots.requires_grad + else: + raise ValueError("At least one input argument must be specified") + + if(rots is None): + rots = Rotation.identity( + batch_dims, dtype, device, requires_grad, + ) + elif(trans is None): + trans = identity_trans( + batch_dims, dtype, device, requires_grad, + ) + + if((rots.shape != trans.shape[:-1]) or + (rots.device != trans.device)): + raise ValueError("Rots and trans incompatible") + + # Force full precision. Happens to the rotations automatically. + trans = trans.to(dtype=torch.float32) + + self._rots = rots + self._trans = trans + + @staticmethod + def identity( + shape: Tuple[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + requires_grad: bool = True, + fmt: str = "quat", + ) -> Rigid: + """ + Constructs an identity transformation. + + Args: + shape: + The desired shape + dtype: + The dtype of both internal tensors + device: + The device of both internal tensors + requires_grad: + Whether grad should be enabled for the internal tensors + Returns: + The identity transformation + """ + return Rigid( + Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt), + identity_trans(shape, dtype, device, requires_grad), + ) + + def __getitem__(self, + index: Any, + ) -> Rigid: + """ + Indexes the affine transformation with PyTorch-style indices. + The index is applied to the shared dimensions of both the rotation + and the translation. + + E.g.:: + + r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None) + t = Rigid(r, torch.rand(10, 10, 3)) + indexed = t[3, 4:6] + assert(indexed.shape == (2,)) + assert(indexed.get_rots().shape == (2,)) + assert(indexed.get_trans().shape == (2, 3)) + + Args: + index: A standard torch tensor index. E.g. 8, (10, None, 3), + or (3, slice(0, 1, None)) + Returns: + The indexed tensor + """ + if type(index) != tuple: + index = (index,) + + return Rigid( + self._rots[index], + self._trans[index + (slice(None),)], + ) + + def __mul__(self, + right: torch.Tensor, + ) -> Rigid: + """ + Pointwise left multiplication of the transformation with a tensor. + Can be used to e.g. mask the Rigid. + + Args: + right: + The tensor multiplicand + Returns: + The product + """ + if not(isinstance(right, torch.Tensor)): + raise TypeError("The other multiplicand must be a Tensor") + + new_rots = self._rots * right + new_trans = self._trans * right[..., None] + + return Rigid(new_rots, new_trans) + + def __rmul__(self, + left: torch.Tensor, + ) -> Rigid: + """ + Reverse pointwise multiplication of the transformation with a + tensor. + + Args: + left: + The left multiplicand + Returns: + The product + """ + return self.__mul__(left) + + @property + def shape(self) -> torch.Size: + """ + Returns the shape of the shared dimensions of the rotation and + the translation. + + Returns: + The shape of the transformation + """ + s = self._trans.shape[:-1] + return s + + @property + def device(self) -> torch.device: + """ + Returns the device on which the Rigid's tensors are located. + + Returns: + The device on which the Rigid's tensors are located + """ + return self._trans.device + + def get_rots(self) -> Rotation: + """ + Getter for the rotation. + + Returns: + The rotation object + """ + return self._rots + + def get_trans(self) -> torch.Tensor: + """ + Getter for the translation. + + Returns: + The stored translation + """ + return self._trans + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + ) -> Rigid: + """ + Composes the transformation with a quaternion update vector of + shape [*, 6], where the final 6 columns represent the x, y, and + z values of a quaternion of form (1, x, y, z) followed by a 3D + translation. + + Args: + q_vec: The quaternion update vector. + Returns: + The composed transformation. + """ + q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:] + new_rots = self._rots.compose_q_update_vec(q_vec) + + trans_update = self._rots.apply(t_vec) + new_translation = self._trans + trans_update + + return Rigid(new_rots, new_translation) + + def compose(self, + r: Rigid, + ) -> Rigid: + """ + Composes the current rigid object with another. + + Args: + r: + Another Rigid object + Returns: + The composition of the two transformations + """ + new_rot = self._rots.compose_r(r._rots) + new_trans = self._rots.apply(r._trans) + self._trans + return Rigid(new_rot, new_trans) + + def apply(self, + pts: torch.Tensor, + ) -> torch.Tensor: + """ + Applies the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor. + Returns: + The transformed points. + """ + rotated = self._rots.apply(pts) + return rotated + self._trans + + def invert_apply(self, + pts: torch.Tensor + ) -> torch.Tensor: + """ + Applies the inverse of the transformation to a coordinate tensor. + + Args: + pts: A [*, 3] coordinate tensor + Returns: + The transformed points. + """ + pts = pts - self._trans + return self._rots.invert_apply(pts) + + def invert(self) -> Rigid: + """ + Inverts the transformation. + + Returns: + The inverse transformation. + """ + rot_inv = self._rots.invert() + trn_inv = rot_inv.apply(self._trans) + + return Rigid(rot_inv, -1 * trn_inv) + + def map_tensor_fn(self, + fn: Callable[torch.Tensor, torch.Tensor] + ) -> Rigid: + """ + Apply a Tensor -> Tensor function to underlying translation and + rotation tensors, mapping over the translation/rotation dimensions + respectively. + + Args: + fn: + A Tensor -> Tensor function to be mapped over the Rigid + Returns: + The transformed Rigid object + """ + new_rots = self._rots.map_tensor_fn(fn) + new_trans = torch.stack( + list(map(fn, torch.unbind(self._trans, dim=-1))), + dim=-1 + ) + + return Rigid(new_rots, new_trans) + + def to_tensor_4x4(self) -> torch.Tensor: + """ + Converts a transformation to a homogenous transformation tensor. + + Returns: + A [*, 4, 4] homogenous transformation tensor + """ + tensor = self._trans.new_zeros((*self.shape, 4, 4)) + tensor[..., :3, :3] = self._rots.get_rot_mats() + tensor[..., :3, 3] = self._trans + tensor[..., 3, 3] = 1 + return tensor + + @staticmethod + def from_tensor_4x4( + t: torch.Tensor + ) -> Rigid: + """ + Constructs a transformation from a homogenous transformation + tensor. + + Args: + t: [*, 4, 4] homogenous transformation tensor + Returns: + T object with shape [*] + """ + if(t.shape[-2:] != (4, 4)): + raise ValueError("Incorrectly shaped input tensor") + + rots = Rotation(rot_mats=t[..., :3, :3], quats=None) + trans = t[..., :3, 3] + + return Rigid(rots, trans) + + def to_tensor_7(self) -> torch.Tensor: + """ + Converts a transformation to a tensor with 7 final columns, four + for the quaternion followed by three for the translation. + + Returns: + A [*, 7] tensor representation of the transformation + """ + tensor = self._trans.new_zeros((*self.shape, 7)) + tensor[..., :4] = self._rots.get_quats() + tensor[..., 4:] = self._trans + + return tensor + + @staticmethod + def from_tensor_7( + t: torch.Tensor, + normalize_quats: bool = False, + ) -> Rigid: + if(t.shape[-1] != 7): + raise ValueError("Incorrectly shaped input tensor") + + quats, trans = t[..., :4], t[..., 4:] + + rots = Rotation( + rot_mats=None, + quats=quats, + normalize_quats=normalize_quats + ) + + return Rigid(rots, trans) + + @staticmethod + def from_3_points( + p_neg_x_axis: torch.Tensor, + origin: torch.Tensor, + p_xy_plane: torch.Tensor, + eps: float = 1e-8 + ) -> Rigid: + """ + Implements algorithm 21. Constructs transformations from sets of 3 + points using the Gram-Schmidt algorithm. + + Args: + p_neg_x_axis: [*, 3] coordinates + origin: [*, 3] coordinates used as frame origins + p_xy_plane: [*, 3] coordinates + eps: Small epsilon value + Returns: + A transformation object of shape [*] + """ + p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) + origin = torch.unbind(origin, dim=-1) + p_xy_plane = torch.unbind(p_xy_plane, dim=-1) + + e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] + e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] + + denom = torch.sqrt(sum((c * c for c in e0)) + eps) + e0 = [c / denom for c in e0] + dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) + e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] + denom = torch.sqrt(sum((c * c for c in e1)) + eps) + e1 = [c / denom for c in e1] + e2 = [ + e0[1] * e1[2] - e0[2] * e1[1], + e0[2] * e1[0] - e0[0] * e1[2], + e0[0] * e1[1] - e0[1] * e1[0], + ] + + rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) + rots = rots.reshape(rots.shape[:-1] + (3, 3)) + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, torch.stack(origin, dim=-1)) + + def unsqueeze(self, + dim: int, + ) -> Rigid: + """ + Analogous to torch.unsqueeze. The dimension is relative to the + shared dimensions of the rotation/translation. + + Args: + dim: A positive or negative dimension index. + Returns: + The unsqueezed transformation. + """ + if dim >= len(self.shape): + raise ValueError("Invalid dimension") + rots = self._rots.unsqueeze(dim) + trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1) + + return Rigid(rots, trans) + + @staticmethod + def cat( + ts: Sequence[Rigid], + dim: int, + ) -> Rigid: + """ + Concatenates transformations along a new dimension. + + Args: + ts: + A list of T objects + dim: + The dimension along which the transformations should be + concatenated + Returns: + A concatenated transformation object + """ + rots = Rotation.cat([t._rots for t in ts], dim) + trans = torch.cat( + [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1 + ) + + return Rigid(rots, trans) + + def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid: + """ + Applies a Rotation -> Rotation function to the stored rotation + object. + + Args: + fn: A function of type Rotation -> Rotation + Returns: + A transformation object with a transformed rotation. + """ + return Rigid(fn(self._rots), self._trans) + + def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid: + """ + Applies a Tensor -> Tensor function to the stored translation. + + Args: + fn: + A function of type Tensor -> Tensor to be applied to the + translation + Returns: + A transformation object with a transformed translation. + """ + return Rigid(self._rots, fn(self._trans)) + + def scale_translation(self, trans_scale_factor: float) -> Rigid: + """ + Scales the translation by a constant factor. + + Args: + trans_scale_factor: + The constant factor + Returns: + A transformation object with a scaled translation. + """ + fn = lambda t: t * trans_scale_factor + return self.apply_trans_fn(fn) + + def stop_rot_gradient(self) -> Rigid: + """ + Detaches the underlying rotation object + + Returns: + A transformation object with detached rotations + """ + fn = lambda r: r.detach() + return self.apply_rot_fn(fn) + + @staticmethod + def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + """ + Returns a transformation object from reference coordinates. + + Note that this method does not take care of symmetries. If you + provide the atom positions in the non-standard way, the N atom will + end up not at [-0.527250, 1.359329, 0.0] but instead at + [-0.527250, -1.359329, 0.0]. You need to take care of such cases in + your code. + + Args: + n_xyz: A [*, 3] tensor of nitrogen xyz coordinates. + ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates. + c_xyz: A [*, 3] tensor of carbon xyz coordinates. + Returns: + A transformation object. After applying the translation and + rotation to the reference backbone, the coordinates will + approximately equal to the input coordinates. + """ + translation = -1 * ca_xyz + n_xyz = n_xyz + translation + c_xyz = c_xyz + translation + + c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2) + sin_c1 = -c_y / norm + cos_c1 = c_x / norm + zeros = sin_c1.new_zeros(sin_c1.shape) + ones = sin_c1.new_ones(sin_c1.shape) + + c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) + c1_rots[..., 0, 0] = cos_c1 + c1_rots[..., 0, 1] = -1 * sin_c1 + c1_rots[..., 1, 0] = sin_c1 + c1_rots[..., 1, 1] = cos_c1 + c1_rots[..., 2, 2] = 1 + + norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2) + sin_c2 = c_z / norm + cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm + + c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + c2_rots[..., 0, 0] = cos_c2 + c2_rots[..., 0, 2] = sin_c2 + c2_rots[..., 1, 1] = 1 + c2_rots[..., 2, 0] = -1 * sin_c2 + c2_rots[..., 2, 2] = cos_c2 + + c_rots = rot_matmul(c2_rots, c1_rots) + n_xyz = rot_vec_mul(c_rots, n_xyz) + + _, n_y, n_z = [n_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2) + sin_n = -n_z / norm + cos_n = n_y / norm + + n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + n_rots[..., 0, 0] = 1 + n_rots[..., 1, 1] = cos_n + n_rots[..., 1, 2] = -1 * sin_n + n_rots[..., 2, 1] = sin_n + n_rots[..., 2, 2] = cos_n + + rots = rot_matmul(n_rots, c_rots) + + rots = rots.transpose(-1, -2) + translation = -1 * translation + + rot_obj = Rotation(rot_mats=rots, quats=None) + + return Rigid(rot_obj, translation) + + def cuda(self) -> Rigid: + """ + Moves the transformation object to GPU memory + + Returns: + A version of the transformation on GPU + """ + return Rigid(self._rots.cuda(), self._trans.cuda()) \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/utils/superimposition.py b/alphafold_pytorch_jit/structure_module/utils/superimposition.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c74c54a55a004933109dc9e07e36ec3ea0ec31 --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/superimposition.py @@ -0,0 +1,100 @@ +# Copyright 2023 HPC-AI Tech Inc. +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from Bio.SVDSuperimposer import SVDSuperimposer +import torch + + +def _superimpose_np(reference, coords): + """ + Superimposes coordinates onto a reference by minimizing RMSD using SVD. + + Args: + reference: + [N, 3] reference array + coords: + [N, 3] array + Returns: + A tuple of [N, 3] superimposed coords and the final RMSD. + """ + sup = SVDSuperimposer() + sup.set(reference, coords) + sup.run() + return sup.get_transformed(), sup.get_rms() + + +def _superimpose_single(reference, coords): + reference_np = reference.detach().cpu().numpy() + coords_np = coords.detach().cpu().numpy() + superimposed, rmsd = _superimpose_np(reference_np, coords_np) + return coords.new_tensor(superimposed), coords.new_tensor(rmsd) + + +def superimpose(reference, coords, mask): + """ + Superimposes coordinates onto a reference by minimizing RMSD using SVD. + + Args: + reference: + [*, N, 3] reference tensor + coords: + [*, N, 3] tensor + mask: + [*, N] tensor + Returns: + A tuple of [*, N, 3] superimposed coords and [*] final RMSDs. + """ + def select_unmasked_coords(coords, mask): + return torch.masked_select( + coords, + (mask > 0.)[..., None], + ).reshape(-1, 3) + + batch_dims = reference.shape[:-2] + flat_reference = reference.reshape((-1,) + reference.shape[-2:]) + flat_coords = coords.reshape((-1,) + reference.shape[-2:]) + flat_mask = mask.reshape((-1,) + mask.shape[-1:]) + superimposed_list = [] + rmsds = [] + for r, c, m in zip(flat_reference, flat_coords, flat_mask): + r_unmasked_coords = select_unmasked_coords(r, m) + c_unmasked_coords = select_unmasked_coords(c, m) + superimposed, rmsd = _superimpose_single( + r_unmasked_coords, + c_unmasked_coords + ) + + # This is very inelegant, but idk how else to invert the masking + # procedure. + count = 0 + superimposed_full_size = torch.zeros_like(r) + for i, unmasked in enumerate(m): + if(unmasked): + superimposed_full_size[i] = superimposed[count] + count += 1 + + superimposed_list.append(superimposed_full_size) + rmsds.append(rmsd) + + superimposed_stacked = torch.stack(superimposed_list, dim=0) + rmsds_stacked = torch.stack(rmsds, dim=0) + + superimposed_reshaped = superimposed_stacked.reshape( + batch_dims + coords.shape[-2:] + ) + rmsds_reshaped = rmsds_stacked.reshape( + batch_dims + ) + + return superimposed_reshaped, rmsds_reshaped \ No newline at end of file diff --git a/alphafold_pytorch_jit/structure_module/utils/tensor_utils.py b/alphafold_pytorch_jit/structure_module/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4c205ee4e24e92d68ce2d7581e2428cc5d714ff8 --- /dev/null +++ b/alphafold_pytorch_jit/structure_module/utils/tensor_utils.py @@ -0,0 +1,409 @@ +# Copyright 2022 HPC-AI Tech Inc +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import torch +import torch.nn as nn +from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def masked_mean(mask, value, dim, eps=1e-4): + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): + boundaries = torch.linspace( + min_bin, max_bin, no_bins - 1, device=pts.device + ) + dists = torch.sqrt( + torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) + ) + return torch.bucketize(dists, boundaries) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if type(v) is dict: + new_dict[k] = dict_multimap(fn, all_v) + else: + # when bs = 1, returns [...] rather than [1, ...] + new_dict[k] = fn(all_v) if len(all_v) > 1 else all_v[0] + + return new_dict + + +def one_hot(x, v_bins): + reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) + diffs = x[..., None] - reshaped_bins + am = torch.argmin(torch.abs(diffs), dim=-1) + return nn.functional.one_hot(am, num_classes=len(v_bins)).float() + + +def batched_gather(data, inds, dim=0, no_batch_dims=0): + ranges = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [ + slice(None) for _ in range(len(data.shape) - no_batch_dims) + ] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +# With tree_map, a poor man's JAX tree_map +def dict_map(fn, dic, leaf_type): + new_dict = {} + for k, v in dic.items(): + if type(v) is dict: + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + + return new_dict + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple([tree_map(fn, x, leaf_type) for x in tree]) + elif isinstance(tree, leaf_type): + return fn(tree) + else: + print(type(tree)) + raise ValueError("Not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) + +def _fetch_dims(tree): + shapes = [] + tree_type = type(tree) + if tree_type is dict: + for v in tree.values(): + shapes.extend(_fetch_dims(v)) + elif tree_type is list or tree_type is tuple: + for t in tree: + shapes.extend(_fetch_dims(t)) + elif tree_type is torch.Tensor: + shapes.append(tree.shape) + else: + raise ValueError("Not supported") + + return shapes + + +@torch.jit.ignore +def _flat_idx_to_idx( + flat_idx: int, + dims: Tuple[int], +) -> Tuple[int]: + idx = [] + for d in reversed(dims): + idx.append(flat_idx % d) + flat_idx = flat_idx // d + + return tuple(reversed(idx)) + + +@torch.jit.ignore +def _get_minimal_slice_set( + start: Sequence[int], + end: Sequence[int], + dims: int, + start_edges: Optional[Sequence[bool]] = None, + end_edges: Optional[Sequence[bool]] = None, +) -> Sequence[Tuple[int]]: + """ + Produces an ordered sequence of tensor slices that, when used in + sequence on a tensor with shape dims, yields tensors that contain every + leaf in the contiguous range [start, end]. Care is taken to yield a + short sequence of slices, and perhaps even the shortest possible (I'm + pretty sure it's the latter). + + end is INCLUSIVE. + """ + # start_edges and end_edges both indicate whether, starting from any given + # dimension, the start/end index is at the top/bottom edge of the + # corresponding tensor, modeled as a tree + def reduce_edge_list(l): + tally = 1 + for i in range(len(l)): + reversed_idx = -1 * (i + 1) + l[reversed_idx] *= tally + tally = l[reversed_idx] + + if(start_edges is None): + start_edges = [s == 0 for s in start] + reduce_edge_list(start_edges) + if(end_edges is None): + end_edges = [e == (d - 1) for e,d in zip(end, dims)] + reduce_edge_list(end_edges) + + # Base cases. Either start/end are empty and we're done, or the final, + # one-dimensional tensor can be simply sliced + if(len(start) == 0): + return [tuple()] + elif(len(start) == 1): + return [(slice(start[0], end[0] + 1),)] + + slices = [] + path = [] + + # Dimensions common to start and end can be selected directly + for s,e in zip(start, end): + if(s == e): + path.append(slice(s, s + 1)) + else: + break + + path = tuple(path) + divergence_idx = len(path) + + # start == end, and we're done + if(divergence_idx == len(dims)): + return [tuple(path)] + + def upper(): + sdi = start[divergence_idx] + return [ + path + (slice(sdi, sdi + 1),) + s for s in + _get_minimal_slice_set( + start[divergence_idx + 1:], + [d - 1 for d in dims[divergence_idx + 1:]], + dims[divergence_idx + 1:], + start_edges=start_edges[divergence_idx + 1:], + end_edges=[1 for _ in end_edges[divergence_idx + 1:]] + ) + ] + + def lower(): + edi = end[divergence_idx] + return [ + path + (slice(edi, edi + 1),) + s for s in + _get_minimal_slice_set( + [0 for _ in start[divergence_idx + 1:]], + end[divergence_idx + 1:], + dims[divergence_idx + 1:], + start_edges=[1 for _ in start_edges[divergence_idx + 1:]], + end_edges=end_edges[divergence_idx + 1:], + ) + ] + + # If both start and end are at the edges of the subtree rooted at + # divergence_idx, we can just select the whole subtree at once + if(start_edges[divergence_idx] and end_edges[divergence_idx]): + slices.append( + path + (slice(start[divergence_idx], end[divergence_idx] + 1),) + ) + # If just start is at the edge, we can grab almost all of the subtree, + # treating only the ragged bottom edge as an edge case + elif(start_edges[divergence_idx]): + slices.append( + path + (slice(start[divergence_idx], end[divergence_idx]),) + ) + slices.extend(lower()) + # Analogous to the previous case, but the top is ragged this time + elif(end_edges[divergence_idx]): + slices.extend(upper()) + slices.append( + path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),) + ) + # If both sides of the range are ragged, we need to handle both sides + # separately. If there's contiguous meat in between them, we can index it + # in one big chunk + else: + slices.extend(upper()) + middle_ground = end[divergence_idx] - start[divergence_idx] + if(middle_ground > 1): + slices.append( + path + (slice(start[divergence_idx] + 1, end[divergence_idx]),) + ) + slices.extend(lower()) + + return [tuple(s) for s in slices] + + +@torch.jit.ignore +def _chunk_slice( + t: torch.Tensor, + flat_start: int, + flat_end: int, + no_batch_dims: int, +) -> torch.Tensor: + """ + Equivalent to + + t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] + + but without the need for the initial reshape call, which can be + memory-intensive in certain situations. The only reshape operations + in this function are performed on sub-tensors that scale with + (flat_end - flat_start), the chunk size. + """ + + batch_dims = t.shape[:no_batch_dims] + start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) + # _get_minimal_slice_set is inclusive + end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) + + # Get an ordered list of slices to perform + slices = _get_minimal_slice_set( + start_idx, + end_idx, + batch_dims, + ) + + sliced_tensors = [t[s] for s in slices] + + return torch.cat( + [s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors] + ) + + +def chunk_layer( + layer: Callable, + inputs: Dict[str, Any], + chunk_size: int, + no_batch_dims: int, + low_mem: bool = False, +) -> Any: + """ + Implements the "chunking" procedure described in section 1.11.8. + + Layer outputs and inputs are assumed to be simple "pytrees," + consisting only of (arbitrarily nested) lists, tuples, and dicts with + torch.Tensor leaves. + + Args: + layer: + The layer to be applied chunk-wise + inputs: + A (non-nested) dictionary of keyworded inputs. All leaves must + be tensors and must share the same batch dimensions. + chunk_size: + The number of sub-batches per chunk. If multiple batch + dimensions are specified, a "sub-batch" is defined as a single + indexing of all batch dimensions simultaneously (s.t. the + number of sub-batches is the product of the batch dimensions). + no_batch_dims: + How many of the initial dimensions of each input tensor can + be considered batch dimensions. + low_mem: + Avoids flattening potentially large input tensors. Unnecessary + in most cases, and is ever so slightly slower than the default + setting. + Returns: + The reassembled output of the layer on the inputs. + """ + if not (len(inputs) > 0): + raise ValueError("Must provide at least one input") + + initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] + orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) + + def _prep_inputs(t): + # TODO: make this more memory efficient. This sucks + if(not low_mem): + if not sum(t.shape[:no_batch_dims]) == no_batch_dims: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + t = t.reshape(-1, *t.shape[no_batch_dims:]) + else: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + return t + + prepped_inputs = tensor_tree_map(_prep_inputs, inputs) + + flat_batch_dim = 1 + for d in orig_batch_dims: + flat_batch_dim *= d + + no_chunks = flat_batch_dim // chunk_size + ( + flat_batch_dim % chunk_size != 0 + ) + + i = 0 + out = None + for _ in range(no_chunks): + # Chunk the input + if(not low_mem): + select_chunk = ( + lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t + ) + else: + select_chunk = ( + partial( + _chunk_slice, + flat_start=i, + flat_end=min(flat_batch_dim, i + chunk_size), + no_batch_dims=len(orig_batch_dims) + ) + ) + + chunks = tensor_tree_map(select_chunk, prepped_inputs) + + # Run the layer on the chunk + output_chunk = layer(**chunks) + + # Allocate space for the output + if out is None: + allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]) + out = tensor_tree_map(allocate, output_chunk) + + # Put the chunk in its pre-allocated space + out_type = type(output_chunk) + if out_type is dict: + def assign(d1, d2): + for k, v in d1.items(): + if type(v) is dict: + assign(v, d2[k]) + else: + v[i : i + chunk_size] = d2[k] + + assign(out, output_chunk) + elif out_type is tuple: + for x1, x2 in zip(out, output_chunk): + x1[i : i + chunk_size] = x2 + elif out_type is torch.Tensor: + out[i : i + chunk_size] = output_chunk + else: + raise ValueError("Not supported") + + i += chunk_size + + reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) + out = tensor_tree_map(reshape, out) + + return out \ No newline at end of file diff --git a/alphafold_pytorch_jit/subnets.py b/alphafold_pytorch_jit/subnets.py index a9038e3b2a55042945dbedb76dda1521a6862e1b..8928b54a96a3ebe576d14ef53a4396c0c11114b4 100644 --- a/alphafold_pytorch_jit/subnets.py +++ b/alphafold_pytorch_jit/subnets.py @@ -22,6 +22,9 @@ import datetime import os import pickle import kpex +from alphafold_pytorch_jit.structure_module.structure_module import StructureModule +from alphafold_pytorch_jit.structure_module.utils.import_weights import import_jax_weights_ +from alphafold_pytorch_jit.structure_module.utils.feats import atom14_to_atom37 global fast_test global fast_test_msax2_evox12 @@ -491,6 +494,33 @@ class AlphaFoldIteration(nn.Module): res = OrderedDict(res) self.heads['predicted_lddt'].load_state_dict(res) + sm_config = self.c.heads.structure_module + structure_config = { + 'c_ipa': 16, + 'c_z': 128, + 'c_s': sm_config.num_channel, + 'c_resnet': sm_config.sidechain.num_channel, + 'dropout_rate': sm_config.dropout, + 'epsilon': 1e-08, + 'inf': 100000.0, + 'no_angles': 7, + 'no_blocks': sm_config.num_layer, + 'no_heads_ipa': sm_config.num_head, + 'no_qk_points': sm_config.num_point_qk, + 'no_resnet_blocks': sm_config.sidechain.num_residual_block, + 'no_transition_layers': 1, + 'no_v_points': sm_config.num_point_v, + 'trans_scale_factor': sm_config.position_scale + } + # Meta 可以让模型加载时忽略 structure_module + class Meta: pass + self.structure_module = Meta() + self.structure_module.model = StructureModule( + is_multimer=False, + **structure_config, + ) + import_jax_weights_(self.structure_module.model,struct_params,) + def _slice_batch(self, i, ensembled_batch, non_ensembled_batch): b = {k: v[i] for k, v in ensembled_batch.items()} if non_ensembled_batch is not None: # omit if prev-keys not exist @@ -615,22 +645,49 @@ class AlphaFoldIteration(nn.Module): elif name in ['structure_module']: if Struct_Params is None: continue - representations_hk = jax.tree_map(detached,representations) - batch_hk = jax.tree_map(detached,batch0) head_module = self.read_time() - res_hk = module(Struct_Params,rng,representations_hk,batch_hk) - head_module2 = self.read_time() - ret[name] = jax.tree_map(list2tensor,res_hk) - del res_hk - if 'representations' in ret[name].keys(): - representations.update(ret[name].pop('representations')) + + + # representations_hk = jax.tree_map(detached,representations) + # batch_hk = jax.tree_map(detached,batch0) + # res_hk = module(Struct_Params,rng,representations_hk,batch_hk) + # ret[name] = jax.tree_map(list2tensor,res_hk) + # del res_hk + # if 'representations' in ret[name].keys(): + # representations.update(ret[name].pop('representations')) + + # print('# ====> [INFO] pLDDTHead input has been saved.') # f_tmp_plddt = 'structure_module_input.pkl' # while os.path.isfile(f_tmp_plddt): # f_tmp_plddt = f_tmp_plddt + '-1.pkl' # with open(f_tmp_plddt, 'wb') as h_tmp: # pickle.dump(representations['structure_module'], h_tmp, protocol=4) - print(' # [TIME] head module duration =', (head_module2 - head_module), 'sec') + + s = representations['single'] + z = representations['pair'] + aatype = batch0['aatype'] + seq_mask = batch0['seq_mask'] + structure_outputs = self.structure_module.model( + s, z, aatype, seq_mask + ) + ret['representations']['structure_module'] = structure_outputs['single'] + ret['structure_module'] = {} + ret['structure_module']['traj'] = structure_outputs['frames'] + ret['structure_module']['final_affines'] = structure_outputs['frames'][-1] + + atom14_pred_positions = structure_outputs['positions'][-1] + ret['structure_module']['final_atom14_positions'] = atom14_pred_positions + ret['structure_module']['final_atom14_mask'] = batch0['atom14_atom_exists'] + + atom37_pred_positions = atom14_to_atom37(atom14_pred_positions, + batch0) + atom37_pred_positions *= batch0['atom37_atom_exists'][:, :, None] + ret['structure_module']['final_atom_positions'] = atom37_pred_positions # (N, 37, 3) + ret['structure_module']['final_atom_mask'] = batch0['atom37_atom_exists'] # (N, 37) + + head_module2 = self.read_time() + print(' # [TIME] structure module duration =', (head_module2 - head_module), 'sec') else: ret[name] = module(representations) if 'representations' in ret[name]: diff --git a/psi_run_af2.sh b/psi_run_af2.sh index 98439a454c1d9878e67d66ee6140b1342a3b2237..a71d82230973bc9adf06fd4d2eb668d68d2f7126 100644 --- a/psi_run_af2.sh +++ b/psi_run_af2.sh @@ -40,7 +40,7 @@ export DISTRIBUTED_EMBEDDING_TRIANGLE=1 export SHMID=test thread=36 rankfile=rankfile32corer -rank_num=2 +rank_num=8 export SEQ_TYPE=T1050 @@ -49,7 +49,7 @@ if [ ! "$DISTRIBUTED_MPI" = "1" ]; then start_cpu_id=0 numactl -C $start_cpu_id-$((start_cpu_id+thread-1)) -m 0-3 python run_psi_af2.py else - mpirun --allow-run-as-root -n $rank_num --map-by rankfile:file=$rankfile -x OMP_NUM_THREADS=$thread --report-bindings -x UCX_TLS=self,sm python -u run_psi_af2.py + mpirun --allow-run-as-root -n $rank_num --rankfile $rankfile -x OMP_NUM_THREADS=$thread -x UCX_RC_VERBS_TX_MIN_SGE=2 -x UCX_UD_VERBS_TX_MIN_SGE=1 -x UCX_TLS=rc python -u run_psi_af2.py fi # 在HBM-flat 模式运行