/* $Revision: 1.0 $ $Date: 2003/03/27 13:28:43 $ */ /*=============================================================================================================== * KALCVS The Kalman smoothing * * State space model is defined as follows: * z(t+1) = a+F*z(t)+eta(t) (state or transition equation) * y(t) = b+H*z(t)+eps(t) (observation or measurement equation) * * [sm, vsm] = kalcvs(data, a, F, b, H, var, pred, vpred, ) * uses backward recursions to compute the smoothed estimate z(t|T) and its covariance matrix, P(t|T), * where T is the number of observations in the complete data set. * * The inputs to the KALCVS function are as follows: * data is a Ny×T matrix containing data (y(1), ... , y(T))'. * a is an Nz×1 vector for a time-invariant input vector in the transition equation, * or an Nz×T vector containing T input vectors in the transition equation. * F is an Nz×Nz matrix for a time-invariant transition matrix in the transition equation, * or an Nz×Nz×T matrix containing T transition matrices in the transition equation. * b is an Ny×1 vector for a time-invariant input vector in the measurement equation, * or an Ny×T vector containing T input vectors in the measurement equation. * H is an Ny×Nz matrix for a time-invariant measurement matrix in the measurement equation, * or an Ny×Nz×T matrix containing T measurement matrices in the measurement equation. * var is an (Ny+Nz)×(Ny+Nz) covariance matrix for the errors in the transition and the measurement equations, * or an (Ny+Nz)×(Ny+Nz)×T matrix containing T covariance matrices in the transition equation and * measurement equation noises, that is, [eta(t)', eps(t)']'. * pred is an Nz×T matrix containing one-step forecasts (z(1|0), ... , z(T|T-1))'. * vpred is an Nz×Nz×T matrix containing mean square error matrices of predicted state vectors (P(1|0), ... , P(T|T-1))'. * un is an optional Nz×1 vector containing u(T). The returned value is u(0). * vun is an optional Nz×Nz covariance matrix containing U(T). The returned value is U(0). * * The KALCVS function returns the following output: * sm is an Nz×T matrix containing smoothed state vectors (z(1|T), ... , z(T|T))'. * vsm is an Nz×Nz×T matrix containing covariance matrices of smoothed state vectors (P(1|T), ... , P(T|T))'. * * This is a MEX-file for MATLAB. * Copyright 2002-2003 Federal Reserve Bank of Atlanta * Iskander Karibzhanov 3-27-03. * Master of Science in Computational Finance * Georgia Institute of Technology *=============================================================================================================== * Revision history: * * 03/27/2003 - algorithm and interface were adapted from SAS/IML KALCVS subroutine for use in MATLAB MEX file * *===============================================================================================================*/ #include #include /* Kalman Filter function */ double kalcvs(double *data, int T, int Ny, int Nz, double *a, int inca, double *F, int incF, double *b, int incb, double *H, int incH, double *var, int incvar, double *pred, double *vpred, double *un, double *vun, double *sm, double *vsm); static double one_d=1.0, mone_d=-1.0, zero_d=0.0; static int one=1, zero=0; static char *chl = "L", *chr = "R", *chn = "N", *cht = "T", *chv = "V", *chu = "U", *chs = "S", msg[201]; /* When calling a LAPACK or BLAS function, some platforms require an underscore character following the function name in the call statement. On the PC, IBM_RS, and HP platforms, use the function name alone, with no trailing underscore. On the SGI, LINUX, Solaris, Alpha, and Macintosh platforms, add the underscore after the function name. */ #if defined(__linux__) || defined(__alpha) || defined(__sgi) #define ddot ddot_ #define dcopy dcopy_ #define daxpy daxpy_ #define dsymm dsymm_ #define dgemm dgemm_ #define dtrmm dtrmm_ #define dsyrk dsyrk_ #define dgemv dgemv_ #define dtrsv dtrsv_ #define dtrsm dtrsm_ #define dgetrf dgetrf_ #define dgetrs dgetrs_ #define dgehrd dgehrd_ #define dorghr dorghr_ #define dhseqr dhseqr_ #define dtrsyl dtrsyl_ #define dpotrf dpotrf_ #endif /* BLAS functions */ extern double ddot(int *N, double *X, int *incX, double *Y, int *incY); extern void dcopy(int *N, double *X, int *incX, double *Y, int *incY); extern void daxpy(int *N, double *alpha, double *X, int *incX, double *Y, int *incY); extern void dsymm(char *side, char *uplo, int *M, int *N, double *alpha, double *A, int *lda, double *B, int *ldb, double *beta, double *C, int *ldc); extern void dgemm(char *transA, char *transB, int *M, int *N, int *K, double *alpha, double *A, int *lda, double *B, int *ldb, double *beta, double *C, int *ldc); extern void dtrmm(char *side, char *uplo, char *transA, char *diag, int *M, int *N, double *alpha, double *A, int *lda, double *B, int *ldb); extern void dsyrk(char *uplo, char *trans, int *N, int *K, double *alpha, double *A, int *lda, double *beta, double *C, int *ldc); extern void dgemv(char *transA, int *M, int *N, double *alpha, double *A, int *lda, double *X, int *incX, double *beta, double *Y, int *incY); extern void dtrsv(char *uplo, char *transA, char *diag, int *N, double *A, int *lda, double *X, int *incX); extern void dtrsm(char *side, char *uplo, char *transA, char *diag, int *M, int *N, double *alpha, double *A, int *lda, double *B, int *ldb); /* LAPACK functions */ extern void dgetrf(int *m, int *n, double *A, int *lda, int *ipiv, int *info); extern void dgetrs(char *trans, int *n, int *nrhs, double *A, int *lda, int *ipiv, double *b, int *ldb, int *info); extern void dgehrd(int *n, int *ilo, int *ihi, double *A, int *lda, double *tau, double *work, int *lwork, int *info); extern void dorghr(int *n, int *ilo, int *ihi, double *A, int *lda, double *tau, double *work, int *lwork, int *info); extern void dhseqr(char *job, char *compz, int *n, int *ilo, int *ihi, double *h, int *ldh, double *wr, double *wi, double *z, int *ldz, double *work, int *lwork, int *info); extern void dtrsyl(char *trana, char *tranb, int *isgn, int *m, int *n, double *A, int *lda, double *B, int *ldb, double *C, int *ldc, double *dscale, int *info); extern void dpotrf(char *uplo, int *n, double *a, int *lda, int *info); void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { double *data, *a, *F, *b, *H, *var, *pred, *vpred, *un, *vun, *sm, *vsm; int T, Nz, Ny, rows, cols, ndims, inca=0, incF=0, incb=0, incH=0, incvar=0; const int *dims; int dim[3]; /*===== check number of input and output arguments =====*/ if ( (nrhs != 8) && (nrhs != 10) ) mexErrMsgTxt("Eight or ten input arguments required."); if ( nlhs != 2 ) mexErrMsgTxt("Two output arguments required."); /*===== check input matrix dimensions =====*/ /* data must be Ny×T matrix */ Ny = mxGetM(prhs[0]); T = mxGetN(prhs[0]); /* a must be Nz×1 vector or Nz×T matrix */ Nz = mxGetM(prhs[1]); if ( (cols=mxGetN(prhs[1])) == 1 ) inca = 0; else if ( cols == T ) inca = Nz; else { sprintf(msg, "data has Ny=%d rows and T=%d columns, a has Nz=%d rows, " "but the number of columns in a (%d) is neither one nor T=%d.", Ny, T, Nz, cols, T); mexErrMsgTxt(msg); } /* F must be Nz×Nz or Nz×Nz×T matrix */ dims = mxGetDimensions(prhs[2]); if ( (dims[0] != Nz) || (dims[1] != Nz) ) { sprintf(msg, "a and F must have the same number of rows. " "F must be square in the first two dimensions. " "a has Nz=%d rows, but F has %d rows and %d columns.", Nz, dims[0], dims[1]); mexErrMsgTxt(msg); } ndims = mxGetNumberOfDimensions(prhs[2]); if ( ndims == 2 ) incF = 0; else if ( ndims == 3 ) { if ( dims[2] == T ) incF = Nz*Nz; else { sprintf(msg, "F must have T=%d elements in third dimension.", T); mexErrMsgTxt(msg); } } else mexErrMsgTxt("F must be two- or three-dimensional matrix."); /* b must be Ny×1 or Ny×T vector */ if ( (rows=mxGetM(prhs[3])) != Ny ) { sprintf(msg, "data and b must have the same number of rows. " "data has Ny=%d rows, but the number of rows in b is %d.", Ny, rows); mexErrMsgTxt(msg); } if ( (cols=mxGetN(prhs[3])) == 1 ) incb = 0; else if ( cols == T ) incb = Ny; else { sprintf(msg, "data has Ny=%d rows and T=%d columns, b has Ny=%d rows, " "but the number of columns in b (%d) is neither one nor T=%d.", Ny, T, Ny, cols, T); mexErrMsgTxt(msg); } /* H must be Ny×Nz or Ny×Nz×T matrix */ dims = mxGetDimensions(prhs[4]); if ( dims[0] != Ny ) { sprintf(msg, "data and H must have the same number of rows. " "data has Ny=%d rows, but the number of rows in H is %d.", Ny, dims[0]); mexErrMsgTxt(msg); } if ( dims[1] != Nz ) { sprintf(msg, "H must have the same number of columns as rows in matrix a. " "a has Nz=%d rows, but the number of columns in H is %d.", Nz, dims[1]); mexErrMsgTxt(msg); } ndims = mxGetNumberOfDimensions(prhs[4]); if ( ndims == 2 ) incH = 0; else if ( ndims == 3 ) { if ( dims[2] == T ) incH = Ny*Nz; else { sprintf(msg, "H must have T=%d elements in third dimension.", T); mexErrMsgTxt(msg); } } else mexErrMsgTxt("H must be two- or three-dimensional matrix."); /* var must be (Ny+Nz)×(Ny+Nz) or (Ny+Nz)×(Ny+Nz)×T matrix */ dims = mxGetDimensions(prhs[5]); if ( (dims[0] != Ny+Nz) || (dims[1] != Ny+Nz) ) { sprintf(msg, "var must contain variance matrix for the errors in transition and measurement equations. " "var must be square p.d.f. %d×%d or %d×%d×%d matrix, but your var has %d rows and %d columns.", Ny+Nz, Ny+Nz, Ny+Nz, Ny+Nz, T, dims[0], dims[1]); mexErrMsgTxt(msg); } ndims = mxGetNumberOfDimensions(prhs[5]); if ( ndims == 2 ) incvar = 0; else if ( ndims == 3 ) { if ( dims[2] == T ) incvar = (Ny+Nz)*(Ny+Nz); else { sprintf(msg, "var must have T=%d elements in third dimension.", T); mexErrMsgTxt(msg); } } else mexErrMsgTxt("var must be two- or three-dimensional matrix."); /* pred must be Nz×T matrix */ if ( (rows=mxGetM(prhs[6]) != Nz) ) { sprintf(msg, "a and pred must have the same number of rows. " "a has %d rows, but the number of rows in pred is %d.", T, rows); mexErrMsgTxt(msg); } if ( (cols=mxGetN(prhs[6]) != T) { sprintf(msg, "data and pred must have the same number of columns. " "data has %d columns, but the number of columns in pred is %d.", T, cols); mexErrMsgTxt(msg); } /* vpred must be Nz×Nz×T matrix */ dims = mxGetDimensions(prhs[7]); if ( (dims[0] != Nz) || (dims[1] != Nz) ) { sprintf(msg, "a and vpred must have the same number of rows. " "vpred must be square in the first two dimensions. " "a has Nz=%d rows, but vpred has %d rows and %d columns.", Nz, dims[0], dims[1]); mexErrMsgTxt(msg); } ndims = mxGetNumberOfDimensions(prhs[7]); if ( ndims == 3 ) { if ( dims[2] != T ) sprintf(msg, "vpred must have T=%d elements in third dimension.", T); mexErrMsgTxt(msg); } } else mexErrMsgTxt("vpred must be three-dimensional matrix."); /* if specified, un must be Nz×1 vector and vun must be Nz×Nz matrix */ if ( nrhs==10 ) { if ( ( (rows=mxGetM(prhs[8])) != Nz ) || ( (cols=mxGetN(prhs[8])) != 1 ) ) { sprintf(msg, "a and un must have the same number of rows. un must be Nz×1 vector. " "a has Nz=%d rows, but un has %d rows and %d columns.", Nz, rows, cols); mexErrMsgTxt(msg); } if ( ( (rows=mxGetM(prhs[9])) != Nz ) || ( (cols=mxGetN(prhs[9])) != Nz ) ) { sprintf(msg, "a and vun must have the same number of rows. vun must be square Nz×Nz matrix. " "a has Nz=%d rows, but vun has %d rows and %d columns.", Nz, rows, cols); mexErrMsgTxt(msg); } } /*===== get pointers to input arguments =====*/ data = mxGetPr(prhs[0]); a = mxGetPr(prhs[1]); F = mxGetPr(prhs[2]); b = mxGetPr(prhs[3]); H = mxGetPr(prhs[4]); var = mxGetPr(prhs[5]); pred = mxGetPr(prhs[6]); vpred = mxGetPr(prhs[7]); if ( nrhs == 10 ) { un = mxGetPr(prhs[8]); vun = mxGetPr(prhs[9]); } else { un = NULL; vun = NULL; } /*===== get pointers to output arguments =====*/ dim[0] = Nz; dim[1] = Nz; dim[2] = T; sm = mxGetPr(plhs[0]=mxCreateDoubleMatrix(Nz, T, mxREAL)); vsm = mxGetPr(plhs[1]=mxCreateNumericArray(3, dim, mxDOUBLE_CLASS, mxREAL)); /*===== compute Kalman Filter =====*/ kalcvs(data, T, Ny, Nz, a, inca, F, incF, b, incb, H, incH, var, incvar, pred, vpred, un, vun, sm, vsm); } #define PI 3.141592653589793238 double kalcvf(double *data, int T, int Ny, int Nz, double *a_all, int inca, double *F_all, int incF, double *b_all, int incb, double *H_all, int incH, double *var_all, int incvar, double *pred, double *vpred, double *un, double *vun, double *sm, double *vsm) { double *a, *F, *b, *H, *var, *u, *vu, *V, *G, *R, *P; int Nyz, NyzNyz, NyNz, NzNz, info=0, t, i; Nyz = Ny+Nz; NyzNyz = Nyz*Nyz; NyNz = Ny*Nz; NzNz = Nz*Nz; /* If specified, pointers pred and vpred will be incremented by incpred and incvpred to save data from each iteration. If not specified, allocate memory for vector pred and matrix vpred, and leave zero increments */ if (un) u = un; else u = (double *)mxCalloc(Nz, sizeof(double)); if (vun) vu = vun; else vu = (double *)mxCalloc(NzNz, sizeof(double)); /* allocate memory for temporary variables */ a = (double *) mxCalloc(Nz, sizeof(double)); F = (double *) mxCalloc(NzNz, sizeof(double)); b = (double *) mxCalloc(Ny, sizeof(double)); H = (double *) mxCalloc(NyNz, sizeof(double)); var = (double *) mxCalloc(NyzNyz, sizeof(double)); P = (double *) mxCalloc(NzNz, sizeof(double)); /* var = [V G; G' R] */ V = var; /* V(t) = Var(eta(t)) */ G = var+Nyz*Nz; /* G(t) = Cov(eta(t),eps(t)) */ R = G+Nz; /* R(t) = Var(eps(t)) */ for (t=0; t0) { /* F = F_all(:,:,t) */ dcopy(&NzNz, &F_all[t*incF], &one, F, &one); /* a = a_all(:,t) */ dcopy(&Nz, &a_all[t*inca], &one, a, &one); /* a = F*pred+a */ dgemv(chn, &Nz, &Nz, &one_d, F, &Nz, pred, &one, &one_d, a, &one); /* F = F*P */ dtrmm(chr, chl, chn, chn, &Nz, &Nz, &one_d, P, &Nz, F, &Nz); /* G = F*H'+G */ dgemm(chn, cht, &Nz, &Ny, &Nz, &one_d, F, &Nz, H, &Ny, &one_d, G, &Nyz); /* G = G/R' */ dtrsm(chr, chl, cht, chn, &Nz, &Ny, &one_d, R, &Nyz, G, &Nyz); /* a = G*b+a */ dgemv(chn, &Nz, &Ny, &one_d, G, &Nyz, b, &one, &one_d, a, &one); /* V = F*F'+V */ dsyrk(chl, chn, &Nz, &Nz, &one_d, F, &Nz, &one_d, V, &Nyz); /* V = -G*G'+V */ dsyrk(chl, chn, &Nz, &Ny, &mone_d, G, &Nyz, &one_d, V, &Nyz); } if (filt || vfilt) { /* H = H*P' */ dtrmm(chr, chl, cht, chn, &Ny, &Nz, &one_d, P, &Nz, H, &Ny); /* H = R\H */ dtrsm(chl, chl, chn, chn, &Ny, &Nz, &one_d, R, &Nyz, H, &Ny); if (filt) { /* filt = pred */ dcopy(&Nz, pred, &one, filt, &one); /* filt = H'*b+filt */ dgemv(cht, &Ny, &Nz, &one_d, H, &Ny, b, &one, &one_d, filt, &one); /* increment filt */ filt += Nz; } if (vfilt) { /* vfilt = vpread */ dcopy(&NzNz, vpred, &one, vfilt, &one); /* vfilt = -H'*H+vfilt */ dsyrk(chl, cht, &Nz, &Ny, &mone_d, H, &Ny, &one_d, vfilt, &Nz); /* increment vfilt */ vfilt += NzNz; } } if (t0) { /* pred(:,t+1) = a */ if (incpred) pred += incpred; dcopy(&Nz, a, &one, pred, &one); /* vpred(:,:,t+1) = V */ if (incvpred) vpred += incvpred; for (i=0; i1 && (incpred || incvpred)) { for (t=T+1; t 0) { #ifndef _DLYAP_NOMSG_ sprintf(msg, "u(%d,%d) is 0. The factorization has been completed, but U" "is exactly singular. Division by 0 will occur if you use the factor " "U for solving a system of linear equations.", info[0], info[0]); mexWarnMsgTxt(msg); #endif mxFree(a); mxFree(c); mxFree(p); return; } dgetrs(chn, &n, &n, c, &n, p, a, &n, info); if (info[0] < 0) { sprintf(msg,"The %dth parameter to DGETRS had an illegal value.",-info[0]); mexErrMsgTxt(msg); } mxFree(p); /* b = a-eye(n) */ b = (double *)mxCalloc(nn,sizeof(double)); dcopy(&nn, a, &one, b, &one); for (i=0; i 0) { #ifndef _DLYAP_NOMSG_ sprintf(msg,"The algorithm has failed to find all the eigenvalues after " "a total %d iterations. Elements 1,2, ..., %d and %d, %d," " ..., %d of wr and wi contain the real and imaginary parts of the " "eigenvalues which have been found.", 30*(ihi-ilo+1), ilo-1, info[0]+1, info[0]+2, n); mexWarnMsgTxt(msg); #endif mxFree(a); mxFree(b); mxFree(c); return; } if (info[0] < 0) { sprintf(msg,"The %dth parameter to DHSEQR had an illegal value.",-info[0]); mexErrMsgTxt(msg); } /* Transform x = b'*x*b */ /* c = x'*b */ if (symm) dsymm(chl, chu, &n, &n, &one_d, x, &n, b, &n, &zero_d, c, &n); else dgemm(cht, chn, &n, &n, &n, &one_d, x, &n, b, &n, &zero_d, c, &n); /* x = c'*b */ dgemm(cht, chn, &n, &n, &n, &one_d, c, &n, b, &n, &zero_d, x, &n); /* Solve Sylvester's equation a*x+x*a'=c for real quasi-triangular matrix a. */ dtrsyl(chn, cht, &isgn, &n, &n, a, &n, a, &n, x, &n, &scale, info); if (info[0] < 0) { sprintf(msg,"The %dth parameter to DTRSYL had an illegal value.",-info[0]); mexErrMsgTxt(msg); } mxFree(a); #ifndef _DLYAP_NOMSG_ if (info[0] == 1) { sprintf(msg,"Solution does not exist or is not unique. A and C have common " "or close eigenvalues perturbed values were used to solve the equation."); mexWarnMsgTxt(msg); } #endif /* Find untransformed solution x = b*x*b' */ /* c = b*x */ if (symm) dsymm(chr, chu, &n, &n, &one_d, x, &n, b, &n, &zero_d, c, &n); else dgemm(chn, chn, &n, &n, &n, &one_d, b, &n, x, &n, &zero_d, c, &n); /* x = c*b' */ dgemm(chn, cht, &n, &n, &n, &one_d, c, &n, b, &n, &zero_d, x, &n); mxFree(b); mxFree(c); }