ConnectivityDataset

class spikeometric.datasets.ConnectivityDataset(root)[source]

Bases: object

A dataset of connectivity matrices for networks of neurons.

The connectivity matrices are loaded from a directory of .npy or .pt files. Each file should contain a square connectivity matrix for a network of neurons. The connectivity matrices are converted to torch_geometric Data objects with edge_index, W0 and num_nodes attributes.

By using torch_geometric’s DataLoader, the connectivity matrices can be batched together into a single graph, with each of the n_networks examples as an isolated subgraph.

Example

>>> from spikeometric.datasets import ConnectivityDataset
>>> from torch_geometric.loader import DataLoader
>>> dataset = ConnectivityDataset("datasets/example_dataset")
>>> len(dataset)
10
>>> data = dataset[0]
>>> data
Data(edge_index=[2, 5042], W0=[5042], num_nodes=100)
>>> loader = DataLoader(dataset, batch_size=2)
>>> for batch in loader:
...     print(batch)
>>> for batch in loader:
...     print(batch)
...
DataBatch(edge_index=[2, 25242], W0=[25242], num_nodes=500, batch=[500], ptr=[6])
DataBatch(edge_index=[2, 25250], W0=[25250], num_nodes=500, batch=[500], ptr=[6])
Parameters

(string) (root) – Root directory where the dataset should be saved.

process()[source]

Processes the connectivity matrices in the root directory and returns a list of torch_geometric Data objects.

combine_all()[source]

Combines all the Data objects into a single Data object.