source: orange/Orange/data/sql.py @ 10580:c4cbae8dcf8b

Revision 10580:c4cbae8dcf8b, 16.4 KB checked in by markotoplak, 2 years ago (diff)

Moved deprecation functions, progress bar support and environ into Orange.utils. Orange imports cleanly, although it is not tested yet.

Line 
1import os
2import urllib
3import Orange
4import orange
5from Orange.utils import deprecated_keywords, deprecated_members
6from Orange.feature import Descriptor
7
8def _parseURI(uri):
9    """ lifted straight from sqlobject """
10    schema, rest = uri.split(':', 1)
11    assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
12    if rest.startswith('/') and not rest.startswith('//'):
13        host = None
14        rest = rest[1:]
15    elif rest.startswith('///'):
16        host = None
17        rest = rest[3:]
18    else:
19        rest = rest[2:]
20        if rest.find('/') == -1:
21            host = rest
22            rest = ''
23        else:
24            host, rest = rest.split('/', 1)
25    if host and host.find('@') != -1:
26        user, host = host.split('@', 1)
27        if user.find(':') != -1:
28            user, password = user.split(':', 1)
29        else:
30            password = None
31    else:
32        user = password = None
33    if host and host.find(':') != -1:
34        _host, port = host.split(':')
35        try:
36            port = int(port)
37        except ValueError:
38            raise ValueError, "port must be integer, got '%s' instead" % port
39        if not (1 <= port <= 65535):
40            raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
41        host = _host
42    else:
43        port = None
44    path = '/' + rest
45    if os.name == 'nt':
46        if (len(rest) > 1) and (rest[1] == '|'):
47            path = "%s:%s" % (rest[0], rest[2:])
48    args = {}
49    if path.find('?') != -1:
50        path, arglist = path.split('?', 1)
51        arglist = arglist.split('&')
52        for single in arglist:
53            argname, argvalue = single.split('=', 1)
54            argvalue = urllib.unquote(argvalue)
55            args[argname] = argvalue
56    return schema, user, password, host, port, path, args
57
58class __MySQLQuirkFix(object):
59    def __init__(self, dbmod):
60        self.dbmod = dbmod
61        self.typeDict = {
62            Descriptor.Continuous:'DOUBLE',
63            Descriptor.Discrete:'VARCHAR(250)', Descriptor.String:'VARCHAR(250)'}
64
65    def beforeWrite(self, cursor):
66        cursor.execute("SET sql_mode='ANSI_QUOTES';")
67
68    def beforeCreate(self, cursor):
69        cursor.execute("SET sql_mode='ANSI_QUOTES';")
70
71class __PostgresQuirkFix(object):
72    def __init__(self, dbmod):
73        self.dbmod = dbmod
74        self.typeDict = {
75            Descriptor.Continuous:'FLOAT',
76            Descriptor.Discrete:'VARCHAR', Descriptor.String:'VARCHAR'}
77
78    def beforeWrite(self, cursor):
79        pass
80
81    def beforeCreate(self, cursor):
82        pass
83
84def _connection(uri):
85        (schema, user, password, host, port, path, args) = _parseURI(uri)
86        argTrans = {
87            'host':'host',
88            'port':'port',
89            'user':'user',
90            'password':'passwd',
91            'database':'db'
92            }
93        if schema == 'postgres':
94            import psycopg2 as dbmod
95            argTrans["database"] = "db"
96            quirks = __PostgresQuirkFix(dbmod)
97            quirks.parameter = "%s"
98        elif schema == 'mysql':
99            import MySQLdb as dbmod
100            quirks = __MySQLQuirkFix(dbmod)
101            quirks.parameter = "%s"
102        elif schema == "sqlite":
103            import sqlite3 as dbmod
104            quirks = __PostgresQuirkFix(dbmod)
105            quirks.parameter = "?"
106            return (quirks, dbmod.connect(host))
107
108        dbArgDict = {}
109        if user:
110            dbArgDict[argTrans['user']] = user
111        if password:
112            dbArgDict[argTrans['password']] = password
113        if host:
114            dbArgDict[argTrans['host']] = host
115        if port:
116            dbArgDict[argTrans['port']] = port
117        if path:
118            dbArgDict[argTrans['database']] = path[1:]
119        return (quirks, dbmod.connect(**dbArgDict))
120
121
122
123class SQLReader(object):
124    """
125    :obj:`~SQLReader` establishes a connection with a database and provides the methods needed
126    to fetch the data from the database into Orange.
127    """
128    @deprecated_keywords({"domainDepot":"domain_depot"})
129    def __init__(self, addr = None, domain_depot = None):
130        """
131        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters]).
132        :type uri: str
133
134        :param domain_depot: Domain depot
135        :type domain_depot: :class:`orange.DomainDepot`
136        """
137        if addr is not None:
138            self.connect(addr)
139        if domain_depot is not None:
140            self.domainDepot = domain_depot
141        else:
142            self.domainDepot = orange.DomainDepot()
143        self.exampleTable = None
144        self._dirty = True
145
146    def connect(self, uri):
147        """
148        Connect to the database.
149
150        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
151        :type uri: str
152        """
153        self._dirty = True
154        self.del_domain()
155        (self.quirks, self.conn) = _connection(uri)
156
157    def disconnect(self):
158        """
159        Disconnect from the database.
160        """
161        func = getattr(self.conn, "disconnect", None)
162        if callable(func):
163            self.conn.disconnect()
164
165    def get_class_name(self):
166        self.update()
167        return self.domain.class_var.name
168
169    def set_class_name(self, class_name):
170        self._className = class_name
171        self.del_domain()
172
173    def del_class_name(self):
174        del self._className
175
176    class_name = property(get_class_name, set_class_name, del_class_name, "Name of class variable.")
177
178    def get_metas_name(self):
179        self.update()
180        return self.domain.get_metas().values()
181
182    def set_metas_name(self, meta_names):
183        self._metaNames = meta_names
184        self.del_domain()
185
186    def del_metas_name(self):
187        del self._metaNames
188
189    meta_names = property(get_metas_name, set_metas_name, del_metas_name, "Names of meta attributes.")
190
191    def set_discrete_names(self, discrete_names):
192        self._discreteNames = discrete_names
193        self.del_domain()
194
195    def get_discrete_names(self):
196        self.update()
197        return self._discreteNames
198
199    def del_discrete_names(self):
200        del self._discreteNames
201
202    discrete_names = property(get_discrete_names, set_discrete_names, del_discrete_names, "Names of discrete attributes.")
203
204    def set_query(self, query, domain = None):
205        #sets the query, resets the internal variables, without executing the query
206        self._query = query
207        self._dirty = True
208        if domain is not None:
209            self._domain = domain
210        else:
211            self.del_domain()
212
213    def get_query(self):
214        return self._query
215
216    def del_query(self):
217        del self._query
218
219    query = property(get_query, set_query, del_query, "Query to be executed on the next execute().")
220
221    def generateDomain(self):
222        pass
223
224    def set_domain(self, domain):
225        self._domain = domain
226        self._dirty = True
227
228    def get_domain(self):
229        if not hasattr(self, '_domain'):
230            self._createDomain()
231        return self._domain
232
233    def del_domain(self):
234        if hasattr(self, '_domain'):
235            del self._domain
236
237    domain = property(get_domain, set_domain, del_domain, "Orange domain.")
238
239    def execute(self, query, domain = None):
240        """
241        Executes an sql query.
242        """
243        self.set_query(query, domain)
244        self.update()
245
246    def _createDomain(self):
247        if hasattr(self, '_domain'):
248            return
249        attrNames = []
250        if not hasattr(self, '_discreteNames'):
251            self._discreteNames = []
252        discreteNames = self._discreteNames
253        if not hasattr(self, '_metaNames'):
254            self._metaNames = []
255        metaNames = self._metaNames
256        if not hasattr(self, '_className'):
257            className = None
258        else:
259            className = self._className
260        for i in self.desc:
261            name = i[0]
262            typ = i[1]
263            if name in discreteNames:
264                attrName = 'D#' + name
265            elif typ is None or typ in [self.quirks.dbmod.STRING, self.quirks.dbmod.DATETIME]:
266                    attrName = 'S#' + name
267            else:
268                attrName = 'C#' + name
269           
270            if name == className:
271                attrName = "c" + attrName
272            elif name in metaNames:
273                attrName = "m" + attrName
274            elif not className and name == 'class':
275                attrName = "c" + attrName
276            attrNames.append(attrName)
277        (self._domain, self._metaIDs, dummy) = self.domainDepot.prepareDomain(attrNames)
278        del dummy
279
280    def update(self):
281        """
282        Execute a pending SQL query.
283        """
284        if not self._dirty and hasattr(self, '_domain'):
285            return self.exampleTable
286        self.exampleTable = None
287        try:
288            curs = self.conn.cursor()
289            try:
290                curs.execute(self.query)
291            except Exception, e:
292                self.conn.rollback()
293                raise e
294            self.desc = curs.description
295            # for reasons unknown, the attributes get reordered.
296            domainIndexes = [0] * len(self.desc)
297            self._createDomain()
298            attrNames = []
299            for i, name in enumerate(self.desc):
300            #    print name[0], '->', self.domain.index(name[0])
301                domainIndexes[self._domain.index(name[0])] = i
302                attrNames.append(name[0])
303            self.exampleTable = Orange.data.Table(self.domain)
304            r = curs.fetchone()
305            while r:
306                # for reasons unknown, domain rearranges the properties
307                example = Orange.data.Instance(self.domain)
308                for i in xrange(len(r)):
309                    val = str(r[i])
310                    var = example[attrNames[i]].variable
311                    if type(var) == Descriptor.Discrete and val not in var.values:
312                        var.values.append(val)
313                    example[attrNames[i]] = str(r[i])
314                self.exampleTable.append(example)
315                r = curs.fetchone()
316            self._dirty = False
317        except Exception, e:
318            self.domain = None
319            raise
320            #self.domain = None
321
322    def data(self):
323        """
324        Return :class:`Orange.data.Table` produced by the last executed query.
325        """
326        self.update()
327        if self.exampleTable:
328            return self.exampleTable
329        return None
330
331SQLReader = deprecated_members({"discreteNames":"discrete_names", "metaName":"meta_names"\
332    , "className":"class_name"})(SQLReader)
333
334class SQLWriter(object):
335    """
336    Establishes a connection with a database and provides the methods needed to create
337    an appropriate table in the database and/or write the data from an :class:`Orange.data.Table`
338    into the database.
339    """
340    def __init__(self, uri = None):
341        """
342        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
343        :type uri: str
344        """
345        if uri is not None:
346            self.connect(uri)
347
348    def connect(self, uri):
349        """
350        Connect to the database.
351
352        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
353        :type uri: str
354        """
355        (self.quirks, self.connection) = _connection(uri)
356
357    def __attrVal2sql(self, d):
358        if d.var_type == Descriptor.Continuous:
359            return d.value
360        elif d.var_type == Descriptor.Discrete:
361            return str(d.value)
362        else:
363            return "'%s'" % str(d.value)
364
365    def __attrName2sql(self, d):
366        return d.name
367
368    def __attrType2sql(self, d):
369        return self.quirks.typeDict[d]
370
371    @deprecated_keywords({"renameDict":"rename_dict"})
372    def write(self, table, instances, rename_dict = None):
373        """
374        Writes the data into the table.
375
376
377        :param table: Table name.
378        :type table: str
379
380        :param instances: Data to be written into the database.
381        :type instances: :class:`Orange.data.Table`
382
383        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
384            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
385        :type rename_dict: dict
386
387        """
388        l = [i.name for i in instances.domain.attributes]
389        l += [i.name for i in instances.domain.get_metas().values()]
390        if instances.domain.class_var:
391            l.append(instances.domain.class_var.name)
392        if rename_dict is None:
393            rename_dict = {}
394        colList = []
395        for i in l:
396            colList.append(rename_dict.get(str(i), str(i)))
397        try:
398            cursor=self.connection.cursor()
399            self.quirks.beforeWrite(cursor)
400            query = 'INSERT INTO "%s" (%s) VALUES (%s);'
401            for d in instances:
402                valList = []
403                colSList = []
404                for (i, name) in enumerate(colList):
405                    colSList.append('"%s"'% name)
406                    valList.append(self.__attrVal2sql(d[l[i]]))
407                d = query % (table,
408                    ", ".join(colSList),
409                    ", ".join ([self.quirks.parameter] * len(valList)))
410                cursor.execute(d, tuple(valList))
411            cursor.close()
412            self.connection.commit()
413        except Exception, e:
414            import traceback
415            traceback.print_exc()
416            self.connection.rollback()
417
418    @deprecated_keywords({"renameDict":"rename_dict", "typeDict":"type_dict"})
419    def create(self, table, instances, rename_dict = {}, type_dict = {}):
420        """
421        Create the required SQL table, then write the data into it.
422
423        :param table: Table name
424        :type table: str
425
426        :param instances: Data to be written into the database.
427        :type instances: :class:`Orange.data.Table`
428
429        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
430            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
431        :type rename_dict: dict
432
433        :param type_dict: When ``type_dict`` is provided the used variables are casted into new types.
434            The type of orange attribute "X" is casted into the database column of type rename_dict["X"].
435        :type type_dict: dict
436
437        """
438        l = [(i.name, i.var_type ) for i in instances.domain.attributes]
439        l += [(i.name, i.var_type ) for i in instances.domain.get_metas().values()]
440        if instances.domain.class_var:
441            l.append((instances.domain.class_var.name, instances.domain.class_var.var_type))
442        #if rename_dict is None:
443        #    rename_dict = {}
444        colNameList = [rename_dict.get(str(i[0]), str(i[0])) for i in l]
445        #if type_dict is None:
446        #    typeDict = {}
447        colTypeList = [type_dict.get(str(i[0]), self.__attrType2sql(i[1])) for i in l]
448        try:
449            cursor = self.connection.cursor()
450            colSList = []
451            for (i, name) in enumerate(colNameList):
452                colSList.append('"%s" %s' % (name, colTypeList[i]))
453            colStr = ", ".join(colSList)
454            query = """CREATE TABLE "%s" ( %s );""" % (table, colStr)
455            self.quirks.beforeCreate(cursor)
456            cursor.execute(query)
457            self.write(table, instances, rename_dict)
458            self.connection.commit()
459        except Exception, e:
460            self.connection.rollback()
461
462    def disconnect(self):
463        """
464        Disconnect from the database.
465        """
466        func = getattr(self.conn, "disconnect", None)
467        if callable(func):
468            self.conn.disconnect()
469
470def loadSQL(filename, dontCheckStored = False, domain = None):
471    f = open(filename)
472    lines = f.readlines()
473    queryLines = []
474    discreteNames = None
475    uri = None
476    metaNames = None
477    className = None
478    for i in lines:
479        if i.startswith("--orng"):
480            (dummy, command, line) = i.split(None, 2)
481            if command == 'uri':
482                uri = eval(line)
483            elif command == 'discrete':
484                discreteNames = eval(line)
485            elif command == 'meta':
486                metaNames = eval(line)
487            elif command == 'class':
488                className = eval(line)
489            else:
490                queryLines.append(i)
491        else:
492            queryLines.append(i)
493    query = "\n".join(queryLines)
494    r = SQLReader(uri)
495    if discreteNames:
496        r.discreteNames = discreteNames
497    if className:
498        r.className = className
499    if metaNames:
500        r.metaNames = metaNames
501    r.execute(query)
502    return r.data()
Note: See TracBrowser for help on using the repository browser.