"""
  Use case for the PCSIM construction framework.
  
  A cuboid grid population is constructed by providing two different factories 
  for inhibitory and excitatory neurons. Connection probabilities depend on distance, 
  and synapse parameters are different for excitatory-all,inhibitory-all connections.
  
"""

import sys

sys.path.append("../_build/lib")

from pypcsim import *
import random, getopt
from datetime import datetime
from math import *

random.seed( datetime.today().microsecond )
random.seed( 134987 )
tstart=datetime.today()

###################################################################
# Global parameter values
###################################################################

nNeurons        = 4000;   # number of neurons
minDelay        = 1e-3;   # minimum synapse delay [sec]
ConnP           = 0.02;   # connectivity probability
Frac_EXC        = 0.8;    # fraction of excitatory neurons
Tsim            = 0.4;    # duration of the simulation [sec]
DTsim           = 1e-4;   # simulation time step [sec]
nRecordNeurons  = 250;    # number of neurons to plot the spikes from
Tinp            = 50e-3;  # length of the initial stimulus [sec]
nInputNeurons   = 10 ;    # number of neurons which provide initial input (for a time span of Tinp)
inpConnP        = 0.01 ;  # connectivity from input neurons to network neurons
inputFiringRate = 80;     # firing rate of the input neurons during the initial input [spikes/sec]


nThreads = 1

def sub_time( t1, t2 ):
    return ( t1 - t2 ).seconds+1e-6*(t1-t2).microseconds;

###################################################################
# Create an empty network
###################################################################

sp = SimParameter( dt=Time.sec( DTsim ) , minDelay = Time.sec(minDelay), simulationRNGSeed = 345678, constructionRNGSeed = 349871 );
net = SingleThreadNetwork( sp )

###################################################################
# Create the neuron factories and set their parameters
###################################################################

# the excitatory neuron factory
exc_nrn_factory = SimObjectVariationFactory( LifNeuron() ) ;
exc_nrn_factory.set("Cm", NormalDistribution(2e-10, 1e-11))
exc_nrn_factory.set("Rm", NormalDistribution(1e8, 5e6))
exc_nrn_factory.set("Vthresh", ConstantNumber(-50e-3))
exc_nrn_factory.set("Vresting", ConstantNumber(-49e-3))
exc_nrn_factory.set("Vreset", ConstantNumber(-60e-3))
exc_nrn_factory.set("Trefract", UniformDistribution(4.8e-3, 5.2e-3))
exc_nrn_factory.set("Vinit", ConstantNumber(-60e-3))


# the inhibitory neuron factory
inh_nrn_factory = SimObjectVariationFactory( LifNeuron() ) ;
inh_nrn_factory.set("Cm", NormalDistribution(2e-10, 2e-11))
inh_nrn_factory.set("Rm", NormalDistribution(1e8, 5e6))
inh_nrn_factory.set("Vthresh", ConstantNumber(-50e-3))
inh_nrn_factory.set("Vresting", ConstantNumber(-49e-3))
inh_nrn_factory.set("Vreset", ConstantNumber(-57e-3))
inh_nrn_factory.set("Trefract", UniformDistribution(4.8e-3, 5.2e-3))
inh_nrn_factory.set("Vinit", ConstantNumber(-57e-3))

all_nrn_popul = SpatialFamilyPopulation( net, [ exc_nrn_factory, inh_nrn_factory ], RatioBasedFamilies( [4, 1]  ), CuboidIntegerGrid3D( 20, 20, 10 ) );

print 

exz_nrn_popul, inh_nrn_popul = tuple( all_nrn_popul.splitFamilies() );

print "Created population of size", all_nrn_popul.size(), ":", exz_nrn_popul.size(), "exzitatory and", inh_nrn_popul.size(), "inhibitory neurons";

###################################################################
# Create synaptic connections
###################################################################

print 'Making synaptic connections:'
t0=datetime.today()

Erev_exc = 0;
Erev_inh = -80e-3;
Vmean    = -60e-3;
Wexc = (Erev_exc-Vmean)*0.27e-9;
Winh = (Erev_inh-Vmean)*4.5e-9;

