source: orange/Orange/data/sql.py @ 10438:02556e77b2d0

Revision 10438:02556e77b2d0, 16.4 KB checked in by crt.gorup@…, 2 years ago (diff)

Fixes #747.

Line 
1import os
2import urllib
3import Orange
4import orange
5from Orange.misc import deprecated_keywords, deprecated_members
6from Orange.feature import Descriptor
7
8def _parseURI(uri):
9    """ lifted straight from sqlobject """
10    schema, rest = uri.split(':', 1)
11    assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
12    if rest.startswith('/') and not rest.startswith('//'):
13        host = None
14        rest = rest[1:]
15    elif rest.startswith('///'):
16        host = None
17        rest = rest[3:]
18    else:
19        rest = rest[2:]
20        if rest.find('/') == -1:
21            host = rest
22            rest = ''
23        else:
24            host, rest = rest.split('/', 1)
25    if host and host.find('@') != -1:
26        user, host = host.split('@', 1)
27        if user.find(':') != -1:
28            user, password = user.split(':', 1)
29        else:
30            password = None
31    else:
32        user = password = None
33    if host and host.find(':') != -1:
34        _host, port = host.split(':')
35        try:
36            port = int(port)
37        except ValueError:
38            raise ValueError, "port must be integer, got '%s' instead" % port
39        if not (1 <= port <= 65535):
40            raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
41        host = _host
42    else:
43        port = None
44    path = '/' + rest
45    if os.name == 'nt':
46        if (len(rest) > 1) and (rest[1] == '|'):
47            path = "%s:%s" % (rest[0], rest[2:])
48    args = {}
49    if path.find('?') != -1:
50        path, arglist = path.split('?', 1)
51        arglist = arglist.split('&')
52        for single in arglist:
53            argname, argvalue = single.split('=', 1)
54            argvalue = urllib.unquote(argvalue)
55            args[argname] = argvalue
56    return schema, user, password, host, port, path, args
57
58class __MySQLQuirkFix(object):
59    def __init__(self, dbmod):
60        self.dbmod = dbmod
61        self.typeDict = {
62            Descriptor.Continuous:'DOUBLE',
63            Descriptor.Discrete:'VARCHAR(250)', Descriptor.String:'VARCHAR(250)'}
64
65    def beforeWrite(self, cursor):
66        cursor.execute("SET sql_mode='ANSI_QUOTES';")
67
68    def beforeCreate(self, cursor):
69        cursor.execute("SET sql_mode='ANSI_QUOTES';")
70
71class __PostgresQuirkFix(object):
72    def __init__(self, dbmod):
73        self.dbmod = dbmod
74        self.typeDict = {
75            Descriptor.Continuous:'FLOAT',
76            Descriptor.Discrete:'VARCHAR', Descriptor.String:'VARCHAR'}
77
78    def beforeWrite(self, cursor):
79        pass
80
81    def beforeCreate(self, cursor):
82        pass
83
84def _connection(uri):
85        (schema, user, password, host, port, path, args) = _parseURI(uri)
86        argTrans = {
87            'host':'host',
88            'port':'port',
89            'user':'user',
90            'password':'passwd',
91            'database':'db'
92            }
93        if schema == 'postgres':
94            import psycopg2 as dbmod
95            argTrans["database"] = "db"
96            quirks = __PostgresQuirkFix(dbmod)
97            quirks.parameter = "%s"
98        elif schema == 'mysql':
99            import MySQLdb as dbmod
100            quirks = __MySQLQuirkFix(dbmod)
101            quirks.parameter = "%s"
102        elif schema == "sqlite":
103            import sqlite3 as dbmod
104            quirks = __PostgresQuirkFix(dbmod)
105            quirks.parameter = "?"
106            return (quirks, dbmod.connect(host))
107
108        dbArgDict = {}
109        if user:
110            dbArgDict[argTrans['user']] = user
111        if password:
112            dbArgDict[argTrans['password']] = password
113        if host:
114            dbArgDict[argTrans['host']] = host
115        if port:
116            dbArgDict[argTrans['port']] = port
117        if path:
118            dbArgDict[argTrans['database']] = path[1:]
119        return (quirks, dbmod.connect(**dbArgDict))
120
121
122
123class SQLReader(object):
124    """
125    :obj:`~SQLReader` establishes a connection with a database and provides the methods needed
126    to fetch the data from the database into Orange.
127    """
128    @deprecated_keywords({"domainDepot":"domain_depot"})
129    def __init__(self, addr = None, domain_depot = None):
130        """
131        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters]).
132        :type uri: str
133
134        :param domain_depot: Domain depot
135        :type domain_depot: :class:`orange.DomainDepot`
136        """
137        if addr is not None:
138            self.connect(addr)
139        if domain_depot is not None:
140            self.domainDepot = domain_depot
141        else:
142            self.domainDepot = orange.DomainDepot()
143        self.exampleTable = None
144        self._dirty = True
145
146    def connect(self, uri):
147        """
148        Connect to the database.
149
150        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
151        :type uri: str
152        """
153        self._dirty = True
154        self.del_domain()
155        (self.quirks, self.conn) = _connection(uri)
156
157    def disconnect(self):
158        """
159        Disconnect from the database.
160        """
161        func = getattr(self.conn, "disconnect", None)
162        if callable(func):
163            self.conn.disconnect()
164
165    def get_class_name(self):
166        self.update()
167        return self.domain.class_var.name
168
169    def set_class_name(self, class_name):
170        self._className = class_name
171        self.del_domain()
172
173    def del_class_name(self):
174        del self._className
175
176    class_name = property(get_class_name, set_class_name, del_class_name, "Name of class variable.")
177
178    def get_metas_name(self):
179        self.update()
180        return self.domain.get_metas().values()
181
182    def set_metas_name(self, meta_names):
183        self._metaNames = meta_names
184        self.del_domain()
185
186    def del_metas_name(self):
187        del self._metaNames
188
189    meta_names = property(get_metas_name, set_metas_name, del_metas_name, "Names of meta attributes.")
190
191    def set_discrete_names(self, discrete_names):
192        self._discreteNames = discrete_names
193        self.del_domain()
194
195    def get_discrete_names(self):
196        self.update()
197        return self._discreteNames
198
199    def del_discrete_names(self):
200        del self._discreteNames
201
202    discrete_names = property(get_discrete_names, set_discrete_names, del_discrete_names, "Names of discrete attributes.")
203
204    def set_query(self, query, domain = None):
205        #sets the query, resets the internal variables, without executing the query
206        self._query = query
207        self._dirty = True
208        if domain is not None:
209            self._domain = domain
210        else:
211            self.del_domain()
212
213    def get_query(self):
214        return self._query
215
216    def del_query(self):
217        del self._query
218
219    query = property(get_query, set_query, del_query, "Query to be executed on the next execute().")
220
221    def generateDomain(self):
222        pass
223
224    def set_domain(self, domain):
225        self._domain = domain
226        self._dirty = True
227
228    def get_domain(self):
229        if not hasattr(self, '_domain'):
230            self._createDomain()
231        return self._domain
232
233    def del_domain(self):
234        if hasattr(self, '_domain'):
235            del self._domain
236
237    domain = property(get_domain, set_domain, del_domain, "Orange domain.")
238
239    def execute(self, query, domain = None):
240        """
241        Executes an sql query.
242        """
243        self.set_query(query, domain)
244        self.update()
245
246    def _createDomain(self):
247        if hasattr(self, '_domain'):
248            return
249        attrNames = []
250        if not hasattr(self, '_discreteNames'):
251            self._discreteNames = []
252        discreteNames = self._discreteNames
253        if not hasattr(self, '_metaNames'):
254            self._metaNames = []
255        metaNames = self._metaNames
256        if not hasattr(self, '_className'):
257            className = None
258        else:
259            className = self._className
260        for i in self.desc:
261            name = i[0]
262            typ = i[1]
263            if name in discreteNames:
264                attrName = 'D#' + name
265            elif typ is None or typ in [self.quirks.dbmod.STRING, self.quirks.dbmod.DATETIME]:
266                    attrName = 'S#' + name
267            else:
268                attrName = 'C#' + name
269           
270            if name == className:
271                attrName = "c" + attrName
272            elif name in metaNames:
273                attrName = "m" + attrName
274            elif not className and name == 'class':
275                attrName = "c" + attrName
276            attrNames.append(attrName)
277        (self._domain, self._metaIDs, dummy) = self.domainDepot.prepareDomain(attrNames)
278        del dummy
279
280    def update(self):
281        """
282        Execute a pending SQL query.
283        """
284        if not self._dirty and hasattr(self, '_domain'):
285            return self.exampleTable
286        self.exampleTable = None
287        try:
288            curs = self.conn.cursor()
289            try:
290                curs.execute(self.query)
291            except Exception, e:
292                self.conn.rollback()
293                raise e
294            self.desc = curs.description
295            # for reasons unknown, the attributes get reordered.
296            domainIndexes = [0] * len(self.desc)
297            self._createDomain()
298            attrNames = []
299            for i, name in enumerate(self.desc):
300            #    print name[0], '->', self.domain.index(name[0])
301                domainIndexes[self._domain.index(name[0])] = i
302                attrNames.append(name[0])
303            self.exampleTable = Orange.data.Table(self.domain)
304            r = curs.fetchone()
305            while r:
306                # for reasons unknown, domain rearranges the properties
307                example = Orange.data.Instance(self.domain)
308                for i in xrange(len(r)):
309                    val = str(r[i])
310                    var = example[attrNames[i]].variable
311                    if type(var) == Descriptor.Discrete and val not in var.values:
312                        var.values.append(val)
313                    example[attrNames[i]] = str(r[i])
314                self.exampleTable.append(example)
315                r = curs.fetchone()
316            self._dirty = False
317        except Exception, e:
318            self.domain = None
319            raise
320            #self.domain = None
321
322    def data(self):
323        """
324        Return :class:`Orange.data.Table` produced by the last executed query.
325        """
326        self.update()
327        if self.exampleTable:
328            return self.exampleTable
329        return None
330
331SQLReader = deprecated_members({"discreteNames":"discrete_names", "metaName":"meta_names"\
332    , "className":"class_name"})(SQLReader)
333
334class SQLWriter(object):
335    """
336    Establishes a connection with a database and provides the methods needed to create
337    an appropriate table in the database and/or write the data from an :class:`Orange.data.Table`
338    into the database.
339    """
340    def __init__(self, uri = None):
341        """
342        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
343        :type uri: str
344        """
345        if uri is not None:
346            self.connect(uri)
347
348    def connect(self, uri):
349        """
350        Connect to the database.
351
352        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
353        :type uri: str
354        """
355        (self.quirks, self.connection) = _connection(uri)
356
357    def __attrVal2sql(self, d):
358        if d.var_type == Descriptor.Continuous:
359            return d.value
360        elif d.var_type == Descriptor.Discrete:
361            return str(d.value)
362        else:
363            return "'%s'" % str(d.value)
364
365    def __attrName2sql(self, d):
366        return d.name
367
368    def __attrType2sql(self, d):
369        return self.quirks.typeDict[d]
370
371    @deprecated_keywords({"renameDict":"rename_dict"})
372    def write(self, table, instances, rename_dict = None):
373        """
374        Writes the data into the table.
375
376
377        :param table: Table name.
378        :type table: str
379
380        :param instances: Data to be written into the database.
381        :type instances: :class:`Orange.data.Table`
382
383        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
384            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
385        :type rename_dict: dict
386
387        """
388        l = [i.name for i in instances.domain.attributes]
389        l += [i.name for i in instances.domain.get_metas().values()]
390        if instances.domain.class_var:
391            l.append(instances.domain.class_var.name)
392        if rename_dict is None:
393            rename_dict = {}
394        colList = []
395        for i in l:
396            colList.append(rename_dict.get(str(i), str(i)))
397        try:
398            cursor=self.connection.cursor()
399            self.quirks.beforeWrite(cursor)
400            query = 'INSERT INTO "%s" (%s) VALUES (%s);'
401            for d in instances:
402                valList = []
403                colSList = []
404                for (i, name) in enumerate(colList):
405                    colSList.append('"%s"'% name)
406                    valList.append(self.__attrVal2sql(d[l[i]]))
407                d = query % (table,
408                    ", ".join(colSList),
409                    ", ".join ([self.quirks.parameter] * len(valList)))
410                cursor.execute(d, tuple(valList))
411            cursor.close()
412            self.connection.commit()
413        except Exception, e:
414            import traceback
415            traceback.print_exc()
416            self.connection.rollback()
417
418    @deprecated_keywords({"renameDict":"rename_dict", "typeDict":"type_dict"})
419    def create(self, table, instances, rename_dict = {}, type_dict = {}):
420        """
421        Create the required SQL table, then write the data into it.
422
423        :param table: Table name
424        :type table: str
425
426        :param instances: Data to be written into the database.
427        :type instances: :class:`Orange.data.Table`
428
429        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
430            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
431        :type rename_dict: dict
432
433        :param type_dict: When ``type_dict`` is provided the used variables are casted into new types.
434            The type of orange attribute "X" is casted into the database column of type rename_dict["X"].
435        :type type_dict: dict
436
437        """
438        l = [(i.name, i.var_type ) for i in instances.domain.attributes]
439        l += [(i.name, i.var_type ) for i in instances.domain.get_metas().values()]
440        if instances.domain.class_var:
441            l.append((instances.domain.class_var.name, instances.domain.class_var.var_type))
442        #if rename_dict is None:
443        #    rename_dict = {}
444        colNameList = [rename_dict.get(str(i[0]), str(i[0])) for i in l]
445        #if type_dict is None:
446        #    typeDict = {}
447        colTypeList = [type_dict.get(str(i[0]), self.__attrType2sql(i[1])) for i in l]
448        try:
449            cursor = self.connection.cursor()
450            colSList = []
451            for (i, name) in enumerate(colNameList):
452                colSList.append('"%s" %s' % (name, colTypeList[i]))
453            colStr = ", ".join(colSList)
454            query = """CREATE TABLE "%s" ( %s );""" % (table, colStr)
455            self.quirks.beforeCreate(cursor)
456            cursor.execute(query)
457            self.write(table, instances, rename_dict)
458            self.connection.commit()
459        except Exception, e:
460            self.connection.rollback()
461
462    def disconnect(self):
463        """
464        Disconnect from the database.
465        """
466        func = getattr(self.conn, "disconnect", None)
467        if callable(func):
468            self.conn.disconnect()
469
470def loadSQL(filename, dontCheckStored = False, domain = None):
471    f = open(filename)
472    lines = f.readlines()
473    queryLines = []
474    discreteNames = None
475    uri = None
476    metaNames = None
477    className = None
478    for i in lines:
479        if i.startswith("--orng"):
480            (dummy, command, line) = i.split(None, 2)
481            if command == 'uri':
482                uri = eval(line)
483            elif command == 'discrete':
484                discreteNames = eval(line)
485            elif command == 'meta':
486                metaNames = eval(line)
487            elif command == 'class':
488                className = eval(line)
489            else:
490                queryLines.append(i)
491        else:
492            queryLines.append(i)
493    query = "\n".join(queryLines)
494    r = SQLReader(uri)
495    if discreteNames:
496        r.discreteNames = discreteNames
497    if className:
498        r.className = className
499    if metaNames:
500        r.metaNames = metaNames
501    r.execute(query)
502    return r.data()
Note: See TracBrowser for help on using the repository browser.