Commit f1f7ca08 authored by Nicolas Joyard's avatar Nicolas Joyard

Ajout filtres API REST

parent c54b123f
......@@ -2,15 +2,27 @@
import collections
from copy import copy
from datetime import datetime
from flask import abort, jsonify, request
from flask import abort, json, jsonify, request
from sqlalchemy.inspection import inspect
from .schema import SchemaFactory
class EndpointException(Exception):
class Endpoint(object):
_casters = {
'Integer': int,
'Unicode': lambda x: x,
'Boolean': lambda x: x.lower() in ('true', 'yes', '1'),
'Date': lambda x: datetime.strptime(x, '%Y-%m-%d').date()
def __init__(self, api, ma, model, description=None, hidden=False):
self.api = api = ma
......@@ -37,6 +49,11 @@ class Endpoint(object):
return Paginated
def _make_error(self, message, code=400):
res = jsonify({ '_error': message, '_code': code })
res.status_code = code
return res
def describe(self):
mapper = inspect(self.model)
......@@ -67,6 +84,55 @@ class Endpoint(object):
return jsonify(desc)
def _cast_value(self, value, typename):
if typename not in self._casters:
raise EndpointException('No type caster for {}'.format(typename))
return self._casters[typename](value)
except ValueError:
raise EndpointException('Could not cast value to {}: {}'.format(typename, value))
def _apply_filters(self, query, args):
if not len(args):
return query
columns = { type(c.type).__name__ for c in inspect(self.model).columns }
filters = []
for colname, value in args.items():
op = 'eq'
if colname.find('__') >= 0:
colname, op = colname.rsplit('__', 1)
if colname not in columns:
raise EndpointException('Unknown column name: {}'.format(colname))
col = getattr(self.model, colname)
if op in ('eq', 'ne', 'gt', 'lt', 'gte', 'lte'):
value = self._cast_value(value, columns[colname])
if op == 'eq':
filters.append(col == value)
elif op == 'ne':
filters.append(col != value)
elif op == 'gt':
filters.append(col > value)
elif op == 'lt':
filters.append(col < value)
elif op == 'gte':
filters.append(col >= value)
elif op == 'lte':
filters.append(col <= value)
elif op == 'isnull':
null = self._cast_value(value, 'Boolean')
if null:
filters.append(col == None)
filters.append(col != None)
raise EndpointException('Unknown operator: {}'.format(op))
return query.filter(*filters)
def list(self):
args = request.args
forward_qs = {}
......@@ -75,6 +141,12 @@ class Endpoint(object):
item_schema, query = self._schema(mode='list')
list_schema = self._paginate(item_schema)
# Handle filters
query = self._apply_filters(query, { k: v for k, v in args.items() if k not in ('search', 'page', 'page_size') })
except EndpointException, e:
return self._make_error(e.message)
# Handle search
if args.get('search', None) and hasattr(query, 'search'):
query =['search'])
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment