Giant Fiber Excitability

In this lab you measured activity in the ventral nerve cord. With electrical stimulation, you evoked action potentials from the medial and lateral giant fibers. The main goal of today is to analyze basic physiological markers of excitability. One of the predominant analysis frameworks you will use is comparing across trials for which stimulation parameters varied.

Setup

toc

Import and define functions

#@title {display-mode: "form" }

#@markdown Run this code cell to import packages and define functions 
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import ndimage
from scipy.signal import hilbert,medfilt,resample, find_peaks, unit_impulse
from scipy.optimize import curve_fit
import seaborn as sns
from datetime import datetime,timezone,timedelta
pal = sns.color_palette(n_colors=15)
pal = pal.as_hex()
import matplotlib.pyplot as plt
import random
from numpy import NaN

from pathlib import Path

from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
from ipywidgets import widgets, interact, interactive,interactive_output
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle")


def hyper_fit(t,r,c):
    return r + ((r*c)/t)

print('Task completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Mount Google Drive

#@title {display-mode: "form" }

#@markdown Run this cell to mount your Google Drive.

from google.colab import drive
drive.mount('/content/drive')

print('Task completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Strength-Duration Curves

First, enter your stimulus duration and amplitude data into a .csv file. Then, you will import that .csv file so that you can fit a hyperbolic estimate to the data. Each row in the table should be a different stimulus condition. Each column in the table should be stimulus duration or amplitude data for each samples (different neurons and/or worms). The first row of the spreadsheet needs to contain a unique identifying header/name for each column.

Read, fit, and plot strength-duration threshold data

The following code cell will import the data from your .csv file and plot it as a scatterplot. Additionally, the data will be fit to the following hyperbolic equation:

\[ a = r + \frac{(r*c)}{t} \]

where \(a\) is the stimulus amplitude, \(r\) is the rheobase, \(c\) is the chronaxie, and \(t\) is the stimulus duration.

#@title {display-mode:"form"}

#@markdown Speficy the filepath to a csv file
filepath = 'full filepath goes here' #@param
# filepath = '/Users/kperks/Documents/Teaching/Neurophysiology-Lab/modules/earthworm-giant-fiber-ap/rheobase.csv' #@param
#@markdown Speficy the header name of one column you want for your x points
x_column = 'duration' #@param
#@markdown Speficy the header name of one column you want for your y points
y_column = 'amplitude' #@param

#@markdown Then run this code cell to plot the raw data and the estimated strength-duration curve for this one sample. 
#@markdown The 'r' and 'c' values will be printed out above the data plot.

df = pd.read_csv(filepath)
hfig,ax = plt.subplots(figsize=(6,5))
sns.scatterplot(x=x_column,y=y_column,data=df,color='black');
ax.set_xscale('log')


params, covs = curve_fit(hyper_fit, df[x_column].dropna(),df[y_column].dropna(),maxfev = 5000)

t_ = np.arange(0.01,3,0.01)
a_ = hyper_fit(t_,params[0],params[1])
ax.plot(t_,a_,color = 'gray',linestyle='--')
ax.hlines(params[0],0,3,color = 'green')
ax.hlines(2*params[0],0,3,color = 'green',linestyle='--')
ax.vlines(params[1],0,2*params[0],color='purple',linestyle='--')
ax.set_ylabel('amplitude (v)')
ax.set_xlabel('duration (ms)');
ax.set_ylim(0,10)

print(f'r = {params[0]}; c = {params[1]}')

Compare strength-duration estimate across samples

#@title {display-mode:"form"}

#@markdown Specify the rheobase for each sample
r = [None] #@param
#@markdown Specify the chronaxie for each sample
c = [None] #@param
#@markdown Specify a label for each sample
label = ['none'] #@param

##@markdown Specify whether you want to plot the x-axis as 'log' or 'linear'
x_axis_scale = 'log'

#@markdown Then run this code cell to plot the estimated strength-duration curve for each sample. 
#@markdown The chronaxie will be marked with scatter dot and a gray line connected to the x-axis for each sample.
#@markdown The x-axis will be log-scaled.

t_ = np.arange(0.01,3,0.01)

hfig,ax = plt.subplots(figsize=(6,5))

for r_,c_,label_ in zip(r,c,label):
    
    
    a_ = hyper_fit(t_,r_,c_)
    ax.plot(t_,a_,label=label_)
    ax.scatter(c_,2*r_,zorder=3)
    ax.vlines(c_,0,2*r_,color='gray')

ax.set_ylabel('amplitude (v)')
ax.set_xlabel('duration (ms)');
ax.set_ylim(0,10)
ax.set_xscale(x_axis_scale)
plt.legend();

Paired Pulse Protocol

In this section, you will load the raw data aquired and saved using Bonsai-rx as a .bin file. In this experiment, you used paired stimulus pulses instead of single stimulus pulses. Therefore, a single trial consists of two stimulus pulses. If there were no glitches in stimulation during the experiment and detection in the “stimulus times” processing step, then the even stimulus pulses (0,2,….n-1) should be the ‘control’ pulse and the odd stimulus pulses (1,3,…n) should be the ‘experimental’ pulses. You can make adjustments to your records as needed if there were abherent pulses.

You can also use the data visualization tools in this section to get plots of example trials from your rheobase/chronaxie data.

Import data

Import data digitized with Nidaq USB6211 and recorded using Bonsai-rx as a .bin file

If you would like to explore the analysis for this lab, but do not have data, you can download examples for the following experiments using the linked shared files:
- Medial Giant stimulus-duration pulse examples
- Medial Giant Fiber Paired Pulse stimulation
- Lateral Giant stimulus-duration pulse examples
- Lateral Giant Fiber Paired Pulse stimulation
For all examples, the sample rate was 30000 with two channels (channel 0 was the nerve signal and channel 1 was the stimulus monitor). The distance between stimulation and ch0 measurement electrodes was not measured.

#@title {display-mode: "form" }

#@markdown Specify the file path 
#@markdown to your recorded data in the colab runtime (find the filepath in the colab file manager):

filepath = "full filepath goes here"  #@param 
# filepath = '/Volumes/Untitled/BIOL247/data/earthworm-giant-fiber-cap/diff_cv_0.bin'  #@param 
# filepath = '/Volumes/Untitled/BIOL247/data/earthworm-giant-fiber-ap/KP_20220930_30khz_2chan/rheo-chr-mgf.bin'

#@markdown Specify the sampling rate and number of channels recorded.

sampling_rate = NaN #@param
number_channels = NaN #@param

# sampling_rate = 30000 #@param
# number_channels = 2 #@param

# downsample = False #@param
# newfs = 10000 #@param

#@markdown After you have filled out all form fields, 
#@markdown run this code cell to load the data. 

filepath = Path(filepath)

# No need to edit below this line
#################################
data = np.fromfile(Path(filepath), dtype = np.float64)
data = data.reshape(-1,number_channels)
data_dur = np.shape(data)[0]/sampling_rate
print('duration of recording was %0.2f seconds' %data_dur)

fs = sampling_rate
# if downsample:
#     # newfs = 10000 #downsample emg data
#     chunksize = int(sampling_rate/newfs)
#     data = data[0::chunksize,:]
#     fs = int(np.shape(data)[0]/data_dur)

time = np.linspace(0,data_dur,np.shape(data)[0])

print('Data upload completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Plot raw data

#@title {display-mode: "form"}

#@markdown Run this code cell to plot the imported data. <br> 
#@markdown Use the range slider to scroll through the data in time.
#@markdown Use the channel slider to choose which channel to plot
#@markdown Be patient with the range refresh... the more data you are plotting the slower it will be. 

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,1),
    step= 1,
    readout=True,
    continuous_update=False,
    description='Time Range (s)')
slider_xrange.layout.width = '600px'

slider_chan = widgets.IntSlider(
    min=0,
    max=number_channels-1,
    value=0,
    step= 1,
    continuous_update=False,
    description='channel')
slider_chan.layout.width = '300px'

# a function that will modify the xaxis range
def update_plot(x,chan):
    fig, ax = plt.subplots(figsize=(10,5),num=1); #specify figure number so that it does not keep creating new ones
    starti = int(x[0]*fs)
    stopi = int(x[1]*fs)
    ax.plot(time[starti:stopi], data[starti:stopi,chan])

w = interact(update_plot, x=slider_xrange, chan=slider_chan);

For a more extensive RAW Data Explorer than the one provided in the above figure, use the DataExplorer.py application found in the howto section of the course website.

Define bout and stimulus times

The time between stimulus onset and action potential, and the time between two stimulus pulses are critical parameters of the data on each trial.

Our first task in processing and analyzing data from this experiment is to figure out the stimulus onset times. You can then segment the data in to separate bouts if the raw recording was not one continuous successful protocol.

Define stimulus times

#@title {display-mode: "form"}

#@markdown Run this cell to create an interactive plot with a slider to scroll 
#@markdown through the signal
#@markdown and set an appropriate event detection threshold  
#@markdown (you can do so based on level crossing or peaks). 

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,1),
    step= 0.05,
    readout=True,
    continuous_update=False,
    description='Time Range (s)',
    style = {'description_width': '200px'})
