Want to predict if a protein will bind to a drug, catalyze a reaction, or fluoresce under UV light? First, you'll need to solve a deceptively tricky problem: how to represent proteins of wildly different sizes in a consistent way. In all these cases, we would like to predict something about a protein from its amino acid sequence. Often, the best tool for this task is a pre-trained protein language model (PLM), such as the ESM family of transformer-based architectures. A PLM takes a protein’s amino acid sequence as input and generates a high-dimensional representation (typically, 1000+ dimensions) for each input token. The problem is that protein sequences have different lengths, ranging from tens of amino acids to more than a thousand. So, starting from a variable-length input—hence, a variable-length embedding matrix—and going to a fixed-length output, you will need to aggregate somewhere.
A common approach is to average pool. Average pooling is great; it is simple to implement and has no parameters to train. However, it weighs each amino acid’s representation equally. This is unrealistic—often, there are specific residues in the protein that are particularly important (for example, the residues at an active site). Even when the per-token PLM representations do contain information distinguishing such residues from others, these distinctions might be lost during average pooling, which could be especially problematic for longer protein sequences.
May we suggest an alternative pooling strategy, one that adapts to your desired downstream task?
In our recent work, Aggregating Residue-Level Protein Language Model Embeddings with Optimal Transport, jointly with Rohit Singh, we introduce a novel method, based on the notion of sliced Wasserstein embeddings (SWE), to convert variable-length PLM outputs into fixed-length representations. Our core idea is simple: Suppose you have a sequence of n residues, each embedded in Rd. We learn a set of m reference anchors in Rd and compute optimal transport distances between your sequence embeddings and these references. Specifically, we use sliced Wasserstein distances, which make this computation efficient.
Averaging exposes only the first moments of your sequence representation, while our SWE approach lets you access higher moments. SWE particularly shines in two scenarios: when you have substantial training data, or when GPU constraints force you to use a moderate-sized PLM. In both cases, our approach solidly outperforms average pooling and other baseline methods.
Our aggregation scheme is especially effective for longer protein sequences, capturing essential information that might be lost through average pooling. This makes intuitive sense: crucial protein properties, such as drug-target interactions, are typically governed by only a few amino acids. Whereas such information could get lost when the token-level representations generated by PLMs are simply averaged, our proposed SWE method can extract more meaningful protein-level representations, especially for smaller PLMs and longer sequences.
How much you benefit from SWE pooling depends on your task, the amount of training data, and the model size you can fit in GPU memory. We believe sophisticated summarization approaches, such as SWE, will be crucial in maximizing the power of smaller PLMs that can fit on common GPUs and could further unlock parameter-efficient fine-tuning opportunities for PLMs.
This work opens up exciting new possibilities. A particularly promising direction is improving the interpretability of our method. Potential research directions include explicitly using active/binding sites in a dictionary learning-type setup, where the key residues in the interaction surface inform the selection of reference set elements, and incorporating auxiliary loss functions to encode greater biological intuition into the proposed aggregation operation. Such approaches would enhance the biological relevance of the representations produced by our method. Our work thus opens up new pathways for enhancing the interpretability and efficiency of PLM-based approaches in biology.
Give it a try, see if it helps with your task, and let us know if you have any questions or comments. Our code is on GitHub and is a drop-in replacement for average pooling.