Open In Colab

Data Explorer

Setup

toc

Import and define functions

#@title {display-mode: "form" }

#@markdown Run this code cell to import packages and define functions 
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import ndimage
from scipy.optimize import curve_fit
from scipy.signal import hilbert,medfilt,resample, find_peaks, unit_impulse
import seaborn as sns
from datetime import datetime,timezone,timedelta
pal = sns.color_palette(n_colors=15)
pal = pal.as_hex()
import matplotlib.pyplot as plt
import random
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from scipy.io import wavfile
from numpy import NaN
from sklearn.metrics import mean_squared_error

from pathlib import Path

from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
from ipywidgets import widgets, interact, interactive, interactive_output
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle")

print('Task completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Mount Google Drive

#@title {display-mode: "form" }

#@markdown Run this cell to mount your Google Drive.

from google.colab import drive
drive.mount('/content/drive')

print('Task completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Import data

Import data digitized with Nidaq USB6211 and recorded using Bonsai-rx as a .bin file

#@title {display-mode: "form" }

## If need to import from Spike Recorder App as .wav
# filepath = '/Users/kperks/Music/Spike Recorder/BYB_Recording_2022-11-26_12.44.42.wav'
# number_channels = 1

# fs, data = wavfile.read(filepath)
# data = data.reshape(-1,number_channels)
# data_dur = np.shape(data)[0]/fs
# time = np.linspace(0,data_dur,np.shape(data)[0])

#@markdown Specify the file path 
#@markdown to your recorded data on Drive (find the filepath in the colab file manager:

filepath = "full filepath goes here"  #@param 
# filepath = '/Volumes/Untitled/BIOL247/data/crayfish-synaptic-connectivity/KP_20221113/spont-simult_A4_2_good-2input.bin' 

#@markdown Specify the sampling rate and number of channels recorded.

sampling_rate = None #@param
number_channels = None #@param

# downsample = False #@param
# newfs = 10000 #@param

#@markdown After you have filled out all form fields, 
#@markdown run this code cell to load the data. 

filepath = Path(filepath)

# No need to edit below this line
#################################
data = np.fromfile(Path(filepath), dtype = np.float64)
data = data.reshape(-1,number_channels)
data_dur = np.shape(data)[0]/sampling_rate
print('duration of recording was %0.2f seconds' %data_dur)

fs = sampling_rate
# if downsample:
   ## newfs = 10000 #downsample emg data
    # chunksize = int(sampling_rate/newfs)
    # data = data[0::chunksize,:]
    # fs = int(np.shape(data)[0]/data_dur)

time = np.linspace(0,data_dur,np.shape(data)[0])

print('Data upload completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Visualize Raw Data

Use this visualization tool if desired to get a sanity check that the data you thought you imported is actually the data that got imported.

#@title {display-mode: "form"}

#@markdown Run this code cell to plot imported data. <br> 
#@markdown Use the range slider to scroll through the data in time.
#@markdown Use the channel slider to choose which channel to plot
#@markdown Be patient with the range refresh... the more data you are plotting the slower it will be. 

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,data_dur),
    step= 0.5,
    readout=True,
    continuous_update=False,
    description='Time Range (s)')
slider_xrange.layout.width = '600px'

slider_chan = widgets.IntSlider(
    min=0,
    max=number_channels-1,
    value=0,
    step= 1,
    continuous_update=False,
    description='channel')
slider_chan.layout.width = '300px'

# a function that will modify the xaxis range
def update_plot(x,chan):
    fig, ax = plt.subplots(figsize=(10,5),num=1); #specify figure number so that it does not keep creating new ones
    starti = int(x[0]*fs)
    stopi = int(x[1]*fs)
    ax.plot(time[starti:stopi], data[starti:stopi,chan])

w = interact(update_plot, x=slider_xrange, chan=slider_chan);

For a more extensive RAW Data Explorer than the one provided in the above figure, use the DataExplorer.py application found in the howto section of the course website.

Section I. Motor Unit Coding and Recruitment

How does motor unit activity within a muscle change as the load requirements of the movement changes?

Detect motor unit events

Python has built-in algorithms for detecting “peaks” in a signal. However, it will detect all peaks. Therefore, the function takes in arguments that specify parameters for minimum height that can count as a peak and a minimum acceptible interval between independent peaks.

First, we will detect all the peaks in the signal within a set of thresholds. This will give the time of each peak.

#@title {display-mode: "form"}

#@markdown Indicate which channel has the EMG signal (should be 0).

spike_channel = 0 #@param

#@markdown Then, run the code cell to create an interactive plot with a slider to scroll 
#@markdown through the raw data and set an upper and lower peak detection threshold.
#@markdown You can set the polarity of the peak detection: upward (1) or downward (-1) peaks. 
#@markdown Peak times (according to your threshold) will be plotted using red markers. <br>

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,data_dur),
    step=0.01,
    readout_format='.2f',
    continuous_update=False,
    readout=True,
    description='xrange (s)'
)
slider_xrange.layout.width = '600px'

