diff --git a/preprocess.py b/preprocess.py index 5e337b8..2173c1a 100644 --- a/preprocess.py +++ b/preprocess.py @@ -236,13 +236,10 @@ def blip_captioning_dataset( ) caption = processor.decode(out[0], skip_special_tokens=True) - # BLIP 2 lowercases all caps tokens. This should properly replace them w/o messing up subwords. I'm sure there's a better way to do this. + # BLIP 2 lowercases all caps tokens, so do a quick find and replace for token in substitution_tokens: - print(token) - sub_cap = " " + caption + " " - print(sub_cap) - sub_cap = sub_cap.replace(" " + token.lower() + " ", " " + token + " ") - caption = sub_cap.strip() + pattern = r"\b" + re.escape(token.lower()) + r"\b" + caption = re.sub(pattern, token, caption) captions.append(caption) print("Generated captions", captions)