2018年1月7日日曜日

fit_generatorがなんかおかしい。

#coding:UTF-8

from keras.applications.inception_v3 import InceptionV3,decode_predictions
import numpy as np;
import os;
import glob;
import sys;
import csv;
import cv2;
from keras.models import Sequential,Model
from keras.layers import Dense,Activation,Dropout,Dropout,Flatten,Convolution1D,MaxPooling1D,Input
from keras.callbacks import EarlyStopping,ModelCheckpoint
from keras.models import model_from_json
from keras.optimizers import SGD,Adam
from keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
import os;
import datetime;
from keras.applications.vgg16 import VGG16
from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D, GlobalAveragePooling2D, AveragePooling2D

#モデルファイルを保存する
def SaveModel(model,FileName):
    json_string = model.to_json()
    fp = open(FileName,"w")
    fp.write(json_string)
    fp.close()
   
'''
InceptionV3の転移学習版
'''
def CreateModel(classes):
    IncepModel = InceptionV3(weights='imagenet',include_top=False)
    for layer in IncepModel.layers:
        layer.trainable = False
   
    x = IncepModel.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024,activation="relu")(x)
    predictions = Dense(classes,activation="softmax")(x)
    opt = SGD(lr=.01, momentum=.9)
    MyModel = Model(inputs=IncepModel.input, outputs=predictions)
   
    MyModel.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

    return MyModel

if __name__ == '__main__':

    WeightFile = '../model/myincep.{epoch:02d}-{loss:.2f}-{acc:.2f}.hdf5'
    ModelFile = "../model/myincep.json";
    imgsize=(299,299);

    #モデル作成
    model=CreateModel(4)
    #モデル保存
    SaveModel(model,ModelFile)
   
    #訓練データ拡張
    train_datagen = ImageDataGenerator(
        featurewise_center=False,
        samplewise_center=False,
        featurewise_std_normalization=False,
        samplewise_std_normalization=False,
        rotation_range=10,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        zoom_range=[.8, 1])
    cp_cb = ModelCheckpoint(filepath = WeightFile, monitor='acc', verbose=1, save_best_only=True, mode='auto')
    train_generator = train_datagen.flow_from_directory('../data',target_size=imgsize,batch_size=64,class_mode='categorical')
    model.fit_generator(train_generator,steps_per_epoch=2000,epochs=10,callbacks=[cp_cb]);
   
   

0 件のコメント:

コメントを投稿