source: orange/source/orange/earth.cpp @ 10952:bd8e4d8e590b

Revision 10952:bd8e4d8e590b, 128.4 KB checked in by Ales Erjavec <ales.erjavec@…>, 21 months ago (diff)

Return an error code if trying to delete the intercept in EvalSubsetsUsingXtx.

Line 
1// This code is derived from code in the Rational Fortran file dmarss.r which is
2// part of the R and S mda package by Hastie and Tibshirani.
3// Comments containing "TODO" mark known issues
4//
5// See the R earth documentation for descriptions of the principal data structures.
6// See also www.milbo.users.sonic.net.
7//
8// Stephen Milborrow Feb 2007 Petaluma
9//
10//-----------------------------------------------------------------------------
11// ...
12//-----------------------------------------------------------------------------
13// References:
14//
15// HastieTibs: Trevor Hastie and Robert Tibshirani
16//      S library mda version 0.3.2 dmarss.r Ratfor code
17//      Modifications for R by Kurt Hornik, Friedrich Leisch, Brian Ripley
18//
19// FriedmanMars: Multivariate Adaptive Regression Splines (with discussion)
20//      Annals of Statistics 19/1, 1--141, 1991
21//
22// FriedmanFastMars: Friedman "Fast MARS"
23//      Dep. of Stats. Stanford, Tech Report 110, May 1993
24//
25// Miller: Alan Miller (2nd ed. 2002) Subset Selection in Regression
26//
27//-----------------------------------------------------------------------------
28// This program is free software; you can redistribute it and/or modify
29// it under the terms of the GNU General Public License as published by
30// the Free Software Foundation; either version 2 of the License, or
31// (at your option) any later version.
32//
33// This program is distributed in the hope that it will be useful,
34// but WITHOUT ANY WARRANTY; without even the implied warranty of
35// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
36// GNU General Public License for more details.
37//
38// A copy of the GNU General Public License is available at
39// http://www.r-project.org/Licenses
40//
41//-----------------------------------------------------------------------------
42
43/*
44    This file is part of Orange.
45
46    Copyright 1996-2011 Faculty of Computer and Information Science, University of Ljubljana
47    Contact: janez.demsar@fri.uni-lj.si
48
49    Orange is free software: you can redistribute it and/or modify
50    it under the terms of the GNU General Public License as published by
51    the Free Software Foundation, either version 3 of the License, or
52    (at your option) any later version.
53
54    Orange is distributed in the hope that it will be useful,
55    but WITHOUT ANY WARRANTY; without even the implied warranty of
56    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
57    GNU General Public License for more details.
58
59    You should have received a copy of the GNU General Public License
60    along with Orange.  If not, see <http://www.gnu.org/licenses/>.
61*/
62
63/*
64     Changes to earth.c from earth R package:
65     - Added defines for STANDALONE, USING_BLAS, _DEBUG
66     - Removed  #include <crtdbg.h> for windows
67     - Fix defines for ISNAN and FINITE to work on non MSC compilers
68     - Removed debugging code for windows
69     - Removed definitions of bool, true false
70     - Define _C_ as "C" for all compilers
71     - Define c linkage for error, xerbla
72     - Replaced POS_INF static global variable with numeric_limits<double>::infinity()
73     - Added #include <limits>
74     - Changed include of earth.h to earth.ppp and moved it before the module level defines
75     - Changed EvalSubsetsUsingXtX to return an error code if lin. dep. terms in bx
76
77    - TODO: Move global vars inside the functions using them (most are local)
78 */
79
80#include <stdlib.h>
81#include <stdio.h>
82#include <stdarg.h>
83#include <string.h>
84#include <float.h>
85#include <math.h>
86#include <limits>
87
88#include "earth.ppp"
89
90#define STANDALONE 1
91#define USING_BLAS 1
92#define _DEBUG 0
93
94#if !STANDALONE
95#define USING_R 1
96#endif // STANDALONE
97
98
99#if _MSC_VER && _DEBUG
100    #include <crtdbg.h> // microsoft malloc debugging library
101#endif
102
103#if _MSC_VER            // microsoft
104    #define _C_ "C"
105    #if _DEBUG          // debugging enabled?
106        // disable warning: too many actual params for macro (for malloc1 and calloc1)
107        #pragma warning(disable: 4002)
108    #endif
109#else
110    #define _C_ "C"
111//    #ifndef bool
112//        typedef int bool;
113//        #define false 0
114//        #define true  1
115//    #endif
116#endif
117
118#if USING_R             // R with gcc
119    #include "R.h"
120    #include "Rinternals.h" // needed for Allowed function handling
121    #include "allowed.h"
122    #define printf Rprintf
123    #define FINITE(x) R_FINITE(x)
124    #define ASSERT(x)   \
125        if (!(x)) error("internal assertion failed in file %s line %d: %s\n", \
126                        __FILE__, __LINE__, #x)
127#else
128    #define warning printf
129    extern "C" { void error(const char *args, ...); }
130    #ifdef _MSC_VER
131        #define ISNAN(x)  _isnan(x)
132        #define FINITE(x) _finite(x)
133    #else
134        #define ISNAN(x)  isnan(x)
135        #define FINITE(x) finite(x)
136    #endif // _MSC_VER
137
138    #define ASSERT(x)   \
139        if (!(x)) error("internal assertion failed in file %s line %d: %s\n", \
140                        __FILE__, __LINE__, #x)
141#endif
142
143//#include "earth.h"
144
145extern _C_ int dqrdc2_(double *x, int *ldx, int *n, int *p,
146                        double *tol, int *rank,
147                        double *qraux, int *pivot, double *work);
148
149extern _C_ int dqrsl_(double *x, int *ldx, int *n, int *k,
150                        double *qraux, double *y,
151                        double *qy, double *qty, double *b,
152                        double *rsd, double *xb, int *job, int *info);
153
154extern _C_ void dtrsl_(double *t, int *ldt, int *n, double *b, int *job, int *info);
155
156extern _C_ void daxpy_(const int *n, const double *alpha,
157                        const double *dx, const int *incx,
158                        double *dy, const int *incy);
159
160extern _C_ double ddot_(const int *n,
161                        const double *dx, const int *incx,
162                        const double *dy, const int *incy);
163
164#define sq(x)       ((x) * (x))
165#ifndef max
166#define max(a,b)    (((a) > (b)) ? (a) : (b))
167#endif
168#ifndef min
169#define min(a,b)    (((a) < (b)) ? (a) : (b))
170#endif
171
172#define INLINE      inline
173#define USE_BLAS    1     // 1 is faster (tested on Windows XP Pentium with R BLAS)
174                          // also, need USE_BLAS to use bxOrthCenteredT
175
176#define FAST_MARS   1     // 1 to use techniques in FriedmanFastMars (see refs)
177
178#define IOFFSET     1     // printfs only: 1 to convert 0-based indices to 1-based in printfs
179                          // use 0 for C style indices in messages to the user
180
181static const char   *VERSION    = "version 3.2-0"; // change if you modify this file!
182static const double BX_TOL      = 0.01;
183static const double QR_TOL      = 0.01;
184static const double MIN_GRSQ    = -10.0;
185static const double ALMOST_ZERO = 1e-10;
186static const int    ONE         = 1;        // parameter for BLAS routines
187#if _MSC_VER                                // microsoft compiler
188static const double ZERO        = 0.0;
189//static const double POS_INF     = (1.0 / ZERO);
190static const double POS_INF     = std::numeric_limits<double>::infinity();
191#else
192//static const double POS_INF     = (1.0 / 0.0);
193static const double POS_INF     = std::numeric_limits<double>::infinity();
194#endif
195static const int    MAX_DEGREE  = 100;
196
197// Poor man's array indexing -- not pretty, but pretty useful.
198//
199// Note that we use column major ordering. C programs usually use row major
200// ordering but we don't here because the functions in this file are called
201// by R and call Fortran routines which use column major ordering.
202
203#define Dirs_(iTerm,iPred)      Dirs[(iTerm) + (iPred)*(nMaxTerms)]
204#define Cuts_(iTerm,iPred)      Cuts[(iTerm) + (iPred)*(nMaxTerms)]
205
206#define bx_(iCase,iTerm)                bx             [(iCase) + (iTerm)*(nCases)]
207#define bxOrth_(iCase,iTerm)            bxOrth         [(iCase) + (iTerm)*(nCases)]
208#define bxOrthCenteredT_(iTerm,iCase)   bxOrthCenteredT[(iTerm) + (iCase)*(nMaxTerms)]
209#define x_(iCase,iPred)                 x              [(iCase) + (iPred)*(nCases)]
210#define xOrder_(iCase,iPred)            xOrder         [(iCase) + (iPred)*(nCases)]
211#define y_(iCase,iResp)                 y              [(iCase) + (iResp)*(nCases)]
212#define Residuals_(iCase,iResp)         Residuals      [(iCase) + (iResp)*(nCases)]
213#define ycboSum_(iTerm,iResp)           ycboSum        [(iTerm) + (iResp)*(nMaxTerms)]
214#define Betas_(iTerm,iResp)             Betas          [(iTerm) + (iResp)*(nUsedCols)]
215
216// Global copies of some input parameters.  These stay constant for the entire MARS fit.
217static double TraceGlobal;      // copy of Trace parameter
218static int nMinSpanGlobal;      // copy of nMinSpan parameter
219
220static void FreeBetaCache(void);
221static char *sFormatMemSize(const unsigned MemSize, const bool Align);
222
223//-----------------------------------------------------------------------------
224// malloc and its friends are redefined (a) so under Microsoft C using
225// crtdbg.h we can easily track alloc errors and (b) so FreeR() doesn't
226// re-free any freed blocks and (c) so out of memory conditions are
227// immediately detected.
228// So DON'T USE free, malloc, and calloc.  Use free1, malloc1, and calloc1 instead.
229
230// free1 is a macro so we can zero p
231#define free1(p) { if (p) free(p); p = NULL; }
232
233#if _MSC_VER && _DEBUG  // microsoft C and debugging enabled?
234
235#define malloc1(size) _malloc_dbg((size), _NORMAL_BLOCK, __FILE__, __LINE__)
236#define calloc1(num, size) \
237                      _calloc_dbg((num), (size), _NORMAL_BLOCK, __FILE__, __LINE__)
238#else
239static void *malloc1(size_t size, const char *args, ...)
240{
241    void *p = malloc(size);
242    if (!p || TraceGlobal == 1.5) {
243        if (args == NULL)
244            printf("malloc %s\n", sFormatMemSize(size, true));
245        else {
246            char s[100];
247            va_list p;
248            va_start(p, args);
249            vsprintf(s, args, p);
250            va_end(p);
251            printf("malloc %s: %s\n", sFormatMemSize(size, true), s);
252        }
253        fflush(stdout);
254    }
255    if (!p)
256        error("Out of memory (could not allocate %s)", sFormatMemSize(size, false));
257    return p;
258}
259
260static void *calloc1(size_t num, size_t size, const char *args, ...)
261{
262    void *p = calloc(num, size);
263    if (!p || TraceGlobal == 1.5) {
264        if (args == NULL)
265            printf("calloc %s\n", sFormatMemSize(size, true));
266        else {
267            char s[100];
268            va_list p;
269            va_start(p, args);
270            vsprintf(s, args, p);
271            va_end(p);
272            printf("calloc %s: %s\n", sFormatMemSize(size, true), s);
273        }
274        fflush(stdout);
275    }
276    if (!p)
277        error("Out of memory (could not allocate %s)", sFormatMemSize(size, false));
278    return p;
279}
280#endif
281
282// After calling this, on program termination we will get a report if there are
283// writes outside the borders of allocated blocks or if there are non-freed blocks.
284
285#if _MSC_VER && _DEBUG          // microsoft C and debugging enabled?
286static void InitMallocTracking(void)
287{
288    _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_WNDW);
289    _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE);
290    _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDOUT);
291    int Flag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG);
292    Flag |= (_CRTDBG_ALLOC_MEM_DF|_CRTDBG_DELAY_FREE_MEM_DF|_CRTDBG_LEAK_CHECK_DF);
293    _CrtSetDbgFlag(Flag);
294}
295#endif
296
297//-----------------------------------------------------------------------------
298// These are malloced blocks.  They unfortunately have to be declared globally so
299// under R if the user interrupts we can free them using on.exit(.C("FreeR"))
300
301static int    *xOrder;      // local to FindTerm
302static bool   *WorkingSet;  // local to FindTerm and EvalSubsets
303static double *xbx;         // local to FindTerm
304static double *CovSx;       // local to FindTerm
305static double *CovCol;      // local to FindTerm
306static double *ycboSum;     // local to FindTerm (used to be called CovSy)
307static double *bxOrth;      // local to ForwardPass
308static double *yMean;       // local to ForwardPass
309static double *Weights;     // local to ForwardPass and EvalSubsetsUsingXtx
310
311// Transposed and mean centered copy of bxOrth, for fast update in FindKnot.
312// It's faster because there is better data locality as iTerm increases, so
313// better L1 cache use.  This is used only if USE_BLAS is true.
314
315static double *bxOrthCenteredT; // local to ForwardPass
316
317static double *bxOrthMean;      // local to ForwardPass
318static int  *nFactorsInTerm;    // local to Earth or ForwardPassR
319static int  *nUses;             // local to Earth or ForwardPassR
320#if USING_R
321static int *iDirs;              // local to ForwardPassR
322static bool *BoolFullSet;       // local to ForwardPassR
323#endif
324#if FAST_MARS
325static void FreeQ(void);
326#endif
327
328#if USING_R
329void FreeR(void)                // for use by R
330{
331    free1(WorkingSet);
332    free1(CovSx);
333    free1(CovCol);
334    free1(ycboSum);
335    free1(xOrder);
336    free1(bxOrthMean);
337    free1(bxOrthCenteredT);
338    free1(bxOrth);
339    free1(yMean);
340    free1(Weights);
341    free1(BoolFullSet);
342    free1(iDirs);
343    free1(nUses);
344    free1(nFactorsInTerm);
345    FreeBetaCache();
346#if FAST_MARS
347    FreeQ();
348#endif
349}
350#endif
351
352//-----------------------------------------------------------------------------
353static char *sFormatMemSize(const unsigned MemSize, const bool Align)
354{
355    static char s[100];
356    double Size = (double)MemSize;
357    if(Size >= 1e9)
358        sprintf(s, Align? "%6.3f GB": "%.3g GB", Size / 1e9);
359    else if(Size >= 1e6)
360        sprintf(s, Align? "%6.0f MB": "%.3g MB", Size / 1e6);
361    else if(Size >= 1e3)
362        sprintf(s, Align? "%6.0f kB": "%.3g kB", Size / 1e3);
363    else
364        sprintf(s, Align? "%6.0f  B": "%g B", Size);
365    return s;
366}
367
368//-----------------------------------------------------------------------------
369// Gets called periodically to service the R framework.
370// Will not return if the user interrupts.
371
372#if USING_R
373
374static INLINE void ServiceR(void)
375{
376    R_FlushConsole();
377    R_CheckUserInterrupt();     // may never return
378}
379
380#endif
381
382//-----------------------------------------------------------------------------
383#if FAST_MARS
384
385typedef struct tQueue {
386    int     iParent;            // parent term
387    double  RssDelta;
388    int     nTermsForRssDelta;  // number of terms when RssDelta was calculated
389    double  AgedRank;
390} tQueue;
391
392static tQueue *Q;       // indexed on iTerm (this Q is used for queue updates)
393static tQueue *SortedQ; // indexed on iParent rank (this Q is used to get next iParent)
394static int    nQMax;    // number of elements in Q
395
396static void InitQ(const int nMaxTerms)
397{
398    int i;
399    nQMax = 0;
400    Q       = (tQueue *)malloc1(nMaxTerms * sizeof(tQueue),
401                            "Q\t\t\tnMaxTerms %f sizeof(tQueue) %d",
402                            nMaxTerms, sizeof(tQueue));
403    SortedQ = (tQueue *)malloc1(nMaxTerms * sizeof(tQueue),
404                            "SortedQ\t\tnMaxTerms %f sizeof(tQueue) %d",
405                            nMaxTerms, sizeof(tQueue));
406    for (i = 0; i < nMaxTerms; i++) {
407        Q[i].iParent = i;
408        Q[i].nTermsForRssDelta = -99;   // not strictly needed, nice for debugging
409        Q[i].RssDelta = -1;
410        Q[i].AgedRank = -1;
411    }
412}
413
414static void FreeQ(void)
415{
416    free1(SortedQ);
417    free1(Q);
418}
419
420static void PrintSortedQ(int nFastK)     // for debugging
421{
422    printf("\n\nSortedQ  QIndex Parent nTermsForRssDelta AgedRank  RssDelta\n");
423    for (int i = 0; i < nQMax; i++) {
424        printf("            %3d    %3d   %15d    %5.1f  %g\n",
425            i+IOFFSET,
426            SortedQ[i].iParent+IOFFSET,
427            SortedQ[i].nTermsForRssDelta+IOFFSET,
428            SortedQ[i].AgedRank,
429            SortedQ[i].RssDelta);
430        if (i == nFastK-1)
431            printf("FastK %d ----------------------------------------------------\n",
432                nFastK);
433    }
434}
435
436// Sort so highest RssDeltas are at low indices.
437// Secondary sort key is iParent.  Not strictly needed, but removes
438// possible differences in qsort implementations (which "sort"
439// identical keys unpredictably).
440
441static int CompareQ(const void *p1, const void *p2)     // for qsort
442{
443    double Diff = ((tQueue*)p2)->RssDelta - ((tQueue*)p1)->RssDelta;
444    if (Diff < 0)
445        return -1;
446    else if (Diff > 0)
447        return 1;
448
449    // Diff is 0, so sort now on iParent
450
451    int iDiff = ((tQueue*)p1)->iParent - ((tQueue*)p2)->iParent;
452    if (iDiff < 0)
453        return -1;
454    else if (iDiff > 0)
455        return 1;
456    return 0;
457}
458
459// Sort so lowest AgedRanks are at low indices.
460// If AgedRanks are the same then sort on RssDelta and iParent.
461
462static int CompareAgedQ(const void *p1, const void *p2) // for qsort
463{
464    double Diff = ((tQueue*)p1)->AgedRank - ((tQueue*)p2)->AgedRank;
465    if (Diff < 0)
466        return -1;
467    else if (Diff > 0)
468        return 1;
469
470    // Diff is 0, so sort now on RssDelta
471
472    Diff = ((tQueue*)p2)->RssDelta - ((tQueue*)p1)->RssDelta;
473    if (Diff < 0)
474        return -1;
475    else if (Diff > 0)
476        return 1;
477
478    // Diff is still 0, so sort now on iParent
479
480    int iDiff = ((tQueue*)p1)->iParent - ((tQueue*)p2)->iParent;
481    if (iDiff < 0)
482        return -1;
483    else if (iDiff > 0)
484        return 1;
485    return 0;
486}
487
488static void AddTermToQ(
489    const int iTerm,        // in
490    const int nTerms,       // in
491    const double RssDelta,  // in
492    const bool Sort,        // in
493    const int nMaxTerms,    // in
494    const double FastBeta)  // in: ageing Coef, 0 is no ageing, FastMARS recommends 1
495{
496    ASSERT(iTerm < nMaxTerms);
497    ASSERT(nQMax < nMaxTerms);
498    Q[nQMax].nTermsForRssDelta = nTerms;
499    Q[nQMax].RssDelta = max(Q[iTerm].RssDelta, RssDelta);
500    nQMax++;
501    if (Sort) {
502        memcpy(SortedQ, Q, nQMax * sizeof(tQueue));
503        qsort(SortedQ, nQMax, sizeof(tQueue), CompareQ);         // sort on RssDelta
504        if (FastBeta != 0) {
505            for (int iRank = 0; iRank < nQMax; iRank++)
506                SortedQ[iRank].AgedRank =
507                    iRank + FastBeta * (nTerms - SortedQ[iRank].nTermsForRssDelta);
508            qsort(SortedQ, nQMax, sizeof(tQueue), CompareAgedQ); // sort on aged rank
509        }
510    }
511}
512
513static void UpdateRssDeltaInQ(const int iParent, const int nTermsForRssDelta,
514                              const double RssDelta)
515{
516    ASSERT(iParent == Q[iParent].iParent);
517    ASSERT(iParent < nQMax);
518    Q[iParent].nTermsForRssDelta = nTermsForRssDelta;
519    Q[iParent].RssDelta = RssDelta;
520}
521
522static int GetNextParent(   // returns -1 if no more parents
523    const bool InitFlag,    // use true to init, thereafter false
524    const int  nFastK)
525{
526    static int iQ;          // index into sorted queue
527    int iParent = -1;
528    if (InitFlag) {
529        if (TraceGlobal == 6)
530            printf("\n|Considering parents ");
531        iQ = 0;
532    } else {
533        if (iQ < min(nQMax, nFastK)) {
534            iParent = SortedQ[iQ].iParent;
535            iQ++;
536        }
537        if (TraceGlobal == 6 && iParent >= 0)
538            printf("%d [%g] ", iParent+IOFFSET, SortedQ[iQ].RssDelta);
539    }
540    return iParent;
541}
542
543#endif // FAST_MARS
544
545//-----------------------------------------------------------------------------
546// Order() gets the sort indices of vector x, so x[sorted[i]] <= x[sorted[i+1]].
547// Ties may be reordered. The returned indices are 0 based (as in C not as in R).
548//
549// This function is similar to the R library function rsort_with_index(),
550// but is defined here to minimize R dependencies.
551// Informal tests show that this is faster than rsort_with_index().
552
553static const double *pxGlobal;
554
555static int Compare(const void *p1, const void *p2)  // for qsort
556{
557    const int i1 = *(int *)p1;
558    const int i2 = *(int *)p2;
559    double Diff = pxGlobal[i1] - pxGlobal[i2];
560    if (Diff < 0)
561        return -1;
562    else if (Diff > 0)
563        return 1;
564    else
565        return 0;
566}
567
568static void Order(int sorted[],                     // out: vector with nx elements
569                  const double x[], const int nx)   // in: x is a vector with nx elems
570{
571    for (int i = 0; i < nx; i++)
572        sorted[i] = i;
573    pxGlobal = x;
574    qsort(sorted, nx, sizeof(int), Compare);
575}
576
577
578//-----------------------------------------------------------------------------
579// Get order indices for an x array of dimensions nRows x nCols.
580//
581// Returns an nRows x nCols integer array of indices, where each column
582// corresponds to a column of x.  See Order() for ordering details.
583//
584// Caller must free the returned array.
585
586static int *OrderArray(const double x[], const int nRows, const int nCols)
587{
588    int *xOrder = (int *)malloc1(nRows * nCols * sizeof(int),
589                            "xOrder\t\tnRows %d nCols %d sizeof(int) %d",
590                            nRows, nCols, sizeof(int));
591
592    for (int iCol = 0; iCol < nCols; iCol++) {
593        Order(xOrder + iCol*nRows, x + iCol*nRows, nRows);
594#if USING_R
595        if (nRows > 10000)
596            ServiceR();
597#endif
598    }
599    return xOrder;
600}
601
602//-----------------------------------------------------------------------------
603// return the number of TRUEs in the boolean vector UsedCols
604
605static int GetNbrUsedCols(const bool UsedCols[], const int nLen)
606{
607    int nTrue = 0;
608
609    for (int iCol = 0; iCol < nLen; iCol++)
610        if (UsedCols[iCol])
611            nTrue++;
612
613    return nTrue;
614}
615
616//-----------------------------------------------------------------------------
617// Copy used columns in x to *pxUsed and return the number of used columns
618// UsedCols[i] is true for each each used column index in x
619// Caller must free *pxUsed
620
621static int CopyUsedCols(double **pxUsed,                // out
622                    const double x[],                   // in: nCases x nCols
623                    const int nCases, const int nCols,  // in
624                    const bool UsedCols[])              // in
625{
626    int nUsedCols = GetNbrUsedCols(UsedCols, nCols);
627    double *xUsed = (double *)malloc1(nCases * nUsedCols * sizeof(double),
628                        "xUsed\t\t\tnCases %d nUsedCols %d sizeof(double) %d",
629                        nCases, nUsedCols, sizeof(double));
630
631    int iUsed = 0;
632    for (int iCol = 0; iCol < nCols; iCol++)
633        if (UsedCols[iCol]) {
634            memcpy(xUsed + iUsed * nCases,
635                x + iCol * nCases, nCases * sizeof(double));
636            iUsed++;
637        }
638    *pxUsed = xUsed;
639    return nUsedCols;
640}
641
642//-----------------------------------------------------------------------------
643// Print a summary of the model, for debug tracing
644
645#if STANDALONE
646static void PrintSummary(
647    const int    nMaxTerms,         // in
648    const int    nTerms,            // in: number of cols in bx, some may be unused
649    const int    nPreds,            // in: number of predictors
650    const int    nResp,             // in: number of cols in y
651    const bool   UsedCols[],        // in: specifies used colums in bx
652    const int    Dirs[],            // in
653    const double Cuts[],            // in
654    const double Betas[],           // in: if NULL will print zeroes
655    const int    nFactorsInTerm[])  // in: number of hinge funcs in basis term
656{
657    printf("   nFacs       Beta\n");
658
659    int nUsedCols = GetNbrUsedCols(UsedCols, nTerms);
660    int iUsed = -1;
661    for (int iTerm = 0; iTerm < nTerms; iTerm++) {
662        if (UsedCols[iTerm]) {
663            iUsed++;
664            printf("%2.2d  %2d    ", iTerm, nFactorsInTerm[iTerm]);
665            for (int iResp = 0; iResp < nResp; iResp++)
666                printf("%9.3g ", (Betas? Betas_(iUsed, iResp): 0));
667            printf("| ");
668            }
669        else {
670            printf("%2.2d  --    ", iTerm);
671            for (int iResp = 0; iResp < nResp; iResp++)
672                printf("%9s ", "--");
673            printf("| ");
674        }
675        int iPred;
676        for (iPred = 0; iPred < nPreds; iPred++)
677            if (Dirs_(iTerm,iPred) == 0)
678                printf(" . ");
679            else
680                printf("%2d ", Dirs_(iTerm,iPred));
681
682        printf("|");
683
684        for (iPred = 0; iPred < nPreds; iPred++)
685            if (Dirs_(iTerm,iPred) == 0)
686                printf("    .    ");
687            else if (Dirs_(iTerm,iPred) == 2)
688                printf("  linear ");
689            else
690                printf("%8.3g ", Cuts_(iTerm,iPred));
691
692        printf("\n");
693    }
694    printf("\n");
695}
696#endif // STANDALONE
697
698//-----------------------------------------------------------------------------
699// Set Diags to the diagonal values of inverse(X'X),
700// where X is referenced via the matrix R, from a previous call to dqrsl
701// with (in practice) bx.  The net result is that Diags is the diagonal
702// values of inverse(bx'bx).  We assume that R is created from a full rank X.
703//
704// TODO This could be simplified
705
706static void CalcDiags(
707    double Diags[],     // out: nCols x 1
708    const double R[],   // in: nCases x nCols, QR from prev call to dqrsl
709    const int nCases,   // in
710    const int nCols)    // in
711{
712    #define R_(i,j)     R [(i) + (j) * nCases]
713    #define R1_(i,j)    R1[(i) + (j) * nCols]
714    #define B_(i,j)     B [(i) + (j) * nCols]
715
716    double *R1 = (double *)malloc1(nCols * nCols * sizeof(double),  // nCols rows of R
717                            "R1\t\t\tnCols %d nCols %d sizeof(double) %d",
718                            nCols, nCols, sizeof(double));
719
720    double *B =  (double *)calloc1(nCols * nCols, sizeof(double),   // rhs of R1 * x = B
721                            "B\t\t\tnCols %d nCols %d sizeof(double) %d",
722                            nCols, nCols, sizeof(double));
723    int i, j;
724    for (i = 0; i < nCols; i++) {   // copy nCols rows of R into R1
725        for (j =  0; j < nCols; j++)
726            R1_(i,j) = R_(i,j);
727        B_(i,i) = 1;                // set diag of B to 1
728    }
729    int job = 1;            // 1 means solve R1 * x = B where R1 is upper triangular
730    int info = 0;
731    for (i = 0; i < nCols; i++) {
732        dtrsl_(             // LINPACK function
733            R1,             // in: t, matrix of the system, untouched
734            (int *)&nCols,  // in: ldt (typecast discards const)
735            (int *)&nCols,  // in: n
736            &B_(0,i),       // io: b, on return has solution x
737            &job,           // in:
738            &info);         // io:
739
740        ASSERT(info == 0);
741    }
742    // B is now inverse(R1).  Calculate B x B.
743
744    for (i = 0; i < nCols; i++)
745        for (j =  0; j < nCols; j++) {
746            double Sum = 0;
747            for (int k = max(i,j); k < nCols; k++)
748                Sum += B_(i,k) * B_(j,k);
749            B_(i,j) = B_(j,i) = Sum;
750        }
751    for (i = 0; i < nCols; i++)
752         Diags[i] = B_(i,i);
753    free1(B);
754    free1(R1);
755}
756
757//-----------------------------------------------------------------------------
758// Regress y on the used columns of x, in the standard way (using QR).
759// UsedCols[i] is true for each each used col i in x; unused cols are ignored.
760//
761// The returned Betas argument is computed from, and is indexed on,
762// the compacted x vector, not on the original x.
763//
764// The returned iPivots should only be used if *pnRank != nUsedCols.
765// The entries of iPivots refer to columns in the full x (and are 0 based).
766// Entries in iPivots at *pnRank and above specify linearly dependent columns in x.
767//
768// To maximize compatibility we call the same routines as the R function lm.
769
770static void Regress(
771    double       Betas[],       // out: nUsedCols * nResp, can be NULL
772    double       Residuals[],   // out: nCases * nResp, can be NULL
773    double       *pRss,         // out: RSS, summed over all nResp, can be NULL
774    double       Diags[],       // out: diags of inv(transpose(x) * x), can be NULL
775    int          *pnRank,       // out: nbr of indep cols in x
776    int          iPivots[],     // out: nCols, can be NULL
777    const double x[],           // in: nCases x nCols, must include intercept
778    const double y[],           // in: nCases x nResp
779    const double Weights[],     // in: nCases x 1, can be NULL
780    const int    nCases,        // in: number of rows in x and in y
781    const int    nResp,         // in: number of cols in y
782    int          nCols,         // in: number of columns in x, some may not be used
783    const bool   UsedCols[])    // in: specifies used columns in x
784{
785    double *xUsed;
786    int nUsedCols = CopyUsedCols(&xUsed, x, nCases, nCols, UsedCols);
787
788    bool MustFreeBetas = false;
789    if (Betas == NULL) {
790        Betas = (double *)malloc1(nUsedCols * nResp * sizeof(double),
791                            "Betas\t\t\tnUsedCols %d nResp %d sizeof(double) %d",
792                            nUsedCols, nResp, sizeof(double));
793        MustFreeBetas = true;
794    }
795    bool MustFreeResiduals = false;
796    if (Residuals == NULL) {
797        Residuals = (double *)malloc1(nCases * nResp * sizeof(double),
798                                "Residuals\t\tnCases %d nResp %d sizeof(double) %d",
799                                nCases, nResp, sizeof(double));
800        MustFreeResiduals = true;
801    }
802    bool MustFreePivots = false;
803    if (iPivots == NULL) {
804        iPivots = (int *)malloc1(nUsedCols * sizeof(int),
805                            "iPivots\t\tnUsedCols %d sizeof(int) %d",
806                            nUsedCols, sizeof(int));
807        MustFreePivots = true;
808    }
809    int iCol;
810    for (iCol = 0; iCol < nUsedCols; iCol++)
811        iPivots[iCol] = iCol+1;
812
813    // apply weights to x and y if Weights is not NULL
814
815    double *wx = xUsed;
816    double *wy = (double *)y;       // cast discards "const" else compiler warning
817    double *Weightss = NULL;        // sqrt of Weights
818    if (Weights) {
819        // wx is xUsed but with each element multiplied by the sqrt of
820        // the corresponding element of Weights.  Ditto for wy.
821
822        int iCase;
823        Weightss = (double *)malloc1(nCases * sizeof(double),
824                                "Weightss\t\t\tnCases %d sizeof(double) %d",
825                                nCases, sizeof(double));
826
827        for (iCase = 0; iCase < nCases; iCase++)
828            Weightss[iCase] = sqrt(Weights[iCase]);
829
830        wx = (double *)malloc1(nCases * nUsedCols * sizeof(double),
831                        "wx\t\t\tnCases %d nUsedCols %d sizeof(double) %d",
832                        nCases, nUsedCols, sizeof(double));
833
834        wy = (double *)malloc1(nCases * nResp * sizeof(double),
835                        "wy\t\t\tnCases %d nResp %d sizeof(double) %d",
836                        nCases, nResp, sizeof(double));
837
838        for (iCase = 0; iCase < nCases; iCase++)
839            for (int iCol = 0; iCol < nUsedCols; iCol++)
840                wx[iCase + iCol * nCases] =
841                    Weightss[iCase] * xUsed[iCase + iCol * nCases];
842
843        for (iCase = 0; iCase < nCases; iCase++)
844            for (int iResp = 0; iResp < nResp; iResp++)
845                wy[iCase + iResp * nCases] =
846                    Weightss[iCase] * y[iCase + iResp * nCases];
847    }
848    // compute Betas and yHat (use Residuals as a temporary buffer to store yHat)
849
850    double *qraux = (double *)malloc1(nUsedCols * sizeof(double),
851                                "qraux\t\t\tnUsedCols %d sizeof(double) %d",
852                                nUsedCols, sizeof(double));
853
854    double *work = (double *)malloc1(nCases * nUsedCols * sizeof(double),
855                                "work\t\t\tnCases %d nUsedCols %d sizeof(double) %d",
856                                nCases, nUsedCols, sizeof(double));
857
858    dqrdc2_(                // R function, QR decomp based on LINPACK dqrdc
859        wx,                 // io:  x, on return upper tri of x is R of QR
860        (int *)&nCases,     // in:  ldx (typecast discards const)
861        (int *)&nCases,     // in:  n
862        &nUsedCols,         // in:  p
863        (double*)&QR_TOL,   // in:  tol
864        pnRank,             // out: k, num of indep cols of x
865        qraux,              // out: qraux
866        iPivots,            // out: jpvt
867        work);              // work
868
869    double Rss = 0;
870    int job = 101;          // specify c=1 e=1 to compute qty, b, yHat
871    int info;
872    for (int iResp = 0; iResp < nResp; iResp++) {
873        dqrsl_(                                 // LINPACK function
874            wx,                                 // in:  x, generated by dqrdc2
875            (int *)&nCases,                     // in:  ldx (typecast discards const)
876            (int *)&nCases,                     // in:  n
877            pnRank,                             // in:  k
878            qraux,                              // in:  qraux
879            (double *)(wy + iResp * nCases),    // in:  y
880            work,                               // out: qy, unused here
881            work,                               // out: qty, unused here
882            (double *)(&Betas_(0,iResp)),       // out: b
883            work,                               // out: rsd, unused here
884            (double *)(&Residuals_(0,iResp)),   // out: xb = yHat = ls approx of x*b
885            &job,                               // in:  job
886            &info);                             // in:  info
887
888        ASSERT(info == 0);
889
890        // compute Residuals and Rss (sum over all responses)
891
892        if (Weightss)
893            for (int iCase = 0; iCase < nCases; iCase++) {
894                Residuals_(iCase,iResp) /= Weightss[iCase];
895                Residuals_(iCase,iResp) = (y_(iCase,iResp) - Residuals_(iCase,iResp));
896                Rss += sq(Residuals_(iCase, iResp));
897            }
898        else
899            for (int iCase = 0; iCase < nCases; iCase++) {
900                Residuals_(iCase,iResp) = (y_(iCase,iResp) - Residuals_(iCase,iResp));
901                Rss += sq(Residuals_(iCase, iResp));
902            }
903    }
904    if (pRss)
905        *pRss = Rss;
906
907    if (*pnRank != nUsedCols) {
908        // adjust iPivots for missing cols in UsedCols and for 1 offset
909
910        int *PivotOffset = (int *)malloc1(nCols * sizeof(int),
911                                    "PivotOffset\t\t\tnCols %d sizeof(int) %d",
912                                    nCols, sizeof(int));
913        int nOffset = 0, iOld = 0;
914        for (iCol = 0; iCol < nCols; iCol++) {
915            if (!UsedCols[iCol])
916                nOffset++;
917            else {
918                PivotOffset[iOld] = nOffset;
919                if (++iOld > nUsedCols)
920                    break;
921            }
922        }
923        for (iCol = 0; iCol < nUsedCols; iCol++)
924            iPivots[iCol] = iPivots[iCol] - 1 + PivotOffset[iPivots[iCol] - 1];
925        free1(PivotOffset);
926    }
927    if (Diags)
928        CalcDiags(Diags, wx, nCases, nUsedCols);
929    if (MustFreePivots)
930        free1(iPivots);
931    if (MustFreeResiduals)
932        free1(Residuals);
933    if (MustFreeBetas)
934        free1(Betas);
935    if (Weightss) {
936        free1(Weightss);
937        free1(wx);
938        free1(wy);
939    }
940    free1(xUsed);
941    free1(qraux);
942    free1(work);
943}
944
945//-----------------------------------------------------------------------------
946// This routine is for testing Regress from R, to compare results to R's lm().
947
948#if USING_R
949void RegressR(
950    double       Betas[],       // out: (nUsedCols+1) * nResp, +1 is for intercept
951    double       Residuals[],   // out: nCases * nResp
952    double       Rss[],         // out: RSS, summed over all nResp
953    double       Diags[],       // out: diags of inv(transpose(x) * x)
954    int          *pnRank,       // out: nbr of indep cols in x
955    int          iPivots[],     // out: nCols, can be R_NilValue
956    const double x[],           // in: nCases x nCols
957    const double y[],           // in: nCases x nResp
958    const double Weights[],     // in: nCases x 1, can be R_NilValue
959    const int    *pnCases,      // in: number of rows in x and in y
960    const int    *pnResp,       // in: number of cols in y
961    int          *pnCols,       // in: number of columns in x, some may not be used
962    const bool   UsedCols[])    // in: specifies used columns in x
963{
964    if ((void *)Weights == (void *)R_NilValue)
965        Weights = NULL;
966
967    Regress(Betas, Residuals, Rss, Diags, pnRank, iPivots,
968        x, y, Weights, *pnCases, *pnResp, *pnCols, UsedCols);
969}
970#endif
971
972//-----------------------------------------------------------------------------
973// Regress y on bx to get Residuals and Betas.  If bx isn't of full rank,
974// remove dependent cols, update UsedCols, and regress again on the bx with
975// removed cols.
976
977static void RegressAndFix(
978    double Betas[],         // out: nMaxTerms x nResp, can be NULL
979    double Residuals[],     // out: nCases x nResp, can be NULL
980    double Diags[],         // out: if !NULL set to diags of inv(transpose(bx) * bx)
981    bool   UsedCols[],      // io:  will remove cols if necessary, nMaxTerms x 1
982    const  double bx[],     // in:  nCases x nMaxTerms
983    const double y[],       // in:  nCases x nResp
984    const double Weights[], // in: nCases x 1, can be NULL
985    const int nCases,       // in
986    const int nResp,        // in: number of cols in y
987    const int nTerms)       // in: number of cols in bx, some may not be used
988{
989    int nRank;
990    int *iPivots = (int *)malloc1(nTerms * sizeof(int),
991                            "iPivots\t\tnTerms %d sizeof(int) %d",
992                            nTerms, sizeof(int));
993    Regress(Betas, Residuals, NULL, Diags, &nRank, iPivots,
994        bx, y, Weights, nCases, nResp, nTerms, UsedCols);
995    int nUsedCols = GetNbrUsedCols(UsedCols, nTerms);
996    int nDeficient = nUsedCols - nRank;
997    if (nDeficient) {           // rank deficient?
998        // Remove linearly dependent columns.
999        // The lin dep columns are at index nRank and higher in iPivots.
1000
1001        for (int iCol = nRank; iCol < nUsedCols; iCol++)
1002            UsedCols[iPivots[iCol]] = false;
1003
1004        Regress(Betas, Residuals, NULL, Diags, &nRank, NULL,
1005            bx, y, Weights, nCases, nResp, nTerms, UsedCols);
1006        nUsedCols = nUsedCols - nDeficient;
1007        if (nRank != nUsedCols)
1008            warning("Could not fix rank deficient bx: nUsedCols %d nRank %d",
1009                nUsedCols,  nRank);
1010        else if (TraceGlobal >= 1)
1011            printf("Fixed rank deficient bx by removing %d term%s, %d term%s remain%s\n",
1012                nDeficient, ((nDeficient==1)? "": "s"),
1013                nUsedCols,  ((nUsedCols==1)? "": "s"), ((nUsedCols==1)? "s": ""));
1014    }
1015    free1(iPivots);
1016}
1017
1018//-----------------------------------------------------------------------------
1019static INLINE double Mean(const double x[], int n)
1020{
1021    double mean = 0;
1022    for (int i = 0; i < n; i++)
1023        mean += x[i] / n;
1024    return mean;
1025}
1026
1027//-----------------------------------------------------------------------------
1028// get mean centered sum of squares
1029
1030static INLINE double SumOfSquares(const double x[], const double mean, int n)
1031{
1032    double ss = 0;
1033    for (int i = 0; i < n; i++)
1034        ss += sq(x[i] - mean);
1035    return ss;
1036}
1037
1038//-----------------------------------------------------------------------------
1039static INLINE double GetGcv(const int nTerms, // nbr basis terms including intercept
1040                const int nCases, double Rss, const double Penalty)
1041{
1042    double cost;
1043    if (Penalty == -1)  // special case: terms and knots are free
1044        cost = 0;
1045    else {
1046        const double nKnots = ((double)nTerms-1) / 2;
1047        cost = (nTerms + Penalty * nKnots) / nCases;
1048    }
1049    // test against cost ensures that GCVs are non-decreasing as nbr of terms increases
1050    return cost >= 1? POS_INF : Rss / (nCases * sq(1 - cost));
1051}
1052
1053//-----------------------------------------------------------------------------
1054// We only consider knots that are nMinSpan distance apart, to increase resistance
1055// to runs of correlated noise.  This function determines that distance.
1056// It implements eqn 43 FriedmanMars (see refs), but with an extension for nMinSpan.
1057// If bx==NULL then instead of counting valid entries in bx we use nCases,
1058// and ignore the term index iTerm.
1059//
1060// nMinSpan: if =0, use internally calculated min span
1061//           if >0, use instead of internally calculated min span
1062//           if <0, use old (incorrect) method of calculating minspan
1063//                  this was the method used prior to earth 2.4-0
1064
1065static INLINE int GetMinSpan(int nCases, int nPreds, const double *bx,
1066                             const int iTerm)
1067{
1068    if (nMinSpanGlobal > 0)                     // user specified a fixed span?
1069        return nMinSpanGlobal;
1070
1071    int nUsed = 0;                              // Nm in Friedman's notation
1072    if (bx == NULL)
1073        nUsed = nCases;
1074    else for (int iCase = 0; iCase < nCases; iCase++)
1075        if (bx_(iCase,iTerm) > 0)
1076            nUsed++;
1077
1078    static const double temp1 = 2.9702;         // -log(-log(0.95)
1079    static const double temp2 = 1.7329;         // 2.5 * log(2)
1080    const double n = (nMinSpanGlobal < 0) ? nCases: nPreds; // CHANGED earth 2.4
1081    return (int)((temp1 + log(n * nUsed)) / temp2);
1082}
1083
1084//-----------------------------------------------------------------------------
1085// We don't consider knots that are too close to the ends.
1086// This function determines how close to an end we can get.
1087// It implements eqn 45 FriedmanMars (see refs), re-expressed
1088// for efficient computation
1089
1090static INLINE int GetEndSpan(const int nCases, const int nPreds)
1091{
1092    static const double log_2 = 0.69315;            // log(2)
1093    static const double temp1 = 7.32193;            // 3 + log(20)/log(2);
1094    const double n = (nMinSpanGlobal < 0) ? nCases: nPreds; // CHANGED earth 2.4
1095    return (int)(temp1 + log(n) / log_2);
1096}
1097
1098//-----------------------------------------------------------------------------
1099// Return true if model term type is not already in model
1100// i.e. if the hockey stick functions in a pre-existing term do not use the same
1101// predictors (ignoring knot values).
1102//
1103// In practice this nearly always returns true.
1104
1105static bool GetNewFormFlag(const int iPred, const int iTerm,
1106                        const int Dirs[], const bool UsedCols[],
1107                        const int nTerms, const int nPreds, const int nMaxTerms)
1108{
1109    bool IsNewForm = true;
1110    for (int iTerm1 = 1; iTerm1 < nTerms; iTerm1++) // start at 1 to skip intercept
1111        if (UsedCols[iTerm1]) {
1112            IsNewForm = false;
1113            if (Dirs_(iTerm1,iPred) == 0)
1114                return true;
1115            for (int iPred1 = 0; iPred1 < nPreds; iPred1++)
1116                if (iPred1 != iPred && Dirs_(iTerm1,iPred1) != Dirs_(iTerm,iPred1))
1117                    return true;
1118        }
1119    return IsNewForm;
1120}
1121
1122//-----------------------------------------------------------------------------
1123static double GetCut(int iCase, const int iPred, const int nCases,
1124                        const double x[], const int xOrder[])
1125{
1126    if (iCase < 0 || iCase >= nCases)
1127        error("GetCut iCase %d: iCase < 0 || iCase >= nCases", iCase);
1128    const int ix = xOrder_(iCase,iPred);
1129    ASSERT(ix >= 0 && ix < nCases);
1130    return x_(ix,iPred);
1131}
1132
1133//-----------------------------------------------------------------------------
1134// The BetaCache is used when searching for a new term pair, via FindTerm.
1135// Most of the calculation for the orthogonal regression betas is repeated
1136// with the same data, and thus we can save time by caching betas.
1137// (The "Betas" are the regression coefficients.)
1138//
1139// iParent    is the term that forms the base for the new term
1140// iPred      is the predictor for the new term
1141// iOrthCol   is the column index in the bxOrth matrix
1142
1143static double *BetaCacheGlobal; // [iOrthCol,iParent,iPred]
1144                                // dim nPreds x nMaxTerms x nMaxTerms
1145
1146static void InitBetaCache(const bool UseBetaCache,
1147                          const int nMaxTerms, const int nPreds)
1148{
1149    int nCache =  nMaxTerms * nMaxTerms * nPreds;
1150    if (!UseBetaCache) {
1151        BetaCacheGlobal = NULL;
1152    // 3e9 below is somewhat arbitrary but seems about right (in 2011)
1153    } else if (nCache * sizeof(double) > 3e9) {
1154            printf(
1155"\nNote: earth's beta cache would require %s, so forcing Use.beta.cache=FALSE.\n"
1156"      Invoke earth with Use.beta.cache=FALSE to make this message go away.\n\n",
1157                sFormatMemSize(nCache * sizeof(double), false));
1158            fflush(stdout);
1159            BetaCacheGlobal = NULL;
1160    } else {
1161       if (TraceGlobal >= 5)    // print cache size
1162            printf("BetaCache %s\n",
1163                sFormatMemSize(nCache * sizeof(double), false));
1164
1165        BetaCacheGlobal = (double *)malloc1(nCache * sizeof(double),
1166            "BetaCacheGlobal\tnMaxTerms %d nMaxTerms %d nPreds %d sizeof(double) %d",
1167            nMaxTerms, nMaxTerms, nPreds, sizeof(double));
1168
1169        for (int i = 0; i < nCache; i++)    // mark all entries as uninited
1170            BetaCacheGlobal[i] = POS_INF;
1171    }
1172}
1173
1174static void FreeBetaCache(void)
1175{
1176    if (BetaCacheGlobal)
1177        free1(BetaCacheGlobal);
1178}
1179
1180//-----------------------------------------------------------------------------
1181// Init a new bxOrthCol to the residuals from regressing y on the used columns
1182// of the orthogonal matrix bxOrth.  The length (i.e. sum of sqaures divided
1183// by nCases) of each column of bxOrth must be 1 with mean 0 (except the
1184// first column which is the intercept).
1185//
1186// In practice this function is called with the params shown in {braces}
1187// and is called only by InitBxOrthCol.
1188//
1189// This function must be fast.
1190//
1191// In calculation of Beta, we used to have
1192//     xty += pbxOrth[iCase] * y[iCase];
1193// and now we have
1194//    xty += pbxOrth[iCase] * bxOrthCol[iCase];
1195// i.e. we use the "modified" instead of the "classic" Gram Schmidt.
1196// This is supposedly less susceptible to round off errors, although I haven't
1197// seen it have any effect on any of the data sets we have tested.
1198
1199static INLINE void OrthogResiduals(
1200    double bxOrthCol[],     // out: nCases x 1      { bxOrth[,nTerms] }
1201    const double y[],       // in:  nCases x nResp  { bx[,nTerms], xbx }
1202    const double bxOrth[],  // in:  nTerms x nPreds { bxOrth }
1203    const int nCases,       // in
1204    const int nTerms,       // in: nTerms in model, i.e. number of used cols in bxOrth
1205    const bool UsedTerms[], // in: UsedTerms[i] is true if col is used, unused cols ignored
1206                            //     Following parameters are only for the beta cache
1207    const int iParent,      // in: if >= 0, use BetaCacheGlobal {FindTerm iTerm, addTermP -1}
1208    const int iPred,        // in: predictor index i.e. col index in input matrix x
1209    const int nMaxTerms)    // in:
1210{
1211    double *pCache;
1212    if (iParent >= 0 && BetaCacheGlobal)
1213        pCache = BetaCacheGlobal + iParent*nMaxTerms + iPred*sq(nMaxTerms);
1214    else
1215        pCache = NULL;
1216
1217    memcpy(bxOrthCol, y, nCases * sizeof(double));
1218
1219    for (int iTerm = 0; iTerm < nTerms; iTerm++)
1220        if (UsedTerms[iTerm]) {
1221            const double *pbxOrth = &bxOrth_(0, iTerm);
1222            double Beta;
1223            if (pCache && pCache[iTerm] != POS_INF)
1224                Beta = pCache[iTerm];
1225            else {
1226                double xty = 0;
1227                for (int iCase = 0; iCase < nCases; iCase++)
1228                    xty += pbxOrth[iCase] * bxOrthCol[iCase]; // see header comment
1229                Beta = xty;  // no need to divide by xtx, it is 1
1230                ASSERT(FINITE(Beta));
1231                if (pCache)
1232                    pCache[iTerm] = Beta;
1233            }
1234#if USE_BLAS
1235            const double NegBeta = -Beta;
1236            daxpy_(&nCases, &NegBeta, pbxOrth, &ONE, bxOrthCol, &ONE);
1237#else
1238            for (int iCase = 0; iCase < nCases; iCase++)
1239                bxOrthCol[iCase] -= Beta * pbxOrth[iCase];
1240#endif
1241        }
1242}
1243
1244//-----------------------------------------------------------------------------
1245// Init the rightmost column of bxOrth i.e. the column indexed by nTerms.
1246// The new col is the normalized residuals from regressing y on the
1247// lower (i.e. already existing) cols of bxOrth.
1248// Also updates bxOrthCenteredT and bxOrthMean.
1249//
1250// In practice this function is called only with the params shown in {braces}
1251
1252static INLINE void InitBxOrthCol(
1253    double bxOrth[],         // io: col nTerms is changed, other cols not touched
1254    double bxOrthCenteredT[],// io: kept in sync with bxOrth
1255    double bxOrthMean[],     // io: element at nTerms is updated
1256    bool   *pGoodCol,        // io: set to false if col sum-of-squares is under BX_TOL
1257    const double *y,         // in: { AddCandLinTerm xbx, addTermPair bx[,nTerms] }
1258    const int nTerms,        // in: column goes in at index nTerms, 0 is the intercept
1259    const bool WorkingSet[], // in
1260    const int nCases,        // in
1261    const int nMaxTerms,     // in
1262    const int iCacheTerm,    // in: if >= 0, use BetaCacheGlobal {FindTerm iTerm, AddTermP -1}
1263                             //     if < 0 then recalc Betas from scratch
1264    const int iPred,         // in: predictor index i.e. col index in input matrix x
1265    const double Weights[])  // in:
1266{
1267    int iCase;
1268    *pGoodCol = true;
1269    Weights = Weights; // prevent compiler warning: unused parameter 'Weights'
1270
1271    if (nTerms == 0) {          // column 0, the intercept
1272        double len = 1 / sqrt((double) nCases);
1273        bxOrthMean[0] = len;
1274        for (iCase = 0; iCase < nCases; iCase++)
1275            bxOrth_(iCase,0) = len;
1276    } else if (nTerms == 1) {   // column 1, the first basis function
1277        double yMean = Mean(y, nCases);
1278        for (iCase = 0; iCase < nCases; iCase++)
1279            bxOrth_(iCase,1) = y[iCase] - yMean;
1280    } else
1281        OrthogResiduals(&bxOrth_(0,nTerms), // resids go in rightmost col of bxOrth at nTerms
1282            y, bxOrth, nCases, nTerms, WorkingSet, iCacheTerm, iPred, nMaxTerms);
1283
1284    if (nTerms > 0) {
1285        // normalize the column to length 1 and init bxOrthMean[nTerms]
1286
1287        double bxOrthSS = SumOfSquares(&bxOrth_(0,nTerms), 0, nCases);
1288        const double Tol = (iCacheTerm < 0? 0: BX_TOL);
1289        if (bxOrthSS > Tol) {
1290            bxOrthMean[nTerms] = Mean(&bxOrth_(0,nTerms), nCases);
1291            const double len = sqrt(bxOrthSS);
1292            for (iCase = 0; iCase < nCases; iCase++)
1293                bxOrth_(iCase,nTerms) /= len;
1294        } else {
1295            *pGoodCol = false;
1296            bxOrthMean[nTerms] = 0;
1297            memset(&bxOrth_(0,nTerms), 0, nCases * sizeof(double));
1298        }
1299    }
1300    for (iCase = 0; iCase < nCases; iCase++)        // keep bxOrthCenteredT in sync
1301        bxOrthCenteredT_(nTerms,iCase) = bxOrth_(iCase,nTerms) - bxOrthMean[nTerms];
1302}
1303
1304//-----------------------------------------------------------------------------
1305// Add a new term pair to the arrays.
1306// Each term in the new term pair is a copy of an existing parent term but extended
1307// by multiplying it by a new hockey stick function at the selected knot.
1308// If the upper term in the term pair is invalid then we still add the upper
1309// term but mark it as false in FullSet.
1310
1311static void AddTermPair(
1312    int    Dirs[],              // io
1313    double Cuts[],              // io
1314    double bx[],                // io: MARS basis matrix
1315    double bxOrth[],            // io
1316    double bxOrthCenteredT[],   // io
1317    double bxOrthMean[],        // io
1318    bool   FullSet[],           // io
1319    int    nFactorsInTerm[],    // io
1320    int    nUses[],             // io: nbr of times each predictor is used in the model
1321    const int nTerms,           // in: new term pair goes in at index nTerms and nTerms1
1322    const int iBestParent,      // in: parent term
1323    const int iBestCase,        // in
1324    const int iBestPred,        // in
1325    const int nPreds,           // in
1326    const int nCases,           // in
1327    const int nMaxTerms,        // in
1328    const bool IsNewForm,       // in
1329    const bool IsLinPred,       // in: pred was discovered by search to be linear
1330    const int LinPreds[],       // in: user specified preds which must enter linearly
1331    const double x[],           // in
1332    const int xOrder[],         // in
1333    const double Weights[])     // in:
1334{
1335    const double BestCut = GetCut(iBestCase, iBestPred, nCases, x, xOrder);
1336    ASSERT(IsLinPred || iBestCase != 0);
1337    const int nTerms1 = nTerms+1;
1338
1339    // copy the parent term to the new term pair
1340
1341    int iPred;
1342    bool PrintedParent = false;
1343    for (iPred = 0; iPred < nPreds; iPred++) {
1344        Dirs_(nTerms, iPred) =
1345        Dirs_(nTerms1,iPred) = Dirs_(iBestParent,iPred);
1346
1347        Cuts_(nTerms, iPred) =
1348        Cuts_(nTerms1,iPred) = Cuts_(iBestParent,iPred);
1349
1350        if (TraceGlobal >= 2 && !PrintedParent && Dirs_(iBestParent,iPred)) {
1351            // print parent term (this appends to prints by PrintForwardStep)
1352            printf("%-3d ", iBestParent+IOFFSET);
1353            PrintedParent = true;
1354        }
1355    }
1356    // incorporate the new hockey stick function
1357
1358    nFactorsInTerm[nTerms]  =
1359    nFactorsInTerm[nTerms1] = nFactorsInTerm[iBestParent] + 1;
1360
1361    int DirEntry = 1;
1362    if (LinPreds[iBestPred]) {
1363        ASSERT(IsLinPred);
1364        DirEntry = 2;
1365    }
1366    Dirs_(nTerms, iBestPred) = DirEntry;
1367    Dirs_(nTerms1,iBestPred) = -1; // will be ignored if adding only one hinge
1368
1369    Cuts_(nTerms, iBestPred) =
1370    Cuts_(nTerms1,iBestPred) = BestCut;
1371
1372    FullSet[nTerms] = true;
1373    if (!IsLinPred && IsNewForm)
1374        FullSet[nTerms1] = true;
1375
1376    // If the term is not valid, then we don't wan't to use it as the base for
1377    // a new term later (in FindTerm).  Enforce this by setting
1378    // nFactorsInTerm to a value greater than any posssible nMaxDegree.
1379
1380    if (!FullSet[nTerms1])
1381        nFactorsInTerm[nTerms1] = MAX_DEGREE + 1;
1382
1383    // fill in new columns of bx, at nTerms and nTerms+1 (left and right hinges)
1384
1385    int iCase;
1386    if (DirEntry == 2) {
1387        for (iCase = 0; iCase < nCases; iCase++)
1388            bx_(iCase,nTerms) = bx_(iCase,iBestParent) * x_(iCase,iBestPred);
1389    } else for (iCase = 0; iCase < nCases; iCase++) {
1390        if (x_(iCase,iBestPred) - BestCut > 0)
1391            bx_(iCase,nTerms) =
1392                bx_(iCase,iBestParent) * (x_(iCase,iBestPred) - BestCut);
1393        else
1394            bx_(iCase,nTerms1) =
1395                bx_(iCase,iBestParent) * (BestCut - x_(iCase,iBestPred));
1396    }
1397    nUses[iBestPred]++;
1398
1399    // init the col in bxOrth at nTerms and init bxOrthMean[nTerms]
1400
1401    bool GoodCol;
1402    InitBxOrthCol(bxOrth, bxOrthCenteredT, bxOrthMean, &GoodCol,
1403        &bx_(0,nTerms), nTerms, FullSet, nCases, nMaxTerms, -1, nPreds, Weights);
1404                            // -1 means don't use BetaCacheGlobal, calc Betas afresh
1405
1406    // init the col in bxOrth at nTerms1 and init bxOrthMean[nTerms1]
1407
1408    if (FullSet[nTerms1]) {
1409        InitBxOrthCol(bxOrth, bxOrthCenteredT, bxOrthMean, &GoodCol,
1410            &bx_(0,nTerms1), nTerms1, FullSet, nCases, nMaxTerms, -1, iPred, Weights);
1411    } else {
1412        memset(&bxOrth_(0,nTerms1), 0, nCases * sizeof(double));
1413        bxOrthMean[nTerms1] = 0;
1414        for (iCase = 0; iCase < nCases; iCase++)    // keep bxOrthCenteredT in sync
1415            bxOrthCenteredT_(nTerms1,iCase) = 0;
1416    }
1417}
1418
1419//-----------------------------------------------------------------------------
1420// The caller has selected a candidate predictor iPred and a candidate iParent.
1421// This function now selects a knot.  If it finds a knot it will
1422// update *piBestCase and pRssDeltaForParentPredPair.
1423//
1424// The general idea: scan backwards through all (ordered) values (i.e. potential
1425// knots) for the given predictor iPred, calculating RssDelta.
1426// If RssDelta > *pRssDeltaForParentPredPair (and all else is ok), then
1427// select the knot (by updating *piBestCase and *pRssDeltaForParentPredPair).
1428//
1429// There are currently nTerms in the model. We want to add a term pair
1430// at index nTerms and nTerms+1.
1431//
1432// This function must be fast.
1433//
1434// A note on the iSpan variable. We used to have
1435//     if (iCase % nMinSpan == 0)
1436// now we have
1437//     if (iSpan-- == 1)
1438// which is measurably faster, at least on a Pentium D.
1439// When we init iSpan (before the loop) we have a bit of code to
1440// initialize to an offset that puts the knots as the same positions as
1441// previous versions of earth.
1442
1443static INLINE void FindKnot(
1444    int    *piBestCase,         // out: possibly updated, index into the ORDERED x's
1445    double *pRssDeltaForParentPredPair, // io: updated if knot is better
1446    double CovCol[],            // scratch buffer, overwritten, nTerms x 1
1447    double ycboSum[],           // scratch buffer, overwritten, nMaxTerms x nResp
1448    double CovSx[],             // scratch buffer, overwritten, nTerms x 1
1449    double *ybxSum,             // scratch buffer, overwritten, nResp x 1
1450    const int nTerms,           // in
1451    const int iParent,          // in: parent term
1452    const int iPred,            // in: predictor index
1453    const int nCases,           // in
1454    const int nResp,            // in: number of cols in y
1455    const int nMaxTerms,        // in
1456    const double RssDeltaLin,   // in: change in RSS if predictor iPred enters linearly
1457    const double MaxAllowedRssDelta, // in: FindKnot rejects any changes in Rss greater than this
1458    const double bx[],          // in: MARS basis matrix
1459    const double bxOrth[],      // in
1460    const double bxOrthCenteredT[], // in
1461    const double bxOrthMean[],  // in
1462    const double x[],           // in: nCases x nPreds
1463    const double y[],           // in: nCases x nResp
1464    const double Weights[],     // in: nCases x 1, must not be NULL
1465    const int xOrder[],         // in
1466    const double yMean[],       // in: vector nResp x 1
1467    const int nMinSpan,
1468    const int nEndSpan,
1469    const double NewVarAdjust)  // in: 1 if not a new var, 1+NewVarPenalty if new var
1470{
1471    Weights = Weights; // prevent compiler warning: unused parameter 'Weights'
1472    ASSERT(MaxAllowedRssDelta > 0);
1473#if USE_BLAS
1474    double Dummy = bxOrth[0];   // prevent compiler warning: unused parameter
1475    Dummy = bxOrthMean[0];
1476#else
1477    double Dummy = bxOrthCenteredT[0];
1478    Dummy = nMaxTerms;
1479#endif
1480    const int nCases_nEndSpan = nCases - nEndSpan;
1481    ASSERT(nMinSpan > 0);
1482    int iSpan = (nCases - 1) % nMinSpan;
1483    if (iSpan == 0)
1484        iSpan = nMinSpan;
1485
1486    int iResp;
1487    for (iResp = 0; iResp < nResp; iResp++)
1488        ycboSum_(nTerms, iResp) = 0;
1489    memset(CovCol, 0, (nTerms+1) * sizeof(double));
1490    memset(CovSx,  0, (nTerms+1) * sizeof(double));
1491    memset(ybxSum, 0, nResp * sizeof(double));
1492    double bxSum = 0, bxSqSum = 0, bxSqxSum = 0, bxxSum = 0, st = 0;
1493
1494    for (int iCase = nCases - 2; iCase >= nEndSpan; iCase--) { // -2 allows for ix1
1495        // may Mars have mercy on the poor soul who enters here
1496
1497        const int    ix0 = xOrder_(iCase,  iPred); // get the x's in descending order
1498        const double x0  = x_(ix0,iPred);
1499        const int    ix1 = xOrder_(iCase+1,iPred);
1500        const double x1  = x_(ix1,iPred);
1501        const double bx1 = bx_(ix1,iParent);
1502        const double xDelta = x1 - x0;
1503        const double bxSq = sq(bx1);
1504
1505#if USE_BLAS
1506        daxpy_(&nTerms, &bx1, &bxOrthCenteredT_(0,ix1), &ONE, CovSx,  &ONE);
1507        daxpy_(&nTerms, &xDelta, CovSx, &ONE, CovCol, &ONE);
1508#else
1509        int it;
1510        for (it = 0; it < nTerms; it++) {
1511            CovSx[it]  += (bxOrth_(ix1,it) - bxOrthMean[it]) * bx1;
1512            CovCol[it] += xDelta * CovSx[it];
1513        }
1514#endif
1515        bxSum    += bx1;
1516        bxSqSum  += bxSq;
1517        bxxSum   += bx1 * x1;
1518        bxSqxSum += bxSq * x1;
1519        const double su = st;
1520        st = bxxSum - bxSum * x0;
1521
1522        CovCol[nTerms] += xDelta * (2 * bxSqxSum - bxSqSum * (x0 + x1)) +
1523                          (sq(su) - sq(st)) / nCases;
1524
1525        if (nResp == 1) {    // treat nResp==1 as a special case, for speed
1526            ybxSum[0] += (y_(ix1, 0) - yMean[0]) * bx1;
1527            ycboSum_(nTerms, 0) += xDelta * ybxSum[0];
1528        } else for (iResp = 0; iResp < nResp; iResp++) {
1529            ybxSum[iResp] += (y_(ix1, iResp) - yMean[iResp]) * bx1;
1530            ycboSum_(nTerms, iResp) += xDelta * ybxSum[iResp];
1531        }
1532        if (iSpan-- == 1) {
1533            iSpan = nMinSpan;
1534            if (CovCol[nTerms] > 0) {
1535                // calculate RssDelta and see if this knot beats the previous best
1536
1537                double RssDelta = RssDeltaLin;
1538                for (iResp = 0; iResp < nResp; iResp++) {
1539#if USE_BLAS
1540                    const double temp1 =
1541                        ycboSum_(nTerms,iResp) -
1542                        ddot_(&nTerms, &ycboSum_(0,iResp), &ONE, CovCol, &ONE);
1543
1544                    const double temp2 =
1545                        CovCol[nTerms] - ddot_(&nTerms, CovCol, &ONE, CovCol, &ONE);
1546#else
1547                    double temp1 = ycboSum_(nTerms,iResp);
1548                    double temp2 = CovCol[nTerms];
1549                    int it;
1550                    for (it = 0; it < nTerms; it++) {
1551                        temp1 -= ycboSum_(it,iResp) * CovCol[it];
1552                        temp2 -= sq(CovCol[it]);
1553                    }
1554#endif
1555                    if (temp2 / CovCol[nTerms] > BX_TOL)
1556                        RssDelta += sq(temp1) / temp2;
1557                }
1558                RssDelta /= NewVarAdjust;
1559
1560                // TODO HastieTibs code had an extra test here, seems unnecessary
1561                // !(iCase > 0 && x_(ix0,iPred) == x_(xOrder_(iCase-1,iPred),iPred))
1562
1563                if (RssDelta > *pRssDeltaForParentPredPair &&
1564                        RssDelta < MaxAllowedRssDelta      &&
1565                        iCase < nCases_nEndSpan            &&
1566                        bx1 > 0) {
1567                    *piBestCase = iCase;
1568                    *pRssDeltaForParentPredPair = RssDelta;
1569                }
1570            }
1571        }
1572    }
1573}
1574
1575//-----------------------------------------------------------------------------
1576// Add a candidate term at bx[,nTerms], with the parent term multiplied by
1577// the predictor iPred entering linearly.  Do this by setting the knot at
1578// the lowest value xMin of x, since max(0,x-xMin)==x-xMin for all x.  The
1579// change in RSS caused by adding this term forms the base RSS delta which
1580// we will try to beat in the search in FindKnot.
1581//
1582// This also initializes CovCol, bxOrth[,nTerms], and ycboSum[nTerms,]
1583
1584static INLINE void AddCandidateLinearTerm(
1585    double *pRssDeltaLin,       // out: change to RSS caused by adding new term
1586    bool   *pIsNewForm,         // io:
1587    double xbx[],               // out: nCases x 1
1588    double CovCol[],            // out: nMaxTerms x 1
1589    double ycboSum[],           // io: nMaxTerms x nResp
1590    double bxOrth[],            // io
1591    double bxOrthCenteredT[],   // io
1592    double bxOrthMean[],        // io
1593    const int iPred,            // in
1594    const int iParent,          // in
1595    const double x[],           // in: nCases x nPreds
1596    const double y[],           // in: nCases x nResp
1597    const double Weights[],     // in: nCases x 1, must not be NULL
1598    const int nCases,           // in
1599    const int nResp,            // in: number of cols in y
1600    const int nTerms,           // in
1601    const int nMaxTerms,        // in
1602    const double yMean[],       // in: vector nResp x 1
1603    const double bx[],          // in: MARS basis matrix
1604    const bool FullSet[],       // in
1605    const double NewVarAdjust)  // in
1606{
1607    // set xbx to x[,iPred] * bx[,iParent]
1608
1609    int iCase;
1610    for (iCase = 0; iCase < nCases; iCase++)
1611        xbx[iCase] = x_(iCase,iPred) * bx_(iCase,iParent);
1612
1613    // init bxOrth[,nTerms] and bxOrthMean[nTerms] for the candidate term
1614    // TODO look into *pIsNewForm handling here, it's confusing
1615
1616    InitBxOrthCol(bxOrth, bxOrthCenteredT, bxOrthMean, pIsNewForm,
1617        xbx, nTerms, FullSet, nCases, nMaxTerms, iParent, iPred, Weights);
1618
1619    // init CovCol and ycboSum[nTerms], for use by FindKnot later
1620
1621    memset(CovCol, 0, (nTerms-1) * sizeof(double));
1622    CovCol[nTerms] = 1;
1623    int iResp;
1624    for (iResp = 0; iResp < nResp; iResp++) {
1625        ycboSum_(nTerms, iResp) = 0;
1626        for (iCase = 0; iCase < nCases; iCase++)
1627            ycboSum_(nTerms, iResp) += (y_(iCase, iResp) - yMean[iResp]) *
1628                                       bxOrth_(iCase,nTerms);
1629    }
1630    // calculate change to RSS caused by adding candidate new term
1631
1632    *pRssDeltaLin = 0;
1633    for (iResp = 0; iResp < nResp; iResp++) {
1634        double yboSum = 0;
1635        for (iCase = 0; iCase < nCases; iCase++)
1636            yboSum += y_(iCase,iResp) * bxOrth_(iCase,nTerms);
1637        *pRssDeltaLin += sq(yboSum) / NewVarAdjust;
1638    }
1639    if (TraceGlobal >= 7)
1640        printf("Case %4d Cut % 12.4g< RssDelta %-12.5g ",
1641            0+IOFFSET, GetCut(0, iPred, nCases, x, xOrder), *pRssDeltaLin);
1642}
1643
1644//-----------------------------------------------------------------------------
1645// The caller has selected a candidate parent term iParent.
1646// This function now selects a predictor, and a knot for that predictor.
1647//
1648// TODO These functions have a ridiculous number of parameters, I know.
1649//
1650// TODO A note on the comparison against ALMOST_ZERO below:
1651// It's not a clean solution but seems to work ok.
1652// It was added after we saw different results on different
1653// machines for certain datasets e.g. (tested on earth 1.4.0)
1654// ldose  <- rep(0:5, 2) - 2
1655// ldose1 <- c(0.1, 1.2, 2.3, 3.4, 4.5, 5.6, 0.3, 1.4, 2.5, 3.6, 4.7, 5.8)
1656// sex3 <- factor(rep(c("male", "female", "andro"), times=c(6,4,2)))
1657// fac3 <- factor(c("lev2", "lev2", "lev1", "lev1", "lev3", "lev3",
1658//                  "lev2", "lev2", "lev1", "lev1", "lev3", "lev3"))
1659// numdead <- c(1,4,9,13,18,20,0,2,6,10,12,16)
1660// numdead2 <- c(2,3,10,13,19,20,0,3,7,11,13,17)
1661// pair <- cbind(numdead, numdead2)
1662// df <- data.frame(sex3, ldose, ldose1, fac3)
1663// am <-  earth(df, pair, trace=6, pmethod="none", degree=2)
1664
1665static INLINE void FindPred(
1666    int    *piBestCase,         // out: return -1 if no new term available
1667                                //      else return an index into the ORDERED x's
1668    int    *piBestPred,         // out
1669    int    *piBestParent,       // out: existing term on which we are basing the new term
1670    double *pBestRssDeltaForTerm,   // io: updated if new predictor is better
1671    double *pBestRssDeltaForParent, // io: used only by FAST_MARS
1672    bool   *pIsNewForm,         // out
1673    bool   *pIsLinPred,         // out: true if knot is at min x val so x enters linearly
1674    double MaxRssPerPred[],     // io: nPreds x 1, max RSS for each predictor over all parents
1675    double xbx[],               // io: nCases x 1
1676    double CovSx[],             // io
1677    double CovCol[],            // io
1678    double ycboSum[],           // io: nMaxTerms x nResp
1679    double bxOrth[],            // io
1680    double bxOrthCenteredT[],   // io
1681    double bxOrthMean[],        // io
1682    const int iBestPred,        // in: if -1 then search for best predictor, else use this predictor
1683    const int iParent,          // in
1684    const double x[],           // in: nCases x nPreds
1685    const double y[],           // in: nCases x nResp
1686    const double Weights[],     // in: nCases x 1
1687    const int nCases,           // in
1688    const int nResp,            // in: number of cols in y
1689    const int nPreds,           // in
1690    const int nTerms,           // in
1691    const int nMaxTerms,        // in
1692    const double yMean[],       // in: vector nResp x 1
1693    const double MaxAllowedRssDelta, // in: FindKnot rejects any changes in Rss greater than this
1694    const double bx[],          // in: MARS basis matrix
1695    const bool FullSet[],       // in
1696    const int xOrder[],         // in
1697    const int nUses[],          // in: nbr of times each predictor is used in the model
1698    const int Dirs[],           // in
1699    const double NewVarPenalty, // in: penalty for adding a new variable (default is 0)
1700    const int LinPreds[])       // in: nPreds x 1, 1 if predictor must enter linearly
1701{
1702#if USING_R
1703    const int nServiceR = 1000000 / nCases;
1704#endif
1705    double *ybxSum = (double *)malloc1(nResp * sizeof(double),  // working var for FindKnot
1706                        "ybxSum\t\tnResp %d sizeof(double) %d",
1707                        nResp, sizeof(double));
1708    bool UpdatedBestRssDelta = false;
1709    int iFirstPred = 0;
1710    int iLastPred = nPreds - 1;
1711    if (iBestPred >= 0) {
1712        // we already know the best predictor to use so don't iterate over all preds
1713        iFirstPred = iBestPred;
1714        iLastPred = iBestPred;
1715    }
1716    for (int iPred = iFirstPred; iPred <= iLastPred; iPred++) {
1717        if (Dirs_(iParent,iPred) != 0) {    // predictor is in parent term?
1718            if (TraceGlobal >= 7)
1719                printf("|Parent %-2d Pred %-2d"
1720                    "                                   "
1721                    "                skip (pred is in parent)\n",
1722                    iParent+IOFFSET, iPred+IOFFSET);
1723#if USING_R
1724        } else if (!IsAllowed(iPred, iParent, Dirs, nPreds, nMaxTerms)) {
1725            if (TraceGlobal >= 7)
1726                printf("|Parent %-2d Pred %-2d"
1727                    "                                   "
1728                    "                skip (not allowed by \"allowed\" func)\n",
1729                    iParent+IOFFSET, iPred+IOFFSET);
1730#endif
1731        } else {
1732#if USING_R
1733            static int iServiceR = 0;
1734            if (++iServiceR > nServiceR) {
1735                ServiceR();
1736                iServiceR = 0;
1737            }
1738#endif
1739            if (TraceGlobal >= 7)
1740                printf("|Parent %-2d Pred %-2d ", iParent+IOFFSET, iPred+IOFFSET);
1741            const double NewVarAdjust = 1 + (nUses[iPred] == 0? NewVarPenalty: 0);
1742            double RssDeltaLin = 0;    // change in RSS for iPred entering linearly
1743            bool IsNewForm = GetNewFormFlag(iPred, iParent, Dirs,
1744                                FullSet, nTerms, nPreds, nMaxTerms);
1745            if (IsNewForm) {
1746                // create a candidate term at bx[,nTerms],
1747                // with iParent and iPred entering linearly
1748
1749                AddCandidateLinearTerm(&RssDeltaLin, &IsNewForm,
1750                    xbx, CovCol, ycboSum, bxOrth, bxOrthCenteredT, bxOrthMean,
1751                    iPred, iParent, x, y, Weights,
1752                    nCases, nResp, nTerms, nMaxTerms,
1753                    yMean, bx, FullSet, NewVarAdjust);
1754
1755                if (fabs(RssDeltaLin - *pBestRssDeltaForTerm) < ALMOST_ZERO)
1756                    RssDeltaLin = *pBestRssDeltaForTerm;        // see header note
1757                if (RssDeltaLin > *pBestRssDeltaForParent)
1758                    *pBestRssDeltaForParent = RssDeltaLin;
1759                if (RssDeltaLin > *pBestRssDeltaForTerm) {
1760                    // The new term (with predictor entering linearly) beats other
1761                    // candidate terms so far.
1762
1763                    if (TraceGlobal >= 7)
1764                        printf("BestRssDeltaForTermSoFar %g (lin pred) ",
1765                            RssDeltaLin - *pBestRssDeltaForTerm);
1766
1767                    UpdatedBestRssDelta = true;
1768                    *pBestRssDeltaForTerm = RssDeltaLin;
1769                    *pIsLinPred    = true;
1770                    *piBestCase    = 0;         // knot is at the lowest value of x
1771                    *piBestPred    = iPred;
1772                    *piBestParent  = iParent;
1773                }
1774            }
1775            double RssDeltaForParentPredPair = RssDeltaLin;
1776            if (TraceGlobal >= 7)
1777                printf("\n");
1778            if (!LinPreds[iPred]) {
1779                const int nMinSpan = GetMinSpan(nCases, nPreds, bx, iParent);
1780                const int nEndSpan = GetEndSpan(nCases, nPreds);
1781                int iBestCase = -1;
1782                FindKnot(&iBestCase, &RssDeltaForParentPredPair,
1783                        CovCol, ycboSum, CovSx, ybxSum,
1784                        (IsNewForm? nTerms + 1: nTerms),
1785                        iParent, iPred, nCases, nResp, nMaxTerms,
1786                        RssDeltaLin, MaxAllowedRssDelta,
1787                        bx, bxOrth, bxOrthCenteredT, bxOrthMean,
1788                        x, y, Weights, xOrder, yMean,
1789                        nMinSpan, nEndSpan, NewVarAdjust);
1790
1791                if (RssDeltaForParentPredPair > *pBestRssDeltaForParent)
1792                    *pBestRssDeltaForParent = RssDeltaForParentPredPair;
1793                if (RssDeltaForParentPredPair > *pBestRssDeltaForTerm) {
1794                    UpdatedBestRssDelta = true;
1795                    *pBestRssDeltaForTerm = RssDeltaForParentPredPair;
1796                    *pIsLinPred    = false;
1797                    *piBestCase    = iBestCase;
1798                    *piBestPred    = iPred;
1799                    *piBestParent  = iParent;
1800                    *pIsNewForm    = IsNewForm;
1801                    if (TraceGlobal >= 7)
1802                        printf("|                  "
1803                            "Case %4d Cut % 12.4g  "
1804                            "RssDelta %-12.5g BestRssDeltaForTermSoFar\n",
1805                            iBestCase+IOFFSET,
1806                            GetCut(iBestCase, iPred, nCases, x, xOrder),
1807                            *pBestRssDeltaForTerm);
1808                }
1809            }
1810            if (MaxRssPerPred && RssDeltaForParentPredPair > MaxRssPerPred[iPred])
1811                MaxRssPerPred[iPred] = RssDeltaForParentPredPair;
1812        } // else
1813    } // for iPred
1814    free1(ybxSum);
1815    if (UpdatedBestRssDelta && nUses[*piBestPred] == 0) {
1816        // de-adjust for NewVarPenalty (only makes a difference if NewVarPenalty != 0)
1817        const double NewVarAdjust = 1 + NewVarPenalty;
1818        *pBestRssDeltaForTerm *= NewVarAdjust;
1819    }
1820}
1821
1822//-----------------------------------------------------------------------------
1823// Find a new term to add to the model, if possible, and return the
1824// selected case (i.e. knot), predictor, and parent term indices.
1825//
1826// The new term is a copy of an existing parent term but extended
1827// by multiplying the parent by a new hockey stick function at the selected knot.
1828//
1829// Actually, this usually finds a term _pair_, with left and right hockey sticks.
1830//
1831// There are currently nTerms in the model. We want to add a term at index nTerms.
1832
1833static void FindTerm(
1834    int    *piBestCase,         // out: return -1 if no new term available
1835                                //      else return an index into the ORDERED x's
1836    int    *piBestPred,         // out:
1837    int    *piBestParent,       // out: existing term on which we are basing the new term
1838    double *pBestRssDeltaForTerm, // out: adding new term reduces RSS this much
1839                                  //      will be set to 0 if no possible new term
1840    bool   *pIsNewForm,         // out
1841    bool   *pIsLinPred,         // out: true if knot is at min x val so x enters linearly
1842    double MaxRssPerPred[],     // io: nPreds x 1, max RSS for each predictor over all parents
1843    double bxOrth[],            // io: column nTerms overwritten
1844    double bxOrthCenteredT[],   // io: kept in sync with bxOrth
1845    double bxOrthMean[],        // io: element nTerms overwritten
1846    const int iBestPred,        // in: if -1 then search for best predictor, else use this predictor
1847    const double x[],           // in: nCases x nPreds
1848    const double y[],           // in: nCases x nResp
1849    const double Weights[],     // in: nCases x 1
1850    const int nCases,           // in:
1851    const int nResp,            // in: number of cols in y
1852    const int nPreds,           // in:
1853    const int nTerms,           // in:
1854    const int nMaxDegree,       // in:
1855    const int nMaxTerms,        // in:
1856    const double yMean[],       // in: vector nResp x 1
1857    const double MaxAllowedRssDelta, // in: FindKnot rejects any changes in Rss greater than this
1858    const double bx[],          // in: MARS basis matrix
1859    const bool FullSet[],       // in:
1860    const int xOrder[],         // in:
1861    const int nFactorsInTerm[], // in:
1862    const int nUses[],          // in: nbr of times each predictor is used in the model
1863    const int Dirs[],           // in:
1864    const int nFastK,           // in: Fast MARS K
1865    const double NewVarPenalty, // in: penalty for adding a new variable (default is 0)
1866    const int LinPreds[])       // in: nPreds x 1, 1 if predictor must enter linearly
1867{
1868#if !FAST_MARS
1869    int Dummy = nFastK;             // prevent compiler warning: unused parameter
1870    Dummy = 0;
1871#endif
1872    if (TraceGlobal >= 7)
1873        printf("\n|Searching for new term %-3d                    "
1874               "RssDelta 0\n",
1875               nTerms+IOFFSET);
1876
1877    *piBestCase = -1;
1878    *pBestRssDeltaForTerm = 0;
1879    *pIsLinPred = false;
1880    *pIsNewForm = false;
1881    int iCase;
1882
1883    xbx = (double *)malloc1(nCases * sizeof(double),
1884                "xbx\t\t\tnCases %d sizeof(double) %d",
1885                nCases, sizeof(double));
1886    CovSx  = (double *)malloc1(nMaxTerms * sizeof(double),
1887                "CovSx\t\t\tnMaxTerms %d sizeof(double) %d",
1888                nMaxTerms, sizeof(double));
1889    CovCol = (double *)calloc1(nMaxTerms, sizeof(double),
1890                "CovCol\t\tnMaxTerms %d sizeof(double) %d",
1891                nMaxTerms, sizeof(double));
1892    ycboSum  = (double *)calloc1(nMaxTerms * nResp, sizeof(double),
1893                "ycbpSum\t\tnMaxTerms %d nResp %d sizeof(double) %d",
1894                nMaxTerms, nResp, sizeof(double));
1895
1896    for (int iResp = 0; iResp < nResp; iResp++)
1897        for (int iTerm = 0; iTerm < nTerms; iTerm++)
1898            for (iCase = 0; iCase < nCases; iCase++)
1899                ycboSum_(iTerm,iResp) +=
1900                    (y_(iCase,iResp) - yMean[iResp]) * bxOrth_(iCase,iTerm);
1901
1902    int iParent;
1903#if FAST_MARS
1904    GetNextParent(true, nFastK); // init queue iterator
1905    while ((iParent = GetNextParent(false, nFastK)) > -1) {
1906#else
1907    for (iParent = 0; iParent < nTerms; iParent++) {
1908#endif
1909        // Assume a bad RssDelta for iParent.  This pushes parent terms that
1910        // can't be used to the bottom of the FastMARS queue.  (A parent can't be
1911        // used if nFactorsInTerm is too big or all predictors are in the parent).
1912
1913        double BestRssDeltaForParent = -1;    // used only by FAST_MARS
1914
1915        if (nFactorsInTerm[iParent] >= nMaxDegree) {
1916            if (TraceGlobal >= 7)
1917                printf("|Parent %-2d"
1918                    "                                                      "
1919                    "     skip (nFactorsInTerm %d)\n",
1920                    iParent+IOFFSET, nFactorsInTerm[iParent]);
1921        } else {
1922            FindPred(piBestCase, piBestPred, piBestParent, pBestRssDeltaForTerm,
1923                &BestRssDeltaForParent, pIsNewForm, pIsLinPred, MaxRssPerPred,
1924                xbx, CovSx, CovCol, ycboSum, bxOrth, bxOrthCenteredT, bxOrthMean,
1925                iBestPred, iParent, x, y, Weights,
1926                nCases, nResp, nPreds, nTerms, nMaxTerms, yMean, MaxAllowedRssDelta,
1927                bx, FullSet, xOrder, nUses, Dirs, NewVarPenalty,
1928                LinPreds);
1929        }
1930#if FAST_MARS
1931        UpdateRssDeltaInQ(iParent, nTerms, BestRssDeltaForParent);
1932#endif
1933    } // iParent
1934    if (TraceGlobal >= 7)
1935        printf("\n");
1936    free1(ycboSum);
1937    free1(CovCol);
1938    free1(CovSx);
1939    free1(xbx);
1940}
1941
1942//-----------------------------------------------------------------------------
1943static void PrintForwardProlog(const int nCases, const int nPreds,
1944        const char *sPredNames[])   // in: predictor names, can be NULL
1945{
1946    if (TraceGlobal == 1)
1947        printf("Forward pass term 1");
1948    else if(TraceGlobal == 1.5)
1949        printf("Forward pass term 1\n");
1950    else if (TraceGlobal >= 2) {
1951        const char *sMinSpan = (nMinSpanGlobal < 0)? " (old minspan calculation)": "";
1952        printf("Forward pass: minspan %d endspan %d%s\n\n",
1953            GetMinSpan(nCases, nPreds, NULL, 0),
1954            GetEndSpan(nCases, nPreds), sMinSpan);
1955
1956        printf("         GRSq    RSq     DeltaRSq Pred ");
1957        if (sPredNames)
1958            printf("    PredName  ");
1959        printf("       Cut  Terms   ParentTerm\n");
1960
1961        printf("1      0.0000 0.0000                               %s%d\n",
1962            (sPredNames? "              ":""), IOFFSET);
1963    }
1964}
1965
1966//-----------------------------------------------------------------------------
1967static void PrintForwardStep(
1968        const int nTerms,
1969        const int nUsedTerms,
1970        const int iBestCase,
1971        const int iBestPred,
1972        const double RSq,
1973        const double RSqDelta,
1974        const double Gcv,
1975        const double GcvNull,
1976        const int nCases,
1977        const int xOrder[],
1978        const double x[],
1979        const bool IsLinPred,
1980        const bool IsNewForm,
1981        const char *sPredNames[])   // in: predictor names, can be NULL
1982{
1983    if (TraceGlobal == 6)
1984        printf("\n\n");
1985    if (TraceGlobal == 1) {
1986        printf(", ");
1987        if (nTerms % 30 == 29)
1988            printf("\n     ");
1989        printf("%d", nTerms+1);
1990    } else if (TraceGlobal == 1.5)
1991        printf("Forward pass term %d\n", nTerms+1);
1992    else if (TraceGlobal >= 2) {
1993        printf("%-4d%9.4f %6.4f %12.4g  ",
1994            nTerms+IOFFSET, 1-Gcv/GcvNull, RSq, RSqDelta);
1995        if (iBestPred < 0)
1996            printf("  -                                ");
1997        else {
1998            printf("%3d", iBestPred+IOFFSET);
1999            if (sPredNames) {
2000                if (sPredNames[iBestPred] && sPredNames[iBestPred][0])
2001                    printf(" %12.12s ", sPredNames[iBestPred]);
2002                else
2003                    printf(" %12.12s ", " ");
2004            }
2005            if (iBestCase == -1)
2006                printf("       none  ");
2007            else
2008                printf("% 11.5g%c ",
2009                    GetCut(iBestCase, iBestPred, nCases, x, xOrder),
2010                        (IsLinPred? '<': ' '));
2011            if (!IsLinPred && IsNewForm)  // two new used terms?
2012                printf("%-3d %-3d ", nUsedTerms-2+IOFFSET, nUsedTerms-1+IOFFSET);
2013            else
2014                printf("%-3d     ", nUsedTerms-1+IOFFSET);
2015            // AddTermPair will print the parents shortly, if any
2016        }
2017    }
2018    if (TraceGlobal != 0)
2019        fflush(stdout);
2020}
2021
2022//-----------------------------------------------------------------------------
2023static void PrintForwardEpilog(
2024            const int nTerms, const int nMaxTerms,
2025            const double Thresh,
2026            const double RSq, const double RSqDelta,
2027            const double Gcv, const double GcvNull,
2028            const int iBestCase,
2029            const bool FullSet[])
2030{
2031    if (TraceGlobal >= 1) {
2032        double GRSq = 1-Gcv/GcvNull;
2033
2034        // print reason why we stopped adding terms
2035        // NOTE: this code must match the loop termination conditions in ForwardPass
2036
2037        // treat very low nMaxTerms as a special case
2038        // because RSDelta etc. not yet completely initialized
2039
2040        if (nMaxTerms < 3)
2041            printf("\nReached max number of terms %d", nMaxTerms);
2042
2043        else if (Thresh != 0 && GRSq < MIN_GRSQ) {
2044            if(GRSq < -1000)
2045                printf("\nReached GRSq = -Inf at %d terms\n", nTerms);
2046            else
2047                printf("\nReached min GRSq (GRSq %g < %g) at %d terms\n",
2048                    GRSq, MIN_GRSQ, nTerms);
2049        }
2050        else if (Thresh != 0 && RSqDelta < Thresh)
2051            printf("\nReached delta RSq threshold (DeltaRSq %g < %g) at %d terms\n",
2052                RSqDelta, Thresh, nTerms);
2053
2054        else if (RSq > 1-Thresh)
2055            printf("\nReached max RSq (RSq %g > %g) at %d terms\n",
2056                RSq, 1-Thresh, nTerms);
2057
2058        else if (iBestCase < 0)
2059            printf("\nNo new term increases RSq (reached numerical limits) at %d terms\n",
2060                nTerms);
2061
2062        else {
2063            printf("\nReached max number of terms %d", nMaxTerms);
2064            if (nTerms < nMaxTerms)
2065                printf(" (no room for another term pair)");
2066            printf("\n");
2067        }
2068        printf("After forward pass GRSq %.4g RSq %.4g\n", GRSq, RSq);
2069    }
2070    if (TraceGlobal >= 2) {
2071        printf("Forward pass complete: %d terms", nTerms);
2072        int nUsed = GetNbrUsedCols(FullSet, nMaxTerms);
2073        if (nUsed != nTerms)
2074            printf(" (%d terms used)", nUsed);
2075        printf("\n");
2076    }
2077    if (TraceGlobal >= 3)
2078        printf("\n");
2079}
2080
2081//-----------------------------------------------------------------------------
2082static void CheckVec(const double x[], int nCases, int nCols, const char sVecName[])
2083{
2084    int iCol, iCase;
2085
2086    for (iCol = 0; iCol < nCols; iCol++)
2087        for (iCase = 0; iCase < nCases; iCase++) {
2088#if USING_R
2089             if (ISNA(x[iCase + iCol * nCases])) {
2090                 if (nCols > 1)
2091                     error("%s[%d,%d] is NA",
2092                         sVecName, iCase+IOFFSET, iCol+IOFFSET);
2093                 else
2094                     error("%s[%d] is NA", sVecName, iCase+IOFFSET);
2095             }
2096#endif
2097             if (ISNAN(x[iCase + iCol * nCases])) {
2098                 if (nCols > 1)
2099                     error("%s[%d,%d] is NaN",
2100                         sVecName, iCase+IOFFSET, iCol+IOFFSET);
2101                 else
2102                     error("%s[%d] is NaN", sVecName, iCase+IOFFSET);
2103             }
2104             if (!FINITE(x[iCase + iCol * nCases])) {
2105                 if (nCols > 1)
2106                     error("%s[%d,%d] is not finite",
2107                         sVecName, iCase+IOFFSET, iCol+IOFFSET);
2108                 else
2109                     error("%s[%d] is not finite", sVecName, iCase+IOFFSET);
2110             }
2111    }
2112}
2113
2114//-----------------------------------------------------------------------------
2115static void CheckRssNull(double RssNull, const double y[], int iResp, int nCases)
2116{
2117    if (RssNull / nCases < 1e-8) {
2118        if (iResp)
2119            error("variance of y[,%d] is zero (values are all equal to %g)",
2120                iResp+IOFFSET, &y_(0,iResp));
2121        else
2122            error("variance of y is zero (values are all equal to %g)",
2123                &y_(0,iResp));
2124    }
2125}
2126
2127//-----------------------------------------------------------------------------
2128static double *pInitWeights(const double WeightsArg[], int nCases)
2129{
2130    Weights = (double *)malloc1(nCases * sizeof(double),
2131                    "Weights\t\tnCases %d sizeof(double) %d",
2132                    nCases, sizeof(double));
2133    if (!WeightsArg)
2134        for (int iCase = 0; iCase < nCases; iCase++)
2135            Weights[iCase] = 1;
2136    else for (int iCase = 0; iCase < nCases; iCase++) {
2137        Weights[iCase] = WeightsArg[iCase];
2138#if USING_R
2139        if (ISNA(Weights[iCase]))
2140            error("Weights[%d] is NA");
2141#endif
2142        if (ISNAN(Weights[iCase]))
2143            error("Weights[%d] is NaN");
2144        if (!FINITE(Weights[iCase]))
2145            error("Weights[%d] is not finite");
2146        if (Weights[iCase] < ALMOST_ZERO)
2147            error("Weights[%d] is less than or equal to zero");
2148    }
2149    return Weights;
2150}
2151
2152//-----------------------------------------------------------------------------
2153// Forward pass
2154//
2155// After initializing the intercept term, the main for loop adds terms in pairs.
2156// In the for loop, nTerms is the index of the potential new term; nTerms+1
2157// the index of its partner.
2158// The upper term in the term pair may not be useable.  If so we still
2159// increment nTerms by 2 but don't set the flag in FullSet.
2160//
2161// TODO feature: add option to prescale x and y
2162
2163static void ForwardPass(
2164    int    *pnTerms,            // out: highest used term number in full model
2165    bool   FullSet[],           // out: 1 * nMaxTerms, indices of lin indep cols of bx
2166    double bx[],                // out: MARS basis matrix, nCases * nMaxTerms
2167    int    Dirs[],              // out: nMaxTerms * nPreds, -1,0,1,2 for iTerm, iPred
2168    double Cuts[],              // out: nMaxTerms * nPreds, cut for iTerm, iPred
2169    int    nFactorsInTerm[],    // out: number of hockey stick funcs in each MARS term
2170    int    nUses[],             // out: nbr of times each predictor is used in the model
2171    const double x[],           // in: nCases x nPreds
2172    const double y[],           // in: nCases x nResp
2173    const double WeightsArg[],  // in: nCases x 1, can be NULL, currently ignored
2174    const int nCases,           // in: number of rows in x and elements in y
2175    const int nResp,            // in: number of cols in y
2176    const int nPreds,           // in:
2177    const int nMaxDegree,       // in:
2178    const int nMaxTerms,        // in:
2179    const double Penalty,       // in: GCV penalty per knot
2180    double Thresh,              // in: forward step threshold
2181    int nFastK,                 // in: Fast MARS K
2182    const double FastBeta,      // in: Fast MARS ageing coef
2183    const double NewVarPenalty, // in: penalty for adding a new variable (default is 0)
2184    const int  LinPreds[],      // in: nPreds x 1, 1 if predictor must enter linearly
2185    const bool UseBetaCache,    // in: true to use the beta cache, for speed
2186    const char *sPredNames[])   // in: predictor names, can be NULL
2187{
2188    if (TraceGlobal >= 5)
2189        printf("earth.c %s\n", VERSION);
2190
2191    // The limits below are somewhat arbitrary and generous.
2192    // They are intended to catch gross errors on the part of the
2193    // caller, and to prevent crashes because of 0 sizes etc.
2194    // We use error rather than ASSERT because these are user settable params
2195    // and we want to be informative from the user's perspective.
2196    // The errors are reported using the variable names in the R code.
2197
2198    // prevent possible minspan range problems, also prevent crash when nCases==0
2199    if (nCases < 8)
2200        error("need at least 8 rows in x, you have %d", nCases);
2201#if 0 // removed for earth 2.0-6
2202    if (nCases < nPreds)    // (this check may not actually be necessary)
2203        warning("Need as many rows as columns in x (nrow %d ncol %d)",
2204                 nCases, nPreds);
2205#endif
2206    if (nCases > 1e8)
2207        error("too many rows %d in input matrix, max is 1e8", nCases);
2208    if (nResp < 1)
2209        error("nResp %d < 1", nResp);
2210    if (nResp > 1e6)
2211        error("nResp %d > 1e6", nResp);
2212    if (nPreds < 1)
2213        error("nPreds %d < 1", nPreds);
2214    if (nPreds > 1e5)
2215        error("nPreds %d > 1e5", nPreds);
2216    if (nMaxDegree <= 0)
2217        error("degree %d <= 0", nMaxDegree);
2218    if (nMaxDegree > MAX_DEGREE)
2219        error("degree %d > %d", nMaxDegree, MAX_DEGREE);
2220    if (nMaxTerms < 3)      // prevent internal misbehaviour
2221        error("nk %d < 3", nMaxTerms);
2222    if (nMaxTerms > 10000)
2223        error("nk %d > 10000", nMaxTerms);
2224    if (nFastK <= 0)
2225        nFastK = 10000+1;   // bigger than any nMaxTerms
2226    if (nFastK < 3)         // avoid possible queue boundary conditions
2227        nFastK = 3;
2228    if (Penalty < 0 && Penalty != -1)
2229        error("penalty %g < 0, the only legal value less than 0 is -1 "
2230            "(meaning terms and knots are free)", Penalty);
2231    if (Penalty > 1000)
2232        error("penalty %g > 1000", Penalty);
2233    if (Thresh < 0)
2234        error("thresh %g < 0", Thresh);
2235    if (Thresh >= 1)
2236        error("thresh %g >= 1", Thresh);
2237    if (nMinSpanGlobal > nCases/2)
2238        error("minspan %d > nrow(x)/2 %d", nMinSpanGlobal, nCases/2);
2239    if (FastBeta < 0)
2240        error("fast.beta %g < 0", FastBeta);
2241    if (FastBeta > 1000)
2242        error("fast.beta %g > 1000", FastBeta);
2243    if (TraceGlobal < 0)
2244        warning("trace %g < 0", TraceGlobal);
2245    if (TraceGlobal > 10)
2246        warning("trace %g > 10", TraceGlobal);
2247    if (NewVarPenalty < 0)
2248        warning("newvar.penalty %g < 0", NewVarPenalty);
2249    if (NewVarPenalty > 10)
2250        warning("newvar.penalty %g > 10", NewVarPenalty);
2251    if (UseBetaCache != 0 && UseBetaCache != 1)
2252        warning("Use.Beta.Cache is neither TRUE nor FALSE");
2253
2254    CheckVec(x, nCases, nPreds, "x");
2255    CheckVec(y, nCases, nResp,  "y");
2256
2257    bxOrth          = (double *)malloc1(nCases * nMaxTerms * sizeof(double),
2258                        "bxOrth\t\tnCases %d nMaxTerms %d  sizeof(double) %d",
2259                        nCases, nMaxTerms, sizeof(double));
2260
2261    bxOrthCenteredT = (double *)malloc1(nMaxTerms * nCases * sizeof(double),
2262                        "bxOrthCenteredT\tnMaxTerms %d nCases %d  sizeof(double) %d",
2263                        nMaxTerms, nCases, sizeof(double));
2264
2265    bxOrthMean      = (double *)malloc1(nMaxTerms * nResp * sizeof(double),
2266                        "bxOrthMean\t\tnMaxTerms %d nResp %d  sizeof(double) %d",
2267                        nMaxTerms, nResp, sizeof(double));
2268
2269    yMean           = (double *)malloc1(nResp * sizeof(double),
2270                        "yMean\t\t\tnResp %d sizeof(double) %d",
2271                        nResp, sizeof(double));
2272
2273    memset(FullSet,        0, nMaxTerms * sizeof(bool));
2274    memset(Dirs,           0, nMaxTerms * nPreds * sizeof(int));
2275    memset(Cuts,           0, nMaxTerms * nPreds * sizeof(double));
2276    memset(nFactorsInTerm, 0, nMaxTerms * sizeof(int));
2277    memset(nUses,          0, nPreds * sizeof(int));
2278    memset(bx,             0, nCases * nMaxTerms * sizeof(double));
2279    Weights = pInitWeights(WeightsArg, nCases);
2280    xOrder = OrderArray(x, nCases, nPreds);
2281    InitBetaCache(UseBetaCache, nMaxTerms, nPreds);
2282    FullSet[0] = true;  // intercept
2283    bool GoodCol;
2284    InitBxOrthCol(bxOrth, bxOrthCenteredT, bxOrthMean, &GoodCol,   // intercept col 0
2285        &bx_(0,0), 0 /*nTerms*/, FullSet, nCases, nMaxTerms, -1, -1, Weights);
2286    ASSERT(GoodCol);
2287
2288    for (int iCase = 0; iCase < nCases; iCase++)
2289        bx_(iCase,0) = 1;
2290    double RssNull = 0;
2291    for (int iResp = 0; iResp < nResp; iResp++) {
2292        yMean[iResp] = Mean(&y_(0,iResp), nCases);
2293        RssNull += SumOfSquares(&y_(0,iResp), yMean[iResp], nCases);
2294        CheckRssNull(RssNull, y, iResp, nCases);
2295    }
2296    double Rss = RssNull, RssDelta = RssNull, RSq = 0, RSqDelta = 0;
2297    int nUsedTerms = 1;     // number of used basis terms including intercept, for GCV calc
2298    double Gcv = 0, GcvNull = GetGcv(nUsedTerms, nCases, RssNull, Penalty);
2299    PrintForwardProlog(nCases, nPreds, sPredNames);
2300#if FAST_MARS
2301    InitQ(nMaxTerms);
2302    AddTermToQ(0, 1, RssNull, true, nMaxTerms, FastBeta); // intercept term into Q
2303#endif
2304    int nTerms = -1, iBestCase = -1;
2305    for (nTerms = 1;                                    // start after intercept
2306            nTerms < nMaxTerms-1 && RSq < 1-Thresh;     // -1 allows for upper term in pair
2307            nTerms += 2) {                              // add terms in pairs
2308        int iBestPred = -1, iBestParent = -1;
2309        bool IsNewForm, IsLinPred;
2310#if USING_R
2311        ServiceR();
2312#endif
2313        if (Rss <= 0)
2314            error("assertion failed: Rss <= 0 (y is all const?)");
2315        ASSERT(RssDelta > 0);
2316        const double MaxAllowedRssDelta = min(1.01 * Rss, 2 * RssDelta);
2317
2318        FindTerm(&iBestCase, &iBestPred, &iBestParent,
2319            &RssDelta, &IsNewForm, &IsLinPred, NULL /* MaxRssPerPred */,
2320            bxOrth, bxOrthCenteredT, bxOrthMean, -1, x, y, Weights,
2321            nCases, nResp, nPreds, nTerms, nMaxDegree, nMaxTerms,
2322            yMean, MaxAllowedRssDelta,
2323            bx, FullSet, xOrder, nFactorsInTerm, nUses, Dirs,
2324            nFastK, NewVarPenalty, LinPreds);
2325
2326        nUsedTerms++;
2327        if (!IsLinPred && IsNewForm)    // add paired term too?
2328            nUsedTerms++;
2329        Rss -= RssDelta;
2330        if (Rss < ALMOST_ZERO)          // RSS can go slightly neg due to rounding
2331            Rss = 0;                    // or can have very small values
2332        Gcv = GetGcv(nUsedTerms, nCases, Rss, Penalty);
2333        const double OldRSq = RSq;
2334        RSq = 1-Rss/RssNull;
2335        RSqDelta = RSq - OldRSq;
2336        if (RSqDelta < ALMOST_ZERO) // for consistent results with different
2337            RSqDelta = 0;           // float hardware else print nbrs like -2e-18
2338
2339        PrintForwardStep(nTerms, nUsedTerms, iBestCase, iBestPred, RSq, RSqDelta,
2340            Gcv, GcvNull, nCases, xOrder, x, IsLinPred, IsNewForm, sPredNames);
2341
2342        if (iBestCase < 0 ||
2343                (Thresh != 0 && ((1-Gcv/GcvNull) < MIN_GRSQ || RSqDelta < Thresh))) {
2344            if (TraceGlobal >= 2)
2345                printf("reject term\n");
2346            break;                      // NOTE break
2347        }
2348        AddTermPair(Dirs, Cuts, bx, bxOrth, bxOrthCenteredT, bxOrthMean,
2349            FullSet, nFactorsInTerm, nUses,
2350            nTerms, iBestParent, iBestCase, iBestPred, nPreds, nCases,
2351            nMaxTerms, IsNewForm, IsLinPred, LinPreds, x, xOrder, Weights);
2352
2353#if FAST_MARS
2354        if (!IsLinPred && IsNewForm) {  // good upper term?
2355            AddTermToQ(nTerms,   nTerms, POS_INF, false, nMaxTerms, FastBeta);
2356            AddTermToQ(nTerms+1, nTerms, POS_INF, true,  nMaxTerms, FastBeta);
2357        } else
2358            AddTermToQ(nTerms,   nTerms, POS_INF, true,  nMaxTerms, FastBeta);
2359        if (TraceGlobal == 6)
2360            PrintSortedQ(nFastK);
2361#endif
2362        if (TraceGlobal >= 2)
2363            printf("\n");
2364    }
2365    PrintForwardEpilog(nTerms, nMaxTerms, Thresh, RSq, RSqDelta,
2366                       Gcv, GcvNull, iBestCase, FullSet);
2367    *pnTerms = nTerms;
2368    FreeBetaCache();
2369#if FAST_MARS
2370    FreeQ();
2371#endif
2372    free1(xOrder);
2373    free1(Weights);
2374    Weights = NULL;
2375    free1(yMean);
2376    free1(bxOrthMean);
2377    free1(bxOrthCenteredT);
2378    free1(bxOrth);
2379}
2380
2381//-----------------------------------------------------------------------------
2382// This is an interface from R to the C routine ForwardPass
2383
2384#if USING_R
2385void ForwardPassR(              // for use by R
2386    int    FullSet[],           // out: nMaxTerms x 1, bool vec of lin indep cols of bx
2387    double bx[],                // out: MARS basis matrix, nCases x nMaxTerms
2388    double Dirs[],              // out: nMaxTerms x nPreds, elements are -1,0,1,2
2389    double Cuts[],              // out: nMaxTerms x nPreds, cut for iTerm,iPred
2390    const double x[],           // in: nCases x nPreds
2391    const double y[],           // in: nCases x nResp
2392    const double WeightsArg[],  // in: nCases x 1, can be R_NilValue, currently ignored
2393    const int *pnCases,         // in: number of rows in x and elements in y
2394    const int *pnResp,          // in: number of cols in y
2395    const int *pnPreds,         // in: number of cols in x
2396    const int *pnMaxDegree,     // in:
2397    const int *pnMaxTerms,      // in:
2398    const double *pPenalty,     // in:
2399    double *pThresh,            // in: forward step threshold
2400    const int *pnMinSpan,       // in:
2401    const int *pnFastK,         // in: Fast MARS K
2402    const double *pFastBeta,    // in: Fast MARS ageing coef
2403    const double *pNewVarPenalty, // in: penalty for adding a new variable (default is 0)
2404    const int  LinPreds[],        // in: nPreds x 1, 1 if predictor must enter linearly
2405    const SEXP Allowed,           // in: constraints function
2406    const int *pnAllowedFuncArgs, // in: number of arguments to Allowed function, 3 or 4
2407    const SEXP Env,               // in: environment for Allowed function
2408    const int *pnUseBetaCache,    // in: 1 to use the beta cache, for speed
2409    const double *pTrace,         // in: 0 none 1 overview 2 forward 3 pruning 4 more pruning
2410    const char *sPredNames[])     // in: predictor names in trace printfs, can be R_NilValue
2411{
2412    TraceGlobal = *pTrace;
2413    nMinSpanGlobal = *pnMinSpan;
2414
2415    const int nCases = *pnCases;
2416    const int nResp = *pnResp;
2417    const int nPreds = *pnPreds;
2418    const int nMaxTerms = *pnMaxTerms;
2419
2420    // nUses is the number of time each predictor is used in the model
2421    nUses = (int *)malloc1(*pnPreds * sizeof(int),
2422                    "nUses\t\t\t*pnPreds %d sizeof(int)",
2423                    *pnPreds, sizeof(int));
2424
2425    // nFactorsInTerm is number of hockey stick functions in basis term
2426    nFactorsInTerm = (int *)malloc1(nMaxTerms * sizeof(int),
2427                        "nFactorsInTerm\tnMaxTerms %d sizeof(int) %d",
2428                        nMaxTerms, sizeof(int));
2429
2430    iDirs = (int *)calloc1(nMaxTerms * nPreds, sizeof(int),
2431                        "iDirs\t\t\tnMaxTerms %d nPreds %d sizeof(int) %d",
2432                        nMaxTerms, nPreds, sizeof(int));
2433
2434    // convert int to bool (may be redundant, depending on compiler)
2435    BoolFullSet = (int *)malloc1(nMaxTerms * sizeof(bool),
2436                        "BoolFullSet\t\tnMaxTerms %d sizeof(bool) %d",
2437                        nMaxTerms, sizeof(bool));
2438
2439    int iTerm;
2440    for (iTerm = 0; iTerm < nMaxTerms; iTerm++)
2441        BoolFullSet[iTerm] = FullSet[iTerm];
2442
2443    // convert R NULL to C NULL
2444    if ((void *)sPredNames == (void *)R_NilValue)
2445        sPredNames = NULL;
2446    if ((void *)WeightsArg == (void *)R_NilValue)
2447        WeightsArg = NULL;
2448
2449    InitAllowedFunc(Allowed, *pnAllowedFuncArgs, Env, sPredNames, nPreds);
2450
2451    int nTerms;
2452    ForwardPass(&nTerms, BoolFullSet, bx, iDirs, Cuts, nFactorsInTerm, nUses,
2453            x, y, WeightsArg, nCases, nResp, nPreds, *pnMaxDegree, nMaxTerms,
2454            *pPenalty, *pThresh, *pnFastK, *pFastBeta, *pNewVarPenalty,
2455            LinPreds, (bool)(*pnUseBetaCache), sPredNames);
2456
2457    FreeAllowedFunc();
2458
2459    // remove linearly independent columns if necessary -- this updates BoolFullSet
2460
2461    RegressAndFix(NULL, NULL, NULL, BoolFullSet,
2462        bx, y, WeightsArg, nCases, nResp, nMaxTerms);
2463
2464    for (iTerm = 0; iTerm < nMaxTerms; iTerm++)     // convert int to double
2465        for (int iPred = 0; iPred < nPreds; iPred++)
2466            Dirs[iTerm + iPred * nMaxTerms] =
2467                iDirs[iTerm + iPred * nMaxTerms];
2468
2469    for (iTerm = 0; iTerm < nMaxTerms; iTerm++)     // convert bool to int
2470        FullSet[iTerm] = BoolFullSet[iTerm];
2471
2472    free1(BoolFullSet);
2473    free1(iDirs);
2474    free1(nFactorsInTerm);
2475    free1(nUses);
2476}
2477#endif // USING_R
2478
2479//-----------------------------------------------------------------------------
2480// Step backwards through the terms, at each step deleting the term that
2481// causes the least RSS increase.  The subset of terms and RSS of each subset are
2482// saved in PruneTerms and RssVec (which are indexed on subset size).
2483//
2484// The crux of the method used here is that the change in RSS (for nResp=1)
2485// caused by removing predictor iPred is DeltaRss = sq(Betas[iPred]) / Diags[iPred]
2486// where Diags is the diagonal elements of the inverse of X'X.
2487// See for example Miller (see refs in file header) section 3.4 p44.
2488//
2489// For multiple responses we sum the above DeltaRss over all responses.
2490//
2491// This method is fast and simple but accuracy can be poor if inv(X'X) is
2492// ill conditioned.  The Miller code in the R package "leaps" uses a more
2493// stable method, but does not support multiple responses.
2494//
2495// The "Xtx" in the name refers to the X'X matrix.
2496
2497static int EvalSubsetsUsingXtx(
2498    bool   PruneTerms[],    // out: nMaxTerms x nMaxTerms
2499    double RssVec[],        // out: nMaxTerms x 1, RSS of each subset
2500    const int    nCases,    // in
2501    const int    nResp,     // in: number of cols in y
2502    const int    nMaxTerms, // in: number of MARS terms in full model
2503    const double bx[],      // in: nCases x nMaxTerms, all cols must be indep
2504    const double y[],       // in: nCases * nResp
2505    const double WeightsArg[]) // in: nCases x 1, can be NULL
2506{
2507    double *Betas = (double *)malloc1(nMaxTerms * nResp * sizeof(double),
2508                        "Betas\t\t\tnMaxTerms %d nResp %d sizeof(double) %d",
2509                        nMaxTerms, nResp, sizeof(double));
2510
2511    double *Diags = (double *)malloc1(nMaxTerms * sizeof(double),
2512                        "Diags\t\t\tnMaxTerms %d sizeof(double) %d",
2513                        nMaxTerms, sizeof(double));
2514
2515    if (WeightsArg)
2516        Weights = pInitWeights(WeightsArg, nCases);
2517
2518    WorkingSet = (bool *)malloc1(nMaxTerms * sizeof(bool),
2519                        "WorkingSet\t\tnMaxTerms %d sizeof(bool) %d",
2520                        nMaxTerms, sizeof(bool));
2521
2522    for (int i = 0; i < nMaxTerms; i++)
2523        WorkingSet[i] = true;
2524
2525    int error_code = 0;
2526    for (int nUsedCols = nMaxTerms; nUsedCols > 0; nUsedCols--) {
2527        int nRank;
2528        double Rss;
2529        Regress(Betas, NULL, &Rss, Diags, &nRank, NULL,
2530            bx, y, Weights, nCases, nResp, nMaxTerms, WorkingSet);
2531
2532        if(nRank != nUsedCols)
2533        {
2534            error_code = 1;
2535            break;
2536        }
2537//            error("nRank %d != nUsedCols %d "
2538//                "(probably because of lin dep terms in bx)\n",
2539//                nRank, nUsedCols);
2540
2541        RssVec[nUsedCols-1] = Rss;
2542        memcpy(PruneTerms + (nUsedCols-1) * nMaxTerms, WorkingSet,
2543            nMaxTerms * sizeof(bool));
2544
2545        if (nUsedCols == 1)
2546            break;
2547
2548        // set iDelete to the best term for deletion
2549
2550        int iDelete = -1;   // term to be deleted
2551        int iTerm1 = 0;     // index taking into account false vals in WorkingSet
2552        double MinDeltaRss = POS_INF;
2553        for (int iTerm = 0; iTerm < nMaxTerms; iTerm++) {
2554            if (WorkingSet[iTerm]) {
2555                double DeltaRss = 0;
2556                for (int iResp = 0; iResp < nResp; iResp++)
2557                    DeltaRss += sq(Betas_(iTerm1, iResp)) / Diags[iTerm1];
2558                if (iTerm > 0 && DeltaRss < MinDeltaRss) {   // new minimum?
2559                    MinDeltaRss = DeltaRss;
2560                    iDelete = iTerm;
2561                }
2562                iTerm1++;
2563            }
2564        }
2565        //ASSERT(iDelete > 0);
2566        if (iDelete < 1)
2567        {
2568            // Trying to delete the intercept.
2569            error_code = 2;
2570            break;
2571        }
2572        WorkingSet[iDelete] = false;
2573    }
2574    if (WeightsArg)
2575        free1(Weights);
2576    free1(WorkingSet);
2577    free1(Diags);
2578    free1(Betas);
2579
2580    return error_code;
2581}
2582
2583//-----------------------------------------------------------------------------
2584// This is invoked from R if y has multiple columns i.e. a multiple response model.
2585// It is needed because the alternative (the leaps package) supports
2586// only one response.
2587
2588#if USING_R
2589void EvalSubsetsUsingXtxR(      // for use by R
2590    double       PruneTerms[],  // out: specifies which cols in bx are in best set
2591    double       RssVec[],      // out: nTerms x 1
2592    const int    *pnCases,      // in
2593    const int    *pnResp,       // in: number of cols in y
2594    const int    *pnMaxTerms,   // in
2595    const double bx[],          // in: MARS basis matrix, all cols must be indep
2596    const double y[],           // in: nCases * nResp
2597    const double WeightsArg[])  // in: nCases x 1, can be R_NilValue
2598{
2599    const int nMaxTerms = *pnMaxTerms;
2600    bool *BoolPruneTerms = (int *)malloc1(nMaxTerms * nMaxTerms * sizeof(bool),
2601                                "BoolPruneTerms\tMaxTerms %d nMaxTerms %d sizeof(bool) %d",
2602                                nMaxTerms, nMaxTerms, sizeof(bool));
2603
2604    if ((void *)WeightsArg == (void *)R_NilValue)
2605        WeightsArg = NULL;
2606
2607    EvalSubsetsUsingXtx(BoolPruneTerms, RssVec, *pnCases, *pnResp,
2608                        nMaxTerms, bx, y, WeightsArg);
2609
2610    // convert BoolPruneTerms to upper triangular matrix PruneTerms
2611
2612    for (int iModel = 0; iModel < nMaxTerms; iModel++) {
2613        int iPrune = 0;
2614        for (int iTerm = 0; iTerm < nMaxTerms; iTerm++)
2615            if (BoolPruneTerms[iTerm + iModel * nMaxTerms])
2616                PruneTerms[iModel + iPrune++ * nMaxTerms] = iTerm + 1;
2617    }
2618    free1(BoolPruneTerms);
2619}
2620#endif
2621
2622//-----------------------------------------------------------------------------
2623#if STANDALONE
2624static void BackwardPass(
2625    double *pBestGcv,       // out: GCV of the best model i.e. BestSet columns of bx
2626    bool   BestSet[],       // out: nMaxTerms x 1, indices of best set of cols of bx
2627    double Residuals[],     // out: nCases x nResp
2628    double Betas[],         // out: nMaxTerms x nResp
2629    const double bx[],      // in: nCases x nMaxTerms
2630    const double y[],       // in: nCases x nResp
2631    const double WeightsArg[], // in: nCases x 1, can be NILL
2632    const int nCases,       // in: number of rows in bx and elements in y
2633    const int nResp,        // in: number of cols in y
2634    const int nMaxTerms,    // in: number of cols in bx
2635    const double Penalty)   // in: GCV penalty per knot
2636{
2637    double *RssVec = (double *)malloc1(nMaxTerms * sizeof(double),
2638                        "RssVec\t\tnMaxTerms %d sizeof(double) %d",
2639                        nMaxTerms, sizeof(double));
2640
2641    bool *PruneTerms = (bool *)malloc1(nMaxTerms * nMaxTerms * sizeof(bool),
2642                        "PruneTerms\t\tnMaxTerms %d nMaxTerms %d sizeof(bool) %d",
2643                        nMaxTerms, nMaxTerms, sizeof(bool));
2644
2645    EvalSubsetsUsingXtx(PruneTerms, RssVec, nCases, nResp,
2646                        nMaxTerms, bx, y, WeightsArg);
2647
2648    // now we have the RSS for each model, so find the iModel which has the best GCV
2649
2650    if (TraceGlobal >= 3)
2651        printf("Backward pass:\nSubsetSize         GRSq          RSq\n");
2652    int iBestModel = -1;
2653    double GcvNull = GetGcv(1, nCases, RssVec[0], Penalty);
2654    double BestGcv = POS_INF;
2655    for (int iModel = 0; iModel < nMaxTerms; iModel++) {
2656        const double Gcv = GetGcv(iModel+1, nCases, RssVec[iModel], Penalty);
2657        if(Gcv < BestGcv) {
2658            iBestModel = iModel;
2659            BestGcv = Gcv;
2660        }
2661        if (TraceGlobal >= 3)
2662            printf("%10d %12.4f %12.4f\n", iModel+IOFFSET,
2663                1 - BestGcv/GcvNull, 1 - RssVec[iModel]/RssVec[0]);
2664    }
2665    if (TraceGlobal >= 3)
2666        printf("\nBackward pass complete: selected %d terms of %d, GRSq %g RSq %g\n\n",
2667            iBestModel+IOFFSET, nMaxTerms,
2668            1 - BestGcv/GcvNull, 1 - RssVec[iBestModel]/RssVec[0]);
2669
2670    // set BestSet to the model which has the best GCV
2671
2672    ASSERT(iBestModel >= 0);
2673    memcpy(BestSet, PruneTerms + iBestModel * nMaxTerms, nMaxTerms * sizeof(bool));
2674    free1(PruneTerms);
2675    free1(RssVec);
2676    *pBestGcv = BestGcv;
2677
2678    // get final model Betas, Residuals, Rss
2679
2680    RegressAndFix(Betas, Residuals, NULL, BestSet,
2681        bx, y, WeightsArg, nCases, nResp, nMaxTerms);
2682
2683}
2684#endif // STANDALONE
2685
2686//-----------------------------------------------------------------------------
2687#if STANDALONE
2688static int DiscardUnusedTerms(
2689    double bx[],             // io: nCases x nMaxTerms
2690    int    Dirs[],           // io: nMaxTerms x nPreds
2691    double Cuts[],           // io: nMaxTerms x nPreds
2692    bool   WhichSet[],       // io: tells us which terms to discard
2693    int    nFactorsInTerm[], // io
2694    const int nMaxTerms,
2695    const int nPreds,
2696    const int nCases)
2697{
2698    int nUsed = 0, iTerm;
2699    for (iTerm = 0; iTerm < nMaxTerms; iTerm++)
2700        if (WhichSet[iTerm]) {
2701            memcpy(bx + nUsed * nCases, bx + iTerm * nCases, nCases * sizeof(double));
2702            for (int iPred = 0; iPred < nPreds; iPred++) {
2703                Dirs_(nUsed, iPred) = Dirs_(iTerm, iPred);
2704                Cuts_(nUsed, iPred) = Cuts_(iTerm, iPred);
2705            }
2706            nFactorsInTerm[nUsed] = nFactorsInTerm[iTerm];
2707            nUsed++;
2708        }
2709    memset(WhichSet, 0, nMaxTerms * sizeof(bool));
2710    for (iTerm = 0; iTerm < nUsed; iTerm++)
2711        WhichSet[iTerm] = true;
2712    return nUsed;
2713}
2714#endif // STANDALONE
2715
2716//-----------------------------------------------------------------------------
2717#if STANDALONE
2718void Earth(
2719    double *pBestGcv,       // out: GCV of the best model i.e. BestSet columns of bx
2720    int    *pnTerms,        // out: max term nbr in final model, after removing lin dep terms
2721    bool   BestSet[],       // out: nMaxTerms x 1, indices of best set of cols of bx
2722    double bx[],            // out: nCases x nMaxTerms
2723    int    Dirs[],          // out: nMaxTerms x nPreds, -1,0,1,2 for iTerm, iPred
2724    double Cuts[],          // out: nMaxTerms x nPreds, cut for iTerm, iPred
2725    double Residuals[],     // out: nCases x nResp
2726    double Betas[],         // out: nMaxTerms x nResp
2727    const double x[],       // in: nCases x nPreds
2728    const double y[],       // in: nCases x nResp
2729    const double WeightsArg[], // in: nCases x 1, can be NULL, currently ignored
2730    const int nCases,       // in: number of rows in x and elements in y
2731    const int nResp,        // in: number of cols in y
2732    const int nPreds,       // in: number of cols in x
2733    const int nMaxDegree,   // in: Friedman's mi
2734    const int nMaxTerms,    // in: includes the intercept term
2735    const double Penalty,   // in: GCV penalty per knot
2736    double Thresh,          // in: forward step threshold
2737    const int nMinSpan,     // in: set to non zero to override internal calculation
2738    const bool Prune,       // in: do backward pass
2739    const int nFastK,       // in: Fast MARS K
2740    const double FastBeta,  // in: Fast MARS ageing coef
2741    const double NewVarPenalty, // in: penalty for adding a new variable
2742    const int LinPreds[],       // in: nPreds x 1, 1 if predictor must enter linearly
2743    const bool UseBetaCache,    // in: 1 to use the beta cache, for speed
2744    const double Trace,         // in: 0 none 1 overview 2 forward 3 pruning 4 more pruning
2745    const char *sPredNames[])   // in: predictor names in trace printfs, can be NULL
2746{
2747#if _MSC_VER && _DEBUG
2748    InitMallocTracking();
2749#endif
2750    TraceGlobal = Trace;
2751    nMinSpanGlobal = nMinSpan;
2752
2753    // nUses is the number of time each predictor is used in the model
2754    nUses = (int *)malloc1(nPreds * sizeof(int),
2755                        "nUses\t\t\tnPreds %d sizeof(int) %d",
2756                        nPreds, sizeof(int));
2757
2758    // nFactorsInTerm is number of hockey stick functions in basis term
2759    nFactorsInTerm = (int *)malloc1(nMaxTerms * sizeof(int),
2760                            "nFactorsInTerm\tnMaxTerms %d sizeof(int) %d",
2761                            nMaxTerms, sizeof(int));
2762
2763    int nTerms;
2764    ForwardPass(&nTerms, BestSet, bx, Dirs, Cuts, nFactorsInTerm, nUses,
2765        x, y, WeightsArg, nCases, nResp, nPreds, nMaxDegree, nMaxTerms,
2766        Penalty, Thresh, nFastK, FastBeta, NewVarPenalty,
2767        LinPreds, UseBetaCache, sPredNames);
2768
2769    // ensure bx is full rank by updating BestSet, and get Residuals and Betas
2770
2771    RegressAndFix(Betas, Residuals, NULL, BestSet,
2772        bx, y, WeightsArg, nCases, nResp, nMaxTerms);
2773
2774    if (TraceGlobal >= 6)
2775        PrintSummary(nMaxTerms, nTerms, nPreds, nResp,
2776            BestSet, Dirs, Cuts, Betas, nFactorsInTerm);
2777
2778    int nMaxTerms1 = DiscardUnusedTerms(bx, Dirs, Cuts, BestSet, nFactorsInTerm,
2779                        nMaxTerms, nPreds, nCases);
2780    if (Prune)
2781        BackwardPass(pBestGcv, BestSet, Residuals, Betas,
2782            bx, y, WeightsArg, nCases, nResp, nMaxTerms1, Penalty);
2783
2784    if (TraceGlobal >= 6)
2785        PrintSummary(nMaxTerms, nMaxTerms1, nPreds, nResp,
2786            BestSet, Dirs, Cuts, Betas, nFactorsInTerm);
2787
2788    *pnTerms = nMaxTerms1;
2789    free1(nFactorsInTerm);
2790    free1(nUses);
2791}
2792#endif // STANDALONE
2793
2794//-----------------------------------------------------------------------------
2795// Return the max number of knots in any term.
2796// Lin dep factors are considered as having one knot (at the min value of the predictor)
2797
2798#if STANDALONE
2799static int GetMaxKnotsPerTerm(
2800    const bool   UsedCols[],    // in
2801    const int    Dirs[],        // in
2802    const int    nPreds,        // in
2803    const int    nTerms,        // in
2804    const int    nMaxTerms)     // in
2805{
2806    int nKnotsMax = 0;
2807    for (int iTerm = 1; iTerm < nTerms; iTerm++)
2808        if (UsedCols[iTerm]) {
2809            int nKnots = 0; // number of knots in this term
2810            for (int iPred = 0; iPred < nPreds; iPred++)
2811                if (Dirs_(iTerm, iPred) != 0)
2812                    nKnots++;
2813            if (nKnots > nKnotsMax)
2814                nKnotsMax = nKnots;
2815        }
2816    return nKnotsMax;
2817}
2818#endif // STANDALONE
2819
2820//-----------------------------------------------------------------------------
2821// print a string representing the earth expresssion, one term per line
2822// TODO spacing is not quite right and is overly complicated
2823
2824#if STANDALONE
2825static void FormatOneResponse(
2826    const bool   UsedCols[],// in: nMaxTerms x 1, indices of best set of cols of bx
2827    const int    Dirs[],    // in: nMaxTerms x nPreds, -1,0,1,2 for iTerm, iPred
2828    const double Cuts[],    // in: nMaxTerms x nPreds, cut for iTerm, iPred
2829    const double Betas[],   // in: nMaxTerms x nResp
2830    const int    nPreds,
2831    const int    iResp,
2832    const int    nTerms,
2833    const int    nMaxTerms,
2834    const int    nDigits,   // number of significant digits to print
2835    const double MinBeta)   // terms with fabs(betas) less than this are not printed, 0 for all
2836{
2837    int iBestTerm = 0;
2838    int nKnotsMax = GetMaxKnotsPerTerm(UsedCols, Dirs, nPreds, nTerms, nMaxTerms);
2839    int nKnots = 0;
2840    char s[1000];
2841    ASSERT(nDigits >= 0);
2842    char sFormat[50];  sprintf(sFormat,  "%%-%d.%dg", nDigits+6, nDigits);
2843    char sFormat1[50]; sprintf(sFormat1, "%%%d.%dg",  nDigits+6, nDigits);
2844    int nPredWidth;
2845    if (nPreds > 100)
2846        nPredWidth = 3;
2847    else if (nPreds > 10)
2848        nPredWidth = 2;
2849    else
2850        nPredWidth = 1;
2851    char sPredFormat[20]; sprintf(sPredFormat, "%%%dd", nPredWidth);
2852    char sPad[500]; sprintf(sPad, "%*s", 28+nDigits+nPredWidth, " ");    // comment pad
2853    const int nUsedCols = nTerms;       // nUsedCols is needed for the Betas_ macro
2854    printf(sFormat, Betas_(0, iResp));  // intercept
2855    while (nKnots++ < nKnotsMax)
2856        printf(sPad);
2857    printf(" // 0\n");
2858
2859    for (int iTerm = 1; iTerm < nTerms; iTerm++)
2860        if (UsedCols[iTerm]) {
2861            iBestTerm++;
2862            if (fabs(Betas_(iBestTerm, iResp)) >= MinBeta) {
2863                printf("%+-9.3g", Betas_(iBestTerm, iResp));
2864                nKnots = 0;
2865                for (int iPred = 0; iPred < nPreds; iPred++) {
2866                    switch(Dirs_(iTerm, iPred)) {
2867                        case  0:
2868                            break;
2869                        case -1:
2870                            sprintf(s, " * max(0, %s - %*sx[%s])",
2871                                sFormat, nDigits+2, " ", sPredFormat);
2872                            printf(s, Cuts_(iTerm, iPred), iPred);
2873                            nKnots++;
2874                            break;
2875                        case  1:
2876                            sprintf(s, " * max(0, x[%s]%*s-  %s)",
2877                                sPredFormat,  nDigits+2, " ", sFormat1);
2878                            printf(s, iPred, Cuts_(iTerm, iPred));
2879                            nKnots++;
2880                            break;
2881                        case  2:
2882                            sprintf(s, " * x[%s]%*s                    ",
2883                                sPredFormat,  nDigits+2, " ");
2884                            printf(s, iPred);
2885                            nKnots++;
2886                            break;
2887                        default:
2888                            ASSERT(false);
2889                            break;
2890                    }
2891                }
2892                while (nKnots++ < nKnotsMax)
2893                    printf(sPad);
2894                printf(" // %d\n", iBestTerm);
2895            }
2896        }
2897}
2898
2899void FormatEarth(
2900    const bool   UsedCols[],// in: nMaxTerms x 1, indices of best set of cols of bx
2901    const int    Dirs[],    // in: nMaxTerms x nPreds, -1,0,1,2 for iTerm, iPred
2902    const double Cuts[],    // in: nMaxTerms x nPreds, cut for iTerm, iPred
2903    const double Betas[],   // in: nMaxTerms x nResp
2904    const int    nPreds,
2905    const int    nResp,     // in: number of cols in y
2906    const int    nTerms,
2907    const int    nMaxTerms,
2908    const int    nDigits,   // number of significant digits to print
2909    const double MinBeta)   // terms with fabs(betas) less than this are not printed, 0 for all
2910{
2911    for (int iResp = 0; iResp < nResp; iResp++) {
2912        if (nResp > 1)
2913            printf("Response %d:\n", iResp+IOFFSET);
2914        FormatOneResponse(UsedCols, Dirs, Cuts, Betas, nPreds, iResp,
2915            nTerms, nMaxTerms, nDigits, MinBeta);
2916    }
2917}
2918#endif // STANDALONE
2919
2920//-----------------------------------------------------------------------------
2921// return the value predicted by an earth model, given  a vector of inputs x
2922
2923#if STANDALONE
2924static double PredictOneResponse(
2925    const double x[],        // in: vector nPreds x 1 of input values
2926    const bool   UsedCols[], // in: nMaxTerms x 1, indices of best set of cols of bx
2927    const int    Dirs[],     // in: nMaxTerms x nPreds, -1,0,1,2 for iTerm, iPred
2928    const double Cuts[],     // in: nMaxTerms x nPreds, cut for iTerm, iPred
2929    const double Betas[],    // in: nMaxTerms x 1
2930    const int    nPreds,     // in: number of cols in x
2931    const int    nTerms,
2932    const int    nMaxTerms)
2933{
2934    double yHat = Betas[0];
2935    int iTerm1 = 0;
2936    for (int iTerm = 1; iTerm < nTerms; iTerm++)
2937        if (UsedCols[iTerm]) {
2938            iTerm1++;
2939            double Term = Betas[iTerm1];
2940            for (int iPred = 0; iPred < nPreds; iPred++)
2941                switch(Dirs_(iTerm, iPred)) {
2942                    case  0: break;
2943                    case -1: Term *= max(0, Cuts_(iTerm, iPred) - x[iPred]); break;
2944                    case  1: Term *= max(0, x[iPred] - Cuts_(iTerm, iPred)); break;
2945                    case  2: Term *= x[iPred]; break;
2946                    default: ASSERT("bad direction" == NULL); break;
2947                }
2948            yHat += Term;
2949        }
2950    return yHat;
2951}
2952
2953void PredictEarth(
2954    double       y[],        // out: vector nResp
2955    const double x[],        // in: vector nPreds x 1 of input values
2956    const bool   UsedCols[], // in: nMaxTerms x 1, indices of best set of cols of bx
2957    const int    Dirs[],     // in: nMaxTerms x nPreds, -1,0,1,2 for iTerm, iPred
2958    const double Cuts[],     // in: nMaxTerms x nPreds, cut for iTerm, iPred
2959    const double Betas[],    // in: nMaxTerms x nResp
2960    const int    nPreds,     // in: number of cols in x
2961    const int    nResp,      // in: number of cols in y
2962    const int    nTerms,
2963    const int    nMaxTerms)
2964{
2965    for (int iResp = 0; iResp < nResp; iResp++)
2966        y[iResp] = PredictOneResponse(x, UsedCols, Dirs, Cuts,
2967                       Betas + iResp * nTerms, nPreds, nTerms, nMaxTerms);
2968}
2969#endif // STANDALONE
2970
2971//-----------------------------------------------------------------------------
2972// Example main routine
2973// See earth/src/tests/test.earthc.c for another example
2974
2975#if STANDALONE
2976extern "C"{
2977
2978    void error(const char *args, ...)       // params like printf
2979    {
2980        char s[1000];
2981        va_list p;
2982        va_start(p, args);
2983        vsprintf(s, args, p);
2984        va_end(p);
2985        printf("\nError: %s\n", s);
2986        exit(-1);
2987    }
2988
2989    void xerbla_(char *srname, int *info)   // needed by BLAS and LAPACK routines
2990    {
2991        char buf[7];
2992        strncpy(buf, srname, 6);
2993        buf[6] = 0;
2994        error("BLAS/LAPACK routine %6s gave error code %d", buf, -(*info));
2995    }
2996
2997}
2998#endif
2999
3000
3001/*
3002 * Extern interface for ctypes
3003 */
3004
3005extern "C" void EarthForwardPass(
3006    int    *pnTerms,            // out: highest used term number in full model
3007    bool   FullSet[],           // out: 1 * nMaxTerms, indices of lin indep cols of bx
3008    double bx[],                // out: MARS basis matrix, nCases * nMaxTerms
3009    int    Dirs[],              // out: nMaxTerms * nPreds, -1,0,1,2 for iTerm, iPred
3010    double Cuts[],              // out: nMaxTerms * nPreds, cut for iTerm, iPred
3011    int    nFactorsInTerm[],    // out: number of hockey stick funcs in each MARS term
3012    int    nUses[],             // out: nbr of times each predictor is used in the model
3013    const double x[],           // in: nCases x nPreds
3014    const double y[],           // in: nCases x nResp
3015    const double WeightsArg[],  // in: nCases x 1, can be NULL, currently ignored
3016    const int nCases,           // in: number of rows in x and elements in y
3017    const int nResp,            // in: number of cols in y
3018    const int nPreds,           // in:
3019    const int nMaxDegree,       // in:
3020    const int nMaxTerms,        // in:
3021    const double Penalty,       // in: GCV penalty per knot
3022    double Thresh,              // in: forward step threshold
3023    int nFastK,                 // in: Fast MARS K
3024    const double FastBeta,      // in: Fast MARS ageing coef
3025    const double NewVarPenalty, // in: penalty for adding a new variable (default is 0)
3026    const int  LinPreds[],      // in: nPreds x 1, 1 if predictor must enter linearly
3027    const bool UseBetaCache,    // in: true to use the beta cache, for speed
3028    const char *sPredNames[])   // in: predictor names, can be NULL
3029{
3030    ForwardPass(pnTerms, FullSet, bx, Dirs, Cuts, nFactorsInTerm, nUses,
3031            x, y, WeightsArg, nCases, nResp, nPreds, nMaxDegree,
3032            nMaxTerms, Penalty, Thresh, nFastK, FastBeta, NewVarPenalty,
3033            LinPreds, UseBetaCache, sPredNames);
3034}
3035
3036extern "C" int EarthEvalSubsetsUsingXtx(
3037    bool   PruneTerms[],    // out: nMaxTerms x nMaxTerms
3038    double RssVec[],        // out: nMaxTerms x 1, RSS of each subset
3039    const int    nCases,    // in
3040    const int    nResp,     // in: number of cols in y
3041    const int    nMaxTerms, // in: number of MARS terms in full model
3042    const double bx[],      // in: nCases x nMaxTerms, all cols must be indep
3043    const double y[],       // in: nCases * nResp
3044    const double WeightsArg[]) // in: nCases x 1, can be NULL
3045{
3046    return EvalSubsetsUsingXtx(PruneTerms, RssVec, nCases, nResp, nMaxTerms, bx, y, WeightsArg);
3047}
3048
3049/*
3050 * ORANGE INTERFACE
3051 */
3052
3053TEarthLearner::TEarthLearner()
3054{
3055    max_terms = 21;
3056    max_degree = 1;
3057    penalty = (max_degree > 1)? 3.0: 2.0;
3058    threshold = 0.001;
3059    prune = true;
3060    trace = 0.0;
3061    min_span = 0;
3062    fast_k = 20;
3063    fast_beta = 0.0;
3064    new_var_penalty = 0.0;
3065    use_beta_cache = true;
3066}
3067
3068PClassifier TEarthLearner::operator() (PExampleGenerator examples, const int & weight_id)
3069{
3070    TDomain& domain = examples->domain.getReference();
3071    int num_preds = domain.attributes->size();
3072    int num_cases = examples->numberOfExamples();
3073    int num_responses = 1;
3074    if (num_cases < 0){
3075        raiseError("Cannot learn from an example generator of unknown size.");
3076    }
3077
3078    // TODO: Check for classVar, assert all attributes are continuous
3079
3080//  num_preds = 1;
3081//  num_cases = 100;
3082
3083    double best_gcv;
3084    int num_terms;
3085
3086    double *x = (double *) calloc(num_preds * num_cases, sizeof(double));
3087    double *y = (double *) calloc(num_cases * num_responses, sizeof(double));
3088    double *bx = (double *) calloc(num_cases * max_terms, sizeof(double));
3089    bool *best_set = (bool *) calloc(max_terms, sizeof(bool));
3090    int *dirs = (int *) calloc(max_terms * num_preds, sizeof(int));
3091    double *cuts = (double *) calloc(max_terms * num_preds, sizeof(double));
3092    double *residuals = (double *) calloc(num_cases * num_responses, sizeof(double));
3093    double *betas = (double *) calloc(max_terms * num_responses, sizeof(double));
3094    int * lin_preds = (int *) calloc(num_preds, sizeof(int));
3095    double *weights = NULL;
3096
3097    // Redefine x indexing
3098    #undef x_
3099    #define x_(i, j) x[i + j * num_cases]
3100
3101    TExampleGenerator::iterator ex_iter = examples->begin();
3102    for (int i=0; i<num_cases; i++, ++ex_iter)
3103    {
3104        TExample &example = *ex_iter;
3105        for (int j=0; j<num_preds; j++)
3106        {
3107            double tempx;
3108            TValue &value = example[j];
3109            if (value.isSpecial())
3110                tempx = 0.0;
3111            else
3112                if (value.varType == TValue::INTVAR)
3113                    tempx = (double) value.intV;
3114                else
3115                    tempx = (double) value.floatV;
3116            x_(i, j) = tempx;
3117        }
3118        double tempy;
3119        TValue &class_value = example.getClass();
3120        if (class_value.varType == TValue::INTVAR)
3121            tempy = (double) class_value.intV;
3122        else
3123            tempy = (double) class_value.floatV;
3124        y[i] = tempy;
3125    }
3126//  for (int i = 0; i < num_cases; i++) {
3127//          const double x0 = (double)i / num_cases;
3128//          x[i] = x0;
3129//          y[i] = sin(4 * x0);     // target function, change this to whatever you want
3130//      }
3131
3132
3133    const char **preds_names = NULL; // Used for trace only.
3134//  preds_names = (char **) malloc(num_preds * sizeof(char *));
3135//  for (int i=0; i<num_preds; i++){
3136//      preds_names[i] = NULL;
3137//  }
3138
3139    Earth(&best_gcv, &num_terms, best_set, bx, dirs, cuts, residuals, betas,
3140            x, y, weights, num_cases, num_responses, num_preds, max_degree,
3141            max_terms, penalty, threshold, min_span, prune,
3142            fast_k, fast_beta, new_var_penalty, lin_preds, use_beta_cache, trace, preds_names);
3143
3144    PEarthClassifier classifier = mlnew TEarthClassifier(examples->domain, best_set, dirs, cuts, betas, num_preds, num_responses, num_terms, max_terms);
3145//  std::string str = classifier->format_earth();
3146
3147    // Free memory
3148    free((void *)x);
3149    free((void *)y);
3150    free((void *)bx);
3151    free((void *)residuals);
3152//  free((void *)weights);
3153    free((void *)lin_preds);
3154
3155    return classifier;
3156}
3157
3158TEarthClassifier::TEarthClassifier(PDomain _domain, bool * best_set, int * dirs, double * cuts, double *betas, int _num_preds, int _num_responses, int _num_terms, int _max_terms)
3159{
3160    domain = _domain;
3161    classVar = domain->classVar;
3162    _best_set = best_set;
3163    _dirs = dirs;
3164    _cuts = cuts;
3165    _betas = betas;
3166    num_preds = _num_preds;
3167    num_responses = _num_responses;
3168    num_terms = _num_terms;
3169    max_terms = _max_terms;
3170    computesProbabilities = false;
3171    init_members();
3172}
3173
3174TEarthClassifier::TEarthClassifier()
3175{
3176    domain = NULL;
3177    classVar = NULL;
3178    _best_set = NULL;
3179    _dirs = NULL;
3180    _cuts = NULL;
3181    _betas = NULL;
3182    num_preds = 0;
3183    num_responses = 0;
3184    num_terms = 0;
3185    max_terms = 0;
3186    computesProbabilities = false;
3187}
3188
3189TEarthClassifier::TEarthClassifier(const TEarthClassifier & other)
3190{
3191    raiseError("Not implemented");
3192}
3193
3194TEarthClassifier::~TEarthClassifier()
3195{
3196    if (_best_set)
3197        free(_best_set);
3198    if (_dirs)
3199        free(_dirs);
3200    if (_cuts)
3201        free(_cuts);
3202    if (_betas)
3203        free(_betas);
3204}
3205
3206TValue TEarthClassifier::operator()(const TExample& example)
3207{
3208    double *x = to_xvector(example);
3209    double y = 0.0;
3210    PredictEarth(&y, x, _best_set, _dirs, _cuts, _betas, num_preds, num_responses, num_terms, max_terms);
3211    free(x);
3212    if (classVar->varType == TValue::INTVAR)
3213        return TValue((int) std::max<float>(0.0, floor(y + 0.5)));
3214    else
3215        return TValue((float) y);
3216}
3217
3218std::string TEarthClassifier::format_earth(){
3219    FormatEarth(_best_set, _dirs, _cuts, _betas, num_preds, 1, num_terms, max_terms, 3, 0.0);
3220    // TODO: FormatEarth to a string.
3221    return "";
3222}
3223
3224double* TEarthClassifier::to_xvector(const TExample& example)
3225{
3226//  TAttributeList &attributes = example.domain->attributes.getReference();
3227    double *x = (double *) calloc(num_preds, sizeof(double));
3228    for (int i=0; i<num_preds; i++){
3229        const TValue &val = example[i];
3230        if (val.isSpecial())
3231            x[i] = 0.0;
3232        else
3233            if (val.varType == TValue::INTVAR)
3234                x[i] = (double) val.intV;
3235            else
3236                x[i] = (double) val.floatV;
3237    }
3238    return x;
3239}
3240
3241PBoolList TEarthClassifier::get_best_set()
3242{
3243    PBoolList list = mlnew TBoolList();
3244    for (bool * p=_best_set; p < _best_set + max_terms; p++)
3245         list->push_back(*p);
3246    return list;
3247}
3248
3249PFloatListList TEarthClassifier::get_dirs()
3250{
3251    PFloatListList list = mlnew TFloatListList();
3252    for (int i=0; i<max_terms; i++)
3253    {
3254        TFloatList * inner_list = mlnew TFloatList();
3255        for(int j=0; j<num_preds; j++)
3256            inner_list->push_back(_dirs[i + j*max_terms]);
3257        list->push_back(inner_list);
3258    }
3259    return list;
3260}
3261
3262PFloatListList TEarthClassifier::get_cuts()
3263{
3264    PFloatListList list = mlnew TFloatListList();
3265    for (int i=0; i<max_terms; i++)
3266    {
3267        TFloatList * inner_list = mlnew TFloatList();
3268        for (int j=0; j<num_preds; j++)
3269            inner_list->push_back(_cuts[i + j*max_terms]);
3270        list->push_back(inner_list);
3271    }
3272    return list;
3273}
3274
3275PFloatList TEarthClassifier::get_betas()
3276{
3277    PFloatList list = mlnew TFloatList();
3278    for (double * p=_betas; p < _betas + max_terms; p++)
3279        list->push_back((float)*p);
3280    return list;
3281}
3282
3283void TEarthClassifier::init_members()
3284{
3285    best_set = get_best_set();
3286    dirs = get_dirs();
3287    cuts = get_cuts();
3288    betas = get_betas();
3289
3290}
3291
3292void TEarthClassifier::save_model(TCharBuffer& buffer)
3293{
3294    buffer.writeInt(max_terms);
3295    buffer.writeInt(num_terms);
3296    buffer.writeInt(num_preds);
3297    buffer.writeInt(num_responses);
3298    buffer.writeBuf((void *) _best_set, sizeof(bool) * max_terms);
3299    buffer.writeBuf((void *) _dirs, sizeof(int) * max_terms * num_preds);
3300    buffer.writeBuf((void *) _cuts, sizeof(double) * max_terms * num_preds);
3301    buffer.writeBuf((void *) _betas, sizeof(double) * max_terms * num_responses);
3302}
3303
3304void TEarthClassifier::load_model(TCharBuffer& buffer)
3305{
3306    if (max_terms)
3307        raiseError("Cannot overwrite a model");
3308
3309    max_terms = buffer.readInt();
3310    num_terms = buffer.readInt();
3311    num_preds = buffer.readInt();
3312    num_responses = buffer.readInt();
3313
3314    _best_set = (bool *) calloc(max_terms, sizeof(bool));
3315    _dirs = (int *) calloc(max_terms * num_preds, sizeof(int));
3316    _cuts = (double *) calloc(max_terms * num_preds, sizeof(double));
3317    _betas = (double *) calloc(max_terms * num_responses, sizeof(double));
3318
3319    buffer.readBuf((void *) _best_set, sizeof(bool) * max_terms);
3320    buffer.readBuf((void *) _dirs, sizeof(int) * max_terms * num_preds);
3321    buffer.readBuf((void *) _cuts, sizeof(double) * max_terms * num_preds);
3322    buffer.readBuf((void *) _betas, sizeof(double) * max_terms * num_responses);
3323    init_members();
3324}
3325
3326
3327
Note: See TracBrowser for help on using the repository browser.