d855c881b052aba07dc9c30a5bd9484c93eb023f
[tiramisu.git] / tiramisu / storage / sqlite3 / storage.py
1 # -*- coding: utf-8 -*-
2 "default plugin for cache: set it in a simple dictionary"
3 # Copyright (C) 2013 Team tiramisu (see AUTHORS for all contributors)
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 2 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
18 #
19 # ____________________________________________________________
20
21 from pickle import dumps, loads
22 from os import unlink
23 from os.path import basename, splitext, join
24 import sqlite3
25 from glob import glob
26
27
28 class Setting(object):
29     extension = 'db'
30     dir_database = '/tmp'
31
32
33 setting = Setting()
34
35
36 def _gen_filename(name):
37     return join(setting.dir_database, '{0}.{1}'.format(name,
38                                                        setting.extension))
39
40
41 def list_sessions():
42     names = []
43     for filename in glob(_gen_filename('*')):
44         names.append(basename(splitext(filename)[0]))
45     return names
46
47
48 def delete_session(session_id):
49     unlink(_gen_filename(session_id))
50
51
52 class Storage(object):
53     __slots__ = ('_conn', '_cursor', 'persistent', '_session_id')
54     storage = 'sqlite3'
55
56     def __init__(self, session_id, persistent):
57         self.persistent = persistent
58         self._session_id = session_id
59         self._conn = sqlite3.connect(_gen_filename(self._session_id))
60         self._conn.text_factory = str
61         self._cursor = self._conn.cursor()
62
63     def execute(self, sql, params=None, commit=True):
64         if params is None:
65             params = tuple()
66         self._cursor.execute(sql, params)
67         if commit:
68             self._conn.commit()
69
70     def select(self, sql, params=None, only_one=True):
71         self.execute(sql, params=params, commit=False)
72         if only_one:
73             return self._cursor.fetchone()
74         else:
75             return self._cursor.fetchall()
76
77     def __del__(self):
78         self._cursor.close()
79         self._conn.close()
80         if not self.persistent:
81             delete_session(self._session_id)
82
83
84 class Cache(object):
85     __slots__ = ('storage',)
86     key_is_path = True
87
88     def __init__(self, cache_type, storage):
89         self.storage = storage
90         cache_table = 'CREATE TABLE IF NOT EXISTS cache_{0}(path '.format(
91             cache_type)
92         cache_table += 'text primary key, value text, time real)'
93         self.storage.execute(cache_table)
94
95     # value
96     def _sqlite_decode_path(self, path):
97         if path == '_none':
98             return None
99         else:
100             return path
101
102     def _sqlite_encode_path(self, path):
103         if path is None:
104             return '_none'
105         else:
106             return path
107
108     def _sqlite_decode(self, value):
109         return loads(value)
110
111     def _sqlite_encode(self, value):
112         if isinstance(value, list):
113             value = list(value)
114         return dumps(value)
115
116     def setcache(self, cache_type, path, val, time):
117         convert_value = self._sqlite_encode(val)
118         path = self._sqlite_encode_path(path)
119         self.storage.execute("DELETE FROM cache_{0} WHERE path = ?".format(
120             cache_type), (path,), False)
121         self.storage.execute("INSERT INTO cache_{0}(path, value, time) "
122                              "VALUES (?, ?, ?)".format(cache_type),
123                              (path, convert_value, time))
124
125     def getcache(self, cache_type, path, exp):
126         path = self._sqlite_encode_path(path)
127         cached = self.storage.select("SELECT value FROM cache_{0} WHERE "
128                                      "path = ? AND time >= ?".format(
129                                          cache_type), (path, exp))
130         if cached is None:
131             return False, None
132         else:
133             return True, self._sqlite_decode(cached[0])
134
135     def hascache(self, cache_type, path):
136         path = self._sqlite_encode_path(path)
137         return self.storage.select("SELECT value FROM cache_{0} WHERE "
138                                    "path = ?".format(cache_type),
139                                    (path,)) is not None
140
141     def reset_expired_cache(self, cache_type, exp):
142         self.storage.execute("DELETE FROM cache_{0} WHERE time < ?".format(
143             cache_type), (exp,))
144
145     def reset_all_cache(self, cache_type):
146         self.storage.execute("DELETE FROM cache_{0}".format(cache_type))
147
148     def get_cached(self, cache_type, context):
149         """return all values in a dictionary
150         example: {'path1': ('value1', 'time1'), 'path2': ('value2', 'time2')}
151         """
152         ret = {}
153         for path, value, time in self.storage.select("SELECT * FROM cache_{0}"
154                                                      "".format(cache_type),
155                                                      only_one=False):
156             path = self._sqlite_decode_path(path)
157             value = self._sqlite_decode(value)
158             ret[path] = (value, time)
159         return ret