NoteOnMe博客平台搭建
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

168 lines
7.6 KiB

  1. import tensorflow as tf
  2. from tensorflow.contrib import slim
  3. from nets import vgg
  4. from utils.rpn_msr.anchor_target_layer import anchor_target_layer as anchor_target_layer_py
  5. def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]):
  6. num_channels = images.get_shape().as_list()[-1]
  7. if len(means) != num_channels:
  8. raise ValueError('len(means) must match the number of channels')
  9. channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images)
  10. for i in range(num_channels):
  11. channels[i] -= means[i]
  12. return tf.concat(axis=3, values=channels)
  13. def make_var(name, shape, initializer=None):
  14. return tf.get_variable(name, shape, initializer=initializer)
  15. def Bilstm(net, input_channel, hidden_unit_num, output_channel, scope_name):
  16. # width--->time step
  17. with tf.variable_scope(scope_name) as scope:
  18. shape = tf.shape(net)
  19. N, H, W, C = shape[0], shape[1], shape[2], shape[3]
  20. net = tf.reshape(net, [N * H, W, C])
  21. net.set_shape([None, None, input_channel])
  22. lstm_fw_cell = tf.contrib.rnn.LSTMCell(hidden_unit_num, state_is_tuple=True)
  23. lstm_bw_cell = tf.contrib.rnn.LSTMCell(hidden_unit_num, state_is_tuple=True)
  24. lstm_out, last_state = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, net, dtype=tf.float32)
  25. lstm_out = tf.concat(lstm_out, axis=-1)
  26. lstm_out = tf.reshape(lstm_out, [N * H * W, 2 * hidden_unit_num])
  27. init_weights = tf.contrib.layers.variance_scaling_initializer(factor=0.01, mode='FAN_AVG', uniform=False)
  28. init_biases = tf.constant_initializer(0.0)
  29. weights = make_var('weights', [2 * hidden_unit_num, output_channel], init_weights)
  30. biases = make_var('biases', [output_channel], init_biases)
  31. outputs = tf.matmul(lstm_out, weights) + biases
  32. outputs = tf.reshape(outputs, [N, H, W, output_channel])
  33. return outputs
  34. def lstm_fc(net, input_channel, output_channel, scope_name):
  35. with tf.variable_scope(scope_name) as scope:
  36. shape = tf.shape(net)
  37. N, H, W, C = shape[0], shape[1], shape[2], shape[3]
  38. net = tf.reshape(net, [N * H * W, C])
  39. init_weights = tf.contrib.layers.variance_scaling_initializer(factor=0.01, mode='FAN_AVG', uniform=False)
  40. init_biases = tf.constant_initializer(0.0)
  41. weights = make_var('weights', [input_channel, output_channel], init_weights)
  42. biases = make_var('biases', [output_channel], init_biases)
  43. output = tf.matmul(net, weights) + biases
  44. output = tf.reshape(output, [N, H, W, output_channel])
  45. return output
  46. def model(image,language):#改
  47. image = mean_image_subtraction(image)
  48. with slim.arg_scope(vgg.vgg_arg_scope()):
  49. conv5_3 = vgg.vgg_16(image)
  50. rpn_conv = slim.conv2d(conv5_3, 512, 3)
  51. lstm_output = Bilstm(rpn_conv, 512, 128, 512, scope_name='BiLSTM')
  52. bbox_pred = lstm_fc(lstm_output, 512, 10 * 4, scope_name="bbox_pred")
  53. cls_pred = lstm_fc(lstm_output, 512, 10 * 3, scope_name="cls_pred")#改
  54. # transpose: (1, H, W, A x d) -> (1, H, WxA, d)
  55. cls_pred_shape = tf.shape(cls_pred)
  56. cls_pred_reshape = tf.reshape(cls_pred, [cls_pred_shape[0], cls_pred_shape[1], -1, 3])#改
  57. cls_pred_reshape_shape = tf.shape(cls_pred_reshape)
  58. cls_prob = tf.reshape(tf.nn.softmax(tf.reshape(cls_pred_reshape, [-1, cls_pred_reshape_shape[3]])),
  59. [-1, cls_pred_reshape_shape[1], cls_pred_reshape_shape[2], cls_pred_reshape_shape[3]],
  60. name="cls_prob")
  61. return bbox_pred, cls_pred, cls_prob
  62. def anchor_target_layer(cls_pred, bbox, im_info, scope_name):
  63. with tf.variable_scope(scope_name) as scope:
  64. # 'rpn_cls_score', 'gt_boxes', 'im_info'
  65. rpn_labels, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = \
  66. tf.py_func(anchor_target_layer_py,
  67. [cls_pred, bbox, im_info, [16, ], [16]],
  68. [tf.float32, tf.float32, tf.float32, tf.float32])
  69. rpn_labels = tf.convert_to_tensor(tf.cast(rpn_labels, tf.int32),
  70. name='rpn_labels')
  71. rpn_bbox_targets = tf.convert_to_tensor(rpn_bbox_targets,
  72. name='rpn_bbox_targets')
  73. rpn_bbox_inside_weights = tf.convert_to_tensor(rpn_bbox_inside_weights,
  74. name='rpn_bbox_inside_weights')
  75. rpn_bbox_outside_weights = tf.convert_to_tensor(rpn_bbox_outside_weights,
  76. name='rpn_bbox_outside_weights')
  77. return [rpn_labels, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights]
  78. def smooth_l1_dist(deltas, sigma2=9.0, name='smooth_l1_dist'):
  79. with tf.name_scope(name=name) as scope:
  80. deltas_abs = tf.abs(deltas)
  81. smoothL1_sign = tf.cast(tf.less(deltas_abs, 1.0 / sigma2), tf.float32)
  82. return tf.square(deltas) * 0.5 * sigma2 * smoothL1_sign + \
  83. (deltas_abs - 0.5 / sigma2) * tf.abs(smoothL1_sign - 1)
  84. def loss(bbox_pred, cls_pred, bbox, im_info):
  85. # rpn_labels : (HxWxA, 1), for each anchor, 0 denotes bg, 1 fg, -1 dontcare
  86. #rpn_bbox_targets: (HxWxA, 4), distances of the anchors to the gt_boxes(may contains some transform)
  87. # that are the regression objectives
  88. #rpn_bbox_inside_weights: (HxWxA, 4) weights of each boxes, mainly accepts hyper param in cfg
  89. #rpn_bbox_outside_weights: (HxWxA, 4) used to balance the fg/bg,
  90. # beacuse the numbers of bgs and fgs mays significiantly different
  91. rpn_data = anchor_target_layer(cls_pred, bbox, im_info, "anchor_target_layer")#改
  92. # classification loss
  93. # transpose: (1, H, W, A x d) -> (1, H, WxA, d)
  94. cls_pred_shape = tf.shape(cls_pred)
  95. cls_pred_reshape = tf.reshape(cls_pred, [cls_pred_shape[0], cls_pred_shape[1], -1, 3])#改
  96. rpn_cls_score = tf.reshape(cls_pred_reshape, [-1, 3])#改
  97. rpn_label = tf.reshape(rpn_data[0], [-1])
  98. # ignore_label(-1)
  99. fg_keep = tf.not_equal(rpn_label, -1)&tf.not_equal(rpn_label, 0)#改
  100. rpn_keep = tf.where(tf.not_equal(rpn_label, -1))
  101. rpn_cls_score = tf.gather(rpn_cls_score, rpn_keep)
  102. rpn_label = tf.gather(rpn_label, rpn_keep)
  103. rpn_cross_entropy_n = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=rpn_label, logits=rpn_cls_score)
  104. # box loss
  105. rpn_bbox_pred = bbox_pred
  106. rpn_bbox_targets = rpn_data[1]
  107. rpn_bbox_inside_weights = rpn_data[2]
  108. rpn_bbox_outside_weights = rpn_data[3]
  109. rpn_bbox_pred = tf.gather(tf.reshape(rpn_bbox_pred, [-1, 4]), rpn_keep) # shape (N, 4)
  110. rpn_bbox_targets = tf.gather(tf.reshape(rpn_bbox_targets, [-1, 4]), rpn_keep)
  111. rpn_bbox_inside_weights = tf.gather(tf.reshape(rpn_bbox_inside_weights, [-1, 4]), rpn_keep)
  112. rpn_bbox_outside_weights = tf.gather(tf.reshape(rpn_bbox_outside_weights, [-1, 4]), rpn_keep)
  113. rpn_loss_box_n = tf.reduce_sum(rpn_bbox_outside_weights * smooth_l1_dist(
  114. rpn_bbox_inside_weights * (rpn_bbox_pred - rpn_bbox_targets)), reduction_indices=[1])
  115. rpn_loss_box = tf.reduce_sum(rpn_loss_box_n) / (tf.reduce_sum(tf.cast(fg_keep, tf.float32)) + 1)
  116. rpn_cross_entropy = tf.reduce_mean(rpn_cross_entropy_n)
  117. model_loss = rpn_cross_entropy + rpn_loss_box
  118. regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
  119. total_loss = tf.add_n(regularization_losses) + model_loss
  120. tf.summary.scalar('model_loss', model_loss)
  121. tf.summary.scalar('total_loss', total_loss)
  122. tf.summary.scalar('rpn_cross_entropy', rpn_cross_entropy)
  123. tf.summary.scalar('rpn_loss_box', rpn_loss_box)
  124. return total_loss, model_loss, rpn_cross_entropy, rpn_loss_box