Skip to content

Commit

Permalink
ParaFold for AlphaFold 2.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Zuricho committed Feb 24, 2023
1 parent 8aeebbf commit 8ccf00b
Show file tree
Hide file tree
Showing 52 changed files with 784 additions and 724 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.slurm
21 changes: 0 additions & 21 deletions LICENSE

This file was deleted.

56 changes: 32 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,69 +2,77 @@
<img src="./docs/parafoldlogo.png" width="400" >
</div>

# ParallelFold
# ParaFold

Author: Bozitao Zhong - zbztzhz@gmail.com

:station: We are adding new functions to ParallelFold, you can see our [Roadmap](https://trello.com/b/sAqBIxBC/parallelfold).

:bookmark_tabs: Please cite our [paper](https://arxiv.org/abs/2111.06340) if you used ParallelFold (ParaFold) in you research.
:bookmark_tabs: Please cite our [paper](https://arxiv.org/abs/2111.06340) if you used ParaFold (ParallelFold) in you research.

## Overview

Recent change: **ParaFold now supports AlphaFold 2.3.1**

This project is a modified version of DeepMind's [AlphaFold2](https://github.com/deepmind/alphafold) to achieve high-throughput protein structure prediction.

We have these following modifications to the original AlphaFold pipeline:

- Divide **CPU part** (MSA and template searching) and **GPU part** (prediction model)

**ParallelFold now supports AlphaFold 2.1.2**



## How to install

We recommend to install AlphaFold locally, and not using **docker**.

For CUDA 11, you can refer to the [installation guide here](./docs/install.md).
```bash
# clone this repo
git clone https://github.com/Zuricho/ParallelFold.git

For CUDA 10.1, you can refer to the [installation guide here](./docs/install_cuda10.md).
# Create a miniconda environment for ParaFold/AlphaFold
# Recommend you to use python 3.8, version < 3.7 have missing packages, python versions newer than 3.8 were not tested
conda create -n parafold python=3.8

pip install py3dmol
# openmm 7.7 is recommended (original alphafold using 7.5.1, but it is not supported now)
conda install -c conda-forge openmm=7.7 pdbfixer

# use pip3 to install most of packages
pip3 install -r requirements.txt

## Some detail information of modified files
# install cuda and cudnn
# cudatoolkit 11.3.1 matches cudnn 8.2.1
conda install cudatoolkit=11.3 cudnn

- `run_alphafold.py`: modified version of original `run_alphafold.py`, it has multiple additional functions like skipping featuring steps when exists `feature.pkl` in output folder
- `run_alphafold.sh`: bash script to run `run_alphafold.py`
- `run_figure.py`: this file can help you make figure for your system
# downgrade jaxlib to the correct version, matches with cuda and cudnn version
pip3 install --upgrade --no-cache-dir jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# install packages for multiple sequence alignment
conda install -c bioconda hmmer=3.3.2 hhsuite=3.3.0 kalign2=2.04

chmod +x run_alphafold.sh
```

## How to run

Visit the [usage page](./docs/usage.md) to know how to run

## Some detail information of modified files

- `run_alphafold.py`: modified version of original `run_alphafold.py`, it has multiple additional functions like skipping featuring steps when exists `feature.pkl` in output folder
- `run_alphafold.sh`: bash script to run `run_alphafold.py`
- `run_figure.py`: this file can help you make figure for your system

## Functions

You can using some flags to change prediction model for ParallelFold:

`-r`: Skip AMBER refinement [Under repair]

`-b`: Using benchmark mode - running JAX model for twice, and the second run can used for evaluate running time

`-R`: Change the number of cycles in recycling
## How to run

**More functions are under development.**
Visit the [usage page](./docs/usage.md) to know how to run



## What is this for

ParallelFold can help you accelerate AlphaFold when you want to predict multiple sequences. After dividing the CPU part and GPU part, users can finish feature step by multiple processors. Using ParallelFold, you can run AlphaFold 2~3 times faster than DeepMind's procedure.
ParallelFold can help you accelerate AlphaFold when you want to predict multiple sequences. After dividing the CPU part and GPU part, users can finish feature step by multiple processors. Using ParaFold, you can run AlphaFold 2~3 times faster than DeepMind's procedure.

**If you have any question, please send GitHub issues**
**If you have any question, please raise issues**



Expand Down
6 changes: 3 additions & 3 deletions alphafold/common/residue_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
# 4,5,6,7: 'chi1,2,3,4-group'
# The atom positions are relative to the axis-end-atom of the corresponding
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
# is defined such that the dihedral-angle-definiting atom (the last entry in
# is defined such that the dihedral-angle-defining atom (the last entry in
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
# format: [atomname, group_idx, rel_position]
rigid_group_atom_positions = {
Expand Down Expand Up @@ -772,10 +772,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
Expand Down
14 changes: 7 additions & 7 deletions alphafold/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(self,
uniref90_database_path: str,
mgnify_database_path: str,
bfd_database_path: Optional[str],
uniclust30_database_path: Optional[str],
uniref30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer,
Expand All @@ -135,9 +135,9 @@ def __init__(self,
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path)
else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path])
databases=[bfd_database_path, uniref30_database_path])
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path)
Expand Down Expand Up @@ -211,14 +211,14 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner,
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniref_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])

templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
Expand Down
8 changes: 4 additions & 4 deletions alphafold/data/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ class LengthError(PrefilterError):
'template_aatype': np.float32,
'template_all_atom_masks': np.float32,
'template_all_atom_positions': np.float32,
'template_domain_names': np.object,
'template_sequence': np.object,
'template_domain_names': object,
'template_sequence': object,
'template_sum_probs': np.float32,
}

Expand Down Expand Up @@ -1002,8 +1002,8 @@ def get_templates(
(1, num_res, residue_constants.atom_type_num), np.float32),
'template_all_atom_positions': np.zeros(
(1, num_res, residue_constants.atom_type_num, 3), np.float32),
'template_domain_names': np.array([''.encode()], dtype=np.object),
'template_sequence': np.array([''.encode()], dtype=np.object),
'template_domain_names': np.array([''.encode()], dtype=object),
'template_sequence': np.array([''.encode()], dtype=object),
'template_sum_probs': np.array([0], dtype=np.float32)
}
return TemplateSearchResult(
Expand Down
26 changes: 18 additions & 8 deletions alphafold/data/tools/jackhmmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,20 @@ def query(self,
input_fasta_path: str,
max_sequences: Optional[int] = None) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer."""
return self.query_multiple([input_fasta_path], max_sequences)[0]

def query_multiple(
self,
input_fasta_paths: Sequence[str],
max_sequences: Optional[int] = None,
) -> Sequence[Sequence[Mapping[str, Any]]]:
"""Queries the database for multiple queries using Jackhmmer."""
if self.num_streamed_chunks is None:
single_chunk_result = self._query_chunk(
input_fasta_path, self.database_path, max_sequences)
return [single_chunk_result]
single_chunk_results = []
for input_fasta_path in input_fasta_paths:
single_chunk_results.append([self._query_chunk(
input_fasta_path, self.database_path, max_sequences)])
return single_chunk_results

db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
Expand All @@ -185,7 +195,7 @@ def query(self,

# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor:
chunked_output = []
chunked_outputs = [[] for _ in range(len(input_fasta_paths))]
for i in range(1, self.num_streamed_chunks + 1):
# Copy the chunk locally
if i == 1:
Expand All @@ -197,9 +207,9 @@ def query(self,

# Run Jackhmmer with the chunk
future.result()
chunked_output.append(self._query_chunk(
input_fasta_path, db_local_chunk(i), max_sequences))

for fasta_index, input_fasta_path in enumerate(input_fasta_paths):
chunked_outputs[fasta_index].append(self._query_chunk(
input_fasta_path, db_local_chunk(i), max_sequences))
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
# Do not set next_future for the last chunk so that this works even for
Expand All @@ -208,4 +218,4 @@ def query(self,
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_output
return chunked_outputs
2 changes: 1 addition & 1 deletion alphafold/model/all_atom_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def torsion_angles_to_frames(
chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6]
chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7]

all_frames_to_backb = jax.tree_multimap(
all_frames_to_backb = jax.tree_map(
lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5],
chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None],
chi4_frame_to_backb[:, None])
Expand Down
61 changes: 61 additions & 0 deletions alphafold/model/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,64 @@ def __call__(self, inputs):

return output


class LayerNorm(hk.LayerNorm):
"""LayerNorm module.
Equivalent to hk.LayerNorm but with different parameter shapes: they are
always vectors rather than possibly higher-rank tensors. This makes it easier
to change the layout whilst keep the model weight-compatible.
"""

def __init__(self,
axis,
create_scale: bool,
create_offset: bool,
eps: float = 1e-5,
scale_init=None,
offset_init=None,
use_fast_variance: bool = False,
name=None,
param_axis=None):
super().__init__(
axis=axis,
create_scale=False,
create_offset=False,
eps=eps,
scale_init=None,
offset_init=None,
use_fast_variance=use_fast_variance,
name=name,
param_axis=param_axis)
self._temp_create_scale = create_scale
self._temp_create_offset = create_offset

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
is_bf16 = (x.dtype == jnp.bfloat16)
if is_bf16:
x = x.astype(jnp.float32)

param_axis = self.param_axis[0] if self.param_axis else -1
param_shape = (x.shape[param_axis],)

param_broadcast_shape = [1] * x.ndim
param_broadcast_shape[param_axis] = x.shape[param_axis]
scale = None
offset = None
if self._temp_create_scale:
scale = hk.get_parameter(
'scale', param_shape, x.dtype, init=self.scale_init)
scale = scale.reshape(param_broadcast_shape)

if self._temp_create_offset:
offset = hk.get_parameter(
'offset', param_shape, x.dtype, init=self.offset_init)
offset = offset.reshape(param_broadcast_shape)

out = super().__call__(x, scale=scale, offset=offset)

if is_bf16:
out = out.astype(jnp.bfloat16)

return out

Loading

0 comments on commit 8ccf00b

Please sign in to comment.