001    /**
002     * NFoldEvaluator.java
003     * jCOLIBRI2 framework. 
004     * @author Juan A. Recio-García.
005     * GAIA - Group for Artificial Intelligence Applications
006     * http://gaia.fdi.ucm.es
007     * 07/05/2007
008     */
009    package jcolibri.evaluation.evaluators;
010    
011    import java.util.ArrayList;
012    import java.util.Collection;
013    import java.util.Date;
014    import java.util.List;
015    
016    import jcolibri.cbraplications.StandardCBRApplication;
017    import jcolibri.cbrcore.CBRCase;
018    import jcolibri.cbrcore.CBRCaseBase;
019    import jcolibri.evaluation.EvaluationReport;
020    import jcolibri.evaluation.Evaluator;
021    import jcolibri.exception.ExecutionException;
022    
023    import org.apache.commons.logging.LogFactory;
024    
025    /**
026     * This evaluation method divides the case base into several random folds (indicated by the user). 
027     * For each fold, their cases are used as queries and the remaining folds are used together as case base. 
028     * This process is performed several times.
029     * 
030     * @author Juan A. Recio García - GAIA http://gaia.fdi.ucm.es
031     * @version 2.0
032     */
033    public class NFoldEvaluator extends Evaluator
034    {
035    
036            protected StandardCBRApplication app;
037    
038            public void init(StandardCBRApplication cbrApp) {
039                    
040                    report = new EvaluationReport();
041                    app = cbrApp;
042                    try {
043                            app.configure();
044                    } catch (ExecutionException e) {
045                            LogFactory.getLog(this.getClass()).error(e);
046                    }
047            }
048    
049            /**
050             * Executes the N-Fold evaluation.
051             * @param folds Number of folds (randomly generated).
052             * @param repetitions Number of repetitions
053             */
054        public void NFoldEvaluation(int folds, int repetitions)
055        {   
056            try
057            {
058                //Get the time
059                long t = (new Date()).getTime();
060                int numberOfCycles = 0;
061    
062                    // Run the precycle to load the case base
063                            LogFactory.getLog(this.getClass()).info("Running precycle()");
064                            CBRCaseBase caseBase = app.preCycle();
065    
066                            if (!(caseBase instanceof jcolibri.casebase.CachedLinealCaseBase))
067                                    LogFactory
068                                                    .getLog(this.getClass())
069                                                    .warn(
070                                                                    "Evaluation should be executed using a cached case base");
071                
072                Collection<CBRCase> cases = new ArrayList<CBRCase>(caseBase.getCases());
073                
074                //For each repetition
075                for(int r=0; r<repetitions; r++)
076                {
077                    //Create the folds
078                    createFolds(cases, folds);
079                    
080                    //For each fold
081                    for(int f=0; f<folds; f++)
082                    {
083                        ArrayList<CBRCase> querySet = new ArrayList<CBRCase>();
084                        ArrayList<CBRCase> caseBaseSet = new ArrayList<CBRCase>();
085                        //Obtain the query and casebase sets
086                        getFolds(f, querySet, caseBaseSet);
087                        
088                        //Clear the caseBase
089                        caseBase.forgetCases(cases);
090                        
091                        //Set the cases that acts as casebase in this cycle
092                        caseBase.learnCases(caseBaseSet);
093                        
094                        //Run cycle for each case in querySet (current fold)
095                        for(CBRCase c: querySet)
096                        {                   
097                                                    LogFactory.getLog(this.getClass()).info(
098                                                                    "Running cycle() " + numberOfCycles);
099                                                    app.cycle(c);
100                            
101                            numberOfCycles++;
102                        }          
103                    } 
104                    
105                }
106    
107                            //Revert case base to original state
108                            caseBase.forgetCases(cases);
109                            caseBase.learnCases(cases);
110                
111                //Run the poscycle to finish the application
112                            LogFactory.getLog(this.getClass()).info("Running postcycle()");
113                            app.postCycle();
114    
115                
116                //Complete the evaluation result
117                report.setTotalTime(t);
118                report.setNumberOfCycles(numberOfCycles);
119                
120    
121            } catch (Exception e) {
122                            LogFactory.getLog(this.getClass()).error(e);
123                    }
124    
125        }
126    
127        
128        protected ArrayList<ArrayList<CBRCase>> _folds;
129        protected void createFolds(Collection<CBRCase> cases, int folds)
130        {
131            _folds = new  ArrayList<ArrayList<CBRCase>>();
132            int foldsize = cases.size() / folds;
133            ArrayList<CBRCase> copy = new ArrayList<CBRCase>(cases);
134            
135            for(int f=0; f<folds; f++)
136            {
137                ArrayList<CBRCase> fold = new ArrayList<CBRCase>();
138                for(int i=0; (i<foldsize)&&(copy.size()>0); i++)
139                {
140                    int random = (int) (Math.random() * copy.size());
141                    CBRCase _case = copy.get( random );
142                    copy.remove(random);
143                    fold.add(_case);
144                }
145                _folds.add(fold);
146            }
147        }
148        
149        protected void getFolds(int f, List<CBRCase> querySet, List<CBRCase> caseBaseSet)
150        {
151            querySet.clear();
152            caseBaseSet.clear();
153            
154            querySet.addAll(_folds.get(f));
155            
156            for(int i=0; i<_folds.size(); i++)
157                if(i!=f)
158                    caseBaseSet.addAll(_folds.get(i));
159        }
160        
161    }