Skip to content

A Python library for training Graph Neural Networks (GNNs) on text-attributed graphs (TAGs) using large language model (LLM) textual features

License

Notifications You must be signed in to change notification settings

devanshamin/tag-llm

Repository files navigation

LLM-Enhanced Text-Attributed Graph Representation Learning

Framework Overview

  1. Node Feature Extraction

    • Prepare prompts containing the article information (title and abstract) for each node.
    • Query an LLM with these prompts to generate a ranked label prediction list and explanation.
  2. Node Feature Encoder

    • Fine-tune a language model (LM) on a sequence classification task with the article title and abstract as input.
  3. GNN Trainer

    • Train a GNN model using the following features, with node features updated by the fine-tuned LM encoder:
      1. Title & Abstract (TA)
      2. Prediction (P) - Using a PyTorch nn.Embedding layer for top-k ranked features.
      3. Explanation (E)
  4. Model Ensemble

    • Fuse predictions from the trained GNN models on TA, P, and E by averaging them.

Note

Fine-tuning an LM is optional and not currently supported. Instead, you can use any open-weight fine-tuned embedding model, significantly reducing time and cost while achieving comparable results.

Design Choices

Two crucial components of this project were LLM inference and LM inference, each with specific challenges and solutions.

LLM Inference

Challenges

  1. Rate Limits and Cost:

    • Using provider APIs (OpenAI, Anthropic, Groq, etc.) was straightforward but slow and expensive due to rate limits on requests per minute (RPM) and tokens per minute (TPM).
  2. Throughput with Naive Pipelines:

    • Naive Hugging Face text generation pipeline was slow with open-weight models.

Solutions

  1. Online Inference:

    • Provider APIs: Utilized APIs from providers like OpenAI, Anthropic, Groq, etc.
    • Unified Interface: Used the litellm package to connect to different LLM providers via a unified interface.
    • Structured Outputs: Employed the instructor package for structured outputs using Pydantic classes.
    • Rate Limit Handling: Implemented exponential backoff retrying and proactive delay by setting the rate_limit_per_minute parameter in the configuration.
    • Durability: Ensured durability with persistent caching of LLM responses using diskcache.
  2. Offline Inference:

    • Open-Weights Models: Used publicly available open-weight models from the Hugging Face hub, opting for mid-sized (7-8 billion parameter) models to balance performance and cost.
    • Batch Inference: Maximized throughput by using the vLLM engine for batch inference.
    • Structured Output Challenges: Addressed the challenge of getting structured outputs from open-weight models with prompt engineering and a generalizable prompt template, validated with Python code and retried as necessary.

LM Inference

Challenges

  1. Encoding Speed:
    • Encoding could be slow depending on the size of the model and dataset.

Solutions

  1. Model Selection:

  2. Caching:

    • Utilized safetensors for safe storage and distribution of cached embeddings, improving the speed and efficiency of the process.

By addressing these bottlenecks with strategic choices in both online and offline LLM inference and efficient LM inference, the framework ensured enhanced performance and scalability.

🚀 Installation

# (Recommended) Create a new conda environment.
$ conda create -n tag_llm python=3.10 -y
$ conda activate tag_llm

# Replace the 'cu118' CUDA version according to your system
$ pip install torch==2.3.0 --index-url https://download.pytorch.org/whl/cu118
$ pip install torch_geometric
$ pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cu118.html

# For online LLM inference
$ poetry install
# For offline LLM inference
$ poetry install --extras "llm_offline"

💻 Usage

$ tag_llm_train --config=train_config.yaml
# You can also provide CLI arguments to overwrite values in the `train_config.yaml` file
$ tag_llm_train --help

🎓 Citations

If you have used my library for research purposes please quote it with the following reference:

@misc{tag-llm,
   author = {Devansh Amin},
   title = {LLM-Enhanced Text-Attributed Graph Representation Learning},
   year = {2024},
   url = {https://github.com/devanshamin/tag-llm},
   note = {A Python library for training Graph Neural Networks (GNNs) on text-attributed graphs (TAGs) using large language model (LLM) textual features}
}

📜 License

tag-llm is released under the Apache 2.0 license. See the LICENSE file for details.

About

A Python library for training Graph Neural Networks (GNNs) on text-attributed graphs (TAGs) using large language model (LLM) textual features

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published