Coverage for ramose / operation.py: 96%
1076 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-07-01 13:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-07-01 13:49 +0000
1# SPDX-FileCopyrightText: 2018-2021 Silvio Peroni <silvio.peroni@unibo.it>
2# SPDX-FileCopyrightText: 2020-2021 Marilena Daquino <marilena.daquino2@unibo.it>
3# SPDX-FileCopyrightText: 2022 Davide Brembilla
4# SPDX-FileCopyrightText: 2024 Ivan Heibi <ivan.heibi2@unibo.it>
5# SPDX-FileCopyrightText: 2025 Sergei Slinkin
6# SPDX-FileCopyrightText: 2026 Arcangelo Massari <arcangelo.massari@unibo.it>
7#
8# SPDX-License-Identifier: ISC
10from __future__ import annotations
12import time
13from csv import DictReader, reader, writer
14from dataclasses import dataclass
15from dataclasses import field as dataclass_field
16from http import HTTPStatus
17from io import StringIO
18from itertools import product
19from json import dumps
20from math import ceil
21from operator import eq, gt, itemgetter, lt
22from re import error as regex_error
23from re import findall, fullmatch, match, search, sub
24from typing import TYPE_CHECKING, NoReturn, TypedDict, cast
25from urllib.parse import parse_qs, quote, urlsplit
27from requests.exceptions import RequestException
28from requests.exceptions import Timeout as RequestsTimeout
30try:
31 from pysparql_anything import SparqlAnything # pyright: ignore[reportMissingImports, reportAttributeAccessIssue]
32except ImportError:
33 SparqlAnything = None
35from ramose._constants import (
36 DEFAULT_HTTP_TIMEOUT,
37 FIELD_TYPE_RE,
38 _http_session,
39 backend_auth_header,
40 media_type_for_format,
41)
42from ramose.datatype import DataType
43from ramose.filters import apply_filters
44from ramose.paging import PaginationInfo, build_link_header, build_pagination_info
46if TYPE_CHECKING:
47 import types
48 from collections.abc import Callable, Mapping
49 from typing import Protocol
51 from requests import Response
53 from ramose.cache import ResultCache
54 from ramose.filters import FiltersConfig
56 class SparqlAnythingEngine(Protocol):
57 def select(self, output_type: type[object], **kwargs: object) -> object: ...
60_WRITE_METHODS = frozenset({"post", "put", "delete"})
61_RETRYABLE_STATUS_CODES = frozenset(
62 {
63 HTTPStatus.REQUEST_TIMEOUT,
64 HTTPStatus.TOO_MANY_REQUESTS,
65 HTTPStatus.INTERNAL_SERVER_ERROR,
66 HTTPStatus.BAD_GATEWAY,
67 HTTPStatus.SERVICE_UNAVAILABLE,
68 HTTPStatus.GATEWAY_TIMEOUT,
69 }
70)
71_SPARQL_ANYTHING_HTTP_STATUS_RE = r"\bHTTP/\d(?:\.\d)?\s+(\d{3})\b"
72_SPARQL_ANYTHING_TIMEOUT_MARKERS = frozenset(
73 {
74 "sockettimeoutexception",
75 "timed out",
76 }
77)
78_SPARQL_ANYTHING_NETWORK_MARKERS = frozenset(
79 {
80 "httphostconnectexception",
81 "connectexception",
82 "noroutetohostexception",
83 "unknownhostexception",
84 "connection refused",
85 }
86)
87_IRI_FORBIDDEN = r'[<>"{}|^`\\\x00-\x20]'
88_UNPROCESSABLE_CONTENT = 422
89_JSON_TRANSFORM_RE = r'^(?P<op_type>array|dict)\((?P<separator>"[^"]+"),(?P<entries>[^)]+)\)$'
90_DICT_TRANSFORM_MIN_FIELD_COUNT = 2
91ResultRow = list[str]
92ResultTable = list[ResultRow]
95class CachedPagination(TypedDict):
96 page: int
97 page_size: int
98 total_items: int
101class CachedResult(TypedDict):
102 rows: ResultTable
103 pagination: CachedPagination | None
106class HttpError(Exception):
107 def __init__(self, status_code: int, message: str) -> None:
108 super().__init__(message)
109 self.status_code = status_code
112@dataclass
113class OperationConfig:
114 sparql_endpoint: str = ""
115 update_endpoint: str = ""
116 sparql_http_method: str = "get"
117 addon: types.ModuleType | None = None
118 format_map: dict = dataclass_field(default_factory=dict)
119 format_media_types: dict = dataclass_field(default_factory=dict)
120 sources_map: dict = dataclass_field(default_factory=dict)
121 custom_params: dict = dataclass_field(default_factory=dict)
122 disabled_params: set = dataclass_field(default_factory=set)
123 requires_auth: bool = False
124 cache: ResultCache | None = None
125 default_cache_ttl: int = 86400
126 custom_param_configs: dict[str, FiltersConfig] = dataclass_field(default_factory=dict)
127 public_base_url: str = ""
128 retry_attempts: int = 3
129 retry_wait: float = 0.5
130 retry_backoff: float = 2.0
132 def __post_init__(self) -> None:
133 if self.retry_attempts < 1:
134 msg = "retry_attempts must be >= 1"
135 raise ValueError(msg)
136 if self.retry_wait < 0:
137 msg = "retry_wait must be >= 0"
138 raise ValueError(msg)
139 if self.retry_backoff < 1:
140 msg = "retry_backoff must be >= 1"
141 raise ValueError(msg)
144class Operation:
145 def __init__(
146 self,
147 op_complete_url: str,
148 op_key: str,
149 op_item: dict[str, str],
150 config: OperationConfig | None = None,
151 ) -> None:
152 if config is None:
153 config = OperationConfig()
154 self.url_parsed = urlsplit(op_complete_url)
155 self.op_url = self.url_parsed.path
156 self.op = op_key
157 self.i = op_item
158 self.tp = config.sparql_endpoint
159 self.update_endpoint = config.update_endpoint
160 self.sparql_http_method = config.sparql_http_method
161 self.addon = config.addon
162 self.format = config.format_map
163 self.format_media_types = config.format_media_types
164 self.sources_map = config.sources_map
165 self.custom_params = config.custom_params
166 self.disabled_params = config.disabled_params
167 self.requires_auth = config.requires_auth
168 self._sa_engine = None
169 self._cache = config.cache
170 self._default_cache_ttl = config.default_cache_ttl
171 self.custom_param_configs = config.custom_param_configs
172 self.public_base_url = config.public_base_url
173 self.retry_attempts = config.retry_attempts
174 self.retry_wait = config.retry_wait
175 self.retry_backoff = config.retry_backoff
176 self.pagination_info: PaginationInfo | None = None
178 self.operation = {"=": eq, "<": lt, ">": gt}
180 self.dt = DataType()
182 def _public_request_url(self, request_url: str) -> str:
183 parsed = urlsplit(request_url)
184 if parsed.scheme and parsed.netloc:
185 return request_url
186 return f"{self.public_base_url.rstrip('/')}/{request_url.lstrip('/')}"
188 def _converter_request_url(self) -> str:
189 if self.pagination_info is not None:
190 return self._public_request_url(self.pagination_info.self_url)
191 if self.url_parsed.query:
192 return self._public_request_url(f"{self.op_url}?{self.url_parsed.query}")
193 return self._public_request_url(self.op_url)
195 def _is_builtin_param_active(self, name: str) -> bool:
196 return name not in self.disabled_params and name not in self.custom_params
198 @staticmethod
199 def _raise_unprocessable(message: str) -> NoReturn:
200 raise HttpError(_UNPROCESSABLE_CONTENT, f"HTTP status code {_UNPROCESSABLE_CONTENT}: {message}")
202 @staticmethod
203 def _parse_positive_int_param(params: dict[str, list[str]], name: str) -> int:
204 raw_value = params[name][0]
205 try:
206 value = int(raw_value)
207 except ValueError:
208 Operation._raise_unprocessable(f"{name} must be an integer, got {raw_value!r}")
209 if value < 1:
210 Operation._raise_unprocessable(f"{name} must be >= 1, got {value}")
211 return value
213 @staticmethod
214 def _validate_page_range(page: int, total_items: int, total_pages: int) -> None:
215 if total_items and page > total_pages:
216 Operation._raise_unprocessable(f"page {page} exceeds total pages {total_pages}")
218 @staticmethod
219 def get_content_type(ct: str) -> str:
220 """It returns the mime type of a given textual representation of a format, being it either
221 'csv' or 'json."""
222 content_type = ct
224 if ct == "csv":
225 content_type = "text/csv"
226 elif ct == "json":
227 content_type = "application/json"
229 return content_type
231 def _resolve_format(self, s: str, query_string: dict[str, list[str]]) -> tuple[str, str] | None:
232 if "format" in query_string and self._is_builtin_param_active("format"):
233 for req_format in query_string["format"]:
234 if req_format in self.format:
235 request_url = self._converter_request_url()
236 converter_func = getattr(self.addon, self.format[req_format])
237 return converter_func(s, request_url=request_url), self._media_type_for_format(req_format)
238 elif "default_format" in self.i:
239 default_fmt = self.i["default_format"].strip()
240 if default_fmt in self.format:
241 request_url = self._converter_request_url()
242 converter_func = getattr(self.addon, self.format[default_fmt])
243 return converter_func(s, request_url=request_url), self._media_type_for_format(default_fmt)
244 return None
246 def _media_type_for_format(self, fmt: str) -> str:
247 if fmt in self.format_media_types:
248 return self.format_media_types[fmt]
249 return Operation.get_content_type(fmt)
251 def _validate_format_values(self, formats: list[str]) -> None:
252 supported_formats = {"csv", "json", *self.format}
253 for req_format in formats:
254 if req_format not in supported_formats:
255 Operation._raise_unprocessable(f"unsupported format {req_format!r}")
257 def media_type_to_format(self) -> dict[str, str]:
258 if not self._is_builtin_param_active("format"):
259 return {}
260 default_token = self.i["default_format"].strip() if "default_format" in self.i else "json"
261 media_type_to_token: dict[str, str] = {}
262 for token in [default_token, "json", "csv", *self.format]:
263 if token in self.format_media_types:
264 media_type = self.format_media_types[token]
265 else:
266 media_type = media_type_for_format(token)
267 if media_type is not None:
268 media_type_to_token.setdefault(media_type, token)
269 return media_type_to_token
271 def conv(self, s: str, query_string: dict[str, list[str]], c_type: str = "text/csv") -> tuple[str, str]:
272 """This method takes a string representing a CSV document and converts it in the requested format according
273 to what content type is specified as input."""
275 content_type = Operation.get_content_type(c_type)
277 if "format" in query_string and self._is_builtin_param_active("format"):
278 self._validate_format_values(query_string["format"])
280 resolved = self._resolve_format(s, query_string)
281 if resolved is not None:
282 return resolved
284 if "format" in query_string and self._is_builtin_param_active("format"):
285 for req_format in query_string["format"]:
286 content_type = Operation.get_content_type(req_format)
287 elif "default_format" in self.i:
288 content_type = Operation.get_content_type(self.i["default_format"].strip())
290 if content_type not in ("text/csv", "application/json"):
291 content_type = "text/csv"
293 if "application/json" in content_type:
294 with StringIO(s) as f:
295 r = [dict(i) for i in DictReader(f)]
297 if self._is_builtin_param_active("json"):
298 r = Operation.structured(query_string, r) # type: ignore[arg-type]
300 return dumps(r, ensure_ascii=False, indent=4), content_type
301 else:
302 return s, content_type
304 @staticmethod
305 def pv(i: int | tuple[object, str], r: list[tuple[object, str]] | None = None) -> str:
306 """This method returns the plain value of a particular item 'i' of the result returned by the SPARQL query.
308 In case 'r' is specified (i.e. a row containing a set of results), then 'i' must be the index of the item
309 within that row."""
310 if r is None:
311 return i[1] # type: ignore[index]
312 return Operation.pv(r[i]) # type: ignore[index]
314 @staticmethod
315 def tv(i: int | tuple[object, str], r: list[tuple[object, str]] | None = None) -> object:
316 """This method returns the typed value of a particular item 'i' of the result returned by the SPARQL query.
317 The type associated to that value is actually specified by means of the particular configuration provided
318 in the specification file of the API - field 'field_type'.
320 In case 'r' is specified (i.e. a row containing a set of results), then 'i' must be the index of the item
321 within that row."""
322 if r is None:
323 return i[0] # type: ignore[index]
324 return Operation.tv(r[i]) # type: ignore[index]
326 @staticmethod
327 def do_overlap(r1: tuple[int, int], r2: tuple[int, int]) -> bool:
328 """This method returns a boolean that says if the two ranges (i.e. two pairs of integers) passed as inputs
329 actually overlap one with the other."""
330 r1_s, r1_e = r1
331 r2_s, r2_e = r2
333 return r1_s <= r2_s <= r1_e or r2_s <= r1_s <= r2_e
335 @staticmethod
336 def get_item_in_dict(
337 d_or_l: dict[str, object] | list[dict[str, object]], key_list: list[str], prev: list[object] | None = None
338 ) -> list[object]:
339 """This method takes as input a dictionary or a list of dictionaries and browses it until the value
340 specified following the chain indicated in 'key_list' is not found. It returns a list of all the
341 values that matched with such search."""
342 res = [] if prev is None else prev.copy()
344 d_list = [d_or_l] if isinstance(d_or_l, dict) else d_or_l
346 for d in d_list:
347 key_list_len = len(key_list)
349 if key_list_len >= 1:
350 key = key_list[0]
351 if key in d:
352 if key_list_len == 1:
353 res.append(d[key])
354 else:
355 res = Operation.get_item_in_dict(d[key], key_list[1:], res) # type: ignore[arg-type]
357 return res
359 @staticmethod
360 def add_item_in_dict(
361 d_or_l: dict[str, object] | list[dict[str, object]], key_list: list[str], item: object, idx: int
362 ) -> None:
363 """This method takes as input a dictionary or a list of dictionaries, browses it until the value
364 specified following the chain indicated in 'key_list' is not found, and then substitutes it with 'item'.
365 In case the final object retrieved is a list, it selects the object in position 'idx' before the
366 substitution."""
367 key_list_len = len(key_list)
369 if key_list_len >= 1:
370 key = key_list[0]
372 if isinstance(d_or_l, list):
373 if key_list_len == 1:
374 d_or_l[idx][key] = item
375 else:
376 for i in d_or_l:
377 Operation.add_item_in_dict(i, key_list, item, idx)
378 elif key in d_or_l:
379 if key_list_len == 1:
380 d_or_l[key] = item
381 else:
382 Operation.add_item_in_dict(d_or_l[key], key_list[1:], item, idx) # type: ignore[arg-type]
384 @staticmethod
385 def _apply_array_transform(row: dict[str, object], keys: list[str], separator: str, v_list: list[object]) -> None:
386 for idx, v in enumerate(v_list):
387 if isinstance(v, str):
388 Operation.add_item_in_dict(row, keys, v.split(separator) if v != "" else [], idx)
390 @staticmethod
391 def _apply_dict_transform(
392 row: dict[str, object], keys: list[str], separator: str, entries: list[str], v_list: list[object]
393 ) -> None:
394 new_fields = entries[1:]
395 new_fields_max_split = len(new_fields) - 1
396 for idx, v in enumerate(v_list):
397 if isinstance(v, str):
398 new_values = v.split(separator, new_fields_max_split)
399 Operation.add_item_in_dict(
400 row,
401 keys,
402 dict(zip(new_fields, new_values, strict=False)) if v != "" else {},
403 idx,
404 )
405 elif isinstance(v, list):
406 new_list = [dict(zip(new_fields, i.split(separator, new_fields_max_split), strict=False)) for i in v]
407 Operation.add_item_in_dict(row, keys, new_list, idx)
409 @staticmethod
410 def structured(params: dict[str, list[str]], json_table: list[dict[str, object]]) -> list[dict[str, object]]:
411 """This method checks if there are particular transformation rules specified in 'params' for a JSON output,
412 and convert each row of the input table ('json_table') according to these rules.
413 There are two specific rules that can be applied:
415 1. array("<separator>",<field>): it converts the string value associated to the field name '<field>' into
416 an array by splitting the various textual parts by means of '<separator>'. For instance, consider the
417 following JSON structure:
419 [
420 { "names": "Doe, John; Doe, Jane" },
421 { "names": "Doe, John; Smith, John" }
422 ]
424 Executing the rule 'array("; ",names)' returns the following new JSON structure:
426 [
427 { "names": [ "Doe, John", "Doe, Jane" ],
428 { "names": [ "Doe, John", "Smith, John" ]
429 ]
431 2. dict("separator",<field>,<new_field_1>,<new_field_2>,...): it converts the string value associated to
432 the field name '<field>' into an dictionary by splitting the various textual parts by means of
433 '<separator>' and by associating the new fields '<new_field_1>', '<new_field_2>', etc., to these new
434 parts. For instance, consider the following JSON structure:
436 [
437 { "name": "Doe, John" },
438 { "name": "Smith, John" }
439 ]
441 Executing the rule 'array(", ",name,family_name,given_name)' returns the following new JSON structure:
443 [
444 { "name": { "family_name": "Doe", "given_name: "John" } },
445 { "name": { "family_name": "Smith", "given_name: "John" } }
446 ]
448 Each of the specified rules is applied in order, and it works on the JSON structure returned after
449 the execution of the previous rule."""
450 if "json" in params:
451 fields = params["json"]
452 for field in fields:
453 op_match = fullmatch(_JSON_TRANSFORM_RE, field)
454 if op_match is None:
455 Operation._raise_unprocessable(f"invalid json transform {field!r}")
456 op_type = op_match.group("op_type")
457 quoted_separator = op_match.group("separator")
458 entries = [i.strip() for i in op_match.group("entries").split(",")]
459 separator = quoted_separator[1:-1]
460 if op_type == "array" and len(entries) != 1:
461 Operation._raise_unprocessable(f"json array transform expects one field, got {field!r}")
462 if op_type == "dict" and len(entries) < _DICT_TRANSFORM_MIN_FIELD_COUNT:
463 Operation._raise_unprocessable(f"json dict transform expects output fields, got {field!r}")
464 keys = entries[0].split(".")
466 for row in json_table:
467 v_list = Operation.get_item_in_dict(row, keys)
468 if op_type == "array":
469 Operation._apply_array_transform(row, keys, separator, v_list)
470 elif op_type == "dict":
471 Operation._apply_dict_transform(row, keys, separator, entries, v_list)
473 return json_table
475 def preprocess(
476 self, par_dict: dict[str, object], op_item: dict[str, str], addon: types.ModuleType
477 ) -> dict[str, object]:
478 """This method takes the a dictionary of parameters with the current typed values associated to them and
479 the item of the API specification defining the behaviour of that operation, and preprocesses the parameters
480 according to the functions specified in the '#preprocess' field (e.g. "#preprocess lower(doi)"), which is
481 applied to the specified parameters as input of the function in consideration (e.g.
482 "/api/v1/citations/10.1108/jd-12-2013-0166", converting the DOI in lowercase).
484 It is possible to run multiple functions sequentially by concatenating them with "-->" in the API
485 specification document. In this case the output of the function f_i will becomes the input operation URL
486 of the function f_i+1.
488 Finally, it is worth mentioning that all the functions specified in the "#preprocess" field must return
489 a tuple of values defining how the particular value passed in the dictionary must be changed."""
490 result = par_dict
492 if "preprocess" in op_item:
493 for pre in [sub(r"\s+", "", i) for i in op_item["preprocess"].split(" --> ")]:
494 func_name = sub(r"^([^\(\)]+)\(.+$", r"\1", pre).strip()
495 params_name = sub(r"^.+\(([^\(\)]+)\).*", r"\1", pre).split(",")
497 param_list = tuple(result[param_name] for param_name in params_name)
499 # run function
500 func = getattr(addon, func_name)
501 res = func(*param_list)
503 # substitute res to the current parameter in result
504 for idx, val in enumerate(res):
505 result[params_name[idx]] = val
507 return result
509 def postprocess(
510 self, res: list[list[str] | list[tuple[object, str]]], op_item: dict[str, str], addon: types.ModuleType
511 ) -> list[list[str] | list[tuple[object, str]]]:
512 """This method takes the result table returned by running the SPARQL query in an API operation (specified
513 as input) and change some of such results according to the functions specified in the '#postprocess'
514 field (e.g. "#postprocess remove_date("2018")"). These functions can take parameters as input, while the first
515 unspecified parameters will be always the result table. It is worth mentioning that this result table (i.e.
516 a list of tuples) actually contains, in each cell, a tuple defining the plain value as well as the typed
517 value for enabling better comparisons and operations if needed. An example of this table of result is shown as
518 follows:
520 [
521 ("id", "date"),
522 ("my_id_1", "my_id_1"), (datetime(2018, 3, 2), "2018-03-02"),
523 ...
524 ]
526 Note that the typed value and the plain value of each cell can be selected by using the methods "tv" and "pv"
527 respectively. In addition, it is possible to run multiple functions sequentially by concatenating them
528 with "-->" in the API specification document. In this case the output of the function f_i will becomes
529 the input result table of the function f_i+1."""
530 result = res
532 if "postprocess" in op_item:
533 for post in [i.strip() for i in op_item["postprocess"].split(" --> ")]:
534 func_name = sub(r"^([^\(\)]+)\(.+$", r"\1", post).strip()
535 param_str = sub(r"^.+\(([^\(\)]*)\).*", r"\1", post)
536 params_values = () if param_str == "" else next(reader(param_str.splitlines(), skipinitialspace=True))
538 func = getattr(addon, func_name)
539 func_params = (result, *tuple(params_values))
540 result, do_type_fields = func(*func_params)
541 if do_type_fields:
542 result = self.type_fields(result, op_item)
544 return result
546 @staticmethod
547 def _apply_require(
548 header: list[str], result: list[list[tuple[object, str]]], fields: list[str]
549 ) -> list[list[tuple[object, str]]]:
550 """Exclude rows with empty values in the specified fields."""
551 for field in fields:
552 if field not in header:
553 Operation._raise_unprocessable(f"require field {field!r} is not in the result header")
554 field_idx = header.index(field)
555 result = [row for row in result if Operation.pv(field_idx, row) not in (None, "")]
556 return result
558 def _apply_filter(
559 self, header: list[str], result: list[list[tuple[object, str]]], fields: list[str]
560 ) -> list[list[tuple[object, str]]]:
561 """Filter rows by comparison operators or regex patterns."""
562 for field in fields:
563 if ":" not in field:
564 Operation._raise_unprocessable(f"filter must use field:value syntax, got {field!r}")
565 field_name, field_value = field.split(":", 1)
566 if field_name not in header:
567 Operation._raise_unprocessable(f"filter field {field_name!r} is not in the result header")
568 if field_value == "":
569 Operation._raise_unprocessable(f"filter value for field {field_name!r} must not be empty")
570 field_idx = header.index(field_name)
571 if not result:
572 continue
573 flag = field_value[0]
574 if flag in ("<", ">", "="):
575 value = field_value[1:].lower()
576 if value == "":
577 Operation._raise_unprocessable(
578 f"filter comparison value for field {field_name!r} must not be empty"
579 )
580 try:
581 typed_value = self.dt.get_func(type(Operation.tv(field_idx, result[0])).__name__)(value)
582 except ValueError:
583 Operation._raise_unprocessable(f"filter value {value!r} is invalid for field {field_name!r}")
584 result = [
585 row
586 for row in result
587 if self.operation[flag](
588 Operation.tv(field_idx, row),
589 typed_value,
590 )
591 ]
592 else:
593 pattern = field_value.lower()
594 try:
595 result = [row for row in result if search(pattern, Operation.pv(field_idx, row).lower())]
596 except regex_error as exc:
597 Operation._raise_unprocessable(f"filter regex for field {field_name!r} is invalid: {exc}")
598 return result
600 @staticmethod
601 def _apply_sort(
602 header: list[str], result: list[list[tuple[object, str]]], fields: list[str]
603 ) -> list[list[tuple[object, str]]]:
604 """Sort rows by the specified fields and directions."""
605 for field in sorted(fields, reverse=True):
606 order_match = fullmatch(r"(?P<direction>desc|asc)\((?P<field_name>[^()]+)\)", field)
607 if order_match:
608 direction = order_match.group("direction")
609 field_name = order_match.group("field_name")
610 else:
611 if "(" in field or ")" in field:
612 Operation._raise_unprocessable(f"invalid sort expression {field!r}")
613 direction, field_name = "asc", field
614 if field_name not in header:
615 Operation._raise_unprocessable(f"sort field {field_name!r} is not in the result header")
616 field_idx = header.index(field_name)
617 result = sorted(result, key=itemgetter(field_idx), reverse=(direction == "desc"))
618 return result
620 def handling_params(
621 self, params: dict[str, list[str]], table: list[list[str] | list[tuple[object, str]]]
622 ) -> list[list[str] | list[tuple[object, str]]]:
623 """This method is used for filtering the results that are returned after the post-processing
624 phase. In particular, it is possible to:
626 1. [require=<field_name>] exclude all the rows that have an empty value in the field specified - e.g. the
627 "require=doi" remove all the rows that do not have any string specified in the "doi" field;
629 2. [filter=<field_name>:<operator><value>] consider only the rows where the string in the input field
630 is compliant with the value specified. If no operation is specified, the value is interpreted as a
631 regular expression, otherwise it is compared according to the particular type associated to that field.
632 Possible operators are "=", "<", and ">" - e.g. "filter=title:semantics?" returns all the rows that contain
633 the string "semantic" or "semantics" in the field title, while "filter=date:>2016-05" returns all the rows
634 that have a date greater than May 2016;
636 3. [sort=<order>(<field_name>)] sort all the results according to the value and type of the particular
637 field specified in input. It is possible to sort the rows either in ascending ("asc") or descending
638 ("desc") order - e.g. "sort=desc(date)" sort all the rows according to the value specified in the
639 field "date" in descending order.
641 Note that these filtering operations are applied in the order presented above - first the "require", then
642 the "filter", and finally the "sort". It is possible to specify one or more filtering operation of the
643 same kind (e.g. "require=doi&require=title").
644 """
645 header = table[0]
646 result = table[1:]
648 overridden = set(self.custom_params) | self.disabled_params
650 if ("exclude" in params or "require" in params) and "require" not in overridden and "exclude" not in overridden:
651 fields = params["exclude"] if "exclude" in params else params["require"]
652 result = self._apply_require(header, result, fields) # type: ignore[arg-type]
654 if "filter" in params and "filter" not in overridden:
655 result = self._apply_filter(header, result, params["filter"]) # type: ignore[arg-type]
657 if "sort" in params and "sort" not in overridden:
658 result = self._apply_sort(header, result, params["sort"]) # type: ignore[arg-type]
660 return [header, *result]
662 def type_fields(
663 self, res: list[list[str] | list[tuple[object, str]] | list[str | object]], op_item: dict[str, str]
664 ) -> list[list[str] | list[tuple[object, str]]]:
665 """It creates a version of the results 'res' that adds, to each value of the fields, the same value interpreted
666 with the type specified in the specification file (field 'field_type'). Note that 'str' is used as default in
667 case no further specifications are provided."""
668 result = []
669 cast_func = {}
670 header = res[0]
671 for heading in header:
672 cast_func[heading] = DataType.str
674 if "field_type" in op_item:
675 for f, p in findall(FIELD_TYPE_RE, op_item["field_type"]):
676 cast_func[p] = self.dt.get_func(f)
678 for row in res[1:]:
679 new_row = []
680 for idx, heading in enumerate(header):
681 cur_value = row[idx]
682 if isinstance(cur_value, tuple):
683 cur_value = cur_value[1]
684 new_row.append((cast_func[heading](cur_value), cur_value))
685 result.append(new_row)
687 return [header, *result] # type: ignore[return-value]
689 def remove_types(self, res: list[list[str] | list[tuple[object, str]]]) -> ResultTable:
690 """This method takes the results 'res' that include also the typed value and returns a version of such
691 results without the types that is ready to be stored on the file system."""
692 result: ResultTable = [cast("list[str]", res[0])]
693 result.extend([Operation.pv(idx, row) for idx in range(len(row))] for row in res[1:]) # type: ignore[arg-type]
694 return result
696 @staticmethod
697 def _is_directive(line: str) -> bool:
698 return line.strip().startswith("@@")
700 @staticmethod
701 def _parse_directive_args(
702 tokens: list[str], param_names: list[str], defaults: dict[str, str] | None = None
703 ) -> dict[str, str]:
704 defaults = defaults or {}
705 all_names = set(param_names) | set(defaults)
706 result = {}
707 positional_index = 0
708 seen_keyword = False
710 for token in tokens:
711 if "=" in token:
712 key, value = token.split("=", 1)
713 if key in all_names:
714 if key in result:
715 msg = f"Duplicate parameter {key!r}"
716 raise ValueError(msg)
717 seen_keyword = True
718 result[key] = value
719 continue
720 if seen_keyword:
721 msg = f"Positional argument {token!r} cannot follow keyword argument"
722 raise ValueError(msg)
723 if positional_index >= len(param_names):
724 msg = f"Unexpected argument {token!r}"
725 raise ValueError(msg)
726 result[param_names[positional_index]] = token
727 positional_index += 1
729 for name, default in defaults.items():
730 if name not in result:
731 result[name] = default
733 missing = [name for name in param_names if name not in result]
734 if missing:
735 msg = f"Missing required parameter(s): {', '.join(missing)}"
736 raise ValueError(msg)
738 return result
740 def _handle_directive_with(self, parts: list[str]) -> tuple[str | None, str, None]:
741 args = Operation._parse_directive_args(
742 parts[1:],
743 ["source"],
744 defaults={"source": "", "endpoint": "", "engine": "sparql"},
745 )
746 name = args["source"].strip()
747 endpoint = args["endpoint"].strip()
748 engine = args["engine"].strip().lower()
749 if engine not in {"sparql", "sparql-anything"}:
750 msg = f"Unknown engine '{args['engine']}' in @@with"
751 raise ValueError(msg)
752 if name and endpoint:
753 msg = "@@with cannot combine source and endpoint"
754 raise ValueError(msg)
755 endpoint_was_given = any(token.startswith("endpoint=") for token in parts[1:])
756 if endpoint_was_given and not endpoint:
757 msg = "@@with endpoint cannot be empty"
758 raise ValueError(msg)
759 if engine == "sparql" and not name and not endpoint:
760 msg = "@@with source or endpoint is required when engine=sparql"
761 raise ValueError(msg)
762 if name and name not in self.sources_map:
763 msg = f"Unknown source '{name}' in @@with; declare it in #sources."
764 raise ValueError(msg)
765 if endpoint:
766 return endpoint, engine, None
767 if name:
768 return self.sources_map[name], engine, None
769 return None, engine, None
771 @staticmethod
772 def _handle_directive_join(parts: list[str]) -> tuple[None, None, tuple[str, str, str, str]]:
773 args = Operation._parse_directive_args(parts[1:], ["left_var", "right_var"], defaults={"type": "inner"})
774 return None, None, ("JOIN", args["left_var"], args["right_var"], args["type"].lower())
776 @staticmethod
777 def _handle_directive_values(parts: list[str]) -> tuple[None, None, tuple[str, list[str]]]:
778 tokens = parts[1:]
779 if not tokens:
780 msg = "@@values needs at least one variable"
781 raise ValueError(msg)
782 return None, None, ("VALUES_INJECT", tokens)
784 @staticmethod
785 def _handle_directive_foreach(parts: list[str]) -> tuple[None, None, tuple[str, str, str, float]]:
786 args = Operation._parse_directive_args(parts[1:], ["variable", "placeholder"], defaults={"wait": "0"})
787 var_name = args["variable"]
788 if not var_name.startswith("?"):
789 msg = f"@@foreach variable must start with '?', got {var_name!r}"
790 raise ValueError(msg)
791 try:
792 delay = float(args["wait"])
793 except ValueError:
794 msg = f"Invalid wait value in @@foreach: {args['wait']!r}"
795 raise ValueError(msg) from None
796 return None, None, ("FOREACH", var_name, args["placeholder"], delay)
798 @staticmethod
799 def _handle_directive_page(parts: list[str]) -> tuple[None, None, tuple[str, str, str, str]]:
800 args = Operation._parse_directive_args(parts[1:], ["variable"], defaults={"default_size": "", "max_size": ""})
801 var_name = args["variable"]
802 if not var_name.startswith("?"):
803 msg = f"@@page variable must start with '?', got {var_name!r}"
804 raise ValueError(msg)
805 return None, None, ("PAGE", var_name, args["default_size"], args["max_size"])
807 def _process_directive(
808 self, line: str, directive_handlers: dict[str, Callable[..., object]]
809 ) -> tuple[str | None, str | None, tuple[str, ...] | None]:
810 body = line.strip()[2:].strip()
811 parts = body.split()
812 cmd = parts[0].lower()
814 handler = directive_handlers.get(cmd)
815 if handler is None:
816 msg = f"Unknown directive @@{cmd}"
817 raise ValueError(msg)
819 return handler(parts) # type: ignore[return-value]
821 @staticmethod
822 def _update_directive_state(
823 current_endpoint: str,
824 current_engine: str,
825 new_endpoint: str | None,
826 new_engine: str | None,
827 ) -> tuple[str, str]:
828 if new_endpoint is not None:
829 current_endpoint = new_endpoint
830 if new_engine is not None:
831 current_engine = new_engine
832 return current_endpoint, current_engine
834 def _parse_steps(self, text: str, default_endpoint: str, params: dict[str, object]) -> list[tuple[str, ...]]:
835 """
836 Returns a list of steps:
837 - ("QUERY", endpoint_url, engine, query_text)
838 - ("JOIN", left_var, right_var, how) # how in {"inner","left"}
839 - ("REMOVE", [vars])
840 - ("VALUES_INJECT", [vars]) # @@values ?var1 ?var2 ...
841 - ("FOREACH", var_name, placeholder, delay) # @@foreach ?var placeholder [wait=N]
842 - ("PAGE", var_name, default_size, max_size) # @@page ?var [default_size=N] [max_size=M]
843 """
844 for p, v in params.items():
845 text = text.replace(f"[[{p}]]", str(v))
846 steps: list[tuple[str, ...]] = []
847 cur_query: list[str] = []
848 current_endpoint = default_endpoint
849 current_engine = "sparql"
851 directive_handlers = {
852 "with": self._handle_directive_with,
853 "join": self._handle_directive_join,
854 "remove": lambda parts: (None, None, ("REMOVE", parts[1:])),
855 "values": self._handle_directive_values,
856 "foreach": self._handle_directive_foreach,
857 "page": self._handle_directive_page,
858 }
860 def flush_query() -> None:
861 if cur_query:
862 q = "\n".join(cur_query).strip()
863 if not q:
864 cur_query.clear()
865 return
866 for p, v in params.items():
867 q = q.replace(f"[[{p}]]", str(v))
868 steps.append(("QUERY", current_endpoint, current_engine, q))
869 cur_query.clear()
871 for raw in text.splitlines():
872 line = raw.rstrip("\n")
873 if not self._is_directive(line):
874 cur_query.append(line)
875 continue
877 flush_query()
879 new_endpoint, new_engine, step = self._process_directive(line, directive_handlers)
880 current_endpoint, current_engine = Operation._update_directive_state(
881 current_endpoint,
882 current_engine,
883 new_endpoint,
884 new_engine,
885 )
886 if step is not None:
887 steps.append(step)
889 flush_query()
890 return steps
892 def _send_sparql_csv_request(self, endpoint_url: str, query_text: str) -> Response:
893 headers = {
894 "Accept": "text/csv",
895 "User-Agent": "RAMOSE/2.0.0",
896 **backend_auth_header(endpoint_url),
897 }
898 if self.sparql_http_method == "get":
899 return _http_session.get(
900 endpoint_url + "?query=" + quote(query_text),
901 headers=headers,
902 timeout=DEFAULT_HTTP_TIMEOUT,
903 )
904 return _http_session.post(
905 endpoint_url,
906 data=query_text,
907 headers={
908 **headers,
909 "Content-Type": "application/sparql-query",
910 },
911 timeout=DEFAULT_HTTP_TIMEOUT,
912 )
914 def _request_sparql_csv(self, endpoint_url: str, query_text: str) -> Response:
915 retry_wait = self.retry_wait
916 for attempt in range(self.retry_attempts):
917 try:
918 response = self._send_sparql_csv_request(endpoint_url, query_text)
919 except (RequestsTimeout, TimeoutError) as exc:
920 if attempt + 1 == self.retry_attempts:
921 msg = f"HTTP status code 408: SPARQL request timeout: {exc}"
922 raise HttpError(HTTPStatus.REQUEST_TIMEOUT, msg) from exc
923 self._sleep_before_retry(retry_wait)
924 retry_wait *= self.retry_backoff
925 continue
926 except RequestException as exc:
927 if attempt + 1 == self.retry_attempts:
928 msg = f"HTTP status code 502: SPARQL request failed: {exc}"
929 raise HttpError(HTTPStatus.BAD_GATEWAY, msg) from exc
930 self._sleep_before_retry(retry_wait)
931 retry_wait *= self.retry_backoff
932 continue
934 response.encoding = "utf-8"
935 if response.status_code not in _RETRYABLE_STATUS_CODES or attempt + 1 == self.retry_attempts:
936 return response
937 self._sleep_before_retry(retry_wait)
938 retry_wait *= self.retry_backoff
940 msg = "SPARQL request did not run"
941 raise RuntimeError(msg)
943 @staticmethod
944 def _sleep_before_retry(retry_wait: float) -> None:
945 if retry_wait:
946 time.sleep(retry_wait)
948 def _run_sparql_dicts(self, endpoint_url: str, query_text: str) -> list[dict[str, object]]:
949 r = self._request_sparql_csv(endpoint_url, query_text)
950 r.encoding = "utf-8"
951 if r.status_code != HTTPStatus.OK:
952 msg = f"SPARQL {r.status_code}: {r.reason}"
953 raise RuntimeError(msg)
954 text = r.content.decode("utf-8-sig", errors="replace")
955 list_of_lines = text.splitlines()
956 return list(DictReader(list_of_lines)) # type: ignore[return-value]
958 @staticmethod
959 def _normalize_sparql_json_resultset(result: dict[str, object]) -> list[dict[str, object]]:
960 """Convert a SPARQL JSON ResultSet dict to a list of flat dicts."""
961 vars_ = result["head"].get("vars") or [] # type: ignore[union-attr]
962 return [
963 {v: (b[v].get("value") if isinstance(b.get(v), dict) else b.get(v)) for v in vars_}
964 for b in result["results"].get("bindings", []) # type: ignore[union-attr]
965 ]
967 @staticmethod
968 def _normalize_columnar_dict(result: dict[str, object]) -> list[dict[str, object]]:
969 """Convert a column-oriented dict {col: [values]} to a list of row dicts."""
970 cols = list(result.keys())
971 max_len = max((len(v) for v in result.values() if isinstance(v, (list, tuple))), default=0)
973 if not max_len:
974 return [result]
976 rows = []
977 for i in range(max_len):
978 row = {}
979 for c in cols:
980 v = result[c]
981 row[c] = (
982 v[i]
983 if isinstance(v, (list, tuple)) and i < len(v)
984 else (v if not isinstance(v, (list, tuple)) else None)
985 )
986 rows.append(row)
987 return rows
989 @staticmethod
990 def _sparql_anything_error_status(exc: Exception) -> int | None:
991 message = str(exc)
992 status_match = search(_SPARQL_ANYTHING_HTTP_STATUS_RE, message)
993 if status_match is not None:
994 return int(status_match.group(1))
996 message = message.lower()
997 if any(marker in message for marker in _SPARQL_ANYTHING_TIMEOUT_MARKERS):
998 return int(HTTPStatus.REQUEST_TIMEOUT)
999 if any(marker in message for marker in _SPARQL_ANYTHING_NETWORK_MARKERS):
1000 return int(HTTPStatus.BAD_GATEWAY)
1001 return None
1003 @staticmethod
1004 def _raise_sparql_anything_error(status_code: int, exc: Exception) -> NoReturn:
1005 msg = f"HTTP status code {status_code}: SPARQL Anything request failed: {exc}"
1006 raise HttpError(status_code, msg) from exc
1008 def _request_sparql_anything_select(self, kwargs: dict[str, object]) -> object:
1009 if self._sa_engine is None:
1010 msg = "SPARQL Anything engine not initialized"
1011 raise RuntimeError(msg)
1012 sa_engine = cast("SparqlAnythingEngine", self._sa_engine)
1014 retry_wait = self.retry_wait
1015 for attempt in range(self.retry_attempts):
1016 try:
1017 return sa_engine.select(output_type=dict, **kwargs)
1018 except Exception as exc: # noqa: PERF203
1019 status_code = self._sparql_anything_error_status(exc)
1020 if status_code is None:
1021 raise
1022 if status_code not in _RETRYABLE_STATUS_CODES or attempt + 1 == self.retry_attempts:
1023 self._raise_sparql_anything_error(status_code, exc)
1024 self._sleep_before_retry(retry_wait)
1025 retry_wait *= self.retry_backoff
1027 msg = "SPARQL Anything request did not run"
1028 raise RuntimeError(msg)
1030 def _run_sparql_anything_dicts(
1031 self, query_text: str, values: dict[str, str] | None = None
1032 ) -> list[dict[str, object]]:
1033 """
1034 Execute a SPARQL Anything SELECT query via PySPARQL-Anything and return
1035 a list of dicts (one per row), in the same shape as _run_sparql_dicts.
1037 query_text: full SPARQL (Anything) query string
1038 (typically containing SERVICE <x-sparql-anything:...>).
1039 values: optional dict of template parameters for the query
1040 (name -> value), passed to SPARQL Anything's `values=`.
1041 """
1042 if self._sa_engine is None:
1043 if SparqlAnything is None:
1044 msg = "pysparql_anything not installed. Install with: pip install ramose[sparql-anything]"
1045 raise ImportError(msg)
1046 self._sa_engine = SparqlAnything()
1048 kwargs: dict[str, object] = {"query": query_text}
1049 if values:
1050 kwargs["values"] = {str(k): str(v) for k, v in values.items()} # type: ignore[assignment]
1052 result = self._request_sparql_anything_select(kwargs)
1054 # Normalize to list[dict]
1055 if isinstance(result, list):
1056 if result and isinstance(result[0], dict):
1057 return result
1058 return [dict(row) for row in result]
1060 if not isinstance(result, dict):
1061 return [{"result": result}]
1063 # Standard SPARQL JSON ResultSet shape
1064 head = result.get("head")
1065 results_obj = result.get("results")
1066 if isinstance(head, dict) and isinstance(results_obj, dict) and "bindings" in results_obj:
1067 return self._normalize_sparql_json_resultset(result)
1069 # Column-oriented dict or single-row fallback
1070 return self._normalize_columnar_dict(result)
1072 def _run_query_dicts(self, endpoint_url: str, engine: str, query_text: str) -> list[dict[str, object]]:
1073 if engine == "sparql-anything":
1074 return self._run_sparql_anything_dicts(query_text)
1075 if engine != "sparql":
1076 msg = f"Unknown query engine {engine!r}"
1077 raise ValueError(msg)
1078 return self._run_sparql_dicts(endpoint_url, query_text)
1080 def _inject_values_clause(self, query_text: str, vars_: list[str], acc_rows: list[dict[str, object]] | None) -> str:
1081 # None means no prior step ran: leave the query unrestricted.
1082 # An empty list means a prior step matched nothing: keep going so the empty
1083 # accumulator injects an empty VALUES block, which correctly yields zero solutions.
1084 if acc_rows is None:
1085 return query_text
1087 # build distinct tuples for requested vars from the accumulator
1088 cols = [v.lstrip("?") for v in vars_]
1089 tuples, seen = [], set()
1090 for row in acc_rows:
1091 tup = tuple(row.get(c, "") for c in cols)
1092 if all(tup) and tup not in seen:
1093 seen.add(tup)
1094 tuples.append(tup)
1096 # format literals vs IRIs
1097 def fmt(x: object) -> str:
1098 s = str(x)
1099 if s.startswith(("http://", "https://")):
1100 return f"<{s}>"
1101 return '"' + s.replace("\\", "\\\\").replace('"', '\\"') + '"'
1103 head = "VALUES (" + " ".join(vars_) + ") {\n"
1104 body = "\n".join(" (" + " ".join(fmt(v) for v in tup) + ")" for tup in tuples)
1105 tail = "\n}\n"
1107 i = query_text.find("{")
1108 if i == -1:
1109 # no WHERE brace: put VALUES at top (legal SPARQL)
1110 return head + body + tail + query_text
1111 j = i + 1
1112 return query_text[:j] + "\n" + head + body + tail + query_text[j:]
1114 @staticmethod
1115 def _drop_columns(rows: list[dict[str, object]], vars_: list[str]) -> list[dict[str, object]]:
1116 if not rows:
1117 return rows
1118 vars_set = {v.lstrip("?") for v in vars_}
1119 return [{k: v for k, v in r.items() if k not in vars_set and ("?" + k) not in vars_set} for r in rows]
1121 def _norm_join_key(self, v: object) -> str | None:
1122 if v is None:
1123 return None
1124 s = str(v).strip()
1125 # unify scheme for w3id IRIs (and similar)
1126 if s.startswith("http://"):
1127 s = "https://" + s[len("http://") :]
1128 # drop a single trailing slash for stability
1129 return s.removesuffix("/")
1131 @staticmethod
1132 def _merge_row(
1133 left_row: dict[str, object], right_row: dict[str, object], right_cols: list[str]
1134 ) -> dict[str, object]:
1135 merged = dict(left_row)
1136 for col in right_cols:
1137 right_val = right_row.get(col)
1138 if right_val is None:
1139 continue
1140 if col not in merged or merged[col] in ("", None):
1141 merged[col] = right_val
1142 else:
1143 alt = f"{col}_r"
1144 if alt not in merged or merged[alt] in ("", None):
1145 merged[alt] = right_val
1146 return merged
1148 def _join(
1149 self,
1150 left_rows: list[dict[str, object]] | None,
1151 right_rows: list[dict[str, object]] | None,
1152 lkey: str,
1153 rkey: str,
1154 how: str = "inner",
1155 ) -> list[dict[str, object]]:
1156 """
1157 Merge two row sets on lkey (from left_rows) and rkey (from right_rows).
1158 - lkey/rkey may be passed as '?var' or 'var' -> we normalize to bare names.
1159 - Keys are normalized with _norm_join_key (e.g., http -> https, trim slash).
1160 - When 'left', all left rows are preserved even if no match on the right.
1161 - Right-hand columns are copied into the merged row; collisions are avoided.
1162 """
1163 lcol = lkey.lstrip("?")
1164 rcol = rkey.lstrip("?")
1166 left_rows = left_rows or []
1167 right_rows = right_rows or []
1169 rindex: dict[str, list[dict[str, object]]] = {}
1170 for r in right_rows:
1171 rk = self._norm_join_key(r.get(rcol))
1172 if rk is None:
1173 continue
1174 rindex.setdefault(rk, []).append(r)
1176 right_cols = [c for c in (right_rows[0].keys() if right_rows else []) if c != rcol]
1178 out: list[dict[str, object]] = []
1179 for left_row in left_rows:
1180 lk = self._norm_join_key(left_row.get(lcol))
1181 matches = rindex.get(lk, []) # type: ignore[arg-type]
1182 if matches:
1183 out.extend(self._merge_row(left_row, r, right_cols) for r in matches)
1184 elif how == "left":
1185 out.append(dict(left_row))
1186 return out
1188 def _apply_custom_postprocess_params(self, table: ResultTable, q_string: dict[str, list[str]]) -> ResultTable:
1189 for param_name, param_conf in self.custom_params.items():
1190 if param_conf["phase"] != "postprocess":
1191 continue
1192 if param_name in q_string:
1193 handler = getattr(self.addon, param_conf["handler"])
1194 table = handler(table, q_string[param_name])
1195 return table
1197 @property
1198 def _cache_ttl(self) -> int:
1199 if "cache_duration" in self.i:
1200 return int(self.i["cache_duration"])
1201 return self._default_cache_ttl
1203 def _build_cache_key(self, q_string: dict[str, list[str]]) -> str:
1204 presentation_params = {"format", "json"}
1205 if "@@page" not in self.i["sparql"]:
1206 presentation_params |= {"page", "page_size"}
1207 data_params = sorted((name, values) for name, values in q_string.items() if name not in presentation_params)
1208 if data_params:
1209 query_string = "&".join(f"{name}={value}" for name, values in data_params for value in values)
1210 return f"{self.tp}:{self.op_url}?{query_string}"
1211 return f"{self.tp}:{self.op_url}"
1213 def _extract_pagination_params(self, q_string: dict[str, list[str]]) -> tuple[int, int] | None:
1214 page_size_active = self._is_builtin_param_active("page_size")
1215 page_active = self._is_builtin_param_active("page")
1216 if "page_size" not in q_string or not page_size_active:
1217 if page_size_active and page_active and "page" in q_string:
1218 Operation._raise_unprocessable("page requires page_size")
1219 return None
1220 page_size = Operation._parse_positive_int_param(q_string, "page_size")
1221 page = 1
1222 if "page" in q_string and page_active:
1223 page = Operation._parse_positive_int_param(q_string, "page")
1224 return page, page_size
1226 def _has_custom_converter(self, q_string: dict[str, list[str]]) -> bool:
1227 if "format" in q_string and self._is_builtin_param_active("format"):
1228 for req_format in q_string["format"]:
1229 if req_format in self.format:
1230 return True
1231 elif "default_format" in self.i and self.i["default_format"].strip() in self.format:
1232 return True
1233 return False
1235 def _paginate_and_format(
1236 self,
1237 table: ResultTable,
1238 q_string: dict[str, list[str]],
1239 content_type: str,
1240 ) -> tuple[int, str, str]:
1241 has_page_directive = "@@page" in self.i["sparql"]
1242 if self._has_custom_converter(q_string) and not has_page_directive:
1243 self.pagination_info = None
1244 # A @@page step already paginated upstream; do not paginate again here.
1245 elif not has_page_directive:
1246 page_params = self._extract_pagination_params(q_string)
1247 if page_params is not None:
1248 page, page_size = page_params
1249 total_items = len(table) - 1
1250 total_pages = ceil(total_items / page_size)
1251 Operation._validate_page_range(page, total_items, total_pages)
1252 start = (page - 1) * page_size
1253 end = start + page_size
1254 table = [table[0], *table[1 + start : 1 + end]]
1255 self.pagination_info = build_pagination_info(self.op_url, q_string, page, page_size, total_items)
1256 else:
1257 self.pagination_info = None
1259 s_res = StringIO()
1260 writer(s_res).writerows(table)
1261 body, ctype = self.conv(s_res.getvalue(), q_string, content_type)
1263 return 200, body, ctype
1265 def _cache_value(self, rows: ResultTable) -> CachedResult:
1266 pagination: CachedPagination | None = None
1267 if "@@page" in self.i["sparql"] and self.pagination_info is not None:
1268 pagination = {
1269 "page": self.pagination_info.page,
1270 "page_size": self.pagination_info.page_size,
1271 "total_items": self.pagination_info.total_items,
1272 }
1273 return {"rows": rows, "pagination": pagination}
1275 def _format_cached_result(
1276 self,
1277 cached_value: object,
1278 q_string: dict[str, list[str]],
1279 content_type: str,
1280 ) -> tuple[int, str, str]:
1281 entry = cast("CachedResult", cached_value)
1282 if entry["pagination"] is not None:
1283 pagination = entry["pagination"]
1284 self.pagination_info = build_pagination_info(
1285 self.op_url,
1286 q_string,
1287 pagination["page"],
1288 pagination["page_size"],
1289 pagination["total_items"],
1290 )
1291 return self._paginate_and_format(entry["rows"], q_string, content_type)
1293 def _finalize_result(
1294 self, csv_rows: list[list[str]] | list[list[str | object]], content_type: str
1295 ) -> tuple[int, str, str]:
1296 """Run the shared pipeline: type fields, postprocess, filter, remove types, cache, paginate, format."""
1297 q_string = parse_qs(quote(self.url_parsed.query, safe="&="))
1298 res = self.type_fields(csv_rows, self.i) # type: ignore[arg-type]
1299 if self.addon is not None:
1300 res = self.postprocess(res, self.i, self.addon)
1301 res = self.handling_params(q_string, res)
1302 res = self.remove_types(res)
1303 if self.custom_params:
1304 res = self._apply_custom_postprocess_params(res, q_string)
1305 if self._cache is not None and "cache_disable" not in self.i:
1306 self._cache.set(self._build_cache_key(q_string), self._cache_value(res), expire=self._cache_ttl)
1307 return self._paginate_and_format(res, q_string, content_type)
1309 @staticmethod
1310 def _header_from_field_type(op_item: dict[str, str], acc: list[dict[str, object]]) -> list[str]:
1311 # Respect #field_type order if provided, else derive from data
1312 if "field_type" in op_item:
1313 # FIELD_TYPE_RE is global in this file
1314 return [f for (_, f) in findall(FIELD_TYPE_RE, op_item["field_type"])]
1315 # fallback to keys of first row
1316 return list(acc[0].keys()) if acc else []
1318 @staticmethod
1319 def _to_csv_rows(header: list[str], acc: list[dict[str, object]]) -> list[list[object]]:
1320 rows: list[list[object]] = [header] # type: ignore[list-item]
1321 rows.extend([d.get(h, "") for h in header] for d in acc)
1322 return rows
1324 def _extract_params(self, body_params: Mapping[str, object] | None = None) -> dict[str, object]:
1325 """Extract URL parameters (and request body parameters, for write operations) and apply type
1326 conversions based on the operation spec."""
1327 par_dict: dict[str, object] = {}
1328 url_match = match(self.op, self.op_url)
1329 if url_match is None:
1330 msg = f"URL {self.op_url} does not match pattern {self.op}"
1331 raise ValueError(msg)
1332 par_man = url_match.groups()
1333 for idx, par in enumerate(findall("{([^{}]+)}", self.i["url"])):
1334 try:
1335 par_type = self.i[par].split("(")[0]
1336 if par_type in ("str", "iri", "literal"):
1337 par_value = par_man[idx]
1338 else:
1339 par_value = self.dt.get_func(par_type)(par_man[idx])
1340 except KeyError:
1341 par_value = par_man[idx]
1342 par_dict[par] = par_value
1343 if body_params:
1344 par_dict.update(body_params)
1345 return par_dict
1347 def _resolve_preprocess_handler(self, param_name: str, handler: str) -> Callable[[list[str]], dict[str, str]]:
1348 if param_name in self.custom_param_configs:
1349 config = self.custom_param_configs[param_name]
1350 return lambda values: apply_filters(config, values)
1351 return getattr(self.addon, handler)
1353 def _apply_custom_preprocess_params(self, par_dict: dict[str, object]) -> None:
1354 q_string = parse_qs(quote(self.url_parsed.query, safe="&="))
1355 for param_name, param_conf in self.custom_params.items():
1356 if param_conf["phase"] != "preprocess":
1357 continue
1358 if param_name in q_string:
1359 handler = self._resolve_preprocess_handler(param_name, param_conf["handler"])
1360 par_dict.update(handler(q_string[param_name]))
1361 elif param_name not in par_dict:
1362 par_dict[param_name] = ""
1363 for placeholder in findall(r"\[\[(\w+)\]\]", self.i["sparql"]):
1364 if placeholder not in par_dict:
1365 par_dict[placeholder] = ""
1367 def _exec_standard_sparql(self, par_dict: dict[str, object], content_type: str) -> tuple[int, str, str]:
1368 """Execute standard SPARQL queries, handling parameter combinations via cartesian product."""
1369 # Wrap scalar values in lists for cartesian product
1370 par_dict = {k: v if isinstance(v, list) else [v] for k, v in par_dict.items()}
1372 parameters_comb = [
1373 dict(zip(par_dict.keys(), combination, strict=False))
1374 for combination in product(*par_dict.values()) # type: ignore[arg-type]
1375 ]
1377 # Example: {"id":"5","area":["A1","A2"]} -> [{"id":"5","area":"A1"}, {"id":"5","area":"A2"}]
1379 list_of_res = []
1380 include_header_line = True
1381 for comb in parameters_comb:
1382 query = self.i["sparql"]
1383 for param, val in comb.items():
1384 query = query.replace(f"[[{param}]]", str(val))
1386 r = self._request_sparql_csv(self.tp, query)
1388 if r.status_code != HTTPStatus.OK:
1389 return r.status_code, f"HTTP status code {r.status_code}: {r.reason}", "text/plain"
1391 # Re-encode to handle non-UTF8 characters in splitlines
1392 list_of_lines = [line.decode("utf-8") for line in r.text.encode("utf-8").splitlines()]
1394 # Include the CSV header only from the first response
1395 if not include_header_line:
1396 list_of_lines = list_of_lines[1:]
1397 include_header_line = False
1399 list_of_res += list_of_lines
1401 return self._finalize_result(list(reader(list_of_res)), content_type)
1403 def _exec_foreach_query(
1404 self,
1405 endpoint_url: str,
1406 engine: str,
1407 qtxt: str,
1408 foreach: tuple[str, str, float],
1409 acc: list[dict[str, object]] | None,
1410 ) -> list[dict[str, object]]:
1411 """Run one query per distinct value collected from the accumulator (@@foreach)."""
1412 var_name, placeholder, delay = foreach
1413 column = var_name.lstrip("?")
1415 values = []
1416 seen = set()
1417 for row in acc or []:
1418 v = row.get(column)
1419 if v and v not in seen:
1420 seen.add(v)
1421 values.append(v)
1423 all_rows = []
1424 for idx_val, val in enumerate(values):
1425 q_one = qtxt.replace(f"[[{placeholder}]]", str(val))
1426 sub_rows = self._run_query_dicts(endpoint_url, engine, q_one)
1427 if sub_rows:
1428 all_rows.extend(sub_rows)
1429 if delay and idx_val + 1 < len(values):
1430 time.sleep(delay)
1432 return all_rows
1434 def _exec_multi_source_query_step(
1435 self, endpoint_url: str, engine: str, qtxt: str, state: dict[str, object]
1436 ) -> None:
1437 """Handle a QUERY step in the multi-source pipeline."""
1438 if state["pending_foreach"] is not None:
1439 rows = self._exec_foreach_query(endpoint_url, engine, qtxt, state["pending_foreach"], state["acc"]) # type: ignore[arg-type]
1440 state["pending_foreach"] = None
1441 state["pending_values_vars"] = None
1442 else:
1443 if state["pending_values_vars"]:
1444 qtxt = self._inject_values_clause(qtxt, state["pending_values_vars"], state["acc"]) # type: ignore[arg-type]
1445 state["pending_values_vars"] = None
1446 rows = self._run_query_dicts(endpoint_url, engine, qtxt)
1448 if state["acc"] is None:
1449 state["acc"] = rows
1450 elif state["pending_join"]:
1451 lvar, rvar, how = state["pending_join"] # type: ignore[misc]
1452 state["acc"] = self._join(state["acc"], rows, lvar, rvar, how) # type: ignore[arg-type]
1453 state["pending_join"] = None
1454 else:
1455 msg = "Multiple QUERY steps without an explicit @@join directive"
1456 raise ValueError(msg)
1458 def _exec_page_step(self, var: str, default_size: str, max_size: str, state: dict[str, object]) -> None:
1459 q_string = parse_qs(quote(self.url_parsed.query, safe="&="))
1461 page_size_active = self._is_builtin_param_active("page_size")
1462 page_active = self._is_builtin_param_active("page")
1463 explicit_page_size = page_size_active and "page_size" in q_string
1465 if explicit_page_size:
1466 page_size = Operation._parse_positive_int_param(q_string, "page_size")
1467 elif default_size:
1468 page_size = int(default_size)
1469 else:
1470 if page_size_active and page_active and "page" in q_string:
1471 Operation._raise_unprocessable("page requires page_size")
1472 return
1473 if page_size < 1:
1474 msg = f"page_size must be >= 1, got {page_size}"
1475 raise ValueError(msg)
1476 if max_size:
1477 max_page_size = int(max_size)
1478 if explicit_page_size and page_size > max_page_size:
1479 Operation._raise_unprocessable(f"page_size must be <= {max_page_size}, got {page_size}")
1480 page_size = min(page_size, max_page_size)
1482 page = 1
1483 if "page" in q_string and page_active:
1484 page = Operation._parse_positive_int_param(q_string, "page")
1486 column = var.lstrip("?")
1487 acc = state["acc"]
1488 rows = [] if acc is None else cast("list[dict[str, object]]", acc)
1489 distinct = list(dict.fromkeys(row[column] for row in rows if row.get(column)))
1490 total_items = len(distinct)
1491 total_pages = ceil(total_items / page_size)
1492 Operation._validate_page_range(page, total_items, total_pages)
1494 start = (page - 1) * page_size
1495 keep = set(distinct[start : start + page_size])
1496 state["acc"] = [row for row in rows if row.get(column) in keep]
1497 self.pagination_info = build_pagination_info(self.op_url, q_string, page, page_size, total_items)
1499 def _exec_multi_source(self, par_dict: dict[str, object], content_type: str) -> tuple[int, str, str]:
1500 """Execute a multi-source query pipeline with @@ directives."""
1501 steps = self._parse_steps(self.i["sparql"], self.tp, par_dict)
1503 state: dict[str, object] = {
1504 "acc": None,
1505 "pending_join": None,
1506 "pending_values_vars": None,
1507 "pending_foreach": None,
1508 }
1510 for st in steps:
1511 tag = st[0]
1513 if tag == "QUERY":
1514 self._exec_multi_source_query_step(st[1], st[2], st[3], state)
1515 elif tag == "JOIN":
1516 state["pending_join"] = (st[1], st[2], st[3])
1517 elif tag == "REMOVE":
1518 state["acc"] = self._drop_columns(state["acc"] or [], st[1]) # type: ignore[arg-type]
1519 elif tag == "VALUES_INJECT":
1520 state["pending_values_vars"] = st[1]
1521 elif tag == "FOREACH":
1522 state["pending_foreach"] = (st[1], st[2], st[3])
1523 elif tag == "PAGE":
1524 self._exec_page_step(st[1], st[2], st[3], state)
1525 else:
1526 msg = f"Unknown step tag {tag}"
1527 raise RuntimeError(msg)
1529 header = self._header_from_field_type(self.i, state["acc"] or []) # type: ignore[arg-type]
1530 csv_rows = self._to_csv_rows(header, state["acc"] or []) # type: ignore[arg-type]
1531 return self._finalize_result(csv_rows, content_type)
1533 @staticmethod
1534 def _format_error(sc: int, e: Exception, prefix: str = "") -> tuple[int, str, str]:
1535 """Format an error response tuple with traceback line info."""
1536 tb = e.__traceback__
1537 line = tb.tb_lineno if tb else "?"
1538 msg = f"HTTP status code {sc}: {prefix}{type(e).__name__}: {e} (line {line})"
1539 return sc, msg, "text/plain"
1541 def exec(
1542 self,
1543 method: str = "get",
1544 content_type: str = "application/json",
1545 body_params: Mapping[str, object] | None = None,
1546 ) -> tuple[int, str, str, dict[str, str]]:
1547 """This method takes in input the HTTP method to use for the call
1548 and the content type to return, and execute the operation as indicated
1549 in the specification file, by running (in the following order):
1551 1. the methods to preprocess the query;
1552 2. the SPARQL query related to the operation called, by using the parameters indicated in the URL;
1553 3. the specification of all the types of the various rows returned;
1554 4. the methods to postprocess the result;
1555 5. the application of the filter to remove, filter, sort the result;
1556 6. the removal of the types added at the step 3, so as to have a data structure ready to be returned;
1557 7. the conversion in the format requested by the user."""
1558 str_method = method.lower()
1559 if str_method not in self.i["method"].split():
1560 return 405, f"HTTP status code 405: '{str_method}' method not allowed", "text/plain", {}
1562 try:
1563 if self._is_write(str_method):
1564 status, body, ctype = self._exec_update(self._prepare_params(body_params), content_type)
1565 else:
1566 status, body, ctype = self._dispatch_exec(content_type, body_params)
1567 except HttpError as err:
1568 return err.status_code, str(err), "text/plain", {}
1569 except TimeoutError as e:
1570 return *self._format_error(408, e, "request timeout - "), {}
1571 except (TypeError, ValueError) as e:
1572 return *self._format_error(400, e, "parameter in the request not compliant with the type specified - "), {}
1573 except Exception as e: # noqa: BLE001
1574 return *self._format_error(500, e, "something unexpected happened - "), {}
1576 headers = {}
1577 if self.pagination_info is not None:
1578 link_header = build_link_header(self.pagination_info)
1579 if link_header:
1580 headers["Link"] = link_header
1581 return status, body, ctype, headers
1583 def _prepare_params(self, body_params: Mapping[str, object] | None = None) -> dict[str, object]:
1584 par_dict = self._extract_params(body_params)
1585 if self.addon is not None:
1586 self.preprocess(par_dict, self.i, self.addon)
1587 if self.custom_params:
1588 self._apply_custom_preprocess_params(par_dict)
1589 return par_dict
1591 def _dispatch_exec(
1592 self,
1593 content_type: str,
1594 body_params: Mapping[str, object] | None = None,
1595 ) -> tuple[int, str, str]:
1596 """Dispatch to the appropriate read execution path based on the SPARQL text content."""
1597 par_dict = self._prepare_params(body_params)
1599 if self._cache is not None and "cache_disable" not in self.i:
1600 q_string = parse_qs(quote(self.url_parsed.query, safe="&="))
1601 cached_table = self._cache.get(self._build_cache_key(q_string))
1602 if cached_table is not None:
1603 return self._format_cached_result(cached_table, q_string, content_type)
1605 sparql_text = self.i["sparql"]
1606 resolved_text = sparql_text
1607 for param, val in par_dict.items():
1608 resolved_text = resolved_text.replace(f"[[{param}]]", str(val))
1610 if "@@" not in resolved_text:
1611 return self._exec_standard_sparql(par_dict, content_type)
1613 try:
1614 return self._exec_multi_source(par_dict, content_type)
1615 except ValueError as ve:
1616 return 400, f"HTTP status code 400: {ve}", "text/plain"
1617 except RuntimeError as re_err:
1618 return 502, f"HTTP status code 502: {re_err}", "text/plain"
1620 @staticmethod
1621 def _is_write(method: str) -> bool:
1622 return method.lower() in _WRITE_METHODS
1624 @staticmethod
1625 def _escape_literal(value: str) -> str:
1626 value = value.replace("\\", "\\\\").replace('"', '\\"')
1627 return value.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
1629 @staticmethod
1630 def _escape_iri(value: str) -> str:
1631 if search(_IRI_FORBIDDEN, value):
1632 msg = f"invalid IRI value: {value!r}"
1633 raise ValueError(msg)
1634 return value
1636 def _bind_sparql_value(self, param: str, value: object) -> str:
1637 kind = self.i[param].split("(")[0] if param in self.i else "literal"
1638 text = "" if value is None else str(value)
1639 if kind == "iri":
1640 return Operation._escape_iri(text)
1641 if kind in ("int", "float"):
1642 return str(self.dt.get_func(kind)(text))
1643 return Operation._escape_literal(text)
1645 def _format_write_success(self, content_type: str) -> tuple[int, str, str]:
1646 if content_type == "text/csv":
1647 return HTTPStatus.OK, "status,message\r\n200,operation completed\r\n", "text/csv"
1648 return HTTPStatus.OK, dumps({"status": 200, "message": "operation completed"}), "application/json"
1650 def _exec_update(self, par_dict: dict[str, object], content_type: str) -> tuple[int, str, str]:
1651 """Send a SPARQL 1.1 Update to the update endpoint and return a confirmation with no result set."""
1652 update_text = self.i["sparql"]
1653 for param, val in par_dict.items():
1654 update_text = update_text.replace(f"[[{param}]]", self._bind_sparql_value(param, val))
1656 unresolved = findall(r"\[\[(\w+)\]\]", update_text)
1657 if unresolved:
1658 missing = ", ".join(dict.fromkeys(unresolved))
1659 message = f"HTTP status code 400: missing required parameter(s): {missing}"
1660 return HTTPStatus.BAD_REQUEST, message, "text/plain"
1662 endpoint = self.update_endpoint or self.tp
1663 try:
1664 response = _http_session.post(
1665 endpoint,
1666 data={"update": update_text},
1667 headers={"Accept": "application/json", **backend_auth_header(endpoint)},
1668 timeout=DEFAULT_HTTP_TIMEOUT,
1669 )
1670 except RequestException as exc:
1671 msg = f"SPARQL update request failed: {exc}"
1672 raise RuntimeError(msg) from exc
1674 if response.status_code not in (HTTPStatus.OK, HTTPStatus.CREATED, HTTPStatus.NO_CONTENT):
1675 return response.status_code, f"HTTP status code {response.status_code}: {response.reason}", "text/plain"
1677 if self._cache is not None:
1678 self._cache.clear()
1679 return self._format_write_success(content_type)