update
This commit is contained in:
parent
f514f06b1e
commit
886c883561
@ -79,6 +79,9 @@ data:
|
|||||||
target: ldm.data.simple.hf_dataset_RSITMD
|
target: ldm.data.simple.hf_dataset_RSITMD
|
||||||
params:
|
params:
|
||||||
name: RSITMD-captions
|
name: RSITMD-captions
|
||||||
|
data_path: ./data/RSITMD
|
||||||
|
train_json_path: ./data/RSITMD/hf_train.json
|
||||||
|
val_json_path: ./data/RSITMD/hf_val.json
|
||||||
image_transforms:
|
image_transforms:
|
||||||
- target: torchvision.transforms.Resize
|
- target: torchvision.transforms.Resize
|
||||||
params:
|
params:
|
||||||
|
|||||||
1
data/RSITMD/hf_train.json
Normal file
1
data/RSITMD/hf_train.json
Normal file
File diff suppressed because one or more lines are too long
1
data/RSITMD/hf_val.json
Normal file
1
data/RSITMD/hf_val.json
Normal file
File diff suppressed because one or more lines are too long
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from omegaconf import DictConfig, ListConfig
|
from omegaconf import DictConfig, ListConfig
|
||||||
@ -298,6 +299,9 @@ def hf_dataset(
|
|||||||
|
|
||||||
def hf_dataset_RSITMD(
|
def hf_dataset_RSITMD(
|
||||||
name,
|
name,
|
||||||
|
data_path,
|
||||||
|
train_json_path,
|
||||||
|
val_json_path,
|
||||||
image_transforms=[],
|
image_transforms=[],
|
||||||
image_column="image",
|
image_column="image",
|
||||||
text_column="text",
|
text_column="text",
|
||||||
@ -309,8 +313,8 @@ def hf_dataset_RSITMD(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
data_files = {
|
data_files = {
|
||||||
"train": "/mmu_nlp_ssd/yuanzhiqiang/dif/data/RSITMD/hf_train.json",
|
"train": train_json_path,
|
||||||
"validation": "/mmu_nlp_ssd/yuanzhiqiang/dif/data/RSITMD/hf_val.json"
|
"validation": val_json_path
|
||||||
}
|
}
|
||||||
|
|
||||||
ds = load_dataset("json", data_files=data_files, field="data")
|
ds = load_dataset("json", data_files=data_files, field="data")
|
||||||
@ -326,7 +330,7 @@ def hf_dataset_RSITMD(
|
|||||||
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
|
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
|
||||||
|
|
||||||
def pre_process(examples):
|
def pre_process(examples):
|
||||||
examples['image'] = [Image.open(x) for x in examples['image']]
|
examples['image'] = [Image.open(os.path.join(data_path, x)) for x in examples['image']]
|
||||||
|
|
||||||
processed = {}
|
processed = {}
|
||||||
processed[image_key] = [tform(im) for im in examples[image_column]]
|
processed[image_key] = [tform(im) for im in examples[image_column]]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user