source: orange/Orange/data/sql.py @ 9993:087d1df6ace9

Revision 9993:087d1df6ace9, 15.9 KB checked in by crt.gorup@…, 2 years ago (diff)

Code for Orange.data.sql.

Line 
1import os
2import urllib
3import Orange
4from Orange.misc import deprecated_keywords, deprecated_members
5from Orange.feature import Descriptor
6
7def _parseURI(uri):
8    """ lifted straight from sqlobject """
9    schema, rest = uri.split(':', 1)
10    assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
11    if rest.startswith('/') and not rest.startswith('//'):
12        host = None
13        rest = rest[1:]
14    elif rest.startswith('///'):
15        host = None
16        rest = rest[3:]
17    else:
18        rest = rest[2:]
19        if rest.find('/') == -1:
20            host = rest
21            rest = ''
22        else:
23            host, rest = rest.split('/', 1)
24    if host and host.find('@') != -1:
25        user, host = host.split('@', 1)
26        if user.find(':') != -1:
27            user, password = user.split(':', 1)
28        else:
29            password = None
30    else:
31        user = password = None
32    if host and host.find(':') != -1:
33        _host, port = host.split(':')
34        try:
35            port = int(port)
36        except ValueError:
37            raise ValueError, "port must be integer, got '%s' instead" % port
38        if not (1 <= port <= 65535):
39            raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
40        host = _host
41    else:
42        port = None
43    path = '/' + rest
44    if os.name == 'nt':
45        if (len(rest) > 1) and (rest[1] == '|'):
46            path = "%s:%s" % (rest[0], rest[2:])
47    args = {}
48    if path.find('?') != -1:
49        path, arglist = path.split('?', 1)
50        arglist = arglist.split('&')
51        for single in arglist:
52            argname, argvalue = single.split('=', 1)
53            argvalue = urllib.unquote(argvalue)
54            args[argname] = argvalue
55    return schema, user, password, host, port, path, args
56
57class __MySQLQuirkFix(object):
58    def __init__(self, dbmod):
59        self.dbmod = dbmod
60        self.typeDict = {
61            Descriptor.Continuous:'DOUBLE',
62            Descriptor.Discrete:'VARCHAR(250)', Descriptor.String:'VARCHAR(250)'}
63
64    def beforeWrite(self, cursor):
65        cursor.execute("SET sql_mode='ANSI_QUOTES';")
66
67    def beforeCreate(self, cursor):
68        cursor.execute("SET sql_mode='ANSI_QUOTES';")
69
70class __PostgresQuirkFix(object):
71    def __init__(self, dbmod):
72        self.dbmod = dbmod
73        self.typeDict = {
74            Descriptor.Continuous:'FLOAT',
75            Descriptor.Discrete:'VARCHAR', Descriptor.String:'VARCHAR'}
76
77    def beforeWrite(self, cursor):
78        pass
79
80    def beforeCreate(self, cursor):
81        pass
82
83def _connection(uri):
84        (schema, user, password, host, port, path, args) = _parseURI(uri)
85        argTrans = {
86            'host':'host',
87            'port':'port',
88            'user':'user',
89            'password':'passwd',
90            'database':'db'
91            }
92        if schema == 'postgres':
93            import psycopg2 as dbmod
94            argTrans["database"] = "db"
95            quirks = __PostgresQuirkFix(dbmod)
96        elif schema == 'mysql':
97            import MySQLdb as dbmod
98            quirks = __MySQLQuirkFix(dbmod)
99
100        dbArgDict = {}
101
102        if user:
103            dbArgDict[argTrans['user']] = user
104        if password:
105            dbArgDict[argTrans['password']] = password
106        if host:
107            dbArgDict[argTrans['host']] = host
108        if port:
109            dbArgDict[argTrans['port']] = port
110        if path:
111            dbArgDict[argTrans['database']] = path[1:]
112        return (quirks, dbmod.connect(**dbArgDict))
113
114class SQLReader(object):
115    """
116    :obj:`~SQLReader` establishes a connection with a database and provides the methods needed
117    to fetch the data from the database into Orange.
118    """
119    @deprecated_keywords({"domainDepot":"domain_depot"})
120    def __init__(self, addr = None, domain_depot = None):
121        """
122        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters]).
123        :type uri: str
124
125        :param domain_depot: Domain depot
126        :type domain_depot: :class:`orange.DomainDepot`
127        """
128        if uri is not None:
129            self.connect(uri)
130        if domain_depot is not None:
131            self.domainDepot = domain_depot
132        else:
133            self.domainDepot = orange.DomainDepot()
134        self.exampleTable = None
135        self._dirty = True
136
137    def connect(self, uri):
138        """
139        Connect to the database.
140
141        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
142        :type uri: str
143        """
144        self._dirty = True
145        self.delDomain()
146        (self.quirks, self.conn) = _connection(uri)
147
148    def disconnect(self):
149        """
150        Disconnect from the database.
151        """
152        self.conn.disconnect()
153
154    def getClassName(self):
155        self.update()
156        return self.domain.class_var.name
157
158    def setClassName(self, className):
159        self._className = className
160        self.delDomain()
161
162    def delClassName(self):
163        del self._className
164
165    class_name = property(getClassName, setClassName, delClassName, "Name of class variable.")
166    className = class_name
167   
168    def getMetaNames(self):
169        self.update()
170        return self.domain.get_metas().values()
171
172    def setMetaNames(self, meta_names):
173        self._metaNames = meta_names
174        self.delDomain()
175
176    def delMetaNames(self):
177        del self._metaNames
178
179    meta_names = property(getMetaNames, setMetaNames, delMetaNames, "Names of meta attributes.")
180    metaName = meta_names
181
182    def setDiscreteNames(self, discrete_names):
183        self._discreteNames = discrete_names
184        self.delDomain()
185
186    def getDiscreteNames(self):
187        self.update()
188        return self._discreteNames
189
190    def delDiscreteNames(self):
191        del self._discreteNames
192
193    discrete_names = property(getDiscreteNames, setDiscreteNames, delDiscreteNames, "Names of discrete attributes.")
194    discreteNames = discrete_names
195
196    def setQuery(self, query, domain = None):
197        #sets the query, resets the internal variables, without executing the query
198        self._query = query
199        self._dirty = True
200        if domain is not None:
201            self._domain = domain
202        else:
203            self.delDomain()
204
205    def getQuery(self):
206        return self._query
207
208    def delQuery(self):
209        del self._query
210
211    query = property(getQuery, setQuery, delQuery, "Query to be executed on the next execute().")
212
213    def generateDomain(self):
214        pass
215
216    def setDomain(self, domain):
217        self._domain = domain
218        self._dirty = True
219
220    def getDomain(self):
221        if not hasattr(self, '_domain'):
222            self._createDomain()
223        return self._domain
224
225    def delDomain(self):
226        if hasattr(self, '_domain'):
227            del self._domain
228
229    domain = property(getDomain, setDomain, delDomain, "Orange domain.")
230
231    def execute(self, query, domain = None):
232        """
233        Executes an sql query.
234        """
235        self.setQuery(query, domain)
236        self.update()
237
238    def _createDomain(self):
239        if hasattr(self, '_domain'):
240            return
241        attrNames = []
242        if not hasattr(self, '_discreteNames'):
243            self._discreteNames = []
244        discreteNames = self._discreteNames
245        if not hasattr(self, '_metaNames'):
246            self._metaNames = []
247        metaNames = self._metaNames
248        if not hasattr(self, '_className'):
249            className = None
250        else:
251            className = self._className
252        for i in self.desc:
253            name = i[0]
254            typ = i[1]
255            if name in discreteNames:
256                attrName = 'D#' + name
257            elif typ == self.quirks.dbmod.STRING:
258                    attrName = 'S#' + name
259            elif typ == self.quirks.dbmod.DATETIME:
260                attrName = 'S#' + name
261            else:
262                attrName = 'C#' + name
263           
264            if name == className:
265                attrName = "c" + attrName
266            elif name in metaNames:
267                attrName = "m" + attrName
268            elif not className and name == 'class':
269                attrName = "c" + attrName
270            attrNames.append(attrName)
271        (self._domain, self._metaIDs, dummy) = self.domainDepot.prepareDomain(attrNames)
272        del dummy
273
274    def update(self):
275        """
276        Execute a pending SQL query.
277        """
278        if not self._dirty and hasattr(self, '_domain'):
279            return self.exampleTable
280        self.exampleTable = None
281        try:
282            curs = self.conn.cursor()
283            try:
284                curs.execute(self.query)
285            except Exception, e:
286                self.conn.rollback()
287                raise e
288            self.desc = curs.description
289            # for reasons unknown, the attributes get reordered.
290            domainIndexes = [0] * len(self.desc)
291            self._createDomain()
292            attrNames = []
293            for i, name in enumerate(self.desc):
294            #    print name[0], '->', self.domain.index(name[0])
295                domainIndexes[self._domain.index(name[0])] = i
296                attrNames.append(name[0])
297            self.exampleTable = Orange.data.Table(self.domain)
298            r = curs.fetchone()
299            while r:
300                # for reasons unknown, domain rearranges the properties
301                example = Orange.data.Instance(self.domain)
302                for i in xrange(len(r)):
303                    val = str(r[i])
304                    var = example[attrNames[i]].variable
305                    if type(var) == Descriptor.Discrete and val not in var.values:
306                        var.values.append(val)
307                    example[attrNames[i]] = str(r[i])
308                self.exampleTable.append(example)
309                r = curs.fetchone()
310            self._dirty = False
311        except Exception, e:
312            self.domain = None
313            raise
314            #self.domain = None
315
316    def data(self):
317        """
318        Return :class:`Orange.data.Table` produced by the last executed query.
319        """
320        self.update()
321        if self.exampleTable:
322            return self.exampleTable
323        return None
324
325class SQLWriter(object):
326    """
327    Establishes a connection with a database and provides the methods needed to create
328    an appropriate table in the database and/or write the data from an :class:`Orange.data.Table`
329    into the database.
330    """
331    def __init__(self, uri = None):
332        """
333        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
334        :type uri: str
335        """
336        if uri is not None:
337            self.connect(uri)
338
339    def connect(self, uri):
340        """
341        Connect to the database.
342
343        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
344        :type uri: str
345        """
346        (self.quirks, self.connection) = _connection(uri)
347
348    def __attrVal2sql(self, d):
349        if d.var_type == Descriptor.Continuous:
350            return d.value
351        elif d.var_type == Descriptor.Discrete:
352            return str(d.value)
353        else:
354            return "'%s'" % str(d.value)
355
356    def __attrName2sql(self, d):
357        return d.name
358
359    def __attrType2sql(self, d):
360        return self.quirks.typeDict[d]
361
362    @deprecated_keywords({"renameDict":"rename_dict"})
363    def write(self, table, instances, rename_dict = None):
364        """
365        Writes the data into the table.
366
367
368        :param table: Table name.
369        :type table: str
370
371        :param instances: Data to be written into the database.
372        :type instances: :class:`Orange.data.Table`
373
374        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
375            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
376        :type rename_dict: dict
377
378        """
379        l = [i.name for i in instances.domain.attributes]
380        l += [i.name for i in instances.domain.get_metas().values()]
381        if instances.domain.class_var:
382            l.append(instances.domain.class_var.name)
383        if rename_dict is None:
384            rename_dict = {}
385        colList = []
386        for i in l:
387            colList.append(rename_dict.get(str(i), str(i)))
388        try:
389            cursor=self.connection.cursor()
390            self.quirks.beforeWrite(cursor)
391            query = 'INSERT INTO "%s" (%s) VALUES (%s);'
392            for d in instances:
393                valList = []
394                colSList = []
395                for (i, name) in enumerate(colList):
396                    colSList.append('"%s"'% name)
397                    valList.append(self.__attrVal2sql(d[l[i]]))
398                valStr = ', '.join(["%s"]*len(colList))
399                cursor.execute(query % (table,
400                    ", ".join(colSList),
401                    ", ".join (["%s"] * len(valList))), tuple(valList))
402            cursor.close()
403            self.connection.commit()
404        except Exception, e:
405            import traceback
406            traceback.print_exc()
407            self.connection.rollback()
408
409    @deprecated_keywords({"renameDict":"rename_dict", "typeDict":"type_dict"})
410    def create(self, table, instances, rename_dict = None, type_dict = None):
411        """
412        Create the required SQL table, then write the data into it.
413
414        :param table: Table name
415        :type table: str
416
417        :param instances: Data to be written into the database.
418        :type instances: :class:`Orange.data.Table`
419
420        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
421            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
422        :type rename_dict: dict
423
424        :param type_dict: When ``type_dict`` is provided the used variables are casted into new types.
425            The type of orange attribute "X" is casted into the database column of type rename_dict["X"].
426        :type type_dict: dict
427
428        """
429        l = [(i.name, i.var_type ) for i in instances.domain.attributes]
430        l += [(i.name, i.var_type ) for i in instances.domain.get_metas().values()]
431        if instances.domain.class_var:
432            l.append((instances.domain.class_var.name, instances.domain.class_var.var_type))
433        if rename_dict is None:
434            renameDict = {}
435        colNameList = [rename_dict.get(str(i[0]), str(i[0])) for i in l]
436        if type_dict is None:
437            typeDict = {}
438        colTypeList = [type_dict.get(str(i[0]), self.__attrType2sql(i[1])) for i in l]
439        try:
440            cursor = self.connection.cursor()
441            colSList = []
442            for (i, name) in enumerate(colNameList):
443                colSList.append('"%s" %s' % (name, colTypeList[i]))
444            colStr = ", ".join(colSList)
445            query = """CREATE TABLE "%s" ( %s );""" % (table, colStr)
446            self.quirks.beforeCreate(cursor)
447            cursor.execute(query)
448            print query
449            self.write(table, instances, rename_dict)
450            self.connection.commit()
451        except Exception, e:
452            self.connection.rollback()
453
454    def disconnect(self):
455        """
456        Disconnect from the database.
457        """
458        self.conn.disconnect()
459
460def loadSQL(filename, dontCheckStored = False, domain = None):
461    f = open(filename)
462    lines = f.readlines()
463    queryLines = []
464    discreteNames = None
465    uri = None
466    metaNames = None
467    className = None
468    for i in lines:
469        if i.startswith("--orng"):
470            (dummy, command, line) = i.split(None, 2)
471            if command == 'uri':
472                uri = eval(line)
473            elif command == 'discrete':
474                discreteNames = eval(line)
475            elif command == 'meta':
476                metaNames = eval(line)
477            elif command == 'class':
478                className = eval(line)
479            else:
480                queryLines.append(i)
481        else:
482            queryLines.append(i)
483    query = "\n".join(queryLines)
484    r = SQLReader(uri)
485    if discreteNames:
486        r.discreteNames = discreteNames
487    if className:
488        r.className = className
489    if metaNames:
490        r.metaNames = metaNames
491    r.execute(query)
492    return r.data()
Note: See TracBrowser for help on using the repository browser.