#================================================================================
#
#  PyPCSIM implementation of a variation of a benchmark simulation described in 
#  the paper "Simulation of networks of spiking neurons: A review of tools 
#  and strategies"
#
#  Benchmark 2: Current-based (CUBA) IF network. This benchmark consists of a 
#               network of intefrate-and-fire neurons connected with 
#               current-based synapses.
#
#  PyPCSIM is freely available from http://www.sf.net/projects/pcsim
#
#  Authors: Dejan Pecevski, dejan@igi.tugraz.at
#           Thomas Natschlaeger, thomas.natschlaeger@scch.at
#
#  Date: September 2006
#
#================================================================================
import sys

sys.path.append('../_build/lib')
from pypcsim import *
import random, getopt
from datetime import datetime
from math import *

random.seed( datetime.today().microsecond )
random.seed( 134987 )
tstart=datetime.today()

###################################################################
# Global parameter values
###################################################################

nNeurons        = 4000;   # number of neurons
minDelay        = 1e-3;   # minimum synapse delay [sec]
ConnP           = 0.02;   # connectivity probability
Frac_EXC        = 0.8;    # fraction of excitatory neurons
Tsim            = 0.4;    # duration of the simulation [sec]
DTsim           = 1e-4;   # simulation time step [sec]
nRecordNeurons  = 250;    # number of neurons to plot the spikes from
Tinp            = 50e-3;  # length of the initial stimulus [sec]
nInputNeurons   = 10 ;    # number of neurons which provide initial input (for a time span of Tinp)
inpConnP        = 0.01 ;  # connectivity from input neurons to network neurons
inputFiringRate = 80;     # firing rate of the input neurons during the initial input [spikes/sec]

nThreads = 1
networkType = 'DST'
# Analyze command line options and override default values of certain parameters
optlist, args = getopt.getopt(sys.argv[1:] , '', ['networkType=', 'nthreads=' , 'nNeurons=', 'connectionProbability=', 'logfile'] )

for opt, arg in optlist:
    if opt == '--networkType':
      networkType = arg
    elif opt == '--nthreads':
        nThreads = int(arg)
    elif opt == '--nNeurons':
        nNeurons = int(arg)
    elif opt == '--connectionProbability':
        ConnP = float(arg)

def sub_time( t1, t2 ):
    return ( t1 - t2 ).seconds+1e-6*(t1-t2).microseconds;

###################################################################
# Create an empty network
###################################################################

sp = SimParameter( dt=Time.sec( DTsim ) , minDelay = Time.sec(minDelay), simulationRNGSeed = 345678, constructionRNGSeed = 349871 );

if networkType == 'ST':
    net = SingleThreadNetwork( sp )
elif networkType == 'MT':
        if nThreads < 2:
                nThreads = 2;
        print 'multi thread:', nThreads
        net = MultiThreadNetwork( nThreads, sp )
elif networkType == 'DST':
    net = DistributedSingleThreadNetwork( sp )
elif networkType == 'DMT':
    net = DistributedMultiThreadNetwork( nThreads, sp )

###################################################################
# Create the neurons and set their parameters
###################################################################
t0=datetime.today()
base_lifmodel = LifNeuron( Cm=2e-10, Rm=1e8, Vthresh=-50e-3, Vresting=-49e-3, Vreset=-60e-3, Trefract=5e-3, Vinit=-60e-3, Inoise=0.0e-9 ) ;

exc_lifmodel = SimObjectVariationFactory( base_lifmodel );
exc_lifmodel.set( "Vresting", UniformDistribution( -50e-3, -48e-3) )
exc_lifmodel.set( "Cm", UniformDistribution( 1.5e-10, 2.5e-10 ) )

inh_lifmodel = SimObjectVariationFactory( base_lifmodel );
inh_lifmodel.set( "Vthresh", UniformDistribution( -55e-3, -50e-3) )
inh_lifmodel.set( "Cm", UniformDistribution( 2.2e-10, 2.7e-10 ) )
 
exz_nrn = net.create( exc_lifmodel, int( nNeurons *  Frac_EXC ) );
inh_nrn = net.create( inh_lifmodel, nNeurons -len(exz_nrn) );
all_nrn = list(exz_nrn) + list(inh_nrn);

print "Created",len(exz_nrn),"exz and",len(inh_nrn),"inh neurons";

###################################################################
# Create synaptic connections
###################################################################

print 'Making synaptic connections:'
t0=datetime.today()

Erev_exc = 0;
Erev_inh = -80e-3;
Vmean    = -60e-3;
Wexc = (Erev_exc-Vmean)*0.27e-9;
Winh = (Erev_inh-Vmean)*4.5e-9;

