Coverage for heritrace / extensions.py: 99%

271 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-07-02 10:16 +0000

1# SPDX-FileCopyrightText: 2024-2026 Arcangelo Massari <arcangelo.massari@unibo.it> 

2# 

3# SPDX-License-Identifier: ISC 

4 

5import json 

6import os 

7from collections import defaultdict 

8from dataclasses import dataclass 

9from datetime import datetime, timedelta, timezone 

10from pathlib import Path 

11from typing import cast 

12from urllib.parse import urlparse, urlunparse 

13 

14import yaml 

15from flask import Flask, current_app, g, redirect, session, url_for 

16from flask_babel import Babel 

17from flask_login import LoginManager 

18from flask_login.signals import user_loaded_from_cookie 

19from rdflib import Graph 

20from rdflib_ocdm.counter_handler.counter_handler import CounterHandler 

21from redis import Redis 

22from redis.exceptions import RedisError 

23from SPARQLWrapper import JSON 

24from time_agnostic_library.support import generate_config_file 

25 

26from heritrace.models import User 

27from heritrace.services.resource_lock_manager import ResourceLockManager 

28from heritrace.sparql import SPARQLWrapperWithRetry, get_sparql_bindings, select_results 

29from heritrace.uri_generator.uri_generator import CounterBasedURIGenerator 

30from heritrace.utils.filters import Filter, split_namespace 

31 

32 

33@dataclass(frozen=True) 

34class AppState: 

35 dataset_endpoint: str 

36 provenance_endpoint: str 

37 sparql: SPARQLWrapperWithRetry 

38 provenance_sparql: SPARQLWrapperWithRetry 

39 change_tracking_config: dict 

40 custom_filter: Filter 

41 display_rules: list[dict] 

42 form_fields_cache: dict 

43 dataset_is_quadstore: bool 

44 shacl_graph: Graph 

45 classes_with_multiple_shapes: set[str] 

46 

47 

48def get_app_state() -> AppState: 

49 return current_app.extensions["heritrace"] 

50 

51 

52def init_extensions( 

53 app: Flask, babel: Babel, login_manager: LoginManager, redis: Redis 

54) -> None: 

55 babel.init_app( 

56 app=app, 

57 locale_selector=lambda: session.get("lang", "en"), 

58 default_translation_directories=str( 

59 Path(__file__).resolve().parent.parent / "babel" / "translations" 

60 ), 

61 ) 

62 

63 init_login_manager(app, login_manager) 

64 

65 ( 

66 dataset_endpoint, 

67 provenance_endpoint, 

68 sparql, 

69 provenance_sparql, 

70 change_tracking_config, 

71 ) = init_sparql_services(app) 

72 initialize_counter_handler(app, redis, sparql, provenance_sparql) 

73 

74 app.extensions["heritrace"] = AppState( 

75 dataset_endpoint=dataset_endpoint, 

76 provenance_endpoint=provenance_endpoint, 

77 sparql=sparql, 

78 provenance_sparql=provenance_sparql, 

79 change_tracking_config=change_tracking_config, 

80 custom_filter=cast("Filter", None), 

81 display_rules=[], 

82 form_fields_cache={}, 

83 dataset_is_quadstore=False, 

84 shacl_graph=Graph(), 

85 classes_with_multiple_shapes=set(), 

86 ) 

87 

88 ( 

89 display_rules, 

90 form_fields_cache, 

91 dataset_is_quadstore, 

92 shacl_graph, 

93 classes_with_multiple_shapes, 

94 ) = initialize_global_variables(app) 

95 custom_filter = init_filters(app, display_rules, dataset_endpoint) 

96 init_request_handlers(app, redis) 

97 

98 app.extensions["heritrace"] = AppState( 

99 dataset_endpoint=dataset_endpoint, 

100 provenance_endpoint=provenance_endpoint, 

101 sparql=sparql, 

102 provenance_sparql=provenance_sparql, 

103 change_tracking_config=change_tracking_config, 

104 custom_filter=custom_filter, 

105 display_rules=display_rules, 

106 form_fields_cache=form_fields_cache, 

107 dataset_is_quadstore=dataset_is_quadstore, 

108 shacl_graph=shacl_graph, 

109 classes_with_multiple_shapes=classes_with_multiple_shapes, 

110 ) 

