/*! \file  mexnetwork.cpp
**  \brief Implementation of Network
*/

#include "mexnetwork.h"
#include "mexrecorder.h"
#include "csimlist.h"
#include "csimerror.h"
#include "csimmex.h"

#define object(i) (objectList.elements[(i)])    
#define nObjects  (objectList.n)    

MexNetwork::MexNetwork() {
}


MexNetwork::~MexNetwork() {
}


int MexNetwork::addNewObject(Advancable *a)
{
  
  MexRecorder *r=dynamic_cast<MexRecorder *>(a);
  
  if ( r ) mexRecorderList.add(r);
  
  return Network::addNewObject(a);

}

mxArray *MexNetwork::getMexOutput(void)
{

  mxArray *out_a = mxCreateCellMatrix(1,mexRecorderList.n+(spikeOutput>0));

  for(unsigned i=0;i<mexRecorderList.n;i++)
    mxSetCell(out_a,i,mexRecorderList.elements[i]->getMxStructArray());

  // as the last cell we always output the spikes in the form
  // ST.idx [uint32 array], ST.times [double array]
  /*
  ** output the spike times of all but the input neurons
  */

  if ( spikeOutput ) {
    unsigned long nSpikes = totalSpikeCount();
    
    const char *st_fields[]= { "idx", "times" };
    mxArray *st_a = mxCreateStructMatrix(1,1,2,st_fields);


    if ( nSpikes > 0 ) {
      mxArray *idx_a   = mxCreateNumericMatrix ( 1, nSpikes, mxUINT32_CLASS, mxREAL );
      mxArray *times_a = mxCreateDoubleMatrix( 1, nSpikes, mxREAL);
      
      mxSetFieldByNumber(st_a,0,0,idx_a);
      mxSetFieldByNumber(st_a,0,1,times_a);
      
      uint32 *idx  = (uint32 *)mxGetData(idx_a);
      double *times= mxGetPr(times_a);
      
      getSpikes(idx,times);
    }
    
    mxSetCell(out_a,mexRecorderList.n,st_a);
  }

  return out_a;
}

mxArray *MexNetwork::exportNetwork(void)
{
  int sz,nCon=0;

  const char *fields[]= { "globals", "object", "dst", "src", "recorderInfo", "version" };
  
  mxArray *mxNet = mxCreateStructMatrix(1,1,sizeof(fields)/sizeof(char *),fields);
  mxSetField(mxNet,0,"version",mxCreateString("$Id: mexnetwork.cpp,v 1.7 2003/06/03 13:41:05 haeusler Exp $"));

  mxArray *src_a;  
  if ( nSrc > 0 ) { 
    src_a = mxCreateNumericMatrix ( 1, nSrc, mxUINT32_CLASS, mxREAL );
    memcpy((uint32 *)mxGetData(src_a),src,nSrc*sizeof(uint32));
  } else {
    src_a = mxCreateDoubleMatrix ( 0, 0, mxREAL );
  }
  mxSetField(mxNet,0,"src",src_a);
  
  mxArray *dst_a;
  if ( nDst > 0 ) { 
    dst_a = mxCreateNumericMatrix ( 1, nDst, mxUINT32_CLASS, mxREAL );
    memcpy((uint32 *)mxGetData(dst_a),dst,nDst*sizeof(uint32));
  } else {
    dst_a = mxCreateDoubleMatrix ( 0, 0, mxREAL );
  }
  mxSetField(mxNet,0,"dst",dst_a);

  nCon += nDst;
  
  mxArray *rec_a;
  if ( nRecMem > 0 ) { 
    rec_a = mxCreateNumericMatrix ( 1, nRecMem, mxUINT8_CLASS, mxREAL );
    memcpy(mxGetData(rec_a),RecMem,nRecMem);
  } else {
    rec_a = mxCreateDoubleMatrix ( 0, 0, mxREAL );
  }
  mxSetField(mxNet,0,"recorderInfo",rec_a);
  
  // export the fields of the Network
  mxArray *g_a;
  g_a = mxCreateDoubleMatrix ( 1, sz=this->getFieldArraySize(), mxREAL );
  double *p=this->getFieldArray();
  memcpy(mxGetPr(g_a),p,sz*sizeof(double));
  if (p) free(p); p=0;
  mxSetField(mxNet,0,"globals",g_a);
  
  mxArray *mxObjects;
  if ( nObjects > 0 ) {
    const char *objfields[]= { "type", "parameter" };
    mxObjects = mxCreateStructMatrix(1,nObjects,sizeof(objfields)/sizeof(char *),objfields);
    int type_i  = mxGetFieldNumber(mxObjects,"type");
    int param_i = mxGetFieldNumber(mxObjects,"parameter");
    mxArray *mxParam;
    for(unsigned i=0;i<nObjects;i++) {
      mxSetFieldByNumber(mxObjects,i,type_i,mxCreateString(object(i)->className()));
      sz=object(i)->getFieldArraySize();
      mxParam = mxCreateDoubleMatrix ( 1, sz, mxREAL );
      double *p=object(i)->getFieldArray((char *)object(i));
      memcpy(mxGetPr(mxParam),p,sz*sizeof(double));
      if (p) free(p); p=0;
      mxSetFieldByNumber(mxObjects,i,param_i,mxParam);
    }
  } else {
    mxObjects = mxCreateDoubleMatrix ( 0, 0, mxREAL );
  }
  csimPrintf("CSIM: %i objects and %i connections exported\n",nObjects,nCon);

  mxSetField(mxNet,0,"object",mxObjects);

  return mxNet;

}