exz_syn = SimObjectVariationFactory( StaticSpikingSynapse( W=Wexc, tau= 5e-3, delay=1e-3 ) );
exz_syn.set( "tau", UniformDistribution( 4e-3, 6e-3 ) )

inh_syn = SimObjectVariationFactory( StaticSpikingSynapse( W=Winh, tau=10e-3, delay=1e-3 ) );
inh_syn.set( "tau", UniformDistribution( 9e-3, 13e-3 ) )

n_exz_syn = net.connect( exz_nrn, all_nrn, exz_syn, RandomConnections( conn_prob = ConnP ) )[0]
n_inh_syn = net.connect( inh_nrn, all_nrn, inh_syn, RandomConnections( conn_prob = ConnP ) )[0]

print 'Created',n_exz_syn+n_inh_syn,'current based synapses in',sub_time( datetime.today(), t0 ),'seconds (nex=', n_exz_syn, ' nin=', n_inh_syn,')'
constructionTime = sub_time( datetime.today() , t0 )

###################################################################
# Create input neurons for the initial stimulus
# and connect them to random neurons in circuit
###################################################################

#inp_nrn = [ net.add( SpikingInputNeuron( [ random.uniform(0,Tinp) for x in range( int(inputFiringRate*Tinp) ) ] ) ) for i in range(nInputNeurons) ];
#
#net.connect( inp_nrn, all_nrn, StaticSpikingSynapse( W=(Erev_exc-Vmean)*2e-9, tau=5e-3, delay=1e-3 ), RandomConnections( conn_prob = inpConnP ) );

###########################################################
# Create recorders to record spikes and voltage traces
###########################################################

spike_rec = range( len(all_nrn)  )
for i in range( len(all_nrn) ):    
    spike_rec[i] = net.add( SpikeTimeRecorder(), SimEngine.ID(0,0));
    net.connect( all_nrn[i], spike_rec[i] , Time.ms(1));
    
#rec_nrn = random.sample( all_nrn, nRecordNeurons );
#vm_rec = net.add( AnalogRecorder, [ SimEngine.ID(0,0) for i in range( nRecordNeurons ) ] );
#for i in range( nRecordNeurons ):
#    net.connect( rec_nrn[i], 'Vm', vm_rec[i], 0 );

###########################################################
# Simulate the circuit
###########################################################

print 'Running simulation:';
t0=datetime.today()

net.reset();
net.advance( int( Tsim / DTsim ) )

simulationTime = sub_time( datetime.today(), t0 );
print 'Done.', simulationTime, 'sec CPU time for', Tsim*1000, 'ms simulation time';
totalTime = sub_time( datetime.today(), tstart );
print '==> ', totalTime, 'seconds total'


###########################################################
# Make some figures out of simulation results
###########################################################

def diff(x):
    return [ x[i+1] - x[i] for i in range(len(x)-1) ]
   
def std(data):
    mu = sum(data)/float(len(data));
    r=[ (x-mu)*(x-mu) for x in data ];
    return sqrt(sum(r)/(len(r)-1))

def mean(data):
    return sum(data)/len(data);

def meanisi(spikes):
    if( len(spikes) > 1):
        return mean(diff(spikes));
    else:
        return None

def cv(spikes):
    if( len(spikes) > 2):
        isi=diff(spikes);
        return std(isi) / mean(isi)
    else:
        return None

if net.mpi_rank() == 0:    
    n = [ net.object(spike_rec[i]).spikeCount() for i in range( len(all_nrn) ) ]
    print "spikes total", sum(n)
    meanRate = sum(n)/Tsim/len(n)
    print "mean rate:", meanRate
    
    isi = [ meanisi(net.object(spike_rec[i]).getSpikeTimes()) for i in range( len(all_nrn) ) ]
    isi = filter( lambda x: x != None, isi);
    print "mean ISI:",sum(isi)/len(isi)
    
    
    CV = [ cv(net.object(spike_rec[i]).getSpikeTimes()) for i in range( len(all_nrn) ) ]
    CV = filter( lambda x: x != None, CV);
    print "mean CV:",mean(CV)

    nMachines = 0
    if networkType in ("ST", "MT"):
        nMachines = 1
    else: 
        nMachines = net.mpi_size() 
    print 'BEGIN:%d,%d,%s,%d,%.2f,%.2f,%.2f,%.2f:END' % (nMachines, nThreads, networkType, nNeurons, constructionTime, simulationTime, totalTime, meanRate)