111 app.extensions["login_manager"] = login_manager 

112 app.extensions["redis_client"] = redis 

113 

114 

115def init_login_manager(app: Flask, login_manager: LoginManager) -> None: 

116 login_manager.init_app(app) 

117 login_manager.login_view = "auth.login" # type: ignore[reportAttributeAccessIssue] 

118 login_manager.unauthorized_handler(lambda: redirect(url_for("auth.login"))) 

119 

120 @login_manager.user_loader 

121 def load_user(user_id: str) -> User: 

122 user_name = session.get("user_name", "Unknown User") 

123 return User(user_id=user_id, name=user_name, orcid=user_id) 

124 

125 @user_loaded_from_cookie.connect 

126 def rotate_session_token(_sender: object, _user: object) -> None: 

127 session.modified = True 

128 

129 

130def initialize_change_tracking_config( 

131 app: Flask, 

132 adjusted_dataset_endpoint: str | None = None, 

133 adjusted_provenance_endpoint: str | None = None, 

134) -> dict: 

135 config_needs_generation = False 

136 config_path = None 

137 config = None 

138 

139 if "CHANGE_TRACKING_CONFIG" in app.config: 

140 config_path = app.config["CHANGE_TRACKING_CONFIG"] 

141 if not Path(config_path).exists(): 

142 app.logger.warning( 

143 "Change tracking configuration file not found at specified path: %s", 

144 config_path, 

145 ) 

146 config_needs_generation = True 

147 else: 

148 config_needs_generation = True 

149 config_path = str(Path(app.instance_path) / "change_tracking_config.json") 

150 Path(app.instance_path).mkdir(parents=True, exist_ok=True) 

151 

152 if config_needs_generation: 

153 dataset_urls = [adjusted_dataset_endpoint] if adjusted_dataset_endpoint else [] 

154 provenance_urls = ( 

155 [adjusted_provenance_endpoint] if adjusted_provenance_endpoint else [] 

156 ) 

157 

158 db_triplestore = app.config.get("DATASET_DB_TRIPLESTORE", "").lower() 

159 text_index_enabled = app.config.get("DATASET_DB_TEXT_INDEX_ENABLED", False) 

160 

161 blazegraph_search = db_triplestore == "blazegraph" and text_index_enabled 

162 fuseki_search = db_triplestore == "fuseki" and text_index_enabled 

163 virtuoso_search = db_triplestore == "virtuoso" and text_index_enabled 

164 

165 graphdb_connector = "" # TODO(@arcangelo-massari): Add graphdb support 

166 # https://github.com/opencitations/heritrace/issues/1 

167 

168 try: 

169 config = generate_config_file( 

170 config_path=config_path, 

171 dataset_urls=dataset_urls, 

172 dataset_dirs=app.config.get("DATASET_DIRS", []), 

173 dataset_is_quadstore=app.config.get("DATASET_IS_QUADSTORE", False), 

174 provenance_urls=provenance_urls, 

175 provenance_is_quadstore=app.config.get( 

176 "PROVENANCE_IS_QUADSTORE", False 

177 ), 

178 provenance_dirs=app.config.get("PROVENANCE_DIRS", []), 

179 blazegraph_full_text_search=blazegraph_search, 

180 fuseki_full_text_search=fuseki_search, 

181 virtuoso_full_text_search=virtuoso_search, 

182 graphdb_connector_name=graphdb_connector, 

183 ) 

184 app.logger.info( 

185 "Generated new change tracking configuration at: %s", config_path 

186 ) 

187 except OSError as e: 

188 msg = f"Failed to generate change tracking configuration: {e!s}" 

189 raise RuntimeError(msg) from e 

190 

191 try: 

192 if not config: 

193 with Path(config_path).open(encoding="utf8") as f: 

194 config = json.load(f) 

195 

196 except json.JSONDecodeError as e: 

197 msg = f"Invalid change tracking configuration JSON at {config_path}: {e!s}" 

198 raise RuntimeError(msg) from e 

199 except OSError as e: 

200 msg = f"Error reading change tracking configuration at {config_path}: {e!s}" 

201 raise RuntimeError(msg) from e 

202 

203 app.config["CHANGE_TRACKING_CONFIG"] = config_path 

204 return config 

205 

206 

207def need_initialization(app: Flask, redis: Redis) -> bool: 

208 redis_url = os.environ.get("REDIS_URL", "redis://localhost:6379/0") 

209 is_external_redis = redis_url != "redis://localhost:6379/0" 

210 

211 if is_external_redis: 

212 app.logger.info( 

213 "Using external Redis at %s - skipping counter initialization", redis_url 

214 ) 

215 return False 

216 

217 cache_validity_days = app.config["CACHE_VALIDITY_DAYS"] 

218 

219 try: 

220 last_init_raw: str | None = redis.get("heritrace:last_initialization") # type: ignore[assignment] 

221 if not last_init_raw: 

222 return True 

223 

224 last_init = datetime.fromisoformat(last_init_raw) 

225 return datetime.now(tz=timezone.utc) - last_init > timedelta( 

226 days=cache_validity_days 

227 ) 

228 except (RedisError, ValueError): 

229 return True 

230 

231 

232def update_cache(_app: Flask, redis: Redis) -> None: 

233 current_time = datetime.now(tz=timezone.utc).isoformat() 

234 redis.set("heritrace:last_initialization", current_time) 

235 redis.set("heritrace:cache_version", "1.0") 

236 

237 

238def initialize_counter_handler( 

239 app: Flask, 

240 redis: Redis, 

241 sparql: SPARQLWrapperWithRetry, 

242 provenance_sparql: SPARQLWrapperWithRetry, 

243) -> None: 

244 if not need_initialization(app, redis): 

245 return 

246 

247 uri_generator = app.config["URI_GENERATOR"] 

248 if isinstance(uri_generator, CounterBasedURIGenerator): 

249 uri_generator.initialize_counters(sparql) 

250 

251 counter_handler = app.config["COUNTER_HANDLER"] 

252 

253 prov_query = """ 

254 SELECT ?entity (COUNT(DISTINCT ?snapshot) as ?count) 

255 WHERE { 

256 ?snapshot a <http://www.w3.org/ns/prov#Entity> ; 

257 <http://www.w3.org/ns/prov#specializationOf> ?entity . 

258 OPTIONAL { 

259 ?snapshot <http://www.w3.org/ns/prov#wasDerivedFrom> ?prev . 

260 } 

261 } 

262 GROUP BY ?entity 

263 """ 

264 

265 provenance_sparql.setQuery(prov_query) 

266 provenance_sparql.setReturnFormat(JSON) 

267 prov_bindings = get_sparql_bindings(provenance_sparql.query().convert()) 

268 

269 for result in prov_bindings: 

270 entity = result["entity"]["value"] 

271 count = int(result["count"]["value"]) 

272 counter_handler.set_counter(count, entity) 

273 

274 update_cache(app, redis) 

275 

276 

277def identify_classes_with_multiple_shapes( 

278 display_rules: list[dict], shacl_graph: Graph 

279) -> set[str]: 

280 if not display_rules or not shacl_graph: 

281 return set() 

282 

283 from heritrace.utils.display_rules_utils import ( # noqa: PLC0415 

284 is_entity_type_visible, 

285 ) 

286 

287 class_to_shapes: defaultdict[str, set[str]] = defaultdict(set) 

288 

289 for rule in display_rules: 

290 target = rule.get("target", {}) 

291 

292 if "class" in target: 

293 class_uri = target["class"] 

294 query_string = f""" 

295 SELECT DISTINCT ?shape WHERE {{ 

296 ?shape <http://www.w3.org/ns/shacl#targetClass> <{class_uri}> . 

297 }} 

298 """ 

299 results = shacl_graph.query(query_string) 

300 for row in select_results(results): 

301 shape_uri = str(row.shape) 

302 entity_key = (class_uri, shape_uri) 

303 if is_entity_type_visible(entity_key): 

304 class_to_shapes[class_uri].add(shape_uri) 

305 

306 elif "shape" in target: 

307 shape_uri = target["shape"] 

308 query_string = f""" 

309 SELECT DISTINCT ?class WHERE {{ 

310 <{shape_uri}> <http://www.w3.org/ns/shacl#targetClass> ?class . 

311 }} 

312 """ 

