#*******************************************************************
#*   
#*   PCSIM Tutorial Exercises
#*
#*   Exercise 2: Balanced Random Network 
#*                 with Excitatory and Inhibitory Populations
#*
#*   Part 2: Random Distributions for Parameter Values 
#*           
#*      
#*       FIAS Theoretical Neuroscience Summer School, August 2008
#*   
#*       Author: Dejan Pecevski 
#*
#*******************************************************************

# We will need the pypcsim package for simulation,
# the pypcsimplus package for the spike raster plotting
# the numerical package numpy 
# and the pylab (matplotlib) package for the plotting
from pypcsimplus import *
from pypcsim import *
from numpy import *
from pylab import *

###################################################################
# Global parameter values
###################################################################
nNeurons        = 1000;   # number of neurons
ConnP           = 0.1;    # connectivity probability
Frac_EXC        = 0.8;    # fraction of excitatory neurons
nInputNeurons   = 100     # number of neurons which provide initial input (for a time span of Tinp)
inpConnP        = 0.1     # connectivity probability from input neurons to network neurons
inputFiringRate = 10      # firing rate of the input neurons during the initial input [spikes/sec]
Tsim            = 1.0     # simulation time 

# Create an empty network
net = SingleThreadNetwork()


###################################################################
# CREATING THE NEURONS
###################################################################
# Create the LIF neuron model used to generate
# the neurons in the populations
lifmodel = LifNeuron( Cm=2e-10, 
                      Rm=1e8, 
                      Vthresh=-50e-3, 
                      Vresting=-49e-3, 
                      Vreset=-60e-3, 
                      Trefract=5e-3, 
                      Vinit=-60e-3 ) ;

# create a variation factory based on the model
nrn_factory = SimObjectVariationFactory( lifmodel )

# associate the initial membrane potential parameter with a normal distribution
nrn_factory.set( 'Vinit', BndNormalDistribution(mu = -55e-3, 
                                                cv = 0.1, 
                                                lowerBound = -60e-3, 
                                                upperBound = -50e-3) )

# Create the population of excitatory neurons with the neuron variation factory
exc_nrn_popul = SimObjectPopulation(net, nrn_factory, int( nNeurons *  Frac_EXC ) );

# Create the population of inhibitory neurons with the variation factory
inh_nrn_popul = SimObjectPopulation(net, nrn_factory, nNeurons - exc_nrn_popul.size() );

# Combine the two populations into one population (for easier creation of connections and  recording)
all_nrn_popul = SimObjectPopulation(net, list(exc_nrn_popul.idVector()) + list(inh_nrn_popul.idVector()));

print "Created", exc_nrn_popul.size(), "exc and", inh_nrn_popul.size(), "inh neurons";

###################################################################
# SYNAPTIC CONNECTIONS
###################################################################
# create the excitatory synapse model
exc_syn_model = StaticSpikingSynapse( W=1.62e-11, tau=5e-3, delay=1e-3 )

# create the variation factory based on the model
exc_syn_factory = SimObjectVariationFactory( exc_syn_model )

# associate a gamma distribution with the weight parameter W
exc_syn_factory.set( "W", BndGammaDistribution(mu = 1.62e-11, cv = 0.1, 
                                               upperBound = 2 * 1.62e-11 ) )

print 'Making synaptic connections:'
# randomly connect the neurons from the excitatory population to all the neurons
# with probability of connP = 0.1, using the excitatory synapse factory
exc_syn_project = ConnectionsProjection( exc_nrn_popul, all_nrn_popul, 
                                      exc_syn_factory,
                                     RandomConnections( conn_prob = ConnP ))

# randomly connect the neurons from the inhibitory population to all the neurons
# with probability of connP = 0.1, using an inhibitory synapse model (with negative weight)
inh_syn_project = ConnectionsProjection( inh_nrn_popul, all_nrn_popul, 
                                     StaticSpikingSynapse( W=-10e-11, tau=10e-3, delay=1e-3 ), 
                                     RandomConnections( conn_prob = ConnP ))


###################################################################
#  INPUTS
###################################################################
# Create a population of input neurons emitting random Poisson process spikes
inp_nrn_popul = SimObjectPopulation( net, PoissonInputNeuron( rate = 10, duration = 10), 10) 

# Randomly connect the input neurons to the network neurons with probability inpConnP = 0.1
inp_syn_project = ConnectionsProjection( inp_nrn_popul, exc_nrn_popul, 
                          StaticSpikingSynapse( W=1.2e-10, tau=5e-3, delay=1e-3 ), 
                          RandomConnections( conn_prob = inpConnP ) );


###########################################################
# RECORDERS
###########################################################
# create a population of recorders for recording the spikes of the neurons 
# in the network
spike_rec_popul = all_nrn_popul.record(SpikeTimeRecorder())

# population of analog recorders for recording the membrane potential 
# of all neurons in the network
vm_rec_popul = all_nrn_popul.record( AnalogRecorder(), "Vm" )

###########################################################
# Simulate the network
###########################################################
print "Starting simulation..."
net.simulate( Tsim )
print "Done"

###########################################################
# PRESENTATION OF THE RESULTS
###########################################################

# collect all the spikes from the neurons into a list of numpy arrays
spikes = [ array(spike_rec_popul.object(i).getRecordedValues()) for i in range(spike_rec_popul.size()) ]

# plot the spike raster of the network activity
x, y = create_raster(spikes, 0, Tsim)
plot(x, y, '.')

# Necessary functions to calculate the following figures
def diff(x):
    return [ x[i+1] - x[i] for i in range(len(x)-1) ]


def meanisi(spikes):
    if( len(spikes) > 1):
        return mean(diff(spikes));
    else:
        return None

def cv(spikes):
    if( len(spikes) > 2):
        isi=diff(spikes);
        return std(isi) / mean(isi)
    else:
        return None

    
# Calculate the mean firing rate of the network activity
n = [ len(s) for s in spikes ]
print "spikes total", sum(n)
print "mean rate:", sum(n)/Tsim/len(n)

# Calculate the mean inter-spike interval
isi = [ meanisi(s) for s in spikes ]
isi = filter( lambda x: x != None, isi);
print "mean ISI:", sum(isi)/len(isi)

# Calculate the coefficient of variation
CV = [ cv(s) for s in spikes ]
CV = filter( lambda x: x != None, CV);
print "mean CV:", mean(CV)