154 lines
4.3 KiB
Python
154 lines
4.3 KiB
Python
# Calendar.social
|
|
# Copyright (C) 2018 Gergely Polonkai
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
"""Caching functionality for Calendar.social
|
|
"""
|
|
|
|
from datetime import timedelta
|
|
import pickle
|
|
from uuid import uuid4
|
|
|
|
from flask import current_app, has_request_context, request, session
|
|
from flask.sessions import SessionInterface, SessionMixin
|
|
from flask_caching import Cache
|
|
from werkzeug.datastructures import CallbackDict
|
|
|
|
cache = Cache() # pylint: disable=invalid-name
|
|
|
|
|
|
class CachedSession(CallbackDict, SessionMixin): # pylint: disable=too-many-ancestors
|
|
"""Object for session data saved in the cache
|
|
"""
|
|
|
|
def __init__(self, initial=None, sid=None, new=False):
|
|
self.__modifying = False
|
|
|
|
def on_update(self):
|
|
"""Function to call when session data is updated
|
|
"""
|
|
|
|
if self.__modifying:
|
|
return
|
|
|
|
self.__modifying = True
|
|
|
|
if has_request_context():
|
|
self['ip'] = request.remote_addr
|
|
|
|
self.modified = True
|
|
|
|
self.__modifying = False
|
|
|
|
CallbackDict.__init__(self, initial, on_update)
|
|
self.sid = sid
|
|
self.new = new
|
|
self.modified = False
|
|
|
|
@property
|
|
def user(self):
|
|
from calsocial.models import User
|
|
|
|
if 'user_id' not in self:
|
|
return None
|
|
|
|
return User.query.get(self['user_id'])
|
|
|
|
|
|
class CachedSessionInterface(SessionInterface):
|
|
"""A session interface that loads/saves session data from the cache
|
|
"""
|
|
|
|
serializer = pickle
|
|
session_class = CachedSession
|
|
global_cache = cache
|
|
|
|
def __init__(self, prefix='session:'):
|
|
self.cache = cache
|
|
self.prefix = prefix
|
|
|
|
@staticmethod
|
|
def generate_sid():
|
|
"""Generade a new session ID
|
|
"""
|
|
|
|
return str(uuid4())
|
|
|
|
@staticmethod
|
|
def get_cache_expiration_time(app, session):
|
|
"""Get the expiration time of the cache entry
|
|
"""
|
|
|
|
if session.permanent:
|
|
return app.permanent_session_lifetime
|
|
|
|
return timedelta(days=1)
|
|
|
|
def open_session(self, app, request):
|
|
sid = request.cookies.get(app.session_cookie_name)
|
|
|
|
if not sid:
|
|
sid = self.generate_sid()
|
|
|
|
return self.session_class(sid=sid, new=True)
|
|
|
|
session = self.load_session(sid)
|
|
|
|
if session is None:
|
|
return self.session_class(sid=sid, new=True)
|
|
|
|
return session
|
|
|
|
def load_session(self, sid):
|
|
"""Load a specific session from the cache
|
|
"""
|
|
|
|
val = self.cache.get(self.prefix + sid)
|
|
|
|
if val is None:
|
|
return None
|
|
|
|
data = self.serializer.loads(val)
|
|
|
|
return self.session_class(data, sid=sid)
|
|
|
|
def save_session(self, app, session, response):
|
|
domain = self.get_cookie_domain(app)
|
|
|
|
if not session:
|
|
self.cache.delete(self.prefix + session.sid)
|
|
|
|
if session.modified:
|
|
response.delete_cookie(app.session_cookie_name, domain=domain)
|
|
|
|
return
|
|
|
|
cache_exp = self.get_cache_expiration_time(app, session)
|
|
cookie_exp = self.get_expiration_time(app, session)
|
|
val = self.serializer.dumps(dict(session))
|
|
self.cache.set(self.prefix + session.sid, val, int(cache_exp.total_seconds()))
|
|
|
|
response.set_cookie(app.session_cookie_name,
|
|
session.sid,
|
|
expires=cookie_exp,
|
|
httponly=True,
|
|
domain=domain)
|
|
|
|
def delete_session(self, sid):
|
|
if has_request_context() and session.sid == sid:
|
|
raise ValueError('Will not delete the current session')
|
|
|
|
cache.delete(self.prefix + sid)
|