From f78ab556ecaeb03fa54bf51313a781a439c259d1 Mon Sep 17 00:00:00 2001 From: mansikataria Date: Tue, 28 Sep 2021 19:25:57 +0530 Subject: [PATCH] Initial commit for: Add training script that uses our data --- timm/data/dataset.py | 38 +++++++++++++++++++++++++++++++++--- timm/data/dataset_factory.py | 19 +++++++++++++++++- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index e719f3f6..be958638 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -2,18 +2,21 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import csv + import torch.utils.data as data import os import torch import logging - +import pandas as pd +import numpy as np from PIL import Image - +from torch.utils.data import Dataset +from sklearn import preprocessing from .parsers import create_parser _logger = logging.getLogger(__name__) - _ERROR_RETRY = 50 @@ -144,3 +147,32 @@ class AugMixDataset(torch.utils.data.Dataset): def __len__(self): return len(self.dataset) + + +class COAIImageClassDataset(Dataset): + def __init__(self, dict, base_path='', transform=None): + self.transform = transform + self.base_path = base_path + self.dict = dict + + df = pd.DataFrame(list(dict.items()), columns=['image', 'label']) + classes = df['label'].unique() + le = preprocessing.LabelEncoder() + le.fit(classes) + df['encoded_label'] = le.transform(df['label']) + self.df = df + + def __len__(self): + index = self.df.index + number_of_rows = len(index) + return number_of_rows + + def __getitem__(self, index): + img_path = self.df.iloc[index]['image'] + image = Image.open(self.base_path + img_path) + np_img = np.array(image) + # print(np_img.shape) #(h=512,w=512,c=3) + if self.transform: + np_img = self.transform(np_img) + # print(np_img.shape) #(1,256,256) + return np_img, self.df.iloc[index]['encoded_label'] diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index ccc99d5c..82aa921e 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -1,6 +1,7 @@ +import csv import os -from .dataset import IterableImageDataset, ImageDataset +from .dataset import IterableImageDataset, ImageDataset, COAIImageClassDataset def _search_split(root, split): @@ -21,6 +22,10 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin if name.startswith('tfds'): ds = IterableImageDataset( root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) + elif name.startswith('coaiclass'): + # Get Dict from csv(current implementation)/mongodb(needs to be added) + dict = _get_dict_from_csv(root) + ds = COAIImageClassDataset(dict=dict) else: # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier @@ -28,3 +33,15 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin root = _search_split(root, split) ds = ImageDataset(root, parser=name, **kwargs) return ds + + +def _convert(lst): + res_dct = {lst[i][0]: lst[i][1] for i in range(len(lst))} + return res_dct + + +def _get_dict_from_csv(data_folder): + with open(data_folder + '/train.csv', 'r') as f: + reader = csv.reader(f) + data = [row for row in reader] + return _convert(data) \ No newline at end of file