313 results = shacl_graph.query(query_string) 

314 for row in select_results(results): 

315 class_uri = str(row[0]) 

316 entity_key = (class_uri, shape_uri) 

317 if is_entity_type_visible(entity_key): 

318 class_to_shapes[class_uri].add(shape_uri) 

319 

320 return { 

321 class_uri for class_uri, shapes in class_to_shapes.items() if len(shapes) > 1 

322 } 

323 

324 

325def initialize_global_variables( 

326 app: Flask, 

327) -> tuple[list[dict], dict, bool, Graph, set[str]]: 

328 try: 

329 dataset_is_quadstore = app.config.get("DATASET_IS_QUADSTORE", False) 

330 

331 display_rules: list[dict] = [] 

332 if app.config.get("DISPLAY_RULES_PATH"): 

333 if not app.config["DISPLAY_RULES_PATH"].exists(): 

334 app.logger.warning( 

335 "Display rules file not found at: %s", 

336 app.config["DISPLAY_RULES_PATH"], 

337 ) 

338 else: 

339 try: 

340 with app.config["DISPLAY_RULES_PATH"].open() as f: 

341 yaml_content = yaml.safe_load(f) 

342 display_rules = yaml_content["rules"] 

343 except yaml.YAMLError as e: 

344 app.logger.exception("Error loading display rules") 

345 msg = f"Failed to load display rules: {e!s}" 

346 raise RuntimeError(msg) from e 

347 

348 shacl_graph = Graph() 

349 form_fields_cache: dict = {} 

350 if app.config.get("SHACL_PATH"): 

351 if not app.config["SHACL_PATH"].exists(): 

352 app.logger.warning( 

353 "SHACL file not found at: %s", app.config["SHACL_PATH"] 

354 ) 

355 else: 

356 try: 

357 shacl_graph.parse(source=app.config["SHACL_PATH"], format="turtle") 

358 

359 from heritrace.utils.shacl_utils import ( # noqa: PLC0415 

360 get_form_fields_from_shacl, 

361 ) 

362 

363 form_fields_cache = get_form_fields_from_shacl( 

364 shacl_graph, display_rules, app=app 

365 ) 

366 except (OSError, ValueError) as e: 

367 app.logger.exception("Error initializing form fields from SHACL") 

368 msg = f"Failed to initialize form fields: {e!s}" 

369 raise RuntimeError(msg) from e 

370 

371 classes_with_multiple_shapes = identify_classes_with_multiple_shapes( 

372 display_rules, shacl_graph 

373 ) 

374 

375 app.logger.info("Global variables initialized successfully") 

376 

377 except RuntimeError: 

378 raise 

379 except (OSError, yaml.YAMLError, ValueError) as e: 

380 app.logger.exception("Error during global variables initialization") 

381 msg = f"Global variables initialization failed: {e!s}" 

382 raise RuntimeError(msg) from e 

383 else: 

384 return ( 

385 display_rules, 

386 form_fields_cache, 

387 dataset_is_quadstore, 

388 shacl_graph, 

389 classes_with_multiple_shapes, 

390 ) 

391 

392 

393def init_sparql_services( 

394 app: Flask, 

395) -> tuple[str, str, SPARQLWrapperWithRetry, SPARQLWrapperWithRetry, dict]: 

396 dataset_endpoint = adjust_endpoint_url(app.config["DATASET_DB_URL"]) 

397 provenance_endpoint = adjust_endpoint_url(app.config["PROVENANCE_DB_URL"]) 

398 

399 sparql = SPARQLWrapperWithRetry(dataset_endpoint, timeout=30.0) 

400 provenance_sparql = SPARQLWrapperWithRetry(provenance_endpoint, timeout=30.0) 

401 

402 change_tracking_config = initialize_change_tracking_config( 

403 app, 

404 adjusted_dataset_endpoint=dataset_endpoint, 

405 adjusted_provenance_endpoint=provenance_endpoint, 

406 ) 

407 

408 return ( 

409 dataset_endpoint, 

410 provenance_endpoint, 

411 sparql, 

412 provenance_sparql, 

413 change_tracking_config, 

414 ) 

415 

416 