n_exz_syn_project = ConnectionsProjection( exz_nrn_popul, all_nrn_popul, 
                                           StaticSpikingSynapse( W=Wexc, tau=5e-3, delay=1e-3 ),
                                           EuclideanDistanceRandomConnections( ConnP, 10 ) )

n_inh_syn_project = ConnectionsProjection( inh_nrn_popul, all_nrn_popul, 
                                           StaticSpikingSynapse( W=Winh, tau=10e-3, delay=1e-3 ),
                                           EuclideanDistanceRandomConnections( ConnP, 10 ) )

print "nex=", n_exz_syn_project.size(), " nin=", n_inh_syn_project.size()

t1= datetime.today();
print 'Created',int( n_exz_syn_project.size() + n_inh_syn_project.size()),'current synapses in',sub_time( t1, t0 ),'seconds'

###################################################################
# Create input neurons for the initial stimulus
# and connect them to random neurons in circuit
###################################################################

inp_nrn_popul = SimObjectPopulation( net, [ net.add( SpikingInputNeuron( [ random.uniform(0,Tinp) for x in range( int(inputFiringRate*Tinp) ) ] ) ) for i in range(nInputNeurons) ] );

inp_syn_project = ConnectionsProjection( inp_nrn_popul, exz_nrn_popul, 
                          StaticSpikingSynapse( W=(Erev_exc-Vmean)*2e-9, tau=5e-3, delay=1e-3 ), 
                          RandomConnections( conn_prob = inpConnP ) );

###########################################################
# Create recorders to record spikes and voltage traces
###########################################################

net.setDistributionStrategy(DistributionStrategy.ModuloOverLocalEngines())
spike_rec_popul = SimObjectPopulation(net, SpikeTimeRecorder(), all_nrn_popul.size())
    
rec_conn_project = ConnectionsProjection( all_nrn_popul, spike_rec_popul, Time.ms(1) );  # if delay is specified, wiring method is assumed to be one-to-one
    
vm_rec_nrn_popul = SimObjectPopulation(net, random.sample( all_nrn_popul.idVector(), nRecordNeurons ));
vm_recorders_popul = SimObjectPopulation(net, AnalogRecorder(),  nRecordNeurons );
for i in range(vm_recorders_popul.size()):
    net.connect( vm_rec_nrn_popul[i], 'Vm', vm_recorders_popul[i], 0,  Time.ms(1) );

###########################################################
# Simulate the circuit
###########################################################

print 'Running simulation:';
t0=datetime.today()

net.reset();
net.advance( int( Tsim / DTsim ) )

t1=datetime.today()
print 'Done.', sub_time(t1,t0), 'sec CPU time for', Tsim*1000, 'ms simulation time';
print '==> ', sub_time(t1,tstart), 'seconds total'

###########################################################
# Make some figures out of simulation results
###########################################################

def diff(x):
    return [ x[i+1] - x[i] for i in range(len(x)-1) ]
   
def std(data):
    mu = sum(data)/float(len(data));
    r=[ (x-mu)*(x-mu) for x in data ];
    return sqrt(sum(r)/(len(r)-1))

def mean(data):
    return sum(data)/len(data);

def meanisi(spikes):
    if( len(spikes) > 1):
        return mean(diff(spikes));
    else:
        return None

def cv(spikes):
    if( len(spikes) > 2):
        isi=diff(spikes);
        return std(isi) / mean(isi)
    else:
        return None

if net.mpi_rank() == 0:
    n = [ spike_rec_popul.object(i).spikeCount() for i in range( spike_rec_popul.size() ) ]
    print "spikes total",sum(n)
    print "mean rate:",sum(n)/Tsim/len(n)
    
    isi = [ meanisi(spike_rec_popul.object(i).getSpikeTimes()) for i in range( spike_rec_popul.size() ) ]
    isi = filter( lambda x: x != None, isi);
    print "mean ISI:",sum(isi)/len(isi)
    
    
    CV = [ cv(spike_rec_popul.object(i).getSpikeTimes()) for i in range( spike_rec_popul.size() ) ]
    CV = filter( lambda x: x != None, CV);
    print "mean CV:",mean(CV)


