from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
import pandas as pd
from pydantic import BaseModel, Field
from NEExT.collections import GraphCollection
[docs]
class GraphIO:
"""
Input/Output class for reading and writing graph data.
This class provides methods to read graph data from various file formats
and create a GraphCollection instance.
"""
[docs]
def __init__(self, logger=None):
"""Initialize GraphIO with optional logger."""
self.logger = logger
[docs]
def read_from_csv(
self,
edges_path: Union[str, Path],
node_graph_mapping_path: Union[str, Path],
graph_label_path: Optional[Union[str, Path]] = None,
node_features_path: Optional[Union[str, Path]] = None,
edge_features_path: Optional[Union[str, Path]] = None,
graph_type: str = "networkx",
reindex_nodes: bool = True,
filter_largest_component: bool = True,
node_sample_rate: float = 1.0
) -> GraphCollection:
"""
Read graph data from CSV files and create a GraphCollection.
Args:
edges_path: Path to edges CSV file (src_node_id, dest_node_id)
node_graph_mapping_path: Path to node-graph mapping CSV file (node_id, graph_id)
graph_label_path: Optional path to graph labels CSV file (graph_id, graph_label)
node_features_path: Optional path to node features CSV file
edge_features_path: Optional path to edge features CSV file
graph_type: Backend to use ("networkx" or "igraph"). Defaults to "networkx"
reindex_nodes: Whether to reindex nodes to start from 0 (default: True)
filter_largest_component: Whether to keep only the largest connected
component of each graph (default: True)
node_sample_rate: Rate at which to sample nodes from each graph (default: 1.0).
Must be between 0 and 1.
Returns:
GraphCollection: Collection of graphs created from the CSV data
"""
# Read required CSV files
edges_df = pd.read_csv(edges_path)
node_graph_df = pd.read_csv(node_graph_mapping_path)
# Validate required columns
if not {'src_node_id', 'dest_node_id'}.issubset(edges_df.columns):
raise ValueError("edges.csv must contain 'src_node_id' and 'dest_node_id' columns")
if not {'node_id', 'graph_id'}.issubset(node_graph_df.columns):
raise ValueError("node_graph_mapping.csv must contain 'node_id' and 'graph_id' columns")
# Read graph labels if provided
graph_labels_df = None
if graph_label_path:
graph_labels_df = pd.read_csv(graph_label_path)
if not {'graph_id', 'graph_label'}.issubset(graph_labels_df.columns):
raise ValueError("graph_labels.csv must contain 'graph_id' and 'graph_label' columns")
# Read optional feature files
node_features_df = None
if node_features_path:
node_features_df = pd.read_csv(node_features_path)
if 'node_id' not in node_features_df.columns:
raise ValueError("node_features.csv must contain 'node_id' column")
edge_features_df = None
if edge_features_path:
edge_features_df = pd.read_csv(edge_features_path)
if not {'src_node_id', 'dest_node_id'}.issubset(edge_features_df.columns):
raise ValueError("edge_features.csv must contain 'src_node_id' and 'dest_node_id' columns")
# Validate node_sample_rate
if not 0.0 < node_sample_rate <= 1.0:
raise ValueError("node_sample_rate must be between 0 and 1")
# Organize data by graph
graphs_data = self._organize_graph_data(
edges_df,
node_graph_df,
node_features_df,
edge_features_df,
graph_labels_df
)
# Create GraphCollection and add graphs
collection = GraphCollection(graph_type=graph_type, node_sample_rate=node_sample_rate)
collection.add_graphs(
graph_data_list=graphs_data,
graph_type=graph_type,
reindex_nodes=reindex_nodes,
filter_largest_component=filter_largest_component,
node_sample_rate=node_sample_rate
)
return collection
[docs]
def load_from_dfs(
self,
edges_df: pd.DataFrame,
node_graph_df: pd.DataFrame,
graph_labels_df: Optional[pd.DataFrame] = None,
node_features_df: Optional[pd.DataFrame] = None,
edge_features_df: Optional[pd.DataFrame] = None,
graph_type: str = "networkx",
reindex_nodes: bool = True,
filter_largest_component: bool = True,
node_sample_rate: float = 1.0,
) -> GraphCollection:
# Validate required columns
if not {"src_node_id", "dest_node_id"}.issubset(edges_df.columns):
raise ValueError("edges_df must contain 'src_node_id' and 'dest_node_id' columns")
if not {"node_id", "graph_id"}.issubset(node_graph_df.columns):
raise ValueError("node_graph_df must contain 'node_id' and 'graph_id' columns")
if graph_labels_df is not None:
if not {"graph_id", "graph_label"}.issubset(graph_labels_df.columns):
raise ValueError("graph_labels_df must contain 'graph_id' and 'graph_label' columns")
# Read optional feature files
if node_features_df is not None:
if "node_id" not in node_features_df.columns:
raise ValueError("node_features_df must contain 'node_id' column")
if edge_features_df is not None:
if not {"src_node_id", "dest_node_id"}.issubset(edge_features_df.columns):
raise ValueError("edge_features_df must contain 'src_node_id' and 'dest_node_id' columns")
# Validate node_sample_rate
if not 0.0 < node_sample_rate <= 1.0:
raise ValueError("node_sample_rate must be between 0 and 1")
# Organize data by graph
graphs_data = self._organize_graph_data(
edges_df,
node_graph_df,
node_features_df,
edge_features_df,
graph_labels_df,
)
# Create GraphCollection and add graphs
collection = GraphCollection(graph_type=graph_type, node_sample_rate=node_sample_rate)
collection.add_graphs(
graph_data_list=graphs_data,
graph_type=graph_type,
reindex_nodes=reindex_nodes,
filter_largest_component=filter_largest_component,
node_sample_rate=node_sample_rate,
)
return collection
[docs]
def _organize_graph_data(
self,
edges_df: pd.DataFrame,
node_graph_df: pd.DataFrame,
node_features_df: Optional[pd.DataFrame],
edge_features_df: Optional[pd.DataFrame],
graph_labels_df: Optional[pd.DataFrame]
) -> List[Dict]:
"""
Organizes the data from DataFrames into a list of graph dictionaries.
Args:
edges_df (pd.DataFrame): DataFrame containing edge information
node_graph_df (pd.DataFrame): DataFrame containing node-to-graph mapping
node_features_df (Optional[pd.DataFrame]): DataFrame containing node features
edge_features_df (Optional[pd.DataFrame]): DataFrame containing edge features
graph_labels_df (Optional[pd.DataFrame]): DataFrame containing graph labels
Returns:
List[Dict]: List of dictionaries containing organized graph data
"""
# Group nodes by graph_id
graph_nodes = defaultdict(list)
for _, row in node_graph_df.iterrows():
graph_nodes[row['graph_id']].append(row['node_id'])
# Create graph labels dictionary if available
graph_labels = {}
if graph_labels_df is not None:
graph_labels = dict(zip(graph_labels_df['graph_id'], graph_labels_df['graph_label']))
# Create graph data dictionaries
graphs_data = []
for graph_id, nodes in graph_nodes.items():
# Get edges for this graph
graph_edges = edges_df[
(edges_df['src_node_id'].isin(nodes)) &
(edges_df['dest_node_id'].isin(nodes))
]
edges = list(zip(graph_edges['src_node_id'], graph_edges['dest_node_id']))
# Initialize graph data
graph_data = {
"graph_id": graph_id,
"graph_label": graph_labels.get(graph_id),
"nodes": nodes,
"edges": edges,
"node_attributes": {},
"edge_attributes": {}
}
# Add node features if available
if node_features_df is not None:
node_features = node_features_df[node_features_df['node_id'].isin(nodes)]
feature_cols = [col for col in node_features.columns if col != 'node_id']
for _, row in node_features.iterrows():
graph_data["node_attributes"][row['node_id']] = {
col: row[col] for col in feature_cols
}
# Add edge features if available
if edge_features_df is not None:
edge_features = edge_features_df[
(edge_features_df['src_node_id'].isin(nodes)) &
(edge_features_df['dest_node_id'].isin(nodes))
]
feature_cols = [col for col in edge_features.columns
if col not in ['src_node_id', 'dest_node_id']]
for _, row in edge_features.iterrows():
edge_key = (row['src_node_id'], row['dest_node_id'])
graph_data["edge_attributes"][edge_key] = {
col: row[col] for col in feature_cols
}
graphs_data.append(graph_data)
return graphs_data