From Molecules to Materials: Pre-training Large Generalizable Models for Atomic Property Prediction

Machine Learning for Chemistry Needs a Foundation Model

Natural language processing has seen some of the most exciting advances in machine learning in the last few years. Modern language models like GPT-4 can translate text, answer questions, and even write code. They're not perfect, but they're lightyears ahead of where we were just a few years ago. What changed? These models gained "common sense" by pre-training on massive diverse datasets. Just like we humans learn concepts that generalize across domains through varied life experiences.

So why haven't we seen similar breakthroughs in chemistry? Despite having massive datasets from quantum simulations (e.g., the OC20 dataset with ~250 million DFT calculations), state-of-the-art models for tasks like catalyst screening and crystal structure prediction still start training from scratch.

The t-SNE figure above shows JMP's learned representations on randomly sampled structures from all pre-training and fine-tuning datasets, colored by dataset. The model clusters structures by their chemical domain, showing that it learns to distinguish between different chemical domains (despite never being explicitly trained to do so).

This is because machine learning for chemistry poses unique challenges. Molecules and materials have diverse structures in continuous 3D space. Their properties vary over many orders of magnitude. And modeling dynamic non-equilibrium systems gets quite messy.

NLP researchers faced similarly complex challenges. But they overcame them by pre-training on massive diverse datasets, containing billions of sentences in hundreds of languages --- both natural and programming languages. This allowed models to learn "common sense" linguistic patterns that generalize across domains, ultimately leading to groundbreaking performance on downstream tasks.

We believe a similar approach is key for chemistry. A single foundation model pre-trained on diverse chemical data could learn universal atomistic representations. Our atoms and molecules speak different dialects, but a polyglot foundation model could become fluent in the universal language of chemistry.

Here's a peek at what we found: Our method, JMP, significantly outperformed training from scratch and matched or exceeded the state-of-the-art on the vast majority of benchmarks we evaluated. On the popular QM9 dataset, it achieved ~50% lower error compared to previous best results! This suggests JMP effectively learns transferable representations that broadly generalize across diverse chemical domains (see the t-SNE figure above for a visualization of these representations; also, look at our interactive t-SNE on JMP's webpage).

Teaching a model “all of chemistry”

We started with a graph neural network model called GemNet-OC. This architecture takes a graph representing the 3D structure of a molecule or material and predicts its properties (e.g., per-atom forces, per-system energies, etc.). Our goal was to pre-train this model to learn general chemical representations that work for any task.

The key was training GemNet-OC on diverse datasets, spanning different chemical domains. In this paper, we focused on four chemical domains: small molecules (1-20 atoms), macromolecules (more than 20 atoms), materials, and catalysis (molecules placed on material surfaces). To ensure that our method is generalizable, we only pre-trained on datasets from the small molecule and catalysis domains. 

All of our pre-training datasets are DFT relaxation trajectory datasets, which means that they consist of relaxation trajectories computed by optimizing the atomic positions using density functional theory (DFT). This means that every dataset is a collection of atomic structures, where the atom positions are iteratively updated based on the forces to minimize the energy. The prediction task is to take an intermediate structure and predict (1) its total energy and (2) the atomic forces acting on each atom in the structure. The pre-training datasets are:

  • OC20: 250 million DFT calculations across a wide swath of materials, surfaces, and adsorbates (nitrogen, carbon, and oxygen chemistries).

  • OC22: 9.9 million DFT calculations of metal oxide surfaces and adsorbates for renewable energy applications.

  • Transition-1x: 9.6 DFT million calculations of organic reactions sampled along minimum energy paths using nudged elastic band.

  • ANI-1x: 5 million DFT calculations of non-equilibrium molecular conformations of small organic molecules.

We simultaneously trained GemNet-OC on all these datasets in a multi-task framework we call Joint Multi-Domain Pre-Training (JMP). Here's how it works:

An overview of the Joint Multi-domain Pre-training (JMP) method. Left demonstrates the joint pre-training setup, where a single model is simultaneously trained on set of diverse pre-training datasets using multi-task learning. Right shows the fine-tuning process, where the pre-trained JMP backbone is equipped with new prediction heads and trained on downstream tasks.

For each dataset, GemNet-OC had a dedicated head attached to its shared backbone. These heads predicted total energies and forces for their specific dataset. The loss combined contributions from all datasets and drove the model to build unified representations that generalized across tasks.

Given the vast differences between datasets, dataset and loss balancing techniques and high levels of regularization were key to training success. These techniques are described in detail in our paper.

After pre-training on all our data, we tested how well the model performed on 40 tasks covering molecules and materials from three different chemical domains. This evaluated whether the pre-trained representations generalized to new domains. These tasks are taken from the following datasets:

  • QM9: QM geometries, energies, enthalpies, and other properties for 134k small organic molecules made up of CHONF atoms.

  • rMD17: Molecular dynamics trajectories with forces and energies for 10 small organic molecules like benzene and aspirin.

  • SPICE: Molecular dynamics trajectories of different dipeptides and solvated amino acids, relevant for drug design applications.

  • MD22: Molecular dynamics trajectories with forces and energies for 7 large organic molecules and supramolecular complexes up to 370 atoms.

  • MatBench: Calculated properties for inorganic materials, including bandgaps, formation energies, exfoliation energies, and more.

  • QMOF: 14k+ metal-organic frameworks with DFT-calculated electronic properties such as band gaps.

Results: Towards universal chemistry models

JMP consistently beat training from scratch, achieving a 59% average improvement across tasks.

More importantly, it matched or beat state-of-the-art on 34 out of 40 benchmarks, including outperforming other pre-training methods. A single model dominating so many tasks suggests JMP effectively learns transferable representations.

JMP also enabled scaling to larger models without overfitting. Check out this performance comparison; we show the relative improvement in performance when going from a small to a large model for both scratch-trained and JMP models:

JMP Scaling

While the large scratch-trained model (GN-OC-L) performed worse than the small one (GN-OC-S), the large JMP model (JMP-L) significantly outperformed its smaller variant (JMP-S). Pre-training acts as a shield against overfitting even on small datasets. Since model size has been key to recent AI breakthroughs, this is really exciting!

Finally, by stopping JMP fine-tuning early to match the performance of scratch-trained models, we found that JMP's pre-trained representations sped up downstream training by 12x on average. This presents a trade-off: While the pre-training cost of the model is substantial, the downstream training cost is significantly reduced. Provided that pre-training is done once and fine-tuning is done many times, this trade-off is favorable in practice.

The future of chemistry AI

JMP demonstrates that pre-training on diverse data can produce universal atomistic representations. There's still a long way to go, but we envision this approach enabling GPT-3-level foundation models for chemical prediction and synthesis.

Such models would guide drug design, materials discovery, catalyst optimization and more. They could rapidly predict properties and reactions for any reasonable molecule you dream up.

Research like JMP brings us closer to this future. Next steps are exploring larger models, more pre-training techniques, and ever more diverse training data. Exciting times ahead at the intersection of machine learning and chemistry!

For more information on our approach, including additional details on the JMP model, pre-training datasets, and extensive evaluation results, please refer to our paper: https://arxiv.org/abs/2310.16802

Our code, pre-trained checkpoints, recorded presentation, poster, and some additional interactive visualizations are available at: https://nima.sh/jmp

If you have any questions or would like to further discuss our work, feel free to reach out to me. My Twitter/LinkedIn links are on my website: https://nima.sh

2
1 reply