Changeset 11750:78866659d679 in orange


Ignore:
Timestamp:
11/06/13 10:37:40 (5 months ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
Message:

Fixed NeuralNetworkClassifier interface/pickling. Added tests.

Location:
Orange
Files:
1 added
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/classification/neural.py

    r11559 r11750  
    148148    :type normalize: bool 
    149149 
    150     :rtype: :class:`Orange.multitarget.neural.neuralNetworkLearner` or  
    151             :class:`Orange.multitarget.chain.NeuralNetworkClassifier` 
     150    :rtype: :class:`~NeuralNetworkLearner` or 
     151            :class:`~NeuralNetworkClassifier` 
    152152    """ 
    153153 
     
    161161            return self(data,weight) 
    162162 
    163     def __init__(self, name="NeuralNetwork", n_mid=10, reg_fact=1, max_iter=300, normalize=True, rand=None): 
    164         """ 
    165         Current default values are the same as in the original implementation (neural_networks.py) 
    166         """ 
     163    def __init__(self, name="NeuralNetwork", n_mid=10, reg_fact=1, 
     164                 max_iter=300, normalize=True, rand=None): 
    167165        self.name = name 
    168166        self.n_mid = n_mid 
     
    181179         
    182180        :param instances: data for learning. 
    183         :type instances: class:`Orange.data.Table` 
     181        :type instances: :class:`Orange.data.Table` 
    184182 
    185183        :param weight: weight. 
    186184        :type weight: int 
    187185 
    188         :param class_order: list of descriptors of class variables 
    189         :type class_order: list of :class:`Orange.feature.Descriptor` 
    190  
    191         :rtype: :class:`Orange.multitarget.chain.NeuralNetworkClassifier` 
     186        :rtype: :class:`~NeuralNetworkClassifier` 
    192187        """ 
    193188 
    194189        #converts attribute data 
    195         X = data.to_numpy()[0]  
     190        X = data.to_numpy()[0] 
    196191 
    197192        mean = X.mean(axis=0) 
     
    201196 
    202197        #converts multi-target or single-target classes to numpy 
    203         if data.domain.class_vars: 
    204             for cv in data.domain.class_vars: 
    205                 if cv.var_type == Orange.feature.Continuous: 
    206                     raise ValueError("non-discrete classes not supported") 
    207         else: 
    208             if data.domain.class_var.var_type == Orange.feature.Continuous: 
    209                 raise ValueError("non-discrete classes not supported") 
     198        if any(isinstance(var, Orange.feature.Continuous) 
     199               for var in (data.domain.class_vars or [data.domain.class_var])): 
     200            raise ValueError("non-discrete classes not supported") 
    210201 
    211202        if data.domain.class_vars: 
     
    229220        
    230221        #initializes neural networks 
    231         self.nn =  _NeuralNetwork([len(X[0]), self.n_mid,len(Y[0])], lambda_=self.reg_fact, maxfun=self.max_iter, iprint=-1) 
    232          
    233         self.nn.fit(X,Y) 
     222        nn =  _NeuralNetwork([len(X[0]), self.n_mid,len(Y[0])], lambda_=self.reg_fact, maxfun=self.max_iter, iprint=-1) 
     223         
     224        nn.fit(X,Y) 
    234225                
    235         return NeuralNetworkClassifier(classifier=self.nn.predict, 
    236             domain=data.domain, normalize=self.normalize, mean=mean, std=std) 
    237  
    238 class NeuralNetworkClassifier(): 
    239     """     
    240     Uses the classifier induced by the :obj:`NeuralNetworkLearner`. 
    241    
    242     :param name: name of the classifier. 
    243     :type name: string 
     226        return NeuralNetworkClassifier(domain=data.domain, nn=nn, normalize=self.normalize, mean=mean, std=std) 
     227 
     228 
     229class NeuralNetworkClassifier(Orange.classification.Classifier): 
    244230    """ 
    245  
    246     def __init__(self,**kwargs): 
    247         self.__dict__.update(**kwargs) 
    248  
    249     def __call__(self,example, result_type=Orange.core.GetValue): 
     231    Classifier induced by the :obj:`NeuralNetworkLearner`. 
     232    """ 
     233 
     234    def __init__(self, domain, nn, normalize, mean, std, **kwargs): 
     235        self.domain = domain 
     236        if domain.class_vars: 
     237            self.class_vars = domain.class_vars 
     238            self.class_var = None 
     239        else: 
     240            self.class_var = domain.class_var 
     241        self.nn = nn 
     242        self.normalize = normalize 
     243        self.mean = mean 
     244        self.std = std 
     245 
     246        for name, val in kwargs.items(): 
     247            setattr(self, name, val) 
     248 
     249    def __reduce__(self): 
     250        return (type(self), 
     251                (self.domain, self.nn, self.normalize, self.mean, self.std), 
     252                dict(self.__dict__)) 
     253 
     254    def __call__(self, example, result_type=Orange.core.GetValue): 
    250255        """ 
    251         :param instance: instance to be classified. 
    252         :type instance: :class:`Orange.data.Instance` 
     256        :param example: instance to be classified. 
     257        :type example: :class:`Orange.data.Instance` 
    253258         
    254259        :param result_type: :class:`Orange.classification.Classifier.GetValue` or \ 
     
    259264              :class:`Orange.statistics.Distribution` or a tuple with both 
    260265        """ 
    261  
     266        example = Orange.data.Instance(self.domain, example) 
    262267        # transform example to numpy 
    263         if not self.domain.class_vars: example = [example[i] for i in range(len(example)-1)] 
     268        if not self.domain.class_vars: 
     269            example = [example[i] for i in range(len(example)-1)] 
    264270        input = np.array([[float(e) for e in example]]) 
    265271 
     
    268274 
    269275        # transform results from numpy 
    270         results = self.classifier(input).tolist()[0] 
     276        results = self.nn.predict(input).tolist()[0] 
    271277        if len(results) == 1: 
    272278            prob_positive = results[0] 
     
    289295                mt_prob.append(cprob) 
    290296                mt_value.append(Orange.data.Value(self.domain.class_vars[cls], cprob.values().index(max(cprob)))) 
    291                                   
    292         else: 
     297 
     298        else: 
     299            assert len(self.domain.class_var.values) == len(results) 
    293300            cprob = Orange.statistics.distribution.Discrete(results) 
     301            cprob.variable = self.domain.class_var 
    294302            cprob.normalize() 
    295  
    296303            mt_prob = cprob 
    297304            mt_value = Orange.data.Value(self.domain.class_var, cprob.values().index(max(cprob))) 
Note: See TracChangeset for help on using the changeset viewer.