Resources/Answers from ChatGPT

다양한 input size 의 미니배치 학습을 위한 custom collate function

Cho et al. 2023. 2. 6.

import torch

def collate_fn(batch):
    # Find the largest width and height in the batch
    max_width = max(tensor.shape[2] for tensor in batch)
    max_height = max(tensor.shape[1] for tensor in batch)
    
    # Resize all tensors in the batch to the size of the largest width and height
    resized_batch = []
    for tensor in batch:
        resized_tensor = torch.zeros((tensor.shape[0], max_height, max_width), dtype=tensor.dtype)
        resized_tensor[:, :tensor.shape[1], :tensor.shape[2]] = tensor
        resized_batch.append(resized_tensor)
        
    return resized_batch

'Resources > Answers from ChatGPT' 카테고리의 다른 글

[Letter of motivation] Clinical artificial intelligence  (0) 2023.03.16
Answers from ChatGPT  (0) 2023.01.04

댓글