Open In Colab

Data Explorer

Synaptic Connectivity

Setup

toc

Import and define functions

#@title {display-mode: "form" }

#@markdown Run this code cell to import packages and define functions 
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import ndimage
from scipy.optimize import curve_fit
from scipy.signal import hilbert,medfilt,resample, find_peaks, unit_impulse
import seaborn as sns
from datetime import datetime,timezone,timedelta
pal = sns.color_palette(n_colors=15)
pal = pal.as_hex()
import matplotlib.pyplot as plt
import random
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

from pathlib import Path

from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
from ipywidgets import widgets, interact, interactive, interactive_output
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle")

print('Task completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Mount Google Drive

#@title {display-mode: "form" }

#@markdown Run this cell to mount your Google Drive.

from google.colab import drive
drive.mount('/content/drive')

print('Task completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Import data

Import data digitized with Nidaq USB6211 and recorded using Bonsai-rx as a .bin file

If you would like sample this Data Explorer, but do not have data, you can download the following examples (two channels digitized at 40000). Channel 0 is the signal measured from N3 and Channel 1 is the signal measured from the Superficial Flexor muscle.

#@title {display-mode: "form" }

#@markdown Specify the file path 
#@markdown to your recorded data on Drive (find the filepath in the colab file manager:

filepath = "full filepath goes here"  #@param 
# filepath = '/Volumes/Untitled/BIOL247/data/crayfish-synaptic-connectivity/KP_20221113/spont-simult_A4_2_good-2input.bin' 

#@markdown Specify the sampling rate and number of channels recorded.

sampling_rate = None #@param
number_channels = None #@param

# downsample = False #@param
# newfs = 10000 #@param

#@markdown After you have filled out all form fields, 
#@markdown run this code cell to load the data. 

filepath = Path(filepath)

# No need to edit below this line
#################################
data = np.fromfile(Path(filepath), dtype = np.float64)
data = data.reshape(-1,number_channels)
data_dur = np.shape(data)[0]/sampling_rate
print('duration of recording was %0.2f seconds' %data_dur)

fs = sampling_rate
# if downsample:
   ## newfs = 10000 #downsample emg data
    # chunksize = int(sampling_rate/newfs)
    # data = data[0::chunksize,:]
    # fs = int(np.shape(data)[0]/data_dur)

time = np.linspace(0,data_dur,np.shape(data)[0])

print('Data upload completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Visualize Data

Use this visualization tool if desired to get a sanity check that the data you thought you imported is actually the data that got imported.

#@title {display-mode: "form"}

#@markdown Run this code cell to plot imported data. <br> 
#@markdown Use the range slider to scroll through the data in time.
#@markdown Use the channel slider to choose which channel to plot
#@markdown Be patient with the range refresh... the more data you are plotting the slower it will be. 

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,data_dur),
    step= 0.5,
    readout=True,
    continuous_update=False,
    description='Time Range (s)')
slider_xrange.layout.width = '600px'

slider_chan = widgets.IntSlider(
    min=0,
    max=number_channels-1,
    value=0,
    step= 1,
    continuous_update=False,
    description='channel')
slider_chan.layout.width = '300px'

# a function that will modify the xaxis range
def update_plot(x,chan):
    fig, ax = plt.subplots(figsize=(10,5),num=1); #specify figure number so that it does not keep creating new ones
    starti = int(x[0]*fs)
    stopi = int(x[1]*fs)
    ax.plot(time[starti:stopi], data[starti:stopi,chan])

w = interact(update_plot, x=slider_xrange, chan=slider_chan);

For a more extensive RAW Data Explorer than the one provided in the above figure, use the DataExplorer.py application found in the howto section of the course website.

Define Bouts

To efficiently assess your data with this analysis, make sure to exclude any raw data that does not have a clean (low-noise) signal. For the simultaneously recorded pre- and post-synaptic signals, make sure to exclude raw data in which the post-synaptic electrode was not stably in the cell. The more data you are able to include, the better your spike sorting results will be.

#@title {display-mode: "form"}

#@markdown For this experiment, the entire file should be one long bout, 
#@markdown but if there were regions that something got messed up or that you want to exclude, you can specify bouts with good data.
#@markdown Specify the list of bout ranges as follows: [[start of bout 0, end of bout 0],[start 1, end 1],...]] <br>


bouts_list = [[None,None]] #@param

#@markdown Then run this code cell to programatically define the list of bouts as 'bouts_list'.

Part I. Spike Sorting

Python has built-in algorithms for detecting “peaks” in a signal. However, it will detect all peaks. Therefore, the function takes in arguments that specify parameters for minimum height that can count as a peak and a minimum acceptible interval between independent peaks.

First, we will detect all the peaks on the N3 channel. This will give the time of each peak whose amplitude falls between the two given thresholds (putative motor neuron spikes).

Use the Dash Data Explorer to visualize the spike waveform shapes. This will enable you to determine which “polarity” is optimal for peak detection in your recording.

#@title {display-mode: "form"}

#@markdown Indicate which channel has the N3 signal with motor neuron spikes.

spike_channel = None #@param

#@markdown Then, run the code cell to create an interactive plot with a slider to scroll 
#@markdown through the raw data and set an upper and lower peak detection threshold.
#@markdown You can set the polarity of the peak detection: upward (1) or downward (-1) peaks. 
#@markdown Peak times (according to your threshold) will be plotted using red markers. <br>
#@markdown Only the signal within the previously specified bout times will be processed (shaded gray regions in the plot).

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,data_dur),
    step=0.01,
    readout_format='.2f',
    continuous_update=False,
    readout=True,
    description='xrange (s)'
)
slider_xrange.layout.width = '600px'

