This commit is contained in:
xiaoyuan1996 2023-05-06 22:21:59 +08:00
parent f514f06b1e
commit 886c883561
4 changed files with 12 additions and 3 deletions

View File

@ -79,6 +79,9 @@ data:
target: ldm.data.simple.hf_dataset_RSITMD
params:
name: RSITMD-captions
data_path: ./data/RSITMD
train_json_path: ./data/RSITMD/hf_train.json
val_json_path: ./data/RSITMD/hf_val.json
image_transforms:
- target: torchvision.transforms.Resize
params:

File diff suppressed because one or more lines are too long

1
data/RSITMD/hf_val.json Normal file

File diff suppressed because one or more lines are too long

View File

@ -1,3 +1,4 @@
import os
from typing import Dict
import numpy as np
from omegaconf import DictConfig, ListConfig
@ -298,6 +299,9 @@ def hf_dataset(
def hf_dataset_RSITMD(
name,
data_path,
train_json_path,
val_json_path,
image_transforms=[],
image_column="image",
text_column="text",
@ -309,8 +313,8 @@ def hf_dataset_RSITMD(
"""
data_files = {
"train": "/mmu_nlp_ssd/yuanzhiqiang/dif/data/RSITMD/hf_train.json",
"validation": "/mmu_nlp_ssd/yuanzhiqiang/dif/data/RSITMD/hf_val.json"
"train": train_json_path,
"validation": val_json_path
}
ds = load_dataset("json", data_files=data_files, field="data")
@ -326,7 +330,7 @@ def hf_dataset_RSITMD(
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
def pre_process(examples):
examples['image'] = [Image.open(x) for x in examples['image']]
examples['image'] = [Image.open(os.path.join(data_path, x)) for x in examples['image']]
processed = {}
processed[image_key] = [tform(im) for im in examples[image_column]]