#================================================================================
#
#  PyPCSIM implementation of a variation of a benchmark simulation described in 
#  the paper "Simulation of networks of spiking neurons: A review of tools 
#  and strategies"
#
# Benchmark 3: Conductance-based HH network. This benchmark consists of a 
#              network of HH point neurons connected by 
#              conductance-based synapses.
#
#  PyPCSIM is freely available from http://www.sf.net/projects/pcsim
#
#  Authors: Dejan Pecevski, dejan@igi.tugraz.at
#           Thomas Natschlaeger, thomas.natschlaeger@scch.at
#
#  Date: March 2007
#
#================================================================================

import sys

sys.path.append('../_build/lib')
from pypcsim import *
import random, getopt
from datetime import datetime
from math import *

random.seed( datetime.today().microsecond )
random.seed( 134987 )
tstart=datetime.today()

###################################################################
# Global parameter values
###################################################################

nNeurons        = 400;   # number of neurons
minDelay        = 1e-3;   # minimum synapse delay [sec]
ConnP           = 0.02;   # connectivity probability
Frac_EXC        = 0.8;    # fraction of excitatory neurons
Tsim            = 0.4;    # duration of the simulation [sec]
DTsim           = 1e-4;   # simulation time step [sec]
nRecordNeurons  = 25;     # number of neurons to plot the spikes from
Tinp            = 50e-3;  # length of the initial stimulus [sec]
nInputNeurons   = 10 ;    # number of neurons which provide initial input (for a time span of Tinp)
inpConnP        = 0.01 ;  # connectivity from input neurons to network neurons
inputFiringRate = 80;     # firing rate of the input neurons during the initial input [spikes/sec]

nThreads = 1
networkType = 'MT'
# Analyze command line options and override default values of certain parameters
optlist, args = getopt.getopt(sys.argv[1:] , '', ['networkType=', 'nthreads=' , 'nNeurons=', 'connectionProbability=', 'logfile'] )

for opt, arg in optlist:
    if opt == '--networkType':
      networkType = arg
    elif opt == '--nthreads':
        nThreads = int(arg)
    elif opt == '--nNeurons':
        nNeurons = int(arg)
    elif opt == '--connectionProbability':
        ConnP = float(arg)

def sub_time( t1, t2 ):
    return ( t1 - t2 ).seconds+1e-6*(t1-t2).microseconds;

###################################################################
# Create an empty network
###################################################################

sp = SimParameter( dt=Time.sec( DTsim ) , minDelay = Time.sec(minDelay), simulationRNGSeed = 345678, constructionRNGSeed = 349871 );

if networkType == 'ST':
    net = SingleThreadNetwork( sp )
elif networkType == 'MT':
        if nThreads < 2:
                nThreads = 2;
        print 'multi thread:', nThreads
        net = MultiThreadNetwork( nThreads, sp )
elif networkType == 'DST':
    net = DistributedSingleThreadNetwork( sp )
elif networkType == 'DMT':
    net = DistributedMultiThreadNetwork( nThreads, sp )

###################################################################
# Create the neurons and set their parameters
###################################################################

t0=datetime.today()
 
exz_nrn = net.create( HHNeuronTraubMiles91(), int( nNeurons *  Frac_EXC ) );
inh_nrn = net.create( HHNeuronTraubMiles91(), nNeurons -len(exz_nrn) );
all_nrn = list(exz_nrn) + list(inh_nrn);

print "Created",len(exz_nrn),"exz and",len(inh_nrn),"inh neurons";

###################################################################
# Create synaptic connections
###################################################################

print 'Making synaptic connections:'

t0=datetime.today()

Erev_exc = 0;
Erev_inh = -80e-3;
Wexc = 2e-9;
Winh = 33e-9;

exz_syn = SimObjectVariationFactory( StaticCondExpSynapse( W=Wexc, tau= 5e-3, delay=1e-3, Erev = Erev_exc ) );
#exz_syn.set( "tau", UniformDistribution( 4e-3, 6e-3 ) )

inh_syn = SimObjectVariationFactory( StaticCondExpSynapse( W=Winh, tau=10e-3, delay=1e-3, Erev = Erev_inh ) );
#inh_syn.set( "tau", UniformDistribution( 9e-3, 13e-3 ) )

