# Copyright(C) 2019-2021 Romain Bignon
#
# This file is part of woob.
#
# woob is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# woob is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with woob. If not, see <http://www.gnu.org/licenses/>.
import os
import sqlite3
import tempfile
from collections.abc import Mapping, MutableMapping
import yaml
from .iconfig import ConfigError, IConfig
from .util import replace, time_buffer
from .yamlconfig import WoobDumper
try:
from yaml import CSafeLoader as SafeLoader
except ImportError:
from yaml import SafeLoader
__all__ = ['SQLiteConfig']
def quote(s, errors="strict"):
encodable = s.encode("utf-8", errors).decode("utf-8")
nul_index = encodable.find("\x00")
if nul_index >= 0:
raise ValueError("invalid identifier")
return "\"" + encodable.replace("\"", "\"\"") + "\""
class VirtualRootDict(Mapping):
def __init__(self, config):
self.config = config
def __getitem__(self, base):
if base in self.config._tables:
return VirtualDict(self.config, base)
raise KeyError('%s table not found' % base)
def __iter__(self):
for base in self.config._tables:
yield base
def __len__(self):
return len(self.config._tables)
class VirtualDict(MutableMapping):
def __init__(self, config, base):
self.config = config
self.base = base
def __getitem__(self, key):
try:
return self.config.get(self.base, key)
except ConfigError:
raise KeyError('%s key in %s table not found' % (key, self.base))
def __contains__(self, key):
return self.config.has(self.base, key)
def __iter__(self):
for key in self.config.keys(self.base):
yield key
def items(self):
return self.config.items(self.base)
def __len__(self):
return self.config.count(self.base)
def __delitem__(self, key):
try:
self.config.delete(self, self.base, key)
except ConfigError:
raise KeyError('%s key in %s table not found' % (key, self.base))
def __setitem__(self, key, value):
self.config.set(self.base, key, value)
[docs]class SQLiteConfig(IConfig):
commit_since_seconds = 3600
dump_since_seconds = 600
def __init__(self, path, commit_since_seconds=None, dump_since_seconds=None, last_run=True, logger=None):
self.path = path
if commit_since_seconds:
self.commit_since_seconds = commit_since_seconds
if dump_since_seconds:
self.dump_since_seconds = dump_since_seconds
if self.commit_since_seconds:
self.commit = time_buffer(since_seconds=self.commit_since_seconds, last_run=last_run, logger=logger)(self.commit)
if self.dump_since_seconds:
self.dump = time_buffer(since_seconds=self.dump_since_seconds, last_run=last_run, logger=logger)(self.dump)
[docs] def load(self, default={}, optimize=True):
self.storage = sqlite3.connect(self.path)
self.storage.execute('PRAGMA page_size = 4096')
if optimize:
self.storage.execute('VACUUM')
self.storage.execute('REINDEX')
self._tables = set(self.tables())
self.values = VirtualRootDict(self)
[docs] def save(self, commit_since_seconds=None, dump_since_seconds=None):
self.commit(since_seconds=commit_since_seconds)
# No one would want immediate dumps, assume it means no dumps
if self.dump_since_seconds:
self.dump(since_seconds=dump_since_seconds)
[docs] def force_save(self):
self.save(commit_since_seconds=False, dump_since_seconds=False)
def __exit__(self, t, v, tb):
self.force_save()
super(SQLiteConfig, self).__exit__(t, v, tb)
[docs] def commit(self, **kwargs):
kwargs.pop('since_seconds', None)
self.storage.commit()
[docs] def dump(self, **kwargs):
kwargs.pop('since_seconds', None)
target = os.path.splitext(self.path)[0] + '.sql'
with tempfile.NamedTemporaryFile(dir=os.path.dirname(self.path), delete=False) as f:
for line in self.storage.iterdump():
f.write(str(line).encode('utf-8'))
f.write(b'\n')
replace(f.name, target)
[docs] def ensure_table(self, name):
if name not in self._tables:
self.storage.execute('''CREATE TABLE IF NOT EXISTS %s (
key text PRIMARY KEY,
value text
);''' % quote(name)) # nosec
self._tables.add(name)
[docs] def tables(self):
cur = self.storage.cursor()
cur.execute('''SELECT name FROM sqlite_master
WHERE type="table" AND name NOT LIKE "sqlite_%";''')
return [k[0] for k in cur.fetchall()]
[docs] def items(self, table, size=100):
"""
Low memory way of listing all items.
The size parameters alters how many items are fetched at a time.
"""
cur = self.storage.cursor()
cur.execute('SELECT key, value FROM %s;' % quote(table)) # nosec
items = cur.fetchmany(size)
while items:
for key, strvalue in items:
yield key, yaml.load(strvalue, Loader=SafeLoader)
items = cur.fetchmany(size)
[docs] def keys(self, table, size=200):
"""
Low memory way of listing all keys.
The size parameters alters how many items are fetched at a time.
"""
cur = self.storage.cursor()
cur.execute('SELECT key FROM %s;' % quote(table)) # nosec
items = cur.fetchmany(size)
while items:
for item in items:
yield item[0]
items = cur.fetchmany(size)
[docs] def count(self, table):
cur = self.storage.cursor()
cur.execute('SELECT count(*) FROM %s;' % quote(table)) # nosec
return cur.fetchone()[0]
[docs] def get(self, *args, **kwargs):
table = args[0]
key = '.'.join(args[1:])
self.ensure_table(table)
if not key:
return self.values[table]
try:
cur = self.storage.cursor()
cur.execute('SELECT value FROM %s WHERE key=?;' % quote(table), (key, )) # nosec
row = cur.fetchone()
if row is None:
if 'default' in kwargs:
value = kwargs.get('default')
else:
raise ConfigError()
else:
strvalue = row[0]
value = yaml.load(strvalue, Loader=SafeLoader)
except TypeError:
raise ConfigError()
return value
[docs] def set(self, *args):
table = args[0]
key = '.'.join(args[1:-1])
if not key:
raise ConfigError('A minimum of two levels are required.')
value = args[-1]
self.ensure_table(table)
try:
strvalue = yaml.dump(value, None, Dumper=WoobDumper, default_flow_style=False)
cur = self.storage.cursor()
cur.execute('''INSERT OR REPLACE INTO %s VALUES (?, ?)''' % quote(table), (key, strvalue)) # nosec
except KeyError:
raise ConfigError()
except TypeError:
raise ConfigError()
[docs] def delete(self, *args):
table = args[0]
key = '.'.join(args[1:])
if not key:
if table in self._tables:
cur = self.storage.cursor()
cur.execute('DROP TABLE %s;' % quote(table)) # nosec
else:
raise ConfigError()
else:
self.ensure_table(table)
cur = self.storage.cursor()
cur.execute('DELETE FROM %s WHERE key=?;' % quote(table), (key, )) # nosec
if not cur.rowcount:
raise ConfigError()
[docs] def has(self, *args):
table = args[0]
key = '.'.join(args[1:])
if not key:
return table in self._tables
cur = self.storage.cursor()
cur.execute('SELECT count(*) FROM %s WHERE key=?;' % quote(table), (key, )) # nosec
return cur.fetchone()[0] > 0