Source code for spikeometric.stimulus.sin_stimulus

import torch
import torch.nn as nn
from spikeometric.stimulus.base_stimulus import BaseStimulus
import math
from typing import Union

[docs]class SinStimulus(BaseStimulus): r""" Sinusoidal stimulus of neurons. The stimulus is a sinusoidal function with amplitude :math:`A`, period :math:`T`, and phase :math:`\phi`. The stimulus starts at time :math:`t_0` and lasts for a duration :math:`\tau`. Parameters ---------- amplitude : float Amplitude :math:`A` of stimulus. period : float Period :math:`T` of stimulus duration : int Duration :math:`\tau` stimulus in total stimulus_masks: A mask of shape (n_neurons,) indicating which neurons to stimulate. batch_size : int The number of networks to stimulate in parallel. phase : float Phase of stimulus :math:`\phi` baseline: float The constant baseline of the stimulus. start : float Start time :math:`t_0` of stimulus. dt : float Time step :math:`\Delta t` of the simulation in ms. """ def __init__(self, amplitude: float, period: float, duration: int, stimulus_masks: torch.Tensor, batch_size: int = 1, phase: float = 0., baseline: float = 0, start: float = 0., dt: float = 1.): super().__init__() if amplitude < 0: raise ValueError("All amplitudes must be positive.") if period < 0: raise ValueError("All periods must be positive.") if duration < 0: raise ValueError("All durations must be positive.") self.register_parameter("amplitude", nn.Parameter(torch.tensor(amplitude, dtype=torch.float))) self.register_parameter("period", nn.Parameter(torch.tensor(period, dtype=torch.float))) self.register_parameter("phase", nn.Parameter(torch.tensor(phase, dtype=torch.float))) self.register_parameter("baseline", nn.Parameter(torch.tensor(baseline, dtype=torch.float))) self.register_buffer("duration", torch.tensor(duration, dtype=torch.int)) self.register_buffer("start", torch.tensor(start, dtype=torch.float)) self.register_buffer("dt", torch.tensor(dt, dtype=torch.float)) if isinstance(stimulus_masks, torch.Tensor) and stimulus_masks.ndim == 1: stimulus_masks = [stimulus_masks] if isinstance(stimulus_masks, torch.Tensor) and stimulus_masks.ndim == 2: stimulus_masks = [sm.squeeze() for sm in torch.split(stimulus_masks, 1, dim=0)] conc_stimulus_masks, split_points = self.batch_stimulus_masks(stimulus_masks, batch_size) self.n_batches = math.ceil(len(stimulus_masks) / batch_size) self.register_buffer("conc_stimulus_masks", conc_stimulus_masks) self.register_buffer("split_points", torch.tensor(split_points, dtype=torch.int)) self._idx = 0 self.requires_grad_(False)
[docs] def __call__(self, t: Union[float, torch.Tensor]) -> torch.Tensor: r""" Computes stimulus at time t by applying a sinusoidal function. Between the start time :math:`t_0` and the end time :math:`t_0 + \tau`, the stimulus is given by .. math:: f(t) = A \sin \left( \frac{2 \pi}{T} (t - t_0)\Delta t + \phi \right) """ time_mask = (t < self.duration) * (t >= self.start) if torch.is_tensor(t): stim_mask = self.stimulus_masks[self._idx].unsqueeze(1) return (self.amplitude * torch.sin(2*torch.pi / self.period * (t-self.start)*self.dt + self.phase) + self.baseline) * time_mask * stim_mask return ( self.amplitude * torch.sin(2*torch.pi / self.period * (t-self.start)*self.dt + self.phase) + self.baseline ) * time_mask * self.stimulus_masks[self._idx]