n_exz_syn = net.connect( exz_nrn, all_nrn, exz_syn, RandomConnections( conn_prob = ConnP ) )[0]
n_inh_syn = net.connect( inh_nrn, all_nrn, inh_syn, RandomConnections( conn_prob = ConnP ) )[0]

print 'Created',n_exz_syn+n_inh_syn,'conductance based synapses in',sub_time( datetime.today(), t0 ),'seconds (nex=', n_exz_syn, ' nin=', n_inh_syn,')'
constructionTime = sub_time( datetime.today() , t0 )

###################################################################
# Create input neurons for the initial stimulus
# and connect them to random neurons in circuit
###################################################################

inp_nrn = [ net.add( SpikingInputNeuron( [ random.uniform(0,Tinp) for x in range( int(inputFiringRate*Tinp) ) ] ) ) for i in range(nInputNeurons) ];

net.connect( inp_nrn, all_nrn, StaticCondExpSynapse( W=6e-9, tau=5e-3, delay=1e-3, Erev = 0 ), RandomConnections( conn_prob = inpConnP ) );

###########################################################
# Create recorders to record spikes and voltage traces
###########################################################

spike_rec = range( len(all_nrn)  )
for i in range( len(all_nrn) ):    
    spike_rec[i] = net.add( SpikeTimeRecorder(), SimEngine.ID(0,0));
    net.connect( all_nrn[i], spike_rec[i] , Time.ms(1));
    
rec_nrn = random.sample( all_nrn, nRecordNeurons );
vm_rec = net.create( AnalogRecorder(), [ SimEngine.ID(0,0) for i in range( nRecordNeurons ) ] );
for i in range( nRecordNeurons ):
    net.connect( rec_nrn[i], 'Vm', vm_rec[i], 0, Time.ms(1) );

###########################################################
# Simulate the circuit
###########################################################

print 'Running simulation:';
t0=datetime.today()

net.reset();
net.simulate( Tsim )

simulationTime = sub_time( datetime.today(), t0 );
print 'Done.', simulationTime, 'sec CPU time for', Tsim*1000, 'ms simulation time';
totalTime = sub_time( datetime.today(), tstart );
print '==> ', totalTime, 'seconds total'


###########################################################
# Make some figures out of simulation results
###########################################################

#import matplotlib
#matplotlib.use('GTKAgg')
#from pylab import *

def diff(x):
    return [ x[i+1] - x[i] for i in range(len(x)-1) ]

def mean(data):
    return sum(data)/len(data)

def std(data):
    mu = sum(data)/float(len(data));
    r=[ (x-mu)*(x-mu) for x in data ];
    return sqrt(sum(r)/(len(r)-1))


def meanisi(spikes):
    if( len(spikes) > 1):        
        return mean(diff(spikes))
    else:
        return None

def cv(spikes):
    if( len(spikes) > 2):
        isi=diff(spikes);
        return std(isi) / mean(isi)
    else:
        return None

if net.mpi_rank() == 0:    
    n = [ net.object(spike_rec[i]).spikeCount() for i in range( len(all_nrn) ) ]
    print "spikes total", sum(n)
    meanRate = sum(n)/Tsim/len(n)
    print "mean rate:", meanRate
    
    isi = [ meanisi(net.object(spike_rec[i]).getSpikeTimes()) for i in range( len(all_nrn) ) ]
    isi = filter( lambda x: x != None, isi);
    print "mean ISI:",sum(isi)/len(isi)
    
    
    CV = [ cv(net.object(spike_rec[i]).getSpikeTimes()) for i in range( len(all_nrn) ) ]
    CV = filter( lambda x: x != None, CV);
    print "mean CV:",mean(CV)

    """
    vm = net.object( vm_rec[0] ).getRecordedValues() 
    plot(
       net.object( vm_rec[0] ).getRecordedValues(),'b',
       net.object( vm_rec[2] ).getRecordedValues(),'g',
       net.object( vm_rec[3] ).getRecordedValues(),'r'
       )
    show()
    savefig('vm.png')
    
    x = []
    for i in range( min(len(all_nrn),250) ):
      x.extend( net.object(spike_rec[i]).getSpikeTimes() )
      
    y = []
    for i in range( min(len(all_nrn),250) ):
      y.extend( net.object(spike_rec[i]).spikeCount() * [ i ] )
    
    clf()
    print
    plot(x,y,'k.')
    show()
    """
