Changeset 10316:cae20013e4c8 in orange


Ignore:
Timestamp:
02/20/12 11:12:48 (2 years ago)
Author:
Ales Erjavec <ales.erjavec@…>
Branch:
default
rebase_source:
91e72e60ba323705a153ede1c39c70a5d1ce0161
Message:

Added support for continuous class in SOMSupervisedLearner (fixes #1101).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/projection/som.py

    r10315 r10316  
    160160import random 
    161161 
    162 random.seed(42) 
     162import Orange 
     163from Orange.misc import deprecated_keywords, \ 
     164                        deprecated_attribute 
     165 
    163166 
    164167HexagonalTopology = 0 
     
    181184########################################################################## 
    182185# Inference of Self-Organizing Maps  
    183  
    184 from Orange.misc import deprecated_keywords, \ 
    185                         deprecated_attribute 
    186186 
    187187class Solver(object): 
     
    542542    def __call__(self, data, weight_id=0, progress_callback=None): 
    543543        array, classes, w = data.toNumpyMA() 
    544         nval = len(data.domain.class_var.values) 
    545         ext = ma.zeros((len(array), nval)) 
    546         ext[([i for i, m in enumerate(classes.mask) if m], [int(c) for c, m in zip(classes, classes.mask) if m])] = 1.0 
     544        domain = data.domain 
     545        if isinstance(domain.class_var, Orange.feature.Discrete): 
     546            # Discrete class (extend the data with class indicator matrix) 
     547            nval = len(data.domain.class_var.values) 
     548            ext = ma.zeros((len(array), nval)) 
     549            ext[([i for i, m in enumerate(classes.mask) if m], 
     550                 [int(c) for c, m in zip(classes, classes.mask) if m])] = 1.0 
     551        elif isinstance(domain.class_var, Orange.feature.Continuous): 
     552            # Continuous class, just add the one column (what about multitarget) 
     553            nval = 1 
     554            ext = ma.zeros((len(array), nval)) 
     555            ext[:,0] = classes 
     556        elif domain.class_var is None: 
     557            # No class var 
     558            nval = 0 
     559            ext = ma.zeros((len(array), nval)) 
     560        else: 
     561            raise TypeError("Unsuported `class_var` %r" % domain.class_var)  
    547562        array = ma.hstack((array, ext)) 
     563         
    548564        map = Map(self.map_shape, topology=self.topology) 
    549565        if self.initialize == Map.InitializeLinear: 
     
    554570                     radius_ini=self.radius_ini, radius_fin=self.radius_fin, learning_rate=self.learning_rate, 
    555571                     epoch=self.epochs)(array, map, progress_callback=progress_callback) 
     572        # Remove class columns from the vectors  
    556573        for node in map: 
    557574            node.vector = node.vector[:-nval] 
Note: See TracChangeset for help on using the changeset viewer.