PythonMania

普段はロボットとAIを組み合わせて色々作ってます。Python関係以外も色々投稿していくと思います。

【Python】fastai ImageDataBunchの作り方 【DeepLearning】


fastai ImageDataBunchの作成方法

多クラス分類

ファイルディレクトリを指定して ImageList.from_dfで作成する

ImageList.from_df
第一引数にcsvデータ(ラベルや提出用ファイルのパス)を指定、
第二引数(path=)で画像データのディレクトリを指定
第三引数(folder)で画像データが格納されているフォルダ名を指定

画像データとラベルデータ対になったデータ形式(ImageDataBunch)が作成される

データを複製する場合にはget_transforms()関数でそれぞれのパラメータを指定しておく

#ディレクトリの指定
data_folder = Path("../input")
train_df = pd.read_csv("../input/train.csv")
test_df = pd.read_csv("../input/sample_submit.csv")

#学習用データの読み込み
test_img = ImageList.from_df(test_df, path=data_folder/'test', folder='test')
trfm = get_transforms(do_flip=True, flip_vert=True, max_rotate=10.0, max_zoom=1.1, max_lighting=0.2, max_warp=0.2, p_affine=0.75, p_lighting=0.75)
train_img = (ImageList.from_df(train_df,path=data_folder/"train",folder="train")
            .split_by_rand_pct(0.01)
            .label_from_df()
            .add_test(test_img)
            .transform(trfm,size=128)
            .databunch(path=".",bs=64,device=torch.device('cuda:0'))
             .normalize(imagenet_stats))

0,1で分類するようなケースも基本的には同じ



今度はpandasDataFrameではなく、csvからImageDataBunchを作成してみる

使用するのはImageDataBunch.from_csv()

csvファイルの画像パスにファイルの拡張子まで含まれていないときは
「suffix=」で指定することができる(ex,suffix=".png")

#データディレクトリの指定
model_path = "."
path = "../input/"
train_folder=f'{path}train'
test_folder=f'{path}test'
train_lbl=f'{path}train.csv'
ORG_SIZE=96

#ラベルデータの読み込み
df_trn = pd.read_csv(train_lbl)

#画像複製する場合はパラメータ設定
tfms = get_transforms(do_flip=True, flip_vert=True, max_rotate=.0, max_zoom=.1,
                      max_lighting=0.05, max_warp=0.)

#画像データから入力用データ(ImageDataBunch)の作成
data = ImageDataBunch.from_csv(path,csv_labels=train_lbl,folder="train",
                              ds_tfms=tfms,size=sz,test=test_folder,bs=bs)

#データの正則化等
stats=data.batch_stats()
data.normalize(stats)


<||




またcsvのラベルファイルが別途用意されていない場合(フォルダ分けされていてフォルダ名がカテゴリ名になっている場合)は
以下のようにして読み込む

path=でトレーニングデータのディレクトリを指定することで簡単に作成可能

>|python|

path = Path('input/train/')
np.random.seed(42)

data = ImageDataBunch.from_folder(path,test='../test', ds_tfms=get_transforms(),valid_pct=0.25,size=299,bs=32,num_workers=0)
data.normalize(imagenet_stats)