Skip to content

🎉 Torch geometric plaid dataset loader #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions src/plaid/utils/dataset_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
from plaid.containers.sample import Sample
from plaid.problem_definition import ProblemDefinition
from torch_geometric.data import Data
import numpy as np
import torch
from torch_geometric.utils._coalesce import coalesce as geometric_coalesce
from plaid.containers.dataset import Dataset as PlaidDataset
from plaid.problem_definition import ProblemDefinition
from typing import List, Tuple, Union, Annotated
from datasets import load_dataset
from plaid.bridges.huggingface_bridge import huggingface_dataset_to_plaid
from torch_geometric.data import Data
import os
from tqdm import tqdm
from multiprocessing import Pool

Check warning on line 15 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L1-L15

Added lines #L1 - L15 were not covered by tests

def my_coalesce(edges: torch.Tensor | np.ndarray, num_nodes: int, reduce="add"):
if isinstance(edges, np.ndarray):
edges = torch.tensor(edges).T
return geometric_coalesce(edges, num_nodes=num_nodes, reduce=reduce).T.numpy()
edges = geometric_coalesce(edges.T, num_nodes=num_nodes, reduce=reduce).T
return edges

Check warning on line 22 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L17-L22

Added lines #L17 - L22 were not covered by tests

def faces_to_edges(faces: np.ndarray, num_nodes: int):

Check warning on line 24 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L24

Added line #L24 was not covered by tests
"""Creates a list of edges from a Faces array

Args:
faces (np.ndarray): Array of faces shape (n_faces, face_dim)

Returns:
np.ndarray: the edge list of shape (n, 2)
"""

assert len(faces.shape)==2, "Wrong shape for the faces, should be a 2D array"

Check warning on line 34 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L34

Added line #L34 was not covered by tests

# Generate edges (without duplicates in one pass)
rolled = np.roll(faces, -1, axis=1)
edges = np.vstack((faces.ravel(), rolled.ravel())).T
edges = np.concatenate((edges, edges[:, ::-1]), axis=0)

Check warning on line 39 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L37-L39

Added lines #L37 - L39 were not covered by tests

return edges

Check warning on line 41 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L41

Added line #L41 was not covered by tests


def sample_to_pyg(sample: Sample, sample_id: int, problem_definition: ProblemDefinition, base_name: str) -> Data:

Check warning on line 44 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L44

Added line #L44 was not covered by tests
"""
Converts a Plaid sample to PytorchGeometric Data object

Args:
sample (plaid.containers.sample.Sample): data sample
sample_id (int): Plaid sample id
problem_definition (ProblemDefinition)
base_name (str): Name of the base to extract

Returns:
Data: the converted data sample
"""

vertices = sample.get_vertices(base_name=base_name)
edge_index = []
faces_dict = {}
for key, faces in sample.get_elements(base_name=base_name).items():
edge_index.append(faces_to_edges(faces, num_nodes=vertices.shape[0], coalesce=False))
faces_dict[key] = torch.tensor(faces, dtype=torch.long)
edge_index = np.concatenate(edge_index, axis=0)
edge_index = my_coalesce(edge_index, num_nodes=vertices.shape[0])

Check warning on line 65 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L58-L65

Added lines #L58 - L65 were not covered by tests

v1 = vertices[edge_index[:, 0]]
v2 = vertices[edge_index[:, 1]]
edge_weight = np.linalg.norm(v2 - v1, axis=1)

Check warning on line 69 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L67-L69

Added lines #L67 - L69 were not covered by tests

# loading scalars
input_scalars_names = problem_definition.get_input_scalars_names()
output_scalars_names = problem_definition.get_output_scalars_names()

Check warning on line 73 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L72-L73

Added lines #L72 - L73 were not covered by tests

input_scalars = []
output_scalars = []
for name in input_scalars_names:
input_scalars.append(sample.get_scalar(name))
for name in output_scalars_names:
output_scalars.append(sample.get_scalar(name))

Check warning on line 80 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L75-L80

Added lines #L75 - L80 were not covered by tests

# loading fields
input_fields_names = problem_definition.get_input_fields_names()
output_fields_names = problem_definition.get_output_fields_names()

Check warning on line 84 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L83-L84

Added lines #L83 - L84 were not covered by tests

if "cell_ids" in input_fields_names:
input_fields_names.remove("cell_ids")

Check warning on line 87 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L86-L87

Added lines #L86 - L87 were not covered by tests

