Source code for spikeometric.stimulus.loaded_stimulus

import torch
import torch.nn as nn
import math
from spikeometric.stimulus.base_stimulus import BaseStimulus
from typing import Union

[docs]class LoadedStimulus(BaseStimulus): r""" Stimulus loaded from a file. The file should be a .pt file containing a torch.Tensor of shape (n_neurons, n_steps, n_networks), where n_neurons is the number of neurons in the network and n_steps is the number of time steps in the stimulus. At each time step, the stimulus is a :obj:`torch.Tensor` of length n_neurons, indicating the stimulus to each neuron. Example: >>> from spikeometric.stimulus import LoadedStimulus >>> stimulus = LoadedStimulus("example_directory/stimulus_file.pt", batch_size=3) >>> stimulus(0).shape torch.Size([60]) >>> stimulus.set_batch(3) >>> stimulus(0).shape torch.Size([20]) >>> stimulus.reset() >>> model.add_stimulus(stimulus) Parameters ---------- path : str Path to the file containing the stimulus. batch_size : int, optional Number of networks in each batch. If the number of networks in the stimulus is not a multiple of the batch size, the last batch will contain fewer networks. Default: 1. """ def __init__(self, path: str, batch_size: int = 1): super().__init__() stimuli = torch.load(path) self.register_buffer("n_neurons", torch.tensor(stimuli.shape[0], dtype=torch.int)) self.register_buffer("n_steps", torch.tensor(stimuli.shape[1], dtype=torch.int)) if len(stimuli.shape) < 4: n_networks = stimuli.shape[2] if len(stimuli.shape) > 2 else 1 self.register_buffer("n_networks", torch.tensor(n_networks, dtype=torch.int)) else: n_networks = stimuli.shape[3] self.register_buffer("n_networks", torch.tensor(n_networks, dtype=torch.int)) self.register_buffer("batch_size", torch.tensor(batch_size, dtype=torch.int)) if self.n_networks < batch_size: raise ValueError("The number of networks in the stimulus is smaller than the batch size.") self.register_buffer("n_batches", torch.tensor(math.ceil(n_networks / batch_size), dtype=torch.int)) if len(stimuli.shape) == 3: stimuli = torch.concat(torch.split(stimuli, 1, dim=2), dim=0).squeeze(2) elif len(stimuli.shape) == 4: stimuli = torch.concat(torch.split(stimuli, 1, dim=3), dim=0).squeeze(3) neurons_per_batch = [self.n_neurons*self.batch_size] * (self.n_batches - 1) if n_networks % batch_size != 0: neurons_per_batch.append(self.n_neurons*(n_networks % batch_size)) else: neurons_per_batch.append(self.n_neurons*self.batch_size) self.register_buffer("neurons_per_batch", torch.tensor(neurons_per_batch, dtype=torch.int)) self.register_buffer("stimuli", stimuli) self._idx = 0 @property def stimulus(self): stimulus = torch.split(self.stimuli, self.neurons_per_batch.tolist(), dim=0) return stimulus[self._idx]
[docs] def __call__(self, t: Union[float, torch.Tensor]) -> torch.Tensor: r""" If :math:`t` is a float, returns the stimulus at time :math:`t`. If :math:`t` is a :obj:`torch.Tensor`, returns the stimulus at each time step in :math:`t`, and the returned tensor has shape (n_neurons, t.shape[0]) Parameters ---------- t : torch.Tensor or float Time :math:`t` at which to compute the stimulus. Returns ------- torch.Tensor [n_neurons, t.shape[0]] or [n_neurons] Stimulus at time :math:`t`. """ if torch.is_tensor(t) and not t.dim() == 0: return self.stimulus[..., :t.shape[0], ...] else: return self.stimulus[..., t, ...] if 0 <= t < self.n_steps else torch.zeros(self.neurons_per_batch[self._idx], dtype=torch.float, device=self.stimulus.device)