Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions scispacy/entity_merging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Merge entities recognized by different NER models,
optionally incorporating abbreviation long forms as entities.

Usage
-----

.. code-block:: python

import spacy
from scispacy.entity_merging import merge_entities

text = "Spinal and bulbar muscular atrophy (SBMA) is an inherited motor neuron disease."
doc = merge_entities(
text,
model_names=["en_core_sci_sm", "en_core_sci_lg"],
use_abbreviations=True,
)
print(doc.ents)

Or as a function you call on an already-processed list of docs:

.. code-block:: python

from scispacy.entity_merging import merge_overlapping_spans

merged = merge_overlapping_spans(all_spans, doc)
"""

from typing import List, Optional

import spacy
from spacy.language import Language
from spacy.tokens import Doc, Span
from spacy.util import filter_spans


def merge_overlapping_spans(spans: List[Span], doc: Doc) -> List[Span]:
"""
Given a flat list of (possibly overlapping) spans that all reference
the same Doc, return a filtered list keeping the longest non-overlapping
spans. Ties are broken by whichever span appears first.
"""
if not spans:
return []
# filter_spans keeps the longest span when there's overlap
return filter_spans(spans)


def _collect_entity_spans(
text: str,
model_names: List[str],
use_abbreviations: bool,
) -> tuple:
"""
Run *text* through each model in *model_names*, collect every entity
span, and (optionally) add spans for abbreviation long forms.

Returns (base_doc, all_spans) where *base_doc* is the Doc produced by
the first model and *all_spans* are Span objects that all reference
*base_doc*.
"""
if not model_names:
raise ValueError("model_names must contain at least one model")

pipelines = [spacy.load(name) for name in model_names]
base_nlp = pipelines[0]
base_doc = base_nlp(text)

all_spans: List[Span] = list(base_doc.ents)

# Entities from the remaining models need to be projected onto base_doc
# because spaCy Spans are tied to a specific Doc object.
for nlp in pipelines[1:]:
other_doc = nlp(text)
for ent in other_doc.ents:
try:
span = base_doc.char_span(ent.start_char, ent.end_char, label=ent.label_)
except Exception:
continue
if span is not None:
all_spans.append(span)

if use_abbreviations:
all_spans = _add_abbreviation_spans(base_nlp, base_doc, all_spans)

return base_doc, all_spans


def _add_abbreviation_spans(
nlp: Language, doc: Doc, spans: List[Span]
) -> List[Span]:
"""
If AbbreviationDetector is in the pipeline, use detected long forms
to create additional entity spans.
"""
try:
nlp.get_pipe("abbreviation_detector")
except KeyError:
# no abbreviation detector in this pipeline — add one temporarily
from scispacy.abbreviation import AbbreviationDetector # noqa: F811

nlp.add_pipe("abbreviation_detector")
doc = nlp(doc.text)

for abrv in doc._.abbreviations:
long_form = abrv._.long_form
if long_form is None:
continue
# long_form is already a Span on our doc
if isinstance(long_form, Span):
spans.append(long_form)

return spans


def merge_entities(
text: str,
model_names: List[str],
use_abbreviations: bool = True,
) -> Doc:
"""
Run *text* through multiple spaCy NER models, collect all recognized
entities, optionally add abbreviation long forms, and return a single
Doc whose ``.ents`` contains the longest non-overlapping entity spans.

Parameters
----------
text : str
The text to process.
model_names : list of str
Names of spaCy models to use (e.g. ``["en_core_sci_sm", "en_core_sci_lg"]``).
use_abbreviations : bool, optional (default True)
Whether to incorporate abbreviation long forms as candidate entities.

Returns
-------
Doc
A spaCy Doc with merged entities set as ``doc.ents``.
"""
base_doc, all_spans = _collect_entity_spans(text, model_names, use_abbreviations)
merged = merge_overlapping_spans(all_spans, base_doc)
base_doc.ents = merged
return base_doc
27 changes: 24 additions & 3 deletions scispacy/file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ def url_to_filename(url: str, etag: Optional[str] = None) -> str:
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""

last_part = url.split("/")[-1]
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
Expand All @@ -67,7 +65,12 @@ def url_to_filename(url: str, etag: Optional[str] = None) -> str:
etag_hash = sha256(etag_bytes)
filename += "." + etag_hash.hexdigest()

filename += "." + last_part
# Only keep the file extension to stay within filesystem NAME_MAX
# limits (e.g. 143 bytes on eCryptfs).
_, ext = os.path.splitext(url.split("/")[-1])
if ext:
filename += ext

return filename


Expand Down Expand Up @@ -106,6 +109,19 @@ def http_get(url: str, temp_file: IO) -> None:
pbar.close()


def _find_legacy_cache_path(
url: str, etag: Optional[str], cache_dir: str
) -> Optional[str]:
"""Check for a cached file using the old naming scheme (full trailing URL component)."""
last_part = url.split("/")[-1]
filename = sha256(url.encode("utf-8")).hexdigest()
if etag:
filename += "." + sha256(etag.encode("utf-8")).hexdigest()
filename += "." + last_part
path = os.path.join(cache_dir, filename)
return path if os.path.exists(path) else None


def get_from_cache(url: str, cache_dir: Optional[str] = None) -> str:
"""
Given a URL, look for the corresponding dataset in the local cache.
Expand All @@ -131,6 +147,11 @@ def get_from_cache(url: str, cache_dir: Optional[str] = None) -> str:
cache_path = os.path.join(cache_dir, filename)

