update
This commit is contained in:
parent
f514f06b1e
commit
886c883561
@ -79,6 +79,9 @@ data:
|
||||
target: ldm.data.simple.hf_dataset_RSITMD
|
||||
params:
|
||||
name: RSITMD-captions
|
||||
data_path: ./data/RSITMD
|
||||
train_json_path: ./data/RSITMD/hf_train.json
|
||||
val_json_path: ./data/RSITMD/hf_val.json
|
||||
image_transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
|
||||
1
data/RSITMD/hf_train.json
Normal file
1
data/RSITMD/hf_train.json
Normal file
File diff suppressed because one or more lines are too long
1
data/RSITMD/hf_val.json
Normal file
1
data/RSITMD/hf_val.json
Normal file
File diff suppressed because one or more lines are too long
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
from omegaconf import DictConfig, ListConfig
|
||||
@ -298,6 +299,9 @@ def hf_dataset(
|
||||
|
||||
def hf_dataset_RSITMD(
|
||||
name,
|
||||
data_path,
|
||||
train_json_path,
|
||||
val_json_path,
|
||||
image_transforms=[],
|
||||
image_column="image",
|
||||
text_column="text",
|
||||
@ -309,8 +313,8 @@ def hf_dataset_RSITMD(
|
||||
"""
|
||||
|
||||
data_files = {
|
||||
"train": "/mmu_nlp_ssd/yuanzhiqiang/dif/data/RSITMD/hf_train.json",
|
||||
"validation": "/mmu_nlp_ssd/yuanzhiqiang/dif/data/RSITMD/hf_val.json"
|
||||
"train": train_json_path,
|
||||
"validation": val_json_path
|
||||
}
|
||||
|
||||
ds = load_dataset("json", data_files=data_files, field="data")
|
||||
@ -326,7 +330,7 @@ def hf_dataset_RSITMD(
|
||||
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
|
||||
|
||||
def pre_process(examples):
|
||||
examples['image'] = [Image.open(x) for x in examples['image']]
|
||||
examples['image'] = [Image.open(os.path.join(data_path, x)) for x in examples['image']]
|
||||
|
||||
processed = {}
|
||||
processed[image_key] = [tform(im) for im in examples[image_column]]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user