source: orange/orange/orngSQL.py @ 8042:ffcb93bc9028

Revision 8042:ffcb93bc9028, 14.9 KB checked in by markotoplak, 3 years ago (diff)

Hierarchical clustering: also catch RuntimeError when importing matplotlib (or the documentation could not be built on server).

Line 
1# A module to read data from an SQL database into an Orange ExampleTable.
2# The goal is to keep it compatible with PEP 249.
3# For now, the writing shall be basic, if it works at all.
4
5import orange
6import os
7import urllib
8
9def _parseURI(uri):
10    """ lifted straight from sqlobject """
11    schema, rest = uri.split(':', 1)
12    assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
13    if rest.startswith('/') and not rest.startswith('//'):
14        host = None
15        rest = rest[1:]
16    elif rest.startswith('///'):
17        host = None
18        rest = rest[3:]
19    else:
20        rest = rest[2:]
21        if rest.find('/') == -1:
22            host = rest
23            rest = ''
24        else:
25            host, rest = rest.split('/', 1)
26    if host and host.find('@') != -1:
27        user, host = host.split('@', 1)
28        if user.find(':') != -1:
29            user, password = user.split(':', 1)
30        else:
31            password = None
32    else:
33        user = password = None
34    if host and host.find(':') != -1:
35        _host, port = host.split(':')
36        try:
37            port = int(port)
38        except ValueError:
39            raise ValueError, "port must be integer, got '%s' instead" % port
40        if not (1 <= port <= 65535):
41            raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
42        host = _host
43    else:
44        port = None
45    path = '/' + rest
46    if os.name == 'nt':
47        if (len(rest) > 1) and (rest[1] == '|'):
48            path = "%s:%s" % (rest[0], rest[2:])
49    args = {}
50    if path.find('?') != -1:
51        path, arglist = path.split('?', 1)
52        arglist = arglist.split('&')
53        for single in arglist:
54            argname, argvalue = single.split('=', 1)
55            argvalue = urllib.unquote(argvalue)
56            args[argname] = argvalue
57    return schema, user, password, host, port, path, args
58
59class __DummyQuirkFix:
60    def __init__(self, dbmod):
61        self.dbmod = dbmod
62        self.typeDict = {
63            orange.VarTypes.Continuous:'FLOAT', 
64            orange.VarTypes.Discrete:'VARCHAR(250)', orange.VarTypes.String:'VARCHAR(250)'}
65    def beforeWrite(self, cursor):
66        pass
67    def beforeCreate(self, cursor):
68        pass
69    def beforeRead(self, cursor):
70        pass
71class __MySQLQuirkFix(__DummyQuirkFix):
72    def __init__(self, dbmod):
73        self.dbmod = dbmod
74        self.BOOLEAN = None
75        self.STRING = dbmod.STRING
76        self.DATETIME = dbmod.DATETIME
77        self.typeDict = {
78            orange.VarTypes.Continuous:'DOUBLE', 
79            orange.VarTypes.Discrete:'VARCHAR(250)', orange.VarTypes.String:'VARCHAR(250)'}
80    def beforeWrite(self, cursor):
81        cursor.execute("SET sql_mode='ANSI_QUOTES';")
82    def beforeCreate(self, cursor):
83        cursor.execute("SET sql_mode='ANSI_QUOTES';")
84    def beforeRead(self, cursor):
85        pass
86class __PostgresQuirkFix(__DummyQuirkFix):
87    def __init__(self, dbmod):
88        self.dbmod = dbmod
89        self.BOOLEAN = 16
90        self.STRING = dbmod.STRING
91        self.DATETIME = dbmod.DATETIME
92        self.typeDict = {
93            orange.VarTypes.Continuous:'FLOAT', 
94            orange.VarTypes.Discrete:'VARCHAR', orange.VarTypes.String:'VARCHAR'}
95    def beforeWrite(self, cursor):
96        pass
97    def beforeCreate(self, cursor):
98        pass
99    def beforeRead(self, cursor):
100        pass
101
102def _connection(uri):
103        """the uri string's syntax is the same as that of sqlobject.
104        Unfortunately, only postgres and mysql are going to be supported in
105        the near future.
106        scheme://[user[:password]@]host[:port]/database[?parameters]
107        Examples:
108        mysql://user:password@host/database
109        mysql://host/database?debug=1
110        postgres://user@host/database?debug=&cache=
111        postgres:///full/path/to/socket/database
112        postgres://host:5432/database
113        """
114        (schema, user, password, host, port, path, args) = _parseURI(uri)
115        if schema == 'postgres':
116            import psycopg2 as dbmod
117            argTrans = {
118            'host':'host',
119            'port':'port',
120            'user':'user',
121            'password':'password',
122            'database':'database'
123            }
124            quirks = __PostgresQuirkFix(dbmod)
125        elif schema == 'mysql':
126            import MySQLdb as dbmod
127            argTrans = {
128            'host':'host',
129            'port':'port',
130            'user':'user',
131            'password':'passwd',
132            'database':'db'
133            }
134            quirks = __MySQLQuirkFix(dbmod)
135        dbArgDict = {}
136        if user:
137            dbArgDict[argTrans['user']] = user
138        if password:
139            dbArgDict[argTrans['password']] = password
140        if host:
141            dbArgDict[argTrans['host']] = host
142        if port:
143            dbArgDict[argTrans['port']] = port
144        if path:
145            dbArgDict[argTrans['database']] = path[1:]
146        return (quirks, dbmod.connect(**dbArgDict))
147
148class SQLReader(object):
149    def __init__(self, addr = None, domainDepot = None):
150        if addr is not None:
151            self.connect(addr)
152        if domainDepot is not None:
153            self.domainDepot = domainDepot
154        else:
155            self.domainDepot = orange.DomainDepot()
156        self.exampleTable = None
157        self._dirty = True
158    def connect(self, uri):
159        self._dirty = True
160        self.delDomain()
161        (self.quirks, self.conn) = _connection(uri)
162    def disconnect(self):
163        self.conn.disconnect()
164    def getClassName(self):
165        self.update()
166        return self.domain.classVar.name
167    def setClassName(self, className):
168        self._className = className
169        self.delDomain()
170    def delClassName(self):
171        del self._className
172    className = property(getClassName, setClassName, delClassName, "the name of the class variable")
173
174    def getMetaNames(self):
175        self.update()
176        return self.domain.getmetas().values()
177    def setMetaNames(self, metaNames):
178        self._metaNames = metaNames
179        self.delDomain()
180    def delMetaNames(self):
181        del self._metaNames
182    metaNames = property(getMetaNames, setMetaNames, delMetaNames, "the names of the meta attributes")
183
184    def setDiscreteNames(self, discreteNames):
185        self._discreteNames = discreteNames
186        self.delDomain()
187    def getDiscreteNames(self):
188        self.update()
189        return self._discreteNames
190    def delDiscreteNames(self):
191        del self._discreteNames
192    discreteNames = property(getDiscreteNames, setDiscreteNames, delDiscreteNames, "the names of the discrete attributes")
193
194    def setQuery(self, query, domain = None):
195        """sets the query, resets the internal variables, without executing the query"""
196        self._query = query
197        self._dirty = True
198        if domain is not None:
199            self._domain = domain
200        else:
201            self.delDomain()
202    def getQuery(self):
203        return self._query
204    def delQuery(self):
205        del self._query
206    query = property(getQuery, setQuery, delQuery, "The query to be executed on the next execute()")
207    def generateDomain(self):
208        pass
209    def setDomain(self, domain):
210        self._domain = domain
211        self._dirty = True
212    def getDomain(self):
213        if not hasattr(self, '_domain'):
214            self._createDomain()
215        return self._domain
216    def delDomain(self):
217        if hasattr(self, '_domain'):
218            del self._domain
219    domain = property(getDomain, setDomain, delDomain, "the Orange domain")
220    def execute(self, query, domain = None):
221        """executes an sql query"""
222        self.setQuery(query, domain)
223        self.update()
224       
225    def _createDomain(self):
226        if hasattr(self, '_domain'):
227            return
228        attrNames = []
229        if not hasattr(self, '_discreteNames'):
230            self._discreteNames = []
231        discreteNames = self._discreteNames
232        if not hasattr(self, '_metaNames'):
233            self._metaNames = []
234        metaNames = self._metaNames
235        if not hasattr(self, '_className'):
236            className = None
237        else:
238            className = self._className
239        for i in self.desc:
240            name = i[0]
241            typ = i[1]
242            if name in discreteNames or typ == self.quirks.BOOLEAN:
243                attrName = 'D#' + name
244            elif typ == self.quirks.STRING:
245                    attrName = 'S#' + name
246            elif typ == self.quirks.DATETIME:
247                attrName = 'S#' + name
248            else:
249                attrName = 'C#' + name
250            if name == className:
251                attrName = "c" + attrName
252            elif name in metaNames:
253                attrName = "m" + attrName
254            elif not className and name == 'class':
255                attrName = "c" + attrName
256            attrNames.append(attrName)
257    #       print "NAME:", '"%s"' % name, ", t:", typ, " attrN:", '"%s"' % attrName
258        (self._domain, self._metaIDs, dummy) = self.domainDepot.prepareDomain(attrNames)
259 #           print "Created domain."
260        del dummy
261
262       
263    def update(self):
264        if not self._dirty and hasattr(self, '_domain'):
265            return self.exampleTable
266        self.exampleTable = None
267        try:
268            curs = self.conn.cursor()
269            try:
270                self.quirks.beforeRead(curs)
271                curs.execute(self.query)
272            except Exception, e:
273                self.conn.rollback()
274                raise e
275            self.desc = curs.description
276            # for reasons unknown, the attributes get reordered.
277            domainIndexes = [0] * len(self.desc)
278            self._createDomain()
279            attrNames = []
280            for i, name in enumerate(self.desc):
281            #    print name[0], '->', self.domain.index(name[0])
282                domainIndexes[self._domain.index(name[0])] = i
283                attrNames.append(name[0])
284            self.exampleTable = orange.ExampleTable(self.domain)
285            r = curs.fetchone()
286            while r:
287                # for reasons unknown, domain rearranges the properties
288                example = orange.Example(self.domain)
289                for i in xrange(len(r)):
290                    if r[i] is not None:
291                        val = str(r[i])
292                        var = example[attrNames[i]].variable
293                        if type(var) == orange.EnumVariable and val not in var.values:
294                            var.values.append(val)
295                        example[attrNames[i]] = str(r[i])
296                self.exampleTable.append(example)
297                r = curs.fetchone()
298            self._dirty = False
299        except Exception, e:
300            self.domain = None
301            raise
302            #self.domain = None
303
304    def data(self):
305        self.update()
306        if self.exampleTable:
307            return self.exampleTable
308        return None
309   
310class SQLWriter(object):
311    def __init__(self, uri = None):
312        if uri is not None:
313            self.connect(uri)
314   
315    def connect(self, uri):
316        (self.quirks, self.connection) = _connection(uri)
317    def __attrVal2sql(self, d):
318        if d.varType == orange.VarTypes.Continuous:
319            return d.value
320        elif d.varType == orange.VarTypes.Discrete:
321            return str(d.value)
322        else:
323            return "'%s'" % str(d.value)
324    def __attrName2sql(self, d):
325        return d.name
326    def __attrType2sql(self, d):
327        return self.quirks.typeDict[d]
328    def write(self, table, data, renameDict = None):
329        """if provided, renameDict maps the names in data to columns in
330        the database. For each var in data: dbColName = renameDict[var.name]"""
331        l = [i.name for i in data.domain.attributes]
332        l += [i.name for i in data.domain.getmetas().values()]
333        if data.domain.classVar:
334            l.append(data.domain.classVar.name)
335        if renameDict is None:
336            renameDict = {}
337        colList = []
338        for i in l:
339            colList.append(renameDict.get(str(i), str(i)))
340        try:
341            cursor=self.connection.cursor()
342            self.quirks.beforeWrite(cursor)
343            query = 'INSERT INTO "%s" (%s) VALUES (%s);'
344            for d in data:
345                valList = []
346                colSList = []
347                for (i, name) in enumerate(colList):
348                    colSList.append('"%s"'% name)
349                    valList.append(self.__attrVal2sql(d[l[i]]))
350                valStr = ', '.join(["%s"]*len(colList))
351                # print "exec:", query % (table, "%s ", "%s "), tuple(colList + valList)
352                cursor.execute(query % (table, 
353                    ", ".join(colSList), 
354                    ", ".join (["%s"] * len(valList))), tuple(valList))
355            cursor.close()
356            self.connection.commit()
357        except Exception, e:
358            import traceback
359        traceback.print_exc()
360            self.connection.rollback()
361
362    def create(self, table, data, renameDict = None, typeDict = None):
363        l = [(i.name, i.varType ) for i in data.domain.attributes]
364        l += [(i.name, i.varType ) for i in data.domain.getmetas().values()]
365        if data.domain.classVar:
366            l.append((data.domain.classVar.name, data.domain.classVar.varType))
367        if renameDict is None:
368            renameDict = {}
369        colNameList = [renameDict.get(str(i[0]), str(i[0])) for i in l]
370        if typeDict is None:
371            typeDict = {}
372        colTypeList = [typeDict.get(str(i[0]), self.__attrType2sql(i[1])) for i in l]
373        try:
374            cursor = self.connection.cursor()
375            colSList = []
376            for (i, name) in enumerate(colNameList):
377                colSList.append('"%s" %s' % (name, colTypeList[i]))
378            colStr = ", ".join(colSList)
379            query = """CREATE TABLE "%s" ( %s );""" % (table, colStr)
380            self.quirks.beforeCreate(cursor)
381            cursor.execute(query)
382            print query
383            self.write(table, data, renameDict)
384            self.connection.commit()
385        except Exception, e:
386            self.connection.rollback()
387   
388    def disconnect(self):
389        self.conn.disconnect()
390
391def loadSQL(filename, dontCheckStored = False, domain = None):
392    f = open(filename)
393    lines = f.readlines()
394    queryLines = []
395    discreteNames = None
396    uri = None
397    metaNames = None
398    className = None
399    for i in lines:
400        if i.startswith("--orng"):
401            (dummy, command, line) = i.split(None, 2)
402            if command == 'uri':
403                uri = eval(line)
404            elif command == 'discrete':
405                discreteNames = eval(line)
406            elif command == 'meta':
407                metaNames = eval(line)
408            elif command == 'class':
409                className = eval(line)
410            else:
411                queryLines.append(i)
412        else:
413            queryLines.append(i)
414    query = "\n".join(queryLines)
415    r = SQLReader(uri)
416    if discreteNames:
417        r.discreteNames = discreteNames
418    if className:
419        r.className = className
420    if metaNames:
421        r.metaNames = metaNames
422    r.execute(query)
423    data = r.data()
424    return data
425
426def saveSQL():
427    pass
Note: See TracBrowser for help on using the repository browser.