if not os.path.exists(cache_path):
# Check for files cached under the old naming scheme, which appended
# the full trailing URL component instead of just the extension.
legacy_path = _find_legacy_cache_path(url, etag, cache_dir)
if legacy_path is not None:
return legacy_path
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file: # type: IO
Expand Down
79 changes: 79 additions & 0 deletions tests/test_entity_merging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import unittest

import spacy
from spacy.tokens import Doc, Span

from scispacy.entity_merging import merge_overlapping_spans


class TestMergeOverlappingSpans(unittest.TestCase):
def setUp(self):
self.nlp = spacy.blank("en")

def _make_doc(self, text):
return self.nlp(text)

def test_no_spans_returns_empty(self):
doc = self._make_doc("hello world")
result = merge_overlapping_spans([], doc)
assert result == []

def test_non_overlapping_spans_kept(self):
doc = self._make_doc("Spinal atrophy and motor neuron disease are conditions")
span_a = doc.char_span(0, 14, label="ENTITY") # "Spinal atrophy"
span_b = doc.char_span(19, 39, label="ENTITY") # "motor neuron disease"
assert span_a is not None
assert span_b is not None
result = merge_overlapping_spans([span_a, span_b], doc)
assert len(result) == 2

def test_overlapping_spans_keep_longest(self):
doc = self._make_doc("Spinal and bulbar muscular atrophy is a disease")
short = doc.char_span(0, 6, label="ENTITY") # "Spinal"
long = doc.char_span(0, 34, label="ENTITY") # "Spinal and bulbar muscular atrophy"
assert short is not None
assert long is not None
result = merge_overlapping_spans([short, long], doc)
assert len(result) == 1
assert result[0].text == "Spinal and bulbar muscular atrophy"

def test_partial_overlap_keeps_longest(self):
doc = self._make_doc("bulbar muscular atrophy is studied")
span_a = doc.char_span(0, 23, label="ENTITY") # "bulbar muscular atrophy"
span_b = doc.char_span(7, 23, label="ENTITY") # "muscular atrophy"
assert span_a is not None
assert span_b is not None
result = merge_overlapping_spans([span_a, span_b], doc)
assert len(result) == 1
assert result[0].text == "bulbar muscular atrophy"

def test_duplicate_spans_deduplicated(self):
doc = self._make_doc("motor neuron disease is common")
span_a = doc.char_span(0, 20, label="ENTITY")
span_b = doc.char_span(0, 20, label="ENTITY")
assert span_a is not None
assert span_b is not None
result = merge_overlapping_spans([span_a, span_b], doc)
assert len(result) == 1

def test_many_overlapping_spans(self):
# Simulates entities from multiple models with different granularity
doc = self._make_doc("Spinal and bulbar muscular atrophy caused by androgen receptor")
spans = []
# model A: fragments
spans.append(doc.char_span(0, 6, label="ENTITY")) # "Spinal"
spans.append(doc.char_span(11, 34, label="ENTITY")) # "bulbar muscular atrophy"
# model B: full phrase
spans.append(doc.char_span(0, 34, label="ENTITY")) # "Spinal and bulbar muscular atrophy"
# model C: second entity
spans.append(doc.char_span(45, 62, label="ENTITY")) # "androgen receptor"
# filter out any None from char_span misalignment
spans = [s for s in spans if s is not None]

result = merge_overlapping_spans(spans, doc)
texts = {s.text for s in result}
assert "Spinal and bulbar muscular atrophy" in texts
assert "androgen receptor" in texts
# fragments should be gone
assert "Spinal" not in texts
assert "bulbar muscular atrophy" not in texts
35 changes: 35 additions & 0 deletions tests/test_file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,38 @@ def test_url_to_filename_with_etags_eliminates_quotes(self):
back_to_url, etag = filename_to_url(filename, cache_dir=self.TEST_DIR)
assert back_to_url == url
assert etag == "mytag"

def test_url_to_filename_stays_within_name_max(self):
# eCryptfs limits filenames to 143 bytes; make sure we stay under that
# even with a long URL and etag.
long_url = "https://s3-us-west-2.amazonaws.com/bucket/" + "a" * 300 + "/file.npz"
long_etag = "x" * 300
filename = url_to_filename(long_url, etag=long_etag)
assert len(filename) <= 143
assert filename.endswith(".npz")
# also without etag
filename_no_etag = url_to_filename(long_url)
assert len(filename_no_etag) <= 143

def test_url_to_filename_no_extension(self):
# URLs without a file extension should still produce a valid filename
filename = url_to_filename("https://example.com/data/somefile")
assert len(filename) == 64 # just the sha256 hex digest
assert "." not in filename

def test_legacy_cache_files_still_found(self):
from scispacy.file_cache import _find_legacy_cache_path
from hashlib import sha256

url = "https://example.com/data/model.bin"
etag = "some-etag"
# Create a file with the old naming scheme
last_part = url.split("/")[-1]
old_filename = sha256(url.encode("utf-8")).hexdigest()
old_filename += "." + sha256(etag.encode("utf-8")).hexdigest()
old_filename += "." + last_part
old_path = os.path.join(self.TEST_DIR, old_filename)
pathlib.Path(old_path).touch()

found = _find_legacy_cache_path(url, etag, self.TEST_DIR)
assert found == old_path