Changeset 10109:0724d77247e4 in orange


Ignore:
Timestamp:
02/08/12 17:51:25 (2 years ago)
Author:
crt.gorup@…
Branch:
default
rebase_source:
74befbd95536e2109fa552f02aa5580f7c050f3c
Message:

Fixed imports.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • Orange/orng/orngSQL.py

    r9671 r10109  
    1 # A module to read data from an SQL database into an Orange ExampleTable. 
    2 # The goal is to keep it compatible with PEP 249. 
    3 # For now, the writing shall be basic, if it works at all. 
    4  
    5 import orange 
    6 import os 
    7 import urllib 
    8  
    9 def _parseURI(uri): 
    10     """ lifted straight from sqlobject """ 
    11     schema, rest = uri.split(':', 1) 
    12     assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest 
    13     if rest.startswith('/') and not rest.startswith('//'): 
    14         host = None 
    15         rest = rest[1:] 
    16     elif rest.startswith('///'): 
    17         host = None 
    18         rest = rest[3:] 
    19     else: 
    20         rest = rest[2:] 
    21         if rest.find('/') == -1: 
    22             host = rest 
    23             rest = '' 
    24         else: 
    25             host, rest = rest.split('/', 1) 
    26     if host and host.find('@') != -1: 
    27         user, host = host.split('@', 1) 
    28         if user.find(':') != -1: 
    29             user, password = user.split(':', 1) 
    30         else: 
    31             password = None 
    32     else: 
    33         user = password = None 
    34     if host and host.find(':') != -1: 
    35         _host, port = host.split(':') 
    36         try: 
    37             port = int(port) 
    38         except ValueError: 
    39             raise ValueError, "port must be integer, got '%s' instead" % port 
    40         if not (1 <= port <= 65535): 
    41             raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port 
    42         host = _host 
    43     else: 
    44         port = None 
    45     path = '/' + rest 
    46     if os.name == 'nt': 
    47         if (len(rest) > 1) and (rest[1] == '|'): 
    48             path = "%s:%s" % (rest[0], rest[2:]) 
    49     args = {} 
    50     if path.find('?') != -1: 
    51         path, arglist = path.split('?', 1) 
    52         arglist = arglist.split('&') 
    53         for single in arglist: 
    54             argname, argvalue = single.split('=', 1) 
    55             argvalue = urllib.unquote(argvalue) 
    56             args[argname] = argvalue 
    57     return schema, user, password, host, port, path, args 
    58  
    59 class __DummyQuirkFix: 
    60     def __init__(self, dbmod): 
    61         self.dbmod = dbmod 
    62         self.typeDict = { 
    63             orange.VarTypes.Continuous:'FLOAT',  
    64             orange.VarTypes.Discrete:'VARCHAR(250)', orange.VarTypes.String:'VARCHAR(250)'} 
    65     def beforeWrite(self, cursor): 
    66         pass 
    67     def beforeCreate(self, cursor): 
    68         pass 
    69     def beforeRead(self, cursor): 
    70         pass 
    71 class __MySQLQuirkFix(__DummyQuirkFix): 
    72     def __init__(self, dbmod): 
    73         self.dbmod = dbmod 
    74         self.BOOLEAN = None 
    75         self.STRING = dbmod.STRING 
    76         self.DATETIME = dbmod.DATETIME 
    77         self.typeDict = { 
    78             orange.VarTypes.Continuous:'DOUBLE',  
    79             orange.VarTypes.Discrete:'VARCHAR(250)', orange.VarTypes.String:'VARCHAR(250)'} 
    80     def beforeWrite(self, cursor): 
    81         cursor.execute("SET sql_mode='ANSI_QUOTES';") 
    82     def beforeCreate(self, cursor): 
    83         cursor.execute("SET sql_mode='ANSI_QUOTES';") 
    84     def beforeRead(self, cursor): 
    85         pass 
    86 class __PostgresQuirkFix(__DummyQuirkFix): 
    87     def __init__(self, dbmod): 
    88         self.dbmod = dbmod 
    89         self.BOOLEAN = 16 
    90         self.STRING = dbmod.STRING 
    91         self.DATETIME = dbmod.DATETIME 
    92         self.typeDict = { 
    93             orange.VarTypes.Continuous:'FLOAT',  
    94             orange.VarTypes.Discrete:'VARCHAR', orange.VarTypes.String:'VARCHAR'} 
    95     def beforeWrite(self, cursor): 
    96         pass 
    97     def beforeCreate(self, cursor): 
    98         pass 
    99     def beforeRead(self, cursor): 
    100         pass 
    101  
    102 def _connection(uri): 
    103         """the uri string's syntax is the same as that of sqlobject. 
    104         Unfortunately, only postgres and mysql are going to be supported in 
    105         the near future. 
    106         scheme://[user[:password]@]host[:port]/database[?parameters] 
    107         Examples: 
    108         mysql://user:password@host/database 
    109         mysql://host/database?debug=1 
    110         postgres://user@host/database?debug=&cache= 
    111         postgres:///full/path/to/socket/database 
    112         postgres://host:5432/database 
    113         """ 
    114         (schema, user, password, host, port, path, args) = _parseURI(uri) 
    115         if schema == 'postgres': 
    116             import psycopg2 as dbmod 
    117             argTrans = { 
    118             'host':'host', 
    119             'port':'port', 
    120             'user':'user', 
    121             'password':'password', 
    122             'database':'database' 
    123             } 
    124             quirks = __PostgresQuirkFix(dbmod) 
    125         elif schema == 'mysql': 
    126             import MySQLdb as dbmod 
    127             argTrans = { 
    128             'host':'host', 
    129             'port':'port', 
    130             'user':'user', 
    131             'password':'passwd', 
    132             'database':'db' 
    133             } 
    134             quirks = __MySQLQuirkFix(dbmod) 
    135         dbArgDict = {} 
    136         if user: 
    137             dbArgDict[argTrans['user']] = user 
    138         if password: 
    139             dbArgDict[argTrans['password']] = password 
    140         if host: 
    141             dbArgDict[argTrans['host']] = host 
    142         if port: 
    143             dbArgDict[argTrans['port']] = port 
    144         if path: 
    145             dbArgDict[argTrans['database']] = path[1:] 
    146         return (quirks, dbmod.connect(**dbArgDict)) 
    147  
    148 class SQLReader(object): 
    149     def __init__(self, addr = None, domainDepot = None): 
    150         if addr is not None: 
    151             self.connect(addr) 
    152         if domainDepot is not None: 
    153             self.domainDepot = domainDepot 
    154         else: 
    155             self.domainDepot = orange.DomainDepot() 
    156         self.exampleTable = None 
    157         self._dirty = True 
    158     def connect(self, uri): 
    159         self._dirty = True 
    160         self.delDomain() 
    161         (self.quirks, self.conn) = _connection(uri) 
    162     def disconnect(self): 
    163         self.conn.disconnect() 
    164     def getClassName(self): 
    165         self.update() 
    166         return self.domain.classVar.name 
    167     def setClassName(self, className): 
    168         self._className = className 
    169         self.delDomain() 
    170     def delClassName(self): 
    171         del self._className 
    172     className = property(getClassName, setClassName, delClassName, "the name of the class variable") 
    173  
    174     def getMetaNames(self): 
    175         self.update() 
    176         return self.domain.getmetas().values() 
    177     def setMetaNames(self, metaNames): 
    178         self._metaNames = metaNames 
    179         self.delDomain() 
    180     def delMetaNames(self): 
    181         del self._metaNames 
    182     metaNames = property(getMetaNames, setMetaNames, delMetaNames, "the names of the meta attributes") 
    183  
    184     def setDiscreteNames(self, discreteNames): 
    185         self._discreteNames = discreteNames 
    186         self.delDomain() 
    187     def getDiscreteNames(self): 
    188         self.update() 
    189         return self._discreteNames 
    190     def delDiscreteNames(self): 
    191         del self._discreteNames 
    192     discreteNames = property(getDiscreteNames, setDiscreteNames, delDiscreteNames, "the names of the discrete attributes") 
    193  
    194     def setQuery(self, query, domain = None): 
    195         """sets the query, resets the internal variables, without executing the query""" 
    196         self._query = query 
    197         self._dirty = True 
    198         if domain is not None: 
    199             self._domain = domain 
    200         else: 
    201             self.delDomain() 
    202     def getQuery(self): 
    203         return self._query 
    204     def delQuery(self): 
    205         del self._query 
    206     query = property(getQuery, setQuery, delQuery, "The query to be executed on the next execute()") 
    207     def generateDomain(self): 
    208         pass 
    209     def setDomain(self, domain): 
    210         self._domain = domain 
    211         self._dirty = True 
    212     def getDomain(self): 
    213         if not hasattr(self, '_domain'): 
    214             self._createDomain() 
    215         return self._domain 
    216     def delDomain(self): 
    217         if hasattr(self, '_domain'): 
    218             del self._domain 
    219     domain = property(getDomain, setDomain, delDomain, "the Orange domain") 
    220     def execute(self, query, domain = None): 
    221         """executes an sql query""" 
    222         self.setQuery(query, domain) 
    223         self.update() 
    224          
    225     def _createDomain(self): 
    226         if hasattr(self, '_domain'): 
    227             return 
    228         attrNames = [] 
    229         if not hasattr(self, '_discreteNames'): 
    230             self._discreteNames = [] 
    231         discreteNames = self._discreteNames 
    232         if not hasattr(self, '_metaNames'): 
    233             self._metaNames = [] 
    234         metaNames = self._metaNames 
    235         if not hasattr(self, '_className'): 
    236             className = None 
    237         else: 
    238             className = self._className 
    239         for i in self.desc: 
    240             name = i[0] 
    241             typ = i[1] 
    242             if name in discreteNames or typ == self.quirks.BOOLEAN: 
    243                 attrName = 'D#' + name 
    244             elif typ == self.quirks.STRING: 
    245                     attrName = 'S#' + name 
    246             elif typ == self.quirks.DATETIME: 
    247                 attrName = 'S#' + name 
    248             else: 
    249                 attrName = 'C#' + name 
    250             if name == className: 
    251                 attrName = "c" + attrName 
    252             elif name in metaNames: 
    253                 attrName = "m" + attrName 
    254             elif not className and name == 'class': 
    255                 attrName = "c" + attrName 
    256             attrNames.append(attrName) 
    257     #       print "NAME:", '"%s"' % name, ", t:", typ, " attrN:", '"%s"' % attrName 
    258         (self._domain, self._metaIDs, dummy) = self.domainDepot.prepareDomain(attrNames) 
    259  #           print "Created domain." 
    260         del dummy 
    261  
    262          
    263     def update(self): 
    264         if not self._dirty and hasattr(self, '_domain'): 
    265             return self.exampleTable 
    266         self.exampleTable = None 
    267         try: 
    268             curs = self.conn.cursor() 
    269             try: 
    270                 self.quirks.beforeRead(curs) 
    271                 curs.execute(self.query) 
    272             except Exception, e: 
    273                 self.conn.rollback() 
    274                 raise e 
    275             self.desc = curs.description 
    276             # for reasons unknown, the attributes get reordered. 
    277             domainIndexes = [0] * len(self.desc) 
    278             self._createDomain() 
    279             attrNames = [] 
    280             for i, name in enumerate(self.desc): 
    281             #    print name[0], '->', self.domain.index(name[0]) 
    282                 domainIndexes[self._domain.index(name[0])] = i 
    283                 attrNames.append(name[0]) 
    284             self.exampleTable = orange.ExampleTable(self.domain) 
    285             r = curs.fetchone() 
    286             while r: 
    287                 # for reasons unknown, domain rearranges the properties 
    288                 example = orange.Example(self.domain) 
    289                 for i in xrange(len(r)): 
    290                     if r[i] is not None: 
    291                         val = str(r[i]) 
    292                         var = example[attrNames[i]].variable 
    293                         if type(var) == orange.EnumVariable and val not in var.values: 
    294                             var.values.append(val) 
    295                         example[attrNames[i]] = str(r[i]) 
    296                 self.exampleTable.append(example) 
    297                 r = curs.fetchone() 
    298             self._dirty = False 
    299         except Exception, e: 
    300             self.domain = None 
    301             raise 
    302             #self.domain = None 
    303  
    304     def data(self): 
    305         self.update() 
    306         if self.exampleTable: 
    307             return self.exampleTable 
    308         return None 
    309      
    310 class SQLWriter(object): 
    311     def __init__(self, uri = None): 
    312         if uri is not None: 
    313             self.connect(uri) 
    314      
    315     def connect(self, uri): 
    316         (self.quirks, self.connection) = _connection(uri) 
    317     def __attrVal2sql(self, d): 
    318         if d.varType == orange.VarTypes.Continuous: 
    319             return d.value 
    320         elif d.varType == orange.VarTypes.Discrete: 
    321             return str(d.value) 
    322         else: 
    323             return "'%s'" % str(d.value) 
    324     def __attrName2sql(self, d): 
    325         return d.name 
    326     def __attrType2sql(self, d): 
    327         return self.quirks.typeDict[d] 
    328     def write(self, table, data, renameDict = None): 
    329         """if provided, renameDict maps the names in data to columns in 
    330         the database. For each var in data: dbColName = renameDict[var.name]""" 
    331         l = [i.name for i in data.domain.attributes] 
    332         l += [i.name for i in data.domain.getmetas().values()] 
    333         if data.domain.classVar: 
    334             l.append(data.domain.classVar.name) 
    335         if renameDict is None: 
    336             renameDict = {} 
    337         colList = [] 
    338         for i in l: 
    339             colList.append(renameDict.get(str(i), str(i))) 
    340         try: 
    341             cursor=self.connection.cursor() 
    342             self.quirks.beforeWrite(cursor) 
    343             query = 'INSERT INTO "%s" (%s) VALUES (%s);' 
    344             for d in data: 
    345                 valList = [] 
    346                 colSList = [] 
    347                 for (i, name) in enumerate(colList): 
    348                     colSList.append('"%s"'% name) 
    349                     valList.append(self.__attrVal2sql(d[l[i]])) 
    350                 valStr = ', '.join(["%s"]*len(colList)) 
    351                 # print "exec:", query % (table, "%s ", "%s "), tuple(colList + valList) 
    352                 cursor.execute(query % (table,  
    353                     ", ".join(colSList),  
    354                     ", ".join (["%s"] * len(valList))), tuple(valList)) 
    355             cursor.close() 
    356             self.connection.commit() 
    357         except Exception, e: 
    358             import traceback 
    359         traceback.print_exc() 
    360             self.connection.rollback() 
    361  
    362     def create(self, table, data, renameDict = None, typeDict = None): 
    363         l = [(i.name, i.varType ) for i in data.domain.attributes] 
    364         l += [(i.name, i.varType ) for i in data.domain.getmetas().values()] 
    365         if data.domain.classVar: 
    366             l.append((data.domain.classVar.name, data.domain.classVar.varType)) 
    367         if renameDict is None: 
    368             renameDict = {} 
    369         colNameList = [renameDict.get(str(i[0]), str(i[0])) for i in l] 
    370         if typeDict is None: 
    371             typeDict = {} 
    372         colTypeList = [typeDict.get(str(i[0]), self.__attrType2sql(i[1])) for i in l] 
    373         try: 
    374             cursor = self.connection.cursor() 
    375             colSList = [] 
    376             for (i, name) in enumerate(colNameList): 
    377                 colSList.append('"%s" %s' % (name, colTypeList[i])) 
    378             colStr = ", ".join(colSList) 
    379             query = """CREATE TABLE "%s" ( %s );""" % (table, colStr) 
    380             self.quirks.beforeCreate(cursor) 
    381             cursor.execute(query) 
    382             print query 
    383             self.write(table, data, renameDict) 
    384             self.connection.commit() 
    385         except Exception, e: 
    386             self.connection.rollback() 
    387      
    388     def disconnect(self): 
    389         self.conn.disconnect() 
    390  
    391 def loadSQL(filename, dontCheckStored = False, domain = None): 
    392     f = open(filename) 
    393     lines = f.readlines() 
    394     queryLines = [] 
    395     discreteNames = None 
    396     uri = None 
    397     metaNames = None 
    398     className = None 
    399     for i in lines: 
    400         if i.startswith("--orng"): 
    401             (dummy, command, line) = i.split(None, 2) 
    402             if command == 'uri': 
    403                 uri = eval(line) 
    404             elif command == 'discrete': 
    405                 discreteNames = eval(line) 
    406             elif command == 'meta': 
    407                 metaNames = eval(line) 
    408             elif command == 'class': 
    409                 className = eval(line) 
    410             else: 
    411                 queryLines.append(i) 
    412         else: 
    413             queryLines.append(i) 
    414     query = "\n".join(queryLines) 
    415     r = SQLReader(uri) 
    416     if discreteNames: 
    417         r.discreteNames = discreteNames 
    418     if className: 
    419         r.className = className 
    420     if metaNames: 
    421         r.metaNames = metaNames 
    422     r.execute(query) 
    423     data = r.data() 
    424     return data 
    425  
    426 def saveSQL(): 
    427     pass 
     1from Orange.data.sql import * 
Note: See TracChangeset for help on using the changeset viewer.