Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions e2e/pipelines/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@pytest.fixture
def raw_texts():
def raw_texts() -> list:
return [
"My name is Clara and I live in Berkeley, California.",
"I'm Merlin, the happy pig!",
Expand All @@ -24,7 +24,7 @@ def raw_texts():


@pytest.fixture
def hf_annotations():
def hf_annotations() -> list:
return [
[
NamedEntityAnnotation(entity="PER", start=11, end=16),
Expand All @@ -38,7 +38,7 @@ def hf_annotations():


@pytest.fixture
def spacy_annotations():
def spacy_annotations() -> list:
return [
[
NamedEntityAnnotation(entity="PERSON", start=11, end=16),
Expand All @@ -51,14 +51,14 @@ def spacy_annotations():
]


def test_ner_extractor_init(del_hf_env_vars):
def test_ner_extractor_init(del_hf_env_vars) -> None:
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
extractor.warm_up()
assert extractor.initialized


@pytest.mark.parametrize("batch_size", [1, 3])
def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size, del_hf_env_vars):
def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size, del_hf_env_vars) -> None:
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
extractor.warm_up()

Expand All @@ -70,23 +70,23 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size, del_hf_
not os.environ.get("HF_API_TOKEN", None) and not os.environ.get("HF_TOKEN", None),
reason="Export an env var called HF_API_TOKEN or HF_TOKEN containing the Hugging Face token to run this test.",
)
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size) -> None:
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
extractor.warm_up()

_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)


@pytest.mark.parametrize("batch_size", [1, 3])
def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size):
def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size) -> None:
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_trf")
extractor.warm_up()

_extract_and_check_predictions(extractor, raw_texts, spacy_annotations, batch_size)


@pytest.mark.parametrize("batch_size", [1, 3])
def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size, del_hf_env_vars):
def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size, del_hf_env_vars) -> None:
pipeline = Pipeline()
pipeline.add_component(
name="ner_extractor",
Expand All @@ -100,7 +100,7 @@ def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size, del_hf
_check_predictions(predicted, hf_annotations)


def _extract_and_check_predictions(extractor, texts, expected, batch_size):
def _extract_and_check_predictions(extractor, texts, expected, batch_size) -> None:
docs = [Document(content=text) for text in texts]
outputs = extractor.run(documents=docs, batch_size=batch_size)["documents"]
for original_doc, output_doc in zip(docs, outputs):
Expand All @@ -117,7 +117,7 @@ def _extract_and_check_predictions(extractor, texts, expected, batch_size):
_check_predictions(predicted, expected)


def _check_predictions(predicted, expected):
def _check_predictions(predicted, expected) -> None:
assert len(predicted) == len(expected)
for pred, exp in zip(predicted, expected):
assert len(pred) == len(exp)
Expand Down