Initial commit for: Add training script that uses our data

pull/891/head
mansikataria 4 years ago
parent 9ea8242729
commit f78ab556ec

@ -2,18 +2,21 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import csv
import torch.utils.data as data
import os
import torch
import logging
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from sklearn import preprocessing
from .parsers import create_parser
_logger = logging.getLogger(__name__)
_ERROR_RETRY = 50
@ -144,3 +147,32 @@ class AugMixDataset(torch.utils.data.Dataset):
def __len__(self):
return len(self.dataset)
class COAIImageClassDataset(Dataset):
def __init__(self, dict, base_path='', transform=None):
self.transform = transform
self.base_path = base_path
self.dict = dict
df = pd.DataFrame(list(dict.items()), columns=['image', 'label'])
classes = df['label'].unique()
le = preprocessing.LabelEncoder()
le.fit(classes)
df['encoded_label'] = le.transform(df['label'])
self.df = df
def __len__(self):
index = self.df.index
number_of_rows = len(index)
return number_of_rows
def __getitem__(self, index):
img_path = self.df.iloc[index]['image']
image = Image.open(self.base_path + img_path)
np_img = np.array(image)
# print(np_img.shape) #(h=512,w=512,c=3)
if self.transform:
np_img = self.transform(np_img)
# print(np_img.shape) #(1,256,256)
return np_img, self.df.iloc[index]['encoded_label']

@ -1,6 +1,7 @@
import csv
import os
from .dataset import IterableImageDataset, ImageDataset
from .dataset import IterableImageDataset, ImageDataset, COAIImageClassDataset
def _search_split(root, split):
@ -21,6 +22,10 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin
if name.startswith('tfds'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
elif name.startswith('coaiclass'):
# Get Dict from csv(current implementation)/mongodb(needs to be added)
dict = _get_dict_from_csv(root)
ds = COAIImageClassDataset(dict=dict)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
@ -28,3 +33,15 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs)
return ds
def _convert(lst):
res_dct = {lst[i][0]: lst[i][1] for i in range(len(lst))}
return res_dct
def _get_dict_from_csv(data_folder):
with open(data_folder + '/train.csv', 'r') as f:
reader = csv.reader(f)
data = [row for row in reader]
return _convert(data)
Loading…
Cancel
Save