From bdc3d07a4a856228656b379db0a1aeaa1d42bc56 Mon Sep 17 00:00:00 2001 From: TrifonovRoman Date: Wed, 16 Oct 2024 13:50:34 +0800 Subject: [PATCH 1/2] add create_dataset_classifier.py --- scr/create_dataset_classifier.py | 97 ++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 scr/create_dataset_classifier.py diff --git a/scr/create_dataset_classifier.py b/scr/create_dataset_classifier.py new file mode 100644 index 0000000..a8c133c --- /dev/null +++ b/scr/create_dataset_classifier.py @@ -0,0 +1,97 @@ +import json +from pager import PageModel, PageModelUnit, WordsAndStylesModel, SpGraph4NModel, WordsAndStylesToSpGraph4N +from pager.page_model.sub_models.dtype import Style, StyleWord, ImageSegment +import os +import argparse + +page_model = PageModel([ + PageModelUnit("words_and_styles", sub_model=WordsAndStylesModel(), extractors=[], converters={}), + PageModelUnit("graph", sub_model=SpGraph4NModel(), extractors=[], + converters={"words_and_styles": WordsAndStylesToSpGraph4N()}), +]) + + +def is_one_block(word1, word2, blocks): + for block in blocks: + if block.is_intersection(word1) and block.is_intersection(word2): + return 1 + return 0 + + +def get_subgraphs_from_blocks(blocks, graph): + subgraphs = [] + + words = [w.segment for w in page_model.page_units[0].sub_model.words] + + for block in blocks: + subgraph = { + "A": [[], []], + "nodes_feature": [], + "edges_feature": [], + "true_edges": [] + } + + block_nodes = [] + for i, word in enumerate(words): + if block.is_intersection(word): + block_nodes.append(i) + subgraph["nodes_feature"].append(graph["nodes_feature"][i]) + + for i, (n1, n2) in enumerate(zip(graph["A"][0], graph["A"][1])): + if n1 in block_nodes and n2 in block_nodes: + subgraph["A"][0].append(block_nodes.index(n1)) + subgraph["A"][1].append(block_nodes.index(n2)) + subgraph["edges_feature"].append(graph["edges_feature"][i]) + subgraph["true_edges"].append(is_one_block(words[n1], words[n2], blocks)) + + subgraphs.append(subgraph) + + return subgraphs + + +def get_graphs_from_file(file_name): + with open(file_name, "r") as f: + info_img = json.load(f) + + publaynet_rez = info_img["blocks"] + pager_rez = info_img["additional_info"] + + page_model.from_dict(pager_rez) + page_model.extract() + graph = page_model.to_dict() + + blocks = [ImageSegment(dict_p_size=bl) for bl in publaynet_rez] + + subgraphs = get_subgraphs_from_blocks(blocks, graph) + + return subgraphs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Create classification dataset from PubLayNet and PageR') + parser.add_argument('--path_dir_jsons', type=str, nargs='?', required=True, + help='Path to the directory with json dataset') + parser.add_argument('--path_rez_json', type=str, nargs='?', required=True, + help='Path to the resulting json file') + + args = parser.parse_args() + args.path_dir_jsons, + args.path_rez_json + + dataset = { + "dataset": [] + } + + files = os.listdir(args.path_dir_jsons) + N = len(files) + + for i, json_file in enumerate(files): + subgraphs = get_graphs_from_file(os.path.join(args.path_dir_jsons, json_file)) + + for subgraph in subgraphs: + dataset["dataset"].append(subgraph) + + print(f"{(i + 1) / N * 100:.2f} %" + 20 * " ", end='\r') + + with open(args.path_rez_json, "w") as f: + json.dump(dataset, f) From 604238653d2ca6f5c36504d09fd96307fc78b977 Mon Sep 17 00:00:00 2001 From: TrifonovRoman Date: Wed, 16 Oct 2024 19:20:57 +0800 Subject: [PATCH 2/2] fixes --- scr/create_dataset_classifier.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/scr/create_dataset_classifier.py b/scr/create_dataset_classifier.py index a8c133c..5984233 100644 --- a/scr/create_dataset_classifier.py +++ b/scr/create_dataset_classifier.py @@ -28,7 +28,7 @@ def get_subgraphs_from_blocks(blocks, graph): "A": [[], []], "nodes_feature": [], "edges_feature": [], - "true_edges": [] + "label": block.get_info('label') } block_nodes = [] @@ -42,12 +42,15 @@ def get_subgraphs_from_blocks(blocks, graph): subgraph["A"][0].append(block_nodes.index(n1)) subgraph["A"][1].append(block_nodes.index(n2)) subgraph["edges_feature"].append(graph["edges_feature"][i]) - subgraph["true_edges"].append(is_one_block(words[n1], words[n2], blocks)) subgraphs.append(subgraph) return subgraphs +def get_block(bl): + b = ImageSegment(dict_p_size=bl) + b.add_info('label', bl['label']) + return b def get_graphs_from_file(file_name): with open(file_name, "r") as f: @@ -60,7 +63,7 @@ def get_graphs_from_file(file_name): page_model.extract() graph = page_model.to_dict() - blocks = [ImageSegment(dict_p_size=bl) for bl in publaynet_rez] + blocks = [get_block(bl) for bl in publaynet_rez] subgraphs = get_subgraphs_from_blocks(blocks, graph) @@ -86,8 +89,11 @@ def get_graphs_from_file(file_name): N = len(files) for i, json_file in enumerate(files): - subgraphs = get_graphs_from_file(os.path.join(args.path_dir_jsons, json_file)) - + subgraphs = [] + try: + subgraphs = get_graphs_from_file(os.path.join(args.path_dir_jsons, json_file)) + except: + print(json_file) for subgraph in subgraphs: dataset["dataset"].append(subgraph)