Efficient Evolutionary Scale Modeling: Efficient and simplified implementation of protein language model for inference and training.
Project description
ESM-Efficient
Efficient implementatin of ESM family of models.
Installation
conda install pytorch cudatoolkit=12.5 -c pytorch -c nvidia
pip install flash-attn --no-build-isolation
pip install esm-efficient
Usage
Predict the log probabilities of a sequence of tokens using the model.
import torch
from esme import ESME
from esme.alphabet import tokenize
# create load the model
model = ESME.load_from_checkpoint("{model}.safetensors")
tokens = tokenize(['MEEPQSDPSVEPPLSQETFSDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
# predict logits
logits = model(tokens)
# logits.shape = (2, seq_len, embed_size)
# predict log probabilities
log_probs = model.predict_log_prob(tokens, pad_output=True)
# log_probs.shape = (2, seq_len, embed_size)
from esme.alphabet import tokenize_unpad
# tokenize without padding (more efficient avoids calculating with padding)
tokens, indices, (cu_lens, max_len) = tokenize_unpad(['MEEPQSDPSVEPPLSQETFSDLWK', 'MADQLTEEQIAEFKEAFSLFDKDG'])
log_probs = model.predict_log_prob(tokens, (cu_lens, max_len))
# log_probs.shape = (seq_len_protein1 + seq_len_protein2, embed_size)
Predict effect of variants:
from esme.variant import predict_mask_margin
seq = 'MEEPQSDPSVEPPLSQETFSDLWK'
df = predicit_mask_margin(model, seq)
# ... pd.DataFrame({
# ... 'variant': ['M1A', 'M1C', ..., 'P16Y'],
# ... 'score': [0.1, 0.2, ..., -0.3]
# ... }).set_index('variant')
Fine-tune the model with lora adapters:
# only add will be trained by default
model.add_lora(rank=16, layers=('query', 'key', 'value'), adapter_names=['adapter1', 'adapter2'])
# mark only lora as trainable called by default when adding lora
model.mark_only_lora_as_trainable()
# save the model with the lora weights
model.save_lora('<path>.safetensors', adapter_names=['adapter1'])
# load the model with the lora weights
model.load_lora('<path>.safetensors', adapter_names=['adapter1'])
Quantization of the model:
model = model.from_pretrained('8M.safetensors', quantization='4bit')
Activation checkpointing of each transformer layer:
model = model.from_pretrained('8M.safetensors', checkpointing=True)
Model Weights
The model weights can be downloaded from the HuggingFace: https://huggingface.co/mhcelik/esm-efficient
Evaluation
To perform the evaluation reported in the paper, run the following command:
snakemake -n --use-conda
This will download the data, train the models, and evaluate them. The results will be saved in the results
directory.
See the workflow/Snakefile
for more details.
To generate a specific figures in the paper, run the following command:
snakemake reports/paper_figures/figure-2.pdf -n --use-conda
Citation
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for esm_efficient-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 06370bd1745a67161d27ce119e291244c864473316676274432ff229c953346d |
|
MD5 | 963c806251ad21a1ea52a3f4266ba2f5 |
|
BLAKE2b-256 | 6ea754951f9cedcc8e735a260aecad67ec24fe938a42d95bfcd84cac953e8fc7 |