PythonMania

普段はロボットとAIを組み合わせて色々作ってます。Python関係以外も色々投稿していくと思います。

【Python】画像認識 - fastaiとDenseNetでサボテンの分類をやってみる 【DeepLearning】


今回はこちらのコンペを参考に画像認識について勉強していきたいと思います。




www.kaggle.com






Kernelを読んでいると、至るところで「fastai」という言葉を見かけました。

調べてみたところ、どうやら「fast.ai」というのはAIに関する講座(学習サイト?)のようなものらしく

そのコースの中で使用されているライブラリが「fastai」みたいです。



ライブラリ自体はPytorchベースで動いており、「fast.ai」の講座の中でKaggleに挑戦するような問題もあることから
実際のコンペでも使用されることが多いみたいです。


今回はこのfastaiを使ってDenceNetを転移学習させるものを作ってみたいと思います。



DenseNetはResNetと似た構造のモデルらしいのですが、まだちゃんと理解できていないので

まとまったら改めて記事にします。



以下コードです


#必要なライブラリのインポート
import fastai
from fastai.vision import *
from sklearn.model_selection import KFold


# 事前学習モデルから重みをコピー
!mkdir '/tmp/.torch'
!mkdir '/tmp/.torch/models/'
!cp '../input/densenet201/densenet201.pth' '/tmp/.torch/models/densenet201-c1103571.pth'


#トレーニングデータの読み込み
data_path = Path('../input/aerial-cactus-identification')
df = pd.read_csv(data_path/'train.csv')
df.head()


#提出用データの読み込み
sub_csv = pd.read_csv(data_path/'sample_submission.csv')
sub_csv.head()

#データ束の作成
def create_databunch(valid_idx):
    test = ImageList.from_df(sub_csv, path=data_path/'test', folder='test')
    data = (ImageList.from_df(df, path=data_path/'train', folder='train')
            .split_by_idx(valid_idx)
            .label_from_df()
            .add_test(test)
            .transform(get_transforms(flip_vert=True, max_rotate=20.0), size=128)
            .databunch(path='.', bs=64)
            .normalize(imagenet_stats)
           )
    return data


#KFold5つのアンサンブルで学習
kf = KFold(n_splits=5, random_state=379)
epochs = 6
lr = 1e-2
preds = []
for train_idx, valid_idx in kf.split(df):
    data = create_databunch(valid_idx)
    learn = create_cnn(data, models.densenet201, metrics=[accuracy])
    learn.fit_one_cycle(epochs, slice(lr))
    learn.unfreeze()
    learn.fit_one_cycle(epochs, slice(lr/400, lr/4))
    learn.fit_one_cycle(epochs, slice(lr/800, lr/8))
    preds.append(learn.get_preds(ds_type=DatasetType.Test))


ens = torch.cat([preds[i][0][:,1].view(-1, 1) for i in range(5)], dim=1)
ens  = (ens.mean(1)>0.5).long(); ens[:10]

sub_csv['has_cactus'] = ens
sub_csv.to_csv('submission.csv', index=False)