source: orange/Orange/data/sql.py @ 11832:d3f10458113d

Revision 11832:d3f10458113d, 18.3 KB checked in by Slach@…, 3 months ago (diff)

adding ODBC support for SQL select widget

Line 
1import os
2import urllib
3import Orange
4import orange
5from Orange.utils import deprecated_keywords, deprecated_members
6from Orange.feature import Descriptor
7
8def _parseURI(uri):
9    """ lifted straight from sqlobject """
10    schema, rest = uri.split(':', 1)
11    assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
12    if rest.startswith('/') and not rest.startswith('//'):
13        host = None
14        rest = rest[1:]
15    elif rest.startswith('///'):
16        host = None
17        rest = rest[3:]
18    else:
19        rest = rest[2:]
20        if rest.find('/') == -1:
21            host = rest
22            rest = ''
23        else:
24            host, rest = rest.split('/', 1)
25    if host and host.find('@') != -1:
26        user, host = host.split('@', 1)
27        if user.find(':') != -1:
28            user, password = user.split(':', 1)
29        else:
30            password = None
31    else:
32        user = password = None
33    if host and host.find(':') != -1:
34        _host, port = host.split(':')
35        try:
36            port = int(port)
37        except ValueError:
38            raise ValueError, "port must be integer, got '%s' instead" % port
39        if not (1 <= port <= 65535):
40            raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
41        host = _host
42    else:
43        port = None
44    path = '/' + rest
45    if os.name == 'nt':
46        if (len(rest) > 1) and (rest[1] == '|'):
47            path = "%s:%s" % (rest[0], rest[2:])
48    args = {}
49    if path.find('?') != -1:
50        path, arglist = path.split('?', 1)
51        arglist = arglist.split('&')
52        for single in arglist:
53            argname, argvalue = single.split('=', 1)
54            argvalue = urllib.unquote(argvalue)
55            args[argname] = argvalue
56    return schema, user, password, host, port, path, args
57
58class __MySQLQuirkFix(object):
59    def __init__(self, dbmod):
60        self.dbmod = dbmod
61        self.typeDict = {
62            Descriptor.Continuous:'DOUBLE',
63            Descriptor.Discrete:'VARCHAR(250)', Descriptor.String:'VARCHAR(250)'}
64
65    def beforeWrite(self, cursor):
66        cursor.execute("SET sql_mode='ANSI_QUOTES';")
67
68    def beforeCreate(self, cursor):
69        cursor.execute("SET sql_mode='ANSI_QUOTES';")
70
71class __PostgresQuirkFix(object):
72    def __init__(self, dbmod):
73        self.dbmod = dbmod
74        self.typeDict = {
75            Descriptor.Continuous:'FLOAT',
76            Descriptor.Discrete:'VARCHAR', Descriptor.String:'VARCHAR'}
77
78    def beforeWrite(self, cursor):
79        pass
80
81    def beforeCreate(self, cursor):
82        pass
83
84class __ODBCQuirkFix(object):
85    def __init__(self, dbmod):
86        self.dbmod = dbmod
87        self.typeDict = {
88            Descriptor.Continuous:'FLOAT',
89            Descriptor.Discrete:'VARCHAR', Descriptor.String:'VARCHAR'}
90
91    def beforeWrite(self, cursor):
92        pass
93
94    def beforeCreate(self, cursor):
95        pass
96
97
98def _connection(uri):
99        (schema, user, password, host, port, path, args) = _parseURI(uri)
100        argTrans = {
101            'host':'host',
102            'port':'port',
103            'user':'user',
104            'password':'passwd',
105            'database':'db'
106            }
107        if schema == 'postgres':
108            argTrans["database"] = "db"
109        elif schema == 'odbc':
110            argTrans["host"] = "server"
111            argTrans["user"] = "uid"
112            argTrans["password"] = "pwd"
113            argTrans['database'] = 'database'
114
115        dbArgDict = {}
116        if user:
117            dbArgDict[argTrans['user']] = user
118        if password:
119            dbArgDict[argTrans['password']] = password
120        if host:
121            dbArgDict[argTrans['host']] = host
122        if port:
123            dbArgDict[argTrans['port']] = port
124        if path:
125            dbArgDict[argTrans['database']] = path[1:]
126
127        if schema == 'postgres':
128            import psycopg2 as dbmod
129            quirks = __PostgresQuirkFix(dbmod)
130            quirks.parameter = "%s"
131            return (quirks, dbmod.connect(**dbArgDict))
132        elif schema == 'mysql':
133            import MySQLdb as dbmod
134            quirks = __MySQLQuirkFix(dbmod)
135            quirks.parameter = "%s"
136            return (quirks, dbmod.connect(**dbArgDict))
137        elif schema == "sqlite":
138            import sqlite3 as dbmod
139            quirks = __PostgresQuirkFix(dbmod)
140            quirks.parameter = "?"
141            return (quirks, dbmod.connect(host))
142        elif schema == "odbc":
143            import pyodbc as dbmod
144            quirks = __ODBCQuirkFix(dbmod)
145            quirks.parameter = "?"
146            if args.has_key('DSN'):
147                connectionString = 'DSN=%s' % (args['DSN'])
148            elif args.has_key('Driver'):
149                connectionString = 'Driver=%s' % (args['Driver'])
150            else:
151                raise ValueError, "ODBC url schema must have DSN or Driver parameter"
152            for k in args:
153                if k not in ['DSN','Driver']:
154                    connectionString +=';%s=%s' % (k,args[k])
155            #print connectionString, dbArgDict
156            return (quirks, dbmod.connect(connectionString,**dbArgDict))
157
158class SQLReader(object):
159    """
160    :obj:`~SQLReader` establishes a connection with a database and provides the methods needed
161    to fetch the data from the database into Orange.
162    """
163    @deprecated_keywords({"domainDepot":"domain_depot"})
164    def __init__(self, addr = None, domain_depot = None):
165        """
166        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters]).
167        :type uri: str
168
169        :param domain_depot: Domain depot
170        :type domain_depot: :class:`orange.DomainDepot`
171        """
172        if addr is not None:
173            self.connect(addr)
174        if domain_depot is not None:
175            self.domainDepot = domain_depot
176        else:
177            self.domainDepot = orange.DomainDepot()
178        self.exampleTable = None
179        self._dirty = True
180
181    def connect(self, uri):
182        """
183        Connect to the database.
184
185        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
186        :type uri: str
187        """
188        self._dirty = True
189        self.del_domain()
190        (self.quirks, self.conn) = _connection(uri)
191
192    def disconnect(self):
193        """
194        Disconnect from the database.
195        """
196        func = getattr(self.conn, "disconnect", None)
197        if callable(func):
198            self.conn.disconnect()
199
200    def get_class_name(self):
201        self.update()
202        return self.domain.class_var.name
203
204    def set_class_name(self, class_name):
205        self._className = class_name
206        self.del_domain()
207
208    def del_class_name(self):
209        del self._className
210
211    class_name = property(get_class_name, set_class_name, del_class_name, "Name of class variable.")
212
213    def get_metas_name(self):
214        self.update()
215        return self.domain.get_metas().values()
216
217    def set_metas_name(self, meta_names):
218        self._metaNames = meta_names
219        self.del_domain()
220
221    def del_metas_name(self):
222        del self._metaNames
223
224    meta_names = property(get_metas_name, set_metas_name, del_metas_name, "Names of meta attributes.")
225
226    def set_discrete_names(self, discrete_names):
227        self._discreteNames = discrete_names
228        self.del_domain()
229
230    def get_discrete_names(self):
231        self.update()
232        return self._discreteNames
233
234    def del_discrete_names(self):
235        del self._discreteNames
236
237    discrete_names = property(get_discrete_names, set_discrete_names, del_discrete_names, "Names of discrete attributes.")
238
239    def set_query(self, query, domain = None):
240        #sets the query, resets the internal variables, without executing the query
241        self._query = query
242        self._dirty = True
243        if domain is not None:
244            self._domain = domain
245        else:
246            self.del_domain()
247
248    def get_query(self):
249        return self._query
250
251    def del_query(self):
252        del self._query
253
254    query = property(get_query, set_query, del_query, "Query to be executed on the next execute().")
255
256    def generateDomain(self):
257        pass
258
259    def set_domain(self, domain):
260        self._domain = domain
261        self._dirty = True
262
263    def get_domain(self):
264        if not hasattr(self, '_domain') or self._domain is None:
265            self._createDomain()
266        return self._domain
267
268    def del_domain(self):
269        if hasattr(self, '_domain'):
270            del self._domain
271
272    domain = property(get_domain, set_domain, del_domain, "Orange domain.")
273
274    def execute(self, query, domain = None):
275        """
276        Executes an sql query.
277        """
278        self.set_query(query, domain)
279        self.update()
280
281    def _createDomain(self):
282        if hasattr(self, '_domain') and not self._domain is None:
283            return
284        attrNames = []
285        if not hasattr(self, '_discreteNames'):
286            self._discreteNames = []
287        discreteNames = self._discreteNames
288        if not hasattr(self, '_metaNames'):
289            self._metaNames = []
290        metaNames = self._metaNames
291        if not hasattr(self, '_className'):
292            className = None
293        else:
294            className = self._className
295        for i in self.desc:
296            name = i[0]
297            typ = i[1]
298            if name in discreteNames:
299                attrName = 'D#' + name
300            elif typ is None or typ in [unicode, self.quirks.dbmod.STRING, self.quirks.dbmod.DATETIME]:
301                    attrName = 'S#' + name
302            else:
303                attrName = 'C#' + name
304
305            if name == className:
306                attrName = "c" + attrName
307            elif name in metaNames:
308                attrName = "m" + attrName
309            elif not className and name == 'class':
310                attrName = "c" + attrName
311            attrNames.append(attrName)
312        (self._domain, self._metaIDs, dummy) = self.domainDepot.prepareDomain(attrNames)
313        del dummy
314
315    def update(self):
316        """
317        Execute a pending SQL query.
318        """
319        if not self._dirty and hasattr(self, '_domain') and not self._domain is None:
320            return self.exampleTable
321        self.exampleTable = None
322        try:
323            curs = self.conn.cursor()
324            try:
325                curs.execute(self.query)
326            except Exception, e:
327                self.conn.rollback()
328                raise e
329            self.desc = curs.description
330            # for reasons unknown, the attributes get reordered.
331            domainIndexes = [0] * len(self.desc)
332            self._createDomain()
333            attrNames = []
334            for i, name in enumerate(self.desc):
335                #print name[0], '->', self._domain.index(name[0])
336                domainIndexes[self._domain.index(name[0])] = i
337                attrNames.append(name[0])
338            self.exampleTable = Orange.data.Table(self.domain)
339            r = curs.fetchone()
340            while r:
341                # for reasons unknown, domain rearranges the properties
342                example = Orange.data.Instance(self.domain)
343                for i in xrange(len(r)):
344                    val = str(r[i])
345                    var = example[attrNames[i]].variable
346                    if type(var) == Descriptor.Discrete and val not in var.values:
347                        var.values.append(val)
348                    example[attrNames[i]] = str(r[i])
349                self.exampleTable.append(example)
350                r = curs.fetchone()
351            self._dirty = False
352        except Exception, e:
353            self.domain = None
354            raise
355            #self.domain = None
356
357    def data(self):
358        """
359        Return :class:`Orange.data.Table` produced by the last executed query.
360        """
361        self.update()
362        if self.exampleTable:
363            return self.exampleTable
364        return None
365
366SQLReader = deprecated_members({"discreteNames":"discrete_names", "metaName":"meta_names"\
367    , "className":"class_name"})(SQLReader)
368
369class SQLWriter(object):
370    """
371    Establishes a connection with a database and provides the methods needed to create
372    an appropriate table in the database and/or write the data from an :class:`Orange.data.Table`
373    into the database.
374    """
375    def __init__(self, uri = None):
376        """
377        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
378        :type uri: str
379        """
380        if uri is not None:
381            self.connect(uri)
382
383    def connect(self, uri):
384        """
385        Connect to the database.
386
387        :param uri: Connection string (scheme://[user[:password]@]host[:port]/database[?parameters])
388        :type uri: str
389        """
390        (self.quirks, self.connection) = _connection(uri)
391
392    def __attrVal2sql(self, d):
393        if d.var_type == Descriptor.Continuous:
394            return d.value
395        elif d.var_type == Descriptor.Discrete:
396            return str(d.value)
397        else:
398            return "'%s'" % str(d.value)
399
400    def __attrName2sql(self, d):
401        return d.name
402
403    def __attrType2sql(self, d):
404        return self.quirks.typeDict[d]
405
406    @deprecated_keywords({"renameDict":"rename_dict"})
407    def write(self, table, instances, rename_dict = None):
408        """
409        Writes the data into the table.
410
411
412        :param table: Table name.
413        :type table: str
414
415        :param instances: Data to be written into the database.
416        :type instances: :class:`Orange.data.Table`
417
418        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
419            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
420        :type rename_dict: dict
421
422        """
423        l = [i.name for i in instances.domain.attributes]
424        l += [i.name for i in instances.domain.get_metas().values()]
425        if instances.domain.class_var:
426            l.append(instances.domain.class_var.name)
427        if rename_dict is None:
428            rename_dict = {}
429        colList = []
430        for i in l:
431            colList.append(rename_dict.get(str(i), str(i)))
432        try:
433            cursor=self.connection.cursor()
434            self.quirks.beforeWrite(cursor)
435            query = 'INSERT INTO "%s" (%s) VALUES (%s);'
436            for d in instances:
437                valList = []
438                colSList = []
439                for (i, name) in enumerate(colList):
440                    colSList.append('"%s"'% name)
441                    valList.append(self.__attrVal2sql(d[l[i]]))
442                d = query % (table,
443                    ", ".join(colSList),
444                    ", ".join ([self.quirks.parameter] * len(valList)))
445                cursor.execute(d, tuple(valList))
446            cursor.close()
447            self.connection.commit()
448        except Exception, e:
449            import traceback
450            traceback.print_exc()
451            self.connection.rollback()
452
453    @deprecated_keywords({"renameDict":"rename_dict", "typeDict":"type_dict"})
454    def create(self, table, instances, rename_dict = {}, type_dict = {}):
455        """
456        Create the required SQL table, then write the data into it.
457
458        :param table: Table name
459        :type table: str
460
461        :param instances: Data to be written into the database.
462        :type instances: :class:`Orange.data.Table`
463
464        :param rename_dict: When ``rename_dict`` is provided the used names are remapped.
465            The orange attribute "X" is written into the database column rename_dict["X"] of the table.
466        :type rename_dict: dict
467
468        :param type_dict: When ``type_dict`` is provided the used variables are casted into new types.
469            The type of orange attribute "X" is casted into the database column of type rename_dict["X"].
470        :type type_dict: dict
471
472        """
473        l = [(i.name, i.var_type ) for i in instances.domain.attributes]
474        l += [(i.name, i.var_type ) for i in instances.domain.get_metas().values()]
475        if instances.domain.class_var:
476            l.append((instances.domain.class_var.name, instances.domain.class_var.var_type))
477        #if rename_dict is None:
478        #    rename_dict = {}
479        colNameList = [rename_dict.get(str(i[0]), str(i[0])) for i in l]
480        #if type_dict is None:
481        #    typeDict = {}
482        colTypeList = [type_dict.get(str(i[0]), self.__attrType2sql(i[1])) for i in l]
483        try:
484            cursor = self.connection.cursor()
485            colSList = []
486            for (i, name) in enumerate(colNameList):
487                colSList.append('"%s" %s' % (name, colTypeList[i]))
488            colStr = ", ".join(colSList)
489            query = """CREATE TABLE "%s" ( %s );""" % (table, colStr)
490            self.quirks.beforeCreate(cursor)
491            cursor.execute(query)
492            self.write(table, instances, rename_dict)
493            self.connection.commit()
494        except Exception, e:
495            self.connection.rollback()
496
497    def disconnect(self):
498        """
499        Disconnect from the database.
500        """
501        func = getattr(self.conn, "disconnect", None)
502        if callable(func):
503            self.conn.disconnect()
504
505def loadSQL(filename, dontCheckStored = False, domain = None):
506    f = open(filename)
507    lines = f.readlines()
508    queryLines = []
509    discreteNames = None
510    uri = None
511    metaNames = None
512    className = None
513    for i in lines:
514        if i.startswith("--orng"):
515            (dummy, command, line) = i.split(None, 2)
516            if command == 'uri':
517                uri = eval(line)
518            elif command == 'discrete':
519                discreteNames = eval(line)
520            elif command == 'meta':
521                metaNames = eval(line)
522            elif command == 'class':
523                className = eval(line)
524            else:
525                queryLines.append(i)
526        else:
527            queryLines.append(i)
528    query = "\n".join(queryLines)
529    r = SQLReader(uri)
530    if discreteNames:
531        r.discreteNames = discreteNames
532    if className:
533        r.className = className
534    if metaNames:
535        r.metaNames = metaNames
536    r.execute(query)
537    return r.data()
Note: See TracBrowser for help on using the repository browser.