trainModel.py 1.38 KB
Newer Older
josh's avatar
josh committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
import sys
import shutil
import os

import numpy as np

from pathlib import Path
from PIL import Image
from sample.model import HDRModel


def getSamples(dataDir: Path) -> [(np.array, np.array)]:
    samples = []
    for sampleDir in dataDir.iterdir():
        inputPath = sampleDir.joinpath("{}.jpg".format("input"))
        maskPath  = sampleDir.joinpath("{}.jpg".format("mask"))
        
        inputImage = Image.open(inputPath)
        maskImage  = Image.open(maskPath)
        
        inputArr = np.array(inputImage)
        maskArr  = np.array(maskImage.convert("L"))
        
        samples.append((inputArr, maskArr))
        
    X_TRAIN = []
    Y_TRAIN = []
    for sample in samples:
        X_TRAIN.append(sample[0])
        Y_TRAIN.append(sample[1])
    return (np.array(X_TRAIN), np.array(Y_TRAIN))



if __name__ == "__main__":
    '''
    Train Model and save, overwriting loaded model
    '''
    modelDir  = Path(sys.argv[1])
    sampleDir = Path(sys.argv[2])

    print(modelDir)
    print(sampleDir)

    (xTrain, yTrain) = getSamples(sampleDir)

    xTrain = xTrain[0:3,:,:,:]
    yTrain = yTrain[0:3,:,:]

    print(xTrain.shape)
    print(yTrain.shape)

    model = HDRModel(loadDir=modelDir)
    model.featureModel.summary()
    model.fit(xTrain, yTrain)

    # Remove loaded model and save trained model
    shutil.rmtree(modelDir)
    os.mkdir(modelDir)
    model.save(modelDir)