new_input_fields_names = []
input_fields = []
if len(input_fields_names)>=1:
for field_name in input_fields_names:
field = sample.get_field(field_name, base_name=base_name)
if field is None: continue
input_fields.append(field)
new_input_fields_names.append(field_name)
if len(input_fields) >= 1:
input_fields = np.vstack(input_fields).T
input_fields = np.concatenate((vertices, input_fields), axis=1)
new_input_fields_names = ["x", "y", *new_input_fields_names]

Check warning on line 100 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L89-L100

Added lines #L89 - L100 were not covered by tests
else:
input_fields = vertices
new_input_fields_names = ["x", "y"]

Check warning on line 103 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L102-L103

Added lines #L102 - L103 were not covered by tests

input_fields_names = new_input_fields_names

Check warning on line 105 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L105

Added line #L105 was not covered by tests

output_fields = []
new_output_fields_names = []
for field_name in output_fields_names:
field = sample.get_field(field_name, base_name=base_name)
if field is None: continue
output_fields.append(field)
new_output_fields_names.append(field_name)
output_fields = np.vstack(output_fields).T

Check warning on line 114 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L107-L114

Added lines #L107 - L114 were not covered by tests

output_fields_names = new_output_fields_names

Check warning on line 116 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L116

Added line #L116 was not covered by tests

# torch tensor conversion
input_scalars = torch.tensor(input_scalars, dtype=torch.float32)
input_fields = torch.tensor(input_fields, dtype=torch.float32)

Check warning on line 120 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L119-L120

Added lines #L119 - L120 were not covered by tests

vertices = torch.tensor(vertices, dtype=torch.float32)
edge_weight = torch.tensor(edge_weight, dtype=torch.float32)
edge_index = torch.tensor(edge_index, dtype=torch.long)

Check warning on line 124 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L122-L124

Added lines #L122 - L124 were not covered by tests

# Extracting special nodal tags
nodal_tags = {}
for k, v in sample.get_nodal_tags(base_name=base_name).items():
nodal_tags["id_" + k] = torch.tensor(v, dtype=torch.long)

Check warning on line 129 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L127-L129

Added lines #L127 - L129 were not covered by tests

if None not in output_scalars and None not in output_fields:
output_scalars = torch.tensor(output_scalars, dtype=torch.float32)
output_fields = torch.tensor(output_fields, dtype=torch.float32)

Check warning on line 133 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L131-L133

Added lines #L131 - L133 were not covered by tests

data = Data(

Check warning on line 135 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L135

Added line #L135 was not covered by tests
pos = vertices,
x = input_fields,
y = output_fields,
x_scalars = input_scalars.reshape(1, -1),
y_scalars = output_scalars.reshape(1, -1),
x_fields_names=input_fields_names,
y_fields_names=output_fields_names,
x_scalars_names=input_scalars_names,
y_scalars_names=output_scalars_names,
edge_index = edge_index.T,
edge_weight = edge_weight,
**faces_dict,
**nodal_tags,
sample_id = sample_id
)

return data

Check warning on line 152 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L152

Added line #L152 was not covered by tests

data = Data(

Check warning on line 154 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L154

Added line #L154 was not covered by tests
pos = vertices,
x_scalars = input_scalars.reshape(1, -1),
x = input_fields,
x_fields_names=input_fields_names,
y_fields_names=output_fields_names,
x_scalars_names=input_scalars_names,
y_scalars_names=output_scalars_names,
edge_index = edge_index.T,
edge_weight = edge_weight,
**faces_dict,
**nodal_tags,
sample_id = sample_id
)

return data

Check warning on line 169 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L169

Added line #L169 was not covered by tests


class Loader():

Check warning on line 172 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L172

Added line #L172 was not covered by tests
"""Loader class to load a PLAID Dataset and convert it to Pytorch Geometric"""
def __init__(self,

Check warning on line 174 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L174

Added line #L174 was not covered by tests
dataset_name: str=None,
cache_dir: str=None):

self.dataset_name = dataset_name
self.cache_dir = cache_dir

Check warning on line 179 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L178-L179

Added lines #L178 - L179 were not covered by tests

def get_dataset_split(self,

Check warning on line 181 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L181

Added line #L181 was not covered by tests
plaid_dataset: PlaidDataset,
problem_definition: ProblemDefinition,
task_split: Union[list[str], str]):
if type(task_split)==list:
assert all(split in problem_definition._split.keys() for split in task_split), f"task_split {task_split} not in set of split keys {problem_definition._split.keys()}"
datasets = []
for split in task_split:
split_ids = problem_definition.get_split(split)
dataset = PlaidDataset()
dataset.set_samples(plaid_dataset.get_samples(ids=split_ids))
datasets.append(dataset)
return tuple(datasets)
assert task_split in problem_definition._split.keys(), f"task_split {task_split} not in set of split keys {problem_definition._split.keys()}"
ids = problem_definition.get_split(task_split)
dataset = PlaidDataset()
dataset.set_samples(plaid_dataset.get_samples(ids=ids))
return dataset

