#================================================================================
#
#  PyPCSIM implementation of a benchmark simulation described in the paper
#  "Simulation of networks of spiking neurons: A review of tools and strategies"
#
#  Benchmark 2: Current-based (CUBA) IF network. This benchmark consists of a 
#               network of intefrate-and-fire neurons connected with 
#               current-based synapses.
#
#  This implementation uses the population based pcsim high level interface.
#
#  PCSIM is freely available from http://sf.net/projects/pcsim
#
#  Authors: Dejan Pecevski, dejan@igi.tugraz.at
#           Thomas Natschlaeger, thomas.natschlaeger@scch.at
#
#  Date: December 2006
#
#================================================================================

import sys

sys.path.append("../_build/lib")

from pypcsim import *
import random, getopt
from datetime import datetime
from math import *




random.seed( datetime.today().microsecond )
random.seed( 134987 )
tstart=datetime.today()

###################################################################
# Global parameter values
###################################################################

nNeurons        = 4000;   # number of neurons
minDelay        = 1e-3;   # minimum synapse delay [sec]
ConnP           = 0.02;   # connectivity probability
Frac_EXC        = 0.8;    # fraction of excitatory neurons
Tsim            = 0.4;    # duration of the simulation [sec]
DTsim           = 1e-4;   # simulation time step [sec]
nRecordNeurons  = 250;    # number of neurons to plot the spikes from
Tinp            = 50e-3;  # length of the initial stimulus [sec]
nInputNeurons   = 10 ;    # number of neurons which provide initial input (for a time span of Tinp)
inpConnP        = 0.01 ;  # connectivity from input neurons to network neurons
inputFiringRate = 80;     # firing rate of the input neurons during the initial input [spikes/sec]


nThreads = 1
networkType = 'ST'
# Analyze command line options and override default values of certain parameters
optlist, args = getopt.getopt(sys.argv[1:] , '', ['networkType=', 'nthreads=' , 'nNeurons=', 'connectionProbability=', 'logfile'] )

for opt, arg in optlist:
    if opt == '--networkType':
      networkType = arg
    elif opt == '--nthreads':
        nThreads = int(arg)
    elif opt == '--nNeurons':
        nNeurons = int(arg)
    elif opt == '--connectionProbability':
        ConnP = float(arg) 

def sub_time( t1, t2 ):
    return ( t1 - t2 ).seconds+1e-6*(t1-t2).microseconds;

###################################################################
# Create an empty network
###################################################################

sp = SimParameter( dt=Time.sec( DTsim ) , minDelay = Time.sec(minDelay), simulationRNGSeed = 345678, constructionRNGSeed = 349871 );

if networkType == 'ST':
    net = SingleThreadNetwork( sp )
elif networkType == 'MT':
	if nThreads < 2:
		nThreads = 2;
	print 'multi thread:', nThreads  
	net = MultiThreadNetwork( nThreads, sp )    
elif networkType == 'DST':
    net = DistributedSingleThreadNetwork( sp )
elif networkType == 'DMT':
    net = DistributedMultiThreadNetwork( nThreads, sp )


###################################################################
# Create the neurons and set their parameters
###################################################################

lifmodel = LifNeuron( Cm=2e-10, Rm=1e8, Vthresh=-50e-3, Vresting=-49e-3, Vreset=-60e-3, Trefract=5e-3, Vinit=-60e-3 ) ;


exz_nrn_popul = SimObjectPopulation(net, lifmodel, int( nNeurons *  Frac_EXC ) );
inh_nrn_popul = SimObjectPopulation(net, lifmodel, nNeurons - exz_nrn_popul.size() );

all_nrn_popul = SimObjectPopulation(net, list(exz_nrn_popul.idVector()) + list(inh_nrn_popul.idVector()));

print "Created", exz_nrn_popul.size(), "exz and", inh_nrn_popul.size(), "inh neurons";

###################################################################
# Create synaptic connections
###################################################################

print 'Making synaptic connections:'
t0=datetime.today()

Erev_exc = 0;
Erev_inh = -80e-3;
Vmean = -60e-3;
Wexc = (Erev_exc-Vmean)*0.27e-9;
Winh = (Erev_inh-Vmean)*4.5e-9;


n_exz_syn_project = ConnectionsProjection( exz_nrn_popul, all_nrn_popul, 
                                              StaticSpikingSynapse( W=Wexc, tau=5e-3, delay=1e-3 ),                                               
                                              RandomConnections( conn_prob = ConnP ), 
                                              SimpleAllToAllWiringMethod(net),
                                              True )

