【英】深度绘画协调:让补丁图以一致风格自然融入主图

600 阅读3分钟
原文链接: sgugger.github.io

The histogram loss

Histogram matching is a technique that is often used to modify a certain photograph with the luminosity or shadows of another. The technique in itself is explained on wikipedia and here is a concrete example of application.

Histogram matching

In their paper, Pierre Wilmot et al. found that applying the same technique to define another loss could help preserve the textures of the style picture. They recommended to use it for the features of the first convolutional layer and the fourth one, for both the fine details and the more general aspects of the style.

The idea is, for these two layers, to compute the histogram of each channel of the style features as a reference. Then, at each pass of our training, we calculate the remapping of our output features so that their histogram (for each channel) matches the style reference. We then define the histogram loss as being the the mean-squared error between the output features and their remapped version. The challenge here is to compute that remapping.

Let's say we are trying to change x so that it matches an histogram hist. We sort x first, while keeping the permutation we had to do (it will be used at the end to put the new values we interpolate in their right place). Then, when we treat the i-th value, we look at the first index idx such has hist.cumsum(idx) is greater than i (which means the i-th value of the data we are trying to match the histogram is in the bin with index idx). The value attributed to x[i] is basically

min+idx×max−minnbins

where \hbox{min} and \hbox{max} are the minim and the maximum values of the data. This formula is slightly corrected because if we have several values of x with the same index idx, we want them to be evenly distributed inside the range of the bin. So we compute the ratio

ratio=i−hist.cumsum(idx−1)hist[idx]

and finally put

x[i]=min+(idx+ratio)×max−minnbins.

Now we just have to do this for all the i possibles and all the channels. Of course, a simple for loop just won't do if we want to use the GPU to handle all the computations quickly (and if we want 1000 iterations we better compute this remapping as quickly as we can). Let's assume we have our input x of size ch (for channels) by a given n (the number of activations we keep) and a variable hist_ref of size ch by n_bins (they picked 256 in the paper). Sorting x for each channel and keeping the corresponding mapping is easy with pytorch:

sorted_x, sort_idx = x.data.sort(1)

Then we have to adapt our histogram a bit because x and our reference may not have the same number of activations (we removed some style features, the one that appeared more than once). So an histogram for x would have a total sum of n, so we just have to compute the sum of each lines in hist_ref.

hist = hist_ref * n/hist_ref.sum(1).unsqueeze(1)#Normalization between the different lengths of masks.
cum_ref = hist.cumsum(1)
cum_prev = torch.cat([torch.zeros(ch,1).cuda(), cum_ref[:,:-1]],1)

The cumsums will be used later, and we will need both the cumulative sums of hist_ref and the one that contain the cumulative sums for the previous index. To replace our for loop we will create a tensor that contains all the values i from 1 to n. To determine the first index idx such that hist.cumsum(idx) is greater than i, I've used this line:

rng = torch.arange(1,n+1).unsqueeze(0).cuda()
idx = (cum_ref.unsqueeze(1) - rng.unsqueeze(2) < 0).sum(2).long()

Since all the lines of cum_ref are sorted by ascending values, by subtracting i, the sum over the booleans corresponding to the test cum_ref - i < 0 will give us the first index where cum_ref is greater than i. Then we use this tensor idx to get all the values in cum_prev and hist that we will need. Since pytorch doesn't like indexing with a multi-dimensional tensor, we have to flatten everything (though that probably won't be needed anymore in pytorch 0.4)

ymin, ymax = x.data.min(1)[0].unsqueeze(1), x.data.max(1)[0].unsqueeze(1)
step = (ymax-ymin)/n_bins
ratio = (rng - cum_prev.view(-1)[idx.view(-1)].view(ch,-1)) / (1e-8 + hist.view(-1)[idx.view(-1)].view(ch,-1))
ratio = ratio.squeeze().clamp(0,1)
new_x = ymin + (ratio + idx.float()) * step

At this stage new_x contains all the values of our remapping, but they are sorted. We have to use the inverse permutation of the one we applied at the beginning to finish the process. To find the inverse permutation I've simply chose to get the arg sort:

_, remap = sort_idx.sort()
new_x = new_x.view(-1)[remap.view(-1)].view(ch,-1)