Coverage for test/group_entities_test.py: 99%
307 statements
« prev ^ index » next coverage.py v6.5.0, created at 2025-12-20 08:55 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2025-12-20 08:55 +0000
1import csv
2import os
3import shutil
4import unittest
5from unittest.mock import MagicMock, patch
7import pandas as pd
8from oc_meta.run.merge.group_entities import (
9 UnionFind,
10 get_all_related_entities,
11 get_file_path,
12 group_entities,
13 optimize_groups,
14 save_grouped_entities,
15)
17BASE = os.path.join("test", "group_entities")
18OUTPUT = os.path.join(BASE, "output")
21class TestUnionFind(unittest.TestCase):
22 """Test UnionFind data structure for correctness and edge cases"""
24 def setUp(self):
25 self.uf = UnionFind()
27 def test_find_single_element(self):
28 """Test find on a single element"""
29 result = self.uf.find("entity1")
30 self.assertEqual(result, "entity1")
32 def test_union_two_elements(self):
33 """Test union of two elements"""
34 self.uf.union("entity1", "entity2")
35 root1 = self.uf.find("entity1")
36 root2 = self.uf.find("entity2")
37 self.assertEqual(root1, root2)
39 def test_union_multiple_elements(self):
40 """Test union of multiple elements forms single group"""
41 self.uf.union("entity1", "entity2")
42 self.uf.union("entity2", "entity3")
43 self.uf.union("entity3", "entity4")
45 root = self.uf.find("entity1")
46 self.assertEqual(self.uf.find("entity2"), root)
47 self.assertEqual(self.uf.find("entity3"), root)
48 self.assertEqual(self.uf.find("entity4"), root)
50 def test_separate_groups(self):
51 """Test that separate unions create separate groups"""
52 self.uf.union("entity1", "entity2")
53 self.uf.union("entity3", "entity4")
55 root1 = self.uf.find("entity1")
56 root3 = self.uf.find("entity3")
57 self.assertNotEqual(root1, root3)
59 def test_path_compression(self):
60 """Test that path compression works to flatten structure"""
61 self.uf.union("entity1", "entity2")
62 self.uf.union("entity2", "entity3")
63 self.uf.union("entity3", "entity4")
65 self.uf.find("entity4")
67 self.assertIn("entity4", self.uf.parent)
69 def test_find_long_chain(self):
70 """Test find on a long chain of unions"""
71 for i in range(100):
72 self.uf.union(f"entity{i}", f"entity{i+1}")
74 root = self.uf.find("entity0")
75 for i in range(101):
76 self.assertEqual(self.uf.find(f"entity{i}"), root)
78 def test_circular_reference_bug(self):
79 """Test that circular references raise ValueError"""
80 self.uf.parent["entity1"] = "entity2"
81 self.uf.parent["entity2"] = "entity3"
82 self.uf.parent["entity3"] = "entity1"
84 with self.assertRaises(ValueError) as context:
85 self.uf.find("entity1")
87 self.assertIn("Cycle detected", str(context.exception))
89 def test_self_loop(self):
90 """Test handling of self-loop (should not happen)"""
91 self.uf.parent["entity1"] = "entity1"
92 result = self.uf.find("entity1")
93 self.assertEqual(result, "entity1")
96class TestQuerySPARQL(unittest.TestCase):
97 """Test SPARQL query functions"""
99 @patch('oc_meta.run.merge.group_entities.SPARQLClient')
100 def test_query_sparql_batch(self, mock_sparql_client):
101 """Test batch querying for related entities"""
102 mock_client = MagicMock()
103 mock_sparql_client.return_value.__enter__.return_value = mock_client
104 mock_client.query.return_value = {
105 'results': {
106 'bindings': [
107 {'entity': {'value': 'https://example.org/related1', 'type': 'uri'}},
108 {'entity': {'value': 'https://example.org/related2', 'type': 'uri'}}
109 ]
110 }
111 }
113 from oc_meta.run.merge.group_entities import query_sparql_batch
114 result = query_sparql_batch("http://endpoint",
115 ["https://example.org/test1", "https://example.org/test2"])
116 self.assertEqual(len(result), 2)
117 self.assertIn('https://example.org/related1', result)
118 self.assertIn('https://example.org/related2', result)
120 @patch('oc_meta.run.merge.group_entities.SPARQLClient')
121 def test_query_sparql_batch_large_input(self, mock_sparql_client):
122 """Test batch processing with large input (multiple batches)"""
123 mock_client = MagicMock()
124 mock_sparql_client.return_value.__enter__.return_value = mock_client
125 mock_client.query.return_value = {'results': {'bindings': []}}
127 from oc_meta.run.merge.group_entities import query_sparql_batch
128 uris = [f"https://example.org/entity{i}" for i in range(25)]
129 query_sparql_batch("http://endpoint", uris, batch_size=10)
131 self.assertEqual(mock_client.query.call_count, 3)
133 @patch('oc_meta.run.merge.group_entities.SPARQLClient')
134 def test_query_sparql_batch_empty_results(self, mock_sparql_client):
135 """Test handling of empty results"""
136 mock_client = MagicMock()
137 mock_sparql_client.return_value.__enter__.return_value = mock_client
138 mock_client.query.return_value = {'results': {'bindings': []}}
140 from oc_meta.run.merge.group_entities import query_sparql_batch
141 result = query_sparql_batch("http://endpoint", ["https://example.org/test"])
142 self.assertEqual(len(result), 0)
144 @patch('oc_meta.run.merge.group_entities.SPARQLClient')
145 def test_query_sparql_batch_filters_literals(self, mock_sparql_client):
146 """Test that literal values are filtered out (only URIs)"""
147 mock_client = MagicMock()
148 mock_sparql_client.return_value.__enter__.return_value = mock_client
149 mock_client.query.return_value = {
150 'results': {
151 'bindings': [
152 {'entity': {'value': 'https://example.org/uri1', 'type': 'uri'}},
153 {'entity': {'value': 'Some Literal', 'type': 'literal'}},
154 {'entity': {'value': 'https://example.org/uri2', 'type': 'uri'}}
155 ]
156 }
157 }
159 from oc_meta.run.merge.group_entities import query_sparql_batch
160 result = query_sparql_batch("http://endpoint", ["https://example.org/test"])
161 self.assertEqual(len(result), 2)
162 self.assertIn('https://example.org/uri1', result)
163 self.assertIn('https://example.org/uri2', result)
164 self.assertNotIn('Some Literal', result)
167class TestGetAllRelatedEntities(unittest.TestCase):
168 """Test get_all_related_entities function"""
170 @patch('oc_meta.run.merge.group_entities.query_sparql_batch')
171 def test_get_all_related_entities_performance_fixed(self, mock_query_batch):
172 """Test that batch querying is used (performance fix)"""
173 mock_query_batch.return_value = set()
175 uris = [f"https://example.org/entity{i}" for i in range(10)]
176 get_all_related_entities("http://endpoint", uris)
178 self.assertEqual(mock_query_batch.call_count, 1)
180 @patch('oc_meta.run.merge.group_entities.query_sparql_batch')
181 def test_get_all_related_entities_performance_large_batch(self, mock_query_batch):
182 """Test performance with 100 URIs (should be ~10 queries with batch_size=10)"""
183 mock_query_batch.return_value = set()
185 uris = [f"https://example.org/entity{i}" for i in range(100)]
186 get_all_related_entities("http://endpoint", uris, batch_size=10)
188 self.assertEqual(mock_query_batch.call_count, 1)
190 @patch('oc_meta.run.merge.group_entities.query_sparql_batch')
191 def test_get_all_related_entities_includes_input_uris(self, mock_query_batch):
192 """Test that input URIs are included in results"""
193 mock_query_batch.return_value = set()
195 uris = ["https://example.org/entity1", "https://example.org/entity2"]
196 result = get_all_related_entities("http://endpoint", uris)
198 self.assertIn("https://example.org/entity1", result)
199 self.assertIn("https://example.org/entity2", result)
201 @patch('oc_meta.run.merge.group_entities.query_sparql_batch')
202 def test_get_all_related_entities_combines_results(self, mock_query_batch):
203 """Test that batch results are combined with input URIs"""
204 mock_query_batch.return_value = {
205 "https://example.org/related1",
206 "https://example.org/related2"
207 }
209 result = get_all_related_entities("http://endpoint", ["https://example.org/entity1"])
211 self.assertIn("https://example.org/entity1", result)
212 self.assertIn("https://example.org/related1", result)
213 self.assertIn("https://example.org/related2", result)
214 self.assertEqual(len(result), 3)
217class TestOptimizeGroups(unittest.TestCase):
218 """Test optimize_groups function"""
220 def test_optimize_groups_combines_single_groups(self):
221 """Test that single-row groups are combined"""
222 grouped_data = {
223 "group1": pd.DataFrame([{"surviving_entity": "e1", "merged_entities": "e2"}]),
224 "group2": pd.DataFrame([{"surviving_entity": "e3", "merged_entities": "e4"}]),
225 "group3": pd.DataFrame([{"surviving_entity": "e5", "merged_entities": "e6"}]),
226 }
228 result = optimize_groups(grouped_data, target_size=2)
230 combined_count = sum(1 for df in result.values() if len(df) >= 2)
231 self.assertGreater(combined_count, 0)
233 def test_optimize_groups_preserves_multi_groups(self):
234 """Test that multi-row groups are preserved and singles are combined"""
235 grouped_data = {
236 "group1": pd.DataFrame([
237 {"surviving_entity": "e1", "merged_entities": "e2"},
238 {"surviving_entity": "e3", "merged_entities": "e4"}
239 ]),
240 "group2": pd.DataFrame([{"surviving_entity": "e5", "merged_entities": "e6"}]),
241 "group3": pd.DataFrame([{"surviving_entity": "e7", "merged_entities": "e8"}]),
242 }
244 result = optimize_groups(grouped_data, target_size=2)
246 has_two_row_group = any(len(df) == 2 for df in result.values())
247 self.assertTrue(has_two_row_group)
249 total_rows = sum(len(df) for df in result.values())
250 self.assertEqual(total_rows, 4)
252 def test_optimize_groups_handles_empty_input(self):
253 """Test handling of empty input"""
254 result = optimize_groups({}, target_size=10)
255 self.assertEqual(len(result), 0)
257 def test_optimize_groups_data_loss_bug(self):
258 """Test for data loss bug when remaining group < target_size"""
259 grouped_data = {}
260 for i in range(35):
261 grouped_data[f"group{i}"] = pd.DataFrame([{
262 "surviving_entity": f"e{i}",
263 "merged_entities": f"e{i+100}"
264 }])
266 result = optimize_groups(grouped_data, target_size=50)
268 total_rows_input = sum(len(df) for df in grouped_data.values())
269 total_rows_output = sum(len(df) for df in result.values())
271 self.assertEqual(total_rows_input, total_rows_output,
272 "Data loss detected: not all rows preserved after optimization")
274 def test_optimize_groups_no_multi_groups_edge_case(self):
275 """Test edge case where there are no multi-row groups"""
276 grouped_data = {}
277 for i in range(25):
278 grouped_data[f"group{i}"] = pd.DataFrame([{
279 "surviving_entity": f"e{i}",
280 "merged_entities": f"e{i+100}"
281 }])
283 result = optimize_groups(grouped_data, target_size=10)
285 total_rows_output = sum(len(df) for df in result.values())
286 self.assertEqual(total_rows_output, 25)
288 def test_optimize_groups_all_multi_groups(self):
289 """Test when all groups are already multi-row"""
290 grouped_data = {
291 "group1": pd.DataFrame([
292 {"surviving_entity": "e1", "merged_entities": "e2"},
293 {"surviving_entity": "e3", "merged_entities": "e4"}
294 ]),
295 "group2": pd.DataFrame([
296 {"surviving_entity": "e5", "merged_entities": "e6"},
297 {"surviving_entity": "e7", "merged_entities": "e8"}
298 ]),
299 }
301 result = optimize_groups(grouped_data, target_size=10)
303 self.assertEqual(len(result), 2)
304 total_rows = sum(len(df) for df in result.values())
305 self.assertEqual(total_rows, 4)
308class TestGetFilePath(unittest.TestCase):
309 """Test get_file_path function"""
311 def test_get_file_path_basic(self):
312 """Test basic file path calculation"""
313 uri = "https://w3id.org/oc/meta/br/060100"
314 result = get_file_path(uri, dir_split=10000, items_per_file=1000, zip_output=True)
315 self.assertEqual(result, "br/060/10000/1000.zip")
317 def test_get_file_path_different_number(self):
318 """Test file path for different entity number"""
319 uri = "https://w3id.org/oc/meta/id/0605500"
320 result = get_file_path(uri, dir_split=10000, items_per_file=1000, zip_output=True)
321 self.assertEqual(result, "id/060/10000/6000.zip")
323 def test_get_file_path_json_output(self):
324 """Test file path with JSON output (not zipped)"""
325 uri = "https://w3id.org/oc/meta/br/060100"
326 result = get_file_path(uri, dir_split=10000, items_per_file=1000, zip_output=False)
327 self.assertEqual(result, "br/060/10000/1000.json")
329 def test_get_file_path_different_supplier(self):
330 """Test file path with different supplier prefix"""
331 uri = "https://w3id.org/oc/meta/ra/070250"
332 result = get_file_path(uri, dir_split=10000, items_per_file=1000, zip_output=True)
333 self.assertEqual(result, "ra/070/10000/1000.zip")
335 def test_get_file_path_large_number(self):
336 """Test file path for large entity number"""
337 uri = "https://w3id.org/oc/meta/br/06025000"
338 result = get_file_path(uri, dir_split=10000, items_per_file=1000, zip_output=True)
339 self.assertEqual(result, "br/060/30000/25000.zip")
341 def test_get_file_path_invalid_uri(self):
342 """Test that invalid URI returns None"""
343 uri = "https://invalid.uri.com/test"
344 result = get_file_path(uri, dir_split=10000, items_per_file=1000, zip_output=True)
345 self.assertIsNone(result)
347 def test_get_file_path_no_supplier_prefix(self):
348 """Test that URI without supplier prefix returns None"""
349 uri = "https://w3id.org/oc/meta/br/100"
350 result = get_file_path(uri, dir_split=10000, items_per_file=1000, zip_output=True)
351 self.assertIsNone(result)
354class TestGroupEntities(unittest.TestCase):
355 """Test group_entities function"""
357 @patch('oc_meta.run.merge.group_entities.get_all_related_entities')
358 def test_group_entities_creates_groups(self, mock_get_related):
359 """Test that group_entities creates correct groups"""
360 mock_get_related.return_value = set()
362 df = pd.DataFrame([
363 {"surviving_entity": "https://example.org/e1", "merged_entities": "https://example.org/e2"},
364 {"surviving_entity": "https://example.org/e3", "merged_entities": "https://example.org/e4"},
365 ])
367 result = group_entities(df, "http://endpoint")
369 self.assertGreater(len(result), 0)
371 @patch('oc_meta.run.merge.group_entities.get_all_related_entities')
372 def test_group_entities_handles_multiple_merged_entities(self, mock_get_related):
373 """Test handling of multiple merged entities (semicolon-separated)"""
374 mock_get_related.return_value = set()
376 df = pd.DataFrame([
377 {"surviving_entity": "https://example.org/e1",
378 "merged_entities": "https://example.org/e2; https://example.org/e3; https://example.org/e4"},
379 ])
381 result = group_entities(df, "http://endpoint")
383 self.assertGreater(len(result), 0)
385 @patch('oc_meta.run.merge.group_entities.get_all_related_entities')
386 def test_group_entities_single_iteration(self, mock_get_related):
387 """Test that single iteration is used (performance fix)"""
388 mock_get_related.return_value = set()
390 df = pd.DataFrame([
391 {"surviving_entity": f"https://example.org/e{i}",
392 "merged_entities": f"https://example.org/e{i+100}"}
393 for i in range(10)
394 ])
396 result = group_entities(df, "http://endpoint")
398 self.assertEqual(mock_get_related.call_count, 10)
399 self.assertGreater(len(result), 0)
401 @patch('oc_meta.run.merge.group_entities.get_all_related_entities')
402 def test_group_entities_no_double_iteration(self, mock_get_related):
403 """Test that DataFrame is iterated only once (not twice)"""
404 mock_get_related.return_value = set()
406 df_mock = MagicMock(spec=pd.DataFrame)
407 df_mock.iterrows.return_value = iter([
408 (0, pd.Series({"surviving_entity": "https://example.org/e1",
409 "merged_entities": "https://example.org/e2"})),
410 (1, pd.Series({"surviving_entity": "https://example.org/e3",
411 "merged_entities": "https://example.org/e4"})),
412 ])
413 df_mock.shape = (2,)
415 group_entities(df_mock, "http://endpoint")
417 self.assertEqual(df_mock.iterrows.call_count, 1,
418 "DataFrame.iterrows() should be called only once")
420 @patch('oc_meta.run.merge.group_entities.get_all_related_entities')
421 def test_group_entities_file_range_grouping(self, mock_get_related):
422 """Test that entities in same file range are grouped together"""
423 mock_get_related.return_value = set()
425 df = pd.DataFrame([
426 {"surviving_entity": "https://w3id.org/oc/meta/br/060100",
427 "merged_entities": "https://w3id.org/oc/meta/br/060200"},
428 {"surviving_entity": "https://w3id.org/oc/meta/br/060300",
429 "merged_entities": "https://w3id.org/oc/meta/br/060400"},
430 ])
432 result = group_entities(df, "http://endpoint", dir_split=10000, items_per_file=1000)
434 self.assertEqual(len(result), 1, "All entities in same file should be in same group")
436 @patch('oc_meta.run.merge.group_entities.get_all_related_entities')
437 def test_group_entities_different_files_separate_groups(self, mock_get_related):
438 """Test that entities in different files are in separate groups"""
439 mock_get_related.return_value = set()
441 df = pd.DataFrame([
442 {"surviving_entity": "https://w3id.org/oc/meta/br/060100",
443 "merged_entities": "https://w3id.org/oc/meta/br/060200"},
444 {"surviving_entity": "https://w3id.org/oc/meta/br/0601500",
445 "merged_entities": "https://w3id.org/oc/meta/br/0601600"},
446 ])
448 result = group_entities(df, "http://endpoint", dir_split=10000, items_per_file=1000)
450 self.assertEqual(len(result), 2, "Entities in different files should be in different groups")
453class TestSaveGroupedEntities(unittest.TestCase):
454 """Test save_grouped_entities function"""
456 def setUp(self):
457 if os.path.exists(OUTPUT):
458 shutil.rmtree(OUTPUT)
459 os.makedirs(OUTPUT, exist_ok=True)
461 def tearDown(self):
462 if os.path.exists(OUTPUT):
463 shutil.rmtree(OUTPUT)
465 def test_save_grouped_entities_creates_files(self):
466 """Test that files are created correctly"""
467 grouped_data = {
468 "https://example.org/e1": pd.DataFrame([
469 {"surviving_entity": "e1", "merged_entities": "e2"}
470 ]),
471 "https://example.org/e2": pd.DataFrame([
472 {"surviving_entity": "e3", "merged_entities": "e4"}
473 ]),
474 }
476 save_grouped_entities(grouped_data, OUTPUT)
478 files = os.listdir(OUTPUT)
479 self.assertEqual(len(files), 2)
480 self.assertTrue(all(f.endswith('.csv') for f in files))
482 def test_save_grouped_entities_preserves_data(self):
483 """Test that saved data matches input data"""
484 grouped_data = {
485 "https://example.org/e1": pd.DataFrame([
486 {"surviving_entity": "e1", "merged_entities": "e2"},
487 {"surviving_entity": "e3", "merged_entities": "e4"}
488 ])
489 }
491 save_grouped_entities(grouped_data, OUTPUT)
493 saved_file = os.path.join(OUTPUT, "e1.csv")
494 self.assertTrue(os.path.exists(saved_file))
496 loaded_df = pd.read_csv(saved_file)
497 self.assertEqual(len(loaded_df), 2)
498 self.assertIn("surviving_entity", loaded_df.columns)
499 self.assertIn("merged_entities", loaded_df.columns)
501 def test_save_grouped_entities_handles_special_characters(self):
502 """Test handling of special characters in URIs"""
503 grouped_data = {
504 "https://example.org/e1?param=value": pd.DataFrame([
505 {"surviving_entity": "e1", "merged_entities": "e2"}
506 ])
507 }
509 save_grouped_entities(grouped_data, OUTPUT)
511 files = os.listdir(OUTPUT)
512 self.assertEqual(len(files), 1)
514 def test_save_grouped_entities_creates_output_dir(self):
515 """Test that output directory is created if it doesn't exist"""
516 new_output = os.path.join(OUTPUT, "subdir", "nested")
518 grouped_data = {
519 "https://example.org/e1": pd.DataFrame([
520 {"surviving_entity": "e1", "merged_entities": "e2"}
521 ])
522 }
524 save_grouped_entities(grouped_data, new_output)
526 self.assertTrue(os.path.exists(new_output))
527 files = os.listdir(new_output)
528 self.assertEqual(len(files), 1)
531class TestIntegration(unittest.TestCase):
532 """Integration tests for complete workflow"""
534 def setUp(self):
535 if os.path.exists(BASE):
536 shutil.rmtree(BASE)
537 os.makedirs(BASE, exist_ok=True)
538 os.makedirs(OUTPUT, exist_ok=True)
540 def tearDown(self):
541 if os.path.exists(BASE):
542 shutil.rmtree(BASE)
544 def test_missing_csv_columns_validation_bug(self):
545 """Test that missing required columns causes proper error (validation bug)"""
546 from oc_meta.run.merge.group_entities import load_csv
548 csv_path = os.path.join(BASE, "invalid.csv")
549 with open(csv_path, 'w', newline='') as f:
550 writer = csv.DictWriter(f, fieldnames=["wrong_column"])
551 writer.writeheader()
552 writer.writerow({"wrong_column": "value"})
554 with self.assertRaises(ValueError) as context:
555 load_csv(csv_path)
557 self.assertIn("missing required columns", str(context.exception))
559 @patch('oc_meta.run.merge.group_entities.get_all_related_entities')
560 def test_complete_workflow(self, mock_get_related):
561 """Test complete workflow from CSV to grouped output"""
562 mock_get_related.return_value = set()
564 csv_path = os.path.join(BASE, "input.csv")
565 with open(csv_path, 'w', newline='') as f:
566 writer = csv.DictWriter(f, fieldnames=["surviving_entity", "merged_entities"])
567 writer.writeheader()
568 for i in range(5):
569 writer.writerow({
570 "surviving_entity": f"https://example.org/e{i}",
571 "merged_entities": f"https://example.org/e{i+100}"
572 })
574 df = pd.read_csv(csv_path)
575 grouped = group_entities(df, "http://endpoint")
576 optimized = optimize_groups(grouped, target_size=2)
577 save_grouped_entities(optimized, OUTPUT)
579 output_files = os.listdir(OUTPUT)
580 self.assertGreater(len(output_files), 0)
582 total_rows = 0
583 for file in output_files:
584 file_path = os.path.join(OUTPUT, file)
585 df_saved = pd.read_csv(file_path)
586 total_rows += len(df_saved)
588 self.assertEqual(total_rows, 5)
591if __name__ == "__main__":
592 unittest.main()