๐ซ Few-shot and zero-shot#
Zero and Few-shot models can create reasonably well predictions while requiring zero or only a few training samples. For zero-shot classification models, you generally only provide a label. These models are more popular for TextClassification
tasks, but there are also some examples for TokenClassification
. These models generally perform okay out of the box but custom models generally performs better. That is why these kind of model are generally used to get a head start with labeling
before training a tailor-made model. There are 2 basic examples is this guide, but there are more examples of using GPT3 here.
TextClassification#
Few-shot with SetFit
#
A more in-depth overview can be found in our tutorial about SetFit. For now, we will just show a short overview of that tutorial. We have great dataset integration with transformers
, json
and pandas
? Check our Datasets features.
[ ]:
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
import argilla as rg
# load a dataset from the hub
unlabelled = (
load_dataset("imdb", split="unsupervised").shuffle(seed=42).select(range(100))
)
unlabelled = rg.DatasetForTextClassification.from_datasets(unlabelled)
rg.log(unlabelled, "imdb_unlabelled")
# Go to Argilla and label ca. 8 examples per label.
# Load the handlabelled dataset from Argilla
train_ds = rg.load("imdb_unlabelled").prepare_for_training()
test_ds = load_dataset("imdb", split="test")
# Load SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_ds,
eval_dataset=test_ds,
loss_class=CosineSimilarityLoss,
batch_size=16,
num_iterations=20, # The number of text pairs to generate
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
# Share the model and data with the world.
train_ds.push_to_hub("setfit-mini-imdb-data")
trainer.push_to_hub("setfit-mini-imdb")
Zero-shot with transformers
#
A good collection of zero-shot models can be found on the Hugging Face model page. For this example we will use the most popular one facebook/bart-large-mnli
.
[ ]:
import argilla as rg
from transformers import pipeline
from datasets import load_dataset
# load a dataset from the hub
dataset = load_dataset("imdb", split="unsupervised")
# load a model pipeline
nlp = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
framework="pt",
)
# deploy and monitor your model
nlp = rg.monitor(nlp, dataset="transformers-mini-imdb")
dataset.map(
lambda example: {"prediction": nlp(example["text"], ["positive", "negative"])}
)
TokenClassification#
Few-shot concise-concepts
#
A more elaborate example of the usage of concise-concepts can be found in our blogs.
[ ]:
import spacy
import concise_concepts
# create some testdata
data = {
"fruit": ["apple", "pear", "orange"],
"vegetable": ["broccoli", "spinach", "tomato"],
"meat": ["beef", "pork", "fish"],
}
text = "Heat the oil in a large pan and add the Onion, celery and carrots."
# load a spaCy concise-concepts pipeline
nlp = spacy.load("en_core_web_lg", disable=["ner"])
nlp.add_pipe("concise_concepts", config={"data": data, "ent_score": True})
# deploy and monitor your model
rg.monitor(nlp, dataset="concise-concepts-fruits")
Zero-shot flair
#
We will use the NER dataset โWNUT 17: Emerging and Rare entity recognitionโ, which focuses on unusual, previously-unseen entities in the context of emerging discussions. This is the same dataset we use in our tutorial on flair.
[ ]:
from datasets import load_dataset
from flair.models import TARSTagger
from flair.data import Sentence
# download dataset
dataset = load_dataset("wnut_17", split="test")
labels = ["corporation", "creative-work", "group", "location", "person", "product"]
# load zero-shot NER tagger
tars = TARSTagger.load("tars-ner")
tars.add_and_switch_to_new_task("task 1", labels, label_type="ner")
# log data into Rubrix
records = []
for record in dataset.select(range(100)):
input_text = " ".join(record["tokens"])
sentence = Sentence(input_text)
tars.predict(sentence)
prediction = [
(entity.get_labels()[0].value, entity.start_position, entity.end_position)
for entity in sentence.get_spans("ner")
]
# building TokenClassificationRecord
records.append(
rg.TokenClassificationRecord(
text=input_text,
tokens=[token.text for token in sentence],
prediction=prediction,
prediction_agent="tars-ner",
)
)
# log the records to Argilla
rg.log(records, name="tars_ner_wnut_17", metadata={"split": "test"})