Implementation for the 2024 IEEE 11th International Conference on Data Science and Advanced Analytics (DSAA) accepted paper "Local Hierarchy-Aware Text-Label Association for Hierarchical Text Classification" paper-link
- Python >= 3.6
- torch >= 1.6.0
- transformers >= 4.30.2
- Below libraries only if you want to run on GAT/GCN as the graph encoder
- torch-geometric == 2.4.0
- torch-sparse == 0.6.17
- torch-scatter == 2.1.1
- All datasets are publically available and can be accessed at WOS, RCV1-V2 and NYT.
- We followed the specific details mentioned in the contrastive-htc repository to obtain and preprocess the original datasets (WOS, RCV1-V2, and NYT).
- After accessing the dataset, run the scripts in the folder
preprocessfor each dataset separately to obtain tokenized version of dataset and the related files. These will be added in thedata/xfolder where x is the name of dataset with possible choices as: wos, rcv and nyt. - Detailed steps regarding how to obtain and preprocess each dataset are mentioned in the readme file of
preprocessfolder - For reference, we have added tokenized versions of the WOS and NYT datasets along with their related files in the
datafolder. The RCV1-V2 dataset exceeds 400 MB in size, so it couldn't be uploaded due to GitHub's file size limits.
The train.py can be used to train all the models by setting different arguments.
python train.py --name='ckp_htla' --batch 10 --data='wos' --graph 1 --graph_type='graphormer' --msl 1 --msl_pen 1 --mg_list 0.1 0.1
Some Important arguments:
--namename of directory in which your model will be saved. For e.g. the above model will be saved in./HTLA/data/wos/ckp_htla--dataname of dataset directory which contains your data and related files. Possible options are 'wos', 'rcv' and 'nyt'--graphwhether to use graph encoder--graph_typetype of graph encoder. Possible choices are 'graphormer', 'GCN', and 'GAT'. HTLA uses graphormer as the graph encoder. The code for graph encoder is in the scriptgraph.py--mslwhether Margin Separation Loss required or not. The code for MSL is incriterion.py.--msl_penweight for the MSL component (we set it to 1 for all datasets)--mg_listmargin distance for each level. (We use 0.1 as margin distance for each level in all datasets).- For rcv:
--mg_list 0.1 0.1 0.1 - For nyt:
--mg_list 0.1 0.1 0.1 0.1 0.1 0.1- Note: For RCV and NYT, the last level contains only 1 and 2 labels, respectively, so MSL is not applied there.
- For rcv:
- The node feature is fixed as 768 to match the text feature size and is not included as run time argument
python train.py --name='ckp_bgrapho' --batch 10 --data='wos' --graph 1 --graph_type='graphormer' --msl 0
python train.py --name='ckp_bert' --batch 10 --data='wos' --graph 0
To run the trained model on test set run the script test.py
python test.py --name ckp_htla --data wos --extra _macro
Some Important arguments
--namename of the directory which contains the saved checkpoint. The checkpoint is saved in../HTLA/data/wos/when working with WOS dataset--dataname of dataset directory which contains your data and related files--extratwo checkpoints are kept based on best macro-F1 and micro-F1 respectively. The possible choices are_macroand_microto choose from the two checkpoints
If you find our work helpful, please cite it using the following BibTeX entry:
@INPROCEEDINGS{10722840,
author={Kumar, Ashish and Toshniwal, Durga},
booktitle={2024 IEEE 11th International Conference on Data Science and Advanced Analytics (DSAA)},
title={Local Hierarchy-Aware Text-Label Association for Hierarchical Text Classification},
year={2024},
volume={},
number={},
pages={1-10},
doi={10.1109/DSAA61799.2024.10722840}}