Source code for asclepias_broker.search.query

# -*- coding: utf-8 -*-
#
# Copyright (C) 2018 CERN.
#
# Asclepias Broker is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
"""Search utilities."""

from typing import Dict

from elasticsearch_dsl import Q
from elasticsearch_dsl.query import Range
from flask import request
from invenio_records_rest.errors import InvalidQueryRESTError
from invenio_rest.errors import FieldError, RESTValidationError


[docs]def search_factory(self, search, query_parser=None): """Parse query using elasticsearch DSL query. :param self: REST view. :param search: Elastic search DSL search instance. :returns: Tuple with search instance and URL arguments. """ from invenio_records_rest.facets import default_facets_factory from invenio_records_rest.sorter import default_sorter_factory search_index = search._index[0] # TODO: make "scheme" optional? for field in ('id', 'scheme', 'relation'): if field not in request.values: raise RESTValidationError( errors=[FieldError(field, 'Required field.')]) search, urlkwargs = default_facets_factory(search, search_index) search, sortkwargs = default_sorter_factory(search, search_index) for key, value in sortkwargs.items(): urlkwargs.add(key, value) # Apply 'identity' grouping by default if 'group_by' not in request.values: search = search.filter(Q('term', Grouping='identity')) urlkwargs['group_by'] = 'identity' try: query_string = request.values.get('q') if query_string: search = search.query(Q('query_string', query=query_string, default_field='_search_all')) urlkwargs['q'] = query_string except SyntaxError: raise InvalidQueryRESTError() # Exclude the identifiers by which the search was made (large aggregate) search = search.source(exclude=['*.SearchIdentifier']) return search, urlkwargs
[docs]def enum_term_filter(label: str, field: str, choices: Dict[str, str]): """Term filter with controlled vocabulary.""" def inner(values): if len(values) != 1: raise RESTValidationError( errors=[FieldError(label, 'Multiple values specified.')]) term_value = choices.get(values[0]) if not term_value: raise RESTValidationError( errors=[FieldError( label, 'Allowed values: [{}]'.format(', '.join(choices)))]) return Q('term', **{field: term_value}) return inner
[docs]def nested_terms_filter(field: str, path: str = None): """Nested terms filter.""" path = path or field.rsplit('.', 1)[0] def inner(values): return Q('nested', path=path, query=dict(terms={field: values})) return inner
[docs]def nested_range_filter( label: str, field: str, path: str = None, op: str = None): """Nested range filter.""" path = path or field.rsplit('.', 1)[0] assert op in ('gte', 'gt', 'lte', 'lt') def inner(values): if len(values) != 1: raise RESTValidationError( errors=[FieldError(label, 'Multiple values specified.')]) return Q('nested', path=path, query=Range(**{field: {op: values[0]}})) return inner