n_inh_syn_project = ConnectionsProjection( inh_nrn_popul, all_nrn_popul, 
                             StaticSpikingSynapse( W=Winh, tau=10e-3, delay=1e-3 ), 
                             RandomConnections( conn_prob = ConnP ))

# print "W=",net.object(n_inh_syn[0]).W,"tau=",net.object(n_inh_syn[0]).tau,"d=",net.object(n_inh_syn[0]).delay

syn_id = n_exz_syn_project(100);

synapse_handle = net.object( syn_id )

if syn_id.node == net.mpi_rank():
    print "The delay of the 100th synapse is " , synapse_handle.delay


print "nex=", n_exz_syn_project.size(), " nin=", n_inh_syn_project.size()

t1= datetime.today();
print 'Created',int( n_exz_syn_project.size() + n_inh_syn_project.size()),'current based synapses in',( t1 - t0 ).seconds,'seconds'

###################################################################
# Create input neurons for the initial stimulus
# and connect them to random neurons in circuit
###################################################################

inp_nrn_popul = SimObjectPopulation( net, [ net.add( SpikingInputNeuron( [ random.uniform(0,Tinp) for x in range( int(inputFiringRate*Tinp) ) ] ) ) for i in range(nInputNeurons) ] );

inp_syn_project = ConnectionsProjection( inp_nrn_popul, exz_nrn_popul, 
                          StaticSpikingSynapse( W=(Erev_exc-Vmean)*2e-9, tau=5e-3, delay=1e-3 ), 
                          RandomConnections( conn_prob = inpConnP ) );

###########################################################
# Create recorders to record spikes and voltage traces
###########################################################
net.setDistributionStrategy(DistributionStrategy.ModuloOverLocalEnginesOnOneNode(0)) # this currently doesn't work, to be fixed
spike_rec_popul = SimObjectPopulation(net, SpikeTimeRecorder(), all_nrn_popul.size())
# for i in range(spike_rec_popul.size()):
#    print "recorder " , i, " is on node " , spike_rec_popul(i).node
rec_conn_project = ConnectionsProjection( all_nrn_popul, spike_rec_popul, Time.ms(1) );  # if delay is specified, wiring method is assumed to be one-to-one
vm_rec_nrn_popul = SimObjectPopulation(net, random.sample( all_nrn_popul.idVector(), nRecordNeurons ));
vm_recorders_popul = SimObjectPopulation(net, AnalogRecorder(),  nRecordNeurons );
for i in range(vm_recorders_popul.size()):
    net.connect( vm_rec_nrn_popul[i], 'Vm', vm_recorders_popul[i], 0,  Time.ms(1) );

###########################################################
# Simulate the circuit
###########################################################

print 'Running simulation:';
t0=datetime.today()

net.reset();
print "advancing 20 steps"
net.advance( 10 )
print "advancing the rest"
net.advance( int( Tsim / DTsim ) )
t1=datetime.today()
print 'Done.', (t1-t0).seconds, 'sec CPU time for', Tsim*1000, 'ms simulation time';
print '==> ', (t1-tstart).seconds, 'seconds total'


#if net.mpi_rank() == 0:
#    net.object(spike_rec[5]).printSpikeTimes()

###########################################################
# Make some figures out of simulation results
###########################################################

def diff(x):
    return [ x[i+1] - x[i] for i in range(len(x)-1) ]
   
def std(data):
    mu = sum(data)/float(len(data));
    r=[ (x-mu)*(x-mu) for x in data ];
    return sqrt(sum(r)/(len(r)-1))

def mean(data):
    return sum(data)/len(data);

def meanisi(spikes):
    if( len(spikes) > 1):
        return mean(diff(spikes));
    else:
        return None

def cv(spikes):
    if( len(spikes) > 2):
        isi=diff(spikes);
        return std(isi) / mean(isi)
    else:
        return None

if net.mpi_rank() == 0:
    n = [ spike_rec_popul.object(i).spikeCount() for i in range( spike_rec_popul.size() ) ]
    print "spikes total",sum(n)
    print "mean rate:",sum(n)/Tsim/len(n)
    
    isi = [ meanisi(spike_rec_popul.object(i).getSpikeTimes()) for i in range( spike_rec_popul.size() ) ]
    isi = filter( lambda x: x != None, isi);
    print "mean ISI:",sum(isi)/len(isi)
    
    
    CV = [ cv(spike_rec_popul.object(i).getSpikeTimes()) for i in range( spike_rec_popul.size() ) ]
    CV = filter( lambda x: x != None, CV);
    print "mean CV:",mean(CV)