slider_xrange.layout.width = '600px'

# slider_yrange = widgets.FloatRangeSlider(
#     min=np.min(stim)-0.5,
#     max=np.max(stim)+0.5,
#     value=[np.min(stim),np.max(stim)],
#     step=0.05,
#     continuous_update=False,
#     readout=True,
#     description='yrange',
#     style = {'description_width': '200px'})
# slider_yrange.layout.width = '600px'

select_channel = widgets.Select(
    options=np.arange(np.shape(data)[1]), # start with a single trial on a single bout... it will update when runs ; old: np.arange(len(trial_times)),
    value=0,
    #rows=10,
    description='Channel used to detect events',
    style = {'description_width': '200px'},
    disabled=False
)

slider_threshold = widgets.FloatSlider(
    min=-2,
    max=2,
    value=0.2,
    step=0.001,
    readout_format='.3f',
    continuous_update=False,
    readout=True,
    description='event detection threshold',
    style = {'description_width': '200px'})
slider_threshold.layout.width = '600px'

detect_type_radio = widgets.RadioButtons(
    options=['peak', 'level crossing'],
    value='peak', # Defaults to 'level crossing'
    layout={'width': 'max-content'}, # If the items' names are long
    description='Type of event detection',
    style = {'description_width': '200px'},
    disabled=False
)