slider_yrange = widgets.FloatRangeSlider(
    min=np.min(data[:,spike_channel])-0.1,
    max=np.max(data[:,spike_channel])+0.1,
    value=[np.min(data[:,spike_channel])-0.1,np.max(data[:,spike_channel])+0.1],
    step=0.01,
    readout_format='.2f',
    continuous_update=False,
    readout=True,
    description='yrange'
)
slider_yrange.layout.width = '600px'



slider_threshold_low = widgets.FloatSlider(
    min=0,
    max=np.max([np.max(data[:,spike_channel]),np.abs(np.min(data[:,spike_channel]))])+0.1,
    value=0,
    step=0.001,
    readout_format='.3f',
    continuous_update=False,
    readout=True,
    description='lower threshold')
slider_threshold_low.layout.width = '600px'

slider_threshold_high = widgets.FloatSlider(
    min=0,
    max=np.max([np.max(data[:,spike_channel]),np.abs(np.min(data[:,spike_channel]))])+0.1,
    value=np.max([np.max(data[:,spike_channel]),np.abs(np.min(data[:,spike_channel]))])+0.1,
    step=0.001,
    readout_format='.3f',
    continuous_update=False,
    readout=True,
    description='upper threshold')
slider_threshold_high.layout.width = '600px'

radio_polarity = widgets.RadioButtons(
    options=[1, -1],
    value=1,
    description='peaks polarity',
    disabled=False
)

iei_text = widgets.Text(
    value='0.005',
    placeholder='0.005',
    description='min IEI (seconds)',
    style = {'description_width': '200px'},
    disabled=False
)


# a function that will modify the xaxis range
def update_plot(xrange,yrange,thresh_low_,thresh_high_,polarity,iei):
    fig, ax = plt.subplots(figsize=(10,6),num=1); #specify figure number so that it does not keep creating new ones
    fig.tight_layout()    
    
    win_0 = int(xrange[0]*fs)
    win_1 = int(xrange[1]*fs)

    xtime = np.linspace(xrange[0],xrange[1],(win_1 - win_0))

    ax.plot(xtime,data[win_0:win_1,spike_channel],color='black',linewidth=1)
    ax.set_ylim(yrange[0],yrange[1]);
    
    ax.hlines(thresh_low_*polarity, xrange[0],xrange[1],linestyle='--',color='green',zorder=3)
    ax.hlines(thresh_high_*polarity, xrange[0],xrange[1],linestyle='--',color='orange',zorder=3)
    
    
    # calculate spike times based on threshold
    d = float(iei)*fs #minimum time allowed between distinct events
    r = find_peaks(data[:,spike_channel]*polarity,height=thresh_low_,distance=d)

    spike_times = r[0]/fs
    mask_spikes = r[1]['peak_heights']<thresh_high_
    # spike_times = spike_times[mask_spikes]

    # inwin_inds = []
    # for b_ in bouts_list:
    #     inwin_inds.extend(r[0][(spike_times>b_[0]) & (spike_times<b_[1]) & mask_spikes])
    # inwin_inds = np.asarray(inwin_inds)
    # inwin_inds = np.in1d(r[0],inwin_inds)
    
    # df_props = pd.DataFrame({
    #         'height': r[1]['peak_heights'][inwin_inds]*polarity,
    #         'spikeT' : spike_times[inwin_inds],
    #         'spikeInd' : r[0][inwin_inds]
    #             })
    df_props = pd.DataFrame({
            'height': r[1]['peak_heights'],
            'spikeT' : spike_times,
            'spikeInd' : r[0]
                })
    
    spike_times = spike_times[mask_spikes]
    
    inwin_spikes = spike_times[(spike_times>(xrange[0])) & (spike_times<(xrange[1]))]
    ax.scatter(inwin_spikes,[np.mean(data[:,spike_channel])] * len(inwin_spikes),
          zorder=3,color='red',s=20)

#     for b_ in bouts_list:
#         ax.axvspan(b_[0], b_[1], color = 'black', alpha=0.1)    


    return spike_times,df_props
    

w_spikes_ = interactive(update_plot, xrange=slider_xrange, 
             yrange=slider_yrange,
             thresh_low_=slider_threshold_low,thresh_high_=slider_threshold_high,polarity=radio_polarity,iei=iei_text);

display(w_spikes_)
#@title {display-mode: "form"}

#@markdown Run this cell to finalize the list of spike times after settling on a threshold in the interactive plot. 
spike_times,df_props = w_spikes_.result

Define Bout Times

#@title {display-mode: "form"}

#@markdown For this experiment, you don't have precise trial onset markers, but you should be able to 
#@markdown estimate the start and stop time of bouts (bout = set of trials under the same condition). 
#@markdown Specify the bout ranges as follows: [[start of bout 1, end of bout 1], ... , [start of bout n, end of bout n]] <br>
#@markdown For plotting purposes, make sure the plots are in ascending/descending order of the weight lifted.

bouts_list = [[None,None],[None,None],[None,None]] #@param

#@markdown Then run this code cell to programatically define the bout time ranges 
#@markdown and organize the processed data according to bout.

for i,bout_ in enumerate(bouts_list):
    mask = ((spike_times>bout_[0]) & (spike_times<bout_[1]))
    df_props.loc[mask,'bout']=i
    
df_props['rate']=1/df_props['spikeT'].diff()

Plot processed data

#@title {display-mode: "form"}

#@markdown Run this cell to create a plot of instantaneous peak rate and a plot of peak amplitude at each peak time.

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,data_dur),
    step=0.01,
    readout_format='.2f',
    continuous_update=False,
    readout=True,
    description='xrange (s)'
)
slider_xrange.layout.width = '600px'

slider_yrange = widgets.FloatRangeSlider(
    min=np.min(data[:,spike_channel])-0.1,
    max=np.max(data[:,spike_channel])+0.1,
    value=[np.min(data[:,spike_channel])-0.1,np.max(data[:,spike_channel])+0.1],
    step=0.01,
    readout_format='.2f',
    continuous_update=False,
    readout=True,
    description='yrange'
)
slider_yrange.layout.width = '600px'


select_bouts = widgets.Select(
    options=np.arange(len(bouts_list)), # start with a single trial on a single bout... it will update when runs ; old: np.arange(len(trial_times)),
    value=0,
    #rows=10,
    description='Bouts',
    disabled=False
)

