Cache缓存类

import time
import queue
import sqlite3
import threading

class SQLiteConnectionPool:
    def __init__(self, db_path, max_connections=5):
        self.db_path = db_path
        self.max_connections = max_connections
        self._pool = queue.Queue(max_connections)
        self._lock = threading.Lock()
        self._initialize_pool()

    def _initialize_pool(self):
        """初始化连接池"""
        for _ in range(self.max_connections):
            conn = sqlite3.connect(self.db_path, check_same_thread=False)
            conn.execute('PRAGMA journal_mode=WAL')  # 启用WAL模式
            self._pool.put(conn)

    def get_connection(self):
        """从连接池中获取一个连接"""
        return self._pool.get()

    def return_connection(self, conn):
        """将连接返回到连接池"""
        self._pool.put(conn)

    def close_all(self):
        """关闭连接池中的所有连接"""
        while not self._pool.empty():
            conn = self._pool.get()
            conn.close()

class SQLiteCache:
    def __init__(self, db_path='cache.db', max_connections=5):
        self.db_path = db_path
        self.pool = SQLiteConnectionPool(db_path, max_connections)
        self._create_table()

    def _create_table(self):
        """创建缓存表"""
        conn = self.pool.get_connection()
        try:
            cursor = conn.cursor()
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS cache (
                    key TEXT PRIMARY KEY,
                    value TEXT,
                    expiry REAL
                )
            ''')
            conn.commit()
        finally:
            self.pool.return_connection(conn)

    def set(self, key, value, ttl=None):
        """设置缓存值"""
        conn = self.pool.get_connection()
        try:
            cursor = conn.cursor()
            expiry = time.time() + ttl if ttl else None
            cursor.execute('''
                INSERT OR REPLACE INTO cache (key, value, expiry)
                VALUES (?, ?, ?)
            ''', (key, value, expiry))
            conn.commit()
        finally:
            self.pool.return_connection(conn)

    def get(self, key):
        """获取缓存值"""
        conn = self.pool.get_connection()
        try:
            cursor = conn.cursor()
            cursor.execute('''
                SELECT value, expiry FROM cache WHERE key = ?
            ''', (key,))
            result = cursor.fetchone()
            if result:
                value, expiry = result
                if expiry is None or expiry > time.time():
                    return value
                else:
                    # 如果缓存已过期,删除它
                    self.delete(key)
            return None
        finally:
            self.pool.return_connection(conn)

    def delete(self, key):
        """删除缓存值"""
        conn = self.pool.get_connection()
        try:
            cursor = conn.cursor()
            cursor.execute('''
                DELETE FROM cache WHERE key = ?
            ''', (key,))
            conn.commit()
        finally:
            self.pool.return_connection(conn)

    def clear(self):
        """清空缓存"""
        conn = self.pool.get_connection()
        try:
            cursor = conn.cursor()
            cursor.execute('DELETE FROM cache')
            conn.commit()
        finally:
            self.pool.return_connection(conn)

    def close(self):
        """关闭连接池"""
        self.pool.close_all()

# 示例用法
if __name__ == "__main__":
    # 创建缓存实例,最大连接数为5
    cache = SQLiteCache(max_connections=5)

    # 设置缓存
    cache.set('my_key', 'my_value', ttl=10)  # 缓存10秒

    # 获取缓存
    value = cache.get('my_key')
    print(value)  # 输出: my_value

    # 等待缓存过期
    time.sleep(11)
    value = cache.get('my_key')
    print(value)  # 输出: None

    # 关闭缓存
    cache.close()

# import sqlite3
# import time
# import threading
# from queue import Queue
# import redis

# class SQLiteConnectionPool:
#     def __init__(self, db_path, max_connections=5):
#         self.db_path = db_path
#         self.max_connections = max_connections
#         self._pool = Queue(max_connections)
#         self._lock = threading.Lock()
#         self._initialize_pool()

#     def _initialize_pool(self):
#         """初始化连接池"""
#         for _ in range(self.max_connections):
#             conn = sqlite3.connect(self.db_path, check_same_thread=False)
#             conn.execute('PRAGMA journal_mode=WAL')  # 启用WAL模式
#             self._pool.put(conn)

#     def get_connection(self):
#         """从连接池中获取一个连接"""
#         return self._pool.get()

#     def return_connection(self, conn):
#         """将连接返回到连接池"""
#         self._pool.put(conn)

#     def close_all(self):
#         """关闭连接池中的所有连接"""
#         while not self._pool.empty():
#             conn = self._pool.get()
#             conn.close()

# class Cache:
#     def __init__(self, backend='sqlite', **kwargs):
#         """
#         初始化缓存类
#         :param backend: 缓存后端,支持 'sqlite' 或 'redis'
#         :param kwargs: 后端相关参数
#         """
#         self.backend = backend
#         if backend == 'sqlite':
#             self.db_path = kwargs.get('db_path', 'cache.db')
#             self.max_connections = kwargs.get('max_connections', 5)
#             self.pool = SQLiteConnectionPool(self.db_path, self.max_connections)
#             self._create_table()
#         elif backend == 'redis':
#             self.host = kwargs.get('host', 'localhost')
#             self.port = kwargs.get('port', 6379)
#             self.db = kwargs.get('db', 0)
#             self.redis_client = redis.Redis(host=self.host, port=self.port, db=self.db)
#         else:
#             raise ValueError("Unsupported backend. Choose 'sqlite' or 'redis'.")

#     def _create_table(self):
#         """创建缓存表(仅用于SQLite)"""
#         if self.backend == 'sqlite':
#             conn = self.pool.get_connection()
#             try:
#                 cursor = conn.cursor()
#                 cursor.execute('''
#                     CREATE TABLE IF NOT EXISTS cache (
#                         key TEXT PRIMARY KEY,
#                         value TEXT,
#                         expiry REAL
#                     )
#                 ''')
#                 conn.commit()
#             finally:
#                 self.pool.return_connection(conn)

