Commit 301a653e authored by josh's avatar josh
Browse files

add train/test split

parent 87c93ef2
......@@ -14,14 +14,16 @@ from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
class HDRModel:
def __init__(self, shape=None, loadDir=None):
self.featureModel = None
self.rfClassifer = None
self.featureModel = None
self.rfClassifer = None
self.trainTestSplit = 0.2
self.inputChannels = 3
self.outputChannels = 1
......@@ -65,10 +67,21 @@ class HDRModel:
# Reshape input for random forest
# TO DO: sklearn split data to train and test
xForestTrain = features.reshape(-1, features.shape[-1])
yForestTrain = yTrain.reshape(-1)
xForest = features.reshape(-1, features.shape[-1])
yForest = yTrain.reshape(-1)
print("Forest Data X:", xForest.shape)
print("Forest Data Y:", yForest.shape)
self.rfClassifer.fit(xForestTrain, yForestTrain)
xTrain, xTest, yTrain, yTest = train_test_split(xForest, yForest, test_size=self.trainTestSplit)
print("Forest Train X:", xTrain.shape)
print("Forest Train Y:", yTrain.shape)
print("Forest Test X:", xTest.shape)
print("Forest Test Y:", yTest.shape)
self.rfClassifer.fit(xTrain, yTrain)
score = self.rfClassifer.score(xTest, yTest)
print("SCORE:", score)
def predict(self, xInput):
......
......@@ -44,14 +44,13 @@ if __name__ == "__main__":
(xTrain, yTrain) = getSamples(sampleDir)
xTrain = xTrain[0:3,:,:,:]
yTrain = yTrain[0:3,:,:]
#xTrain = xTrain[0:5,:,:,:]
#yTrain = yTrain[0:5,:,:]
print(xTrain.shape)
print(yTrain.shape)
model = HDRModel(loadDir=modelDir)
model.featureModel.summary()
model.fit(xTrain, yTrain)
# Remove loaded model and save trained model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment