source: orange/Orange/data/sql.py @ 10089:e2d97a4e3237

Revision 10089:e2d97a4e3237, 16.3 KB checked in by crt.gorup@…, 2 years ago (diff)

Added support for sqlite.

Line 
1import os
2import urllib
3import Orange
4import orange
5from Orange.misc import deprecated_keywords, deprecated_members
6from Orange.feature import Descriptor
7
8def _parseURI(uri):
9    """ lifted straight from sqlobject """
10    schema, rest = uri.split(':', 1)
11    assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
12    if rest.startswith('/') and not rest.startswith('//'):
13        host = None
14        rest = rest[1:]
15    elif rest.startswith('///'):
16        host = None
17        rest = rest[3:]
18    else:
19        rest = rest[2:]
20        if rest.find('/') == -1:
21            host = rest
22            rest = ''
23        else:
24            host, rest = rest.split('/', 1)
25    if host and host.find('@') != -1:
26        user, host = host.split('@', 1)
27        if user.find(':') != -1:
28            user, password = user.split(':', 1)
29        else:
30            password = None
31    else:
32        user = password = None
33    if host and host.find(':') != -1:
34        _host, port = host.split(':')
35        try:
36            port = int(port)
37        except ValueError:
38            raise ValueError, "port must be integer, got '%s' instead" % port
39        if not (1 <= port <= 65535):
40            raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
41        host = _host
42    else:
43        port = None
44    path = '/' + rest
45    if os.name == 'nt':
46        if (len(rest) > 1) and (rest[1] == '|'):
47            path = "%s:%s" % (rest[0], rest[2:])
48    args = {}
49    if path.find('?') != -1:
50        path, arglist = path.split('?', 1)
51        arglist = arglist.split('&')
52        for single in arglist:
53            argname, argvalue = single.split('=', 1)
54            argvalue = urllib.unquote(argvalue)
55            args[argname] = argvalue
56    return schema, user, password, host, port, path, args
57
58class __MySQLQuirkFix(object):
59    def __init__(self, dbmod):
60        self.dbmod = dbmod
61        self.typeDict = {
62            Descriptor.Continuous:'DOUBLE',
63            Descriptor.Discrete:'VARCHAR(250)', Descriptor.String:'VARCHAR(250)'}
64
65    def beforeWrite(self, cursor):
66        cursor.execute("SET sql_mode='ANSI_QUOTES';")
67
68    def beforeCreate(self, cursor):
69        cursor.execute("SET sql_mode='ANSI_QUOTES';")
70
71class __PostgresQuirkFix(object):
72    def __init__(self, dbmod):
73        self.dbmod = dbmod
74        self.typeDict = {
75            Descriptor.Continuous:'FLOAT',
76            Descriptor.Discrete:'VARCHAR', Descriptor.String:'VARCHAR'}
77
78    def beforeWrite(self, cursor):
79        pass
80
81    def beforeCreate(self, cursor):
82        pass
83
84def _connection(uri):
85        (schema, user, password, host, port, path, args) = _parseURI(uri)
86        argTrans = {
87            'host':'host',
88            'port':'port',
89            'user':'user',
90            'password':'passwd',
91            'database':'db'
92            }
93        if schema == 'postgres':
94            import psycopg2 as dbmod
95            argTrans["database"] = "db"
96            quirks = __PostgresQuirkFix(dbmod)
97            quirks.parameter = "%s"
98        elif schema == 'mysql':
99            import MySQLdb as dbmod
100            quirks = __MySQLQuirkFix(dbmod)
101            quirks.parameter = "%s"
102        elif schema == "sqlite":
103            import sqlite3 as dbmod
104            quirks = __PostgresQuirkFix(dbmod)
105            quirks.parameter = "?"
106            return (quirks, dbmod.connect(host))
107
108        dbArgDict = {}
109        if user:
110            dbArgDict[argTrans['user']] = user
111        if password:
112            dbArgDict[argTrans['password']] = password
113        if host:
114            dbArgDict[argTrans['host']] = host
115        if port:
116            dbArgDict[argTrans['port']] = port
117        if path:
118            dbArgDict[argTrans['database']] = path[1:]
119        return (quirks, dbmod.connect(**dbArgDict))
120
121class SQLReader(object):
122    """
123    :obj:`~SQLReader` establishes a connection with a database and provides the methods needed
124    to fetch the data from the database into Orange.
125    """
126    @deprecated_keywords({"domainDepot":"domain_depot"})
127    def __init__(self, addr = None, domain_depot = None):
128        """
129        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters]).
130        :type uri: str
131
132        :param domain_depot: Domain depot
133        :type domain_depot: :class:`orange.DomainDepot`
134        """
135        if addr is not None:
136            self.connect(addr)
137        if domain_depot is not None:
138            self.domainDepot = domain_depot
139        else:
140            self.domainDepot = orange.DomainDepot()
141        self.exampleTable = None
142        self._dirty = True
143
144    def connect(self, uri):
145        """
146        Connect to the database.
147
148        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
149        :type uri: str
150        """
151        self._dirty = True
152        self.delDomain()
153        (self.quirks, self.conn) = _connection(uri)
154
155    def disconnect(self):
156        """
157        Disconnect from the database.
158        """
159        func = getattr(self.conn, "disconnect", None)
160        if callable(func):
161            self.conn.disconnect()
162
163    def getClassName(self):
164        self.update()
165        return self.domain.class_var.name
166
167    def setClassName(self, className):
168        self._className = className
169        self.delDomain()
170
171    def delClassName(self):
172        del self._className
173
174    class_name = property(getClassName, setClassName, delClassName, "Name of class variable.")
175    className = class_name
176   
177    def getMetaNames(self):
178        self.update()
179        return self.domain.get_metas().values()
180
181    def setMetaNames(self, meta_names):
182        self._metaNames = meta_names
183        self.delDomain()
184
185    def delMetaNames(self):
186        del self._metaNames
187
188    meta_names = property(getMetaNames, setMetaNames, delMetaNames, "Names of meta attributes.")
189    metaName = meta_names
190
191    def setDiscreteNames(self, discrete_names):
192        self._discreteNames = discrete_names
193        self.delDomain()
194
195    def getDiscreteNames(self):
196        self.update()
197        return self._discreteNames
198
199    def delDiscreteNames(self):
200        del self._discreteNames
201
202    discrete_names = property(getDiscreteNames, setDiscreteNames, delDiscreteNames, "Names of discrete attributes.")
203    discreteNames = discrete_names
204
205    def setQuery(self, query, domain = None):
206        #sets the query, resets the internal variables, without executing the query
207        self._query = query
208        self._dirty = True
209        if domain is not None:
210            self._domain = domain
211        else:
212            self.delDomain()
213
214    def getQuery(self):
215        return self._query
216
217    def delQuery(self):
218        del self._query
219
220    query = property(getQuery, setQuery, delQuery, "Query to be executed on the next execute().")
221
222    def generateDomain(self):
223        pass
224
225    def setDomain(self, domain):
226        self._domain = domain
227        self._dirty = True
228
229    def getDomain(self):
230        if not hasattr(self, '_domain'):
231            self._createDomain()
232        return self._domain
233
234    def delDomain(self):
235        if hasattr(self, '_domain'):
236            del self._domain
237
238    domain = property(getDomain, setDomain, delDomain, "Orange domain.")
239
240    def execute(self, query, domain = None):
241        """
242        Executes an sql query.
243        """
244        self.setQuery(query, domain)
245        self.update()
246
247    def _createDomain(self):
248        if hasattr(self, '_domain'):
249            return
250        attrNames = []
251        if not hasattr(self, '_discreteNames'):
252            self._discreteNames = []
253        discreteNames = self._discreteNames
254        if not hasattr(self, '_metaNames'):
255            self._metaNames = []
256        metaNames = self._metaNames
257        if not hasattr(self, '_className'):
258            className = None
259        else:
260            className = self._className
261        for i in self.desc:
262            name = i[0]
263            typ = i[1]
264            if name in discreteNames:
265                attrName = 'D#' + name
266            elif typ is None or typ in [self.quirks.dbmod.STRING, self.quirks.dbmod.DATETIME]:
267                    attrName = 'S#' + name
268            else:
269                attrName = 'C#' + name
270           
271            if name == className:
272                attrName = "c" + attrName
273            elif name in metaNames:
274                attrName = "m" + attrName
275            elif not className and name == 'class':
276                attrName = "c" + attrName
277            attrNames.append(attrName)
278        (self._domain, self._metaIDs, dummy) = self.domainDepot.prepareDomain(attrNames)
279        del dummy
280
281    def update(self):
282        """
283        Execute a pending SQL query.
284        """
285        if not self._dirty and hasattr(self, '_domain'):
286            return self.exampleTable
287        self.exampleTable = None
288        try:
289            curs = self.conn.cursor()
290            try:
291                curs.execute(self.query)
292            except Exception, e:
293                self.conn.rollback()
294                raise e
295            self.desc = curs.description
296            # for reasons unknown, the attributes get reordered.
297            domainIndexes = [0] * len(self.desc)
298            self._createDomain()
299            attrNames = []
300            for i, name in enumerate(self.desc):
301            #    print name[0], '->', self.domain.index(name[0])
302                domainIndexes[self._domain.index(name[0])] = i
303                attrNames.append(name[0])
304            self.exampleTable = Orange.data.Table(self.domain)
305            r = curs.fetchone()
306            while r:
307                # for reasons unknown, domain rearranges the properties
308                example = Orange.data.Instance(self.domain)
309                for i in xrange(len(r)):
310                    val = str(r[i])
311                    var = example[attrNames[i]].variable
312                    if type(var) == Descriptor.Discrete and val not in var.values:
313                        var.values.append(val)
314                    example[attrNames[i]] = str(r[i])
315                self.exampleTable.append(example)
316                r = curs.fetchone()
317            self._dirty = False
318        except Exception, e:
319            self.domain = None
320            raise
321            #self.domain = None
322
323    def data(self):
324        """
325        Return :class:`Orange.data.Table` produced by the last executed query.
326        """
327        self.update()
328        if self.exampleTable:
329            return self.exampleTable
330        return None
331
332class SQLWriter(object):
333    """
334    Establishes a connection with a database and provides the methods needed to create
335    an appropriate table in the database and/or write the data from an :class:`Orange.data.Table`
336    into the database.
337    """
338    def __init__(self, uri = None):
339        """
340        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
341        :type uri: str
342        """
343        if uri is not None:
344            self.connect(uri)
345
346    def connect(self, uri):
347        """
348        Connect to the database.
349
350        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
351        :type uri: str
352        """
353        (self.quirks, self.connection) = _connection(uri)
354
355    def __attrVal2sql(self, d):
356        if d.var_type == Descriptor.Continuous:
357            return d.value
358        elif d.var_type == Descriptor.Discrete:
359            return str(d.value)
360        else:
361            return "'%s'" % str(d.value)
362
363    def __attrName2sql(self, d):
364        return d.name
365
366    def __attrType2sql(self, d):
367        return self.quirks.typeDict[d]
368
369    @deprecated_keywords({"renameDict":"rename_dict"})
370    def write(self, table, instances, rename_dict = None):
371        """
372        Writes the data into the table.
373
374
375        :param table: Table name.
376        :type table: str
377
378        :param instances: Data to be written into the database.
379        :type instances: :class:`Orange.data.Table`
380
381        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
382            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
383        :type rename_dict: dict
384
385        """
386        l = [i.name for i in instances.domain.attributes]
387        l += [i.name for i in instances.domain.get_metas().values()]
388        if instances.domain.class_var:
389            l.append(instances.domain.class_var.name)
390        if rename_dict is None:
391            rename_dict = {}
392        colList = []
393        for i in l:
394            colList.append(rename_dict.get(str(i), str(i)))
395        try:
396            cursor=self.connection.cursor()
397            self.quirks.beforeWrite(cursor)
398            query = 'INSERT INTO "%s" (%s) VALUES (%s);'
399            for d in instances:
400                valList = []
401                colSList = []
402                for (i, name) in enumerate(colList):
403                    colSList.append('"%s"'% name)
404                    valList.append(self.__attrVal2sql(d[l[i]]))
405                d = query % (table,
406                    ", ".join(colSList),
407                    ", ".join ([self.quirks.parameter] * len(valList)))
408                cursor.execute(d, tuple(valList))
409            cursor.close()
410            self.connection.commit()
411        except Exception, e:
412            import traceback
413            traceback.print_exc()
414            self.connection.rollback()
415
416    @deprecated_keywords({"renameDict":"rename_dict", "typeDict":"type_dict"})
417    def create(self, table, instances, rename_dict = {}, type_dict = {}):
418        """
419        Create the required SQL table, then write the data into it.
420
421        :param table: Table name
422        :type table: str
423
424        :param instances: Data to be written into the database.
425        :type instances: :class:`Orange.data.Table`
426
427        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
428            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
429        :type rename_dict: dict
430
431        :param type_dict: When ``type_dict`` is provided the used variables are casted into new types.
432            The type of orange attribute "X" is casted into the database column of type rename_dict["X"].
433        :type type_dict: dict
434
435        """
436        l = [(i.name, i.var_type ) for i in instances.domain.attributes]
437        l += [(i.name, i.var_type ) for i in instances.domain.get_metas().values()]
438        if instances.domain.class_var:
439            l.append((instances.domain.class_var.name, instances.domain.class_var.var_type))
440        #if rename_dict is None:
441        #    rename_dict = {}
442        colNameList = [rename_dict.get(str(i[0]), str(i[0])) for i in l]
443        #if type_dict is None:
444        #    typeDict = {}
445        colTypeList = [type_dict.get(str(i[0]), self.__attrType2sql(i[1])) for i in l]
446        try:
447            cursor = self.connection.cursor()
448            colSList = []
449            for (i, name) in enumerate(colNameList):
450                colSList.append('"%s" %s' % (name, colTypeList[i]))
451            colStr = ", ".join(colSList)
452            query = """CREATE TABLE "%s" ( %s );""" % (table, colStr)
453            self.quirks.beforeCreate(cursor)
454            cursor.execute(query)
455            self.write(table, instances, rename_dict)
456            self.connection.commit()
457        except Exception, e:
458            self.connection.rollback()
459
460    def disconnect(self):
461        """
462        Disconnect from the database.
463        """
464        func = getattr(self.conn, "disconnect", None)
465        if callable(func):
466            self.conn.disconnect()
467
468def loadSQL(filename, dontCheckStored = False, domain = None):
469    f = open(filename)
470    lines = f.readlines()
471    queryLines = []
472    discreteNames = None
473    uri = None
474    metaNames = None
475    className = None
476    for i in lines:
477        if i.startswith("--orng"):
478            (dummy, command, line) = i.split(None, 2)
479            if command == 'uri':
480                uri = eval(line)
481            elif command == 'discrete':
482                discreteNames = eval(line)
483            elif command == 'meta':
484                metaNames = eval(line)
485            elif command == 'class':
486                className = eval(line)
487            else:
488                queryLines.append(i)
489        else:
490            queryLines.append(i)
491    query = "\n".join(queryLines)
492    r = SQLReader(uri)
493    if discreteNames:
494        r.discreteNames = discreteNames
495    if className:
496        r.className = className
497    if metaNames:
498        r.metaNames = metaNames
499    r.execute(query)
500    return r.data()
Note: See TracBrowser for help on using the repository browser.