radio_polarity = widgets.RadioButtons(
    options=[1, -1],
    value=-1,
    description='peaks polarity',
    disabled=False,
    style = {'description_width': '200px'}
)

iei_text = widgets.Text(
    value='0.005',
    placeholder='0.005',
    description='min IEI (seconds)',
    style = {'description_width': '200px'},
    disabled=False
)

def update_plot(chan_,xrange,thresh_,detect_type,polarity,iei):
    fig, ax = plt.subplots(figsize=(10,5),num=1); #specify figure number so that it does not keep creating new ones
    
    signal = data[:,chan_]
    signal = signal-np.median(signal)
    
    iei = float(iei)
    
    if iei>0.001:
        d = iei*fs #minimum time allowed between distinct events
        
        if detect_type == 'peak':
            r = find_peaks(signal*polarity,height=thresh_,distance=d)
            trial_times = r[0]/fs
            # print(r[1])
            ax.scatter(trial_times,[r[1]['peak_heights']*polarity],marker='^',s=300,color='purple',zorder=3)
            
        if detect_type == 'level crossing':
            # get the changes in bool value for a bool of signal greater than threshold
            # if polarity == 1:
            threshold_crossings = np.diff(signal*polarity > thresh_, prepend=False)
            # get indices where threshold crossings are true
            tcross = np.argwhere(threshold_crossings)[:,0]
            # get a mask for only positive level crossings
            mask_ = [signal[t]-signal[t-1] > 0 for t in tcross]
            # if polarity == -1:
            #     threshold_crossings = np.diff(signal*polarity < thresh_*polarity, prepend=False)
            #     # get indices where threshold crossings are true
            #     tcross = np.argwhere(threshold_crossings)[:,0]
            #     # get a mask for only positive level crossings
            #     mask_ = [signal[t]-signal[t-1] > 0 for t in tcross]
                
            # trial times are positive level crossings
            trial_times = tcross[mask_]/fs
            ax.scatter(trial_times,[thresh_*polarity]*len(trial_times),marker='^',s=300,color='purple',zorder=3)

        starti = int(xrange[0]*fs)+1
        stopi = int(xrange[1]*fs)-1
        ax.plot(time[starti:stopi], signal[starti:stopi], color='black')
        
        # ax.plot(tmp,color='black')
        ax.hlines(thresh_*polarity, time[starti],time[stopi],linestyle='--',color='green')
        
        # ax.set_ylim(yrange[0],yrange[1])
        ax.set_xlim(xrange[0],xrange[1])
        

        ax.xaxis.set_minor_locator(AutoMinorLocator(5))

              
        return trial_times

