Loading src/tgreft/train/generic.py +8 −4 Original line number Diff line number Diff line Loading @@ -117,18 +117,22 @@ def visualize_single_epoch( # reshape preds = preds.reshape(-1, model.output_dim) refs = refs.reshape(-1, model.output_dim) # # TODO: need to find better way to auto set the labels without hardcoding labels = [ "electolyte_sld", "electolyte_roughness", "sei_sld", "sei_thickness", "sei_roughness", "si_sld", "si_thickness", "si_roughness", "material_sld", "material_thickness", "material_roughness", "cu_sld", "cu_thickness", "cu_roughness", "ti_sld", "ti_thickness", "ti_roughness", "oxide_sld", "oxide_thickness", "oxide_roughness", Loading Loading
src/tgreft/train/generic.py +8 −4 Original line number Diff line number Diff line Loading @@ -117,18 +117,22 @@ def visualize_single_epoch( # reshape preds = preds.reshape(-1, model.output_dim) refs = refs.reshape(-1, model.output_dim) # # TODO: need to find better way to auto set the labels without hardcoding labels = [ "electolyte_sld", "electolyte_roughness", "sei_sld", "sei_thickness", "sei_roughness", "si_sld", "si_thickness", "si_roughness", "material_sld", "material_thickness", "material_roughness", "cu_sld", "cu_thickness", "cu_roughness", "ti_sld", "ti_thickness", "ti_roughness", "oxide_sld", "oxide_thickness", "oxide_roughness", Loading