001    package jcolibri.method.reuse.classification;
002    
003    import java.util.Collection;
004    import java.util.HashMap;
005    import java.util.Map;
006    
007    import jcolibri.extensions.classification.ClassificationSolution;
008    import jcolibri.method.retrieve.RetrievalResult;
009    
010    /**
011     * Provides the ability to classify a query by predicting its
012     * solution from supplied cases. Classification is done by 
013     * similarity weighted voting, where each vote is based on 
014     * the similarity of the case to the query. The class with the
015     * highest overall value is the predicted class.
016     * 
017     * @author Derek Bridge
018     * @author Lisa Cummins
019     * 16/05/07
020     */
021    public class SimilarityWeightedVotingMethod extends AbstractKNNClassificationMethod
022    {
023        /**
024         * Predicts the class that has the highest value vote
025         * among the k most similar cases, where votes are based on
026         * similarity to the query.
027         * If several classes receive the same highest vote, the class that
028         * has the lowest hash code is taken as the prediction. 
029         * @param cases
030         *            an ordered list of cases along with similarity scores.
031         * @return Returns the predicted solution.
032         */
033        public ClassificationSolution getPredictedSolution(Collection<RetrievalResult> cases)
034        {
035            Map<Object, Double> votes = new HashMap<Object, Double>();
036            Map<Object, ClassificationSolution> values = new HashMap<Object, ClassificationSolution>();
037            
038            for(RetrievalResult result: cases)
039            {   
040                ClassificationSolution solution = (ClassificationSolution)result.get_case().getSolution();
041               
042                Object solnAttVal = solution.getClassification();
043                 
044                double eval = result.getEval();
045                if (votes.containsKey(solnAttVal))
046                {   votes.put(solnAttVal, votes.get(solnAttVal) + eval);
047                }
048                else
049                {   votes.put(solnAttVal, eval);
050                    values.put(solnAttVal, solution);
051                }
052            }
053            double highestVoteSoFar = 0.0;
054            Object predictedClassVal = null;
055            for (Map.Entry<Object, Double> e : votes.entrySet())
056            {   if (e.getValue() >= highestVoteSoFar)
057                {   highestVoteSoFar = e.getValue();
058                    predictedClassVal = e.getKey();
059                }
060            }
061            return values.get(predictedClassVal);
062        }
063    }