417def init_filters( 

418 app: Flask, display_rules: list[dict], dataset_endpoint: str 

419) -> Filter: 

420 with (Path(__file__).parent / "utils" / "context.json").open() as config_file: 

421 context = json.load(config_file)["@context"] 

422 

423 custom_filter = Filter(context, display_rules or None, dataset_endpoint) 

424 

425 app.jinja_env.filters["human_readable_predicate"] = ( 

426 custom_filter.human_readable_predicate 

427 ) 

428 app.jinja_env.filters["human_readable_class"] = custom_filter.human_readable_class 

429 app.jinja_env.filters["human_readable_entity"] = custom_filter.human_readable_entity 

430 app.jinja_env.filters["human_readable_primary_source"] = ( 

431 custom_filter.human_readable_primary_source 

432 ) 

433 app.jinja_env.filters["format_datetime"] = custom_filter.human_readable_datetime 

434 app.jinja_env.filters["split_ns"] = split_namespace 

435 app.jinja_env.filters["format_source_reference"] = ( 

436 custom_filter.format_source_reference 

437 ) 

438 app.jinja_env.filters["format_agent_reference"] = ( 

439 custom_filter.format_agent_reference 

440 ) 

441 return custom_filter 

442 

443 

444def init_request_handlers(app: Flask, redis: Redis) -> None: 

445 @app.before_request 

446 def initialize_lock_manager() -> None: 

447 if not hasattr(g, "resource_lock_manager"): 

448 g.resource_lock_manager = ResourceLockManager(redis) 

449 

450 @app.teardown_appcontext 

451 def close_redis_connection(_error: BaseException | None) -> None: 

452 if hasattr(g, "resource_lock_manager"): 

453 del g.resource_lock_manager 

454 

455 

456def adjust_endpoint_url(url: str) -> str: 

457 if not running_in_docker(): 

458 return url 

459 

460 local_patterns = ["localhost", "127.0.0.1", "0.0.0.0"] 

461 parsed_url = urlparse(url) 

462 

463 if any(pattern in parsed_url.netloc for pattern in local_patterns): 

464 netloc_parts = parsed_url.netloc.split(":") 

465 new_netloc = ( 

466 f"host.docker.internal:{netloc_parts[1]}" 

467 if len(netloc_parts) > 1 

468 else "host.docker.internal" 

469 ) 

470 url_parts = list(parsed_url) 

471 url_parts[1] = new_netloc 

472 return urlunparse(url_parts) 

473 

474 return url 

475 

476 

477def running_in_docker() -> bool: 

478 return Path("/.dockerenv").exists() 

479 

480 

481def get_dataset_endpoint() -> str: 

482 return get_app_state().dataset_endpoint 

483 

484 

485def get_sparql() -> SPARQLWrapperWithRetry: 

486 return get_app_state().sparql 

487 

488 

489def get_provenance_endpoint() -> str: 

490 return get_app_state().provenance_endpoint 

491 

492 

493def get_provenance_sparql() -> SPARQLWrapperWithRetry: 

494 return get_app_state().provenance_sparql 

495 

496 

497def get_counter_handler() -> CounterHandler: 

498 uri_generator = current_app.config.get("URI_GENERATOR") 

499 if not isinstance(uri_generator, CounterBasedURIGenerator): 

500 current_app.logger.error("CounterHandler not found in URIGenerator config.") 

501 msg = "CounterHandler is not available. Initialization might have failed." 

502 raise TypeError(msg) 

503 return uri_generator.counter_handler 

504 

505 

506def get_custom_filter() -> Filter: 

507 return get_app_state().custom_filter 

508 

509 

510def get_change_tracking_config() -> dict: 

511 return get_app_state().change_tracking_config 

512 

513 

514def get_display_rules() -> list[dict]: 

515 return get_app_state().display_rules 

516 

517 

518def get_form_fields() -> dict: 

519 return get_app_state().form_fields_cache 

520 

521 

522def get_dataset_is_quadstore() -> bool: 

523 return get_app_state().dataset_is_quadstore 

524 

525 

526def get_shacl_graph() -> Graph: 

527 return get_app_state().shacl_graph 

528 

529 

530def get_classes_with_multiple_shapes() -> set[str]: 

531 return get_app_state().classes_with_multiple_shapes