import logging
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from NEExT.collections import GraphCollection
from NEExT.embeddings import Embeddings
from NEExT.embeddings import GraphEmbeddings
from NEExT.features import Features
from NEExT.features import StructuralNodeFeatures
from NEExT.ml_models import FeatureImportance
from .io import GraphIO
[docs]
class NEExT:
"""
Main interface class for the NEExT framework.
This class maintains the state of various components and provides
a unified interface for users to interact with the framework.
Attributes:
logger: Logger instance for the framework
"""
[docs]
def __init__(self, log_level: str = "INFO"):
"""
Initialize the NEExT framework.
Args:
log_level: Initial logging level (default: "INFO")
"""
# Initialize logger
self.logger = logging.getLogger("NEExT")
self.logger.setLevel(logging.INFO)
# Create console handler with formatting
console_handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
# Initialize components
self.graph_io = GraphIO(logger=self.logger)
self.logger.info("NEExT framework initialized")
[docs]
def set_log_level(self, level: str) -> None:
"""
Set the logging level for the framework.
Args:
level: Logging level ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL")
"""
level_map = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL
}
if level.upper() not in level_map:
self.logger.error(f"Invalid log level: {level}")
raise ValueError(f"Invalid log level. Choose from: {', '.join(level_map.keys())}")
log_level = level_map[level.upper()]
self.logger.setLevel(log_level)
self.logger.info(f"Log level set to: {level}")
[docs]
def read_from_csv(
self,
edges_path: Union[str, Path],
node_graph_mapping_path: Union[str, Path],
graph_label_path: Optional[Union[str, Path]] = None,
node_features_path: Optional[Union[str, Path]] = None,
edge_features_path: Optional[Union[str, Path]] = None,
graph_type: str = "networkx",
reindex_nodes: bool = True,
filter_largest_component: bool = True,
node_sample_rate: float = 1.0
) -> GraphCollection:
"""
Read graph data from CSV files and return a graph collection.
Args:
edges_path: Path to edges CSV file (src_node_id, dest_node_id)
node_graph_mapping_path: Path to node-graph mapping CSV file (node_id, graph_id)
graph_label_path: Optional path to graph labels CSV file (graph_id, graph_label)
node_features_path: Optional path to node features CSV file
edge_features_path: Optional path to edge features CSV file
graph_type: Backend to use ("networkx" or "igraph"). Defaults to "networkx"
reindex_nodes: Whether to reindex nodes to start from 0 (default: True)
filter_largest_component: Whether to keep only the largest connected
component of each graph (default: True)
node_sample_rate: Rate at which to sample nodes from each graph (default: 1.0).
Must be between 0 and 1.
Returns:
GraphCollection: Collection of graphs loaded from CSV files
"""
self.logger.info("Reading graph data from CSV files")
self.logger.debug(f"Edges path: {edges_path}")
self.logger.debug(f"Node-graph mapping path: {node_graph_mapping_path}")
self.logger.debug(f"Reindex nodes: {reindex_nodes}")
self.logger.debug(f"Filter largest component: {filter_largest_component}")
self.logger.debug(f"Node sample rate: {node_sample_rate}")
try:
graph_collection = self.graph_io.read_from_csv(
edges_path=edges_path,
node_graph_mapping_path=node_graph_mapping_path,
graph_label_path=graph_label_path,
node_features_path=node_features_path,
edge_features_path=edge_features_path,
graph_type=graph_type,
reindex_nodes=reindex_nodes,
filter_largest_component=filter_largest_component,
node_sample_rate=node_sample_rate
)
self.logger.info("Successfully loaded graph collection")
self.logger.debug(f"Loaded {len(graph_collection.graphs)} graphs")
return graph_collection
except Exception as e:
self.logger.error(f"Failed to read CSV files: {str(e)}")
raise
[docs]
def get_collection_info(self, graph_collection: GraphCollection) -> dict:
"""
Get basic information about a graph collection.
This method is deprecated. Use graph_collection.describe() instead.
Args:
graph_collection: The graph collection to get information about
Returns:
dict: Dictionary containing collection information
"""
self.logger.warning("get_collection_info is deprecated. Use graph_collection.describe() instead.")
info = graph_collection.describe()
self.logger.debug(f"Collection info: {info}")
return info
[docs]
def compute_node_features(
self,
graph_collection: GraphCollection,
feature_list: List[str],
feature_vector_length: int = 3,
normalize_features: bool = True,
show_progress: bool = True,
n_jobs:int = -1,
my_feature_methods: list = None,
) -> pd.DataFrame:
"""
Compute node features for all graphs in the collection.
Args:
graph_collection: Collection of graphs to compute features for
feature_list: List of features to compute (e.g., ["page_rank", "degree_centrality"])
feature_vector_length: Length of feature vector for each node (default: 3)
normalize_features: Whether to normalize features across all nodes (default: True)
show_progress: Whether to show progress bars during computation (default: True)
Returns:
pd.DataFrame: DataFrame containing computed features for all nodes
"""
self.logger.info(f"Computing node features: {feature_list}")
node_features = StructuralNodeFeatures(
graph_collection=graph_collection,
feature_list=feature_list,
feature_vector_length=feature_vector_length,
normalize_features=normalize_features,
show_progress=show_progress,
n_jobs=n_jobs
)
if my_feature_methods:
for entry in my_feature_methods:
node_features.register_metric(entry["feature_name"], entry["feature_function"])
features = node_features.compute()
self.logger.info(f"Computed features for {len(features.features_df)} nodes")
return features
[docs]
def compute_graph_embeddings(
self,
graph_collection: GraphCollection,
features: Features,
embedding_algorithm: str,
embedding_dimension: int,
feature_columns: Optional[List[str]] = None,
random_state: int = 42,
memory_size: str = "4G"
) -> Embeddings:
"""
Compute graph embeddings based on node features.
Args:
graph_collection: Collection of graphs to compute embeddings for
features: Features object containing node features
embedding_algorithm: Algorithm to use for embedding computation
embedding_dimension: Dimension of the output embeddings
feature_columns: Specific feature columns to use (default: all)
random_state: Random seed for reproducibility
memory_size: Memory limit for algorithms that support it
Returns:
Embeddings: Embeddings object containing computed embeddings
"""
self.logger.info(f"Computing graph embeddings using {embedding_algorithm}")
graph_embeddings = GraphEmbeddings(
graph_collection=graph_collection,
features=features,
embedding_algorithm=embedding_algorithm,
embedding_dimension=embedding_dimension,
feature_columns=feature_columns,
random_state=random_state,
memory_size=memory_size
)
embeddings = graph_embeddings.compute()
self.logger.info(f"Computed embeddings for {len(embeddings.embeddings_df)} graphs")
return embeddings
[docs]
def train_ml_model(
self,
graph_collection: GraphCollection,
embeddings: Embeddings,
model_type: Literal["classifier", "regressor"],
balance_dataset: bool = False,
sample_size: int = 5,
n_jobs: int = -1,
parallel_backend: str = "process"
) -> Dict:
"""
Train and evaluate a machine learning model using graph embeddings.
Args:
graph_collection: Collection of graphs with labels
embeddings: Embeddings object containing graph embeddings
model_type: Type of model to train ("classifier" or "regressor")
balance_dataset: Whether to balance the dataset for classification (default: False)
sample_size: Number of training/testing iterations (default: 5)
n_jobs: Number of parallel jobs (-1 for all CPUs)
parallel_backend: Parallelization backend ("process" or "thread")
Returns:
Dict: Dictionary containing model information and evaluation metrics
"""
self.logger.info(f"Training {model_type} model on graph embeddings")
from .ml_models import MLModels
ml_models = MLModels(
graph_collection=graph_collection,
embeddings=embeddings,
model_type=model_type,
balance_dataset=balance_dataset,
sample_size=sample_size,
n_jobs=n_jobs,
parallel_backend=parallel_backend
)
results = ml_models.compute()
if model_type == "classifier":
self.logger.info(f"Model trained with average accuracy: {np.mean(results['accuracy']):.4f}")
else:
self.logger.info(f"Model trained with average RMSE: {np.mean(results['rmse']):.4f}")
return results
[docs]
def compute_feature_importance(
self,
graph_collection: GraphCollection,
features: Features,
feature_importance_algorithm: str,
embedding_algorithm: str = "approx_wasserstein",
random_state: int = 42,
n_iterations: int = 5
) -> pd.DataFrame:
"""
Compute feature importance for graph embeddings.
Args:
graph_collection: Collection of graphs to analyze
features: Features object containing node features
feature_importance_algorithm: Algorithm to use for importance analysis
("supervised_greedy", "supervised_fast", "unsupervised")
embedding_algorithm: Algorithm to use for embedding computation
random_state: Random seed for reproducibility
n_iterations: Number of iterations for computing average performance
Returns:
pd.DataFrame: DataFrame containing feature importance results
"""
self.logger.info(f"Computing feature importance using {feature_importance_algorithm}")
feature_importance = FeatureImportance(
graph_collection=graph_collection,
features=features,
algorithm=feature_importance_algorithm,
embedding_algorithm=embedding_algorithm,
random_state=random_state,
n_iterations=n_iterations
)
results_df = feature_importance.compute()
self.logger.info("Feature importance analysis completed")
return results_df