๐ซ Zero-shot and few-shot classification with SetFit#
In this tutorial, youโll learn to use Sentence Transformer embeddings and SetFitโs zero-shot and few-shot capabilities to make data labeling significantly faster. It will walk you through the following steps:
๐พ Use sentence transformers to generate embeddings of a dataset with banking customer requests.
๐ซ Use SetFitโs zero-shot classifier and upload its predictions together with the embeddings.
๐ท Leverage similarity search and pre-annotated examples to efficiently label bulks of semantically-related, high impact examples.
๐ฆพ Train a few-shot SetFit model to improve the results of the zero-shot model.
Introduction#
In this tutorial, weโll use the power of embeddings to make data labeling (and curation) more efficient. Combined with SetFitโs zero and few-shot capabilities, this approach will greatly reduce the time to get a good quality model with your own data.
Letโs do it!
Running Argilla#
For this tutorial, you will need to have an Argilla server running. There are two main options for deploying and running Argilla:
Deploy Argilla on Hugging Face Spaces: If you want to run tutorials with external notebooks (e.g., Google Colab) and you have an account on Hugging Face, you can deploy Argilla on Spaces with a few clicks:
For details about configuring your deployment, check the official Hugging Face Hub guide.
Launch Argilla using Argillaโs quickstart Docker image: This is the recommended option if you want Argilla running on your local machine. Note that this option will only let you run the tutorial locally and not with an external notebook service.
For more information on deployment options, please check the Deployment section of the documentation.
Tip
This tutorial is a Jupyter Notebook. There are two options to run it:
Use the Open in Colab button at the top of this page. This option allows you to run the notebook directly on Google Colab. Donโt forget to change the runtime type to GPU for faster model training and inference.
Download the .ipynb file by clicking on the View source link at the top of the page. This option allows you to download the notebook and run it on your local machine or on a Jupyter notebook tool of your choice.
Setup#
For this tutorial, youโll need Argillaโs Python client and a few third-party libraries that can be installed via pip
:
[ ]:
%pip install argilla==1.3.0 datasets==2.8.0 sentence-transformers==2.2.2 setfit==0.6.0 plotly==4.1.0 -qqq
Letโs import the Argilla module for reading and writing data:
[1]:
import argilla as rg
If you are running Argilla using the Docker quickstart image or Hugging Face Spaces, you need to init the Argilla client with the URL
and API_KEY
:
[2]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
rg.init(
api_url="put_here_your_url",
api_key="admin.apikey"
)
Letโs add the imports we need:
[ ]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CosineSimilarityLoss
from datasets import load_dataset
from setfit import get_templated_dataset
from setfit import SetFitModel, SetFitTrainer
๐พ Embedding your dataset#
The code below will load the banking customer requests dataset from the Hub, encode the text
field, and create the vectors
field which will contain only one key (mini-lm-sentence-transformers
). For the purposes of labeling the dataset from scratch, it will also remove the label
field, which contains the original intent labels.
[ ]:
# Define fast version of sentence transformers, change to cuda if available
encoder = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
# Load dataset with banking
dataset = load_dataset("banking77", split="test")
# Encode text field using batched computation
dataset = dataset.map(
lambda batch: {"vectors": encoder.encode(batch["text"])},
batch_size=32,
batched=True
)
# Removes the original labels because you'll be labeling from scratch
dataset = dataset.remove_columns("label")
# Turn vectors into a dictionary
dataset = dataset.map(
lambda r: {"vectors": {"mini-lm-sentence-transformers": r["vectors"]}}
)
Our dataset now contains a vectors
field with the embedding vector generated by the sentence transformer model.
[ ]:
dataset.to_pandas().head()
text | vectors | |
---|---|---|
0 | How do I locate my card? | {'mini-lm-sentence-transformers': [-0.01016701... |
1 | I still have not received my new card, I order... | {'mini-lm-sentence-transformers': [-0.04284121... |
2 | I ordered a card but it has not arrived. Help ... | {'mini-lm-sentence-transformers': [-0.03365556... |
3 | Is there a way to know when my card will arrive? | {'mini-lm-sentence-transformers': [0.012195922... |
4 | My card has not arrived yet. | {'mini-lm-sentence-transformers': [-0.04361867... |
๐ซ Zero-shot predictions with SetFit#
The original banking77
dataset is an intent classification dataset with dozens of labels (lost_card
, card_arrival
, etc.). To keep this tutorial simple, we define a simplified labeling scheme with higher level classes.
Letโs set up and train our zero-shot SetFit model. Please note that SetFitโs approach to zero-shot is to create a synthetic dataset of training examples, which is different from other approaches (e.g., transformers zero-shot pipelines) where โtemplatedโ examples with label names are used at inference time.
[ ]:
labels = ["change details", "card", "atm", "top up", "balance", "transfer", "exchange rate", "pin"]
train_dataset = get_templated_dataset(
candidate_labels=labels,
sample_size=8,
template="The customer request is about {}"
)
model = SetFitModel.from_pretrained("all-MiniLM-L6-v2")
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset
)
trainer.train()
We can use our trained zero-shot model to predict over the dataset. Weโll later load these predictions into Argilla and use them to speed-up the labeling process.
[ ]:
def get_predictions(texts):
probas = model.predict_proba(texts, as_numpy=True)
for pred in probas:
yield [{"label": label, "score": score} for label, score in zip(labels, pred)]
dataset = dataset.map(lambda batch: {"prediction": list(get_predictions(batch["text"]))}, batched=True)
Letโs upload our dataset with vectors and predictions.
[ ]:
rg_ds = rg.DatasetForTextClassification.from_datasets(dataset)
rg.log(
name="banking77-topics-setfit",
records=rg_ds,
chunk_size=50,
)
๐ท Bulk labeling with the find similar
and zero-shot predictions#
Now that our banking77-topics-setfit
is available from the Argilla UI. You can start annotating the data by leveraging similarity search and our zero-shot predictions. After going to your Argilla UI URL, the workflow is following:
Label a record (e.g., label โChange my informationโ as
change details
) and then click on Find similar on the top-right of your record.As a result, youโll get to a list of the most similar record sorted by similarity with their corresponding predictions.
You can now review the predictions, validate them, or correct them.
After labeling around 200 records, weโre ready to evaluate our zero-shot model, letโs see how!
๐ Evaluate the zero-shot model#
We can use Argillaโs built-in metrics to compute f1
based on (1) the predictions of the zero-shot model we stored at the beginning of the tutorial, and (2) the manual annotations. Please note that during the labeling process, weโve added a new label Other
to account for examples that didnโt fall into our predefined categories. This highlights the importance of iterating early on during project definition. Argilla gives users a lot of flexibility, features like predictions and similarity
search can help to surface potential issues and refinements much faster than with traditional data annotation tools.
[3]:
from argilla.metrics.text_classification import f1
f1(name="banking77-topics-setfit").visualize()