torch.cuda.synchronize() start = time.time() result = model(return_loss=False, rescale=True, **data) torch.cuda.synchronize() end = time.time() print(end - start, "s")