slider_yrange = widgets.FloatRangeSlider(
    min=np.min(data[:,spike_channel])-0.1,
    max=np.max(data[:,spike_channel])+0.1,
    value=[np.min(data[:,spike_channel])-0.1,np.max(data[:,spike_channel])+0.1],
    step=0.01,
    readout_format='.2f',
    continuous_update=False,
    readout=True,
    description='yrange'
)
slider_yrange.layout.width = '600px'



slider_threshold_low = widgets.FloatSlider(
    min=0,
    max=np.max([np.max(data[:,spike_channel]),np.abs(np.min(data[:,spike_channel]))])+0.1,
    value=0,
    step=0.001,
    readout_format='.3f',
    continuous_update=False,
    readout=True,
    description='lower threshold')
slider_threshold_low.layout.width = '600px'

slider_threshold_high = widgets.FloatSlider(
    min=0,
    max=np.max([np.max(data[:,spike_channel]),np.abs(np.min(data[:,spike_channel]))])+0.1,
    value=np.max([np.max(data[:,spike_channel]),np.abs(np.min(data[:,spike_channel]))])+0.1,
    step=0.001,
    readout_format='.3f',
    continuous_update=False,
    readout=True,
    description='upper threshold')
slider_threshold_high.layout.width = '600px'

radio_polarity = widgets.RadioButtons(
    options=[1, -1],
    value=1,
    description='peaks polarity',
    disabled=False
)

iei_text = widgets.Text(
    value='0.005',
    placeholder='0.005',
    description='min IEI (seconds)',
    style = {'description_width': '200px'},
    disabled=False
)


