DaSE-Computer-Vision-2021
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
4.9 KiB

  1. import numpy as np
  2. cimport numpy as np
  3. cimport cython
  4. # DTYPE = np.float64
  5. # ctypedef np.float64_t DTYPE_t
  6. ctypedef fused DTYPE_t:
  7. np.float32_t
  8. np.float64_t
  9. def im2col_cython(np.ndarray[DTYPE_t, ndim=4] x, int field_height,
  10. int field_width, int padding, int stride):
  11. cdef int N = x.shape[0]
  12. cdef int C = x.shape[1]
  13. cdef int H = x.shape[2]
  14. cdef int W = x.shape[3]
  15. cdef int HH = (H + 2 * padding - field_height) / stride + 1
  16. cdef int WW = (W + 2 * padding - field_width) / stride + 1
  17. cdef int p = padding
  18. cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.pad(x,
  19. ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
  20. cdef np.ndarray[DTYPE_t, ndim=2] cols = np.zeros(
  21. (C * field_height * field_width, N * HH * WW),
  22. dtype=x.dtype)
  23. # Moving the inner loop to a C function with no bounds checking works, but does
  24. # not seem to help performance in any measurable way.
  25. im2col_cython_inner(cols, x_padded, N, C, H, W, HH, WW,
  26. field_height, field_width, padding, stride)
  27. return cols
  28. @cython.boundscheck(False)
  29. cdef int im2col_cython_inner(np.ndarray[DTYPE_t, ndim=2] cols,
  30. np.ndarray[DTYPE_t, ndim=4] x_padded,
  31. int N, int C, int H, int W, int HH, int WW,
  32. int field_height, int field_width, int padding, int stride) except? -1:
  33. cdef int c, ii, jj, row, yy, xx, i, col
  34. for c in range(C):
  35. for yy in range(HH):
  36. for xx in range(WW):
  37. for ii in range(field_height):
  38. for jj in range(field_width):
  39. row = c * field_width * field_height + ii * field_height + jj
  40. for i in range(N):
  41. col = yy * WW * N + xx * N + i
  42. cols[row, col] = x_padded[i, c, stride * yy + ii, stride * xx + jj]
  43. def col2im_cython(np.ndarray[DTYPE_t, ndim=2] cols, int N, int C, int H, int W,
  44. int field_height, int field_width, int padding, int stride):
  45. cdef np.ndarray x = np.empty((N, C, H, W), dtype=cols.dtype)
  46. cdef int HH = (H + 2 * padding - field_height) / stride + 1
  47. cdef int WW = (W + 2 * padding - field_width) / stride + 1
  48. cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((N, C, H + 2 * padding, W + 2 * padding),
  49. dtype=cols.dtype)
  50. # Moving the inner loop to a C-function with no bounds checking improves
  51. # performance quite a bit for col2im.
  52. col2im_cython_inner(cols, x_padded, N, C, H, W, HH, WW,
  53. field_height, field_width, padding, stride)
  54. if padding > 0:
  55. return x_padded[:, :, padding:-padding, padding:-padding]
  56. return x_padded
  57. @cython.boundscheck(False)
  58. cdef int col2im_cython_inner(np.ndarray[DTYPE_t, ndim=2] cols,
  59. np.ndarray[DTYPE_t, ndim=4] x_padded,
  60. int N, int C, int H, int W, int HH, int WW,
  61. int field_height, int field_width, int padding, int stride) except? -1:
  62. cdef int c, ii, jj, row, yy, xx, i, col
  63. for c in range(C):
  64. for ii in range(field_height):
  65. for jj in range(field_width):
  66. row = c * field_width * field_height + ii * field_height + jj
  67. for yy in range(HH):
  68. for xx in range(WW):
  69. for i in range(N):
  70. col = yy * WW * N + xx * N + i
  71. x_padded[i, c, stride * yy + ii, stride * xx + jj] += cols[row, col]
  72. @cython.boundscheck(False)
  73. @cython.wraparound(False)
  74. cdef col2im_6d_cython_inner(np.ndarray[DTYPE_t, ndim=6] cols,
  75. np.ndarray[DTYPE_t, ndim=4] x_padded,
  76. int N, int C, int H, int W, int HH, int WW,
  77. int out_h, int out_w, int pad, int stride):
  78. cdef int c, hh, ww, n, h, w
  79. for n in range(N):
  80. for c in range(C):
  81. for hh in range(HH):
  82. for ww in range(WW):
  83. for h in range(out_h):
  84. for w in range(out_w):
  85. x_padded[n, c, stride * h + hh, stride * w + ww] += cols[c, hh, ww, n, h, w]
  86. def col2im_6d_cython(np.ndarray[DTYPE_t, ndim=6] cols, int N, int C, int H, int W,
  87. int HH, int WW, int pad, int stride):
  88. cdef np.ndarray x = np.empty((N, C, H, W), dtype=cols.dtype)
  89. cdef int out_h = (H + 2 * pad - HH) / stride + 1
  90. cdef int out_w = (W + 2 * pad - WW) / stride + 1
  91. cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((N, C, H + 2 * pad, W + 2 * pad),
  92. dtype=cols.dtype)
  93. col2im_6d_cython_inner(cols, x_padded, N, C, H, W, HH, WW, out_h, out_w, pad, stride)
  94. if pad > 0:
  95. return x_padded[:, :, pad:-pad, pad:-pad]
  96. return x_padded