w_trials_ = interactive(update_plot, chan_=select_channel, 
                        xrange=slider_xrange, 
                        thresh_=slider_threshold, detect_type = detect_type_radio, 
                        polarity=radio_polarity, iei = iei_text);
display(w_trials_)
#@title {display-mode: "form"}

#@markdown Run this cell to finalize the list of event times 
#@markdown after settling on a channel and threshold in the interactive plot. <br> 
#@markdown This stores the event times in an array called 'event_times'. <br>
#@markdown NOTE: You may have to use "peaks" method for shorter stimulus pulse durations (separately).
trial_times = w_trials_.result

Define Bouts

#@title {display-mode: "form"}

#@markdown For this experiment, the entire file should be one long bout, 
#@markdown but if there were regions that something got messed up or that you want to exclude, you can specify bouts with good data.
#@markdown Specify the list of bout ranges as follows: [[start of bout 0, end of bout 0],[start 1, end 1],...]] <br>

bouts_list = [[NaN,NaN]] #@param
# bouts_list = [[2,10],[10,20],[20,30],[30,45],[45,55],[55,70],[70,85],[85,100],[100,120]]
# bouts_list = [[0,20]]

#@markdown Then run this code cell to programatically define the list of bouts as 'bouts_list'.

Analyze Data

Measure the raw data

Obtain necessary information from the raw signal time-locked to each event (which should be the stimulus pulse onsets).

Just to give you a ballpark, this data processing step took me about 20 minutes for a paired pulse experiment in which I tested 18 different ISI values (with 2 trials at each value). So there were a total of 72 events that I processed data for.

#@title {display-mode:"form"}

#@markdown Run this code cell to create an interactive plot to  
#@markdown examine the raw signal time-locked to individual events (event_times).
#@markdown You can overlay multple channels by selecting more than one.
#@markdown You can overlay multiple event times by selecting more than one. 
#@markdown (To select more than one item from an option menu, press the control/command key 
#@markdown while mouse clicking or shift while using up/down arrows)

slider_xrange = widgets.FloatRangeSlider(
    min=-0.01,
    max=0.05,
    value=(-0.001,0.03),
    step=0.0005,
    continuous_update=False,
    readout=True,
    readout_format='.4f',
    description='xrange (s)'
)
slider_xrange.layout.width = '600px'

slider_yrange = widgets.FloatRangeSlider(
    min=-5,
    max=5, # normal range for earthworm experiments
    value=(-0.5,1),
    step=0.05,
    continuous_update=False,
    readout=True,
    description='yrange'
)
slider_yrange.layout.width = '600px'

ui_range = widgets.VBox([slider_xrange, slider_yrange])

# trials in bout 0 to start...
trials_t = trial_times[(trial_times>bouts_list[0][0]) & (trial_times<bouts_list[0][1])]

odd_even_radio = widgets.RadioButtons(
    options=['odd', 'even', 'all'],
    value='all', # Defaults to 'none'
    layout={'width': 'max-content'}, # If the items' names are long
    description='show only events by: ',
    style = {'description_width': '400px'},
    disabled=False
)

select_channels = widgets.SelectMultiple(
    options=np.arange(np.shape(data)[1]), # start with a single trial on a single bout... it will update when runs ,
    value=[0],
    #rows=10,
    description='Channels',
    disabled=False
)

select_bouts = widgets.Select(
    options=np.arange(len(bouts_list)), # start with a single trial on a single bout... it will update when runs ; old: np.arange(len(trial_times)),
    value=0,
    #rows=10,
    description='Bouts',
    disabled=False
)

select_trials = widgets.SelectMultiple(
    options=np.arange(len(trials_t)), # start with a single trial on a single bout... it will update when runs ,
    value=[0],
    #rows=10,
    description='Events',
    disabled=False
)

ui_trials = widgets.HBox([select_channels, select_trials, select_bouts])

slider_threshold = widgets.FloatSlider(
    min=-1,
    max=1,
    value=0.25,
    step=0.001,
    readout_format='.3f',
    continuous_update=False,
    readout=True,
    description='peak detection threshold',
    style = {'description_width': '200px'})
slider_threshold.layout.width = '600px'

