作者学习记录

代码来源:GitHub关于BigGAN点赞最多的代码

这个版本的BigGAN开发人员编写的BN层代码内部耦合度非常高,需要一步步的理解。
首先在BigGAN.py中,有着关于生成器的代码,其中关于选择BN层:

    self.which_bn = functools.partial(layers.ccbn,
                          which_linear=bn_linear,
                          cross_replica=self.cross_replica,
                          mybn=self.mybn,
                          input_size=(self.shared_dim + self.z_chunk_size if self.G_shared
                                      else self.n_classes),
                          norm_style=self.norm_style,
                          eps=self.BN_eps)

这个部分为选择BN层的代码,可以发现这段代码的核心参数为layers.ccbn,即self.which_bn的构建是基于layers.ccbn的,之后寻找layers.ccbn这段代码:

# Class-conditional bn
# output size is the number of channels, input size is for the linear layers
# Andy's Note: this class feels messy but I'm not really sure how to clean it up
# Suggestions welcome! (By which I mean, refactor this and make a pull request
# if you want to make this more readable/usable). 
class ccbn(nn.Module):
  # ccbn用途就是将x经过某种bn操作的结果,再一次与基于类别信息得到的gain与bias进行计算
  def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1,
               cross_replica=False, mybn=False, norm_style='bn',):
    super(ccbn, self).__init__()
    self.output_size, self.input_size = output_size, input_size
    # Prepare gain and bias layers
    self.gain = which_linear(input_size, output_size)
    self.bias = which_linear(input_size, output_size)
    # epsilon to avoid dividing by 0
    self.eps = eps
    # Momentum
    self.momentum = momentum
    # Use cross-replica batchnorm?
    # 多个GPU进行批归一化的技术,有助于提高性能
    self.cross_replica = cross_replica
    # Use my batchnorm?
    self.mybn = mybn
    # Norm style?
    self.norm_style = norm_style
    
    if self.cross_replica:
      self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
    elif self.mybn:
      self.bn = myBN(output_size, self.eps, self.momentum)
    elif self.norm_style in ['bn', 'in']:
      self.register_buffer('stored_mean', torch.zeros(output_size))
      self.register_buffer('stored_var',  torch.ones(output_size)) 
    
    
  def forward(self, x, y):
    # Calculate class-conditional gains and biases
    # 类别信息y是经过某种线性变换,从而提供了与类别有关的BN层的gain与bias
    gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
    bias = self.bias(y).view(y.size(0), -1, 1, 1)
    # 几种用于规范化x的BN操作
    # If using my batchnorm
    if self.mybn or self.cross_replica:
      # 这里的gain和bias是类别信息
      return self.bn(x, gain=gain, bias=bias)
    # else:
    else:
      if self.norm_style == 'bn':
        out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
                          self.training, 0.1, self.eps)
      elif self.norm_style == 'in':
        out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
                          self.training, 0.1, self.eps)
      elif self.norm_style == 'gn':
        out = groupnorm(x, self.normstyle)
      elif self.norm_style == 'nonorm':
        out = x
      # 标准化后的x在与基于类别信息得到的gain与bias进行计算
      return out * gain + bias
  def extra_repr(self):
    s = 'out: {output_size}, in: {input_size},'
    s +=' cross_replica={cross_replica}'
    return s.format(**self.__dict__)

在forward函数中,参数x为BN层的输入,有可能为上一层的输出,y为希望生成样本的类别,可以看到self.gain与self.bias为ccbn类初始化时定义的某种线性层,不难发现在进行有类别条件的批标准化时,ccbn首先对类别信息进行变换,得到与类别有关的两个参数gain与bias;紧接着后续进行选择批标准化的种类,其中有5类批标准化操作,以及一个恒等映射,在经过批标准化的处理之后,选择将处理之后得到的结果与类别信息进行组合,从而为生成的图像赋予类别信息,BigGAN是以SAGAN作为BaseLine的,其类别信息是通过BN层赋予的

之后尝试对mybn类进行理解,进而更好的理解类别参与生成样本的方式。想要理解mybn类,首先要了解fused_bn函数,fused_bn代码如下:

# Fused batchnorm op
# 其实就是缩放和平移的过程
# 这里的gain和bias是类别信息 或者在普通的无条件bn中,就为可学习的参数
def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
  # Apply scale and shift--if gain and bias are provided, fuse them here
  # Prepare scale
  scale = torch.rsqrt(var + eps)
  # If a gain is provided, use it
  # 
  if gain is not None:
    scale = scale * gain
  # Prepare shift
  shift = mean * scale
  # If bias is provided, use it
  if bias is not None:
    shift = shift - bias
  return x * scale - shift
  #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.

