ConnectivityDataset
- class spikeometric.datasets.ConnectivityDataset(root)[source]
Bases:
objectA dataset of connectivity matrices for networks of neurons.
The connectivity matrices are loaded from a directory of .npy or .pt files. Each file should contain a square connectivity matrix for a network of neurons. The connectivity matrices are converted to torch_geometric Data objects with edge_index, W0 and num_nodes attributes.
By using torch_geometric’s DataLoader, the connectivity matrices can be batched together into a single graph, with each of the n_networks examples as an isolated subgraph.
Example
>>> from spikeometric.datasets import ConnectivityDataset >>> from torch_geometric.loader import DataLoader >>> dataset = ConnectivityDataset("datasets/example_dataset") >>> len(dataset) 10 >>> data = dataset[0] >>> data Data(edge_index=[2, 5042], W0=[5042], num_nodes=100) >>> loader = DataLoader(dataset, batch_size=2) >>> for batch in loader: ... print(batch) >>> for batch in loader: ... print(batch) ... DataBatch(edge_index=[2, 25242], W0=[25242], num_nodes=500, batch=[500], ptr=[6]) DataBatch(edge_index=[2, 25250], W0=[25250], num_nodes=500, batch=[500], ptr=[6])
- Parameters
(string) (root) – Root directory where the dataset should be saved.