detect_chan_radio = widgets.RadioButtons(
    options=['0', '1', 'none'],
    value='none', # Defaults to 'none'
    layout={'width': 'max-content'}, # If the items' names are long
    description='detect delay to peaks on channel: ',
    style = {'description_width': '400px'},
    disabled=False
)

ui_peaks = widgets.VBox([detect_chan_radio, slider_threshold])

trial_abs_readout = widgets.Label(
    value=f'time of this event is (sec): {NaN}'
)
trial_abs_readout.layout.width = '600px'

trial_readout = widgets.Label(
    value=f'time since last event is: {NaN}'
)
trial_readout.layout.width = '600px'

lagging_time_readout = widgets.Label(
    value=f'lagging peak times are: {NaN}'
)
lagging_time_readout.layout.width = '600px'

lagging_amp_readout = widgets.Label(
    value=f'lagging peak amplitudes are: {NaN}'
)
lagging_amp_readout.layout.width = '600px'

def update_plot(trial_sort_,chan_list,trial_list,bout_,yrange,xrange,lagging_chan_,thresh_):
    fig, ax = plt.subplots(figsize=(8,4))# ,ncols=1, nrows=1); #specify figure number so that it does not keep creating new ones
 
    win_0 = int(xrange[0]*fs)
    win_1 = int(xrange[1]*fs)
    xtime = np.linspace(xrange[0],xrange[1],(win_1 - win_0))
    
    trials_t = trial_times[(trial_times>bouts_list[bout_][0]) & (trial_times<bouts_list[bout_][1])]
    trials_init_ = np.arange(len(trials_t))
    
    if trial_sort_=='all':                     
        select_trials.options = trials_init_

        trial_list = [t_try for t_try in trial_list if t_try in trials_init_]
        select_trials.value = trial_list

    if trial_sort_=='even':                     
        select_trials.options = trials_init_[0::2]

        trial_list = [t_try for t_try in trial_list if t_try in trials_init_[0::2]]
        select_trials.value = trial_list                         
    
    if trial_sort_=='odd':                     
        select_trials.options = trials_init_[1::2]

        trial_list = [t_try for t_try in trial_list if t_try in trials_init_[1::2]]
        select_trials.value = trial_list                                 
    
    lagging_time_readout.value=f'lagging peak times are: {NaN}'
    lagging_amp_readout.value=f'lagging peak amplitudes are: {NaN}'
    trial_abs_readout.value=f'time of this event is: {NaN}'
    trial_readout.value=f'time since last event is: {NaN}'
    
    channel_colors = ['purple','green','blue','orange']
    for chan_ in chan_list:
        this_chan = data[:,chan_]
        for trial_ in trial_list:
            if trial_ in trials_init_:
                t_ = trials_t[trial_]

                if ((int(fs*t_)+win_0)>0) & ((int(fs*t_)+win_1))<len(this_chan):
                    data_sweep = this_chan[(int(fs*t_)+win_0):(int(fs*t_)+win_1)]

                    ax.plot(xtime,data_sweep,color=channel_colors[chan_],linewidth=2,alpha=0.5)
    
    d = 0.0005*fs
    if (lagging_chan_ != 'none') & (len(trial_list)==1):
        ax.hlines(thresh_, xrange[0],xrange[1],linestyle='--',color='green')
        lagging_chan_ = int(lagging_chan_)
        lagging_signal = data[(int(fs*t_)+win_0):(int(fs*t_)+win_1),lagging_chan_]
        if thresh_ >=0:
            r = find_peaks(lagging_signal,height=thresh_,distance=d)
            lagging_peak_amp = r[1]['peak_heights']
        if thresh_ <0:
            r = find_peaks(-1*lagging_signal,height=-1*thresh_,distance=d)
            lagging_peak_amp = -1*r[1]['peak_heights']
            # print(r)
            
        lagging_peak_amp = [np.round(a,2) for a in lagging_peak_amp]
        
        lagging_peak_times = [np.round(xtime[lt],5) for lt in r[0]]#r[0]/fs
        lagging_time_readout.value=f'lagging peak times are (ms): {[t*1000 for t in lagging_peak_times]}'
        lagging_amp_readout.value=f'lagging peak amplitudes are (V): {lagging_peak_amp}'
        
        if trial_list[0] == 0:
            trial_readout.value=f'time since last event is: {NaN}'
            trial_abs_readout.value=f'time of this event is: {NaN}'
        if trial_list[0] > 0:
            iti = 1000*(trials_t[trial_list[0]] - trials_t[trial_list[0]-1])
            trial_readout.value=f'time since last event is (msec): {iti:.2f}'
            trial_abs_readout.value=f'time of this event is (sec): {trials_t[trial_list[0]]}'
        
        for lt_ in lagging_peak_times:
            ax.scatter(lagging_peak_times,lagging_peak_amp,color='black',s=50,zorder=3)
    

    ax.set_ylim(yrange[0],yrange[1]);
    ax.set_xlabel('seconds')
    # ax.vlines(0,yrange[0],yrange[1],color='black')

    