#     def set(self, key, value, ttl=None):
#         """设置缓存值"""
#         if self.backend == 'sqlite':
#             conn = self.pool.get_connection()
#             try:
#                 cursor = conn.cursor()
#                 expiry = time.time() + ttl if ttl else None
#                 cursor.execute('''
#                     INSERT OR REPLACE INTO cache (key, value, expiry)
#                     VALUES (?, ?, ?)
#                 ''', (key, value, expiry))
#                 conn.commit()
#             finally:
#                 self.pool.return_connection(conn)
#         elif self.backend == 'redis':
#             if ttl:
#                 self.redis_client.setex(key, ttl, value)
#             else:
#                 self.redis_client.set(key, value)

#     def get(self, key):
#         """获取缓存值"""
#         if self.backend == 'sqlite':
#             conn = self.pool.get_connection()
#             try:
#                 cursor = conn.cursor()
#                 cursor.execute('''
#                     SELECT value, expiry FROM cache WHERE key = ?
#                 ''', (key,))
#                 result = cursor.fetchone()
#                 if result:
#                     value, expiry = result
#                     if expiry is None or expiry > time.time():
#                         return value
#                     else:
#                         # 如果缓存已过期,删除它
#                         self.delete(key)
#                 return None
#             finally:
#                 self.pool.return_connection(conn)
#         elif self.backend == 'redis':
#             value = self.redis_client.get(key)
#             return value.decode('utf-8') if value else None

#     def delete(self, key):
#         """删除缓存值"""
#         if self.backend == 'sqlite':
#             conn = self.pool.get_connection()
#             try:
#                 cursor = conn.cursor()
#                 cursor.execute('''
#                     DELETE FROM cache WHERE key = ?
#                 ''', (key,))
#                 conn.commit()
#             finally:
#                 self.pool.return_connection(conn)
#         elif self.backend == 'redis':
#             self.redis_client.delete(key)

