213 lines
6.7 KiB
Python
213 lines
6.7 KiB
Python
from flask import current_app, Blueprint, render_template, request, jsonify
|
|
|
|
|
|
class FlaskSQLAlchemyWebquery(object):
|
|
def __init__(self, app=None):
|
|
self.models = {}
|
|
self.blueprint = None
|
|
self.db = None
|
|
|
|
if app:
|
|
self.init_app(app)
|
|
|
|
def list_models(self):
|
|
return render_template('flask_sqlalchemy_webquery/list-models.html')
|
|
|
|
def get_joins(self, from_model, to_model):
|
|
if from_model == to_model:
|
|
return []
|
|
|
|
from math import inf
|
|
|
|
Q = set()
|
|
dist = {}
|
|
prev = {}
|
|
|
|
for model, model_data in self.models.items():
|
|
# v: model
|
|
dist[model] = inf
|
|
prev[model] = None
|
|
Q.add(model)
|
|
|
|
dist[from_model] = 0
|
|
|
|
while Q:
|
|
dests = sorted(Q, key=lambda x: dist[x])
|
|
|
|
u = dests[0]
|
|
|
|
if u == to_model:
|
|
break
|
|
|
|
Q.remove(u)
|
|
|
|
for v in self.models[u]['related_tables'].keys():
|
|
if v not in Q:
|
|
continue
|
|
|
|
alt = dist[u] + 1
|
|
|
|
if alt < dist[v]:
|
|
dist[v] = alt
|
|
prev[v] = u
|
|
|
|
S = []
|
|
u = to_model
|
|
|
|
while u in prev:
|
|
S.append(u)
|
|
u = prev[u]
|
|
|
|
S.reverse()
|
|
|
|
return S
|
|
|
|
def update_view(self):
|
|
if 'model' not in request.json or \
|
|
'columns' not in request.json:
|
|
return jsonify((400, {'error': 'Missing input'}))
|
|
|
|
(model_class,) = [cls for cls, data in self.models.items() if data['table'].name == request.json['model']]
|
|
|
|
available_models = {}
|
|
to_add = [model_class]
|
|
|
|
while to_add:
|
|
model = to_add.pop()
|
|
|
|
if model in available_models:
|
|
continue
|
|
|
|
available_models[model] = self.models[model]
|
|
|
|
to_add += self.models[model]['related_tables'].keys()
|
|
|
|
columns = {}
|
|
|
|
# Collect all columns reachable from the selected model
|
|
for cls, model_data in available_models.items():
|
|
table_name = model_data['table'].name
|
|
columns[table_name] = []
|
|
|
|
for column_name, column in model_data['columns'].items():
|
|
columns[table_name].append(column_name)
|
|
|
|
joins = []
|
|
|
|
# Fetch the results
|
|
if request.json['columns']:
|
|
requested_columns = request.json['columns']
|
|
|
|
query = self.db.session.query(model_class)
|
|
|
|
for data in requested_columns:
|
|
model_name, column_name = data['model'], data['column']
|
|
target_model = self.get_model_by_table_name(model_name)
|
|
|
|
for cls in self.get_joins(model_class, target_model):
|
|
if cls not in joins and cls != model_class:
|
|
query = query.join(cls)
|
|
joins.append(cls)
|
|
|
|
result_count = query.count()
|
|
results = []
|
|
|
|
# TODO: Make this configurable
|
|
for row in query.limit(10):
|
|
res_row = {}
|
|
|
|
for data in requested_columns:
|
|
model_name, column_name = data['model'], data['column']
|
|
target_model = self.get_model_by_table_name(model_name)
|
|
res_key = ':'.join((model_name, column_name))
|
|
|
|
print(row)
|
|
if target_model == model_class:
|
|
current_value = getattr(row, column_name)
|
|
else:
|
|
current = model_class
|
|
current_value = row
|
|
|
|
for cls in self.get_joins(model_class, target_model):
|
|
if cls == model_class:
|
|
continue
|
|
|
|
next_column_name = self.models[current]['related_tables'][cls]
|
|
current_value = getattr(current_value, next_column_name)
|
|
current = cls
|
|
|
|
if cls == target_model:
|
|
current_value = getattr(current_value, column_name)
|
|
|
|
res_row[res_key] = current_value
|
|
|
|
results.append(res_row)
|
|
else:
|
|
results = None
|
|
result_count = 0
|
|
|
|
return jsonify({
|
|
'results': results,
|
|
'columns': columns,
|
|
'count': result_count
|
|
})
|
|
|
|
def get_model_by_table_name(self, table_name):
|
|
for model, model_data in self.models.items():
|
|
if model_data['table'].name == table_name:
|
|
return model
|
|
|
|
def init_app(self, app):
|
|
self.db = app.extensions['sqlalchemy'].db
|
|
|
|
self.models = {}
|
|
|
|
for cls in self.db.Model._decl_class_registry.values():
|
|
if isinstance(cls, type) and issubclass(cls, self.db.Model):
|
|
self.models[cls] = {
|
|
'table': cls.__table__,
|
|
'related_tables': {},
|
|
'columns': cls.__mapper__.columns
|
|
}
|
|
|
|
for model in self.models:
|
|
from sqlalchemy.orm.base import MANYTOONE
|
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
|
from sqlalchemy.orm.relationships import RelationshipProperty
|
|
|
|
relationships = {}
|
|
|
|
for attr_name in model.__dict__:
|
|
attr = getattr(model, attr_name)
|
|
|
|
if isinstance(attr, InstrumentedAttribute):
|
|
if isinstance(attr.property, RelationshipProperty):
|
|
relationships[attr.property] = attr_name
|
|
|
|
for r in model.__mapper__.relationships:
|
|
# Only support many to one relationships (e.g. foreign keys)
|
|
if r.direction != MANYTOONE:
|
|
continue
|
|
|
|
for local, remote in r.local_remote_pairs:
|
|
(remote_model,) = [remote_model for remote_model, model_data in self.models.items() if model_data['table'] == remote.table]
|
|
|
|
self.models[model]['related_tables'][remote_model] = relationships[r]
|
|
|
|
self.blueprint = Blueprint('flask_sqlalchemy_webquery',
|
|
__name__,
|
|
template_folder='templates',
|
|
static_folder='static',
|
|
static_url_path=app.static_url_path + 'flask_sqlalchemy_webquery')
|
|
|
|
self.blueprint.add_url_rule('/list-models', 'list_models', self.list_models)
|
|
self.blueprint.add_url_rule('/get-columns', 'update_view', self.update_view, methods=['POST'])
|
|
|
|
@self.blueprint.context_processor
|
|
def inject_variables():
|
|
return {
|
|
'models': self.models
|
|
}
|
|
|
|
app.register_blueprint(self.blueprint)
|