8c5bc8677d9c3ae5b5dc913d1ef20d73b4129879
[tiramisu.git] / tiramisu / storage / sqlite3 / storage.py
1 # -*- coding: utf-8 -*-
2 "default plugin for cache: set it in a simple dictionary"
3 # Copyright (C) 2013 Team tiramisu (see AUTHORS for all contributors)
4 #
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 2 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
18 #
19 # ____________________________________________________________
20
21 from pickle import dumps, loads
22 from os import unlink
23 from os.path import basename, splitext, join
24 import sqlite3
25 from glob import glob
26
27
28 extension = 'db'
29 dir_database = '/tmp'
30
31
32 def _gen_filename(name):
33     return join(dir_database, '{0}.{1}'.format(name, extension))
34
35
36 def enumerate():
37     names = []
38     for filename in glob(_gen_filename('*')):
39         names.append(basename(splitext(filename)[0]))
40     return names
41
42
43 def delete(session_id):
44     unlink(_gen_filename(session_id))
45
46
47 class Storage(object):
48     __slots__ = ('_conn', '_cursor', 'is_persistent', '_session_id')
49     storage = 'sqlite3'
50
51     def __init__(self, session_id, is_persistent):
52         self.is_persistent = is_persistent
53         self._session_id = session_id
54         self._conn = sqlite3.connect(_gen_filename(self._session_id))
55         self._conn.text_factory = str
56         self._cursor = self._conn.cursor()
57
58     def execute(self, sql, params=None, commit=True):
59         if params is None:
60             params = tuple()
61         self._cursor.execute(sql, params)
62         if commit:
63             self._conn.commit()
64
65     def select(self, sql, params=None, only_one=True):
66         self.execute(sql, params=params, commit=False)
67         if only_one:
68             return self._cursor.fetchone()
69         else:
70             return self._cursor.fetchall()
71
72     def __del__(self):
73         self._cursor.close()
74         self._conn.close()
75         if not self.is_persistent:
76             delete(self._session_id)
77
78
79 class Cache(object):
80     __slots__ = ('storage',)
81     key_is_path = True
82
83     def __init__(self, cache_type, storage):
84         self.storage = storage
85         cache_table = 'CREATE TABLE IF NOT EXISTS cache_{0}(path '.format(
86             cache_type)
87         cache_table += 'text primary key, value text, time real)'
88         self.storage.execute(cache_table)
89
90     # value
91     def _sqlite_decode_path(self, path):
92         if path == '_none':
93             return None
94         else:
95             return path
96
97     def _sqlite_encode_path(self, path):
98         if path is None:
99             return '_none'
100         else:
101             return path
102
103     def _sqlite_decode(self, value):
104         return loads(value)
105
106     def _sqlite_encode(self, value):
107         if isinstance(value, list):
108             value = list(value)
109         return dumps(value)
110
111     def setcache(self, cache_type, path, val, time):
112         convert_value = self._sqlite_encode(val)
113         path = self._sqlite_encode_path(path)
114         self.storage.execute("DELETE FROM cache_{0} WHERE path = ?".format(
115             cache_type), (path,), False)
116         self.storage.execute("INSERT INTO cache_{0}(path, value, time) "
117                              "VALUES (?, ?, ?)".format(cache_type),
118                              (path, convert_value, time))
119
120     def getcache(self, cache_type, path, exp):
121         path = self._sqlite_encode_path(path)
122         cached = self.storage.select("SELECT value FROM cache_{0} WHERE "
123                                      "path = ? AND time >= ?".format(
124                                          cache_type), (path, exp))
125         if cached is None:
126             return False, None
127         else:
128             return True, self._sqlite_decode(cached[0])
129
130     def hascache(self, cache_type, path):
131         path = self._sqlite_encode_path(path)
132         return self.storage.select("SELECT value FROM cache_{0} WHERE "
133                                    "path = ?".format(cache_type),
134                                    (path,)) is not None
135
136     def reset_expired_cache(self, cache_type, exp):
137         self.storage.execute("DELETE FROM cache_{0} WHERE time < ?".format(
138             cache_type), (exp,))
139
140     def reset_all_cache(self, cache_type):
141         self.storage.execute("DELETE FROM cache_{0}".format(cache_type))
142
143     def get_cached(self, cache_type, context):
144         """return all values in a dictionary
145         example: {'path1': ('value1', 'time1'), 'path2': ('value2', 'time2')}
146         """
147         ret = {}
148         for path, value, time in self.storage.select("SELECT * FROM cache_{0}"
149                                                      "".format(cache_type),
150                                                      only_one=False):
151             path = self._sqlite_decode_path(path)
152             value = self._sqlite_decode(value)
153             ret[path] = (value, time)
154         return ret