def update_plot(xrange,yrange,bout_):
    
    hfig,ax = plt.subplots(nrows=2,ncols=1,figsize=(10,8))
    mask = ((spike_times>bouts_list[bout_][0]) & (spike_times<bouts_list[bout_][1]))

    ax[0].scatter(spike_times[mask][1:],1/np.diff(spike_times[mask]))
    ax[0].set_xlabel('peak times',fontsize=14)
    ax[0].set_ylabel('peak rate',fontsize=14)

    ax[1].scatter(spike_times[mask],df_props[mask]['height'])
    ax[1].set_xlabel('peak times',fontsize=14)
    ax[1].set_ylabel('peak amplitude',fontsize=14)

w_spikes_processed = interactive(update_plot, xrange=slider_xrange, 
             yrange=slider_yrange,bout_ = select_bouts);

display(w_spikes_processed)

Since we don’t have precise trial onset markers, let’s just collapse across trials to examine the median (and 95% confidence intervals) of peak rate and amplitude within each bout.

#@title {display-mode: "form"}

#@markdown Run this cell to create a plot comparing the peak rate and amplitude across bouts <br> 
#@markdown > Note that all activity within each bout will count toward the results.

fig,ax = plt.subplots()
ax2=ax.twinx()

if int(sns.__version__.split('.')[1])<12:
    sns.pointplot(data=df_props, x="bout", y="rate",ax = ax, color='green',
                  estimator=np.median, ci=95)
    ax.set_ylabel("rate",color="green")
    ax.tick_params(axis ='y', labelcolor = 'green')

    sns.pointplot(data=df_props, x="bout", y="height",ax=ax2, color='purple',
                  estimator=np.median, ci=95)
    ax2.set_ylabel("amplitude",color="purple")
    ax2.tick_params(axis ='y', labelcolor = 'purple')

if int(sns.__version__.split('.')[1])>=12:
    sns.pointplot(data=df_props, x="bout", y="rate",ax = ax, color='green',
                  estimator='median', errorbar=('ci', 95))
    ax.set_ylabel("rate",color="green")
    ax.tick_params(axis ='y', labelcolor = 'green')

    sns.pointplot(data=df_props, x="bout", y="height",ax=ax2, color='purple',
                  estimator='median', errorbar=('ci', 95))
    ax2.set_ylabel("amplitude",color="purple")
    ax2.tick_params(axis ='y', labelcolor = 'purple')
    
sns.despine(fig,top=True, right=False, left=False, bottom=False,)

Section II. Fatigue

To process the EMG signal in this experiment, we will calculate a moving RMS value.

The signal throughout the entire recording will be analyzed, so make sure that the data was collected efficiently.

With the interactive plot you will generate below, you can zoom in through time during the recording and control x and y markers to readout the time and rms values, respectively.

#@title {display-mode: "form"}

#@markdown Run this code cell to create an interactive plot of the moving RMS value of the recorded EMG signal.

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,data_dur),
    step= 0.5,
    readout=True,
    continuous_update=False,
    description='Time Range (s)',
    style = {'description_width': '100px'})
slider_xrange.layout.width = '600px'

slider_chan = widgets.IntSlider(
    min=0,
    max=number_channels-1,
    value=0,
    step= 1,
    continuous_update=False,
    description='channel')
slider_chan.layout.width = '300px'

window_text = widgets.Text(
    value='1',
    placeholder='1',
    description='window duration (seconds)',
    style = {'description_width': '200px'},
    disabled=False
)

step_text = widgets.Text(
    value='0.5',
    placeholder='0.5',
    description='window steps (seconds)',
    style = {'description_width': '200px'},
    disabled=False
)

slider_rms = widgets.FloatSlider(
    min=0,
    max=np.sqrt(np.max(data[:,0]**2)),
    value=0,
    step= 0.5,
    continuous_update=False,
    description='rms value (purple)',
    style = {'description_width': '200px'})
slider_rms.layout.width = '600px'

slider_time = widgets.FloatSlider(
    min=0,
    max=data_dur,
    value=0,
    step= 0.1,
    continuous_update=False,
    description='time value (green; s)',
    style = {'description_width': '200px'})
