Ignore:
Timestamp:
03/13/12 20:37:36 (2 years ago)
Author:
martin@…
Branch:
default
Message:

Changes to rule learning.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • source/orange/rulelearner.cpp

    r8735 r10517  
    4545: weightID(0), 
    4646  quality(ILLEGAL_FLOAT), 
    47   complexity(-1), 
     47  complexity(0), 
    4848  coveredExamples(NULL), 
    4949  coveredExamplesLength(-1), 
     
    9393 
    9494TRule::~TRule() 
    95 { delete coveredExamples; } 
     95{   delete coveredExamples; } 
    9696 
    9797bool TRule::operator ()(const TExample &ex) 
     
    986986           improved/rule->classDistribution->atint(targetClass) > min_improved_perc*0.01 && 
    987987           quality > (aprioriProb + 1e-3)) 
    988     futureQuality = quality; 
    989 //    futureQuality = 1.0 + quality; 
     988//    futureQuality = quality; 
     989    futureQuality = 1.0 + quality; 
    990990  else { 
    991991    PDistribution oldRuleDist = rule->classDistribution; 
     
    10161016      futureQuality = -1.0; 
    10171017    else { 
    1018       futureQuality = 0.0; 
     1018      futureQuality = 0.0;  
    10191019      PEITERATE(ei, rule->examples) { 
    10201020        if ((*ei).getClass().intV != targetClass) 
     
    10231023          continue; 
    10241024        } 
    1025         float x = ((*ei)[probVar].floatV-quality); //*rule->classDistribution->abs; 
    1026         if ((*ei)[probVar].floatV > quality) 
     1025      /*  float x = ((*ei)[probVar].floatV-quality); //*rule->classDistribution->abs; 
     1026        if ((*ei)[probVar].floatV > quality)  
    10271027          x *= (1.0-quality)/(bestQuality-quality); 
    10281028        x /= sqrt(quality*(1.0-quality)); // rule->classDistribution->abs* 
    1029         futureQuality += log(1.0-max(1e-12,1.0-2*zprob(x))); 
    1030       } 
    1031       futureQuality = 1.0 - exp(futureQuality); 
     1029        futureQuality += log(1.0-max(1e-12,1.0-2*zprob(x))); */ 
     1030        float x; 
     1031        if ((*ei)[probVar].floatV > quality) 
     1032        { 
     1033            x = 1.0 - ((*ei)[probVar].floatV-quality) / (bestQuality-quality); 
     1034        } 
     1035        else 
     1036        { 
     1037            x = 1.0 + quality - (*ei)[probVar].floatV; 
     1038        } 
     1039        futureQuality += x; 
     1040      } 
     1041//      futureQuality = 1.0 - exp(futureQuality); 
     1042      futureQuality /= rule->classDistribution->atint(targetClass); 
    10321043    } 
    10331044  } 
     
    19591970{ 
    19601971    PITERATE(TIntList, ind, ruleIndices[rule_i]) 
    1961   { 
    1962     float bestQuality = 0.0; 
    1963         PITERATE(TIntList, fr, prefixRules) { 
    1964               if (rules->at(*fr)->call(examples->at(*ind)) && rules->at(*fr)->quality > bestQuality) { 
    1965           bestQuality = rules->at(*fr)->quality; 
    1966                   p[getClassIndex(rules->at(*fr))][*ind] = rules->at(*fr)->quality; 
    1967                   for (int ci=0; ci<examples->domain->classVar->noOfValues(); ci++) 
    1968                       if (ci!=getClassIndex(rules->at(*fr))) 
    1969                           p[ci][*ind] = (1.0-rules->at(*fr)->quality)/(examples->domain->classVar->noOfValues()-1); 
    1970                   break; 
    1971               } 
    1972       } 
    1973   } 
     1972    { 
     1973        float bestQuality = 0.0; 
     1974        PITERATE(TIntList, fr, prefixRules) { 
     1975            if (rules->at(*fr)->call(examples->at(*ind)) && rules->at(*fr)->quality > bestQuality) { 
     1976                bestQuality = rules->at(*fr)->quality; 
     1977                p[getClassIndex(rules->at(*fr))][*ind] = rules->at(*fr)->quality; 
     1978                for (int ci=0; ci<examples->domain->classVar->noOfValues(); ci++) 
     1979                if (ci!=getClassIndex(rules->at(*fr))) 
     1980                    p[ci][*ind] = (1.0-rules->at(*fr)->quality)/(examples->domain->classVar->noOfValues()-1); 
     1981                break; 
     1982            } 
     1983        } 
     1984    } 
    19741985} 
    19751986 
     
    20182029void TLogitClassifierState::setPrefixRule(int rule_i) //, int position) 
    20192030{ 
    2020     prefixRules->push_back(rule_i); 
    2021     setFixed(rule_i); 
    2022     updateFixedPs(rule_i); 
    2023     betas[rule_i] = 0.0; 
    2024   computeAvgProbs(); 
    2025   computePriorProbs(); 
     2031    prefixRules->push_back(rule_i); 
     2032    setFixed(rule_i); 
     2033    updateFixedPs(rule_i); 
     2034    betas[rule_i] = 0.0; 
     2035    computeAvgProbs(); 
     2036    computePriorProbs(); 
    20262037} 
    20272038 
     
    20302041{} 
    20312042 
    2032 TRuleClassifier_logit::TRuleClassifier_logit(PRuleList arules, const float &minSignificance, const float &minBeta, PExampleTable anexamples, const int &aweightID, const PClassifier &classifier, const PDistributionList &probList, bool setPrefixRules, bool optimizeBetasFlag) 
     2043TRuleClassifier_logit::TRuleClassifier_logit(PRuleList arules, const float &minSignificance, const float &minBeta, const float &penalty, PExampleTable anexamples, const int &aweightID, const PClassifier &classifier, const PDistributionList &probList, bool setPrefixRules, bool optimizeBetasFlag) 
    20332044: TRuleClassifier(arules, anexamples, aweightID), 
    20342045  minSignificance(minSignificance), 
     
    20362047  setPrefixRules(setPrefixRules), 
    20372048  optimizeBetasFlag(optimizeBetasFlag), 
    2038   minBeta(minBeta) 
     2049  minBeta(minBeta), 
     2050  penalty(penalty) 
    20392051{ 
    20402052  initialize(probList); 
    20412053  float step = 2.0; 
    2042   minStep = (float)0.01; 
    20432054 
    20442055  // initialize prior betas 
     
    20512062  if (setPrefixRules) 
    20522063  { 
    2053       bool changed = setBestPrefixRule(); 
    2054       while (changed) { 
    2055         if (optimizeBetasFlag) 
    2056             optimizeBetas(); 
    2057         changed = setBestPrefixRule(); 
    2058       } 
     2064    bool changed = setBestPrefixRule(); 
     2065    while (changed) { 
     2066        if (optimizeBetasFlag) 
     2067            optimizeBetas(); 
     2068        changed = setBestPrefixRule(); 
     2069    } 
    20592070  } 
    20602071 
     
    22072218  wsig = sig; 
    22082219  PITERATE(TRuleList, ri, rules) { 
    2209     float maxDiff = (*ri)->classDistribution->atint(getClassIndex(*ri))/(*ri)->classDistribution->abs; 
    2210       maxDiff -= (*ri)->quality; 
    2211       wsig->push_back(maxDiff); 
    2212    float n = (*ri)->examples->numberOfExamples(); 
     2220    float maxDiff = (*ri)->classDistribution->atint(getClassIndex(*ri))/(*ri)->classDistribution->abs; 
     2221    maxDiff -= (*ri)->quality; 
     2222    wsig->push_back(maxDiff); 
     2223    float n = (*ri)->examples->numberOfExamples(); 
    22132224    float a = n*(*ri)->quality; 
    22142225    float b = n*(1.0-(*ri)->quality); 
     
    22722283        if ((*ri)->call(*ei)) { 
    22732284          //int vv = (*ei).getClass().intV; 
    2274               //if ((*ei).getClass().intV == getClassIndex(*ri)) 
    2275                 coverages[getClassIndex(*ri)][j] += 1.0; 
     2285          //if ((*ei).getClass().intV == getClassIndex(*ri)) 
     2286          coverages[getClassIndex(*ri)][j] += 1.0; 
    22762287        } 
    2277           j++; 
     2288        j++; 
    22782289      } 
    22792290      i++; 
     
    23132324 
    23142325    float step = 0.1f; 
    2315     float gamma = 0.01f; 
    23162326    float error = 1e+20f; 
    23172327    float old_error = 1e+21f; 
    23182328 
     2329    int nsteps = 0; 
    23192330    while (old_error > error || step > 0.00001) 
    23202331    { 
    23212332        // reduce step if improvement failed 
    2322         if (old_error < error) 
     2333        nsteps++; 
     2334        if (old_error < error && nsteps > 20 || nsteps > 1000) 
    23232335        { 
     2336            nsteps = 0; 
    23242337            step /= 10; 
    23252338            finalState->copyTo(currentState); 
     
    23352348 
    23362349            float der = 0.0; 
    2337             if (currentState->avgProb->at(i) < rules->at(i)->quality) 
    2338                 der -= rules->at(i)->quality - currentState->avgProb->at(i); 
    2339             der += 2*gamma*currentState->betas[i]; 
     2350            if (currentState->avgProb->at(i) > rules->at(i)->quality) 
     2351            { 
     2352                der += pow(rules->at(i)->quality - currentState->avgProb->at(i),2); 
     2353                der -= 2 * (rules->at(i)->quality - currentState->avgProb->at(i)) * currentState->betas[i]; 
     2354            } 
     2355            else 
     2356                der -= 2 * (rules->at(i)->quality - currentState->avgProb->at(i)); 
     2357            der += 2 * penalty * currentState->betas[i]; 
     2358//            if (currentState->avgProb->at(i) > rules->at(i)->quality) 
     2359//                der = max(0.01f / step, der); 
    23402360            currentState->newBeta(i,max(0.0f, currentState->betas[i]-step*der)); 
    23412361        } 
     
    23432363        error = 0; 
    23442364        for (int i=0; i<rules->size(); i++) { 
    2345             if (currentState->avgProb->at(i) < rules->at(i)->quality) 
    2346                 error += rules->at(i)->quality - currentState->avgProb->at(i); 
    2347             error += gamma*pow(currentState->betas[i],2); 
     2365//            if (currentState->avgProb->at(i) < rules->at(i)->quality) 
     2366            if (currentState->avgProb->at(i) > rules->at(i)->quality) 
     2367                error += pow(rules->at(i)->quality - currentState->avgProb->at(i),2) * currentState->betas[i]; 
     2368            else 
     2369                error += pow(rules->at(i)->quality - currentState->avgProb->at(i),2); 
     2370            error += penalty*pow(currentState->betas[i],2); 
    23482371        } 
    23492372        //printf("error = %4.4f\n", error); 
     
    23532376        //printf("\n");  
    23542377    } 
     2378 
    23552379    finalState->copyTo(currentState); 
    23562380} 
    2357  
    2358 /* 
    2359 void TRuleClassifier_logit::updateRuleBetas2(float step_) 
    2360 { 
    2361  
    2362   stabilizeAndEvaluate(step_,-1); 
    2363   PLogitClassifierState finalState, tempState; 
    2364   currentState->copyTo(finalState); 
    2365  
    2366   float step = 2.0; 
    2367   int changed; 
    2368   float worst_underestimate, underestimate; 
    2369   float auc = currentState->getAUC(); 
    2370   float brier = currentState->getBrierScore(); 
    2371   float temp_auc, temp_brier; 
    2372   int worst_rule_index; 
    2373   while (step > 0.001) 
    2374   { 
    2375       step /= 2; 
    2376       changed = 0; 
    2377       while (changed < 100) 
    2378       { 
    2379         changed = 0; 
    2380         worst_underestimate = (float)0.01; 
    2381         worst_rule_index = -1; 
    2382         // find rule with greatest underestimate in probability 
    2383         for (int i=0; i<rules->size(); i++) { 
    2384             if (currentState->avgProb->at(i) >= rules->at(i)->quality) 
    2385                 continue; 
    2386             if (skipRule[i]) 
    2387                 continue; 
    2388  
    2389             underestimate = (rules->at(i)->quality - currentState->avgProb->at(i));//*rules->at(i)->classDistribution->abs; 
    2390             // if under estimate error is big enough 
    2391             if (underestimate > worst_underestimate) 
    2392             { 
    2393                 worst_underestimate = underestimate; 
    2394                 worst_rule_index = i; 
    2395             } 
    2396         } 
    2397         if (worst_rule_index > -1) 
    2398         { 
    2399             currentState->newBeta(worst_rule_index,currentState->betas[worst_rule_index]+step); 
    2400             if (currentState->avgProb->at(worst_rule_index) > rules->at(worst_rule_index)->quality) 
    2401             { 
    2402                 finalState->copyTo(currentState); 
    2403                 changed = 100; 
    2404             } 
    2405             else 
    2406             { 
    2407               stabilizeAndEvaluate(step,-1); 
    2408               temp_auc = currentState->getAUC(); 
    2409               temp_brier = currentState->getBrierScore(); 
    2410               if (temp_auc >= auc && temp_brier < brier) 
    2411               { 
    2412                 currentState->copyTo(finalState); 
    2413                 changed = 0; 
    2414                 auc = temp_auc; 
    2415                 brier = temp_brier; 
    2416               } 
    2417               else 
    2418                 changed ++; 
    2419             } 
    2420          // } 
    2421         } 
    2422         else 
    2423         { 
    2424           changed = 100; 
    2425           finalState->copyTo(currentState); 
    2426         } 
    2427       } 
    2428   } 
    2429   finalState->copyTo(currentState); 
    2430 } 
    2431 */ 
    2432  
    2433 void TRuleClassifier_logit::stabilizeAndEvaluate(float & step, int last_changed_rule_index) 
    2434 { 
    2435     PLogitClassifierState tempState; 
    2436     currentState->copyTo(tempState); 
    2437     bool changed = true; 
    2438     while (changed) 
    2439     { 
    2440         changed = false; 
    2441         for (int i=0; i<rules->size(); i++) 
    2442         { 
    2443             if (currentState->avgProb->at(i) > (rules->at(i)->quality + 0.01) && currentState->betas[i] > 0.0 && 
    2444                 i != last_changed_rule_index) 
    2445             { 
    2446                 float new_beta = currentState->betas[i]-step > 0 ? currentState->betas[i]-step : 0.0; 
    2447                 currentState->newBeta(i,new_beta); 
    2448                 if (currentState->avgProb->at(i) < rules->at(i)->quality + 1e-6) 
    2449                 { 
    2450                     tempState->copyTo(currentState); 
    2451                 } 
    2452                 else 
    2453                 { 
    2454                     currentState->copyTo(tempState); 
    2455                     changed = true; 
    2456                 } 
    2457             } 
    2458         } 
    2459     } 
    2460 } 
    2461  
    24622381 
    24632382void TRuleClassifier_logit::addPriorClassifier(const TExample &ex, double * priorFs) { 
     
    25162435  bool foundPrefixRule = false; 
    25172436  float bestQuality = 0.0; 
    2518   PITERATE(TRuleList, rs, prefixRules) { 
    2519       if ((*rs)->call(ex) && (*rs)->quality > bestQuality) { 
    2520       bestQuality = (*rs)->quality; 
    2521           dist->setint(getClassIndex(*rs),(*rs)->quality); 
    2522           for (int ci=0; ci<examples->domain->classVar->noOfValues(); ci++) 
    2523             if (ci!=getClassIndex(*rs)) 
    2524                 dist->setint(ci,(1.0-(*rs)->quality)/(examples->domain->classVar->noOfValues()-1)); 
    2525       foundPrefixRule = true; 
    2526       break; 
    2527       } 
     2437  PITERATE(TRuleList, rs, prefixRules) {     
     2438    if ((*rs)->call(ex) && (*rs)->quality > bestQuality) { 
     2439        bestQuality = (*rs)->quality; 
     2440        dist->setint(getClassIndex(*rs),(*rs)->quality); 
     2441            for (int ci=0; ci<examples->domain->classVar->noOfValues(); ci++) 
     2442            if (ci!=getClassIndex(*rs)) 
     2443                dist->setint(ci,(1.0-(*rs)->quality)/(examples->domain->classVar->noOfValues()-1)); 
     2444        foundPrefixRule = true; 
     2445        break; 
     2446      } 
    25282447  } 
    25292448  if (foundPrefixRule) 
     
    25452464      if ((*r)->call(cexample)) { 
    25462465        if (getClassIndex(*r) == i) 
    2547             f += (*b); 
     2466            f += (*b); 
    25482467        else if (getClassIndex(*r) == res->noOfElements()-1) 
    2549           f -= (*b); 
     2468            f -= (*b); 
    25502469      } 
    25512470    dist->addint(i,exp(f)); 
Note: See TracChangeset for help on using the changeset viewer.