trainModel.py 1.51 KB
Newer Older
josh's avatar
josh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sys
import shutil
import os

import numpy as np

from pathlib import Path
from PIL import Image
from sample.model import HDRModel


def getSamples(dataDir: Path) -> [(np.array, np.array)]:
    samples = []
    for sampleDir in dataDir.iterdir():
        inputPath = sampleDir.joinpath("{}.jpg".format("input"))
        maskPath  = sampleDir.joinpath("{}.jpg".format("mask"))
        
        inputImage = Image.open(inputPath)
        maskImage  = Image.open(maskPath)
        
        #newSize = (128, 128)
        #resizeInput = inputImage.resize(newSize)
        #resizeMask  = maskImage.resize(newSize)
        
        inputArr = np.array(inputImage)
        maskArr  = np.array(maskImage.convert("L"))
        
        samples.append((inputArr, maskArr))
        
    X_TRAIN = []
    Y_TRAIN = []
    for sample in samples:
        X_TRAIN.append(sample[0])
        Y_TRAIN.append(sample[1])
    return (np.array(X_TRAIN), np.array(Y_TRAIN))



if __name__ == "__main__":
    '''
    Train Model and save, overwriting loaded model
    '''
    modelDir  = Path(sys.argv[1])
    sampleDir = Path(sys.argv[2])

    print(modelDir)
    print(sampleDir)

    (xTrain, yTrain) = getSamples(sampleDir)

    xTrain = xTrain[0:3,:,:,:]
    yTrain = yTrain[0:3,:,:]

    print(xTrain.shape)
    print(yTrain.shape)

    model = HDRModel(loadDir=modelDir)
    model.featureModel.summary()
    model.fit(xTrain, yTrain)

    # Remove loaded model and save trained model
    shutil.rmtree(modelDir)
    os.mkdir(modelDir)
    model.save(modelDir)