slider_time.layout.width = '600px'


def update_plot(xrange,chan,window_size,step_size,rms_,t_):
    fig, ax = plt.subplots(figsize=(10,5),num=1); #specify figure number so that it does not keep creating new ones
    
    window_size = int(float(window_size)*fs)
    step_size= int(float(step_size)*fs)

    a = data[:,chan]

    rms_data = []
    for i,a_ in enumerate(a[0::step_size]):
        i = i*step_size
        if i+window_size < len(a):
            a_ = a[i:(i+window_size)]
            rms_data.append(mean_squared_error(a_, np.zeros(len(a_)), squared=False))

    # rms_data = window_rms(a, window_size)

    xtime = np.linspace(0,data_dur,np.shape(rms_data)[0])

    # hfig,ax = plt.subplots(figsize=(15,4))
    ax.plot(xtime,rms_data)
    ax.set_ylabel('rms',fontsize=14)
    ax.set_xlim(xrange[0],xrange[1])
    # ax.set_xlabel('time (s)',fontsize=14)

    bottom,top = ax.get_ylim()
    ax.vlines(t_,bottom,top,color = 'green')
    
    left,right = ax.get_xlim()
    ax.hlines(rms_,left,right,color='purple')

# w = interact(update_plot, xrange=slider_xrange, chan=slider_chan, start=slider_start, stop=slider_stop);

# display(rms_readout,w);

w = interactive_output(update_plot, {'xrange':slider_xrange,
                                     'chan':slider_chan,
                                     'window_size':window_text,
                                     'step_size':step_text,
                                     'rms_':slider_rms,
                                     't_':slider_time});
display(slider_chan,slider_xrange,w, 
        widgets.HBox([window_text,step_text]),
        slider_rms,slider_time)

Additional Resources (not needed)

#@title {display-mode: "form"}

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,data_dur),
    step= 0.5,
    readout=True,
    continuous_update=False,
    description='Time Range (s)')
slider_xrange.layout.width = '600px'

slider_chan = widgets.IntSlider(
    min=0,
    max=number_channels-1,
    value=0,
    step= 1,
    continuous_update=False,
    description='channel')
slider_chan.layout.width = '300px'

slider_start = widgets.FloatSlider(
    min=0,
    max=data_dur,
    value=0,
    step= 0.1,
    readout=True,
    continuous_update=False,
    description='start (s)')
slider_xrange.layout.width = '600px'

slider_stop = widgets.FloatSlider(
    min=0,
    max=data_dur,
    value=data_dur,
    step= 0.1,
    readout=True,
    continuous_update=False,
    description='stop (s)')
slider_xrange.layout.width = '600px'

rms_readout = widgets.Label(
    value=f'RMS within this window = {NaN}'
)
rms_readout.layout.width = '600px'

def update_plot(xrange,chan,start,stop):
    fig, ax = plt.subplots(figsize=(10,5),num=1); #specify figure number so that it does not keep creating new ones
    
    ax.plot(time, data[:,chan])
    ax.set_xlim(xrange[0],xrange[1])
    # bottom, top = ax.get_ylim()
    # ax.vlines(xrange[0],bottom,top,color='green')
    ax.axvspan(start, stop, color = 'black', alpha=0.1)
    
    starti = int(start*fs)
    stopi = int(stop*fs)
    
    rms_ = mean_squared_error(data[starti:stopi,chan], np.zeros(stopi-starti), squared=False)
    
    rms_readout.value = f'RMS within this window = {rms_}'

# w = interact(update_plot, xrange=slider_xrange, chan=slider_chan, start=slider_start, stop=slider_stop);

# display(rms_readout,w);

w = interactive_output(update_plot, {'xrange':slider_xrange,
                                     'chan':slider_chan,
                                     'start':slider_start,
                                     'stop':slider_stop});
display(rms_readout,w,widgets.VBox([slider_xrange, slider_chan,slider_start,slider_stop ]))