Source code for woob.tools.config.sqliteconfig

# Copyright(C) 2019-2021 Romain Bignon
#
# This file is part of woob.
#
# woob is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# woob is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with woob. If not, see <http://www.gnu.org/licenses/>.


import os
import sqlite3
import tempfile
from collections.abc import Mapping, MutableMapping

import yaml

from .iconfig import ConfigError, IConfig
from .util import replace, time_buffer
from .yamlconfig import WoobDumper

try:
    from yaml import CSafeLoader as SafeLoader
except ImportError:
    from yaml import SafeLoader


__all__ = ['SQLiteConfig']


def quote(s, errors="strict"):
    encodable = s.encode("utf-8", errors).decode("utf-8")
    nul_index = encodable.find("\x00")
    if nul_index >= 0:
        raise ValueError("invalid identifier")

    return "\"" + encodable.replace("\"", "\"\"") + "\""


class VirtualRootDict(Mapping):
    def __init__(self, config):
        self.config = config

    def __getitem__(self, base):
        if base in self.config._tables:
            return VirtualDict(self.config, base)
        raise KeyError('%s table not found' % base)

    def __iter__(self):
        for base in self.config._tables:
            yield base

    def __len__(self):
        return len(self.config._tables)


class VirtualDict(MutableMapping):
    def __init__(self, config, base):
        self.config = config
        self.base = base

    def __getitem__(self, key):
        try:
            return self.config.get(self.base, key)
        except ConfigError:
            raise KeyError('%s key in %s table not found' % (key, self.base))

    def __contains__(self, key):
        return self.config.has(self.base, key)

    def __iter__(self):
        for key in self.config.keys(self.base):
            yield key

    def items(self):
        return self.config.items(self.base)

    def __len__(self):
        return self.config.count(self.base)

    def __delitem__(self, key):
        try:
            self.config.delete(self, self.base, key)
        except ConfigError:
            raise KeyError('%s key in %s table not found' % (key, self.base))

    def __setitem__(self, key, value):
        self.config.set(self.base, key, value)


[docs]class SQLiteConfig(IConfig): commit_since_seconds = 3600 dump_since_seconds = 600 def __init__(self, path, commit_since_seconds=None, dump_since_seconds=None, last_run=True, logger=None): self.path = path if commit_since_seconds: self.commit_since_seconds = commit_since_seconds if dump_since_seconds: self.dump_since_seconds = dump_since_seconds if self.commit_since_seconds: self.commit = time_buffer(since_seconds=self.commit_since_seconds, last_run=last_run, logger=logger)(self.commit) if self.dump_since_seconds: self.dump = time_buffer(since_seconds=self.dump_since_seconds, last_run=last_run, logger=logger)(self.dump)
[docs] def load(self, default={}, optimize=True): self.storage = sqlite3.connect(self.path) self.storage.execute('PRAGMA page_size = 4096') if optimize: self.storage.execute('VACUUM') self.storage.execute('REINDEX') self._tables = set(self.tables()) self.values = VirtualRootDict(self)
[docs] def save(self, commit_since_seconds=None, dump_since_seconds=None): self.commit(since_seconds=commit_since_seconds) # No one would want immediate dumps, assume it means no dumps if self.dump_since_seconds: self.dump(since_seconds=dump_since_seconds)
[docs] def force_save(self): self.save(commit_since_seconds=False, dump_since_seconds=False)
def __exit__(self, t, v, tb): self.force_save() super(SQLiteConfig, self).__exit__(t, v, tb)
[docs] def commit(self, **kwargs): kwargs.pop('since_seconds', None) self.storage.commit()
[docs] def dump(self, **kwargs): kwargs.pop('since_seconds', None) target = os.path.splitext(self.path)[0] + '.sql' with tempfile.NamedTemporaryFile(dir=os.path.dirname(self.path), delete=False) as f: for line in self.storage.iterdump(): f.write(str(line).encode('utf-8')) f.write(b'\n') replace(f.name, target)
[docs] def ensure_table(self, name): if name not in self._tables: self.storage.execute('''CREATE TABLE IF NOT EXISTS %s ( key text PRIMARY KEY, value text );''' % quote(name)) # nosec self._tables.add(name)
[docs] def tables(self): cur = self.storage.cursor() cur.execute('''SELECT name FROM sqlite_master WHERE type="table" AND name NOT LIKE "sqlite_%";''') return [k[0] for k in cur.fetchall()]
[docs] def items(self, table, size=100): """ Low memory way of listing all items. The size parameters alters how many items are fetched at a time. """ cur = self.storage.cursor() cur.execute('SELECT key, value FROM %s;' % quote(table)) # nosec items = cur.fetchmany(size) while items: for key, strvalue in items: yield key, yaml.load(strvalue, Loader=SafeLoader) items = cur.fetchmany(size)
[docs] def keys(self, table, size=200): """ Low memory way of listing all keys. The size parameters alters how many items are fetched at a time. """ cur = self.storage.cursor() cur.execute('SELECT key FROM %s;' % quote(table)) # nosec items = cur.fetchmany(size) while items: for item in items: yield item[0] items = cur.fetchmany(size)
[docs] def count(self, table): cur = self.storage.cursor() cur.execute('SELECT count(*) FROM %s;' % quote(table)) # nosec return cur.fetchone()[0]
[docs] def get(self, *args, **kwargs): table = args[0] key = '.'.join(args[1:]) self.ensure_table(table) if not key: return self.values[table] try: cur = self.storage.cursor() cur.execute('SELECT value FROM %s WHERE key=?;' % quote(table), (key, )) # nosec row = cur.fetchone() if row is None: if 'default' in kwargs: value = kwargs.get('default') else: raise ConfigError() else: strvalue = row[0] value = yaml.load(strvalue, Loader=SafeLoader) except TypeError: raise ConfigError() return value
[docs] def set(self, *args): table = args[0] key = '.'.join(args[1:-1]) if not key: raise ConfigError('A minimum of two levels are required.') value = args[-1] self.ensure_table(table) try: strvalue = yaml.dump(value, None, Dumper=WoobDumper, default_flow_style=False) cur = self.storage.cursor() cur.execute('''INSERT OR REPLACE INTO %s VALUES (?, ?)''' % quote(table), (key, strvalue)) # nosec except KeyError: raise ConfigError() except TypeError: raise ConfigError()
[docs] def delete(self, *args): table = args[0] key = '.'.join(args[1:]) if not key: if table in self._tables: cur = self.storage.cursor() cur.execute('DROP TABLE %s;' % quote(table)) # nosec else: raise ConfigError() else: self.ensure_table(table) cur = self.storage.cursor() cur.execute('DELETE FROM %s WHERE key=?;' % quote(table), (key, )) # nosec if not cur.rowcount: raise ConfigError()
[docs] def has(self, *args): table = args[0] key = '.'.join(args[1:]) if not key: return table in self._tables cur = self.storage.cursor() cur.execute('SELECT count(*) FROM %s WHERE key=?;' % quote(table), (key, )) # nosec return cur.fetchone()[0] > 0