source: orange/Orange/feature/selection.py @ 11645:56cf220f5bfd

Revision 11645:56cf220f5bfd, 8.3 KB checked in by Ales Erjavec <ales.erjavec@…>, 9 months ago (diff)

Fixed 'top_rated' function (actually use the 'highest_best' parameter).

RevLine 
[8042]1__docformat__ = 'restructuredtext'
2
[10523]3from operator import itemgetter
4
[8042]5import Orange.core as orange
6
[8119]7from Orange.feature.scoring import score_all
[8042]8
[11645]9
[10708]10def top_rated(scores, n, highest_best=True):
11    """Return n top-rated features from the list of scores.
[9653]12
[11645]13    :param list scores:
14        A list such as the one returned by :func:`.score_all`
15    :param int n: Number of features to select.
16    :param bool highest_best:
17        If true, the features that are scored higher are preferred.
[8042]18    :rtype: :obj:`list`
[11645]19
[8042]20    """
[11645]21    return [f for f, score in
22            sorted(scores, key=itemgetter(1), reverse=highest_best)[:n]]
[8042]23
[10708]24bestNAtts = top_rated
[9653]25
[11645]26
[9645]27def above_threshold(scores, threshold=0.0):
[10172]28    """Return features (without scores) with scores above or
[8042]29    equal to a specified threshold.
[9653]30
[8042]31    :param scores: a list such as one returned by
[10172]32      :obj:`~Orange.feature.scoring.score_all`
[8042]33    :type scores: list
[10172]34    :param threshold: threshold for selection
[8042]35    :type threshold: float
36    :rtype: :obj:`list`
37
38    """
[10171]39    return [x[0] for x in scores if x[1] > threshold]
40
[8042]41
[9645]42attsAboveThreshold = above_threshold
43
44
[10708]45def select(data, scores, n):
[10171]46    """Construct and return a new data table that includes a
[10172]47    class and only the best features from a list scores.
[9653]48
[10172]49    :param data: a data table
50    :type data: :obj:`Orange.data.Table`
[10171]51    :param scores: a list such as the one returned by
[10172]52      :obj:`~Orange.feature.scoring.score_all`
[8042]53    :type scores: list
[10172]54    :param n: number of features to select
55    :type n: int
56    :rtype: :obj:`Orange.data.Table`
[8042]57    """
[10708]58    return data.select(top_rated(scores, n) + [data.domain.classVar.name])
[8042]59
[10708]60selectBestNAtts = select
[10709]61select_best_n = select
[8042]62
[9645]63
64def select_above_threshold(data, scores, threshold=0.0):
[10171]65    """Construct and return a new data table that includes a class and
[9653]66    features from the list returned by
[10708]67    :obj:`~Orange.feature.scoring.score_all` with higher or equal score
68    to a given threshold.
[9653]69
[10172]70    :param data: a data table
71    :type data: :obj:`Orange.data.Table`
[10171]72    :param scores: a list such as the one returned by
[10172]73      :obj:`~Orange.feature.scoring.score_all`
[8042]74    :type scores: list
[10172]75    :param threshold: threshold for selection
[8042]76    :type threshold: float
[10172]77    :rtype: :obj:`Orange.data.Table`
[8042]78    """
[9653]79    return data.select(above_threshold(scores, threshold) + \
80                       [data.domain.classVar.name])
[8042]81
[9645]82selectAttsAboveThresh = select_above_threshold
83
84
85def select_relief(data, measure=orange.MeasureAttribute_relief(k=20, m=50), margin=0):
[10171]86    """Iteratively remove the worst scored feature until no feature
87    has a score below the margin. The filter procedure was originally
88    designed for measures such as Relief, which are context dependent,
89    i.e., removal of features may change the scores of other remaining
90    features. The score is thus recomputed in each iteration.
[8042]91
[10172]92    :param data: a data table
93    :type data: :obj:`Orange.data.Table`
94    :param measure: a feature scorer
95    :type measure: :obj:`Orange.feature.scoring.Score`
96    :param margin: margin for removal
[8042]97    :type margin: float
[9653]98
[8042]99    """
[8119]100    measl = score_all(data, measure)
[9645]101    while len(data.domain.attributes) > 0 and measl[-1][1] < margin:
[10708]102        data = select(data, measl, len(data.domain.attributes) - 1)
[8119]103        measl = score_all(data, measure)
[8042]104    return data
105
[9653]106filterRelieff = select_relief
[8042]107
[9645]108
109class FilterAboveThreshold(object):
[10708]110    """A wrapper around :obj:`select_above_threshold`; the
111    constructor stores the parameters of the feature selection
112    procedure that are then applied when the the selection
113    is called with the actual data.
[9645]114
[10172]115    :param measure: a feature scorer
116    :type measure: :obj:`Orange.feature.scoring.Score`
117    :param threshold: threshold for selection. Defaults to 0.
[8042]118    :type threshold: float
[10708]119    """
[9645]120
121    def __new__(cls, data=None,
122                measure=orange.MeasureAttribute_relief(k=20, m=50),
123                threshold=0.0):
124        if data is None:
[9878]125            self = object.__new__(cls)
[9645]126            return self
127        else:
128            self = cls(measure=measure, threshold=threshold)
129            return self(data)
130
131    def __init__(self, measure=orange.MeasureAttribute_relief(k=20, m=50), \
132                 threshold=0.0):
[8042]133        self.measure = measure
134        self.threshold = threshold
135
136    def __call__(self, data):
[10172]137        """Return data table features that have scores above given
138        threshold.
[9645]139
[9662]140        :param data: data table
[10172]141        :type data: Orange.data.Table
[8042]142
143        """
[8119]144        ma = score_all(data, self.measure)
[9645]145        return select_above_threshold(data, ma, self.threshold)
[8042]146
[9645]147FilterAttsAboveThresh = FilterAboveThreshold
148FilterAttsAboveThresh_Class = FilterAboveThreshold
149
150
151class FilterBestN(object):
[10708]152    """A wrapper around :obj:`select`; the
[10172]153    constructor stores the filter parameters that are applied when the
154    function is called.
[9645]155
[10172]156    :param measure: a feature scorer
157    :type measure: :obj:`Orange.feature.scoring.Score`
158    :param n: number of features to select
[8042]159    :type n: int
160
161    """
[9645]162    def __new__(cls, data=None,
163                measure=orange.MeasureAttribute_relief(k=20, m=50),
164                n=5):
165
166        if data is None:
[9878]167            self = object.__new__(cls)
[9645]168            return self
169        else:
170            self = cls(measure=measure, n=n)
171            return self(data)
172
[9653]173    def __init__(self, measure=orange.MeasureAttribute_relief(k=20, m=50),
174                 n=5):
[8042]175        self.measure = measure
176        self.n = n
[9645]177
[8042]178    def __call__(self, data):
[8119]179        ma = score_all(data, self.measure)
[8042]180        self.n = min(self.n, len(data.domain.attributes))
[10708]181        return select(data, ma, self.n)
[8042]182
[9645]183FilterBestNAtts = FilterBestN
184FilterBestNAtts_Class = FilterBestN
185
[9653]186
[9645]187class FilterRelief(object):
[10172]188    """A class wrapper around :obj:`select_best_n`; the
189    constructor stores the filter parameters that are applied when the
190    function is called.
[9653]191
[10172]192    :param measure: a feature scorer
193    :type measure: :obj:`Orange.feature.scoring.Score`
194    :param margin: margin for Relief scoring
[8042]195    :type margin: float
196
[9645]197    """
198    def __new__(cls, data=None,
199                measure=orange.MeasureAttribute_relief(k=20, m=50),
200                margin=0):
201
202        if data is None:
[9878]203            self = object.__new__(cls)
[9645]204            return self
205        else:
206            self = cls(measure=measure, margin=margin)
207            return self(data)
208
[9653]209    def __init__(self, measure=orange.MeasureAttribute_relief(k=20, m=50),
210                 margin=0):
[8042]211        self.measure = measure
212        self.margin = margin
[9645]213
[8042]214    def __call__(self, data):
[9645]215        return select_relief(data, self.measure, self.margin)
216
217FilterRelief_Class = FilterRelief
[8042]218
219##############################################################################
220# wrapped learner
221
[9645]222
[9653]223class FilteredLearner(object):
[10708]224    """A feature selection wrapper around base learner. When provided data,
225     this learner applies a given feature selection method and then calls
226     the base learner.
[8042]227
228    Here is an example of how to build a wrapper around naive Bayesian learner
229    and use it on a data set::
230
231        nb = Orange.classification.bayes.NaiveBayesLearner()
[9653]232        learner = Orange.feature.selection.FilteredLearner(nb,
[9662]233            filter=Orange.feature.selection.FilterBestN(n=5), name='filtered')
[8042]234        classifier = learner(data)
235
236    """
[9653]237    def __new__(cls, baseLearner, data=None, weight=0,
238                filter=FilterAboveThreshold(), name='filtered'):
[8042]239
[9653]240        if data is None:
[9878]241            self = object.__new__(cls)
[9653]242            return self
243        else:
244            self = cls(baseLearner, filter=filter, name=name)
245            return self(data, weight)
246
247    def __init__(self, baseLearner, filter=FilterAboveThreshold(),
248                 name='filtered'):
[8042]249        self.baseLearner = baseLearner
250        self.filter = filter
251        self.name = name
[9653]252
[8042]253    def __call__(self, data, weight=0):
254        # filter the data and then learn
255        fdata = self.filter(data)
256        model = self.baseLearner(fdata, weight)
[9645]257        return FilteredClassifier(classifier=model, domain=model.domain)
[8042]258
[9653]259FilteredLearner_Class = FilteredLearner
260
261
[8042]262class FilteredClassifier:
[10171]263    """A classifier returned by FilteredLearner."""
[8042]264    def __init__(self, **kwds):
265        self.__dict__.update(kwds)
[9653]266
[9645]267    def __call__(self, example, resultType=orange.GetValue):
[8042]268        return self.classifier(example, resultType)
[9653]269
[8042]270    def atts(self):
[9645]271        return self.domain.attributes
Note: See TracBrowser for help on using the repository browser.