๐ฆพ Train a Model#
This guide showcases how to train a model on the Dataset
classes in the Argilla client.
The Dataset classes are lightweight containers for Argilla records. These classes facilitate importing from and exporting to different formats (e.g., pandas.DataFrame
, datasets.Dataset
) as well as sharing and versioning Argilla datasets using the Hugging Face Hub.
For each record type, thereโs a corresponding Dataset class called DatasetFor<RecordType>
.
You can look up their API in the reference section.
There are two ways to train custom models on top of your annotated data:
Train models using the Argilla training module, which is quick and easy but does not offer specific customization.
Train with a custom workflow using the prepare for training methods, which requires some configuration but also offers more flexibility to integrate with your existing training workflows.
Note
For training models with the FeedbackDataset
take a look here.
Train directly#
This is, quick and easy but does not offer specific customizations.
The ArgillaTrainer
is a wrapper around many of our favorite NLP libraries. It provides a very intuitive abstract workflow to facilitate simple training workflows using decent default pre-set configurations without having to worry about any data transformations from Argilla. We plan on adding more support for other tasks and frameworks so feel free to reach out on our Slack or GitHub.
Framework/Task |
TextClassification |
TokenClassification |
Text2Text |
Feedback |
---|---|---|---|---|
OpenAI |
โ๏ธ |
โ๏ธ |
||
AutoTrain |
โ๏ธ |
โ๏ธ |
โ๏ธ |
|
SetFit |
โ๏ธ |
|||
spaCy |
โ๏ธ |
โ๏ธ |
||
Transformers |
โ๏ธ |
โ๏ธ |
||
PEFT |
โ๏ธ |
โ๏ธ |
||
SpanMarker |
โ๏ธ |
The ArgillaTrainer
#
We can use the ArgillaTrainer
to train directly using spacy
, setfit
and transformers
as framework variables.
import argilla as rg
from argilla.training import ArgillaTrainer
trainer = ArgillaTrainer(
name="<my_dataset_name>",
workspace="<my_workspace_name>",
framework="<my_framework>",
train_size=0.8
)
trainer.train(path="<my_model_path>")
records = trainer.predict("The ArgillaTrainer is great!", as_argilla_records=True)
rg.log(records=records, name="<my_dataset_name>", workspace="<my_workspace_name>")
Update training config#
The trainer also has an ArgillaTrainer.update_config()
method, which maps **kwargs
to the respective framework. So, these can be derived from the underlying framework that was used to initialize the trainer. Underneath, you can find an overview of these variables for the supported frameworks. Note that you donโt need to pass all of them directly and that the values below are their default configurations.
# `OpenAI.FineTune`
trainer.update_config(
training_file = None,
validation_file = None,
model = "curie,
n_epochs = 2,
batch_size = None,
learning_rate_multiplier = 0.1,
prompt_loss_weight = 0.1,
compute_classification_metrics = False,
classification_n_classes = None,
classification_positive_class = None,
classification_betas = None,
suffix = None
)
# `AutoTrain.autotrain_advanced`
trainer.update_config(
model = "autotrain", # hub models like roberta-base
autotrain = [{
"source_language": "en",
"num_models": 5
}],
hub_model = [{
"learning_rate": 0.001,
"optimizer": "adam",
"scheduler": "linear",
"train_batch_size": 8,
"epochs": 10,
"percentage_warmup": 0.1,
"gradient_accumulation_steps": 1,
"weight_decay": 0.1,
"tasks": "text_binary_classification", # this is inferred from the dataset
}]
)
# `setfit.SetFitModel`
trainer.update_config(
pretrained_model_name_or_path = "all-MiniLM-L6-v2",
force_download = False,
resume_download = False,
proxies = None,
token = None,
cache_dir = None,
local_files_only = False
)
# `setfit.SetFitTrainer`
trainer.update_config(
metric = "accuracy",
num_iterations = 20,
num_epochs = 1,
learning_rate = 2e-5,
batch_size = 16,
seed = 42,
use_amp = True,
warmup_proportion = 0.1,
distance_metric = "BatchHardTripletLossDistanceFunction.cosine_distance",
margin = 0.25,
samples_per_label = 2
)
# `spacy.training`
trainer.update_config(
dev_corpus = "corpora.dev",
train_corpus = "corpora.train",
seed = 42,
gpu_allocator = 0,
accumulate_gradient = 1,
patience = 1600,
max_epochs = 0,
max_steps = 20000,
eval_frequency = 200,
frozen_components = [],
annotating_components = [],
before_to_disk = None,
before_update = None
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
pretrained_model_name_or_path = "distilbert-base-uncased",
force_download = False,
resume_download = False,
proxies = None,
token = None,
cache_dir = None,
local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
# `peft.LoraConfig`
trainer.update_config(
r=8,
target_modules=None,
lora_alpha=16,
lora_dropout=0.1,
fan_in_fan_out=False,
bias="none",
inference_mode=False,
modules_to_save=None,
init_lora_weights=True,
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
pretrained_model_name_or_path = "distilbert-base-uncased",
force_download = False,
resume_download = False,
proxies = None,
token = None,
cache_dir = None,
local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
# `SpanMarkerConfig`
trainer.update_config(
pretrained_model_name_or_path = "distilbert-base-cased"
model_max_length = 256,
marker_max_length = 128,
entity_max_length = 8,
)
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
CLI support#
We also add CLI support for the ArgillaTrainer
. This can be used when for example executing training on an external machine. Not that the --update-config-kwargs
always uses the update_config()
method for the corresponding class. Hence, you should take this into account to configure training via the CLI command by passing a JSON-serializable string.
Usage: python -m argilla train [OPTIONS] COMMAND [ARGS]...
Starts the ArgillaTrainer.
Options:
--name TEXT The name of the dataset to be used for training. [default: None]
--framework [transformers|setfit|spacy|span_marker|spark-nlp|openai] The framework to be used for training. [default: None]
--workspace TEXT The workspace to be used for training. [default: None]
--limit INTEGER The number of record to be used. [default: None]
--query TEXT The query to be used. [default: None]
--model TEXT The modelname or path to be used for training. [default: None]
--train-size FLOAT The train split to be used. [default: 1.0]
--seed INTEGER The random seed number. [default: 42]
--device INTEGER The GPU id to be used for training. [default: -1]
--output-dir TEXT Output directory for the saved model. [default: model]
--update-config-kwargs TEXT update_config() kwargs to be passed as a dictionary. [default: {}]
--api-url TEXT The API url to be used for training. [env var: ARGILLA_API_URL] [default: None]
--api-key TEXT The API key to be used for training. [env var: ARGILLA_API_KEY] [default: None]
An example workflow#
import argilla as rg
from datasets import load_dataset
dataset_rg = rg.DatasetForTokenClassification.from_datasets(
dataset=load_dataset("conll2003", split="train[:100]"),
tags="ner_tags",
)
rg.log(dataset_rg, name="conll2003", workspace="argilla")
trainer = ArgillaTrainer(
name="conll2003",
workspace="argilla",
framework="spacy",
train_size=0.8
)
trainer.update_config(max_epochs=2)
trainer.train(output_dir="my_easy_model")
records = trainer.predict("The ArgillaTrainer is great!", as_argilla_records=True)
rg.log(records=records, name="conll2003", workspace="argilla")
Train custom workflow#
Custom workflows give you more flexibility to integrate with your existing training workflows.
Prepare for training#
If you want to train a model we provide a handy method to prepare your dataset: DatasetFor*.prepare_for_training()
.
It will return a Hugging Face dataset, a spaCy DocBin or a SparkNLP-formatted DataFrame, optimized for the training process with the Hugging Face Trainer, the spaCy CLI or the SparkNLP API. Our training tutorials show entire training workflows for your favorite packages.
Train-test split#
It is possible to directly include train-test splits to the prepare_for_training
by passing the train_size
and test_size
parameters.
Frameworks and Tasks#
TextClassification
For text classification tasks, it flattens the inputs into separate columns of the returned dataset and converts the annotations of your records into integers and writes them in a label column:
By passing the framework
variable as setfit
, transformers
, spark-nlp
or spacy
. This task requires a DatastForTextClassification
.
TokenClassification
For token classification tasks, it converts the annotations of a record into integers representing BIO tags and writes them in a ner_tags
column:
By passing the framework
variable as transformers
, spark-nlp
or spacy
. This task requires a DatastForTokenClassification
.
Text2Text
For text generation tasks like summarization
and translation tasks, it converts the annotations of a record text
and target
columns.
By passing the framework
variable as transformers
and spark-nlp
. This task requires a DatastForText2Text
.
Feedback
For feedback-oriented datasets, we currently rely on a fully customizable workflow, which means automation is limited and yet to be thought out.
This task requires a FeedbackDataset
.
Framework/Dataset |
TextClassification |
TokenClassification |
Text2Text |
Feedback |
---|---|---|---|---|
OpenAI |
โ๏ธ |
โ๏ธ |
โ๏ธ |
|
AutoTrain |
โ๏ธ |
โ๏ธ |
||
SetFit |
โ๏ธ |
|||
spaCy |
โ๏ธ |
โ๏ธ |
||
Transformers |
โ๏ธ |
โ๏ธ |
โ๏ธ |
|
PEFT |
โ๏ธ |
โ๏ธ |
โ๏ธ |
|
SpanMarker |
โ๏ธ |
|||
Spark NLP |
โ๏ธ |
โ๏ธ |
โ๏ธ |
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="openai", train_size=1)
# [{'promt': 'My title', 'completion': ' My content'}]
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="autotrain", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="setfit", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
import spacy
nlp = spacy.blank("en")
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="spacy", lang=nlp, train_size=1)
# <spacy.tokens._serialize.DocBin object at 0x280613af0>
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="transformers", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="peft", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="span_marker", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="spark-nlp", train_size=1)
# <pd.DataFrame>