#     # Change major ticks to show every 20.
    # ax_pwm.xaxis.set_major_locator(MultipleLocator(5))
    # ax_pwm.yaxis.set_major_locator(MultipleLocator(5))

    # # Change minor ticks to show every 5. (20/4 = 5)
    # ax_mro.yaxis.set_minor_locator(AutoMinorLocator(10))
    ax.xaxis.set_minor_locator(AutoMinorLocator(10))
    # ax_pwm.yaxis.set_minor_locator(AutoMinorLocator(5))

#     # Turn grid on for both major and minor ticks and style minor slightly
# #     # differently.
    ax.grid(which='major', color='gray', linestyle='-')
    ax.grid(which='minor', color='gray', linestyle=':')
#     ax_pwm.grid(which='major', color='gray', linestyle='-')
#     ax_pwm.grid(which='minor', color='gray', linestyle=':')


# display(w)
w = interactive_output(update_plot, {'trial_sort_':odd_even_radio,
                                     'chan_list':select_channels,
                                     'trial_list':select_trials, 
                                     'bout_':select_bouts,
                                     'yrange':slider_yrange, 
                                     'xrange':slider_xrange,
                                     'lagging_chan_':detect_chan_radio,
                                     'thresh_':slider_threshold});
display(trial_abs_readout,trial_readout,lagging_time_readout,lagging_amp_readout,
        odd_even_radio,ui_trials,ui_peaks,w,ui_range)

Plot the processed data

#@title {display-mode:"form"}

#@markdown Specify the filepath to a csv file
filepath = 'filepath to csv' #@param

#@markdown Specify the header name of the column you want for your x points. 
#@markdown If more than one header is specified (separated by commas), each will be plotted overlaid in a different color for a violin plot.
x_column = ['column header(s) for x axis'] #@param

# #@markdown Specify categorical bins using np.arange(start,stop,step) if the x_column is a continuous variable. 
# #@markdown Use empty brackets if not needed.
# categorical_bins = np.arange(0,25,0.5) #@param

#@markdown Specify the header name of the column you want for your y points. 
#@markdown If more than one header is specified (separated by commas), each will be plotted overlaid in a different color for a scatter plot
y_column = ['column header(s) for y axis'] #@param

#@markdown Specify the plot type ('scatter' or 'violin'). Note that for a 'violin' plot, only the 'x_column' data would be used.
plot_type = 'plot type' #@param

#@markdown Specify the x-axis range for the plot
x_lim = [0,10] #@param


df = pd.read_csv(filepath)

hfig,ax = plt.subplots(figsize=(10,5))

if plot_type == 'scatter':
    df_melted = df[y_column+x_column].melt(x_column[0],var_name='headers')
    g = sns.scatterplot(data=df_melted,x=x_column[0],y='value',hue='headers',alpha=0.75);
            
if plot_type == 'point':
    df_melted = df[y_column+x_column].melt(x_column[0],var_name='headers')
    
    if len(categorical_bins)>0:
        df_melted[x_column[0]] = pd.cut(df_melted[x_column[0]],bins=categorical_bins,labels=categorical_bins[1:])
        
    g = sns.pointplot(data=df_melted,x=x_column[0],y='value',hue='headers',alpha=0.75);

        
if plot_type == 'violin':
  # sns.stripplot(y=y_column,data=df,color='black',size=10);
    if len(x_column)==1:
        g = sns.violinplot(x=x_column[0],data=df,color='grey')
    if len(x_column)>1:
        df_melted = df[x_column].dropna().melt(var_name='headers')
        g = sns.violinplot(x='value',y='headers',split=True,data=df_melted, inner="stick")
        
g.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=1)
ax.set_xlim(x_lim[0],x_lim[1])

Written by Dr. Krista Perks for courses taught at Wesleyan University.