Source code for spikeometric.stimulus.poisson_stimulus

import torch
import torch.nn as nn
from spikeometric.stimulus.base_stimulus import BaseStimulus
import math
from typing import Union

[docs]class PoissonStimulus(BaseStimulus): r""" Poisson stimulus of neurons. The stimulus times are modeled as a Poisson process with mean interval :math:`\lambda` and total duration :math:`T`. The stimulus events are of duration :math:`\tau` and strength :math:`s`. Parameters ---------- strength : float Strength of the stimulus :math:`s` mean_interval : int Mean interval :math:`\lambda` between stimulus events duration : int Total duration :math:`T` of the stimulus. stimulus_mask : torch.Tensor[bool] A mask of shape (n_neurons,) indicating which neurons to stimulate. batch_size : int Number of networks to stimulate in parallel. tau : int Duration of stimulus events dt: float Time step of the simulation in ms. start : float Start time of the first stimulus event. (ms) rng : torch.Generator Random number generator. """ def __init__(self, strength: float, mean_interval: int, duration: int, stimulus_masks: torch.Tensor, batch_size: int = 1, tau: int = 1, dt: float = 1, start: float = 0, rng=None): super().__init__() if tau < 0: raise ValueError("Stimulus length must be positive.") if mean_interval < 0: raise ValueError("Intervals must be positive.") # Buffers self.register_buffer("dt", torch.tensor(dt, dtype=torch.float)) self.register_buffer("tau", torch.tensor(tau, dtype=torch.int)) self.register_buffer("mean_interval", torch.tensor(mean_interval, dtype=torch.int)) self.register_buffer("duration", torch.tensor(duration, dtype=torch.int)) self.register_buffer("start", torch.tensor(start, dtype=torch.int)) if isinstance(stimulus_masks, torch.Tensor) and stimulus_masks.ndim == 1: stimulus_masks = [stimulus_masks] if isinstance(stimulus_masks, torch.Tensor) and stimulus_masks.ndim == 2: stimulus_masks = [sm.squeeze() for sm in torch.split(stimulus_masks, 1, dim=0)] conc_stimulus_masks, split_points = self.batch_stimulus_masks(stimulus_masks, batch_size) self.register_buffer("split_points", torch.tensor(split_points, dtype=torch.int)) self.register_buffer("conc_stimulus_masks", conc_stimulus_masks) self.n_batches = math.ceil(len(stimulus_masks) / batch_size) self._idx = 0 self.register_parameter("strength", nn.Parameter(torch.tensor(strength, dtype=torch.float))) self._rng = rng if rng is not None else torch.Generator() self.requires_grad_(False) stimulus_times = self._generate_stimulus_plan(self.mean_interval, self.duration) self.register_buffer("stimulus_times", stimulus_times) def _generate_stimulus_plan(self, mean_interval, duration): """Generate poisson stimulus times. Parameters ---------- intervals : torch.Tensor Mean interval between stimuli. duration : torch.Tensor Duration of the stimulus. Returns ------- torch.Tensor Stimulus times with intervals drawn from a poisson distribution. """ intervals = torch.poisson(mean_interval*torch.ones(duration), generator=self._rng) stimulus_times = torch.cumsum( torch.cat([torch.zeros(1), intervals]), dim=0 ) return stimulus_times[stimulus_times < duration] + self.start
[docs] def __call__(self, t: Union[torch.Tensor, float]) -> torch.Tensor: r""" Computes stimulus at time t. The stimulus is 0 if t is not in the interval of a stimulus event and :math:`s` if t is. Parameters ---------- t : torch.Tensor Time at which to compute the stimulus. returns : torch.Tensor Stimulus at time t. """ if torch.is_tensor(t): result = torch.zeros(self.stimulus_masks[self._idx].shape[0], t.shape[0], dtype=torch.float, device=self.stimulus_masks[self._idx].device) stimulus_times = self.stimulus_times.unsqueeze(1) mask = (t - stimulus_times < self.tau) * (t - stimulus_times >= 0) result[:, :self.duration] = self.strength * mask.any(0)[:self.duration] * self.stimulus_masks[self._idx].unsqueeze(1) return result mask = (t - self.stimulus_times < self.tau) * (t - self.stimulus_times >= 0) return self.strength * mask.any() * self.stimulus_masks[self._idx]