在这里fused_bn的输入参数mean,var为输入量x的均值与方差,已经提前计算好,torch.rsqrt函数能够计算张量的平方根倒数,在没有gain与bias的参与下,相当于令x减去均值再除以标准差,即正常的批标准化步骤;在有类别信息的参数下,这里我们提前知道gain与bias为在ccbn中给予类别学习到的信息,在fused_bn中,分别以缩放与平移的方式参加了批标准化的操作。
这样就了解了类别信息如何参与批标准化的操作,但是ccbn类中并没有fused_bn函数,所以接下来开始了解下一个函数——manual_bn:

# Manual BN
# Calculate means and variances using mean-of-squares minus mean-squared
# 这里的gain和bias是类别信息 或者在普通的无条件bn中,就为可学习的参数
def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
  # 首先计算均值与方差
  # Cast x to float32 if necessary
  float_x = x.float()
  # Calculate expected value of x (m) and expected value of x**2 (m2)  
  # Mean of x
  m = torch.mean(float_x, [0, 2, 3], keepdim=True)
  # Mean of x squared
  m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
  # Calculate variance as mean of squared minus mean squared.
  var = (m2 - m **2)
  # Cast back to float 16 if necessary
  var = var.type(x.type())
  m = m.type(x.type())
  # Return mean and variance for updating stored mean/var if requested  
  if return_mean_var:
    # 训练的时候使用
    return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
  else:
    return fused_bn(x, m, var, gain, bias, eps)

很明显,manual_bn将输入x的均值与方差计算出来,赋值到了fused_bn进行运算,需要注意的一点为manual_bn分为两种模式,在训练情况下,需要返回均值与方差参与训练,而在测试情况下,均值与方差便无需返回。

之后便迎来最后一个中间类——myBN:

# My batchnorm, supports standing stats    
# "My batchnorm, supports standing stats"这句注释表明MyBatchNorm支持在测试时使用先前计算的统计信息,从而避免在测试时重新计算统计信息。
# 指该 BN 实现支持固定的均值和方差,也就是说,如果在训练过程中已经计算好了某个 mini-batch 的均值和方差,并将其保存下来,
# 那么在之后的推理过程中,这个 BN 层就可以直接使用这个固定的均值和方差,而不需要重新计算。这种做法有助于提高模型的推理速度。
class myBN(nn.Module):
  def __init__(self, num_channels, eps=1e-5, momentum=0.1):
    super(myBN, self).__init__()
    # momentum for updating running stats
    self.momentum = momentum
    # epsilon to avoid dividing by 0
    self.eps = eps
    # Momentum
    self.momentum = momentum
    # Register buffers
    self.register_buffer('stored_mean', torch.zeros(num_channels))
    self.register_buffer('stored_var',  torch.ones(num_channels))
    self.register_buffer('accumulation_counter', torch.zeros(1))
    # Accumulate running means and vars
    self.accumulate_standing = False
    
  # reset standing stats
  def reset_stats(self):
    self.stored_mean[:] = 0
    self.stored_var[:] = 0
    self.accumulation_counter[:] = 0
    
  def forward(self, x, gain, bias):
    # 这里的gain和bias是类别信息 或者在普通的无条件bn中,就为可学习的参数
    if self.training:
      # 如果在训练过程中,就进行参数的更新(return_mean_var=True),测试的时候不用
      out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
      # If accumulating standing stats, increment them
      if self.accumulate_standing:
        self.stored_mean[:] = self.stored_mean + mean.data
        self.stored_var[:] = self.stored_var + var.data
        self.accumulation_counter += 1.0
      # If not accumulating standing stats, take running averages
      else:
        self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
        self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
      return out
    # If not in training mode, use the stored statistics
    else:         
      mean = self.stored_mean.view(1, -1, 1, 1)
      var = self.stored_var.view(1, -1, 1, 1)
      # If using standing stats, divide them by the accumulation counter   
      if self.accumulate_standing:
        mean = mean / self.accumulation_counter
        var = var / self.accumulation_counter
      return fused_bn(x, mean, var, gain, bias, self.eps)

在这个myBN类中,首先定义了三个缓冲区用于训练时的参数更新或测试时的参数调用,分为两种情况,在self.accumulate_standing设置为True时,会将每个mini_batch得到的均值与方差累计起来,在测试时会基于mini_batch的数量对均值与方差求平均值,进而进行bn计算;如果设置为False,那么用于测试的均值与方差将在训练过程中,基于对每个mini_batch的训练进行即时更新,即每收到一个mini_batch便会更新一次均值与方差。这两种方式各有优劣,在实际应用中需要进行调试和验证,基于mini_batch的大小比较影响生成样本的精度。

最后再通过ccbn类调用myBN类,这样便结束了BigGAN中有类别条件的批标准化的主体部分的理解。另外还有无条件bn,与普通的bn层差距不大;还有给予参数self.cross_replica实施的SyncBN2d操作,这是基于多个GPU的BN层操作,在参数更新上有些不同;还有几个不同风格批标准化,理解难度没有主体部分大,有机会便记录一下。

最后感谢GPT3.5的帮助

Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