#================================================================================
#
#  PyPCSIM implementation of a benchmark simulation described in the paper
#  "Simulation of networks of spiking neurons: A review of tools and strategies"
#
#  Benchmark 1: Conductance-based (COBA) IF network. This benchmark consists of a 
#               network of integrate-and-fire neurons connected with 
#               conductance-based synapses.
#
#  PyPCSIM is freely available from www.lsm.tugraz.at/csim
#
#  Authors: Dejan Pecevski, dejan@igi.tugraz.at
#           Thomas Natschlaeger, thomas.natschlaeger@scch.at
#
#  Date: November 2006
#
#================================================================================

import sys

sys.path.append("../_build/lib")

from pypcsim import *
import random
from datetime import datetime
from math import *

random.seed( datetime.today().microsecond )
random.seed( 134987 )
tstart=datetime.today()

###################################################################
# Global parameter values
###################################################################

nNeurons        = 4000;   # number of neurons
minDelay        = 1e-3;   # minimum synapse delay [sec]
ConnP           = 0.02;   # connectivity probability
Frac_EXC        = 0.8;    # fraction of excitatory neurons
Tsim            = 0.4;    # duration of the simulation [sec]
DTsim           = 1e-4;   # simulation time step [sec]
nRecordNeurons  = 250;    # number of neurons to plot the spikes from
Tinp            = 50e-3;  # length of the initial stimulus [sec]
nInputNeurons   = 10 ;    # number of neurons which provide initial input (for a time span of Tinp)
inpConnP        = 0.01 ;  # connectivity from input neurons to network neurons
inputFiringRate = 80;     # firing rate of the input neurons during the initial input [spikes/sec]


# Override some of the default values based on command line arguments

###################################################################
# Create an empty network
###################################################################
#net = SingleThreadNetwork( SimParameter( dt=Time.sec( DTsim ) , minDelay = Time.ms(0.1), constructionRNGSeed = 134987 ) )
#net = MultiThreadNetwork( 2, SimParameter( dt=Time.sec( DTsim ) , minDelay = Time.ms(1), constructionRNGSeed = 134987 ) )
net = DistributedSingleThreadNetwork( SimParameter( dt=Time.sec( DTsim ) , minDelay = Time.ms(0.1), constructionRNGSeed = 134987 ) )
#net = DistributedMultiThreadNetwork( 4, SimParameter( dt=Time.sec( DTsim ) , minDelay = Time.ms(0.1), constructionRNGSeed = 134987 ) )

###################################################################
# Create the neurons and set their parameters
###################################################################

#csim('set', neuronIdx, ...
#      'Cm', 2e-10, 'Rm', 1e8, 'Vthresh', -50e-3, 'Vresting', -60e-3, ...
#      'Vreset', -60e-3, 'Trefract', 5e-3, 'Vinit', -60e-3 );

lifmodel = CbLifNeuron( Cm=2e-10, Rm=1e8, Vthresh=-50e-3, Vresting=-60e-3, Vreset=-60e-3, Trefract=5e-3, Vinit=-60e-3 ) ;


exz_nrn = net.add( lifmodel, int( nNeurons *  Frac_EXC ) );
inh_nrn = net.add( lifmodel, nNeurons -len(exz_nrn) );
all_nrn = list(exz_nrn) + list(inh_nrn);

print "Created", len(exz_nrn), "exz and", len(inh_nrn), "inh neurons";

###################################################################
# Create synaptic connections
###################################################################

print 'Making synaptic connections:'
t0=datetime.today()

# csim('set', excSynIdx, 'W',  6e-9,'E',      0, 'tau',  5e-3, 'delay', 0);
n_exz_syn = net.connect( exz_nrn, all_nrn, StaticCondExpSynapse( W=  4e-9, Erev=  0e-3, tau=5e-3, delay=0.1e-3 ), RandomConnections( conn_prob = ConnP ) )

