001 /** 002 * NFoldEvaluator.java 003 * jCOLIBRI2 framework. 004 * @author Juan A. Recio-García. 005 * GAIA - Group for Artificial Intelligence Applications 006 * http://gaia.fdi.ucm.es 007 * 07/05/2007 008 */ 009 package jcolibri.evaluation.evaluators; 010 011 import java.util.ArrayList; 012 import java.util.Collection; 013 import java.util.Date; 014 import java.util.List; 015 016 import jcolibri.cbraplications.StandardCBRApplication; 017 import jcolibri.cbrcore.CBRCase; 018 import jcolibri.cbrcore.CBRCaseBase; 019 import jcolibri.evaluation.EvaluationReport; 020 import jcolibri.evaluation.Evaluator; 021 import jcolibri.exception.ExecutionException; 022 023 import org.apache.commons.logging.LogFactory; 024 025 /** 026 * This evaluation method divides the case base into several random folds (indicated by the user). 027 * For each fold, their cases are used as queries and the remaining folds are used together as case base. 028 * This process is performed several times. 029 * 030 * @author Juan A. Recio García - GAIA http://gaia.fdi.ucm.es 031 * @version 2.0 032 */ 033 public class NFoldEvaluator extends Evaluator 034 { 035 036 protected StandardCBRApplication app; 037 038 public void init(StandardCBRApplication cbrApp) { 039 040 report = new EvaluationReport(); 041 app = cbrApp; 042 try { 043 app.configure(); 044 } catch (ExecutionException e) { 045 LogFactory.getLog(this.getClass()).error(e); 046 } 047 } 048 049 /** 050 * Executes the N-Fold evaluation. 051 * @param folds Number of folds (randomly generated). 052 * @param repetitions Number of repetitions 053 */ 054 public void NFoldEvaluation(int folds, int repetitions) 055 { 056 try 057 { 058 //Get the time 059 long t = (new Date()).getTime(); 060 int numberOfCycles = 0; 061 062 // Run the precycle to load the case base 063 LogFactory.getLog(this.getClass()).info("Running precycle()"); 064 CBRCaseBase caseBase = app.preCycle(); 065 066 if (!(caseBase instanceof jcolibri.casebase.CachedLinealCaseBase)) 067 LogFactory 068 .getLog(this.getClass()) 069 .warn( 070 "Evaluation should be executed using a cached case base"); 071 072 Collection<CBRCase> cases = new ArrayList<CBRCase>(caseBase.getCases()); 073 074 //For each repetition 075 for(int r=0; r<repetitions; r++) 076 { 077 //Create the folds 078 createFolds(cases, folds); 079 080 //For each fold 081 for(int f=0; f<folds; f++) 082 { 083 ArrayList<CBRCase> querySet = new ArrayList<CBRCase>(); 084 ArrayList<CBRCase> caseBaseSet = new ArrayList<CBRCase>(); 085 //Obtain the query and casebase sets 086 getFolds(f, querySet, caseBaseSet); 087 088 //Clear the caseBase 089 caseBase.forgetCases(cases); 090 091 //Set the cases that acts as casebase in this cycle 092 caseBase.learnCases(caseBaseSet); 093 094 //Run cycle for each case in querySet (current fold) 095 for(CBRCase c: querySet) 096 { 097 LogFactory.getLog(this.getClass()).info( 098 "Running cycle() " + numberOfCycles); 099 app.cycle(c); 100 101 numberOfCycles++; 102 } 103 } 104 105 } 106 107 //Revert case base to original state 108 caseBase.forgetCases(cases); 109 caseBase.learnCases(cases); 110 111 //Run the poscycle to finish the application 112 LogFactory.getLog(this.getClass()).info("Running postcycle()"); 113 app.postCycle(); 114 115 116 //Complete the evaluation result 117 report.setTotalTime(t); 118 report.setNumberOfCycles(numberOfCycles); 119 120 121 } catch (Exception e) { 122 LogFactory.getLog(this.getClass()).error(e); 123 } 124 125 } 126 127 128 protected ArrayList<ArrayList<CBRCase>> _folds; 129 protected void createFolds(Collection<CBRCase> cases, int folds) 130 { 131 _folds = new ArrayList<ArrayList<CBRCase>>(); 132 int foldsize = cases.size() / folds; 133 ArrayList<CBRCase> copy = new ArrayList<CBRCase>(cases); 134 135 for(int f=0; f<folds; f++) 136 { 137 ArrayList<CBRCase> fold = new ArrayList<CBRCase>(); 138 for(int i=0; (i<foldsize)&&(copy.size()>0); i++) 139 { 140 int random = (int) (Math.random() * copy.size()); 141 CBRCase _case = copy.get( random ); 142 copy.remove(random); 143 fold.add(_case); 144 } 145 _folds.add(fold); 146 } 147 } 148 149 protected void getFolds(int f, List<CBRCase> querySet, List<CBRCase> caseBaseSet) 150 { 151 querySet.clear(); 152 caseBaseSet.clear(); 153 154 querySet.addAll(_folds.get(f)); 155 156 for(int i=0; i<_folds.size(); i++) 157 if(i!=f) 158 caseBaseSet.addAll(_folds.get(i)); 159 } 160 161 }