Orange Forum • View topic - PATCH: Parallel and distributed cross-validation

PATCH: Parallel and distributed cross-validation

Report bugs (or imagined bugs).
(Archived/read-only, please use our ticketing system for reporting bugs and their discussion.)
Forum rules
Archived/read-only, please use our ticketing system for reporting bugs and their discussion.

PATCH: Parallel and distributed cross-validation

Postby yang » Thu Mar 03, 2011 5:09

Here is a module that has a parallel (multi-process) and distributed (multi-host) implementation of testWithIndices() and crossValidation() that we hacked together. It has some caveats, but we felt this may be useful enough to others to be released now, since we most probably won't have the resources to fix these anytime soon. Hope this is a useful addition to the orange project.

Code: Select all
Index: orngDist.py
===================================================================
--- orngDist.py   (revision 0)
+++ orngDist.py   (revision 0)
@@ -0,0 +1,228 @@
+"""
+Parallel (multi-process) and distributed (multi-host) implementation of
+testWithIndices() and crossValidation().
+
+testWithIndices runs each fold in its own process. The number of processes
+allowed to run in parallel at any given time is determined by the ORNG_PAR
+environment variable (default 10). The processes run either locally (default)
+or on the ssh hosts specified via ORNG_HOSTS. The same orange version must be
+available on all hosts. The program that's run is orange-worker, which calls
+here.
+
+Limitations and notes:
+
+    - Some code duplication from Orange.evaluation.testing.
+    - Exists in orngDist instead of Orange.evaluation.testing.
+    - Serializes/deserializes ExampleTables to/from files. Not sure how to
+      avoid this without excessive additional code.
+    - Not all classifiers can be pickled and unpickled.
+    - Requires the python-futures library (which for Python 3 is in the
+      standard library).
+
+It would be great to address these issues, but we felt this may be useful
+enough to others to be released now, since we most probably won't have the
+resources to fix these anytime soon.
+
+Contributed by Yang Zhang.
+"""
+
+from orngTest import *
+from concurrent.futures import *
+from cPickle import *
+import random, sys, time, zlib
+from subprocess import *
+from cStringIO import StringIO
+import itertools
+
+def savePar(path, table):
+    with open(path, 'w') as f:
+        # The order happens to match the order of the attributes in saveTxt
+        attrs = ('%s%s#%s' % (
+            'c' if a is table.domain.classVar else '',
+            'D' if a.varType == orange.VarTypes.Discrete else 'C',
+            a.name)
+            for a in table.domain)
+        metas = ('mD#%s' % m.name for m in table.domain.getmetas().values())
+        print >> f, '\t'.join(itertools.chain(attrs, metas))
+    table.save(path + '.txt')
+    with open(path, 'a') as f:
+        with open(path + '.txt') as g:
+            g.readline()
+            f.write(g.read())
+
+orange.registerFileType('par', None, savePar, '.par')
+
+_path = None
+def path():
+    global _path
+    if _path: return _path
+    _path = '/tmp/%s-%s.par' % \
+            (os.getpid(), random.Random(time.time()).randrange(sys.maxint))
+    return _path
+
+def do_fold((fold, args)):
+    (testResults, indices, examples, pps, nLrn, cache, storeclassifiers, learners, weight) = loads(args)
+    callback=None
+    with open(path() + '.txt', 'w') as f:
+        f.write(examples)
+    examples = orange.ExampleTable(path() + '.txt')
+    # learning
+    learnset = examples.selectref(indices, fold, negate=1)
+    if not len(learnset):
+        return
+    testset = examples.selectref(indices, fold, negate=0)
+    if not len(testset):
+        return
+   
+    for pp in pps:
+        if pp[0]=="B":
+            learnset = pp[1](learnset)
+            testset = pp[1](testset)
+
+    for pp in pps:
+        if pp[0]=="L":
+            learnset = pp[1](learnset)
+        elif pp[0]=="T":
+            testset = pp[1](testset)
+        elif pp[0]=="LT":
+            (learnset, testset) = pp[1](learnset, testset)
+
+    if not learnset:
+        raise SystemError, "no training examples after preprocessing"
+
+    if not testset:
+        raise SystemError, "no test examples after preprocessing"
+
+    classifiers = [None]*nLrn
+    for i in range(nLrn):
+        if not cache or not testResults.loaded[i]:
+            classifiers[i] = learners[i](learnset, weight)
+    testResults.classifiers.append(classifiers)
+
+    # testing
+    tcn = 0
+    for i in range(len(examples)):
+        if (indices[i]==fold):
+            # This is to prevent cheating:
+            ex = orange.Example(testset[tcn])
+            ex.setclass("?")
+            tcn += 1
+            for cl in range(nLrn):
+                if not cache or not testResults.loaded[cl]:
+                    cr = classifiers[cl](ex, orange.GetBoth)                                     
+                    if cr[0].isSpecial():
+                        raise "Classifier %s returned unknown value" % (classifiers[cl].name or ("#%i" % cl))
+                    testResults.results[i].setResult(cl, cr[0], cr[1])
+    if callback:
+        callback()
+    return testResults
+
+def testWithIndices(learners, examples, indices, indicesrandseed="*", pps=[], callback=None, **argkw):
+    verb = argkw.get("verbose", 0)
+    cache = argkw.get("cache", 0)
+    storeclassifiers = argkw.get("storeclassifiers", 0) or argkw.get("storeClassifiers", 0)
+    cache = cache and not storeclassifiers
+
+    examples, weight = demangleExamples(examples)
+    nLrn = len(learners)
+
+    if not examples:
+        raise SystemError, "Test data set with no examples"
+    if not examples.domain.classVar:
+        raise "Test data set without class attribute"
+   
+##    for pp in pps:
+##        if pp[0]!="L":
+##            raise SystemError, "cannot preprocess testing examples"
+
+    nIterations = max(indices)+1
+    if examples.domain.classVar.varType == orange.VarTypes.Discrete:
+        values = list(examples.domain.classVar.values)
+        basevalue = examples.domain.classVar.baseValue
+    else:
+        basevalue = values = None
+
+    conv = examples.domain.classVar.varType == orange.VarTypes.Discrete and int or float       
+    testResults = ExperimentResults(nIterations, [getobjectname(l) for l in learners], values, weight!=0, basevalue)
+    testResults.results = [TestedExample(indices[i], conv(examples[i].getclass()), nLrn, examples[i].getweight(weight))
+                           for i in range(len(examples))]
+
+    if argkw.get("storeExamples", 0):
+        testResults.examples = examples
+       
+    ccsum = hex(examples.checksum())[2:]
+    ppsp = encodePP(pps)
+    fnstr = "{TestWithIndices}_%s_%s%s-%s" % ("%s", indicesrandseed, ppsp, ccsum)
+    if "*" in fnstr:
+        cache = 0
+
+    if cache and testResults.loadFromFiles(learners, fnstr):
+        printVerbose("  loaded from cache", verb)
+    else:
+        examples.save(path())
+        with open(path()) as f: EXAMPLES = f.read()
+        args = zlib.compress(dumps(
+            (testResults, indices, EXAMPLES, pps, nLrn, cache,
+             storeclassifiers, learners, weight), 2))
+
+        nprocs = int(os.environ.get('ORNG_PAR', '10'))
+        with ThreadPoolExecutor(nprocs) as executor:
+            # Distribute work
+            hosts = os.environ.get('ORNG_HOSTS', 'localhost').split()
+            def runssh(fold):
+                host = hosts[fold % len(hosts)]
+                cmd = ['orange-worker',str(fold)]
+                if host != 'localhost': cmd = ['ssh',host] + cmd
+                p = Popen(cmd, stdin=PIPE, stdout=PIPE, close_fds=True)
+                return loads(zlib.decompress(p.communicate(args)[0]))
+            rs = list(executor.map(runssh, xrange(nIterations)))
+
+            # Accumulate results
+            for fold,r in enumerate(rs):
+                [classifiers] = r.classifiers
+                if storeclassifiers:
+                    testResults.classifiers.append(classifiers)
+                testset = examples.selectref(indices, fold, negate=0)
+                tcn = 0
+                # TODO interaction btwn this loop and outer loop can be sped up
+                for i in xrange(len(examples)):
+                    if (indices[i]==fold):
+                        # This is to prevent cheating:
+                        ex = orange.Example(testset[tcn])
+                        ex.setclass("?")
+                        tcn += 1
+                        for cl in xrange(nLrn):
+                            if not cache or not testResults.loaded[cl]:
+                                cr = classifiers[cl](ex, orange.GetBoth)                                     
+                                if cr[0].isSpecial():
+                                    raise "Classifier %s returned unknown value" % (classifiers[cl].name or ("#%i" % cl))
+                                testResults.results[i].setResult(cl, cr[0], cr[1])
+
+        if cache:
+            testResults.saveToFiles(learners, fnstr)
+       
+    return testResults
+
+
+def crossValidation(learners, examples, folds=10,
+                    strat=orange.MakeRandomIndices.StratifiedIfPossible,
+                    pps=[], indicesrandseed="*", **argkw):
+    """cross-validation evaluation of learners"""
+    (examples, weight) = demangleExamples(examples)
+    if indicesrandseed!="*":
+        indices = orange.MakeRandomIndicesCV(examples, folds, randseed=indicesrandseed, stratified = strat)
+    else:
+        randomGenerator = argkw.get("randseed", 0) or argkw.get("randomGenerator", 0)
+        indices = orange.MakeRandomIndicesCV(examples, folds, stratified = strat, randomGenerator = randomGenerator)
+    return testWithIndices(learners, (examples, weight), indices, indicesrandseed, pps, **argkw)
+
+
+def main(argv=sys.argv):
+    fold = int(argv[1])
+    args = zlib.decompress(sys.stdin.read())
+    stdout, sys.stdout = sys.stdout, StringIO()
+    res = do_fold((fold, args))
+    sys.stdout = stdout
+    sys.stdout.write(zlib.compress(dumps(res, 2)))
+
+# vim:ts=4 sw=4
Index: orange-worker
===================================================================
--- orange-worker   (revision 0)
+++ orange-worker   (revision 0)
@@ -0,0 +1,3 @@
+#!/usr/bin/env python
+import orngDist
+orngDist.main()
Index: setup.py
===================================================================
--- setup.py   (revision 10307)
+++ setup.py   (working copy)
@@ -420,7 +420,7 @@
                       },
       ext_modules = [include_ext, orange_ext, orangeom_ext, orangene_ext, corn_ext, statc_ext],
       extra_path=("orange", "orange"),
-      scripts = ["orange-canvas"],
+      scripts = ["orange-canvas","orange-worker"],
       license = "GNU General Public License (GPL)",
       keywords = ["data mining", "machine learning", "artificial intelligence"],
       classifiers = ["Development Status :: 4 - Beta",

Return to Bugs