Quick Start#
First, pytorch-pfn-extras organizes the training code implemented using PyTorch using the Trainer/Evaluator classes.
Next, it provides the following interfaces for training PyTorch models.
Addition of extensions for analysis and visualization
Runtime changes
Addition of custom training steps
Custom data handling
Step 1: Use Trainer#
First, pass to the Trainer the Model and Optimizer you want to train.
import pytorch_pfn_extras as ppe
import torch
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(in_features=64, out_features=2)
self.criterion = torch.nn.NLLLoss()
def forward(self, x, target):
y = self.linear.forward(x).log_softmax(dim=1)
loss = self.criterion.forward(y, target)
return {"loss": loss}
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
device = (
"cuda:0" # or any other PyTorch devices ('cpu', etc.) or PPE runtime names
)
epochs = 3
# Create a trainer with the defined model, optimizer, and other parameters
trainer = ppe.engine.create_trainer(
models=model,
optimizers=optimizer,
max_epochs=epochs,
evaluator=ppe.engine.create_evaluator(
models=model,
device=device,
),
device=device,
)
# Send the model to device(GPU) for computation
ppe.to(model, device=device)
batch_size = 10
# Create 10 batches of random training data with dimension (batch_size x 64)
training_data = [
{
"x": torch.rand((batch_size, 64)),
"target": torch.ones((batch_size,), dtype=torch.long),
}
for _ in range(10)
]
# Create 10 batches of random validation data with dimension (batch_size x 64)
validation_data = [
{
"x": torch.rand((batch_size, 64)),
"target": torch.ones((batch_size,), dtype=torch.long),
}
for _ in range(10)
]
# Start the training and validation of the model
trainer.run(train_loader=training_data, val_loader=validation_data)
print("Finish training!")
Step 2: Get Log#
Next, collect the logs of the training progress.
import pytorch_pfn_extras as ppe
import torch
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(in_features=64, out_features=2)
self.criterion = torch.nn.NLLLoss()
def forward(self, x, target):
y = self.linear.forward(x).log_softmax(dim=1)
loss = self.criterion.forward(y, target)
return {"loss": loss}
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
device = "cuda:0"
epochs = 3
trainer = ppe.engine.create_trainer(
models=model,
optimizers=optimizer,
max_epochs=epochs,
evaluator=ppe.engine.create_evaluator(
models=model,
device=device,
options={
"eval_report_keys": [
"loss"
], # Let the value of the loss be notified to the LogReport.
},
),
device=device,
options={
"train_report_keys": [
"loss"
], # Let the value of the loss be notified to the LogReport.
},
)
trainer.extend(
ppe.training.extensions.LogReport()
) # It is an extension to collect parameters reported during training.
ppe.to(model, device=device)
batch_size = 10
training_data = [
{
"x": torch.rand((batch_size, 64)),
"target": torch.ones((batch_size,), dtype=torch.long),
}
for _ in range(10)
]
validation_data = [
{
"x": torch.rand((batch_size, 64)),
"target": torch.ones((batch_size,), dtype=torch.long),
}
for _ in range(10)
]
trainer.run(train_loader=training_data, val_loader=validation_data)
print("Finish training!")
The logs of the collected learning progress are output to ./result/log
.
Step 3: Display of progress#
Make it possible to check the progress of the learning.
import pytorch_pfn_extras as ppe
import torch
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(in_features=64, out_features=2)
self.criterion = torch.nn.NLLLoss()
def forward(self, x, target):
y = self.linear.forward(x).log_softmax(dim=1)
loss = self.criterion.forward(y, target)
return {"loss": loss}
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
device = "cuda:0"
epochs = 3
trainer = ppe.engine.create_trainer(
models=model,
optimizers=optimizer,
max_epochs=epochs,
evaluator=ppe.engine.create_evaluator(
models=model,
device=device,
options={
"eval_report_keys": ["loss"],
},
),
device=device,
options={
"train_report_keys": ["loss"],
},
)
trainer.extend(ppe.training.extensions.LogReport())
trainer.extend(ppe.training.extensions.ProgressBar())
trainer.extend(
ppe.training.extensions.PrintReport( # Displays the collected logs interactively.
[
"epoch", # epoch, iteration, elapsed_time are automatically collected by LogReport.
"iteration",
"elapsed_time",
"train/loss", # The parameters specified by train_report_keys are collected under keys with the 'train/' prefix.
"val/loss", # The parameters specified by eval_report_keys are collected under keys with the 'val/' prefix.
],
)
)
ppe.to(model, device=device)
batch_size = 10
training_data = [
{
"x": torch.rand((batch_size, 64)),
"target": torch.ones((batch_size,), dtype=torch.long),
}
for _ in range(10)
]
validation_data = [
{
"x": torch.rand((batch_size, 64)),
"target": torch.ones((batch_size,), dtype=torch.long),
}
for _ in range(10)
]
trainer.run(train_loader=training_data, val_loader=validation_data)
print("Finish training!")
Step 4: Save Model#
Finally, save the trained model.
import pytorch_pfn_extras as ppe
import torch
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(in_features=64, out_features=2)
self.criterion = torch.nn.NLLLoss()
def forward(self, x, target):
y = self.linear.forward(x).log_softmax(dim=1)
loss = self.criterion.forward(y, target)
return {"loss": loss}
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
device = "cuda:0"
epochs = 3
trainer = ppe.engine.create_trainer(
models=model,
optimizers=optimizer,
max_epochs=epochs,
evaluator=ppe.engine.create_evaluator(
models=model,
device=device,
options={
"eval_report_keys": ["loss"],
},
),
device=device,
options={
"train_report_keys": ["loss"],
},
)
trainer.extend(ppe.training.extensions.LogReport())
trainer.extend(ppe.training.extensions.ProgressBar())
trainer.extend(
ppe.training.extensions.PrintReport( # Displays the collected logs interactively.
[
"epoch", # epoch, iteration, elapsed_time are automatically collected by LogReport.
"iteration",
"elapsed_time",
"train/loss", # The parameters specified by train_report_keys are collected under keys with the 'train/' prefix.
"val/loss", # The parameters specified by eval_report_keys are collected under keys with the 'val/' prefix.
],
)
)
trainer.extend(
ppe.training.extensions.snapshot(target=model)
) # Save the model parameters after each epoch.
ppe.to(model, device=device)
batch_size = 10
training_data = [
{
"x": torch.rand((batch_size, 64)),
"target": torch.ones((batch_size,), dtype=torch.long),
}
for _ in range(10)
]
validation_data = [
{
"x": torch.rand((batch_size, 64)),
"target": torch.ones((batch_size,), dtype=torch.long),
}
for _ in range(10)
]
trainer.run(train_loader=training_data, val_loader=validation_data)
print("Finish training!")
The model parameters are stored with a file name that includes the time they were saved under ./result
.
Snapshots are generated using state_dict()
. Please refer to the official PyTorch docs for how to load the model.