# a function that will modify the xaxis range
def update_plot(xrange,yrange,thresh_low_,thresh_high_,polarity,iei):
    fig, ax = plt.subplots(figsize=(10,6),num=1); #specify figure number so that it does not keep creating new ones
    fig.tight_layout()    
    
    win_0 = int(xrange[0]*fs)
    win_1 = int(xrange[1]*fs)

    xtime = np.linspace(xrange[0],xrange[1],(win_1 - win_0))

    ax.plot(xtime,data[win_0:win_1,spike_channel],color='black',linewidth=1)
    ax.set_ylim(yrange[0],yrange[1]);
    
    ax.hlines(thresh_low_*polarity, xrange[0],xrange[1],linestyle='--',color='green',zorder=3)
    ax.hlines(thresh_high_*polarity, xrange[0],xrange[1],linestyle='--',color='orange',zorder=3)
    
    
    # calculate spike times based on threshold
    d = float(iei)*fs #minimum time allowed between distinct events
    r = find_peaks(data[:,spike_channel]*polarity,height=thresh_low_,distance=d)

    spike_times = r[0]/fs
    mask_spikes = r[1]['peak_heights']<thresh_high_
    # spike_times = spike_times[mask_spikes]

    inwin_inds = []
    for b_ in bouts_list:
        inwin_inds.extend(r[0][(spike_times>b_[0]) & (spike_times<b_[1]) & mask_spikes])
    inwin_inds = np.asarray(inwin_inds)
    inwin_inds = np.in1d(r[0],inwin_inds)
    
    df_props = pd.DataFrame({
            'height': r[1]['peak_heights'][inwin_inds]*polarity,
            'spikeT' : spike_times[inwin_inds],
            'spikeInd' : r[0][inwin_inds]
                })
    
    spike_times = spike_times[mask_spikes]
    
    inwin_spikes = spike_times[(spike_times>(xrange[0])) & (spike_times<(xrange[1]))]
    ax.scatter(inwin_spikes,[np.mean(data[:,spike_channel])] * len(inwin_spikes),
          zorder=3,color='red',s=20)

    for b_ in bouts_list:
        ax.axvspan(b_[0], b_[1], color = 'black', alpha=0.1)    

    ax.set_xlim(xrange[0],xrange[1])
    return spike_times,df_props
    

w_spikes_ = interactive(update_plot, xrange=slider_xrange, 
             yrange=slider_yrange,
             thresh_low_=slider_threshold_low,thresh_high_=slider_threshold_high,polarity=radio_polarity,iei=iei_text);

display(w_spikes_)
#@title {display-mode: "form"}

#@markdown Run this cell to finalize the list of spike times after settling on a threshold in the interactive plot. <br> 
#@markdown This will also create a histogram plot of peak heights.
spike_times,df_props = w_spikes_.result

n,bins = np.histogram(df_props['height'],bins = 500) # calculate the histogram
bins = bins[1:]
hfig,ax = plt.subplots(figsize=(5,4))
ax.step(bins,n,color='black')
ax.set_ylabel('count',fontsize=14)
ax.set_xlabel('amplitude',fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14);

The histogram plot can give you a sense for how many distinct motor neurons might be in your recording.

We can cluster events based on peak height and waveform shape using “Kmeans” clustering. This will provide us with “putative single units” for further analysis.

If you “over cluster” (try to assign more categories than exist), then you can potentially dissocate closely related event waveforms. You can re-combine clusters at a later stage as needed.

You will be able to visualize the mean spike waveform (and std around the mean) of events with each cluster (putative motor neurons).

#@title { display-mode: "form" }


#@markdown Choose the number of clusters you want to split the event-based data into and type that number below. <br>
#@markdown >Note: It can sometimes help to "over-split" the events into more clusters 
#@markdown than you think will be necessary. You can try both strategies and assess the results.
number_of_clusters = 3 #@param {type: "number"}
#@markdown Then run this cell to run the Kmeans algorithm. 


windur = 0.001
winsamp = int(windur*fs)
spkarray = []
for i in df_props['spikeInd'].values:
    spkarray.append(data[i-winsamp : i+winsamp+1,spike_channel])

df = pd.DataFrame(np.asarray(spkarray).T)
df_norm =(df - df.mean()) / df.std() # normalize for pca