# csim('set', inhSynIdx, 'W', 67e-9,'E', -80e-3, 'tau', 10e-3, 'delay', 0);
n_inh_syn = net.connect( inh_nrn, all_nrn, StaticCondExpSynapse( W= 81e-9, Erev=-80e-3, tau=10e-3, delay=0.1e-3 ), RandomConnections( conn_prob = ConnP ) )

# print "W=",net.object(n_inh_syn[0]).W,"tau=",net.object(n_inh_syn[0]).tau,"d=",net.object(n_inh_syn[0]).delay

print "nex=", n_exz_syn[0], " nin=", n_inh_syn[0]

t1= datetime.today();
print 'Created',int( n_exz_syn[0] + n_inh_syn[0] ),'conductance based synapses in',( t1 - t0 ).seconds,'seconds'

###################################################################
# Create input neurons for the initial stimulus
# and connect them to random neurons in circuit
###################################################################

inp_nrn = [ net.add( SpikingInputNeuron( [ random.uniform(0,Tinp) for x in range( int(inputFiringRate*Tinp) ) ] ) ) for i in range(nInputNeurons) ];

net.connect( inp_nrn, all_nrn, StaticCondExpSynapse( W=6e-9, Erev=0e-3, tau=5e-3, delay=1e-3 ), RandomConnections( conn_prob = inpConnP ) );

###########################################################
# Create recorders to record spikes and voltage traces
###########################################################


spike_rec = range( len(all_nrn)  )
for i in range( len(all_nrn) ):
    spike_rec[i] = net.add( SpikeTimeRecorder(), SimEngine.ID(0,0));
    net.connect( all_nrn[i], spike_rec[i] , Time.ms(1));
    
#rec_nrn = random.sample( all_nrn, nRecordNeurons );
#vm_rec = net.add( AnalogRecorder, [ SimEngine.ID(0,0) for i in range( nRecordNeurons ) ] );
#for i in range( nRecordNeurons ):
#    net.connect( rec_nrn[i], 'Vm', vm_rec[i], 0 );

###########################################################
# Simulate the circuit
###########################################################

print 'Running simulation:';
t0=datetime.today()

net.reset();
net.advance( int( Tsim / DTsim ) )

t1=datetime.today()
print 'Done.', (t1-t0).seconds, 'sec CPU time for', Tsim*1000, 'ms simulation time';
print '==> ', (t1-tstart).seconds, 'seconds total'


#if net.mpi_rank() == 0:
#    net.object(spike_rec[5]).printSpikeTimes()

###########################################################
# Make some figures out of simulation results
###########################################################

def diff(x):
    return [ x[i+1] - x[i] for i in range(len(x)-1) ]
   
def std(data):
    mu = sum(data)/float(len(data));
    r=[ (x-mu)*(x-mu) for x in data ];
    return sqrt(sum(r)/(len(r)-1))

def mean(data):
    return sum(data)/len(data);

def meanisi(spikes):
    if( len(spikes) > 1):
        return mean(diff(spikes));
    else:
        return None

def cv(spikes):
    if( len(spikes) > 2):
        isi=diff(spikes);
        return std(isi) / mean(isi)
    else:
        return None

if net.mpi_rank() == 0:    
    n = [ net.object(spike_rec[i]).spikeCount() for i in range( len(all_nrn) ) ]
    print "spikes total",sum(n)
    print "mean rate:",sum(n)/Tsim/len(n)
    
    isi = [ meanisi(net.object(spike_rec[i]).getSpikeTimes()) for i in range( len(all_nrn) ) ]
    isi = filter( lambda x: x != None, isi);
    print "mean ISI:",sum(isi)/len(isi)
    
    
    CV = [ cv(net.object(spike_rec[i]).getSpikeTimes()) for i in range( len(all_nrn) ) ]
    CV = filter( lambda x: x != None, CV);
    print "mean CV:",mean(CV)


