%--------------------------------------------------------------------------
% Desc - This files launches the experiment with the multiple spiking
%             readouts.
%
%            (c), 2004, Prashant Joshi (joshi@igi.tugraz.at)
%--------------------------------------------------------------------------
clear all classes;
close all;


%-------------------------------------------------------------------------
% Load the simulation parameters and create the liquid
%-------------------------------------------------------------------------
load_params;

if(REGEN_CIRCUIT == 1)
    make_liquid;
else
    [fn,e] = sprintf('%s/nmc.mat', CIRCUIT_PATH);
    load(fn);
end
if(SET_STATIC_SYNAPSES ==1)
    fprintf('\n\n***\n\nRunning the simulation with STATIC synapses\n\n***\n');
    modify_circuit;
else
    fprintf('\n\n***\n\nRunning the simulation with DYNAMIC synapses\n\n***\n');
end;

%-------------------------------------------------------------------------
% Create the stimulus object
%-------------------------------------------------------------------------
TrainDist = discrete_attractor_stimuli;
S = generate(TrainDist);
if(PLOTTING_LEVEL >0)
    figure(1);
    plot(TrainDist, S);
    drawnow;
    figure(2);
    plot_circuit_resp;
    drawnow;
end
anykey;
%-------------------------------------------------------------------------
% PART 1 - TRAINING OF A LINEAR READOUT
%-------------------------------------------------------------------------

% collect the stimulus and response pairs
[train_response,train_stimuli] = collect_sr_data(nmc, ...
                                                     TrainDist,...
                                                     NTRAIN, ...
                                                     TSTIM);

%add_linear_readouts;

pool = get(nmc, 'pool');
pool = pool(7);
linear_readout_idx = pool.neuronIdx;

%reset_linear_readouts;                                                
for readout_idx = 1:NO_READOUTS
    %--------------------------------------------------------------------------
    % make a partial liquid state (only the state due to pre-synaptic neurons
    % of the linear readout is visible to the readout).
    [pre_syn, post_syn] = csim('get', linear_readout_idx(readout_idx), 'connections');
    
    % find the presynaptic neurons to this readout
    pre_neur = [];
    cnt1 = length(pre_syn);
    for idx1 = 1 : cnt1
        [pre_neur(idx1), post_neur] = csim('get', pre_syn(idx1), 'connections');
    end
    pre_neur_idx = double(pre_neur) - NO_IP_CHANNELS - NO_FB_CHANNELS+1;
    partial_train_resp = [];
    for trainIdx = 1:NTRAIN
        cnt2 = length(pre_neur);
        for respIdx = 1:cnt2
            partial_train_resp(trainIdx).channel(respIdx) = train_response(trainIdx).channel(pre_neur_idx(respIdx));
        end
        partial_train_resp(trainIdx).Tsim = TSTIM;
    end
    
    %--------------------------------------------------------------------------
    % calculate the response to states
    train_val  = response2states(partial_train_resp,[], ...
                             [0:DT:TSTIM], ...
                             { 'spikes2alpha' 0.03 });

    train_states.X = vertcat(train_val(:).X);
    % add noise to network states
    noise_var = 0.05*mean(abs(train_states.X(:)));
    trainX = train_states.X + noise_var*(rand(size(train_states.X)) - 0.5);
    
    % make the target values
    trainY = [];
    for idx3 = 1:NTRAIN
        target_val = find_target_val(train_stimuli(idx3));
        trainY = vertcat(trainY, target_val');
    end;
    b = ones(1,size(trainX, 2));
    %-- start - joshi
    %pool_temp = get(nmc, 'pool');
    %pool_temp = pool_temp(6);
    %type = csim('get', pool_temp.neuronIdx, 'type');
    type = csim('get', pre_neur, 'type');
    inh_Idx = find(type == 1);
    trainX(:, inh_Idx) = -1*trainX(:,inh_Idx);
    %-- end - joshi
%    b = regress(trainY, trainX);
    b = lsqnonneg(trainX, trainY);
    b(inh_Idx) = -1*b(inh_Idx);
    
    % set the weights of the linear neuron to the trained weights
    %--------------------------------------------------------------------
    % Set only realistic weights
    cnt3 = length(b);
    for idx4 = 1:cnt3
        [pre, post] = csim('get', pre_neur(idx4), 'connections');
        idx_syn = find(ismember(post, pre_syn));
        csim('set', post(idx_syn), 'W', b(idx4));
    end

    % see the performance of the linear readout now.
    err = mse(trainY - trainX*b);
    ERROR(readout_idx) = err;
    fprintf('Error = %g\n\n', err);
end;
reset(nmc); R = simulate(nmc, TSTIM, S); 
figure; plot_linear_readout_response(S,R);drawnow;
pause(1);

%-------------------------------------------------------------------------
% PART 2 - NOW SIMULATING USING SPIKING READOUTS
%-------------------------------------------------------------------------

%-------------------------------------------------------------------------
% Now create a pool of leaky integrate and fire neuron
[nmc, plif] = add(nmc, 'Pool', 'origin', ORIG_POOL_SPIKING_READOUTS, ...
                 'size', [1 1 NO_READOUTS], 'type', 'LifNeuron', ...
                 'Neuron.Vthresh', VTHRESH, 'Neuron.Vreset', VRESET, ...
                 'Neuron.Vinit', VINIT, 'Neuron.Iinject', IINJECT, ...
                 'Neuron.Trefract', TREFRACT, 'Neuron.Inoise',  INOISE, ...
                 'Neuron.Cm', CM, 'frac_EXC', 1);

%--------------------------------------------------------------------------
% Add recorders to record the spikes in the spiking readout, as well as the
% membrane potential of the spiking readout
nmc = record(nmc, 'Pool', plif, 'Field', 'spikes','dt',DT);
nmc = record(nmc, 'Pool', plif, 'Field', 'Vm','dt',DT);

reset(nmc);
 
pool = get(nmc, 'pool');
pool = pool(8);
spiking_readout_idx = pool.neuronIdx;
 
% connect the neurons from the liquid to the spiking readouts
pre_syn = []; post_syn = [];
for srIdx = 1:NO_READOUTS
    Wlr.readout(srIdx).weights =[];
    Wsr.readout(srIdx).weights =[];
    [pre_syn, post_syn] = csim('get', linear_readout_idx(srIdx), 'connections');
    pre_neur = [];post_neur=[];idx5 = 0;
    cnt4 = length(pre_syn);
    for idx5 = 1 : cnt4
        [pre_neur(idx5), post_neur] = csim('get', pre_syn(idx5), 'connections');
    end
    pre = []; post=[];
    cnt5 = length(pre_neur);
    for idx6 = 1:cnt5
        [pre, post] = csim('get', pre_neur(idx6), 'connections');
        id = find(ismember(post, pre_syn));
        w = csim('get', post(id), 'W');
        Wlr.readout(srIdx).weights = [Wlr.readout(srIdx).weights w];
        Wsr.readout(srIdx).weights = [Wsr.readout(srIdx).weights w*ALPHA];
        syn_liq2_spk_readout(idx6) = csim('create',STYPE_IP);
	csim('set',syn_liq2_spk_readout(idx6),'W', Wsr.readout(srIdx).weights(idx6));
        csim('connect',spiking_readout_idx(srIdx),pre_neur(idx6),syn_liq2_spk_readout(idx6));
    end
end

% connect the spiking readouts to the liquid (actual feedback)
pool = get(nmc, 'pool');
pool = pool(5);
teacher_fb_neuron_idx = pool.neuronIdx;
pre=[];post=[];
pre_syn = []; post_syn = [];
for tfbIdx = 1:NO_READOUTS
    Wtfb.readout(tfbIdx).weights = [];
    Wsfb.readout(tfbIdx).weights =[];
    [pre_syn, post_syn] = csim('get', teacher_fb_neuron_idx(tfbIdx), 'connections');
    pre_neur=[];post_neur=[];
    cnt6 = length(post_syn);
    for idx = 1 : cnt6
        [pre_neur, post_neur(idx)] = csim('get', post_syn(idx), 'connections');
    end
    pre = []; post = [];
    cnt7 = length(post_neur);
    for idx7 = 1:cnt7
        [pre, post] = csim('get', post_neur(idx7), 'connections');
        id = find(ismember(pre, post_syn));
        syn_spk_readout2_liq(idx7) = csim('create',STYPE_IP);
        tau_tfb   = csim('get', pre(id), 'tau');
        w         = csim('get', pre(id), 'W');
        delay_tfb = csim('get', pre(id), 'delay');
        csim('set',syn_spk_readout2_liq(idx7),'tau', tau_tfb);
        csim('set',syn_spk_readout2_liq(idx7),'delay', delay_tfb);
        
        Wtfb.readout(tfbIdx).weights = [Wtfb.readout(tfbIdx).weights w];
        Wsfb.readout(tfbIdx).weights = [Wsfb.readout(tfbIdx).weights w];
        csim('set',syn_spk_readout2_liq(idx7),'W',Wsfb.readout(tfbIdx).weights(idx7)); % weight
        csim('connect',post_neur(idx7), spiking_readout_idx(tfbIdx),syn_spk_readout2_liq(idx7));
    end
end

%best_example = plot_readouts(trained_readouts,V.test_response,V.test_stimuli,...
%                             { 'response2states' [0:DT:TSTIM] {'spikes2exp' 0.03}} );  



%add_linear_readouts;
train_readouts_in_closed_loop;
validate;
close all;
[fp_mat, fn_mat, opt] = perf_measure;
[fn,e] = sprintf('%s/workspace.mat' RESULT_PATH);
save(fn, 'nmc', 'trained_readouts', 'TrainDist', 'Wtfb', 'Wsfb');