n_components=5 #df.shape[0]  
pca = PCA(n_components=n_components)
pca.fit(df_norm)
df_pca = pd.DataFrame(pca.transform(df), columns=['PC%i' % i for i in range(n_components)], index=df.index)

loadings = pd.DataFrame(pca.components_.T, columns=df_pca.columns, index=df.columns)
df_data = loadings.join(df_props['height'])

# hfig,ax = plt.subplots(1)
# ax.set_xlabel('seconds')
# ax.set_ylabel('amplitude (a.u.)')
# ax.set_yticklabels([])
# for c in df_pca.columns:
#     ax.plot(df_pca[c],label = c,alpha = 0.75)
# plt.legend(bbox_to_anchor=(1, 1));


kmeans = KMeans(n_clusters=number_of_clusters).fit(df_data)
# df_props['peaks_t'] = peaks_t
df_props['cluster'] = kmeans.labels_

df_data['cluster_id'] = kmeans.labels_
# sns.scatterplot(x='PC0',y='PC1',hue='cluster_id',data=df_data)

#@title {display-mode:'form'}

#@markdown Run this cell to display the mean (and std) waveform for each cluster.

windur = 0.001
winsamps = int(windur * fs)
x = np.linspace(-windur,windur,winsamps*2)*1000
hfig,ax = plt.subplots(1,figsize=(8,6))
ax.set_ylabel('Volts recorded',fontsize=14)
ax.set_xlabel('milliseconds',fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

for k in np.unique(df_props['cluster']):
    spkt = df_props.loc[df_props['cluster']==k]['spikeT'].values #['peaks_t'].values
    spkt = spkt[(spkt>windur) & (spkt<(data_dur)-windur)]
    print(str(len(spkt)) + " spikes in cluster number " + str(k))
    spkwav = np.asarray([data[(int(t*fs)-winsamps):(int(t*fs)+winsamps),spike_channel] for t in spkt])
    wav_u = np.mean(spkwav,0)
    wav_std = np.std(spkwav,0)
    ax.plot(x,wav_u,linewidth = 3,label='cluster '+ str(k),color=pal[k])
    ax.fill_between(x, wav_u-wav_std, wav_u+wav_std, alpha = 0.25,color=pal[k])
plt.legend(bbox_to_anchor=[1.25,1],fontsize=14);

print('Tasks completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

If there are multiple spike clusters you want to merge into a single cell class, edit and run the cell below.

merge_cluster_list = a list of the clusters (identified by numbers associated with the colors specified in the legend above).

  • For example, the folowing list would merge clusters 0 and 2 together and 1, 4, and 3 together:
    merge_cluster_list = [[0,2],[1,4,3]]

  • For each merge group, the first cluster number listed will be the re-asigned cluster number for that group (for example, in this case you would end up with a cluster number 0 and a cluster number 1).

  • Leave any clusters that don’t need merging out of the list.

#@title { display-mode: "form" }

#@markdown ONLY USE THIS CODE CELL IF YOU WANT TO MERGE CLUSTERS. 
#@markdown OTHERWISE, MOVE ON. 
#@markdown <br> Below, create your list (of sublists) of clusters to merge.
#@markdown >Just leave out from the list any clusters that you want unmerged.
merge_cluster_list = [[1,4,3]] #@param
#@markdown Then, run this cell to merge clusters as specified. <br>
#@markdown A new figure of waveform shapes will be generated for the new categorization.

for k_group in merge_cluster_list:
    for k in k_group:
        df_props.loc[df_props['cluster']==k,'cluster'] = k_group[0]
print('you now have the following clusters: ' + str(np.unique(df_props['cluster'])))

print('Tasks completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

windur = 0.001
winsamps = int(windur * fs)
x = np.linspace(-windur,windur,winsamps*2)*1000
hfig,ax = plt.subplots(1,figsize=(8,6))
ax.set_ylabel('Volts recorded',fontsize=14)
ax.set_xlabel('milliseconds',fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

for k in np.unique(df_props['cluster']):
    spkt = df_props.loc[df_props['cluster']==k]['spikeT'].values #['peaks_t'].values
    spkt = spkt[(spkt>windur) & (spkt<(data_dur)-windur)]
    print(str(len(spkt)) + " spikes in cluster number " + str(k))
    spkwav = np.asarray([data[(int(t*fs)-winsamps):(int(t*fs)+winsamps),spike_channel] for t in spkt])
    wav_u = np.mean(spkwav,0)
    wav_std = np.std(spkwav,0)
    ax.plot(x,wav_u,linewidth = 3,label='cluster '+ str(k),color=pal[k])
    ax.fill_between(x, wav_u-wav_std, wav_u+wav_std, alpha = 0.25,color=pal[k])
plt.legend(bbox_to_anchor=[1.25,1],fontsize=14);

Once you are happy with the clustering results based on the waveform shapes, check back with the raw data to make sure the spike assignments are at least close.

#@title {display-mode: "form"}

#@markdown Then, run the code cell to create an interactive plot with a slider to scroll 
#@markdown through the raw data and overlaid "spike-sorted" event time data. <br>
#@markdown Only the signal within the previously specified bout times will be processed (shaded gray regions in the plot).



slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,data_dur),
    step=0.01,
    readout_format='.2f',
    continuous_update=False,
    readout=True,
    description='xrange (s)'
)
slider_xrange.layout.width = '600px'

slider_yrange = widgets.FloatRangeSlider(
    min=np.min(data[:,spike_channel])-0.1,
    max=np.max(data[:,spike_channel])+0.1,
    value=[np.min(data[:,spike_channel])-0.1,np.max(data[:,spike_channel])+0.1],
    step=0.01,
    readout_format='.2f',
    continuous_update=False,
    readout=True,
    description='yrange'
)
slider_yrange.layout.width = '600px'


# a function that will modify the xaxis range
def update_plot(xrange,yrange):
    fig, ax = plt.subplots(figsize=(10,6),num=1); #specify figure number so that it does not keep creating new ones
    fig.tight_layout()    
    
    win_0 = int(xrange[0]*fs)
    win_1 = int(xrange[1]*fs)

    xtime = np.linspace(xrange[0],xrange[1],(win_1 - win_0))

    ax.plot(xtime,data[win_0:win_1,spike_channel],color='black',linewidth=1)
    ax.set_ylim(yrange[0],yrange[1]);
    ax.set_xlim(xrange[0],xrange[1]);
    
    
    for k in np.unique(df_props['cluster']):
        inwin_inds = np.asarray([(df_props['spikeT'].values>xrange[0]) & (df_props['spikeT'].values<xrange[1])]).T
        df_ = df_props[inwin_inds]
        df_ = df_[df_['cluster']==k]
        ax.scatter(df_['spikeT'],df_['height'],
                   color=pal[k],s=20,label='cluster '+ str(k),zorder=3)

    # plt.legend()
    for b_ in bouts_list:
        ax.axvspan(b_[0], b_[1], color = 'black', alpha=0.1)

w_spikes_sorted_ = interactive(update_plot, xrange=slider_xrange, 
             yrange=slider_yrange);

display(w_spikes_sorted_)

If you think that two different spike waveforms are being lumped together, try going back to the Kmeans clustering algorithm and increasing the cluster number constraint on the Kmeans algorithm - then merge as needed.

Part II. Spike-triggered voltage

toc

You can use the event times (spike times) to extract the pre-synaptic and/or post-synaptic voltage signal following each event. This is a helpful way to determine if you have recorded any synaptic pairs (a connected pair of pre and post-synaptic cells).

Visualize the average raw signal time-locked to putative pre-synaptic neurons.

#@title {display-mode:"form"}

#@markdown Specify which channel has the intracellular signal. 
#@markdown Then run this cell to plot the average spike-triggered 
#@markdown post-synaptic potential for each spike cluster you defined in Part I

intracellular_channel = 1 #@param

slider_xrange = widgets.FloatRangeSlider(
    min=-0.1,
    max=0.500,
    value=(0,0.200),
    step=0.01,
    readout_format='.2f',
    continuous_update=False,
    readout=True,
    description='xrange (s)'
)
slider_xrange.layout.width = '600px'

def update_plot(xrange):
    # No need to edit below this line
    #################################
    windur = xrange[1]-xrange[0]
    winsamps = int(windur * fs)

    onset = int(xrange[0]*fs)
    offset = int(xrange[1]*fs)

    x = np.linspace(xrange[0],xrange[1],offset-onset)
    
    hfig,ax = plt.subplots(figsize=(10,4))
    ax.set_ylabel('volts recorded',fontsize=14)
    ax.set_xlabel('milliseconds',fontsize=14)
    # plt.xticks(fontsize=14)
    # plt.yticks(fontsize=14)
    for k in np.unique(df_props['cluster']):
        spkt = df_props.loc[df_props['cluster']==k]['spikeT'].values
        spkt = spkt[(spkt<((data_dur)-windur))]
        synwav = np.asarray([data[int(t*fs)+onset:int(t*fs)+offset,intracellular_channel] for t in spkt])
        wav_u = np.mean(synwav,0)
        wav_std = np.std(synwav,0)
        ax.plot(x,wav_u,linewidth = 3,color = pal[k],label='cluster '+str(k));
        # ax.fill_between(x, wav_u-wav_std, wav_u+wav_std, alpha = 0.25, color = pal[k])
    plt.legend(bbox_to_anchor=[1.5,1], fontsize=14);
    ax.set_xlim(xrange[0],xrange[1])
    
w_psps_sorted_ = interactive(update_plot, xrange=slider_xrange);

display(w_psps_sorted_)

Visualize the raw signal time-locked to individual spike times

Finally, you can plot the pre and post synaptic signal associated with each individual event time in each cluster.

This visualization enables you to extract more exact quantitative measurements from the signals associated with each event. It is also helpful to look at individual events to get a sense of the variance/reliability in the pre and post-synaptic signal associated with each spike time.

#@title {display-mode:"form"}

#@markdown Run this code cell to create an interactive plot to  
#@markdown examine the raw signal time-locked to individual stimuli (event_times).
#@markdown You can overlay multple channels by selecting more than one.
#@markdown You can overlay multiple stimulus times by selecting more than one. 
#@markdown (To select more than one item from an option menu, press the control/command key 
#@markdown while mouse clicking or shift while using up/down arrows)

slider_xrange = widgets.FloatRangeSlider(
    min=-0.01,
    max=0.3,
    value=(-0.05,0.1),
    step=0.001,
    continuous_update=False,
    readout=True,
    readout_format='.4f',
    description='xrange (s)'
)
slider_xrange.layout.width = '600px'

slider_yrange_0 = widgets.FloatRangeSlider(
    min=-3,
    max=3, # normal range for crayfish superficial flexor
    value=(-2,2),
    step=0.01,
    continuous_update=False,
    readout=True,
    description='yrange ch 0'
)
slider_yrange_0.layout.width = '600px'

slider_yrange_1 = widgets.FloatRangeSlider(
    min=-1,
    max=-0.2, # normal range for crayfish superficial flexor
    value=(-0.7,-0.4),
    step=0.01,
    continuous_update=False,
    readout=True,
    description='yrange ch 1'
)
slider_yrange_1.layout.width = '600px'

ui_range = widgets.VBox([slider_xrange, slider_yrange_0, slider_yrange_1])

select_channels = widgets.SelectMultiple(
    options=np.arange(np.shape(data)[1]), # start with a single trial on a single bout... it will update when runs ,
    value=[0],
    #rows=10,
    description='Channels',
    disabled=False
)

select_clusters = widgets.Select(
    options=np.unique(df_props['cluster']), # start with a single trial on a single bout... it will update when runs ; old: np.arange(len(trial_times)),
    value=[np.unique(df_props['cluster'])[0]],
    #rows=10,
    description='Clusters',
    disabled=False
)

select_trials = widgets.SelectMultiple(
    options=df_props[df_props['cluster']==np.unique(df_props['cluster'])[0]]['spikeInd'], # start with a single trial on a single bout... it will update when runs ,
    value=[df_props[df_props['cluster']==np.unique(df_props['cluster'])[0]]['spikeInd'].values[0]],
    #rows=10,
    description='Spikes',
    disabled=False
)

ui_trials = widgets.HBox([select_channels, select_trials, select_clusters])


def update_plot(chan_list,trial_list,cluster_,xrange,yrange0,yrange1):
    fig, ax0 = plt.subplots(figsize=(8,4),num=1); #specify figure number so that it does not keep creating new ones
     
    ax1 = ax0.twinx()
    
    win_0 = int(xrange[0]*fs)
    win_1 = int(xrange[1]*fs)
    xtime = np.linspace(xrange[0],xrange[1],(win_1 - win_0))
    
    trial_times = df_props[df_props['cluster']==cluster_]['spikeT']
    trial_inds = df_props[df_props['cluster']==cluster_]['spikeInd'].values       
        
    trials_init_ = np.arange(len(trial_times))
    select_trials.options = trials_init_

    trial_list = [t_try for t_try in trial_list if t_try in trials_init_]
    select_trials.value = trial_list
                 
    channel_colors = ['purple','green','blue','orange']
    for chan_ in chan_list:
                    
        this_chan = data[:,chan_]
        for trial_ in trial_list:

            if trial_ in trials_init_:
                t_ = trial_inds[trial_]

                if ((t_+win_0)>0) & ((t_+win_1))<len(this_chan):
                    data_sweep = this_chan[(t_+win_0):(t_+win_1)]
                    
                    if chan_==0:
                        ax0.plot(xtime,data_sweep,color=channel_colors[chan_],linewidth=2,alpha=0.5)
                    if chan_==1:
                        ax1.plot(xtime,data_sweep,color=channel_colors[chan_],linewidth=2,alpha=0.5)
                    
    ax0.set_ylim(yrange0[0],yrange0[1])
    ax1.set_ylim(yrange1[0],yrange1[1])
    ax0.set_xlabel('seconds')
    # ax.vlines(0,yrange[0],yrange[1],color='black')

    
#     # Change major ticks to show every 20.
    # ax_pwm.xaxis.set_major_locator(MultipleLocator(5))
    # ax_pwm.yaxis.set_major_locator(MultipleLocator(5))

    # # Change minor ticks to show every 5. (20/4 = 5)
    # ax_mro.yaxis.set_minor_locator(AutoMinorLocator(10))
    ax.xaxis.set_minor_locator(AutoMinorLocator(10))
    # ax_pwm.yaxis.set_minor_locator(AutoMinorLocator(5))

#     # Turn grid on for both major and minor ticks and style minor slightly
# #     # differently.
    ax.grid(which='major', color='gray', linestyle='-')
    ax.grid(which='minor', color='gray', linestyle=':')
#     ax_pwm.grid(which='major', color='gray', linestyle='-')
#     ax_pwm.grid(which='minor', color='gray', linestyle=':')


# display(w)
w = interactive_output(update_plot, {'chan_list':select_channels,
                                     'trial_list':select_trials, 
                                     'cluster_':select_clusters,
                                     'xrange':slider_xrange,
                                     'yrange0':slider_yrange_0,
                                     'yrange1':slider_yrange_1});
display(ui_trials,w,ui_range)

Written by Dr. Krista Perks for courses taught at Wesleyan University.