다양한 input size 의 미니배치 학습을 위한 custom collate function
Cho et al.2023. 2. 6.
목차
import torch
def collate_fn(batch):
# Find the largest width and height in the batch
max_width = max(tensor.shape[2] for tensor in batch)
max_height = max(tensor.shape[1] for tensor in batch)
# Resize all tensors in the batch to the size of the largest width and height
resized_batch = []
for tensor in batch:
resized_tensor = torch.zeros((tensor.shape[0], max_height, max_width), dtype=tensor.dtype)
resized_tensor[:, :tensor.shape[1], :tensor.shape[2]] = tensor
resized_batch.append(resized_tensor)
return resized_batch
댓글