#define max(A, B) ((A) > (B) ? (A) : (B))
#define MAXTYPENAMELENGTH 2000

#include "classlist.i"

int MexNetwork::importNetwork(const mxArray *mxNet)
{
  if ( nObjects > 0 ) {
    TheCsimError.add("MexNetwork::importMexNetwork: can not merge networks!\n");
    return -1;
  }

  mxArray *mxObjects=mxGetField(mxNet,0,"object");
  if ( !mxObjects ) {
    TheCsimError.add("MexNetwork::importNetwork: input is no struct array with field 'object'!\n");
    return -1;
  }

  mxArray *mxDst=mxGetField(mxNet,0,"dst");
  if ( !mxDst ) {
    TheCsimError.add("MexNetwork::importNetwork: input is no struct array with field 'dst'!\n");
    return -1;
  }

  mxArray *mxSrc=mxGetField(mxNet,0,"src");
  if ( !mxSrc ) {
    TheCsimError.add("MexNetwork::importNetwork: input is no struct array with field 'src'!\n");
  }

  mxArray *mxRec=mxGetField(mxNet,0,"recorderInfo");
  if ( !mxRec ) {
    TheCsimError.add("MexNetwork::importNetwork: input is no struct array with field 'recorderInfo'!\n");
  }

  mxArray *mxGlob=mxGetField(mxNet,0,"globals");
  if ( !mxGlob ) {
    TheCsimError.add("MexNetwork::importNetwork: input is no struct array with field 'globals'!\n");
  }

  int nm;
  if ( (nm=mxGetN(mxGlob)*mxGetM(mxGlob)) == this->getFieldArraySize() ) {
    this->setFieldArray(mxGetPr(mxGlob));
  } else {
    TheCsimError.add("MexNetwork::importNetwork: length of 'globals' (%i) != %i which is required by Network!\n\n"
		     ,nm,this->getFieldArraySize());
    return -1;    
  }
  

  int nObj    = max(mxGetN(mxObjects),mxGetM(mxObjects));
  int type_i  = mxGetFieldNumber(mxObjects,"type");
  int param_i = mxGetFieldNumber(mxObjects,"parameter");
  if ( type_i < 0 || param_i < 0 ) {
    TheCsimError.add("MexNetwork::importNetwork: field 'object' is ot a struct array with fields 'type' and 'parameter'!\n");
    return -1;    
  }

  for(int i=0;i<nObj;i++) {
    mxArray *mxCN = mxGetFieldByNumber(mxObjects,i,type_i);
    if ( !mxCN ) {
      TheCsimError.add("MexNetwork::importNetwork:  mxGetFieldByNumber(mxObjects,i,type_i) failed \n");
      return -1;    
    }
    char *className;
    if ( getString(mxCN,&className) ) {
      TheCsimError.add("MexNetwork::importNetwork: mxGetString(mxCN,className,MAXTYPENAMELENGTH) failed \n");
      return -1;
    }

    int switchError = 0;
    Advancable *a;

    #define __SWITCH_COMMAND__  { (addNewObject(a=(Advancable *)(new TYPE))); \
                                   if ( a->init(a) < 0 ) { \
                                     TheCsimError.add("MexNetwork::importNetwork: error calling init of %s!\n",className); \
                                     return -1; \
                                   } \
				}

    #include "switch.i"

    mxArray *mxParam = mxGetFieldByNumber(mxObjects,i,param_i);
    if ( (nm=mxGetN(mxParam)*mxGetM(mxParam)) == object(i)->getFieldArraySize() ) {
      object(i)->setFieldArray((char *)(object(i)),mxGetPr(mxParam));
    } else {
      TheCsimError.add("MexNetwork::importNetwork: length of object(%i).parameter (%i) != %i which is required by %s!\n\n"
		       ,i,nm,object(i)->getFieldArraySize(),object(i)->className());
      return -1;    
    }
  }
  csimPrintf("CSIM: %i objects",nObjects);

  unsigned nCon = max(mxGetN(mxDst),mxGetM(mxDst));

  if ( nCon > 0 ) {
    uint32 *dst = (uint32 *)mxGetData(mxDst);
    uint32 *src = (uint32 *)mxGetData(mxSrc);
    for(unsigned j=0;j<nCon;j++)
      connect(dst[j],src[j]);
  }
  csimPrintf(" and %i connections imported\n",nCon);

  char *r       = (char *)mxGetData(mxRec);     // pointer to current position in the recorder call data
  int  nMem = max(mxGetN(mxRec),mxGetM(mxRec)); // number of bytes of recorder call data

  uint32 rIdx,n;
  while ( r - (char *)mxGetData(mxRec) < nMem )  {
    memcpy(&rIdx,r,sizeof(uint32));    r += sizeof(uint32);
    memcpy(&n,r,sizeof(uint32));       r += sizeof(uint32);
    uint32 *objIdx=(uint32 *)r;        r += sizeof(uint32)*n;
    char *fieldName=(char *)r;         r += strlen(fieldName)+1;
 
    connect(rIdx,objIdx,n,fieldName);
  }
 
  return 0;
}
