source: orange/Orange/feature/selection.py @ 11646:94e12073788c

Revision 11646:94e12073788c, 8.2 KB checked in by Ales Erjavec <ales.erjavec@…>, 9 months ago (diff)

Fixed 'above_threshold' test condition.

It now tests True for scores equal to threshold (as documented).

RevLine 
[8042]1__docformat__ = 'restructuredtext'
2
[10523]3from operator import itemgetter
4
[8042]5import Orange.core as orange
6
[8119]7from Orange.feature.scoring import score_all
[8042]8
[11645]9
[10708]10def top_rated(scores, n, highest_best=True):
11    """Return n top-rated features from the list of scores.
[9653]12
[11645]13    :param list scores:
14        A list such as the one returned by :func:`.score_all`
15    :param int n: Number of features to select.
16    :param bool highest_best:
17        If true, the features that are scored higher are preferred.
[8042]18    :rtype: :obj:`list`
[11645]19
[8042]20    """
[11645]21    return [f for f, score in
22            sorted(scores, key=itemgetter(1), reverse=highest_best)[:n]]
[8042]23
[10708]24bestNAtts = top_rated
[9653]25
[11645]26
[9645]27def above_threshold(scores, threshold=0.0):
[10172]28    """Return features (without scores) with scores above or
[8042]29    equal to a specified threshold.
[9653]30
[11646]31    :param list scores:
32        A list such as one returned by :func:`.score_all`
33    :param float threshold: Threshold for selection.
[8042]34    :rtype: :obj:`list`
35
36    """
[11646]37    return [f for f, score in scores if score >= threshold]
[10171]38
[8042]39
[9645]40attsAboveThreshold = above_threshold
41
42
[10708]43def select(data, scores, n):
[10171]44    """Construct and return a new data table that includes a
[10172]45    class and only the best features from a list scores.
[9653]46
[10172]47    :param data: a data table
48    :type data: :obj:`Orange.data.Table`
[10171]49    :param scores: a list such as the one returned by
[10172]50      :obj:`~Orange.feature.scoring.score_all`
[8042]51    :type scores: list
[10172]52    :param n: number of features to select
53    :type n: int
54    :rtype: :obj:`Orange.data.Table`
[8042]55    """
[10708]56    return data.select(top_rated(scores, n) + [data.domain.classVar.name])
[8042]57
[10708]58selectBestNAtts = select
[10709]59select_best_n = select
[8042]60
[9645]61
62def select_above_threshold(data, scores, threshold=0.0):
[10171]63    """Construct and return a new data table that includes a class and
[9653]64    features from the list returned by
[10708]65    :obj:`~Orange.feature.scoring.score_all` with higher or equal score
66    to a given threshold.
[9653]67
[10172]68    :param data: a data table
69    :type data: :obj:`Orange.data.Table`
[10171]70    :param scores: a list such as the one returned by
[10172]71      :obj:`~Orange.feature.scoring.score_all`
[8042]72    :type scores: list
[10172]73    :param threshold: threshold for selection
[8042]74    :type threshold: float
[10172]75    :rtype: :obj:`Orange.data.Table`
[8042]76    """
[9653]77    return data.select(above_threshold(scores, threshold) + \
78                       [data.domain.classVar.name])
[8042]79
[9645]80selectAttsAboveThresh = select_above_threshold
81
82
83def select_relief(data, measure=orange.MeasureAttribute_relief(k=20, m=50), margin=0):
[10171]84    """Iteratively remove the worst scored feature until no feature
85    has a score below the margin. The filter procedure was originally
86    designed for measures such as Relief, which are context dependent,
87    i.e., removal of features may change the scores of other remaining
88    features. The score is thus recomputed in each iteration.
[8042]89
[10172]90    :param data: a data table
91    :type data: :obj:`Orange.data.Table`
92    :param measure: a feature scorer
93    :type measure: :obj:`Orange.feature.scoring.Score`
94    :param margin: margin for removal
[8042]95    :type margin: float
[9653]96
[8042]97    """
[8119]98    measl = score_all(data, measure)
[9645]99    while len(data.domain.attributes) > 0 and measl[-1][1] < margin:
[10708]100        data = select(data, measl, len(data.domain.attributes) - 1)
[8119]101        measl = score_all(data, measure)
[8042]102    return data
103
[9653]104filterRelieff = select_relief
[8042]105
[9645]106
107class FilterAboveThreshold(object):
[10708]108    """A wrapper around :obj:`select_above_threshold`; the
109    constructor stores the parameters of the feature selection
110    procedure that are then applied when the the selection
111    is called with the actual data.
[9645]112
[10172]113    :param measure: a feature scorer
114    :type measure: :obj:`Orange.feature.scoring.Score`
115    :param threshold: threshold for selection. Defaults to 0.
[8042]116    :type threshold: float
[10708]117    """
[9645]118
119    def __new__(cls, data=None,
120                measure=orange.MeasureAttribute_relief(k=20, m=50),
121                threshold=0.0):
122        if data is None:
[9878]123            self = object.__new__(cls)
[9645]124            return self
125        else:
126            self = cls(measure=measure, threshold=threshold)
127            return self(data)
128
129    def __init__(self, measure=orange.MeasureAttribute_relief(k=20, m=50), \
130                 threshold=0.0):
[8042]131        self.measure = measure
132        self.threshold = threshold
133
134    def __call__(self, data):
[10172]135        """Return data table features that have scores above given
136        threshold.
[9645]137
[9662]138        :param data: data table
[10172]139        :type data: Orange.data.Table
[8042]140
141        """
[8119]142        ma = score_all(data, self.measure)
[9645]143        return select_above_threshold(data, ma, self.threshold)
[8042]144
[9645]145FilterAttsAboveThresh = FilterAboveThreshold
146FilterAttsAboveThresh_Class = FilterAboveThreshold
147
148
149class FilterBestN(object):
[10708]150    """A wrapper around :obj:`select`; the
[10172]151    constructor stores the filter parameters that are applied when the
152    function is called.
[9645]153
[10172]154    :param measure: a feature scorer
155    :type measure: :obj:`Orange.feature.scoring.Score`
156    :param n: number of features to select
[8042]157    :type n: int
158
159    """
[9645]160    def __new__(cls, data=None,
161                measure=orange.MeasureAttribute_relief(k=20, m=50),
162                n=5):
163
164        if data is None:
[9878]165            self = object.__new__(cls)
[9645]166            return self
167        else:
168            self = cls(measure=measure, n=n)
169            return self(data)
170
[9653]171    def __init__(self, measure=orange.MeasureAttribute_relief(k=20, m=50),
172                 n=5):
[8042]173        self.measure = measure
174        self.n = n
[9645]175
[8042]176    def __call__(self, data):
[8119]177        ma = score_all(data, self.measure)
[8042]178        self.n = min(self.n, len(data.domain.attributes))
[10708]179        return select(data, ma, self.n)
[8042]180
[9645]181FilterBestNAtts = FilterBestN
182FilterBestNAtts_Class = FilterBestN
183
[9653]184
[9645]185class FilterRelief(object):
[10172]186    """A class wrapper around :obj:`select_best_n`; the
187    constructor stores the filter parameters that are applied when the
188    function is called.
[9653]189
[10172]190    :param measure: a feature scorer
191    :type measure: :obj:`Orange.feature.scoring.Score`
192    :param margin: margin for Relief scoring
[8042]193    :type margin: float
194
[9645]195    """
196    def __new__(cls, data=None,
197                measure=orange.MeasureAttribute_relief(k=20, m=50),
198                margin=0):
199
200        if data is None:
[9878]201            self = object.__new__(cls)
[9645]202            return self
203        else:
204            self = cls(measure=measure, margin=margin)
205            return self(data)
206
[9653]207    def __init__(self, measure=orange.MeasureAttribute_relief(k=20, m=50),
208                 margin=0):
[8042]209        self.measure = measure
210        self.margin = margin
[9645]211
[8042]212    def __call__(self, data):
[9645]213        return select_relief(data, self.measure, self.margin)
214
215FilterRelief_Class = FilterRelief
[8042]216
217##############################################################################
218# wrapped learner
219
[9645]220
[9653]221class FilteredLearner(object):
[10708]222    """A feature selection wrapper around base learner. When provided data,
223     this learner applies a given feature selection method and then calls
224     the base learner.
[8042]225
226    Here is an example of how to build a wrapper around naive Bayesian learner
227    and use it on a data set::
228
229        nb = Orange.classification.bayes.NaiveBayesLearner()
[9653]230        learner = Orange.feature.selection.FilteredLearner(nb,
[9662]231            filter=Orange.feature.selection.FilterBestN(n=5), name='filtered')
[8042]232        classifier = learner(data)
233
234    """
[9653]235    def __new__(cls, baseLearner, data=None, weight=0,
236                filter=FilterAboveThreshold(), name='filtered'):
[8042]237
[9653]238        if data is None:
[9878]239            self = object.__new__(cls)
[9653]240            return self
241        else:
242            self = cls(baseLearner, filter=filter, name=name)
243            return self(data, weight)
244
245    def __init__(self, baseLearner, filter=FilterAboveThreshold(),
246                 name='filtered'):
[8042]247        self.baseLearner = baseLearner
248        self.filter = filter
249        self.name = name
[9653]250
[8042]251    def __call__(self, data, weight=0):
252        # filter the data and then learn
253        fdata = self.filter(data)
254        model = self.baseLearner(fdata, weight)
[9645]255        return FilteredClassifier(classifier=model, domain=model.domain)
[8042]256
[9653]257FilteredLearner_Class = FilteredLearner
258
259
[8042]260class FilteredClassifier:
[10171]261    """A classifier returned by FilteredLearner."""
[8042]262    def __init__(self, **kwds):
263        self.__dict__.update(kwds)
[9653]264
[9645]265    def __call__(self, example, resultType=orange.GetValue):
[8042]266        return self.classifier(example, resultType)
[9653]267
[8042]268    def atts(self):
[9645]269        return self.domain.attributes
Note: See TracBrowser for help on using the repository browser.