import os from typing import Dict import numpy as np from omegaconf import DictConfig, ListConfig import torch from torch.utils.data import Dataset from pathlib import Path import json from PIL import Image from torchvision import transforms from einops import rearrange from ldm.util import instantiate_from_config from datasets import load_dataset import copy import csv import cv2 # Some hacky things to make experimentation easier def make_transform_multi_folder_data(paths, caption_files=None, **kwargs): ds = make_multi_folder_data(paths, caption_files, **kwargs) return TransformDataset(ds) def make_nfp_data(base_path): dirs = list(Path(base_path).glob("*/")) print(f"Found {len(dirs)} folders") print(dirs) tforms = [transforms.Resize(512), transforms.CenterCrop(512)] datasets = [NfpDataset(x, image_transforms=copy.copy(tforms), default_caption="A view from a train window") for x in dirs] return torch.utils.data.ConcatDataset(datasets) class VideoDataset(Dataset): def __init__(self, root_dir, image_transforms, caption_file, offset=8, n=2): self.root_dir = Path(root_dir) self.caption_file = caption_file self.n = n ext = "mp4" self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}"))) self.offset = offset if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] image_transforms.extend([transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms with open(self.caption_file) as f: reader = csv.reader(f) rows = [row for row in reader] self.captions = dict(rows) def __len__(self): return len(self.paths) def __getitem__(self, index): for i in range(10): try: return self._load_sample(index) except Exception: # Not really good enough but... print("uh oh") def _load_sample(self, index): n = self.n filename = self.paths[index] min_frame = 2*self.offset + 2 vid = cv2.VideoCapture(str(filename)) max_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT)) curr_frame_n = random.randint(min_frame, max_frames) vid.set(cv2.CAP_PROP_POS_FRAMES,curr_frame_n) _, curr_frame = vid.read() prev_frames = [] for i in range(n): prev_frame_n = curr_frame_n - (i+1)*self.offset vid.set(cv2.CAP_PROP_POS_FRAMES,prev_frame_n) _, prev_frame = vid.read() prev_frame = self.tform(Image.fromarray(prev_frame[...,::-1])) prev_frames.append(prev_frame) vid.release() caption = self.captions[filename.name] data = { "image": self.tform(Image.fromarray(curr_frame[...,::-1])), "prev": torch.cat(prev_frames, dim=-1), "txt": caption } return data # end hacky things def make_tranforms(image_transforms): if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] image_transforms.extend([transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) image_transforms = transforms.Compose(image_transforms) return image_transforms def make_multi_folder_data(paths, caption_files=None, **kwargs): """Make a concat dataset from multiple folders Don't suport captions yet If paths is a list, that's ok, if it's a Dict interpret it as: k=folder v=n_times to repeat that """ list_of_paths = [] if isinstance(paths, (Dict, DictConfig)): assert caption_files is None, \ "Caption files not yet supported for repeats" for folder_path, repeats in paths.items(): list_of_paths.extend([folder_path]*repeats) paths = list_of_paths if caption_files is not None: datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] else: datasets = [FolderData(p, **kwargs) for p in paths] return torch.utils.data.ConcatDataset(datasets) class NfpDataset(Dataset): def __init__(self, root_dir, image_transforms=[], ext="jpg", default_caption="", ) -> None: """assume sequential frames and a deterministic transform""" self.root_dir = Path(root_dir) self.default_caption = default_caption self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}"))) self.tform = make_tranforms(image_transforms) def __len__(self): return len(self.paths) - 1 def __getitem__(self, index): prev = self.paths[index] curr = self.paths[index+1] data = {} data["image"] = self._load_im(curr) data["prev"] = self._load_im(prev) data["txt"] = self.default_caption return data def _load_im(self, filename): im = Image.open(filename).convert("RGB") return self.tform(im) class FolderData(Dataset): def __init__(self, root_dir, caption_file=None, image_transforms=[], ext="jpg", default_caption="", postprocess=None, return_paths=False, ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) """ self.root_dir = Path(root_dir) self.default_caption = default_caption self.return_paths = return_paths if isinstance(postprocess, DictConfig): postprocess = instantiate_from_config(postprocess) self.postprocess = postprocess if caption_file is not None: with open(caption_file, "rt") as f: ext = Path(caption_file).suffix.lower() if ext == ".json": captions = json.load(f) elif ext == ".jsonl": lines = f.readlines() lines = [json.loads(x) for x in lines] captions = {x["file_name"]: x["text"].strip("\n") for x in lines} else: raise ValueError(f"Unrecognised format: {ext}") self.captions = captions else: self.captions = None if not isinstance(ext, (tuple, list, ListConfig)): ext = [ext] # Only used if there is no caption file self.paths = [] for e in ext: self.paths.extend(sorted(list(self.root_dir.rglob(f"*.{e}")))) self.tform = make_tranforms(image_transforms) def __len__(self): if self.captions is not None: return len(self.captions.keys()) else: return len(self.paths) def __getitem__(self, index): data = {} if self.captions is not None: chosen = list(self.captions.keys())[index] caption = self.captions.get(chosen, None) if caption is None: caption = self.default_caption filename = self.root_dir/chosen else: filename = self.paths[index] if self.return_paths: data["path"] = str(filename) im = Image.open(filename).convert("RGB") im = self.process_im(im) data["image"] = im if self.captions is not None: data["txt"] = caption else: data["txt"] = self.default_caption if self.postprocess is not None: data = self.postprocess(data) return data def process_im(self, im): im = im.convert("RGB") return self.tform(im) import random class TransformDataset(): def __init__(self, ds, extra_label="sksbspic"): self.ds = ds self.extra_label = extra_label self.transforms = { "align": transforms.Resize(768), "centerzoom": transforms.CenterCrop(768), "randzoom": transforms.RandomCrop(768), } def __getitem__(self, index): data = self.ds[index] im = data['image'] im = im.permute(2,0,1) # In case data is smaller than expected im = transforms.Resize(1024)(im) tform_name = random.choice(list(self.transforms.keys())) im = self.transforms[tform_name](im) im = im.permute(1,2,0) data['image'] = im data['txt'] = data['txt'] + f" {self.extra_label} {tform_name}" return data def __len__(self): return len(self.ds) def hf_dataset( name, image_transforms=[], image_column="image", text_column="text", split='train', image_key='image', caption_key='txt', ): """Make huggingface dataset with appropriate list of transforms applied """ ds = load_dataset(name, split=split) tform = make_tranforms(image_transforms) assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}" def pre_process(examples): processed = {} processed[image_key] = [tform(im) for im in examples[image_column]] processed[caption_key] = examples[text_column] return processed ds.set_transform(pre_process) return ds def hf_dataset_RSITMD( name, data_path, train_json_path, val_json_path, image_transforms=[], image_column="image", text_column="text", split='train', image_key='image', caption_key='txt', ): """Make huggingface dataset with appropriate list of transforms applied """ data_files = { "train": train_json_path, "validation": val_json_path } ds = load_dataset("json", data_files=data_files, field="data") if split == 'train': ds = ds['train'] else: ds = ds['validation'] tform = make_tranforms(image_transforms) assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}" def pre_process(examples): examples['image'] = [Image.open(os.path.join(data_path, x)) for x in examples['image']] processed = {} processed[image_key] = [tform(im) for im in examples[image_column]] processed[caption_key] = examples[text_column] return processed ds.set_transform(pre_process) return ds class TextOnly(Dataset): def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): """Returns only captions with dummy images""" self.output_size = output_size self.image_key = image_key self.caption_key = caption_key if isinstance(captions, Path): self.captions = self._load_caption_file(captions) else: self.captions = captions if n_gpus > 1: # hack to make sure that all the captions appear on each gpu repeated = [n_gpus*[x] for x in self.captions] self.captions = [] [self.captions.extend(x) for x in repeated] def __len__(self): return len(self.captions) def __getitem__(self, index): dummy_im = torch.zeros(3, self.output_size, self.output_size) dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') return {self.image_key: dummy_im, self.caption_key: self.captions[index]} def _load_caption_file(self, filename): with open(filename, 'rt') as f: captions = f.readlines() return [x.strip('\n') for x in captions] import random import json class IdRetreivalDataset(FolderData): def __init__(self, ret_file, *args, **kwargs): super().__init__(*args, **kwargs) with open(ret_file, "rt") as f: self.ret = json.load(f) def __getitem__(self, index): data = super().__getitem__(index) key = self.paths[index].name matches = self.ret[key] if len(matches) > 0: retreived = random.choice(matches) else: retreived = key filename = self.root_dir/retreived im = Image.open(filename).convert("RGB") im = self.process_im(im) # data["match"] = im data["match"] = torch.cat((data["image"], im), dim=-1) return data