Ticket #1291: sql.py

File sql.py, 16.5 KB (added by jarnavat, 13 months ago)
Line 
1import os
2import urllib
3import Orange
4import orange
5from Orange.utils import deprecated_keywords, deprecated_members
6from Orange.feature import Descriptor
7
8def _parseURI(uri):
9    """ lifted straight from sqlobject """
10    schema, rest = uri.split(':', 1)
11    assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
12    if rest.startswith('/') and not rest.startswith('//'):
13        host = None
14        rest = rest[1:]
15    elif rest.startswith('///'):
16        host = None
17        rest = rest[3:]
18    else:
19        rest = rest[2:]
20        if rest.find('/') == -1:
21            host = rest
22            rest = ''
23        else:
24            host, rest = rest.split('/', 1)
25    if host and host.find('@') != -1:
26        user, host = host.split('@', 1)
27        if user.find(':') != -1:
28            user, password = user.split(':', 1)
29        else:
30            password = None
31    else:
32        user = password = None
33    if host and host.find(':') != -1:
34        _host, port = host.split(':')
35        try:
36            port = int(port)
37        except ValueError:
38            raise ValueError, "port must be integer, got '%s' instead" % port
39        if not (1 <= port <= 65535):
40            raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
41        host = _host
42    else:
43        port = None
44    path = '/' + rest
45    if os.name == 'nt':
46        if (len(rest) > 1) and (rest[1] == '|'):
47            path = "%s:%s" % (rest[0], rest[2:])
48    args = {}
49    if path.find('?') != -1:
50        path, arglist = path.split('?', 1)
51        arglist = arglist.split('&')
52        for single in arglist:
53            argname, argvalue = single.split('=', 1)
54            argvalue = urllib.unquote(argvalue)
55            args[argname] = argvalue
56    return schema, user, password, host, port, path, args
57
58class __MySQLQuirkFix(object):
59    def __init__(self, dbmod):
60        self.dbmod = dbmod
61        self.typeDict = {
62            Descriptor.Continuous:'DOUBLE',
63            Descriptor.Discrete:'VARCHAR(250)', Descriptor.String:'VARCHAR(250)'}
64
65    def beforeWrite(self, cursor):
66        cursor.execute("SET sql_mode='ANSI_QUOTES';")
67
68    def beforeCreate(self, cursor):
69        cursor.execute("SET sql_mode='ANSI_QUOTES';")
70
71class __PostgresQuirkFix(object):
72    def __init__(self, dbmod):
73        self.dbmod = dbmod
74        self.typeDict = {
75            Descriptor.Continuous:'FLOAT',
76            Descriptor.Discrete:'VARCHAR', Descriptor.String:'VARCHAR'}
77
78    def beforeWrite(self, cursor):
79        pass
80
81    def beforeCreate(self, cursor):
82        pass
83
84def _connection(uri):
85        (schema, user, password, host, port, path, args) = _parseURI(uri)
86        argTrans = {
87            'host':'host',
88            'port':'port',
89            'user':'user',
90            'password':'passwd',
91            'database':'db'
92            }
93        if schema == 'postgres':
94            import psycopg2 as dbmod
95            argTrans["database"] = "db"
96            quirks = __PostgresQuirkFix(dbmod)
97            quirks.parameter = "%s"
98        elif schema == 'mysql':
99            import MySQLdb as dbmod
100            quirks = __MySQLQuirkFix(dbmod)
101            quirks.parameter = "%s"
102        elif schema == "sqlite":
103            import sqlite3 as dbmod
104            quirks = __PostgresQuirkFix(dbmod)
105            quirks.parameter = "?"
106            return (quirks, dbmod.connect(host))
107
108        dbArgDict = {}
109        if user:
110            dbArgDict[argTrans['user']] = user
111        if password:
112            dbArgDict[argTrans['password']] = password
113        if host:
114            dbArgDict[argTrans['host']] = host
115        if port:
116            dbArgDict[argTrans['port']] = port
117        if path:
118            dbArgDict[argTrans['database']] = path[1:]
119        if schema == 'mysql' and args and args.has_key('unix_socket'):
120            dbArgDict['unix_socket'] = args.get('unix_socket')
121        return (quirks, dbmod.connect(**dbArgDict))
122
123
124
125class SQLReader(object):
126    """
127    :obj:`~SQLReader` establishes a connection with a database and provides the methods needed
128    to fetch the data from the database into Orange.
129    """
130    @deprecated_keywords({"domainDepot":"domain_depot"})
131    def __init__(self, addr = None, domain_depot = None):
132        """
133        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters]).
134        :type uri: str
135
136        :param domain_depot: Domain depot
137        :type domain_depot: :class:`orange.DomainDepot`
138        """
139        if addr is not None:
140            self.connect(addr)
141        if domain_depot is not None:
142            self.domainDepot = domain_depot
143        else:
144            self.domainDepot = orange.DomainDepot()
145        self.exampleTable = None
146        self._dirty = True
147
148    def connect(self, uri):
149        """
150        Connect to the database.
151
152        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
153        :type uri: str
154        """
155        self._dirty = True
156        self.del_domain()
157        (self.quirks, self.conn) = _connection(uri)
158
159    def disconnect(self):
160        """
161        Disconnect from the database.
162        """
163        func = getattr(self.conn, "disconnect", None)
164        if callable(func):
165            self.conn.disconnect()
166
167    def get_class_name(self):
168        self.update()
169        return self.domain.class_var.name
170
171    def set_class_name(self, class_name):
172        self._className = class_name
173        self.del_domain()
174
175    def del_class_name(self):
176        del self._className
177
178    class_name = property(get_class_name, set_class_name, del_class_name, "Name of class variable.")
179
180    def get_metas_name(self):
181        self.update()
182        return self.domain.get_metas().values()
183
184    def set_metas_name(self, meta_names):
185        self._metaNames = meta_names
186        self.del_domain()
187
188    def del_metas_name(self):
189        del self._metaNames
190
191    meta_names = property(get_metas_name, set_metas_name, del_metas_name, "Names of meta attributes.")
192
193    def set_discrete_names(self, discrete_names):
194        self._discreteNames = discrete_names
195        self.del_domain()
196
197    def get_discrete_names(self):
198        self.update()
199        return self._discreteNames
200
201    def del_discrete_names(self):
202        del self._discreteNames
203
204    discrete_names = property(get_discrete_names, set_discrete_names, del_discrete_names, "Names of discrete attributes.")
205
206    def set_query(self, query, domain = None):
207        #sets the query, resets the internal variables, without executing the query
208        self._query = query
209        self._dirty = True
210        if domain is not None:
211            self._domain = domain
212        else:
213            self.del_domain()
214
215    def get_query(self):
216        return self._query
217
218    def del_query(self):
219        del self._query
220
221    query = property(get_query, set_query, del_query, "Query to be executed on the next execute().")
222
223    def generateDomain(self):
224        pass
225
226    def set_domain(self, domain):
227        self._domain = domain
228        self._dirty = True
229
230    def get_domain(self):
231        if not hasattr(self, '_domain'):
232            self._createDomain()
233        return self._domain
234
235    def del_domain(self):
236        if hasattr(self, '_domain'):
237            del self._domain
238
239    domain = property(get_domain, set_domain, del_domain, "Orange domain.")
240
241    def execute(self, query, domain = None):
242        """
243        Executes an sql query.
244        """
245        self.set_query(query, domain)
246        self.update()
247
248    def _createDomain(self):
249        if hasattr(self, '_domain'):
250            return
251        attrNames = []
252        if not hasattr(self, '_discreteNames'):
253            self._discreteNames = []
254        discreteNames = self._discreteNames
255        if not hasattr(self, '_metaNames'):
256            self._metaNames = []
257        metaNames = self._metaNames
258        if not hasattr(self, '_className'):
259            className = None
260        else:
261            className = self._className
262        for i in self.desc:
263            name = i[0]
264            typ = i[1]
265            if name in discreteNames:
266                attrName = 'D#' + name
267            elif typ is None or typ in [self.quirks.dbmod.STRING, self.quirks.dbmod.DATETIME]:
268                    attrName = 'S#' + name
269            else:
270                attrName = 'C#' + name
271           
272            if name == className:
273                attrName = "c" + attrName
274            elif name in metaNames:
275                attrName = "m" + attrName
276            elif not className and name == 'class':
277                attrName = "c" + attrName
278            attrNames.append(attrName)
279        (self._domain, self._metaIDs, dummy) = self.domainDepot.prepareDomain(attrNames)
280        del dummy
281
282    def update(self):
283        """
284        Execute a pending SQL query.
285        """
286        if not self._dirty and hasattr(self, '_domain'):
287            return self.exampleTable
288        self.exampleTable = None
289        try:
290            curs = self.conn.cursor()
291            try:
292                curs.execute(self.query)
293            except Exception, e:
294                self.conn.rollback()
295                raise e
296            self.desc = curs.description
297            # for reasons unknown, the attributes get reordered.
298            domainIndexes = [0] * len(self.desc)
299            self._createDomain()
300            attrNames = []
301            for i, name in enumerate(self.desc):
302            #    print name[0], '->', self.domain.index(name[0])
303                domainIndexes[self._domain.index(name[0])] = i
304                attrNames.append(name[0])
305            self.exampleTable = Orange.data.Table(self.domain)
306            r = curs.fetchone()
307            while r:
308                # for reasons unknown, domain rearranges the properties
309                example = Orange.data.Instance(self.domain)
310                for i in xrange(len(r)):
311                    val = str(r[i])
312                    var = example[attrNames[i]].variable
313                    if type(var) == Descriptor.Discrete and val not in var.values:
314                        var.values.append(val)
315                    example[attrNames[i]] = str(r[i])
316                self.exampleTable.append(example)
317                r = curs.fetchone()
318            self._dirty = False
319        except Exception, e:
320            self.domain = None
321            raise
322            #self.domain = None
323
324    def data(self):
325        """
326        Return :class:`Orange.data.Table` produced by the last executed query.
327        """
328        self.update()
329        if self.exampleTable:
330            return self.exampleTable
331        return None
332
333SQLReader = deprecated_members({"discreteNames":"discrete_names", "metaName":"meta_names"\
334    , "className":"class_name"})(SQLReader)
335
336class SQLWriter(object):
337    """
338    Establishes a connection with a database and provides the methods needed to create
339    an appropriate table in the database and/or write the data from an :class:`Orange.data.Table`
340    into the database.
341    """
342    def __init__(self, uri = None):
343        """
344        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
345        :type uri: str
346        """
347        if uri is not None:
348            self.connect(uri)
349
350    def connect(self, uri):
351        """
352        Connect to the database.
353
354        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
355        :type uri: str
356        """
357        (self.quirks, self.connection) = _connection(uri)
358
359    def __attrVal2sql(self, d):
360        if d.var_type == Descriptor.Continuous:
361            return d.value
362        elif d.var_type == Descriptor.Discrete:
363            return str(d.value)
364        else:
365            return "'%s'" % str(d.value)
366
367    def __attrName2sql(self, d):
368        return d.name
369
370    def __attrType2sql(self, d):
371        return self.quirks.typeDict[d]
372
373    @deprecated_keywords({"renameDict":"rename_dict"})
374    def write(self, table, instances, rename_dict = None):
375        """
376        Writes the data into the table.
377
378
379        :param table: Table name.
380        :type table: str
381
382        :param instances: Data to be written into the database.
383        :type instances: :class:`Orange.data.Table`
384
385        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
386            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
387        :type rename_dict: dict
388
389        """
390        l = [i.name for i in instances.domain.attributes]
391        l += [i.name for i in instances.domain.get_metas().values()]
392        if instances.domain.class_var:
393            l.append(instances.domain.class_var.name)
394        if rename_dict is None:
395            rename_dict = {}
396        colList = []
397        for i in l:
398            colList.append(rename_dict.get(str(i), str(i)))
399        try:
400            cursor=self.connection.cursor()
401            self.quirks.beforeWrite(cursor)
402            query = 'INSERT INTO "%s" (%s) VALUES (%s);'
403            for d in instances:
404                valList = []
405                colSList = []
406                for (i, name) in enumerate(colList):
407                    colSList.append('"%s"'% name)
408                    valList.append(self.__attrVal2sql(d[l[i]]))
409                d = query % (table,
410                    ", ".join(colSList),
411                    ", ".join ([self.quirks.parameter] * len(valList)))
412                cursor.execute(d, tuple(valList))
413            cursor.close()
414            self.connection.commit()
415        except Exception, e:
416            import traceback
417            traceback.print_exc()
418            self.connection.rollback()
419
420    @deprecated_keywords({"renameDict":"rename_dict", "typeDict":"type_dict"})
421    def create(self, table, instances, rename_dict = {}, type_dict = {}):
422        """
423        Create the required SQL table, then write the data into it.
424
425        :param table: Table name
426        :type table: str
427
428        :param instances: Data to be written into the database.
429        :type instances: :class:`Orange.data.Table`
430
431        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
432            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
433        :type rename_dict: dict
434
435        :param type_dict: When ``type_dict`` is provided the used variables are casted into new types.
436            The type of orange attribute "X" is casted into the database column of type rename_dict["X"].
437        :type type_dict: dict
438
439        """
440        l = [(i.name, i.var_type ) for i in instances.domain.attributes]
441        l += [(i.name, i.var_type ) for i in instances.domain.get_metas().values()]
442        if instances.domain.class_var:
443            l.append((instances.domain.class_var.name, instances.domain.class_var.var_type))
444        #if rename_dict is None:
445        #    rename_dict = {}
446        colNameList = [rename_dict.get(str(i[0]), str(i[0])) for i in l]
447        #if type_dict is None:
448        #    typeDict = {}
449        colTypeList = [type_dict.get(str(i[0]), self.__attrType2sql(i[1])) for i in l]
450        try:
451            cursor = self.connection.cursor()
452            colSList = []
453            for (i, name) in enumerate(colNameList):
454                colSList.append('"%s" %s' % (name, colTypeList[i]))
455            colStr = ", ".join(colSList)
456            query = """CREATE TABLE "%s" ( %s );""" % (table, colStr)
457            self.quirks.beforeCreate(cursor)
458            cursor.execute(query)
459            self.write(table, instances, rename_dict)
460            self.connection.commit()
461        except Exception, e:
462            self.connection.rollback()
463
464    def disconnect(self):
465        """
466        Disconnect from the database.
467        """
468        func = getattr(self.conn, "disconnect", None)
469        if callable(func):
470            self.conn.disconnect()
471
472def loadSQL(filename, dontCheckStored = False, domain = None):
473    f = open(filename)
474    lines = f.readlines()
475    queryLines = []
476    discreteNames = None
477    uri = None
478    metaNames = None
479    className = None
480    for i in lines:
481        if i.startswith("--orng"):
482            (dummy, command, line) = i.split(None, 2)
483            if command == 'uri':
484                uri = eval(line)
485            elif command == 'discrete':
486                discreteNames = eval(line)
487            elif command == 'meta':
488                metaNames = eval(line)
489            elif command == 'class':
490                className = eval(line)
491            else:
492                queryLines.append(i)
493        else:
494            queryLines.append(i)
495    query = "\n".join(queryLines)
496    r = SQLReader(uri)
497    if discreteNames:
498        r.discreteNames = discreteNames
499    if className:
500        r.className = className
501    if metaNames:
502        r.metaNames = metaNames
503    r.execute(query)
504    return r.data()