diff --git a/photo_wct.py b/photo_wct.py index 2fb6fda..0ab33f3 100644 --- a/photo_wct.py +++ b/photo_wct.py @@ -79,8 +79,8 @@ def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg): else: target_feature = cont_feat.view(cont_c, -1).clone() - t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST)) - t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST)) + t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='I').resize((cont_w, cont_h), Image.NEAREST)) + t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='I').resize((styl_w, styl_h), Image.NEAREST)) for l in self.label_set: if self.label_indicator[l] == 0: