/* * C/Matlab/mex Viterbi routine for ISI * * Note that the index origin is 0 here in C, but is 1 in Matlab. * * R. Perry, 9 April 1999 * * updated with GET_COST_M compile option to use m file for get_cost() * based on modifications by Hai Chen. 12 May 2000 * * updated for get_cost() here to handle direct ML, as an alternative to * the GET_COST_M option, to speed up run-times. 14 May 2000 * * updated to handle pilot bits in get_cost() here. 29 May 2000 */ #include /* for memset() */ #include "mex.h" /* for Matlab/mex stuff */ /* #include -- for log(), but conflict with HUGE_VAL here */ extern double log( double); #define S(i,j) states[(i)+(j)*Memory] static const double *states; /* state array */ static const double *z; /* input data vector */ static int M; /* number of states */ static int N; /* number of samples */ static int Memory; /* channel memory size */ static int L; /* number of channel coefficients */ static double HUGE_VAL; /* for "infinite" costs */ #ifndef GET_COST_M static const double *trans; /* state transition matrix */ static const double *FIR; /* channel coefficients */ static const double *g; /* and covariance */ static const mxArray *mx_trans; static const mxArray *mx_FIR; static const mxArray *mx_g; static int GAUSS, GM, ML; /* algorithm type flags */ static double var_noise; static const double *Rh, *avg_h; static int avg_h_N; static const mxArray *mx_Rh, *mx_avg_h; static const mxArray *mx_tmp; static const double *pilot; /* for pilot bits */ #endif static const mxArray *mx_zstate, *mx_zstate_0, *mx_w0; static int zstate, zstate_0; static double *w0; static double *w; /* cost array */ #ifndef GET_COST_M static double *y, *t; /* allocated in vit(), used in get_cost() */ static double *s_cur, *s_next; #endif #if DEBUG static double *V; #endif static double *path_index; /* optimal path */ #ifndef GET_COST_M static double get_cost( double zz, int cur, int next, int j) { double w, s2; const double *pg; int a, i, k; #if DEBUG > 2 mexPrintf( "get_cost: zz = %g, cur = %d, next = %d, j = %d\n", zz, cur, next, j); #endif w = HUGE_VAL; /* default return value */ /* If we have a non-zero pilot bit for time j, and it's not equal to S(0,next), * then the state is not reachable. */ if( pilot && pilot[j] && pilot[j] != S(0,next) ) return w; if( zstate_0 && j+1 <= Memory) { /* handle initial times when the channel contains 0's */ for( i = 0; i < j; ++i) /* s_cur = [states(1:j-1,cur)' zeros(1,Memory-j+1)]; */ s_cur[i] = S(i,cur); for( ; i < Memory; ++i) s_cur[i] = 0; for( i = 0; i <= j; ++i) /* s_next = [states(1:j,next)' zeros(1,Memory-j)]; */ s_next[i] = S(i,next); for( ; i < Memory; ++i) s_next[i] = 0; #if DEBUG > 1 mexPrintf( "s_cur:"); for( i = 0; i < Memory; ++i) mexPrintf( " %g", s_cur[i]); mexPrintf( "\n"); mexPrintf( "s_next:"); for( i = 0; i < Memory; ++i) mexPrintf( " %g", s_next[i]); mexPrintf( "\n"); #endif if( Memory > 1) { /* a = all( s_cur(1:Memory-1) == s_next(2:Memory)); */ a = 1; for( i = 0; i < Memory-1; ++i) if( s_cur[i] != s_next[i+1]) { a = 0; break; } } else a = 0; if( a || Memory == 1) { for( i = 0; i < Memory; ++i) /* y = [s_next s_cur(Memory)]; */ y[i] = s_next[i]; y[Memory] = s_cur[Memory-1]; w = 0; for( i = 0; i < L; ++i) /* w = y * FIR; */ w += y[i] * FIR[i]; } } else { w = trans[cur+next*M]; if( w != HUGE_VAL && (GAUSS || GM || ML) ) { for( i = 0; i < Memory; ++i) /* y = [states(:,next)' states(Memory,cur)]; */ y[i] = S(i,next); y[Memory] = S(Memory-1,cur); } } if( w != HUGE_VAL) { if( ML) { s2 = var_noise; #if 0 #if DEBUG > 2 mexPrintf( "get_cost: begin s2 += y*Rh*y', s2 = %g\n", s2); #endif #endif /* s2 = s2 + y*Rh*y'; */ pg = Rh; for( i = 0; i < L; ++i) { t[i] = 0; for( k = 0; k < L; ++k) t[i] += y[k] * *pg++; } #if 0 #if DEBUG > 2 mexPrintf( "get_cost: almost done s2 += y*Rh*y'\n"); #endif #endif for( i = 0; i < L; ++i) s2 += t[i] * y[i]; #if 0 #if DEBUG > 2 mexPrintf( "get_cost: begin w = zz - ..., s2 = %g\n", s2); #endif #endif /* w = zz - y*avg_h(:,j); */ w = zz; if( avg_h_N == 1) /* Gauss-Markov h -- this is wrong, but just prevents crashing */ pg = avg_h; else pg = avg_h + j*L; for( i = 0; i < L; ++i) w -= y[i] * *pg++; #if 0 #if DEBUG > 2 mexPrintf( "get_cost: begin w = zz - ..., s2 = %g\n", s2); #endif #endif w = (w*w)/s2 + log(s2); } else { /* !ML */ w = zz - w; w = w * w; if( g) { /* w = w + y*G*y'; */ pg = g; for( i = 0; i < L; ++i) { t[i] = 0; for( k = 0; k < L; ++k) t[i] += y[k] * *pg++; } for( i = 0; i < L; ++i) w += t[i] * y[i]; } } /* if ML */ } /* w != HUGE_VAL */ return w; } #endif /* GET_COST_M */ static void vit( void) { int i, j, k, i_min; double p, w_min; size_t v_size; int *v, *pv; double *w_new; #ifdef GET_COST_M double *p_zz, *p_i, *p_k, *p_j; mxArray *mx_input[4]; mxArray *mx_ww; mx_input[0] = mxCreateDoubleMatrix(1,1,mxREAL); mx_input[1] = mxCreateDoubleMatrix(1,1,mxREAL); mx_input[2] = mxCreateDoubleMatrix(1,1,mxREAL); mx_input[3] = mxCreateDoubleMatrix(1,1,mxREAL); mx_ww = mxCreateDoubleMatrix(1,1,mxREAL); p_zz = mxGetPr(mx_input[0]); p_i = mxGetPr(mx_input[1]); p_k = mxGetPr(mx_input[2]); p_j = mxGetPr(mx_input[3]); #else #ifdef COST_UPDATE mxArray *mx_j = mxCreateDoubleMatrix( 1, 1, mxREAL); double *p_j = mxGetPr( mx_j); #endif #endif /* initialize stuff for Viterbi */ v_size = M * N * sizeof(int); v = mxMalloc( v_size); memset( v, -1, v_size); w_new = mxMalloc( M * sizeof(double)); if( zstate > 0) { /* initial state is known */ for( i = 0; i < M; ++i) w[i] = HUGE_VAL; w[zstate-1] = 0; } else if( zstate == 0) { /* initial state is unknown */ for( i = 0; i < M; ++i) w[i] = 0; } else { /* zstate < 0, initial state is known statistically */ for( i = 0; i < M; ++i) /* with initial cost vector w0 */ w[i] = w0[i]; } /* initialize stuff for get_cost() */ #ifndef GET_COST_M s_cur = mxMalloc( Memory * sizeof(double)); s_next = mxMalloc( Memory * sizeof(double)); y = mxMalloc( L * sizeof(double)); #endif #ifdef COST_UPDATE #if DEBUG > 2 mexPrintf("vit: calling cost_init\n"); #endif if( mexCallMATLAB( 0, NULL, 0, NULL, "cost_init")) mexErrMsgTxt("mexCallMATLAB cost_init failed"); #if DEBUG > 2 mexPrintf("vit: return from cost_init\n"); #endif #endif #ifndef GET_COST_M trans = mxGetPr( mx_trans); FIR = mxGetPr( mx_FIR); g = mxGetPr( mx_g); #if DEBUG > 1 mexPrintf( "trans = %p, FIR = %p, g = %p\n", trans, FIR, g); #endif t = mxMalloc( L * sizeof(double)); #endif /*** main Viterbi loop ***/ pv = v; for( j = 0; j < N; ++j) { /* for each column, i.e. time index */ for( i = 0; i < M; ++i) /* set big values to start with */ w_new[i] = HUGE_VAL; for( i = 0; i < M; ++i) { /* for each row, i.e. state index */ if( w[i] == HUGE_VAL) /* HUGE_VAL marks unreachable states */ continue; for( k = 0; k < M; ++k) { /* for each possible next state */ #ifdef GET_COST_M *p_zz = z[j]; *p_i = i+1; *p_k = k+1; *p_j = j+1; if( mexCallMATLAB( 1, &mx_ww, 4, mx_input, "get_cost")) mexErrMsgTxt("mexCallMATLAB get_cost failed"); p = mxGetScalar(mx_ww); #else p = get_cost( z[j], i, k, j); #endif #if DEBUG > 1 mexPrintf( "j = %d, i = %d, k = %d, p = %g\n", j, i, k, p); #endif if( p == HUGE_VAL) /* unreachable state */ continue; if( p + w[i] < w_new[k]) { /* found better path */ w_new[k] = p + w[i]; pv[k] = i; } } } for( i = 0; i < M; ++i) /* save new costs */ w[i] = w_new[i]; #if DEBUG > 1 mexPrintf( "w ="); for( i = 0; i < M; ++i) mexPrintf( " %g", w[i]); mexPrintf( "\n"); #endif pv += M; #ifdef COST_UPDATE #if DEBUG > 2 mexPrintf("vit: calling cost_update\n"); #endif *p_j = j + 1; #ifdef GET_COST_M if( mexCallMATLAB( 0, NULL, 1, &mx_input[3], "cost_update")) mexErrMsgTxt("mexCallMATLAB cost_update failed"); #else if( mexCallMATLAB( 0, NULL, 1, &mx_j, "cost_update")) mexErrMsgTxt("mexCallMATLAB cost_update failed"); trans = mxGetPr( mx_trans); FIR = mxGetPr( mx_FIR); g = mxGetPr( mx_g); #endif #if DEBUG > 2 mexPrintf("vit: return from cost_update\n"); #endif #endif } /*** get optimal path ***/ i_min = 0; w_min = w[0]; for( i = 0; i < M; ++i) if( w[i] < w_min) { i_min = i; w_min = w[i]; } for( j = N-1; j >= 0; --j) { path_index[j] = i_min + 1; i_min = v[i_min+M*j]; } #if DEBUG /* return v */ { int n = M*N; double *pV = V; /* v has to be copied, since v is int here but is double in vit.m */ /* v in vit.m has an extra initial column */ if( zstate > 0) { for( i = 0; i < M; ++i) *pV++ = 0; V[zstate-1] = zstate; } else { for( i = 0; i < M; ++i) *pV++ = i+1; } pv = v; while( n) { *pV++ = (*pv++) + 1; --n; } } #endif mxFree( v); mxFree( w_new); #ifndef GET_COST_M mxFree(s_cur); mxFree(s_next); mxFree(y); mxFree(t); #endif } void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { const mxArray *mx_states; const mxArray *mx_z; const mxArray *mx_HUGE_VAL; mxArray *mx_path_index; mxArray *mx_w; #if DEBUG mxArray *mx_V; #endif /* get states, Memory, M */ if( (mx_states = mexGetArrayPtr( "states", "global")) == NULL) mexErrMsgTxt("global variable ``states'' not found"); states = mxGetPr( mx_states); Memory = mxGetM( mx_states); M = mxGetN( mx_states); L = Memory + 1; /* get z, N */ if( (mx_z = mexGetArrayPtr( "z", "global")) == NULL) mexErrMsgTxt("global variable ``z'' not found"); z = mxGetPr( mx_z); N = mxGetN( mx_z); #ifndef GET_COST_M /* get trans */ if( (mx_trans = mexGetArrayPtr( "trans", "global")) == NULL) mexErrMsgTxt("global variable ``trans'' not found"); /* get FIR, g */ if( (mx_FIR = mexGetArrayPtr( "FIR", "global")) == NULL) mexErrMsgTxt("global variable ``FIR'' not found"); if( (mx_g = mexGetArrayPtr( "G", "global")) == NULL) mexErrMsgTxt("global variable ``G'' not found"); #if DEBUG > 1 mexPrintf( "mx_FIR = %p, mx_g = %p\n", mx_FIR, mx_g); #endif /* get GAUSS, GM, and ML (these might not exist in non-time-varying isi/) */ GAUSS = GM = ML = 0; if( (mx_tmp = mexGetArrayPtr( "GAUSS", "global")) == NULL) mexErrMsgTxt("global variable ``GAUSS'' not found"); GAUSS = mxGetScalar( mx_tmp); if( (mx_tmp = mexGetArrayPtr( "GM", "global")) == NULL) mexErrMsgTxt("global variable ``GM'' not found"); GM = mxGetScalar( mx_tmp); if( (mx_tmp = mexGetArrayPtr( "ML", "global")) == NULL) mexErrMsgTxt("global variable ``ML'' not found"); ML = mxGetScalar( mx_tmp); #if DEBUG > 1 mexPrintf( "GAUSS = %d, GM = %d, ML = %d\n", GAUSS, GM, ML); #endif /* get var_noise */ if( (mx_tmp = mexGetArrayPtr( "var_noise", "global")) == NULL) mexErrMsgTxt("global variable ``var_noise'' not found"); var_noise = mxGetScalar( mx_tmp); #if DEBUG > 1 mexPrintf( "var_noise = %g\n", var_noise); #endif /* get Rh and avg_h */ if( (mx_Rh = mexGetArrayPtr( "Rh", "global")) == NULL) mexErrMsgTxt("global variable ``Rh'' not found"); Rh = mxGetPr( mx_Rh); if( (mx_avg_h = mexGetArrayPtr( "avg_h", "global")) == NULL) mexErrMsgTxt("global variable ``avg_h'' not found"); avg_h = mxGetPr( mx_avg_h); avg_h_N = mxGetN( mx_avg_h); #if DEBUG > 1 mexPrintf( "Rh(%d x %d) = %p, avg_h = %p\n", mxGetM(mx_Rh), mxGetN(mx_Rh), Rh, avg_h); #endif /* get pilot bit array */ if( (mx_tmp = mexGetArrayPtr( "pilot", "global")) == NULL) mexErrMsgTxt("global variable ``pilot'' not found"); pilot = mxGetPr( mx_tmp); if( pilot && !( (mxGetM(mx_tmp) == 1 && mxGetN(mx_tmp) == N) || (mxGetM(mx_tmp) == N && mxGetN(mx_tmp) == 1) ) ) mexErrMsgTxt("global variable ``pilot'' has wrong size."); #endif /* ifndef GET_COST_M */ /* get HUGE_VAL */ if( (mx_HUGE_VAL = (mxArray *) mexGetArrayPtr( "HUGE_VAL", "global")) == NULL) mexErrMsgTxt("global variable ``HUGE_VAL'' not found"); HUGE_VAL = mxGetScalar( mx_HUGE_VAL); /* get zstate, zstate_0, and w0 */ if( (mx_zstate = (mxArray *) mexGetArrayPtr( "zstate", "global")) == NULL) mexErrMsgTxt("global variable ``zstate'' not found"); zstate = mxGetScalar( mx_zstate); if( (mx_zstate_0 = (mxArray *) mexGetArrayPtr( "zstate_0", "global")) == NULL) mexErrMsgTxt("global variable ``zstate_0'' not found"); zstate_0 = mxGetScalar( mx_zstate_0); if( (mx_w0 = (mxArray *) mexGetArrayPtr( "w0", "global")) == NULL) mexErrMsgTxt("global variable ``w0'' not found"); w0 = mxGetPr( mx_w0); if( zstate < 0) { if( w0 == NULL) mexErrMsgTxt("global variable ``w0'' is empty"); else if( mxGetM(mx_w0) != 1 || mxGetN(mx_w0) != M) mexErrMsgTxt("global variable ``w0'' has wrong size"); } /* get/set path_index */ if( (mx_path_index = (mxArray *) mexGetArrayPtr( "path_index", "global")) == NULL) mexErrMsgTxt("global variable ``path_index'' not found"); if( mxGetM( mx_path_index) != N) { mxFree( mxGetPr( mx_path_index)); mxSetPr( mx_path_index, mxMalloc( N*sizeof(double))); mxSetM( mx_path_index, N); } mxSetN( mx_path_index, 1); mxFree( mxGetPi( mx_path_index)); mxSetPi( mx_path_index, NULL); path_index = mxGetPr( mx_path_index); #if DEBUG /* get/set V */ if( (mx_V = (mxArray *) mexGetArrayPtr( "v", "global")) == NULL) mexErrMsgTxt("global variable ``v'' not found"); if( mxGetN( mx_V) != M*(N+1)) { mxFree( mxGetPr( mx_V)); mxSetPr( mx_V, mxMalloc( M*(N+1)*sizeof(double))); mxSetN( mx_V, M*(N+1)); } mxSetM( mx_V, 1); mxFree( mxGetPi( mx_V)); mxSetPi( mx_V, NULL); V = mxGetPr( mx_V); #endif /* get/set w */ if( (mx_w = (mxArray *) mexGetArrayPtr( "w", "global")) == NULL) mexErrMsgTxt("global variable ``w'' not found"); if( mxGetN( mx_w) != M) { mxFree( mxGetPr( mx_w)); mxSetPr( mx_w, mxMalloc( M*sizeof(double))); mxSetN( mx_w, M); } mxSetM( mx_w, 1); mxFree( mxGetPi( mx_w)); mxSetPi( mx_w, NULL); w = mxGetPr( mx_w); #if DEBUG > 1 mexPrintf( "Memory = %d, M = %d\n", Memory, M); mexPrintf( "N = %d\n", N); #endif /* run viterbi */ vit(); }