diff --git a/README.md b/README.md index 36bf36c..6ff53d4 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,9 @@ This repository shows three examples of how [Prithvi](https://huggingface.co/ibm ## The approach ### Background -To finetune for these tasks in this repository, we make use of [MMSegmentation](https://mmsegmentation.readthedocs.io/en/0.x/), which provides an extensible framework for segmentation tasks. +To finetune for these tasks in this repository, we make use of [MMSegmentation](https://mmsegmentation.readthedocs.io/en/0.x/), which provides an extensible framework for segmentation tasks. -[MMSegmentation](https://mmsegmentation.readthedocs.io/en/0.x/) allows us to concatenate necks and heads appropriate for any segmentation downstream task to the encoder, and then perform the finetuning. This only requires setting up a config file detailing the desired model architecture, dataset setup and training strategy. +[MMSegmentation](https://mmsegmentation.readthedocs.io/en/0.x/) allows us to concatenate necks and heads appropriate for any segmentation downstream task to the encoder, and then perform the finetuning. This only requires the setup of a config file detailing the desired model architecture, dataset setup and training strategy. We build extensions on top of [MMSegmentation](https://mmsegmentation.readthedocs.io/en/0.x/) to support our encoder and provide classes to read and augment remote sensing data (from .tiff files) using [MMSegmentation](https://mmsegmentation.readthedocs.io/en/0.x/) data pipelines. These extensions can be found in the [geospatial_fm](./geospatial_fm/) directory, and they are installed as a package on the top of [MMSegmentation](https://mmsegmentation.readthedocs.io/en/0.x/) for ease of use. If more advanced functionality is necessary, it should be added there. @@ -39,33 +39,27 @@ We reccomend implementing the change after the `ToTensor` operation (which is al ### Data -The flood detection dataset can be downloaded from [Sen1Floods11](https://github.com/cloudtostreet/Sen1Floods11). Splits in the `mmsegmentation` format are available in the `data_splits` folders. +- Download the flood detection dataset from [Sen1Floods11](https://github.com/cloudtostreet/Sen1Floods11). Splits in the `mmsegmentation` format are available in the `data_splits` folders. +- Download the [NASA HLS fire scars dataset](https://huggingface.co/datasets/nasa-impact/hls_burn_scars) from Hugging Face. +- Download the [NASA HLS multi-temporal crop classification dataset](https://huggingface.co/datasets/ibm-nasa-geospatial/multi-temporal-crop-classification) fro Hugging Face. +- Download the [NASA HLS irrigation_scenes dataset](https://huggingface.co/datasets/ibm-nasa-geospatial/hls_irrigation_scenes) from HuggingFace. +## Running the finetuning +1. In the `configs` folder there are the three config examples for the three segmentation tasks. Complete the configs with your setup specifications. Parts that must be completed are marked with `#TO BE DEFINED BY USER`. They relate to where you downloaded the dataset, pretrained model weights, test set (e.g. regular one or Bolivia out of bag data) and where you are going to save the experiment outputs. -The [NASA HLS fire scars dataset](https://huggingface.co/datasets/nasa-impact/hls_burn_scars) can be downloaded from Hugging Face. +2. + a. With the conda env created above activated, run either of the commands below: -The [NASA HLS multi-temporal crop classification dataset](https://huggingface.co/datasets/ibm-nasa-geospatial/multi-temporal-crop-classification) can be downloaded from Hugging Face. + mim train mmsegmentation --launcher pytorch configs/sen1floods11_config.py + mim train mmsegmentation --launcher pytorch configs/burn_scars_config.py + mim train mmsegmentation --launcher pytorch configs/multi_temporal_crop_classification.py + mim train mmsegmentation --launcher pytorch configs/irrigation_scenes_config.py + b. To run testing: -## Running the finetuning -1. In the `configs` folder there are three config examples for the three segmentation tasks. Complete the configs with your setup specifications. Parts that must be completed are marked with `#TO BE DEFINED BY USER`. They relate to the location where you downloaded the dataset, pretrained model weights, the test set (e.g. regular one or Bolivia out of bag data) and where you are going to save the experiment outputs. - -2. - a. With the conda env created above activated, run: - - `mim train mmsegmentation --launcher pytorch configs/sen1floods11_config.py` or - - `mim train mmsegmentation --launcher pytorch configs/burn_scars.py` or - - `mim train mmsegmentation --launcher pytorch configs/multi_temporal_crop_classification.py` - - b. To run testing: - - `mim test mmsegmentation configs/sen1floods11_config.py --checkpoint /path/to/best/checkpoint/model.pth --eval "mIoU"` or - - `mim test mmsegmentation configs/burn_scars.py --checkpoint /path/to/best/checkpoint/model.pth --eval "mIoU"` or - - `mim test mmsegmentation configs/multi_temporal_crop_classification.py --checkpoint /path/to/best/checkpoint/model.pth --eval "mIoU"` + mim test mmsegmentation configs/sen1floods11_config.py --checkpoint /path/to/best/checkpoint/model.pth --eval "mIoU" + mim test mmsegmentation configs/burn_scars_config.py --checkpoint /path/to/best/checkpoint/model.pth --eval "mIoU" + mim test mmsegmentation configs/multi_temporal_crop_classification.py --checkpoint /path/to/best/checkpoint/model.pth --eval "mIoU" ## Checkpoints on Hugging Face We also provide checkpoints on Hugging Face for the [burn scars detection](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar) and the [multi temporal crop classification tasks](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification). diff --git a/configs/irrigation_scenes_config.py b/configs/irrigation_scenes_config.py new file mode 100644 index 0000000..ae16488 --- /dev/null +++ b/configs/irrigation_scenes_config.py @@ -0,0 +1,361 @@ +import os + +# base options +dist_params = dict(backend="nccl") +log_level = "INFO" +load_from = None +resume_from = None +cudnn_benchmark = True + +custom_imports = dict(imports=["geospatial_fm"]) + + +### Configs +# Data +# TO BE DEFINED BY USER: Data root to firescar downloaded dataset +data_root = "data/irrigation_scenes/" + +dataset_type = "SpatioTemporalDataset" +num_classes = 1 +num_frames = int(os.getenv("NUM_FRAMES", 3)) +img_size = int(os.getenv("IMG_SIZE", 224)) +num_workers = int(os.getenv("DATA_LOADER_NUM_WORKERS", 2)) +samples_per_gpu = 1 +CLASSES = (0, 1) + +img_norm_cfg = dict( + means=[0.166, 0.166, 0.166, 0.166, 0.166, 0.166], + stds=[0.114, 0.114, 0.114, 0.114, 0.114, 0.114], +) +# Sentinel-2 Bands 2,3,4,8A,11,12 (Blue, Green, Red, NIR_Narrow, SWIR1, SWIR2) +bands = [0, 1, 2, 3, 4, 5] + +tile_size = img_size +orig_nsize = 512 +crop_size = (tile_size, tile_size) + +img_suffix = ".tif" +seg_map_suffix = ".tif" + + +# ignore_index = -1 +# image_nodata = -9999 +# image_nodata_replace = 0 +image_to_float32 = True + +# Model +# TO BE DEFINED BY USER: path to pretrained backbone weights +pretrained_weights_path = "pretrain_ckpts/Prithvi_100M.pt" +num_layers = 12 +patch_size = 16 +embed_dim = 768 +num_heads = 12 +tubelet_size = 1 + +# TRAINING +# epochs=50 +# eval_epoch_interval = 5 + +# TO BE DEFINED BY USER: Save directory +experiment = "test_1" +project_dir = "finetune_weights/irrigation_scenes" +work_dir = os.path.join(project_dir, experiment) +save_path = work_dir + +gpu_ids = [0] + +splits = { + "train": data_root + "training_chips/training_data.txt", + "val": data_root + "validation_chips/validation_data.txt", + "test": data_root + "validation_chips/validation_data.txt", +} + +# Pipelines +train_pipeline = [ + dict( + type="LoadSpatioTemporalImagesFromFile", + to_float32=image_to_float32, + channels_last=True, + ), + dict( + type="LoadGeospatialAnnotations", + reduce_zero_label=False, + nodata=255, + nodata_replace=2, + ), + dict(type="RandomFlip", prob=0.5), # flip on axis 1, assume channel last NHWC + dict(type="ToTensor", keys=["img", "gt_semantic_seg"]), + dict( + type="TorchPermute", + keys=["img"], + order=(0, 3, 1, 2), # channel last to channels first NCHW + ), + dict(type="TorchNormalize", **img_norm_cfg), + dict(type="TorchRandomCrop", crop_size=crop_size), + dict( + type="Reshape", + keys=["img"], + new_shape=(len(bands), num_frames, tile_size, tile_size), + ), + dict(type="Reshape", keys=["gt_semantic_seg"], new_shape=(1, tile_size, tile_size)), + dict(type="CastTensor", keys=["gt_semantic_seg"], new_type="torch.LongTensor"), + dict(type="Collect", keys=["img", "gt_semantic_seg"]), +] + +val_pipeline = [ + dict( + type="LoadSpatioTemporalImagesFromFile", + to_float32=image_to_float32, + channels_last=True, + ), + dict( + type="LoadGeospatialAnnotations", + reduce_zero_label=False, + nodata=255, + nodata_replace=2, + ), + dict(type="ToTensor", keys=["img", "gt_semantic_seg"]), + dict( + type="TorchPermute", + keys=["img"], + order=(0, 3, 1, 2), # channel last to channels first NCHW + ), + dict(type="TorchNormalize", **img_norm_cfg), + dict(type="TorchRandomCrop", crop_size=crop_size), + dict( + type="Reshape", + keys=["img"], + new_shape=(len(bands), num_frames, tile_size, tile_size), + ), + dict(type="Reshape", keys=["gt_semantic_seg"], new_shape=(1, tile_size, tile_size)), + dict(type="CastTensor", keys=["gt_semantic_seg"], new_type="torch.LongTensor"), + dict( + type="Collect", + keys=["img", "gt_semantic_seg"], + meta_keys=[ + "img_info", + "ann_info", + "seg_fields", + "img_prefix", + "seg_prefix", + "filename", + "ori_filename", + "img", + "img_shape", + "ori_shape", + "pad_shape", + "scale_factor", + "img_norm_cfg", + "gt_semantic_seg", + ], + ), +] + +test_pipeline = [ + dict( + type="LoadSpatioTemporalImagesFromFile", + to_float32=image_to_float32, + channels_last=True, + ), + dict(type="ToTensor", keys=["img"]), + dict( + type="TorchPermute", + keys=["img"], + order=(0, 3, 1, 2), # channel last to channels first NCHW + ), + dict(type="TorchNormalize", **img_norm_cfg), + dict(type="TorchRandomCrop", crop_size=crop_size), # TODO remove hardcoded 224 size + dict( + type="Reshape", + keys=["img"], + new_shape=(len(bands), num_frames, tile_size, tile_size), + ), + dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"), + dict( + type="CollectTestList", + keys=["img"], + meta_keys=[ + "img_info", + "seg_fields", + "img_prefix", + "seg_prefix", + "filename", + "ori_filename", + "img", + "img_shape", + "ori_shape", + "pad_shape", + "scale_factor", + "img_norm_cfg", + ], + ), +] + +CLASSES = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + +data = dict( + samples_per_gpu=samples_per_gpu, + workers_per_gpu=4, + train=dict( + type=dataset_type, + # CLASSES=CLASSES, + reduce_zero_label=True, + data_root=data_root, + img_dir="month1", + ann_dir="masks", + pipeline=train_pipeline, + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + # split=splits["train"], + ), + val=dict( + type=dataset_type, + # CLASSES=CLASSES, + reduce_zero_label=True, + data_root=data_root, + img_dir="month1", + ann_dir="masks", + pipeline=val_pipeline, + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + # split=splits["val"], + ), + test=dict( + type=dataset_type, + # CLASSES=CLASSES, + reduce_zero_label=True, + data_root=data_root, + img_dir="month1", + ann_dir="masks", + pipeline=test_pipeline, + img_suffix=img_suffix, + seg_map_suffix=seg_map_suffix, + # split=splits["test"], + ), +) +# gt_seg_map_loader_cfg=dict(nodata=-1, nodata_replace=2))) + +# AdamW optimizer, no weight decay for position embedding & layer norm in backbone +optimizer = dict(type="Adam", lr=1.5e-5, betas=(0.9, 0.999), weight_decay=0.05) +optimizer_config = dict(grad_clip=None) +lr_config = dict( + policy="poly", + warmup="linear", + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, + min_lr=0.0, + by_epoch=False, +) + +log_config = dict( + interval=20, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type="TensorboardLoggerHook", by_epoch=False), + ], +) + +checkpoint_config = dict(by_epoch=True, interval=10, out_dir=save_path) + +evaluation = dict( + interval=1180, metric="mIoU", pre_eval=True, save_best="mIoU", by_epoch=False +) +reduce_train_set = dict(reduce_train_set=False) +reduce_factor = dict(reduce_factor=1) + +optimizer_config = dict(grad_clip=None) + +runner = dict(type="IterBasedRunner", max_iters=10000) +workflow = [("train", 1)] + +norm_cfg = dict(type="BN", requires_grad=True) + +loss_weights_multi = [ + 1.5652886, + 0.46067129, + 0.59387921, + 0.48431193, + 0.65555127, + 0.73865282, + 0.77616475, + 3.46336277, + 1.01650963, + 1.87640752, + 1.52960976, + 1.49788817, + 57.55048277, + 1.97697006, + 2.34793961, + 0.83456613, +] + +# loss_func = dict(type='DiceLoss', use_sigmoid=False, loss_weight=1, class_weight=loss_weights_multi) +loss_func = dict( + type="CrossEntropyLoss", + use_sigmoid=False, + class_weight=loss_weights_multi, + avg_non_ignore=True, +) + + +output_embed_dim = embed_dim * num_frames + +model = dict( + type="TemporalEncoderDecoder", + frozen_backbone=False, + backbone=dict( + type="TemporalViTEncoder", + pretrained=pretrained_weights_path, + img_size=img_size, + patch_size=patch_size, + num_frames=num_frames, + tubelet_size=1, + in_chans=len(bands), + embed_dim=embed_dim, + depth=num_layers, + num_heads=num_heads, + mlp_ratio=4.0, + norm_pix_loss=False, + ), + neck=dict( + type="ConvTransformerTokensToEmbeddingNeck", + embed_dim=embed_dim * num_frames, + output_embed_dim=output_embed_dim, + drop_cls_token=True, + Hp=img_size // patch_size, + Wp=img_size // patch_size, + ), + decode_head=dict( + num_classes=len(loss_weights_multi), + in_channels=output_embed_dim, + type="FCNHead", + in_index=-1, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=loss_func, + ), + auxiliary_head=dict( + num_classes=len(loss_weights_multi), + in_channels=output_embed_dim, + type="FCNHead", + in_index=-1, + channels=256, + num_convs=2, + concat_input=False, + dropout_ratio=0.1, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=loss_func, + ), + train_cfg=dict(), + test_cfg=dict( + mode="slide", + stride=(int(tile_size / 2), int(tile_size / 2)), + crop_size=(tile_size, tile_size), + ), +) diff --git a/geospatial_fm/datasets.py b/geospatial_fm/datasets.py index 76a63eb..c3b1173 100644 --- a/geospatial_fm/datasets.py +++ b/geospatial_fm/datasets.py @@ -22,4 +22,56 @@ def __init__(self, CLASSES=(0, 1), PALETTE=None, **kwargs): # ignore_index=2, **kwargs) - self.gt_seg_map_loader = LoadGeospatialAnnotations(reduce_zero_label=reduce_zero_label, **gt_seg_map_loader_cfg) \ No newline at end of file + self.gt_seg_map_loader = LoadGeospatialAnnotations( + reduce_zero_label=reduce_zero_label, **gt_seg_map_loader_cfg + ) + + +@DATASETS.register_module() +class SpatioTemporalDataset(GeospatialDataset): + """ + Time-series dataset for irrigation data at + https://huggingface.co/datasets/ibm-nasa-geospatial/hls_irrigation_scenes + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, split): + """Load annotation from directory. + + Args: + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. + ann_dir (str|None): Path to annotation directory. + seg_map_suffix (str|None): Suffix of segmentation maps. + split (str|None): Split txt file. If split is specified, only file + with suffix in the splits will be loaded. Otherwise, all images + in img_dir/ann_dir will be loaded. Default: None + + Returns: + list[dict]: All image info of dataset. + """ + + img_infos = [] + if split is not None: + raise NotImplementedError + else: + for img in self.file_client.list_dir_or_file( + dir_path=img_dir, list_dir=False, suffix=img_suffix, recursive=True + ): + # Get 'T10SFG_chip22.tif' basename from 'scene_m01_T10SFG_chip22.tif' + basename = "_".join(img.split(sep="_")[2:]) + img_info = dict( + filename_t1=f"scene_m01_{basename}", + filename_t2=f"scene_m02_{basename}", + filename_t3=f"scene_m03_{basename}", + filename_t4=f"scene_m04_{basename}", + ) + if ann_dir is not None: + seg_map = f"mask_{basename.replace(img_suffix, seg_map_suffix)}" + img_info["ann"] = dict(seg_map=seg_map) + img_infos.append(img_info) + img_infos = sorted(img_infos, key=lambda x: x["filename_t1"]) + + return img_infos diff --git a/geospatial_fm/geospatial_pipelines.py b/geospatial_fm/geospatial_pipelines.py index 62633b1..ca55e6b 100644 --- a/geospatial_fm/geospatial_pipelines.py +++ b/geospatial_fm/geospatial_pipelines.py @@ -315,19 +315,101 @@ def __repr__(self): return repr_str +@PIPELINES.register_module() +class LoadSpatioTemporalImagesFromFile(LoadGeospatialImageFromFile): + """ + Load a time-series dataset from multiple files. + + Currently hardcoded to assume that GeoTIFF files are structured in four + different 'monthX' folders like so: + + - month1/ + - scene_m01_XXXXXX_chip01.tif + - scene_m01_XXXXXX_chip02.tif + - month2/ + - scene_m02_XXXXXX_chip01.tif + - scene_m02_XXXXXX_chip02.tif + - month3/ + - scene_m03_XXXXXX_chip01.tif + - scene_m03_XXXXXX_chip02.tif + - month4/ + - scene_m04_XXXXXX_chip01.tif + - scene_m04_XXXXXX_chip02.tif + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __call__(self, results): + """ + Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + if results.get("img_prefix") is not None: + img_prefix = results["img_prefix"] + assert img_prefix.endswith("month1") + filenames = [ + osp.join(img_prefix, results["img_info"]["filename_t1"]), # June + osp.join( + img_prefix.replace("month1", "month2"), # July + results["img_info"]["filename_t2"], + ), + osp.join( + img_prefix.replace("month1", "month3"), # August + results["img_info"]["filename_t3"], + ), + # osp.join( + # img_prefix.replace("month1", "month4"), # September + # results["img_info"]["filename_t4"], + # ), + ] + else: + raise NotImplementedError + + img = np.stack(arrays=list(map(open_tiff, filenames)), axis=0) + assert img.shape == (3, 512, 512, 6) # Time, Height, Width, Channels + if not self.channels_last: + img = np.transpose(a=img, axes=(0, 2, 3, 1)) + assert img.shape == (3, 6, 512, 512) # Time, Channels, Height, Width + if self.to_float32: + img = img.astype(dtype=np.float32) + if self.nodata is not None: + img = np.where(img == self.nodata, self.nodata_replace, img) + + results["filename"] = filenames[0] + results["ori_filename"] = results["img_info"]["filename_t1"] + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape + # Set initial values for default meta_keys + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 + results["flip"] = False + num_channels = 1 if len(img.shape) < 3 else img.shape[0] + results["img_norm_cfg"] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False, + ) + return results + + @PIPELINES.register_module() class LoadGeospatialAnnotations(object): """Load annotations for semantic segmentation. Args: - to_uint8 (bool): Whether to convert the loaded label to a uint8 reduce_zero_label (bool): Whether reduce all label value by 1. Usually used for datasets where 0 is background label. Default: False. nodata (float/int): no data value to substitute to nodata_replace - nodata_replace (float/int): value to use to replace no data - - + nodata_replace (float/int): The value used to replace nodata values + with. Default: -1. """ def __init__( @@ -341,7 +423,10 @@ def __init__( self.nodata_replace = nodata_replace def __call__(self, results): - if results.get("seg_prefix", None) is not None: + if results.get("ann_info", {}).get("seg_map") is None: + results["ann_info"] = {"seg_map": results["img_info"]["ann"]["seg_map"]} + + if results.get("seg_prefix") is not None: filename = osp.join(results["seg_prefix"], results["ann_info"]["seg_map"]) else: filename = results["ann_info"]["seg_map"]