In one case you're using regularization
while in the other you're using sqrt(regularization)
. This creates inconsistency in the code since obtaining the same results in both cases will require a different value.
Same as above, assert regularization >= 0.0
and add an if-statement to avoid padding when regularization == 0.0
.
I think you got it, but shouldn't this be:
expanded_matrix = torch.cat( (DR, torch.sqrt( regularization ) * torch.eye(DR.size(1))) )
Otherwise you can use torch.tensor(relaxation)
but you will be left with an unused variable regularization
Also, if you are using regularization
, then you should have an if-statement to use it only when regularization != 0.0
. Otherwise, you will be padding the matrix with zeros and it's wasteful to use those in the QR. Finally, you should add an assert regularization >= 0.0
to make sure you are not taking square root of a negative number.
Merge the two lines into one operation, avoid having an incorrect intermediate state.
These are not constructors, these are factory methods. The common way to indicate a factory is to use the make_
word:
make_sgd() ...
make_sgd_manual() ...
make_adam() ...
In place of a function inside a function it is better to write one function called something_eval()
and have a bunch of inputs, e.g., optimizer, params
and so on.
Then call optimizer.step()
with a lambda:
optimizer.step(lambda : something_eval(optimizer, params))
That way you are explicit on what gets captured and which way the data flows. Otherwise I have to scroll up and down to the definition here to see what gets captured.
I would not bother with Python 2 support, the code would be so much cleaner.
Using global variables is confusing, i.e., where was this defined and why?
Replacing with args.accept
is only marginally longer to write, but everyone knows what args
means.
Why define an internal function that you would call only once?
Why do you need to implement this? Unittest already comes with various asserts, Equal, notEqual, almostEqual?
Should this be an or
statement?
Do you need this?
Do you have to cache the TEST_NUMPY?
This is expressive and it is called only twice in the whole code, so no performance concerns.
if _check_module_exists('numpy'):
import numpy
Extra variable here:
return importlib.find_loader(name) is not None
This may be OSD, but I see this and I think "why MULTIGPU is more than 2?" the =
is hard to read.
I think this is easier to read
torch.cuda.device_count() > 1