Changeset 7981:7262d5f0eb72 in orange


Ignore:
Timestamp:
06/03/11 16:10:02 (3 years ago)
Author:
ales_erjavec <ales.erjavec@…>
Branch:
default
Convert:
99c55d19fe6bb87f4e1e211113d7e800a1695f79
Message:

Added init_kmeans.
Bug fixes.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • orange/Orange/clustering/mixture.py

    r7885 r7981  
    1313import sys, os 
    1414import numpy 
     15import random 
    1516import Orange.data 
    1617 
     
    2526         
    2627    def __call__(self, instance): 
    27         """ Return the conditional probability of instance. 
     28        """ Return the probability of instance. 
    2829        """ 
    2930        return numpy.sum(prob_est([instance], self.weights, self.means, self.covariances)) 
    3031         
    3132    def __getitem__(self, index): 
    32         """ Return the index-th gaussian 
     33        """ Return the index-th gaussian. 
    3334        """  
    3435        return GMModel([1.0], self.means[index: index + 1], self.covariances[index: index + 1]) 
    35      
    36 #    def __getslice__(self, slice): 
    37 #        pass 
    3836 
    3937    def __len__(self): 
     
    4139     
    4240     
    43 def init_random(array, n_centers, *args, **kwargs): 
    44     """ Init random means 
    45     """ 
     41def init_random(data, n_centers, *args, **kwargs): 
     42    """ Init random means and correlations from a data table. 
     43     
     44    :param data: data table 
     45    :type data: :class:`Orange.data.Table` 
     46    :param n_centers: Number of centers and correlations to return. 
     47    :type n_centers: int 
     48     
     49    """ 
     50    if isinstance(data, Orange.data.Table): 
     51        array, w, c = data.toNumpyMA() 
     52    else: 
     53        array = numpy.asarray(data) 
     54         
    4655    min, max = array.max(0), array.min(0) 
    4756    dim = array.shape[1] 
     
    5261    correlations = [numpy.asmatrix(numpy.eye(dim)) for i in range(n_centers)] 
    5362    return means, correlations 
    54      
     63 
     64def init_kmeans(data, n_centers, *args, **kwargs): 
     65    """ Init with k-means algorithm. 
     66     
     67    :param data: data table 
     68    :type data: :class:`Orange.data.Table` 
     69    :param n_centers: Number of centers and correlations to return. 
     70    :type n_centers: int 
     71     
     72    """ 
     73    if not isinstance(data, Orange.data.Table): 
     74        raise TypeError("Orange.data.Table instance expected!") 
     75    from Orange.clustering.kmeans import Clustering 
     76    km = Clustering(data, centroids=n_centers, maxiters=20, nstart=3) 
     77    centers = Orange.data.Table(km.centroids) 
     78    centers, w, c = centers.toNumpyMA() 
     79    dim = len(data.domain.attributes) 
     80    correlations = [numpy.asmatrix(numpy.eye(dim)) for i in range(n_centers)] 
     81    return centers, correlations 
    5582     
    5683def prob_est1(data, mean, covariance, inv_covariance=None): 
    57     """ Return the probability of data given mean and covariance matrix  
     84    """ Return the probability of data given mean and covariance matrix 
    5885    """ 
    5986    data = numpy.asmatrix(data) 
     
    6289        inv_covariance = numpy.linalg.pinv(covariance) 
    6390         
    64     inv_covariance = numpy.asmatrix(inv_covariance)     
     91    inv_covariance = numpy.asmatrix(inv_covariance) 
    6592     
    6693    diff = data - mean 
     
    76103    assert(det != 0.0) 
    77104    p /= det 
    78 #    if det != 0.0: 
    79 #        p /= det 
    80 #    else: 
    81 #        p = numpy.ones(p.shape) / p.shape[0] 
    82105    return p 
    83106 
    84107 
    85108def prob_est(data, weights, means, covariances, inv_covariances=None): 
    86     """ Return the probability estimation of data given weighted, means and 
     109    """ Return the probability estimation of data given weights, means and 
    87110    covariances. 
    88111       
     
    103126    """ An EM solver for gaussian mixture model 
    104127    """ 
     128    _TRACE_MEAN = False 
    105129    def __init__(self, data, weights, means, covariances): 
    106130        self.data = data 
     
    182206        """ Run the EM algorithm. 
    183207        """ 
    184          
    185 #        from pylab import plot, show, draw, ion 
    186 #        ion() 
    187 #        plot(self.data[:, 0], self.data[:, 1], "ro") 
    188 #        vec_plot = plot(self.means[:, 0], self.means[:, 1], "bo")[0] 
     208        if self._TRACE_MEAN: 
     209            from pylab import plot, show, draw, ion 
     210            ion() 
     211            plot(self.data[:, 0], self.data[:, 1], "ro") 
     212            vec_plot = plot(self.means[:, 0], self.means[:, 1], "bo")[0] 
     213         
    189214        curr_iter = 0 
    190215         
     
    193218            self.one_step() 
    194219             
    195 #            vec_plot.set_xdata(self.means[:, 0]) 
    196 #            vec_plot.set_ydata(self.means[:, 1]) 
    197 #            draw() 
     220            if self._TRACE_MEAN: 
     221                vec_plot.set_xdata(self.means[:, 0]) 
     222                vec_plot.set_ydata(self.means[:, 1]) 
     223                draw() 
    198224             
    199225            curr_iter += 1 
    200             print curr_iter 
    201             print abs(old_objective - self.log_likelihood) 
     226#            print curr_iter 
     227#            print abs(old_objective - self.log_likelihood) 
    202228            if abs(old_objective - self.log_likelihood) < eps or curr_iter > max_iter: 
    203229                break 
    204230         
    205231         
    206 class GASolver(object): 
    207     """ A toy genetic algorithm solver  
    208     """ 
    209     def __init__(self, data, weights, means, covariances): 
    210         raise NotImplementedError 
    211  
    212  
    213 class PSSolver(object): 
    214     """ A toy particle swarm solver 
    215     """ 
    216     def __init__(self, data, weights, means, covariances): 
    217         raise NotImplementedError 
    218  
    219 class HybridSolver(object): 
    220     """ A hybrid solver 
    221     """ 
    222     def __init__(self, data, weights, means, covariances): 
    223         raise NotImplementedError 
     232#class GASolver(object): 
     233#    """ A toy genetic algorithm solver  
     234#    """ 
     235#    def __init__(self, data, weights, means, covariances): 
     236#        raise NotImplementedError 
     237# 
     238# 
     239#class PSSolver(object): 
     240#    """ A toy particle swarm solver 
     241#    """ 
     242#    def __init__(self, data, weights, means, covariances): 
     243#        raise NotImplementedError 
     244# 
     245#class HybridSolver(object): 
     246#    """ A hybrid solver 
     247#    """ 
     248#    def __init__(self, data, weights, means, covariances): 
     249#        raise NotImplementedError 
    224250     
    225251     
    226252class GaussianMixture(object): 
     253    """ Computes the gaussian mixture model from an Orange data-set. 
     254    """ 
    227255    def __new__(cls, data=None, weightId=None, **kwargs): 
    228256        self = object.__new__(cls) 
     
    233261            return self 
    234262         
    235     def __init__(self, n_centers=3, init_function=init_random): 
     263    def __init__(self, n_centers=3, init_function=init_kmeans): 
    236264        self.n_centers = n_centers 
    237265        self.init_function = init_function 
    238266         
    239267    def __call__(self, data, weightId=None): 
     268        means, correlations = self.init_function(data, self.n_centers) 
     269        means = numpy.asmatrix(means) 
    240270        array, _, _ = data.to_numpy_MA() 
    241271        solver = EMSolver(array, numpy.ones((self.n_centers)) / self.n_centers, 
    242                           *self.init_function(array, self.n_centers)) 
     272                          means, correlations) 
    243273        solver.run() 
    244274        return GMModel(solver.weights, solver.means, solver.covariances) 
     
    246276         
    247277def plot_model(data_array, mixture, axis=(0, 1), samples=20, contour_lines=20): 
    248      
     278    """ Plot the scaterplot of data_array and the contour lines of the 
     279    probability for the mixture. 
     280      
     281    """ 
    249282    import matplotlib 
    250283    import matplotlib.pylab as plt 
     
    257290     
    258291    weights = mixture.weights 
    259     means = [m[axis] for m in mixture.means] 
     292    means = mixture.means[:, axis] 
    260293     
    261294    covariances = [cov[axis,:][:, axis] for cov in mixture.covariances]  
     
    283316                cmap=cm.gray, extent=extent) 
    284317     
     318    plt.plot(means[:, 0], means[:, 1], "b+") 
    285319    plt.show() 
    286320     
    287 def test(): 
     321def test(seed=0): 
    288322#    data = Orange.data.Table(os.path.expanduser("../../doc/datasets/brown-selected.tab")) 
    289     data = Orange.data.Table(os.path.expanduser("~/Documents/brown-selected-fss.tab")) 
    290 #    data = Orange.data.Table("../../doc/datasets/iris.tab") 
     323#    data = Orange.data.Table(os.path.expanduser("~/Documents/brown-selected-fss.tab")) 
     324    data = Orange.data.Table(os.path.expanduser("~/Documents/brown-selected-fss-1.tab")) 
     325    data = Orange.data.Table("../../doc/datasets/iris.tab") 
    291326#    data = Orange.data.Table(Orange.data.Domain(data.domain[:2], None), data) 
    292     numpy.random.seed(0) 
    293     gmm = GaussianMixture(data, n_centers=3) 
    294     plot_model(data, gmm, axis=(0,1), samples=40, contour_lines=20) 
     327    numpy.random.seed(seed) 
     328    random.seed(seed) 
     329    gmm = GaussianMixture(data, n_centers=3, init_function=init_kmeans) 
     330    plot_model(data, gmm, axis=(0, 1), samples=40, contour_lines=100) 
    295331 
    296332     
Note: See TracChangeset for help on using the changeset viewer.