Check warning on line 198 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L185-L198

Added lines #L185 - L198 were not covered by tests


def load_plaid(self,

Check warning on line 201 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L201

Added line #L201 was not covered by tests
task_split: Union[List[str], str]=None) -> Union[PlaidDataset, Tuple[PlaidDataset]]:
hf_dataset = None
try:
hf_dataset = load_dataset(self.dataset_name, split="all_samples", cache_dir=self.cache_dir)
except Exception as e:
print(f"Error loading dataset from Hugging Face: {e}")
print(f"Please refer to the documentation (https://huggingface.co/PLAID-datasets) to first download the dataset with the command:")
print(f"load_dataset('PLAID-datasets/DATASET', split='all_samples', cache_dir='cache_dir')")
raise e

Check warning on line 210 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L203-L210

Added lines #L203 - L210 were not covered by tests

plaid_dataset, problem_definition = huggingface_dataset_to_plaid(hf_dataset)

Check warning on line 212 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L212

Added line #L212 was not covered by tests

return problem_definition, self.get_dataset_split(plaid_dataset, problem_definition, task_split)

Check warning on line 214 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L214

Added line #L214 was not covered by tests

def load(self,

Check warning on line 216 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L216

Added line #L216 was not covered by tests
task_split: Union[List[str], str]=None,
base_name: str=None,
processes_number: Annotated[int, ">=-1"]=1,
verbose=False) -> Tuple[ProblemDefinition, List[Data], ...]:
"""
Load and converts a plaid dataset to torch geometric format
"""
if processes_number == -1: processes_number = os.cpu_count()
if verbose: print(f"Number of processes: {processes_number}")

Check warning on line 225 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L224-L225

Added lines #L224 - L225 were not covered by tests

if type(task_split)==list:
problem_definition, dataset_list = self.load_plaid(task_split=task_split)
processed_list = []
for dataset in dataset_list:
processed_list.append(self.plaid_to_bridge(dataset, problem_definition=problem_definition, base_name=base_name, processes_number=processes_number, verbose=verbose))
bridged_dataset = tuple(processed_list)

Check warning on line 232 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L227-L232

Added lines #L227 - L232 were not covered by tests
else:
problem_definition, dataset = self.load_plaid(task_split=task_split)
bridged_dataset = [self.plaid_to_bridge(dataset, problem_definition=problem_definition, base_name=base_name, processes_number=processes_number, verbose=verbose)]

Check warning on line 235 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L234-L235

Added lines #L234 - L235 were not covered by tests

return problem_definition, *bridged_dataset

Check warning on line 237 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L237

Added line #L237 was not covered by tests

def plaid_to_bridge(self,

Check warning on line 239 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L239

Added line #L239 was not covered by tests
dataset: PlaidDataset,
problem_definition: ProblemDefinition,
base_name: str=None,
processes_number: Annotated[int, ">=-1"]=1,
verbose= True) -> List[Data]:
"""
Converts a Plaid dataset to PytorchGeometric dataset

Args:
dataset (plaid.containers.dataset.Dataset): Plaid dataset

Returns:
List[Data]: the converted dataset
"""
if verbose: print("in bridge")
data_list = []
sample_ids, samples = list(zip(*list(dataset.get_samples().items())))
if processes_number==0 or processes_number==1:
if verbose: iterator = tqdm(zip(samples, sample_ids), total=len(samples))
else: iterator = zip(samples, sample_ids)
for sample, sample_id in iterator:
new_data = sample_to_pyg(sample, sample_id, problem_definition, base_name=base_name)
data_list.append(new_data)
return data_list

Check warning on line 263 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L254-L263

Added lines #L254 - L263 were not covered by tests

with Pool(processes=processes_number) as p:
if verbose: iterator = tqdm(p.starmap(sample_to_pyg, zip(samples, sample_ids, [problem_definition]*len(samples), [base_name]*len(samples))), total=len(samples))
else: iterator = p.starmap(sample_to_pyg, zip(samples, sample_ids, [problem_definition]*len(samples)))
for new_data in iterator:
data_list.append(new_data)

Check warning on line 269 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L265-L269

Added lines #L265 - L269 were not covered by tests

return data_list

Check warning on line 271 in src/plaid/utils/dataset_loader.py

View check run for this annotation

Codecov / codecov/patch

src/plaid/utils/dataset_loader.py#L271

Added line #L271 was not covered by tests
Loading