Training#
Here we describe the available trainers in Argilla:
Base Trainer: Internal mechanism to handle Trainers
SetFit Trainer: Internal mechanism for handling training logic of SetFit models
OpenAI Trainer: Internal mechanism for handling training logic of OpenAI models
PEFT (LoRA) Trainer: Internal mechanism for handling training logic of PEFT (LoRA) models
spaCy Trainer: Internal mechanism for handling training logic of spaCy models
Transformers Trainer: Internal mechanism for handling training logic of Transformers models
SpanMarker Trainer: Internal mechanism for handling training logic of SpanMarker models
TRL Trainer: Internal mechanism for handling training logic of TRL models
SentenceTransformer Trainer: Internal mechanism for handling training logic of SentenceTransformer models
Base Trainer#
- class argilla.training.base.ArgillaTrainerSkeleton(name, dataset, record_class, workspace=None, multi_label=False, settings=None, model=None, seed=None, *arg, **kwargs)#
- Parameters:
name (str) โ
record_class (Union[TokenClassificationRecord, Text2TextRecord, TextClassificationRecord]) โ
workspace (Optional[str]) โ
multi_label (bool) โ
settings (Union[TextClassificationSettings, TokenClassificationSettings]) โ
model (str) โ
seed (int) โ
- get_model()#
Returns the model.
- get_model_card_data(card_data_kwargs)#
Generates a FrameworkCardData instance to generate a model card from.
- Parameters:
card_data_kwargs (Dict[str, Any]) โ
- Return type:
- get_model_kwargs()#
Returns the model kwargs.
- get_tokenizer()#
Returns the tokenizer.
- get_trainer()#
Returns the trainer.
- get_trainer_kwargs()#
Returns the training kwargs.
- abstract init_model()#
Initializes a model.
- abstract init_training_args()#
Initializes the training arguments.
- abstract predict(text, as_argilla_records=True, **kwargs)#
Predicts the label of the text.
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- push_to_huggingface(repo_id, **kwargs)#
Uploads the model to [Huggingface Hub](https://huggingface.co/docs/hub/models-the-hub).
- Parameters:
repo_id (str) โ
- Return type:
Optional[str]
- abstract save(output_dir)#
Saves the model to the specified path.
- Parameters:
output_dir (str) โ
- abstract train(output_dir=None)#
Trains the model.
- Parameters:
output_dir (Optional[str]) โ
- abstract update_config(*args, **kwargs)#
Updates the configuration of the trainer, but the parameters depend on the trainer.subclass.
- class argilla.client.feedback.integrations.huggingface.model_card.FrameworkCardData(*args, **kwargs)#
Parent class to generate the variables to add to the ModelCard.
Each framework will inherit from here and update accordingly.
- Parameters:
language (Optional[Union[str, List[str]]]) โ
license (Optional[str]) โ
model_name (Optional[str]) โ
model_id (Optional[str]) โ
dataset_name (Optional[str]) โ
dataset_id (Optional[str]) โ
tags (Optional[List[str]]) โ
model_summary (Optional[str]) โ
model_description (Optional[str]) โ
developers (Optional[str]) โ
shared_by (Optional[str]) โ
model_type (Optional[str]) โ
finetuned_from (Optional[str]) โ
repo (Optional[str]) โ
_is_on_huggingface (bool) โ
framework (Optional[Framework]) โ
train_size (Optional[float]) โ
seed (Optional[int]) โ
framework_kwargs (Dict[str, Any]) โ
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) โ
output_dir (Optional[str]) โ
library_name (Optional[str]) โ
update_config_kwargs (Dict[str, Any]) โ
- to_dict()#
Main method to generate the variables that will be written in the model card.
- Return type:
Dict[str, Any]
SetFit Trainer#
- class argilla.training.setfit.ArgillaSetFitTrainer(*args, **kwargs)#
- init_model()#
Initializes a model.
- init_training_args()#
Initializes the training arguments.
- Return type:
None
- predict(text, as_argilla_records=True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- save(output_dir)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
path (str) โ the path to save the model to
output_dir (str) โ
- train(output_dir=None)#
We create a SetFitModel object from a pretrained model, then create a SetFitTrainer object with the model, and then train the model
- Parameters:
output_dir (Optional[str]) โ
- update_config(**kwargs)#
Updates the model_kwargs and trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
- Return type:
None
- class argilla.client.feedback.integrations.huggingface.model_card.SetFitModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.SETFIT: 'setfit'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, tokenizer: 'PreTrainedTokenizer' = '')#
- Parameters:
language (Optional[Union[str, List[str]]]) โ
license (Optional[str]) โ
model_name (Optional[str]) โ
model_id (Optional[str]) โ
dataset_name (Optional[str]) โ
dataset_id (Optional[str]) โ
tags (Optional[List[str]]) โ
model_summary (Optional[str]) โ
model_description (Optional[str]) โ
developers (Optional[str]) โ
shared_by (Optional[str]) โ
model_type (Optional[str]) โ
finetuned_from (Optional[str]) โ
repo (Optional[str]) โ
_is_on_huggingface (bool) โ
framework (Framework) โ
train_size (Optional[float]) โ
seed (Optional[int]) โ
framework_kwargs (Dict[str, Any]) โ
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) โ
output_dir (Optional[str]) โ
library_name (Optional[str]) โ
update_config_kwargs (Dict[str, Any]) โ
tokenizer (PreTrainedTokenizer) โ
OpenAI Trainer#
- class argilla.training.openai.ArgillaOpenAITrainer(*args, **kwargs)#
- init_model()#
Initializes a model.
- Return type:
None
- init_training_args(training_file=None, validation_file=None, model='curie', n_epochs=None, batch_size=None, learning_rate_multiplier=0.1, prompt_loss_weight=0.1, compute_classification_metrics=False, classification_n_classes=None, classification_positive_class=None, classification_betas=None, suffix=None, hyperparameters=None)#
Initializes the training arguments.
- Parameters:
training_file (Optional[str]) โ
validation_file (Optional[str]) โ
model (str) โ
n_epochs (Optional[int]) โ
batch_size (Optional[int]) โ
learning_rate_multiplier (float) โ
prompt_loss_weight (float) โ
compute_classification_metrics (bool) โ
classification_n_classes (Optional[int]) โ
classification_positive_class (Optional[str]) โ
classification_betas (Optional[list]) โ
suffix (Optional[str]) โ
hyperparameters (Optional[dict]) โ
- Return type:
None
- predict(text, as_argilla_records=True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
- Return type:
Union[List, str]
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- Return type:
Union[List, str]
- save(*arg, **kwargs)#
The function saves the model to the path specified and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- Return type:
None
- train(output_dir=None)#
We create an openai.FineTune object from a pre-trained model, and send data to finetune it.
- Parameters:
output_dir (Optional[str]) โ
- Return type:
None
- update_config(**kwargs)#
Updates the model_kwargs dictionaries with the keyword arguments passed to the update_config function.
- class argilla.client.feedback.integrations.huggingface.model_card.OpenAIModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.OPENAI: 'openai'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>)#
- Parameters:
language (Optional[Union[str, List[str]]]) โ
license (Optional[str]) โ
model_name (Optional[str]) โ
model_id (Optional[str]) โ
dataset_name (Optional[str]) โ
dataset_id (Optional[str]) โ
tags (Optional[List[str]]) โ
model_summary (Optional[str]) โ
model_description (Optional[str]) โ
developers (Optional[str]) โ
shared_by (Optional[str]) โ
model_type (Optional[str]) โ
finetuned_from (Optional[str]) โ
repo (Optional[str]) โ
_is_on_huggingface (bool) โ
framework (Framework) โ
train_size (Optional[float]) โ
seed (Optional[int]) โ
framework_kwargs (Dict[str, Any]) โ
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) โ
output_dir (Optional[str]) โ
library_name (Optional[str]) โ
update_config_kwargs (Dict[str, Any]) โ
PEFT (LoRA) Trainer#
- class argilla.training.peft.ArgillaPeftTrainer(*args, **kwargs)#
- init_model(new=False)#
Initializes a model.
- Parameters:
new (bool) โ
- init_training_args()#
Initializes the training arguments.
- predict(text, as_argilla_records=True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- save(output_dir)#
The function saves the model to the path specified and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- sys = <module 'sys' (built-in)>#
- update_config(**kwargs)#
Updates the model_kwargs and trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
- class argilla.client.feedback.integrations.huggingface.model_card.PeftModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.PEFT: 'peft'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, tokenizer: 'PreTrainedTokenizer' = '')#
- Parameters:
language (Optional[Union[str, List[str]]]) โ
license (Optional[str]) โ
model_name (Optional[str]) โ
model_id (Optional[str]) โ
dataset_name (Optional[str]) โ
dataset_id (Optional[str]) โ
tags (Optional[List[str]]) โ
model_summary (Optional[str]) โ
model_description (Optional[str]) โ
developers (Optional[str]) โ
shared_by (Optional[str]) โ
model_type (Optional[str]) โ
finetuned_from (Optional[str]) โ
repo (Optional[str]) โ
_is_on_huggingface (bool) โ
framework (Framework) โ
train_size (Optional[float]) โ
seed (Optional[int]) โ
framework_kwargs (Dict[str, Any]) โ
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) โ
output_dir (Optional[str]) โ
library_name (Optional[str]) โ
update_config_kwargs (Dict[str, Any]) โ
tokenizer (PreTrainedTokenizer) โ
spaCy Trainer#
- class argilla.training.spacy.ArgillaSpaCyTrainer(freeze_tok2vec=False, **kwargs)#
- Parameters:
freeze_tok2vec (bool) โ
- init_training_args()#
This method is used to generate the spacy configuration file, which is used to train
- Return type:
None
- class argilla.training.spacy.ArgillaSpaCyTransformersTrainer(update_transformer=True, **kwargs)#
- Parameters:
update_transformer (bool) โ
- init_training_args()#
This method is used to generate the spacy configuration file, which is used to train
- Return type:
None
- class argilla.training.spacy._ArgillaSpaCyTrainerBase(language=None, gpu_id=-1, model=None, optimize='efficiency', *args, **kwargs)#
- Parameters:
language (Optional[str]) โ
gpu_id (Optional[int]) โ
model (Optional[str]) โ
optimize (Literal['efficiency', 'accuracy']) โ
- init_model()#
Initializes a model.
- predict(text, as_argilla_records=True, **kwargs)#
Predict the labels for the given text using the trained pipeline.
- Parameters:
text (Union[List[str], str]) โ A str or a List[str] with the text to predict the labels for.
as_argilla_records (bool) โ A bool indicating whether to return the predictions as argilla records or as dicts. Defaults to True.
- Returns:
Either a dict, BaseModel (if as_argilla_records is True) or a List[dict], List[BaseModel] (if as_argilla_records is True) with the predictions.
- Return type:
Union[Dict[str, Any], List[Dict[str, Any]], BaseModel, List[BaseModel]]
- save(output_dir)#
Save the trained pipeline to disk.
- Parameters:
output_dir (str) โ A str with the path to save the pipeline.
- Return type:
None
- train(output_dir=None)#
Train the pipeline using spaCy.
- Parameters:
output_dir (Optional[str]) โ A str with the path to save the trained pipeline. Defaults to None.
- Return type:
None
- update_config(**spacy_training_config)#
Update the spaCy training config.
Disclaimer: currently just the training config is supported, but in the future we will support all the spaCy config values for more precise control over the training process. Also, note that the arguments may differ between the CPU and GPU training.
- Parameters:
**spacy_training_config โ The spaCy training config.
- Return type:
None
- class argilla.client.feedback.integrations.huggingface.model_card.SpacyTransformersModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.SPACY_TRANSFORMERS: 'spacy-transformers'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, lang: Optional[ForwardRef('spacy.Language')] = None, gpu_id: Optional[int] = -1, optimize: Literal['efficiency', 'accuracy'] = 'efficiency', pipeline: List[str] = <factory>, update_transformer: bool = True)#
- Parameters:
language (Optional[Union[str, List[str]]]) โ
license (Optional[str]) โ
model_name (Optional[str]) โ
model_id (Optional[str]) โ
dataset_name (Optional[str]) โ
dataset_id (Optional[str]) โ
tags (Optional[List[str]]) โ
model_summary (Optional[str]) โ
model_description (Optional[str]) โ
developers (Optional[str]) โ
shared_by (Optional[str]) โ
model_type (Optional[str]) โ
finetuned_from (Optional[str]) โ
repo (Optional[str]) โ
_is_on_huggingface (bool) โ
framework (Framework) โ
train_size (Optional[float]) โ
seed (Optional[int]) โ
framework_kwargs (Dict[str, Any]) โ
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) โ
output_dir (Optional[str]) โ
library_name (Optional[str]) โ
update_config_kwargs (Dict[str, Any]) โ
lang (Optional[spacy.Language]) โ
gpu_id (Optional[int]) โ
optimize (Literal['efficiency', 'accuracy']) โ
pipeline (List[str]) โ
update_transformer (bool) โ
Transformers Trainer#
- class argilla.training.transformers.ArgillaTransformersTrainer(*args, **kwargs)#
- init_model(new=False)#
Initializes a model.
- Parameters:
new (bool) โ
- init_training_args()#
Initializes the training arguments.
- predict(text, as_argilla_records=True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- save(output_dir)#
The function saves the model to the path specified and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- train(output_dir)#
Trains the model.
- Parameters:
output_dir (str) โ
- update_config(**kwargs)#
Updates the setfit_model_kwargs and setfit_trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
- class argilla.client.feedback.integrations.huggingface.model_card.TransformersModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.TRANSFORMERS: 'transformers'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, tokenizer: 'PreTrainedTokenizer' = '')#
- Parameters:
language (Optional[Union[str, List[str]]]) โ
license (Optional[str]) โ
model_name (Optional[str]) โ
model_id (Optional[str]) โ
dataset_name (Optional[str]) โ
dataset_id (Optional[str]) โ
tags (Optional[List[str]]) โ
model_summary (Optional[str]) โ
model_description (Optional[str]) โ
developers (Optional[str]) โ
shared_by (Optional[str]) โ
model_type (Optional[str]) โ
finetuned_from (Optional[str]) โ
repo (Optional[str]) โ
_is_on_huggingface (bool) โ
framework (Framework) โ
train_size (Optional[float]) โ
seed (Optional[int]) โ
framework_kwargs (Dict[str, Any]) โ
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) โ
output_dir (Optional[str]) โ
library_name (Optional[str]) โ
update_config_kwargs (Dict[str, Any]) โ
tokenizer (PreTrainedTokenizer) โ
SpanMarker Trainer#
- class argilla.training.span_marker.ArgillaSpanMarkerTrainer(*args, **kwargs)#
- init_model()#
Initializes a model.
- Return type:
None
- init_training_args()#
Initializes the training arguments.
- Return type:
None
- predict(text, as_argilla_records=True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- save(output_dir)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- train(output_dir)#
We create a SetFitModel object from a pretrained model, then create a SetFitTrainer object with the model, and then train the model
- Parameters:
output_dir (str) โ
- update_config(**kwargs)#
Updates the model_kwargs and trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
- Return type:
None
TRL Trainer#
- class argilla.client.feedback.integrations.huggingface.model_card.TRLModelCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.TRL: 'trl'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>)#
- Parameters:
language (Optional[Union[str, List[str]]]) โ
license (Optional[str]) โ
model_name (Optional[str]) โ
model_id (Optional[str]) โ
dataset_name (Optional[str]) โ
dataset_id (Optional[str]) โ
tags (Optional[List[str]]) โ
model_summary (Optional[str]) โ
model_description (Optional[str]) โ
developers (Optional[str]) โ
shared_by (Optional[str]) โ
model_type (Optional[str]) โ
finetuned_from (Optional[str]) โ
repo (Optional[str]) โ
_is_on_huggingface (bool) โ
framework (Framework) โ
train_size (Optional[float]) โ
seed (Optional[int]) โ
framework_kwargs (Dict[str, Any]) โ
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) โ
output_dir (Optional[str]) โ
library_name (Optional[str]) โ
update_config_kwargs (Dict[str, Any]) โ
SentenceTransformer Trainer#
- class argilla.client.feedback.training.frameworks.sentence_transformers.ArgillaSentenceTransformersTrainer(dataset, task, prepared_data=None, model=None, seed=None, train_size=1, cross_encoder=False)#
- Parameters:
dataset (FeedbackDataset) โ
task (TrainingTaskForSentenceSimilarity) โ
model (str) โ
seed (int) โ
train_size (Optional[float]) โ
cross_encoder (bool) โ
- get_model_card_data(**card_data_kwargs)#
Generate the card data to be used for the ArgillaModelCard.
- Parameters:
card_data_kwargs โ Extra arguments provided by the user when creating the ArgillaTrainer.
- Returns:
Container for the data to be written on the ArgillaModelCard.
- Return type:
- init_model()#
Initializes a model.
- Return type:
None
- init_training_args()#
Initializes the training arguments.
- Return type:
None
- predict(text, as_argilla_records=False, **kwargs)#
Predicts the similarity of the sentences.
- Parameters:
text (Union[List[List[str]], Tuple[str, List[str]]]) โ The sentences to obtain the similarity from. Allowed inputs are: - A list with a single sentence (as a string) and a list of sentences to compare against. - A list with pair of sentences.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predicted similarities.
- Return type:
List[float]
- push_to_huggingface(repo_id, **kwargs)#
Uploads the model to [huggingfaceโs model hub](https://huggingface.co/models).
The full list of parameters can be seen at: [sentence-transformer api docs](https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.save_to_hub).
- Parameters:
repo_id (str) โ The name of the repository you want to push your model and tokenizer to. It should contain your organization name when pushing to a given organization.
- Raises:
NotImplementedError โ For CrossEncoder models, that currently arenโt implemented underneath.
- Return type:
None
- save(output_dir)#
Saves the model to the specified path.
- Parameters:
output_dir (str) โ
- Return type:
None
- train(output_dir=None)#
Trains the model.
- Parameters:
output_dir (Optional[str]) โ
- Return type:
None
- update_config(**kwargs)#
Updates the configuration of the trainer, but the parameters depend on the trainer.subclass.
- Return type:
None
- class argilla.client.feedback.integrations.huggingface.model_card.SentenceTransformerCardData(language: Union[str, List[str], NoneType] = None, license: Optional[str] = None, model_name: Optional[str] = None, model_id: Optional[str] = None, dataset_name: Optional[str] = None, dataset_id: Optional[str] = None, tags: Optional[List[str]] = <factory>, model_summary: Optional[str] = None, model_description: Optional[str] = None, developers: Optional[str] = None, shared_by: Optional[str] = None, model_type: Optional[str] = None, finetuned_from: Optional[str] = None, repo: Optional[str] = None, _is_on_huggingface: bool = False, framework: argilla.client.models.Framework = <Framework.SENTENCE_TRANSFORMERS: 'sentence-transformers'>, train_size: Optional[float] = None, seed: Optional[int] = None, framework_kwargs: Dict[str, Any] = <factory>, task: Union[argilla.client.feedback.training.schemas.base.TrainingTaskForTextClassification, argilla.client.feedback.training.schemas.base.TrainingTaskForSFT, argilla.client.feedback.training.schemas.base.TrainingTaskForRM, argilla.client.feedback.training.schemas.base.TrainingTaskForPPO, argilla.client.feedback.training.schemas.base.TrainingTaskForDPO, argilla.client.feedback.training.schemas.base.TrainingTaskForChatCompletion, argilla.client.feedback.training.schemas.base.TrainingTaskForSentenceSimilarity, NoneType] = None, output_dir: Optional[str] = None, library_name: Optional[str] = None, update_config_kwargs: Dict[str, Any] = <factory>, cross_encoder: bool = False, trainer_cls: Optional[Callable] = None)#
- Parameters:
language (Optional[Union[str, List[str]]]) โ
license (Optional[str]) โ
model_name (Optional[str]) โ
model_id (Optional[str]) โ
dataset_name (Optional[str]) โ
dataset_id (Optional[str]) โ
tags (Optional[List[str]]) โ
model_summary (Optional[str]) โ
model_description (Optional[str]) โ
developers (Optional[str]) โ
shared_by (Optional[str]) โ
model_type (Optional[str]) โ
finetuned_from (Optional[str]) โ
repo (Optional[str]) โ
_is_on_huggingface (bool) โ
framework (Framework) โ
train_size (Optional[float]) โ
seed (Optional[int]) โ
framework_kwargs (Dict[str, Any]) โ
task (Optional[Union[TrainingTaskForTextClassification, TrainingTaskForSFT, TrainingTaskForRM, TrainingTaskForPPO, TrainingTaskForDPO, TrainingTaskForChatCompletion, TrainingTaskForSentenceSimilarity]]) โ
output_dir (Optional[str]) โ
library_name (Optional[str]) โ
update_config_kwargs (Dict[str, Any]) โ
cross_encoder (bool) โ
trainer_cls (Optional[Callable]) โ