#     def clear(self):
#         """清空缓存"""
#         if self.backend == 'sqlite':
#             conn = self.pool.get_connection()
#             try:
#                 cursor = conn.cursor()
#                 cursor.execute('DELETE FROM cache')
#                 conn.commit()
#             finally:
#                 self.pool.return_connection(conn)
#         elif self.backend == 'redis':
#             self.redis_client.flushdb()

#     def close(self):
#         """关闭连接"""
#         if self.backend == 'sqlite':
#             self.pool.close_all()
#         elif self.backend == 'redis':
#             self.redis_client.close()

# # 示例用法
# if __name__ == "__main__":
#     # 使用 SQLite 作为缓存后端
#     sqlite_cache = Cache(backend='sqlite', db_path='cache.db', max_connections=5)
#     sqlite_cache.set('sqlite_key', 'sqlite_value', ttl=10)
#     print(sqlite_cache.get('sqlite_key'))  # 输出: sqlite_value
#     sqlite_cache.close()

#     # 使用 Redis 作为缓存后端
#     redis_cache = Cache(backend='redis', host='localhost', port=6379, db=0)
#     redis_cache.set('redis_key', 'redis_value', ttl=10)
#     print(redis_cache.get('redis_key'))  # 输出: redis_value
#     redis_cache.close()

import sqlite3
import time
import threading

class SQLiteCache:
    def __init__(self, db_path='cache.db', thread_safe=False):
        self.db_path = db_path
        self.thread_safe = thread_safe
        self.conn = self._create_connection()
        self.cursor = self.conn.cursor()
        self.lock = threading.Lock() if thread_safe else None
        self._create_table()

    def _create_connection(self):
        """创建数据库连接"""
        conn = sqlite3.connect(self.db_path, check_same_thread=False)
        # 启用WAL模式以提高并发性能
        conn.execute('PRAGMA journal_mode=WAL')
        return conn

    def _create_table(self):
        """创建缓存表"""
        with self._thread_safe_context():
            self.cursor.execute('''
                CREATE TABLE IF NOT EXISTS cache (
                    key TEXT PRIMARY KEY,
                    value TEXT,
                    expiry REAL
                )
            ''')
            self.conn.commit()

    def set(self, key, value, ttl=None):
        """设置缓存值"""
        with self._thread_safe_context():
            expiry = time.time() + ttl if ttl else None
            self.cursor.execute('''
                INSERT OR REPLACE INTO cache (key, value, expiry)
                VALUES (?, ?, ?)
            ''', (key, value, expiry))
            self.conn.commit()

    def get(self, key):
        """获取缓存值"""
        with self._thread_safe_context():
            self.cursor.execute('''
                SELECT value, expiry FROM cache WHERE key = ?
            ''', (key,))
            result = self.cursor.fetchone()
            if result:
                value, expiry = result
                if expiry is None or expiry > time.time():
                    return value
                else:
                    # 如果缓存已过期,删除它
                    self.delete(key)
            return None

    def delete(self, key):
        """删除缓存值"""
        with self._thread_safe_context():
            self.cursor.execute('''
                DELETE FROM cache WHERE key = ?
            ''', (key,))
            self.conn.commit()

    def clear(self):
        """清空缓存"""
        with self._thread_safe_context():
            self.cursor.execute('DELETE FROM cache')
            self.conn.commit()

    def close(self):
        """关闭数据库连接"""
        with self._thread_safe_context():
            self.conn.close()

    def _thread_safe_context(self):
        """线程安全上下文管理器"""
        if self.thread_safe:
            return self.lock
        else:
            return DummyContextManager()

class DummyContextManager:
    """用于非线程安全模式的虚拟上下文管理器"""
    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass

# 示例用法
if __name__ == "__main__":
    # 启用线程安全模式
    cache = SQLiteCache(thread_safe=True)

    # 设置缓存
    cache.set('my_key', 'my_value', ttl=10)  # 缓存10秒

    # 获取缓存
    value = cache.get('my_key')
    print(value)  # 输出: my_value

    # 等待缓存过期
    time.sleep(11)
    value = cache.get('my_key')
    print(value)  # 输出: None

    # 关闭缓存
    cache.close()
posted @   glc400  阅读(4)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示