how to continue training from a checkpoint with Trainer?

See original GitHub issue

❓ Questions & Help

Details

I am trying to continue training my model (gpt-2) from a checkpoint, using Trainer. However when I try to do it the model starts training from 0, not from the checkpoint. I share my code because I don’t know where I’m making the mistake.

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from transformers import TextDataset,DataCollatorForLanguageModeling, AutoTokenizer, GPT2LMHeadModel, Trainer, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained("gpt2-large")

train_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path='textfile (1).txt',
          block_size=128)

data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False,
    )

model = GPT2LMHeadModel.from_pretrained("checkpoint-9500").to(device) ##HERE I LOAD FROM CHECKPOINT

training_args = TrainingArguments(
    output_dir='./results',         # output directory
    num_train_epochs=4,              # total # of training epochs
    per_device_train_batch_size=1,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    #eval_dataset=validation_dataset,
    prediction_loss_only=True,
)

trainer.train()



Thanks a lot for the help.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:14 (5 by maintainers)

github_iconTop GitHub Comments

55reactions
sguggercommented, Sep 17, 2020

Hi there, you have to pass the checkpoint path to the method Trainer.train to resume training:

trainer.train("checkpoint-9500")

If you set your logging verbosity to the INFO level (transformers.logging.set_verbosity_info()) you should then see information about the training resuming and the number of steps skipped.

8reactions
sguggercommented, Jan 27, 2021

Trainer does not have a num_train_epochs attribute that is used, you need to set trainer.args.num_train_epochs. To be sure everything is updated, you probably need to re-instantiate the Trainer with the new TrainingArguments though.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Continuing Pre Training from Model Checkpoint
Hi, I pre-trained a language model for my own data and I want to continue the pre-training for additional steps using the last...
Read more >
Does Huggingface's "resume_from_checkpoint" work?
To see it continue training, increase your num_train_epochs before calling trainer.train() on your checkpoint.
Read more >
Saving and Loading Your Model to Resume Training in PyTorch
So in this post, we will be talking about how to save your model in the form of checkpoints and how to load...
Read more >
How to resume training - Trainer - PyTorch Lightning
I don't think that's possible since a new Trainer instance won't have any info regarding the checkpoint state saved in the previous training....
Read more >
Training checkpoints | TensorFlow Core
On this page · Setup · Saving from tf.keras training APIs · Writing checkpoints. Manual checkpointing · Loading mechanics. Deferred restorations; Manually ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found