mtomnet/tbd/utils/preprocess_img.py

37 lines
954 B
Python
Raw Permalink Normal View History

2025-01-10 15:39:20 +01:00
import glob
import cv2
import torchvision.transforms as T
import torch
import os
from tqdm import tqdm
PATH_IN = "/scratch/bortoletto/data/tbd/images"
PATH_OUT = "/scratch/bortoletto/data/tbd/images_norm"
normalisation_steps = [
T.ToTensor(),
T.Resize((128,128)),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
preprocess_img = T.Compose(normalisation_steps)
def main():
print(f"{PATH_IN}/*/*/*.jpg")
all_img = glob.glob(f"{PATH_IN}/*/*/*.jpg")
print(len(all_img))
for img_path in tqdm(all_img):
new_img = preprocess_img(cv2.imread(img_path)).numpy()
img_path_split = img_path.split("/")
os.makedirs(f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}", exist_ok=True)
out_img = f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}/{img_path_split[-1][:-4]}.pt"
torch.save(new_img, out_img)
if __name__ == '__main__':
main()