请分析并解释一下以下代码实现的功能和类中方法的实现逻辑,尽可能详细:class TrainerBase:
"""Base class for iterative trainer."""
def __init__(self):
self._models = OrderedDict()
self._optims = OrderedDict()
self._scheds = OrderedDict()
self._writer = None
def register_model(self, name="model", model=None, optim=None, sched=None):
if self.__dict__.get("_models") is None:
raise AttributeError(
"Cannot assign model before super().__init__() call"
)
if self.__dict__.get("_optims") is None:
raise AttributeError(
"Cannot assign optim before super().__init__() call"
)
if self.__dict__.get("_scheds") is None:
raise AttributeError(
"Cannot assign sched before super().__init__() call"
)
assert name not in self._models, "Found duplicate model names"
self._models[name] = model
self._optims[name] = optim
self._scheds[name] = sched
def get_model_names(self, names=None):
names_real = list(self._models.keys())
if names is not None:
names = tolist_if_not(names)
for name in names:
assert name in names_real
return names
else:
return names_real
def save_model(
self, epoch, directory, is_best=False, val_result=None, model_name=""
):
names = self.get_model_names()
for name in names:
model_dict = self._models[name].state_dict()
optim_dict = None
if self._optims[name] is not None:
optim_dict = self._optims[name].state_dict()
sched_dict = None
if self._scheds[name] is not None:
sched_dict = self._scheds[name].state_dict()
save_checkpoint(
{
"state_dict": model_dict,
"epoch": epoch + 1,
"optimizer": optim_dict,
"scheduler": sched_dict,
"val_result": val_result
},
osp.join(directory, name),
is_best=is_best,
model_name=model_name,
)
def resume_model_if_exist(self, directory):
names = self.get_model_names()
file_missing = False
for name in names:
path = osp.join(directory, name)
if not osp.exists(path):
file_missing = True
break
if file_missing:
print("No checkpoint found, train from scratch")
return 0
print(f"Found checkpoint at {directory} (will resume training)")
for name in names:
path = osp.join(directory, name)
start_epoch = resume_from_checkpoint(
path, self._models[name], self._optims[name],
self._scheds[name]
)
return start_epoch
def load_model(self, directory, epoch=None):
if not directory:
print(
"Note that load_model() is skipped as no pretrained "
"model is given (ignore this if it's done on purpose)"
)
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError(f"No model at {model_path}")
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
val_result = checkpoint["val_result"]
print(
f"Load {model_path} to {name}"
)
self._models[name].load_state_dict(state_dict)
def set_model_mode(self, mode="train", names=None):
names = self.get_model_names(names)
for name in names:
if mode == "train":
self._models[name].train()
elif mode in ["test", "eval"]:
self._models[name].eval()
else:
raise KeyError
def update_lr(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._scheds[name] is not None:
self._scheds[name].step()
def detect_anomaly(self, loss):
if not torch.isfinite(loss).all():
raise FloatingPointError("Loss is infinite or NaN!")
def init_writer(self, log_dir):
if self.__dict__.get("_writer") is None or self._writer is None:
print(f"Initialize tensorboard (log_dir={log_dir})")
self._writer = SummaryWriter(log_dir=log_dir)
def close_writer(self):
if self._writer is not None:
self._writer.close()
def write_scalar(self, tag, scalar_value, global_step=None):
if self._writer is None:
# Do nothing if writer is not initialized
# Note that writer is only used when training is needed
pass
else:
self._writer.add_scalar(tag, scalar_value, global_step)
def train(self, start_epoch, max_epoch):
"""Generic training loops."""
self.start_epoch = start_epoch
self.max_epoch = max_epoch
self.before_train()
for self.epoch in range(self.start_epoch, self.max_epoch):
self.before_epoch()
self.run_epoch()
self.after_epoch()
self.after_train()
def before_train(self):
pass
def after_train(self):
pass
def before_epoch(self):
pass
def after_epoch(self):
pass
def run_epoch(self):
raise NotImplementedError
def test(self):
raise NotImplementedError
def parse_batch_train(self, batch):
raise NotImplementedError
def parse_batch_test(self, batch):
raise NotImplementedError
def forward_backward(self, batch):
raise NotImplementedError
def model_inference(self, input):
raise NotImplementedError
def model_zero_grad(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._optims[name] is not None:
self._optims[name].zero_grad()
def model_backward(self, loss):
self.detect_anomaly(loss)
loss.backward()
def model_update(self, names=None):
names = self.get_model_names(names)
for name in names:
if self._optims[name] is not None:
self._optims[name].step()
def model_backward_and_update(self, loss, names=None):
self.model_zero_grad(names)
self.model_backward(loss)
self.model_update(names)
该代码实现了一个基础的迭代训练器的类TrainerBase,具有以下功能和方法:
__init__(self): 初始化方法,初始化了_models、_optims、_scheds和_writer这几个成员变量。
register_model(self, name="model", model=None, optim=None, sched=None): 注册模型的方法,可以为训练器添加一个模型、优化器和调度器。通过name参数指定模型的名称,model参数指定模型对象,optim参数指定优化器对象,sched参数指定调度器对象。
get_model_names(self, names=None): 获取注册的模型名称的方法。如果names参数不为空,则返回names参数中指定的模型名称列表;否则返回所有注册的模型名称列表。
save_model(self, epoch, directory, is_best=False, val_result=None, model_name=""): 保存模型的方法,将当前训练的模型保存到指定的目录中。参数epoch表示当前训练的轮数,directory表示保存模型的目录,is_best表示是否保存最好的模型,val_result表示验证结果,model_name表示模型的名称。
resume_model_if_exist(self, directory): 如果存在之前保存的模型,恢复模型的方法。参数directory表示保存模型的目录,返回开始恢复的轮数。
load_model(self, directory, epoch=None): 加载模型的方法,从指定的目录中加载模型。参数directory表示保存模型的目录,epoch表示要加载的模型的轮数,默认加载最好的模型。
set_model_mode(self, mode="train", names=None): 设置模型的模式,可以是训练模式("train"),测试模式("test")或评估模式("eval")。参数names表示要设置模式的模型名称列表,默认为所有注册的模型。
update_lr(self, names=None): 更新学习率的方法,更新所有注册模型的学习率。参数names表示要更新学习率的模型名称列表,默认为所有注册的模型。
detect_anomaly(self, loss): 检测损失是否为无穷大或NaN的方法,如果是则抛出异常。
init_writer(self, log_dir): 初始化TensorBoard的方法,创建一个SummaryWriter对象用于记录训练过程中的日志。参数log_dir表示日志保存的目录。
close_writer(self): 关闭TensorBoard的方法,关闭SummaryWriter对象。
write_scalar(self, tag, scalar_value, global_step=None): 向TensorBoard中写入标量数据的方法。参数tag表示数据的标签,scalar_value表示标量值,global_step表示当前的全局步数。
train(self, start_epoch, max_epoch): 训练方法,开始训练模型。参数start_epoch表示开始的轮数,max_epoch表示最大的轮数。
before_train(self), after_train(self), before_epoch(self), after_epoch(self): 训练过程中的回调方法,在训练开始前、结束后、每轮开始前、结束后调用。
run_epoch(self): 执行每一轮训练的方法,需要子类实现。
test(self): 测试方法,需要子类实现。
parse_batch_train(self, batch), parse_batch_test(self, batch), forward_backward(self, batch), model_inference(self, input): 进行训练、测试、前向传播和模型推断的方法,需要子类实现。
model_zero_grad(self, names=None): 将模型的梯度清零的方法,参数names表示要清零梯度的模型名称列表,默认为所有注册的模型。
model_backward(self, loss): 模型的反向传播方法,计算损失函数关于模型参数的梯度。
model_update(self, names=None): 模型参数的更新方法,参数names表示要更新参数的模型名称列表,默认为所有注册的模型。
model_backward_and_update(self, loss, names=None): 模型的反向传播和参数更新方法,参数loss表示损失函数,参数names表示要更新参数的模型名称列表,默认为所有注册的模型。
以同样的方法分析以下代码,包括方法的功能和实现逻辑。必要时结合上一个问题中的类来分析class SimpleTrainer(TrainerBase):
"""A simple trainer class implementing generic functions."""
def __init__(self, cfg):
super().__init__()
self.check_cfg(cfg)
if torch.cuda.is_available() and cfg.USE_CUDA:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# Save as attributes some frequently used variables
self.start_epoch = self.epoch = 0
self.max_epoch = cfg.OPTIM.MAX_EPOCH
self.output_dir = cfg.OUTPUT_DIR
self.cfg = cfg
self.build_data_loader()
self.build_model()
self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)
self.best_result = -np.inf
def check_cfg(self, cfg):
"""Check whether some variables are set correctly for
the trainer (optional).
For example, a trainer might require a particular sampler
for training such as 'RandomDomainSampler', so it is good
to do the checking:
assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler'
"""
pass
def build_data_loader(self):
"""Create essential data-related attributes.
A re-implementation of this method must create the
same attributes (self.dm is optional).
"""
dm = DataManager(self.cfg)
self.train_loader_x = dm.train_loader_x
self.train_loader_u = dm.train_loader_u # optional, can be None
self.val_loader = dm.val_loader # optional, can be None
self.test_loader = dm.test_loader
self.num_classes = dm.num_classes
self.num_source_domains = dm.num_source_domains
self.lab2cname = dm.lab2cname # dict {label: classname}
self.dm = dm
def build_model(self):
"""Build and register model.
The default builds a classification model along with its
optimizer and scheduler.
Custom trainers can re-implement this method if necessary.
"""
cfg = self.cfg
print("Building model")
self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
print(f"# params: {count_num_param(self.model):,}")
self.optim = build_optimizer(self.model, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("model", self.model, self.optim, self.sched)
device_count = torch.cuda.device_count()
if device_count > 1:
print(f"Detected {device_count} GPUs (use nn.DataParallel)")
self.model = nn.DataParallel(self.model)
def train(self):
super().train(self.start_epoch, self.max_epoch)
def before_train(self):
directory = self.cfg.OUTPUT_DIR
if self.cfg.RESUME:
directory = self.cfg.RESUME
self.start_epoch = self.resume_model_if_exist(directory)
# Initialize summary writer
writer_dir = osp.join(self.output_dir, "tensorboard")
mkdir_if_missing(writer_dir)
self.init_writer(writer_dir)
# Remember the starting time (for computing the elapsed time)
self.time_start = time.time()
def after_train(self):
print("Finish training")
do_test = not self.cfg.TEST.NO_TEST
if do_test:
if self.cfg.TEST.FINAL_MODEL == "best_val":
print("Deploy the model with the best val performance")
self.load_model(self.output_dir)
else:
print("Deploy the last-epoch model")
self.test()
# Show elapsed time
elapsed = round(time.time() - self.time_start)
elapsed = str(datetime.timedelta(seconds=elapsed))
print(f"Elapsed: {elapsed}")
# Close writer
self.close_writer()
def after_epoch(self):
last_epoch = (self.epoch + 1) == self.max_epoch
do_test = not self.cfg.TEST.NO_TEST
meet_checkpoint_freq = (
(self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0
if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False
)
if do_test and self.cfg.TEST.FINAL_MODEL == "best_val":
curr_result = self.test(split="val")
is_best = curr_result > self.best_result
if is_best:
self.best_result = curr_result
self.save_model(
self.epoch,
self.output_dir,
val_result=curr_result,
model_name="model-best.pth.tar"
)
if meet_checkpoint_freq or last_epoch:
self.save_model(self.epoch, self.output_dir)
@torch.no_grad()
def test(self, split=None):
"""A generic testing pipeline."""
self.set_model_mode("eval")
self.evaluator.reset()
if split is None:
split = self.cfg.TEST.SPLIT
if split == "val" and self.val_loader is not None:
data_loader = self.val_loader
else:
split = "test" # in case val_loader is None
data_loader = self.test_loader
print(f"Evaluate on the *{split}* set")
for batch_idx, batch in enumerate(tqdm(data_loader)):
input, label = self.parse_batch_test(batch)
output = self.model_inference(input)
self.evaluator.process(output, label)
results = self.evaluator.evaluate()
for k, v in results.items():
tag = f"{split}/{k}"
self.write_scalar(tag, v, self.epoch)
return list(results.values())[0]
def model_inference(self, input):
return self.model(input)
def parse_batch_test(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def get_current_lr(self, names=None):
names = self.get_model_names(names)
name = names[0]
return self._optims[name].param_groups[0]["lr"]
该代码实现了一个简单的训练器类SimpleTrainer,继承自TrainerBase,具有以下功能和方法:
__init__(self, cfg): 初始化方法,初始化了一些属性和变量,包括配置信息cfg、设备类型device、开始轮数start_epoch、当前轮数epoch、最大轮数max_epoch、输出目录output_dir、配置信息cfg、数据加载器train_loader_x、train_loader_u、val_loader和test_loader、类别数num_classes、源域数num_source_domains、标签到类名的映射lab2cname、数据管理器dm、模型model、优化器optim、学习率调度器sched、评估器evaluator和最佳结果best_result。
check_cfg(self, cfg): 检查配置信息的方法,可以在子类中重新实现。
build_data_loader(self): 创建数据加载器的方法,创建了DataManager对象,并将其属性赋值给相应的属性。
build_model(self): 构建和注册模型的方法,创建了一个SimpleNet模型对象,并将其注册到训练器中。
train(self): 训练方法,调用了父类的train方法。
before_train(self): 训练开始前的回调方法,初始化了一些变量和属性,包括输出目录、开始轮数和摘要写入器。
after_train(self): 训练结束后的回调方法,进行一些收尾工作,包括进行测试、计算耗时、关闭摘要写入器。
after_epoch(self): 每个轮次结束后的回调方法,进行一些操作,如保存模型和进行测试。
test(self, split=None): 测试方法,进行模型的评估,并返回评估结果。
model_inference(self, input): 模型推断的方法,调用模型的前向传播。
parse_batch_test(self, batch): 解析测试数据批次的方法,将输入和标签转移到指定的设备上。
get_current_lr(self, names=None): 获取当前学习率的方法,返回第一个模型的当前学习率。
以同样的方法分析以下代码,包括方法的功能和实现逻辑。必要时结合之前的问题中的类来分析class class CausalTrainer(SimpleTrainer):
"""A base trainer using labeled data only."""
def __init__(self, cfg):
super().__init__(cfg)
self.random_train_loader_x_list = self.dm.random_train_loader_x_list
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.num_batches = len(self.train_loader_x)
self.random_train_loader_x_list_iteration = [iter(train_loader) for train_loader in self.random_train_loader_x_list]
end = time.time()
for self.batch_idx, batch in enumerate(self.train_loader_x):
data_time.update(time.time() - end)
# random_chosen_ids = np.random.permutation(np.arange(3))
# iter_idx1, iter_idx2 = random_chosen_ids[0], random_chosen_ids[1]
try:
random_batch, random_batch2, random_batch3 = next(self.random_train_loader_x_list_iteration[0]), next(self.random_train_loader_x_list_iteration[1]),\
next(self.random_train_loader_x_list_iteration[2])
except StopIteration:
self.random_train_loader_x_list_iteration = [iter(train_loader) for train_loader in self.random_train_loader_x_list]
random_batch, random_batch2, random_batch3 = next(self.random_train_loader_x_list_iteration[0]), next(self.random_train_loader_x_list_iteration[1]),\
next(self.random_train_loader_x_list_iteration[2])
loss_summary = self.forward_backward(batch, random_batch, random_batch2, random_batch3)
batch_time.update(time.time() - end)
losses.update(loss_summary)
meet_freq = (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0
only_few_batches = self.num_batches < self.cfg.TRAIN.PRINT_FREQ
if meet_freq or only_few_batches:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
info = []
info += [f"epoch [{self.epoch + 1}/{self.max_epoch}]"]
info += [f"batch [{self.batch_idx + 1}/{self.num_batches}]"]
info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"]
info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"]
info += [f"{losses}"]
info += [f"lr {self.get_current_lr():.4e}"]
info += [f"eta {eta}"]
print(" ".join(info))
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
# if (self.batch_idx + 1) == self.num_batches:
if True:
self.update_lr()
class TrainerStyleTransfer(SimpleTrainer):
"""A base trainer using labeled data only."""
def __init__(self, cfg):
super().__init__(cfg)
self.random_train_loader_x = self.dm.random_train_loader_x_1
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.num_batches = len(self.train_loader_x)
self.random_train_loader_x_iteration = iter(self.random_train_loader_x)
end = time.time()
for self.batch_idx, batch in enumerate(self.train_loader_x):
data_time.update(time.time() - end)
random_batch = next(self.random_train_loader_x_iteration)
loss_summary = self.forward_backward(batch, random_batch)
batch_time.update(time.time() - end)
losses.update(loss_summary)
meet_freq = (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0
only_few_batches = self.num_batches < self.cfg.TRAIN.PRINT_FREQ
if meet_freq or only_few_batches:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
info = []
info += [f"epoch [{self.epoch + 1}/{self.max_epoch}]"]
info += [f"batch [{self.batch_idx + 1}/{self.num_batches}]"]
info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"]
info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"]
info += [f"{losses}"]
info += [f"lr {self.get_current_lr():.4e}"]
info += [f"eta {eta}"]
print(" ".join(info))
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
def after_train(self):
self.save_neural_transfer_state(self.model_st.encoder.vgg19.state_dict(), self.model_st.decoder.state_dict(), self.model_st.optimiser.state_dict(), self.iters, self.output_dir)
super().after_train()
该代码中定义了两个类CausalTrainer和TrainerStyleTransfer,它们都继承自SimpleTrainer,并重写了run_epoch方法和部分其他方法。
CausalTrainer类:__init__(self, cfg): 初始化方法,调用了父类的初始化方法,并添加了random_train_loader_x_list属性。
run_epoch(self): 重写了父类的run_epoch方法,用于执行每个轮次的训练。在这个方法中,首先设置模型的模式为训练模式,然后对每个批次进行训练。在每个批次的训练过程中,除了使用batch作为有标签数据进行训练外,还使用了random_batch、random_batch2和random_batch3作为无标签数据进行训练。训练过程中计算损失和更新模型参数,并记录训练过程中的一些指标和日志信息。
after_train(self): 重写了父类的after_train方法,在训练结束后进行一些操作,包括保存模型状态。
TrainerStyleTransfer类:__init__(self, cfg): 初始化方法,调用了父类的初始化方法,并添加了random_train_loader_x属性。
run_epoch(self): 重写了父类的run_epoch方法,用于执行每个轮次的训练。在这个方法中,首先设置模型的模式为训练模式,然后对每个批次进行训练。在每个批次的训练过程中,除了使用batch作为有标签数据进行训练外,还使用了random_batch作为无标签数据进行训练。训练过程中计算损失和更新模型参数,并记录训练过程中的一些指标和日志信息。
after_train(self): 重写了父类的after_train方法,在训练结束后进行一些操作,包括保存模型状态。
以上是对CausalTrainer和TrainerStyleTransfer类的方法功能和实现逻辑的分析。
如果要换一种方式实现CausalTrainer和TrainerStyleTransfer类的功能和方法,该如何实现?请给出具体代码
以下是使用不同的方式实现CausalTrainer和TrainerStyleTransfer类的代码:
class CausalTrainer(SimpleTrainer):
"""A base trainer using labeled data only."""
def __init__(self, cfg):
super().__init__(cfg)
self.random_train_loader_x_list = self.dm.random_train_loader_x_list
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.num_batches = len(self.train_loader_x)
random_train_loader_x_list_iteration = [iter(train_loader) for train_loader in self.random_train_loader_x_list]
end = time.time()
for self.batch_idx, batch in enumerate(self.train_loader_x):
data_time.update(time.time() - end)
try:
random_batch, random_batch2, random_batch3 = next(random_train_loader_x_list_iteration[0]), next(random_train_loader_x_list_iteration[1]),\
next(random_train_loader_x_list_iteration[2])
except StopIteration:
random_train_loader_x_list_iteration = [iter(train_loader) for train_loader in self.random_train_loader_x_list]
random_batch, random_batch2, random_batch3 = next(random_train_loader_x_list_iteration[0]), next(random_train_loader_x_list_iteration[1]),\
next(random_train_loader_x_list_iteration[2])
loss_summary = self.forward_backward(batch, random_batch, random_batch2, random_batch3)
batch_time.update(time.time() - end)
losses.update(loss_summary)
meet_freq = (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0
only_few_batches = self.num_batches < self.cfg.TRAIN.PRINT_FREQ
if meet_freq or only_few_batches:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
info = []
info += [f"epoch [{self.epoch + 1}/{self.max_epoch}]"]
info += [f"batch [{self.batch_idx + 1}/{self.num_batches}]"]
info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"]
info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"]
info += [f"{losses}"]
info += [f"lr {self.get_current_lr():.4e}"]
info += [f"eta {eta}"]
print(" ".join(info))
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
if True:
self.update_lr()
class TrainerStyleTransfer(SimpleTrainer):
"""A base trainer using labeled data only."""
def __init__(self, cfg):
super().__init__(cfg)
self.random_train_loader_x = self.dm.random_train_loader_x_1
def run_epoch(self):
self.set_model_mode("train")
losses = MetricMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.num_batches = len(self.train_loader_x)
random_train_loader_x_iteration = iter(self.random_train_loader_x)
end = time.time()
for self.batch_idx, batch in enumerate(self.train_loader_x):
data_time.update(time.time() - end)
random_batch = next(random_train_loader_x_iteration)
loss_summary = self.forward_backward(batch, random_batch)
batch_time.update(time.time() - end)
losses.update(loss_summary)
meet_freq = (self.batch_idx + 1) % self.cfg.TRAIN.PRINT_FREQ == 0
only_few_batches = self.num_batches < self.cfg.TRAIN.PRINT_FREQ
if meet_freq or only_few_batches:
nb_remain = 0
nb_remain += self.num_batches - self.batch_idx - 1
nb_remain += (
self.max_epoch - self.epoch - 1
) * self.num_batches
eta_seconds = batch_time.avg * nb_remain
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
info = []
info += [f"epoch [{self.epoch + 1}/{self.max_epoch}]"]
info += [f"batch [{self.batch_idx + 1}/{self.num_batches}]"]
info += [f"time {batch_time.val:.3f} ({batch_time.avg:.3f})"]
info += [f"data {data_time.val:.3f} ({data_time.avg:.3f})"]
info += [f"{losses}"]
info += [f"lr {self.get_current_lr():.4e}"]
info += [f"eta {eta}"]
print(" ".join(info))
n_iter = self.epoch * self.num_batches + self.batch_idx
for name, meter in losses.meters.items():
self.write_scalar("train/" + name, meter.avg, n_iter)
self.write_scalar("train/lr", self.get_current_lr(), n_iter)
end = time.time()
def after_train(self):
self.save_neural_transfer_state(self.model_st.encoder.vgg19.state_dict(), self.model_st.decoder.state_dict(), self.model_st.optimiser.state_dict(), self.iters, self.output_dir)
super().after_train()
这些代码与之前的实现功能和方法相同,只是调整了run_epoch方法中使用随机训练数据的方式。在CausalTrainer中,使用了一个名为random_train_loader_x_list_iteration的列表存储了每个随机训练数据加载器的迭代器,并在每个批次中使用该迭代器获取随机训练数据。在TrainerStyleTransfer中,使用了一个名为random_train_loader_x_iteration的迭代器来获取随机训练数据。其余的代码保持不变。