Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Ortner, Joshua
ai4hdr_backend
Commits
301a653e
Commit
301a653e
authored
Feb 08, 2021
by
josh
Browse files
add train/test split
parent
87c93ef2
Changes
2
Show whitespace changes
Inline
Side-by-side
sample/model.py
View file @
301a653e
...
...
@@ -14,6 +14,7 @@ from tensorflow.keras.models import Model
from
tensorflow.keras.applications.vgg16
import
VGG16
from
sklearn.ensemble
import
RandomForestClassifier
from
sklearn.model_selection
import
train_test_split
class
HDRModel
:
...
...
@@ -22,6 +23,7 @@ class HDRModel:
def
__init__
(
self
,
shape
=
None
,
loadDir
=
None
):
self
.
featureModel
=
None
self
.
rfClassifer
=
None
self
.
trainTestSplit
=
0.2
self
.
inputChannels
=
3
self
.
outputChannels
=
1
...
...
@@ -65,10 +67,21 @@ class HDRModel:
# Reshape input for random forest
# TO DO: sklearn split data to train and test
xForestTrain
=
features
.
reshape
(
-
1
,
features
.
shape
[
-
1
])
yForestTrain
=
yTrain
.
reshape
(
-
1
)
xForest
=
features
.
reshape
(
-
1
,
features
.
shape
[
-
1
])
yForest
=
yTrain
.
reshape
(
-
1
)
print
(
"Forest Data X:"
,
xForest
.
shape
)
print
(
"Forest Data Y:"
,
yForest
.
shape
)
self
.
rfClassifer
.
fit
(
xForestTrain
,
yForestTrain
)
xTrain
,
xTest
,
yTrain
,
yTest
=
train_test_split
(
xForest
,
yForest
,
test_size
=
self
.
trainTestSplit
)
print
(
"Forest Train X:"
,
xTrain
.
shape
)
print
(
"Forest Train Y:"
,
yTrain
.
shape
)
print
(
"Forest Test X:"
,
xTest
.
shape
)
print
(
"Forest Test Y:"
,
yTest
.
shape
)
self
.
rfClassifer
.
fit
(
xTrain
,
yTrain
)
score
=
self
.
rfClassifer
.
score
(
xTest
,
yTest
)
print
(
"SCORE:"
,
score
)
def
predict
(
self
,
xInput
):
...
...
trainModel.py
View file @
301a653e
...
...
@@ -44,14 +44,13 @@ if __name__ == "__main__":
(
xTrain
,
yTrain
)
=
getSamples
(
sampleDir
)
xTrain
=
xTrain
[
0
:
3
,:,:,:]
yTrain
=
yTrain
[
0
:
3
,:,:]
#
xTrain = xTrain[0:
5
,:,:,:]
#
yTrain = yTrain[0:
5
,:,:]
print
(
xTrain
.
shape
)
print
(
yTrain
.
shape
)
model
=
HDRModel
(
loadDir
=
modelDir
)
model
.
featureModel
.
summary
()
model
.
fit
(
xTrain
,
yTrain
)
# Remove loaded model and save trained model
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment