Merge sc-qpr1-dev-plus-aosp-without-vendor@7810918

Bug: 205056467
Merged-In: I86752f113f98ce61995c4e56a48c39d3c6913197
Change-Id: Id8d7ebe4ef8775a165694b430a260756de188326
diff --git a/Android.bp b/Android.bp
index d50d33e..63f8f81 100644
--- a/Android.bp
+++ b/Android.bp
@@ -67,6 +67,7 @@
         "libgav1/src/decoder_settings.cc",
         "libgav1/src/dsp/arm/average_blend_neon.cc",
         "libgav1/src/dsp/arm/cdef_neon.cc",
+        "libgav1/src/dsp/arm/convolve_10bit_neon.cc",
         "libgav1/src/dsp/arm/convolve_neon.cc",
         "libgav1/src/dsp/arm/distance_weighted_blend_neon.cc",
         "libgav1/src/dsp/arm/film_grain_neon.cc",
@@ -79,6 +80,7 @@
         "libgav1/src/dsp/arm/inverse_transform_10bit_neon.cc",
         "libgav1/src/dsp/arm/inverse_transform_neon.cc",
         "libgav1/src/dsp/arm/loop_filter_neon.cc",
+        "libgav1/src/dsp/arm/loop_restoration_10bit_neon.cc",
         "libgav1/src/dsp/arm/loop_restoration_neon.cc",
         "libgav1/src/dsp/arm/mask_blend_neon.cc",
         "libgav1/src/dsp/arm/motion_field_projection_neon.cc",
diff --git a/README.version b/README.version
index 89f9d10..53d5b62 100644
--- a/README.version
+++ b/README.version
@@ -1,5 +1,5 @@
 URL: https://chromium.googlesource.com/codecs/libgav1
-Version: v0.16.3
+Version: v0.17.0
 BugComponent: 324837
 Local Modifications:
 None
diff --git a/libgav1/CMakeLists.txt b/libgav1/CMakeLists.txt
index 5e9e17a..4029de1 100644
--- a/libgav1/CMakeLists.txt
+++ b/libgav1/CMakeLists.txt
@@ -18,8 +18,10 @@
 # libgav1 requires C++11.
 set(CMAKE_CXX_STANDARD 11)
 set(ABSL_CXX_STANDARD 11)
+# libgav1 requires C99.
+set(CMAKE_C_STANDARD 99)
 
-project(libgav1 CXX)
+project(libgav1 CXX C)
 
 set(libgav1_root "${CMAKE_CURRENT_SOURCE_DIR}")
 set(libgav1_build "${CMAKE_BINARY_DIR}")
@@ -56,6 +58,12 @@
   set(CMAKE_BUILD_TYPE Release)
 endif()
 
+# Enable generators like Xcode and Visual Studio to place projects in folders.
+get_property(use_folders_is_set GLOBAL PROPERTY USE_FOLDERS SET)
+if(NOT use_folders_is_set)
+  set_property(GLOBAL PROPERTY USE_FOLDERS TRUE)
+endif()
+
 include(FindThreads)
 
 include("${libgav1_examples}/libgav1_examples.cmake")
@@ -126,6 +134,7 @@
       "    clone \\\n"
       "    https://github.com/abseil/abseil-cpp.git third_party/abseil-cpp")
 endif()
+set(ABSL_PROPAGATE_CXX_STD ON)
 add_subdirectory("${libgav1_abseil}" "${libgav1_abseil_build}" EXCLUDE_FROM_ALL)
 
 libgav1_reset_target_lists()
@@ -136,6 +145,12 @@
 libgav1_add_utils_targets()
 libgav1_setup_install_target()
 
+if(LIBGAV1_ENABLE_TESTS)
+  # include(CTest) or -DBUILD_TESTING=1 aren't used to avoid enabling abseil
+  # tests.
+  enable_testing()
+endif()
+
 if(LIBGAV1_VERBOSE)
   libgav1_dump_cmake_flag_variables()
   libgav1_dump_tracked_configuration_variables()
diff --git a/libgav1/README.md b/libgav1/README.md
index 3155970..6744291 100644
--- a/libgav1/README.md
+++ b/libgav1/README.md
@@ -92,6 +92,21 @@
     options. Note: tools like [FFmpeg](https://ffmpeg.org) can be used to
     convert other container formats to IVF.
 
+*   Unit tests are built when `LIBGAV1_ENABLE_TESTS` is set to `1`. The binaries
+    can be invoked directly or with
+    [`ctest`](https://cmake.org/cmake/help/latest/manual/ctest.1.html).
+
+    *   The test input location can be given by setting the
+        `LIBGAV1_TEST_DATA_PATH` environment variable; it defaults to
+        `<libgav1_src>/tests/data`, where `<libgav1_src>` is `/data/local/tmp`
+        on Android platforms or the source directory configured with cmake
+        otherwise.
+
+    *   Output is written to the value of the `TMPDIR` or `TEMP` environment
+        variables in that order if set, otherwise `/data/local/tmp` on Android
+        platforms, the value of `LIBGAV1_FLAGS_TMPDIR` if defined during
+        compilation or the current directory if not.
+
 ## Development
 
 ### Contributing
diff --git a/libgav1/cmake/libgav1_build_definitions.cmake b/libgav1/cmake/libgav1_build_definitions.cmake
index fc83490..0d00bb6 100644
--- a/libgav1/cmake/libgav1_build_definitions.cmake
+++ b/libgav1/cmake/libgav1_build_definitions.cmake
@@ -32,7 +32,7 @@
   #
   # We set LIBGAV1_SOVERSION = [c-a].a.r
   set(LT_CURRENT 0)
-  set(LT_REVISION 0)
+  set(LT_REVISION 1)
   set(LT_AGE 0)
   math(EXPR LIBGAV1_SOVERSION_MAJOR "${LT_CURRENT} - ${LT_AGE}")
   set(LIBGAV1_SOVERSION "${LIBGAV1_SOVERSION_MAJOR}.${LT_AGE}.${LT_REVISION}")
@@ -53,7 +53,8 @@
               "LIBGAV1_FLAGS_TMPDIR=\"/tmp\"")
 
   if(MSVC OR WIN32)
-    list(APPEND libgav1_defines "_CRT_SECURE_NO_DEPRECATE=1" "NOMINMAX=1")
+    list(APPEND libgav1_defines "_CRT_SECURE_NO_WARNINGS" "NOMINMAX"
+                "_SCL_SECURE_NO_WARNINGS")
   endif()
 
   if(ANDROID)
@@ -159,7 +160,7 @@
 
   # Source file names ending in these suffixes will have the appropriate
   # compiler flags added to their compile commands to enable intrinsics.
-  set(libgav1_avx2_source_file_suffix "avx2.cc")
-  set(libgav1_neon_source_file_suffix "neon.cc")
-  set(libgav1_sse4_source_file_suffix "sse4.cc")
+  set(libgav1_avx2_source_file_suffix "avx2(_test)?.cc")
+  set(libgav1_neon_source_file_suffix "neon(_test)?.cc")
+  set(libgav1_sse4_source_file_suffix "sse4(_test)?.cc")
 endmacro()
diff --git a/libgav1/cmake/libgav1_cpu_detection.cmake b/libgav1/cmake/libgav1_cpu_detection.cmake
index e17e27c..d79b83a 100644
--- a/libgav1/cmake/libgav1_cpu_detection.cmake
+++ b/libgav1/cmake/libgav1_cpu_detection.cmake
@@ -33,17 +33,20 @@
     list(APPEND libgav1_defines "LIBGAV1_ENABLE_AVX2=1")
   else()
     list(APPEND libgav1_defines "LIBGAV1_ENABLE_AVX2=0")
+    set(libgav1_have_avx2 OFF)
   endif()
 
   if(libgav1_have_neon AND LIBGAV1_ENABLE_NEON)
     list(APPEND libgav1_defines "LIBGAV1_ENABLE_NEON=1")
   else()
     list(APPEND libgav1_defines "LIBGAV1_ENABLE_NEON=0")
+    set(libgav1_have_neon, OFF)
   endif()
 
   if(libgav1_have_sse4 AND LIBGAV1_ENABLE_SSE4_1)
     list(APPEND libgav1_defines "LIBGAV1_ENABLE_SSE4_1=1")
   else()
     list(APPEND libgav1_defines "LIBGAV1_ENABLE_SSE4_1=0")
+    set(libgav1_have_sse4 OFF)
   endif()
 endmacro()
diff --git a/libgav1/cmake/libgav1_flags.cmake b/libgav1/cmake/libgav1_flags.cmake
index a5408e2..4f2c4fd 100644
--- a/libgav1/cmake/libgav1_flags.cmake
+++ b/libgav1/cmake/libgav1_flags.cmake
@@ -259,5 +259,18 @@
   if(LIBGAV1_ENABLE_TESTS)
     set(LIBGAV1_TEST_CXX_FLAGS ${LIBGAV1_CXX_FLAGS})
     list(FILTER LIBGAV1_TEST_CXX_FLAGS EXCLUDE REGEX "-Wframe-larger-than")
+
+    if(NOT CMAKE_CXX_COMPILER_ID STREQUAL CMAKE_C_COMPILER_ID)
+      message(
+        FATAL_ERROR
+          "C/CXX compiler mismatch (${CMAKE_C_COMPILER_ID} vs"
+          " ${CMAKE_CXX_COMPILER_ID})! Compiler flags are only tested using"
+          " CMAKE_CXX_COMPILER, rerun cmake with CMAKE_C_COMPILER set to the"
+          " C compiler from the same package as CMAKE_CXX_COMPILER to ensure"
+          " the build completes successfully.")
+    endif()
+    set(LIBGAV1_TEST_C_FLAGS ${LIBGAV1_TEST_CXX_FLAGS})
+    list(FILTER LIBGAV1_TEST_C_FLAGS EXCLUDE REGEX
+         "-fvisibility-inlines-hidden")
   endif()
 endmacro()
diff --git a/libgav1/cmake/libgav1_targets.cmake b/libgav1/cmake/libgav1_targets.cmake
index 997f8bd..f8326a9 100644
--- a/libgav1/cmake/libgav1_targets.cmake
+++ b/libgav1/cmake/libgav1_targets.cmake
@@ -17,6 +17,14 @@
 endif() # LIBGAV1_CMAKE_GAV1_TARGETS_CMAKE_
 set(LIBGAV1_CMAKE_GAV1_TARGETS_CMAKE_ 1)
 
+if(LIBGAV1_IDE_FOLDER)
+  set(LIBGAV1_EXAMPLES_IDE_FOLDER "${LIBGAV1_IDE_FOLDER}/examples")
+  set(LIBGAV1_TESTS_IDE_FOLDER "${LIBGAV1_IDE_FOLDER}/tests")
+else()
+  set(LIBGAV1_EXAMPLES_IDE_FOLDER "libgav1_examples")
+  set(LIBGAV1_TESTS_IDE_FOLDER "libgav1_tests")
+endif()
+
 # Resets list variables used to track libgav1 targets.
 macro(libgav1_reset_target_lists)
   unset(libgav1_targets)
@@ -100,6 +108,13 @@
   endif()
 
   add_executable(${exe_NAME} ${exe_SOURCES})
+  if(exe_TEST)
+    add_test(NAME ${exe_NAME} COMMAND ${exe_NAME})
+    set_property(TARGET ${exe_NAME} PROPERTY FOLDER ${LIBGAV1_TESTS_IDE_FOLDER})
+  else()
+    set_property(TARGET ${exe_NAME}
+                 PROPERTY FOLDER ${LIBGAV1_EXAMPLES_IDE_FOLDER})
+  endif()
 
   if(exe_OUTPUT_NAME)
     set_target_properties(${exe_NAME} PROPERTIES OUTPUT_NAME ${exe_OUTPUT_NAME})
@@ -366,4 +381,17 @@
       libgav1_create_dummy_source_file(TARGET ${lib_NAME} BASENAME ${lib_NAME})
     endif()
   endif()
+
+  if(lib_TEST)
+    set_property(TARGET ${lib_NAME} PROPERTY FOLDER ${LIBGAV1_TESTS_IDE_FOLDER})
+  else()
+    set(sources_list ${lib_SOURCES})
+    list(FILTER sources_list INCLUDE REGEX examples)
+    if(sources_list)
+      set_property(TARGET ${lib_NAME}
+                   PROPERTY FOLDER ${LIBGAV1_EXAMPLES_IDE_FOLDER})
+    else()
+      set_property(TARGET ${lib_NAME} PROPERTY FOLDER ${LIBGAV1_IDE_FOLDER})
+    endif()
+  endif()
 endmacro()
diff --git a/libgav1/cmake/toolchains/aarch64-linux-gnu.cmake b/libgav1/cmake/toolchains/aarch64-linux-gnu.cmake
index 7ffe397..fdcb012 100644
--- a/libgav1/cmake/toolchains/aarch64-linux-gnu.cmake
+++ b/libgav1/cmake/toolchains/aarch64-linux-gnu.cmake
@@ -23,6 +23,13 @@
   set(CROSS aarch64-linux-gnu-)
 endif()
 
-set(CMAKE_CXX_COMPILER ${CROSS}g++)
+# For c_decoder_test.c and c_version_test.c.
+if(NOT CMAKE_C_COMPILER)
+  set(CMAKE_C_COMPILER ${CROSS}gcc)
+endif()
+set(CMAKE_C_FLAGS_INIT "-march=armv8-a")
+if(NOT CMAKE_CXX_COMPILER)
+  set(CMAKE_CXX_COMPILER ${CROSS}g++)
+endif()
 set(CMAKE_CXX_FLAGS_INIT "-march=armv8-a")
 set(CMAKE_SYSTEM_PROCESSOR "aarch64")
diff --git a/libgav1/cmake/toolchains/arm-linux-gnueabihf.cmake b/libgav1/cmake/toolchains/arm-linux-gnueabihf.cmake
index 8051f0d..7448f54 100644
--- a/libgav1/cmake/toolchains/arm-linux-gnueabihf.cmake
+++ b/libgav1/cmake/toolchains/arm-linux-gnueabihf.cmake
@@ -23,7 +23,14 @@
   set(CROSS arm-linux-gnueabihf-)
 endif()
 
-set(CMAKE_CXX_COMPILER ${CROSS}g++)
+# For c_decoder_test.c and c_version_test.c.
+if(NOT CMAKE_C_COMPILER)
+  set(CMAKE_C_COMPILER ${CROSS}gcc)
+endif()
+set(CMAKE_C_FLAGS_INIT "-march=armv7-a -marm")
+if(NOT CMAKE_CXX_COMPILER)
+  set(CMAKE_CXX_COMPILER ${CROSS}g++)
+endif()
 set(CMAKE_CXX_FLAGS_INIT "-march=armv7-a -marm")
 set(CMAKE_SYSTEM_PROCESSOR "armv7")
 set(LIBGAV1_NEON_INTRINSICS_FLAG "-mfpu=neon")
diff --git a/libgav1/src/buffer_pool.h b/libgav1/src/buffer_pool.h
index f35a633..d9eba6d 100644
--- a/libgav1/src/buffer_pool.h
+++ b/libgav1/src/buffer_pool.h
@@ -17,12 +17,13 @@
 #ifndef LIBGAV1_SRC_BUFFER_POOL_H_
 #define LIBGAV1_SRC_BUFFER_POOL_H_
 
+#include <algorithm>
 #include <array>
 #include <cassert>
 #include <climits>
 #include <condition_variable>  // NOLINT (unapproved c++11 header)
 #include <cstdint>
-#include <cstring>
+#include <memory>
 #include <mutex>  // NOLINT (unapproved c++11 header)
 
 #include "src/dsp/common.h"
@@ -52,7 +53,9 @@
 
 // A reference-counted frame buffer. Clients should access it via
 // RefCountedBufferPtr, which manages reference counting transparently.
-class RefCountedBuffer {
+// The alignment requirement is due to the SymbolDecoderContext member
+// frame_context_.
+class RefCountedBuffer : public MaxAlignedAllocable {
  public:
   // Not copyable or movable.
   RefCountedBuffer(const RefCountedBuffer&) = delete;
diff --git a/libgav1/src/decoder_impl.cc b/libgav1/src/decoder_impl.cc
index e23903c..dbb9e81 100644
--- a/libgav1/src/decoder_impl.cc
+++ b/libgav1/src/decoder_impl.cc
@@ -1232,7 +1232,7 @@
     LIBGAV1_DLOG(ERROR, "Failed to allocate memory for the decoder buffer.");
     return kStatusOutOfMemory;
   }
-  if (sequence_header.enable_cdef) {
+  if (frame_header.cdef.bits > 0) {
     if (!frame_scratch_buffer->cdef_index.Reset(
             DivideBy16(frame_header.rows4x4 + kMaxBlockHeight4x4),
             DivideBy16(frame_header.columns4x4 + kMaxBlockWidth4x4),
@@ -1241,6 +1241,15 @@
       return kStatusOutOfMemory;
     }
   }
+  if (do_cdef) {
+    if (!frame_scratch_buffer->cdef_skip.Reset(
+            DivideBy2(frame_header.rows4x4 + kMaxBlockHeight4x4),
+            DivideBy16(frame_header.columns4x4 + kMaxBlockWidth4x4),
+            /*zero_initialize=*/true)) {
+      LIBGAV1_DLOG(ERROR, "Failed to allocate memory for cdef skip.");
+      return kStatusOutOfMemory;
+    }
+  }
   if (!frame_scratch_buffer->inter_transform_sizes.Reset(
           frame_header.rows4x4 + kMaxBlockHeight4x4,
           frame_header.columns4x4 + kMaxBlockWidth4x4,
@@ -1364,23 +1373,39 @@
     const int pixel_size = sequence_header.color_config.bitdepth == 8
                                ? sizeof(uint8_t)
                                : sizeof(uint16_t);
+    const int coefficients_size = kSuperResFilterTaps *
+                                  Align(frame_header.upscaled_width, 16) *
+                                  pixel_size;
     if (!frame_scratch_buffer->superres_coefficients[kPlaneTypeY].Resize(
-            kSuperResFilterTaps * Align(frame_header.upscaled_width, 16) *
-            pixel_size)) {
+            coefficients_size)) {
       LIBGAV1_DLOG(ERROR,
                    "Failed to Resize superres_coefficients[kPlaneTypeY].");
       return kStatusOutOfMemory;
     }
+#if LIBGAV1_MSAN
+    // Quiet SuperRes_NEON() msan warnings.
+    memset(frame_scratch_buffer->superres_coefficients[kPlaneTypeY].get(), 0,
+           coefficients_size);
+#endif
+    const int uv_coefficients_size =
+        kSuperResFilterTaps *
+        Align(SubsampledValue(frame_header.upscaled_width, 1), 16) * pixel_size;
     if (!sequence_header.color_config.is_monochrome &&
         sequence_header.color_config.subsampling_x != 0 &&
         !frame_scratch_buffer->superres_coefficients[kPlaneTypeUV].Resize(
-            kSuperResFilterTaps *
-            Align(SubsampledValue(frame_header.upscaled_width, 1), 16) *
-            pixel_size)) {
+            uv_coefficients_size)) {
       LIBGAV1_DLOG(ERROR,
                    "Failed to Resize superres_coefficients[kPlaneTypeUV].");
       return kStatusOutOfMemory;
     }
+#if LIBGAV1_MSAN
+    if (!sequence_header.color_config.is_monochrome &&
+        sequence_header.color_config.subsampling_x != 0) {
+      // Quiet SuperRes_NEON() msan warnings.
+      memset(frame_scratch_buffer->superres_coefficients[kPlaneTypeUV].get(), 0,
+             uv_coefficients_size);
+    }
+#endif
   }
 
   if (do_superres && threading_strategy.post_filter_thread_pool() != nullptr) {
@@ -1405,10 +1430,6 @@
     }
   }
 
-  PostFilter post_filter(frame_header, sequence_header, frame_scratch_buffer,
-                         current_frame->buffer(), dsp,
-                         settings_.post_filter_mask);
-
   if (is_frame_parallel_ && !IsIntraFrame(frame_header.frame_type)) {
     // We can parse the current frame if all the reference frames have been
     // parsed.
@@ -1477,6 +1498,9 @@
     }
   }
 
+  PostFilter post_filter(frame_header, sequence_header, frame_scratch_buffer,
+                         current_frame->buffer(), dsp,
+                         settings_.post_filter_mask);
   SymbolDecoderContext saved_symbol_decoder_context;
   BlockingCounterWithStatus pending_tiles(tile_count);
   for (int tile_number = 0; tile_number < tile_count; ++tile_number) {
diff --git a/libgav1/src/dsp/arm/average_blend_neon.cc b/libgav1/src/dsp/arm/average_blend_neon.cc
index 5b4c094..3603750 100644
--- a/libgav1/src/dsp/arm/average_blend_neon.cc
+++ b/libgav1/src/dsp/arm/average_blend_neon.cc
@@ -40,17 +40,19 @@
 namespace low_bitdepth {
 namespace {
 
-inline uint8x8_t AverageBlend8Row(const int16_t* prediction_0,
-                                  const int16_t* prediction_1) {
+inline uint8x8_t AverageBlend8Row(const int16_t* LIBGAV1_RESTRICT prediction_0,
+                                  const int16_t* LIBGAV1_RESTRICT
+                                      prediction_1) {
   const int16x8_t pred0 = vld1q_s16(prediction_0);
   const int16x8_t pred1 = vld1q_s16(prediction_1);
   const int16x8_t res = vaddq_s16(pred0, pred1);
   return vqrshrun_n_s16(res, kInterPostRoundBit + 1);
 }
 
-inline void AverageBlendLargeRow(const int16_t* prediction_0,
-                                 const int16_t* prediction_1, const int width,
-                                 uint8_t* dest) {
+inline void AverageBlendLargeRow(const int16_t* LIBGAV1_RESTRICT prediction_0,
+                                 const int16_t* LIBGAV1_RESTRICT prediction_1,
+                                 const int width,
+                                 uint8_t* LIBGAV1_RESTRICT dest) {
   int x = width;
   do {
     const int16x8_t pred_00 = vld1q_s16(prediction_0);
@@ -71,8 +73,10 @@
   } while (x != 0);
 }
 
-void AverageBlend_NEON(const void* prediction_0, const void* prediction_1,
-                       const int width, const int height, void* const dest,
+void AverageBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                       const void* LIBGAV1_RESTRICT prediction_1,
+                       const int width, const int height,
+                       void* LIBGAV1_RESTRICT const dest,
                        const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
@@ -139,10 +143,10 @@
 namespace high_bitdepth {
 namespace {
 
-inline uint16x8_t AverageBlend8Row(const uint16_t* prediction_0,
-                                   const uint16_t* prediction_1,
-                                   const int32x4_t compound_offset,
-                                   const uint16x8_t v_bitdepth) {
+inline uint16x8_t AverageBlend8Row(
+    const uint16_t* LIBGAV1_RESTRICT prediction_0,
+    const uint16_t* LIBGAV1_RESTRICT prediction_1,
+    const int32x4_t compound_offset, const uint16x8_t v_bitdepth) {
   const uint16x8_t pred0 = vld1q_u16(prediction_0);
   const uint16x8_t pred1 = vld1q_u16(prediction_1);
   const uint32x4_t pred_lo =
@@ -158,9 +162,10 @@
   return vminq_u16(vcombine_u16(res_lo, res_hi), v_bitdepth);
 }
 
-inline void AverageBlendLargeRow(const uint16_t* prediction_0,
-                                 const uint16_t* prediction_1, const int width,
-                                 uint16_t* dest,
+inline void AverageBlendLargeRow(const uint16_t* LIBGAV1_RESTRICT prediction_0,
+                                 const uint16_t* LIBGAV1_RESTRICT prediction_1,
+                                 const int width,
+                                 uint16_t* LIBGAV1_RESTRICT dest,
                                  const int32x4_t compound_offset,
                                  const uint16x8_t v_bitdepth) {
   int x = width;
@@ -181,8 +186,10 @@
   } while (x != 0);
 }
 
-void AverageBlend_NEON(const void* prediction_0, const void* prediction_1,
-                       const int width, const int height, void* const dest,
+void AverageBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                       const void* LIBGAV1_RESTRICT prediction_1,
+                       const int width, const int height,
+                       void* LIBGAV1_RESTRICT const dest,
                        const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint16_t*>(dest);
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
diff --git a/libgav1/src/dsp/arm/cdef_neon.cc b/libgav1/src/dsp/arm/cdef_neon.cc
index 60c72d6..da271f2 100644
--- a/libgav1/src/dsp/arm/cdef_neon.cc
+++ b/libgav1/src/dsp/arm/cdef_neon.cc
@@ -33,7 +33,6 @@
 
 namespace libgav1 {
 namespace dsp {
-namespace low_bitdepth {
 namespace {
 
 #include "src/dsp/cdef.inc"
@@ -234,7 +233,8 @@
   *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[3], v_zero, 5));
 }
 
-LIBGAV1_ALWAYS_INLINE void AddPartial(const void* const source,
+template <int bitdepth>
+LIBGAV1_ALWAYS_INLINE void AddPartial(const void* LIBGAV1_RESTRICT const source,
                                       ptrdiff_t stride, uint16x8_t* partial_lo,
                                       uint16x8_t* partial_hi) {
   const auto* src = static_cast<const uint8_t*>(source);
@@ -249,11 +249,20 @@
   // 60 61 62 63 64 65 66 67
   // 70 71 72 73 74 75 76 77
   uint8x8_t v_src[8];
-  for (int i = 0; i < 8; ++i) {
-    v_src[i] = vld1_u8(src);
-    src += stride;
+  if (bitdepth == kBitdepth8) {
+    for (auto& v : v_src) {
+      v = vld1_u8(src);
+      src += stride;
+    }
+  } else {
+    // bitdepth - 8
+    constexpr int src_shift = (bitdepth == kBitdepth10) ? 2 : 4;
+    for (auto& v : v_src) {
+      v = vshrn_n_u16(vld1q_u16(reinterpret_cast<const uint16_t*>(src)),
+                      src_shift);
+      src += stride;
+    }
   }
-
   // partial for direction 2
   // --------------------------------------------------------------------------
   // partial[2][i] += x;
@@ -358,15 +367,19 @@
   return SumVector(c);
 }
 
-void CdefDirection_NEON(const void* const source, ptrdiff_t stride,
-                        uint8_t* const direction, int* const variance) {
+template <int bitdepth>
+void CdefDirection_NEON(const void* LIBGAV1_RESTRICT const source,
+                        ptrdiff_t stride,
+                        uint8_t* LIBGAV1_RESTRICT const direction,
+                        int* LIBGAV1_RESTRICT const variance) {
   assert(direction != nullptr);
   assert(variance != nullptr);
   const auto* src = static_cast<const uint8_t*>(source);
+
   uint32_t cost[8];
   uint16x8_t partial_lo[8], partial_hi[8];
 
-  AddPartial(src, stride, partial_lo, partial_hi);
+  AddPartial<bitdepth>(src, stride, partial_lo, partial_hi);
 
   cost[2] = SquareAccumulate(partial_lo[2]);
   cost[6] = SquareAccumulate(partial_lo[6]);
@@ -407,8 +420,9 @@
 // CdefFilter
 
 // Load 4 vectors based on the given |direction|.
-void LoadDirection(const uint16_t* const src, const ptrdiff_t stride,
-                   uint16x8_t* output, const int direction) {
+void LoadDirection(const uint16_t* LIBGAV1_RESTRICT const src,
+                   const ptrdiff_t stride, uint16x8_t* output,
+                   const int direction) {
   // Each |direction| describes a different set of source values. Expand this
   // set by negating each set. For |direction| == 0 this gives a diagonal line
   // from top right to bottom left. The first value is y, the second x. Negative
@@ -432,8 +446,9 @@
 
 // Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to
 // do 2 rows at a time.
-void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride,
-                    uint16x8_t* output, const int direction) {
+void LoadDirection4(const uint16_t* LIBGAV1_RESTRICT const src,
+                    const ptrdiff_t stride, uint16x8_t* output,
+                    const int direction) {
   const int y_0 = kCdefDirections[direction][0][0];
   const int x_0 = kCdefDirections[direction][0][1];
   const int y_1 = kCdefDirections[direction][1][0];
@@ -469,12 +484,90 @@
       vsubq_u16(veorq_u16(clamp_abs_diff, sign), sign));
 }
 
-template <int width, bool enable_primary = true, bool enable_secondary = true>
-void CdefFilter_NEON(const uint16_t* src, const ptrdiff_t src_stride,
-                     const int height, const int primary_strength,
-                     const int secondary_strength, const int damping,
-                     const int direction, void* dest,
-                     const ptrdiff_t dst_stride) {
+template <typename Pixel>
+uint16x8_t GetMaxPrimary(uint16x8_t* primary_val, uint16x8_t max,
+                         uint16x8_t cdef_large_value_mask) {
+  if (sizeof(Pixel) == 1) {
+    // The source is 16 bits, however, we only really care about the lower
+    // 8 bits.  The upper 8 bits contain the "large" flag.  After the final
+    // primary max has been calculated, zero out the upper 8 bits.  Use this
+    // to find the "16 bit" max.
+    const uint8x16_t max_p01 = vmaxq_u8(vreinterpretq_u8_u16(primary_val[0]),
+                                        vreinterpretq_u8_u16(primary_val[1]));
+    const uint8x16_t max_p23 = vmaxq_u8(vreinterpretq_u8_u16(primary_val[2]),
+                                        vreinterpretq_u8_u16(primary_val[3]));
+    const uint16x8_t max_p = vreinterpretq_u16_u8(vmaxq_u8(max_p01, max_p23));
+    max = vmaxq_u16(max, vandq_u16(max_p, cdef_large_value_mask));
+  } else {
+    // Convert kCdefLargeValue to 0 before calculating max.
+    max = vmaxq_u16(max, vandq_u16(primary_val[0], cdef_large_value_mask));
+    max = vmaxq_u16(max, vandq_u16(primary_val[1], cdef_large_value_mask));
+    max = vmaxq_u16(max, vandq_u16(primary_val[2], cdef_large_value_mask));
+    max = vmaxq_u16(max, vandq_u16(primary_val[3], cdef_large_value_mask));
+  }
+  return max;
+}
+
+template <typename Pixel>
+uint16x8_t GetMaxSecondary(uint16x8_t* secondary_val, uint16x8_t max,
+                           uint16x8_t cdef_large_value_mask) {
+  if (sizeof(Pixel) == 1) {
+    const uint8x16_t max_s01 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[0]),
+                                        vreinterpretq_u8_u16(secondary_val[1]));
+    const uint8x16_t max_s23 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[2]),
+                                        vreinterpretq_u8_u16(secondary_val[3]));
+    const uint8x16_t max_s45 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[4]),
+                                        vreinterpretq_u8_u16(secondary_val[5]));
+    const uint8x16_t max_s67 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[6]),
+                                        vreinterpretq_u8_u16(secondary_val[7]));
+    const uint16x8_t max_s = vreinterpretq_u16_u8(
+        vmaxq_u8(vmaxq_u8(max_s01, max_s23), vmaxq_u8(max_s45, max_s67)));
+    max = vmaxq_u16(max, vandq_u16(max_s, cdef_large_value_mask));
+  } else {
+    max = vmaxq_u16(max, vandq_u16(secondary_val[0], cdef_large_value_mask));
+    max = vmaxq_u16(max, vandq_u16(secondary_val[1], cdef_large_value_mask));
+    max = vmaxq_u16(max, vandq_u16(secondary_val[2], cdef_large_value_mask));
+    max = vmaxq_u16(max, vandq_u16(secondary_val[3], cdef_large_value_mask));
+    max = vmaxq_u16(max, vandq_u16(secondary_val[4], cdef_large_value_mask));
+    max = vmaxq_u16(max, vandq_u16(secondary_val[5], cdef_large_value_mask));
+    max = vmaxq_u16(max, vandq_u16(secondary_val[6], cdef_large_value_mask));
+    max = vmaxq_u16(max, vandq_u16(secondary_val[7], cdef_large_value_mask));
+  }
+  return max;
+}
+
+template <typename Pixel, int width>
+void StorePixels(void* dest, ptrdiff_t dst_stride, int16x8_t result) {
+  auto* const dst8 = static_cast<uint8_t*>(dest);
+  if (sizeof(Pixel) == 1) {
+    const uint8x8_t dst_pixel = vqmovun_s16(result);
+    if (width == 8) {
+      vst1_u8(dst8, dst_pixel);
+    } else {
+      StoreLo4(dst8, dst_pixel);
+      StoreHi4(dst8 + dst_stride, dst_pixel);
+    }
+  } else {
+    const uint16x8_t dst_pixel = vreinterpretq_u16_s16(result);
+    auto* const dst16 = reinterpret_cast<uint16_t*>(dst8);
+    if (width == 8) {
+      vst1q_u16(dst16, dst_pixel);
+    } else {
+      auto* const dst16_next_row =
+          reinterpret_cast<uint16_t*>(dst8 + dst_stride);
+      vst1_u16(dst16, vget_low_u16(dst_pixel));
+      vst1_u16(dst16_next_row, vget_high_u16(dst_pixel));
+    }
+  }
+}
+
+template <int width, typename Pixel, bool enable_primary = true,
+          bool enable_secondary = true>
+void CdefFilter_NEON(const uint16_t* LIBGAV1_RESTRICT src,
+                     const ptrdiff_t src_stride, const int height,
+                     const int primary_strength, const int secondary_strength,
+                     const int damping, const int direction,
+                     void* LIBGAV1_RESTRICT dest, const ptrdiff_t dst_stride) {
   static_assert(width == 8 || width == 4, "");
   static_assert(enable_primary || enable_secondary, "");
   constexpr bool clipping_required = enable_primary && enable_secondary;
@@ -488,22 +581,34 @@
 
   // FloorLog2() requires input to be > 0.
   // 8-bit damping range: Y: [3, 6], UV: [2, 5].
+  // 10-bit damping range: Y: [3, 6 + 2], UV: [2, 5 + 2].
   if (enable_primary) {
-    // primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is necessary
-    // for UV filtering.
+    // 8-bit primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is
+    // necessary for UV filtering.
+    // 10-bit primary_strength: [0, 15 << 2].
     primary_damping_shift =
         vdupq_n_s16(-std::max(0, damping - FloorLog2(primary_strength)));
   }
+
   if (enable_secondary) {
-    // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is
-    // necessary.
-    assert(damping - FloorLog2(secondary_strength) >= 0);
-    secondary_damping_shift =
-        vdupq_n_s16(-(damping - FloorLog2(secondary_strength)));
+    if (sizeof(Pixel) == 1) {
+      // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is
+      // necessary.
+      assert(damping - FloorLog2(secondary_strength) >= 0);
+      secondary_damping_shift =
+          vdupq_n_s16(-(damping - FloorLog2(secondary_strength)));
+    } else {
+      // secondary_strength: [0, 4 << 2]
+      secondary_damping_shift =
+          vdupq_n_s16(-std::max(0, damping - FloorLog2(secondary_strength)));
+    }
   }
 
-  const int primary_tap_0 = kCdefPrimaryTaps[primary_strength & 1][0];
-  const int primary_tap_1 = kCdefPrimaryTaps[primary_strength & 1][1];
+  constexpr int coeff_shift = (sizeof(Pixel) == 1) ? 0 : kBitdepth10 - 8;
+  const int primary_tap_0 =
+      kCdefPrimaryTaps[(primary_strength >> coeff_shift) & 1][0];
+  const int primary_tap_1 =
+      kCdefPrimaryTaps[(primary_strength >> coeff_shift) & 1][1];
 
   int y = height;
   do {
@@ -533,19 +638,7 @@
         min = vminq_u16(min, primary_val[2]);
         min = vminq_u16(min, primary_val[3]);
 
-        // The source is 16 bits, however, we only really care about the lower
-        // 8 bits.  The upper 8 bits contain the "large" flag.  After the final
-        // primary max has been calculated, zero out the upper 8 bits.  Use this
-        // to find the "16 bit" max.
-        const uint8x16_t max_p01 =
-            vmaxq_u8(vreinterpretq_u8_u16(primary_val[0]),
-                     vreinterpretq_u8_u16(primary_val[1]));
-        const uint8x16_t max_p23 =
-            vmaxq_u8(vreinterpretq_u8_u16(primary_val[2]),
-                     vreinterpretq_u8_u16(primary_val[3]));
-        const uint16x8_t max_p =
-            vreinterpretq_u16_u8(vmaxq_u8(max_p01, max_p23));
-        max = vmaxq_u16(max, vandq_u16(max_p, cdef_large_value_mask));
+        max = GetMaxPrimary<Pixel>(primary_val, max, cdef_large_value_mask);
       }
 
       sum = Constrain(primary_val[0], pixel, primary_threshold,
@@ -588,21 +681,7 @@
         min = vminq_u16(min, secondary_val[6]);
         min = vminq_u16(min, secondary_val[7]);
 
-        const uint8x16_t max_s01 =
-            vmaxq_u8(vreinterpretq_u8_u16(secondary_val[0]),
-                     vreinterpretq_u8_u16(secondary_val[1]));
-        const uint8x16_t max_s23 =
-            vmaxq_u8(vreinterpretq_u8_u16(secondary_val[2]),
-                     vreinterpretq_u8_u16(secondary_val[3]));
-        const uint8x16_t max_s45 =
-            vmaxq_u8(vreinterpretq_u8_u16(secondary_val[4]),
-                     vreinterpretq_u8_u16(secondary_val[5]));
-        const uint8x16_t max_s67 =
-            vmaxq_u8(vreinterpretq_u8_u16(secondary_val[6]),
-                     vreinterpretq_u8_u16(secondary_val[7]));
-        const uint16x8_t max_s = vreinterpretq_u16_u8(
-            vmaxq_u8(vmaxq_u8(max_s01, max_s23), vmaxq_u8(max_s45, max_s67)));
-        max = vmaxq_u16(max, vandq_u16(max_s, cdef_large_value_mask));
+        max = GetMaxSecondary<Pixel>(secondary_val, max, cdef_large_value_mask);
       }
 
       sum = vmlaq_n_s16(sum,
@@ -647,41 +726,70 @@
       result = vmaxq_s16(result, vreinterpretq_s16_u16(min));
     }
 
-    const uint8x8_t dst_pixel = vqmovun_s16(result);
-    if (width == 8) {
-      src += src_stride;
-      vst1_u8(dst, dst_pixel);
-      dst += dst_stride;
-      --y;
-    } else {
-      src += src_stride << 1;
-      StoreLo4(dst, dst_pixel);
-      dst += dst_stride;
-      StoreHi4(dst, dst_pixel);
-      dst += dst_stride;
-      y -= 2;
-    }
+    StorePixels<Pixel, width>(dst, dst_stride, result);
+
+    src += (width == 8) ? src_stride : src_stride << 1;
+    dst += (width == 8) ? dst_stride : dst_stride << 1;
+    y -= (width == 8) ? 1 : 2;
   } while (y != 0);
 }
 
+}  // namespace
+
+namespace low_bitdepth {
+namespace {
+
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
-  dsp->cdef_direction = CdefDirection_NEON;
-  dsp->cdef_filters[0][0] = CdefFilter_NEON<4>;
-  dsp->cdef_filters[0][1] =
-      CdefFilter_NEON<4, /*enable_primary=*/true, /*enable_secondary=*/false>;
-  dsp->cdef_filters[0][2] = CdefFilter_NEON<4, /*enable_primary=*/false>;
-  dsp->cdef_filters[1][0] = CdefFilter_NEON<8>;
-  dsp->cdef_filters[1][1] =
-      CdefFilter_NEON<8, /*enable_primary=*/true, /*enable_secondary=*/false>;
-  dsp->cdef_filters[1][2] = CdefFilter_NEON<8, /*enable_primary=*/false>;
+  dsp->cdef_direction = CdefDirection_NEON<kBitdepth8>;
+  dsp->cdef_filters[0][0] = CdefFilter_NEON<4, uint8_t>;
+  dsp->cdef_filters[0][1] = CdefFilter_NEON<4, uint8_t, /*enable_primary=*/true,
+                                            /*enable_secondary=*/false>;
+  dsp->cdef_filters[0][2] =
+      CdefFilter_NEON<4, uint8_t, /*enable_primary=*/false>;
+  dsp->cdef_filters[1][0] = CdefFilter_NEON<8, uint8_t>;
+  dsp->cdef_filters[1][1] = CdefFilter_NEON<8, uint8_t, /*enable_primary=*/true,
+                                            /*enable_secondary=*/false>;
+  dsp->cdef_filters[1][2] =
+      CdefFilter_NEON<8, uint8_t, /*enable_primary=*/false>;
 }
 
 }  // namespace
 }  // namespace low_bitdepth
 
-void CdefInit_NEON() { low_bitdepth::Init8bpp(); }
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->cdef_direction = CdefDirection_NEON<kBitdepth10>;
+  dsp->cdef_filters[0][0] = CdefFilter_NEON<4, uint16_t>;
+  dsp->cdef_filters[0][1] =
+      CdefFilter_NEON<4, uint16_t, /*enable_primary=*/true,
+                      /*enable_secondary=*/false>;
+  dsp->cdef_filters[0][2] =
+      CdefFilter_NEON<4, uint16_t, /*enable_primary=*/false>;
+  dsp->cdef_filters[1][0] = CdefFilter_NEON<8, uint16_t>;
+  dsp->cdef_filters[1][1] =
+      CdefFilter_NEON<8, uint16_t, /*enable_primary=*/true,
+                      /*enable_secondary=*/false>;
+  dsp->cdef_filters[1][2] =
+      CdefFilter_NEON<8, uint16_t, /*enable_primary=*/false>;
+}
+
+}  // namespace
+}  // namespace high_bitdepth
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+void CdefInit_NEON() {
+  low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  high_bitdepth::Init10bpp();
+#endif
+}
 
 }  // namespace dsp
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/arm/cdef_neon.h b/libgav1/src/dsp/arm/cdef_neon.h
index 53d5f86..ef8ed3c 100644
--- a/libgav1/src/dsp/arm/cdef_neon.h
+++ b/libgav1/src/dsp/arm/cdef_neon.h
@@ -33,6 +33,9 @@
 #if LIBGAV1_ENABLE_NEON
 #define LIBGAV1_Dsp8bpp_CdefDirection LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_CdefFilters LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_CdefDirection LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_CdefFilters LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_CDEF_NEON_H_
diff --git a/libgav1/src/dsp/arm/common_neon.h b/libgav1/src/dsp/arm/common_neon.h
index 05e0d05..9c46525 100644
--- a/libgav1/src/dsp/arm/common_neon.h
+++ b/libgav1/src/dsp/arm/common_neon.h
@@ -23,9 +23,13 @@
 
 #include <arm_neon.h>
 
+#include <algorithm>
+#include <cstddef>
 #include <cstdint>
 #include <cstring>
 
+#include "src/utils/compiler_attributes.h"
+
 #if 0
 #include <cstdio>
 #include <string>
@@ -183,6 +187,20 @@
 #define PD(x) PrintReg(x, #x)
 #define PX(x) PrintHex(x, #x)
 
+#if LIBGAV1_MSAN
+#include <sanitizer/msan_interface.h>
+
+inline void PrintShadow(const void* r, const char* const name,
+                        const size_t size) {
+  if (kEnablePrintRegs) {
+    fprintf(stderr, "Shadow for %s:\n", name);
+    __msan_print_shadow(r, size);
+  }
+}
+#define PS(var, N) PrintShadow(var, #var, N)
+
+#endif  // LIBGAV1_MSAN
+
 #endif  // 0
 
 namespace libgav1 {
@@ -210,6 +228,14 @@
       vld1_lane_u16(&temp, vreinterpret_u16_u8(val), lane));
 }
 
+template <int lane>
+inline uint16x4_t Load2(const void* const buf, uint16x4_t val) {
+  uint32_t temp;
+  memcpy(&temp, buf, 4);
+  return vreinterpret_u16_u32(
+      vld1_lane_u32(&temp, vreinterpret_u32_u16(val), lane));
+}
+
 // Load 4 uint8_t values into the low half of a uint8x8_t register. Zeros the
 // register before loading the values. Use caution when using this in loops
 // because it will re-zero the register before loading on every iteration.
@@ -229,6 +255,96 @@
       vld1_lane_u32(&temp, vreinterpret_u32_u8(val), lane));
 }
 
+// Convenience functions for 16-bit loads from a uint8_t* source.
+inline uint16x4_t Load4U16(const void* const buf) {
+  return vld1_u16(static_cast<const uint16_t*>(buf));
+}
+
+inline uint16x8_t Load8U16(const void* const buf) {
+  return vld1q_u16(static_cast<const uint16_t*>(buf));
+}
+
+//------------------------------------------------------------------------------
+// Load functions to avoid MemorySanitizer's use-of-uninitialized-value warning.
+
+inline uint8x8_t MaskOverreads(const uint8x8_t source,
+                               const ptrdiff_t over_read_in_bytes) {
+  uint8x8_t dst = source;
+#if LIBGAV1_MSAN
+  if (over_read_in_bytes > 0) {
+    uint8x8_t mask = vdup_n_u8(0);
+    uint8x8_t valid_element_mask = vdup_n_u8(-1);
+    const int valid_bytes =
+        std::min(8, 8 - static_cast<int>(over_read_in_bytes));
+    for (int i = 0; i < valid_bytes; ++i) {
+      // Feed ff bytes into |mask| one at a time.
+      mask = vext_u8(valid_element_mask, mask, 7);
+    }
+    dst = vand_u8(dst, mask);
+  }
+#else
+  static_cast<void>(over_read_in_bytes);
+#endif
+  return dst;
+}
+
+inline uint8x16_t MaskOverreadsQ(const uint8x16_t source,
+                                 const ptrdiff_t over_read_in_bytes) {
+  uint8x16_t dst = source;
+#if LIBGAV1_MSAN
+  if (over_read_in_bytes > 0) {
+    uint8x16_t mask = vdupq_n_u8(0);
+    uint8x16_t valid_element_mask = vdupq_n_u8(-1);
+    const int valid_bytes =
+        std::min(16, 16 - static_cast<int>(over_read_in_bytes));
+    for (int i = 0; i < valid_bytes; ++i) {
+      // Feed ff bytes into |mask| one at a time.
+      mask = vextq_u8(valid_element_mask, mask, 15);
+    }
+    dst = vandq_u8(dst, mask);
+  }
+#else
+  static_cast<void>(over_read_in_bytes);
+#endif
+  return dst;
+}
+
+inline uint8x8_t Load1MsanU8(const uint8_t* const source,
+                             const ptrdiff_t over_read_in_bytes) {
+  return MaskOverreads(vld1_u8(source), over_read_in_bytes);
+}
+
+inline uint8x16_t Load1QMsanU8(const uint8_t* const source,
+                               const ptrdiff_t over_read_in_bytes) {
+  return MaskOverreadsQ(vld1q_u8(source), over_read_in_bytes);
+}
+
+inline uint16x8_t Load1QMsanU16(const uint16_t* const source,
+                                const ptrdiff_t over_read_in_bytes) {
+  return vreinterpretq_u16_u8(MaskOverreadsQ(
+      vreinterpretq_u8_u16(vld1q_u16(source)), over_read_in_bytes));
+}
+
+inline uint16x8x2_t Load2QMsanU16(const uint16_t* const source,
+                                  const ptrdiff_t over_read_in_bytes) {
+  // Relative source index of elements (2 bytes each):
+  // dst.val[0]: 00 02 04 06 08 10 12 14
+  // dst.val[1]: 01 03 05 07 09 11 13 15
+  uint16x8x2_t dst = vld2q_u16(source);
+  dst.val[0] = vreinterpretq_u16_u8(MaskOverreadsQ(
+      vreinterpretq_u8_u16(dst.val[0]), over_read_in_bytes >> 1));
+  dst.val[1] = vreinterpretq_u16_u8(
+      MaskOverreadsQ(vreinterpretq_u8_u16(dst.val[1]),
+                     (over_read_in_bytes >> 1) + (over_read_in_bytes % 4)));
+  return dst;
+}
+
+inline uint32x4_t Load1QMsanU32(const uint32_t* const source,
+                                const ptrdiff_t over_read_in_bytes) {
+  return vreinterpretq_u32_u8(MaskOverreadsQ(
+      vreinterpretq_u8_u32(vld1q_u32(source)), over_read_in_bytes));
+}
+
 //------------------------------------------------------------------------------
 // Store functions.
 
@@ -272,7 +388,7 @@
 // Store 2 uint16_t values from |lane| * 2 and |lane| * 2 + 1 of a uint16x4_t
 // register.
 template <int lane>
-inline void Store2(uint16_t* const buf, const uint16x4_t val) {
+inline void Store2(void* const buf, const uint16x4_t val) {
   ValueToMem<uint32_t>(buf, vget_lane_u32(vreinterpret_u32_u16(val), lane));
 }
 
@@ -287,6 +403,104 @@
 }
 
 //------------------------------------------------------------------------------
+// Pointer helpers.
+
+// This function adds |stride|, given as a number of bytes, to a pointer to a
+// larger type, using native pointer arithmetic.
+template <typename T>
+inline T* AddByteStride(T* ptr, const ptrdiff_t stride) {
+  return reinterpret_cast<T*>(
+      const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(ptr) + stride));
+}
+
+//------------------------------------------------------------------------------
+// Multiply.
+
+// Shim vmull_high_u16 for armv7.
+inline uint32x4_t VMullHighU16(const uint16x8_t a, const uint16x8_t b) {
+#if defined(__aarch64__)
+  return vmull_high_u16(a, b);
+#else
+  return vmull_u16(vget_high_u16(a), vget_high_u16(b));
+#endif
+}
+
+// Shim vmull_high_s16 for armv7.
+inline int32x4_t VMullHighS16(const int16x8_t a, const int16x8_t b) {
+#if defined(__aarch64__)
+  return vmull_high_s16(a, b);
+#else
+  return vmull_s16(vget_high_s16(a), vget_high_s16(b));
+#endif
+}
+
+// Shim vmlal_high_u16 for armv7.
+inline uint32x4_t VMlalHighU16(const uint32x4_t a, const uint16x8_t b,
+                               const uint16x8_t c) {
+#if defined(__aarch64__)
+  return vmlal_high_u16(a, b, c);
+#else
+  return vmlal_u16(a, vget_high_u16(b), vget_high_u16(c));
+#endif
+}
+
+// Shim vmlal_high_s16 for armv7.
+inline int32x4_t VMlalHighS16(const int32x4_t a, const int16x8_t b,
+                              const int16x8_t c) {
+#if defined(__aarch64__)
+  return vmlal_high_s16(a, b, c);
+#else
+  return vmlal_s16(a, vget_high_s16(b), vget_high_s16(c));
+#endif
+}
+
+// Shim vmul_laneq_u16 for armv7.
+template <int lane>
+inline uint16x4_t VMulLaneQU16(const uint16x4_t a, const uint16x8_t b) {
+#if defined(__aarch64__)
+  return vmul_laneq_u16(a, b, lane);
+#else
+  if (lane < 4) return vmul_lane_u16(a, vget_low_u16(b), lane & 0x3);
+  return vmul_lane_u16(a, vget_high_u16(b), (lane - 4) & 0x3);
+#endif
+}
+
+// Shim vmulq_laneq_u16 for armv7.
+template <int lane>
+inline uint16x8_t VMulQLaneQU16(const uint16x8_t a, const uint16x8_t b) {
+#if defined(__aarch64__)
+  return vmulq_laneq_u16(a, b, lane);
+#else
+  if (lane < 4) return vmulq_lane_u16(a, vget_low_u16(b), lane & 0x3);
+  return vmulq_lane_u16(a, vget_high_u16(b), (lane - 4) & 0x3);
+#endif
+}
+
+// Shim vmla_laneq_u16 for armv7.
+template <int lane>
+inline uint16x4_t VMlaLaneQU16(const uint16x4_t a, const uint16x4_t b,
+                               const uint16x8_t c) {
+#if defined(__aarch64__)
+  return vmla_laneq_u16(a, b, c, lane);
+#else
+  if (lane < 4) return vmla_lane_u16(a, b, vget_low_u16(c), lane & 0x3);
+  return vmla_lane_u16(a, b, vget_high_u16(c), (lane - 4) & 0x3);
+#endif
+}
+
+// Shim vmlaq_laneq_u16 for armv7.
+template <int lane>
+inline uint16x8_t VMlaQLaneQU16(const uint16x8_t a, const uint16x8_t b,
+                                const uint16x8_t c) {
+#if defined(__aarch64__)
+  return vmlaq_laneq_u16(a, b, c, lane);
+#else
+  if (lane < 4) return vmlaq_lane_u16(a, b, vget_low_u16(c), lane & 0x3);
+  return vmlaq_lane_u16(a, b, vget_high_u16(c), (lane - 4) & 0x3);
+#endif
+}
+
+//------------------------------------------------------------------------------
 // Bit manipulation.
 
 // vshXX_n_XX() requires an immediate.
@@ -315,6 +529,51 @@
 #endif
 }
 
+// Shim vqtbl2_u8 for armv7.
+inline uint8x8_t VQTbl2U8(const uint8x16x2_t a, const uint8x8_t index) {
+#if defined(__aarch64__)
+  return vqtbl2_u8(a, index);
+#else
+  const uint8x8x4_t b = {vget_low_u8(a.val[0]), vget_high_u8(a.val[0]),
+                         vget_low_u8(a.val[1]), vget_high_u8(a.val[1])};
+  return vtbl4_u8(b, index);
+#endif
+}
+
+// Shim vqtbl2q_u8 for armv7.
+inline uint8x16_t VQTbl2QU8(const uint8x16x2_t a, const uint8x16_t index) {
+#if defined(__aarch64__)
+  return vqtbl2q_u8(a, index);
+#else
+  return vcombine_u8(VQTbl2U8(a, vget_low_u8(index)),
+                     VQTbl2U8(a, vget_high_u8(index)));
+#endif
+}
+
+// Shim vqtbl3q_u8 for armv7.
+inline uint8x8_t VQTbl3U8(const uint8x16x3_t a, const uint8x8_t index) {
+#if defined(__aarch64__)
+  return vqtbl3_u8(a, index);
+#else
+  const uint8x8x4_t b = {vget_low_u8(a.val[0]), vget_high_u8(a.val[0]),
+                         vget_low_u8(a.val[1]), vget_high_u8(a.val[1])};
+  const uint8x8x2_t c = {vget_low_u8(a.val[2]), vget_high_u8(a.val[2])};
+  const uint8x8_t index_ext = vsub_u8(index, vdup_n_u8(32));
+  const uint8x8_t partial_lookup = vtbl4_u8(b, index);
+  return vtbx2_u8(partial_lookup, c, index_ext);
+#endif
+}
+
+// Shim vqtbl3q_u8 for armv7.
+inline uint8x16_t VQTbl3QU8(const uint8x16x3_t a, const uint8x16_t index) {
+#if defined(__aarch64__)
+  return vqtbl3q_u8(a, index);
+#else
+  return vcombine_u8(VQTbl3U8(a, vget_low_u8(index)),
+                     VQTbl3U8(a, vget_high_u8(index)));
+#endif
+}
+
 // Shim vqtbl1_s8 for armv7.
 inline int8x8_t VQTbl1S8(const int8x16_t a, const uint8x8_t index) {
 #if defined(__aarch64__)
@@ -326,6 +585,25 @@
 }
 
 //------------------------------------------------------------------------------
+// Saturation helpers.
+
+inline int16x4_t Clip3S16(int16x4_t val, int16x4_t low, int16x4_t high) {
+  return vmin_s16(vmax_s16(val, low), high);
+}
+
+inline int16x8_t Clip3S16(const int16x8_t val, const int16x8_t low,
+                          const int16x8_t high) {
+  return vminq_s16(vmaxq_s16(val, low), high);
+}
+
+inline uint16x8_t ConvertToUnsignedPixelU16(int16x8_t val, int bitdepth) {
+  const int16x8_t low = vdupq_n_s16(0);
+  const uint16x8_t high = vdupq_n_u16((1 << bitdepth) - 1);
+
+  return vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(val, low)), high);
+}
+
+//------------------------------------------------------------------------------
 // Interleave.
 
 // vzipN is exclusive to A64.
@@ -439,6 +717,9 @@
   return vreinterpret_u8_u32(b);
 }
 
+// Swap high and low halves.
+inline uint16x8_t Transpose64(const uint16x8_t a) { return vextq_u16(a, a, 4); }
+
 // Implement vtrnq_s64().
 // Input:
 // a0: 00 01 02 03 04 05 06 07
@@ -512,6 +793,108 @@
   *b = e.val[1];
 }
 
+// 4x8 Input:
+// a[0]: 00 01 02 03 04 05 06 07
+// a[1]: 10 11 12 13 14 15 16 17
+// a[2]: 20 21 22 23 24 25 26 27
+// a[3]: 30 31 32 33 34 35 36 37
+// 8x4 Output:
+// a[0]: 00 10 20 30 04 14 24 34
+// a[1]: 01 11 21 31 05 15 25 35
+// a[2]: 02 12 22 32 06 16 26 36
+// a[3]: 03 13 23 33 07 17 27 37
+inline void Transpose4x8(uint16x8_t a[4]) {
+  // b0.val[0]: 00 10 02 12 04 14 06 16
+  // b0.val[1]: 01 11 03 13 05 15 07 17
+  // b1.val[0]: 20 30 22 32 24 34 26 36
+  // b1.val[1]: 21 31 23 33 25 35 27 37
+  const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]);
+  const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]);
+
+  // c0.val[0]: 00 10 20 30 04 14 24 34
+  // c0.val[1]: 02 12 22 32 06 16 26 36
+  // c1.val[0]: 01 11 21 31 05 15 25 35
+  // c1.val[1]: 03 13 23 33 07 17 27 37
+  const uint32x4x2_t c0 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[0]),
+                                    vreinterpretq_u32_u16(b1.val[0]));
+  const uint32x4x2_t c1 = vtrnq_u32(vreinterpretq_u32_u16(b0.val[1]),
+                                    vreinterpretq_u32_u16(b1.val[1]));
+
+  a[0] = vreinterpretq_u16_u32(c0.val[0]);
+  a[1] = vreinterpretq_u16_u32(c1.val[0]);
+  a[2] = vreinterpretq_u16_u32(c0.val[1]);
+  a[3] = vreinterpretq_u16_u32(c1.val[1]);
+}
+
+// Special transpose for loop filter.
+// 4x8 Input:
+// p_q:  p3 p2 p1 p0 q0 q1 q2 q3
+// a[0]: 00 01 02 03 04 05 06 07
+// a[1]: 10 11 12 13 14 15 16 17
+// a[2]: 20 21 22 23 24 25 26 27
+// a[3]: 30 31 32 33 34 35 36 37
+// 8x4 Output:
+// a[0]: 03 13 23 33 04 14 24 34  p0q0
+// a[1]: 02 12 22 32 05 15 25 35  p1q1
+// a[2]: 01 11 21 31 06 16 26 36  p2q2
+// a[3]: 00 10 20 30 07 17 27 37  p3q3
+// Direct reapplication of the function will reset the high halves, but
+// reverse the low halves:
+// p_q:  p0 p1 p2 p3 q0 q1 q2 q3
+// a[0]: 33 32 31 30 04 05 06 07
+// a[1]: 23 22 21 20 14 15 16 17
+// a[2]: 13 12 11 10 24 25 26 27
+// a[3]: 03 02 01 00 34 35 36 37
+// Simply reordering the inputs (3, 2, 1, 0) will reset the low halves, but
+// reverse the high halves.
+// The standard Transpose4x8 will produce the same reversals, but with the
+// order of the low halves also restored relative to the high halves. This is
+// preferable because it puts all values from the same source row back together,
+// but some post-processing is inevitable.
+inline void LoopFilterTranspose4x8(uint16x8_t a[4]) {
+  // b0.val[0]: 00 10 02 12 04 14 06 16
+  // b0.val[1]: 01 11 03 13 05 15 07 17
+  // b1.val[0]: 20 30 22 32 24 34 26 36
+  // b1.val[1]: 21 31 23 33 25 35 27 37
+  const uint16x8x2_t b0 = vtrnq_u16(a[0], a[1]);
+  const uint16x8x2_t b1 = vtrnq_u16(a[2], a[3]);
+
+  // Reverse odd vectors to bring the appropriate items to the front of zips.
+  // b0.val[0]: 00 10 02 12 04 14 06 16
+  // r0       : 03 13 01 11 07 17 05 15
+  // b1.val[0]: 20 30 22 32 24 34 26 36
+  // r1       : 23 33 21 31 27 37 25 35
+  const uint32x4_t r0 = vrev64q_u32(vreinterpretq_u32_u16(b0.val[1]));
+  const uint32x4_t r1 = vrev64q_u32(vreinterpretq_u32_u16(b1.val[1]));
+
+  // Zip to complete the halves.
+  // c0.val[0]: 00 10 20 30 02 12 22 32  p3p1
+  // c0.val[1]: 04 14 24 34 06 16 26 36  q0q2
+  // c1.val[0]: 03 13 23 33 01 11 21 31  p0p2
+  // c1.val[1]: 07 17 27 37 05 15 25 35  q3q1
+  const uint32x4x2_t c0 = vzipq_u32(vreinterpretq_u32_u16(b0.val[0]),
+                                    vreinterpretq_u32_u16(b1.val[0]));
+  const uint32x4x2_t c1 = vzipq_u32(r0, r1);
+
+  // d0.val[0]: 00 10 20 30 07 17 27 37  p3q3
+  // d0.val[1]: 02 12 22 32 05 15 25 35  p1q1
+  // d1.val[0]: 03 13 23 33 04 14 24 34  p0q0
+  // d1.val[1]: 01 11 21 31 06 16 26 36  p2q2
+  const uint16x8x2_t d0 = VtrnqU64(c0.val[0], c1.val[1]);
+  // The third row of c comes first here to swap p2 with q0.
+  const uint16x8x2_t d1 = VtrnqU64(c1.val[0], c0.val[1]);
+
+  // 8x4 Output:
+  // a[0]: 03 13 23 33 04 14 24 34  p0q0
+  // a[1]: 02 12 22 32 05 15 25 35  p1q1
+  // a[2]: 01 11 21 31 06 16 26 36  p2q2
+  // a[3]: 00 10 20 30 07 17 27 37  p3q3
+  a[0] = d1.val[0];  // p0q0
+  a[1] = d0.val[1];  // p1q1
+  a[2] = d1.val[1];  // p2q2
+  a[3] = d0.val[0];  // p3q3
+}
+
 // Reversible if the x4 values are packed next to each other.
 // x4 input / x8 output:
 // a0: 00 01 02 03 40 41 42 43 44
diff --git a/libgav1/src/dsp/arm/convolve_10bit_neon.cc b/libgav1/src/dsp/arm/convolve_10bit_neon.cc
new file mode 100644
index 0000000..b7205df
--- /dev/null
+++ b/libgav1/src/dsp/arm/convolve_10bit_neon.cc
@@ -0,0 +1,3008 @@
+// Copyright 2021 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/convolve.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+// Include the constants and utility functions inside the anonymous namespace.
+#include "src/dsp/convolve.inc"
+
+// Output of ConvolveTest.ShowRange below.
+// Bitdepth: 10 Input range:            [       0,     1023]
+//   Horizontal base upscaled range:    [  -28644,    94116]
+//   Horizontal halved upscaled range:  [  -14322,    47085]
+//   Horizontal downscaled range:       [   -7161,    23529]
+//   Vertical upscaled range:           [-1317624,  2365176]
+//   Pixel output range:                [       0,     1023]
+//   Compound output range:             [    3988,    61532]
+
+template <int filter_index>
+int32x4x2_t SumOnePassTaps(const uint16x8_t* const src,
+                           const int16x4_t* const taps) {
+  const auto* ssrc = reinterpret_cast<const int16x8_t*>(src);
+  int32x4x2_t sum;
+  if (filter_index < 2) {
+    // 6 taps.
+    sum.val[0] = vmull_s16(vget_low_s16(ssrc[0]), taps[0]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[1]), taps[1]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[2]), taps[2]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[3]), taps[3]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[4]), taps[4]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[5]), taps[5]);
+
+    sum.val[1] = vmull_s16(vget_high_s16(ssrc[0]), taps[0]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[1]), taps[1]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[2]), taps[2]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[3]), taps[3]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[4]), taps[4]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[5]), taps[5]);
+  } else if (filter_index == 2) {
+    // 8 taps.
+    sum.val[0] = vmull_s16(vget_low_s16(ssrc[0]), taps[0]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[1]), taps[1]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[2]), taps[2]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[3]), taps[3]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[4]), taps[4]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[5]), taps[5]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[6]), taps[6]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[7]), taps[7]);
+
+    sum.val[1] = vmull_s16(vget_high_s16(ssrc[0]), taps[0]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[1]), taps[1]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[2]), taps[2]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[3]), taps[3]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[4]), taps[4]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[5]), taps[5]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[6]), taps[6]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[7]), taps[7]);
+  } else if (filter_index == 3) {
+    // 2 taps.
+    sum.val[0] = vmull_s16(vget_low_s16(ssrc[0]), taps[0]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[1]), taps[1]);
+
+    sum.val[1] = vmull_s16(vget_high_s16(ssrc[0]), taps[0]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[1]), taps[1]);
+  } else {
+    // 4 taps.
+    sum.val[0] = vmull_s16(vget_low_s16(ssrc[0]), taps[0]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[1]), taps[1]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[2]), taps[2]);
+    sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[3]), taps[3]);
+
+    sum.val[1] = vmull_s16(vget_high_s16(ssrc[0]), taps[0]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[1]), taps[1]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[2]), taps[2]);
+    sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[3]), taps[3]);
+  }
+  return sum;
+}
+
+template <int filter_index>
+int32x4_t SumOnePassTaps(const uint16x4_t* const src,
+                         const int16x4_t* const taps) {
+  const auto* ssrc = reinterpret_cast<const int16x4_t*>(src);
+  int32x4_t sum;
+  if (filter_index < 2) {
+    // 6 taps.
+    sum = vmull_s16(ssrc[0], taps[0]);
+    sum = vmlal_s16(sum, ssrc[1], taps[1]);
+    sum = vmlal_s16(sum, ssrc[2], taps[2]);
+    sum = vmlal_s16(sum, ssrc[3], taps[3]);
+    sum = vmlal_s16(sum, ssrc[4], taps[4]);
+    sum = vmlal_s16(sum, ssrc[5], taps[5]);
+  } else if (filter_index == 2) {
+    // 8 taps.
+    sum = vmull_s16(ssrc[0], taps[0]);
+    sum = vmlal_s16(sum, ssrc[1], taps[1]);
+    sum = vmlal_s16(sum, ssrc[2], taps[2]);
+    sum = vmlal_s16(sum, ssrc[3], taps[3]);
+    sum = vmlal_s16(sum, ssrc[4], taps[4]);
+    sum = vmlal_s16(sum, ssrc[5], taps[5]);
+    sum = vmlal_s16(sum, ssrc[6], taps[6]);
+    sum = vmlal_s16(sum, ssrc[7], taps[7]);
+  } else if (filter_index == 3) {
+    // 2 taps.
+    sum = vmull_s16(ssrc[0], taps[0]);
+    sum = vmlal_s16(sum, ssrc[1], taps[1]);
+  } else {
+    // 4 taps.
+    sum = vmull_s16(ssrc[0], taps[0]);
+    sum = vmlal_s16(sum, ssrc[1], taps[1]);
+    sum = vmlal_s16(sum, ssrc[2], taps[2]);
+    sum = vmlal_s16(sum, ssrc[3], taps[3]);
+  }
+  return sum;
+}
+
+template <int filter_index, bool is_compound, bool is_2d>
+void FilterHorizontalWidth8AndUp(const uint16_t* LIBGAV1_RESTRICT src,
+                                 const ptrdiff_t src_stride,
+                                 void* LIBGAV1_RESTRICT const dest,
+                                 const ptrdiff_t pred_stride, const int width,
+                                 const int height,
+                                 const int16x4_t* const v_tap) {
+  auto* dest16 = static_cast<uint16_t*>(dest);
+  const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+  if (is_2d) {
+    int x = 0;
+    do {
+      const uint16_t* s = src + x;
+      int y = height;
+      do {  // Increasing loop counter x is better.
+        const uint16x8_t src_long = vld1q_u16(s);
+        const uint16x8_t src_long_hi = vld1q_u16(s + 8);
+        uint16x8_t v_src[8];
+        int32x4x2_t v_sum;
+        if (filter_index < 2) {
+          v_src[0] = src_long;
+          v_src[1] = vextq_u16(src_long, src_long_hi, 1);
+          v_src[2] = vextq_u16(src_long, src_long_hi, 2);
+          v_src[3] = vextq_u16(src_long, src_long_hi, 3);
+          v_src[4] = vextq_u16(src_long, src_long_hi, 4);
+          v_src[5] = vextq_u16(src_long, src_long_hi, 5);
+          v_sum = SumOnePassTaps<filter_index>(v_src, v_tap + 1);
+        } else if (filter_index == 2) {
+          v_src[0] = src_long;
+          v_src[1] = vextq_u16(src_long, src_long_hi, 1);
+          v_src[2] = vextq_u16(src_long, src_long_hi, 2);
+          v_src[3] = vextq_u16(src_long, src_long_hi, 3);
+          v_src[4] = vextq_u16(src_long, src_long_hi, 4);
+          v_src[5] = vextq_u16(src_long, src_long_hi, 5);
+          v_src[6] = vextq_u16(src_long, src_long_hi, 6);
+          v_src[7] = vextq_u16(src_long, src_long_hi, 7);
+          v_sum = SumOnePassTaps<filter_index>(v_src, v_tap);
+        } else if (filter_index == 3) {
+          v_src[0] = src_long;
+          v_src[1] = vextq_u16(src_long, src_long_hi, 1);
+          v_sum = SumOnePassTaps<filter_index>(v_src, v_tap + 3);
+        } else {  // filter_index > 3
+          v_src[0] = src_long;
+          v_src[1] = vextq_u16(src_long, src_long_hi, 1);
+          v_src[2] = vextq_u16(src_long, src_long_hi, 2);
+          v_src[3] = vextq_u16(src_long, src_long_hi, 3);
+          v_sum = SumOnePassTaps<filter_index>(v_src, v_tap + 2);
+        }
+
+        const int16x4_t d0 =
+            vqrshrn_n_s32(v_sum.val[0], kInterRoundBitsHorizontal - 1);
+        const int16x4_t d1 =
+            vqrshrn_n_s32(v_sum.val[1], kInterRoundBitsHorizontal - 1);
+        vst1_u16(&dest16[0], vreinterpret_u16_s16(d0));
+        vst1_u16(&dest16[4], vreinterpret_u16_s16(d1));
+        s += src_stride;
+        dest16 += 8;
+      } while (--y != 0);
+      x += 8;
+    } while (x < width);
+    return;
+  }
+  int y = height;
+  do {
+    int x = 0;
+    do {
+      const uint16x8_t src_long = vld1q_u16(src + x);
+      const uint16x8_t src_long_hi = vld1q_u16(src + x + 8);
+      uint16x8_t v_src[8];
+      int32x4x2_t v_sum;
+      if (filter_index < 2) {
+        v_src[0] = src_long;
+        v_src[1] = vextq_u16(src_long, src_long_hi, 1);
+        v_src[2] = vextq_u16(src_long, src_long_hi, 2);
+        v_src[3] = vextq_u16(src_long, src_long_hi, 3);
+        v_src[4] = vextq_u16(src_long, src_long_hi, 4);
+        v_src[5] = vextq_u16(src_long, src_long_hi, 5);
+        v_sum = SumOnePassTaps<filter_index>(v_src, v_tap + 1);
+      } else if (filter_index == 2) {
+        v_src[0] = src_long;
+        v_src[1] = vextq_u16(src_long, src_long_hi, 1);
+        v_src[2] = vextq_u16(src_long, src_long_hi, 2);
+        v_src[3] = vextq_u16(src_long, src_long_hi, 3);
+        v_src[4] = vextq_u16(src_long, src_long_hi, 4);
+        v_src[5] = vextq_u16(src_long, src_long_hi, 5);
+        v_src[6] = vextq_u16(src_long, src_long_hi, 6);
+        v_src[7] = vextq_u16(src_long, src_long_hi, 7);
+        v_sum = SumOnePassTaps<filter_index>(v_src, v_tap);
+      } else if (filter_index == 3) {
+        v_src[0] = src_long;
+        v_src[1] = vextq_u16(src_long, src_long_hi, 1);
+        v_sum = SumOnePassTaps<filter_index>(v_src, v_tap + 3);
+      } else {  // filter_index > 3
+        v_src[0] = src_long;
+        v_src[1] = vextq_u16(src_long, src_long_hi, 1);
+        v_src[2] = vextq_u16(src_long, src_long_hi, 2);
+        v_src[3] = vextq_u16(src_long, src_long_hi, 3);
+        v_sum = SumOnePassTaps<filter_index>(v_src, v_tap + 2);
+      }
+      if (is_compound) {
+        const int16x4_t v_compound_offset = vdup_n_s16(kCompoundOffset);
+        const int16x4_t d0 =
+            vqrshrn_n_s32(v_sum.val[0], kInterRoundBitsHorizontal - 1);
+        const int16x4_t d1 =
+            vqrshrn_n_s32(v_sum.val[1], kInterRoundBitsHorizontal - 1);
+        vst1_u16(&dest16[x],
+                 vreinterpret_u16_s16(vadd_s16(d0, v_compound_offset)));
+        vst1_u16(&dest16[x + 4],
+                 vreinterpret_u16_s16(vadd_s16(d1, v_compound_offset)));
+      } else {
+        // Normally the Horizontal pass does the downshift in two passes:
+        // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
+        // kInterRoundBitsHorizontal). Each one uses a rounding shift.
+        // Combining them requires adding the rounding offset from the skipped
+        // shift.
+        const int32x4_t v_first_shift_rounding_bit =
+            vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 2));
+        v_sum.val[0] = vaddq_s32(v_sum.val[0], v_first_shift_rounding_bit);
+        v_sum.val[1] = vaddq_s32(v_sum.val[1], v_first_shift_rounding_bit);
+        const uint16x4_t d0 = vmin_u16(
+            vqrshrun_n_s32(v_sum.val[0], kFilterBits - 1), v_max_bitdepth);
+        const uint16x4_t d1 = vmin_u16(
+            vqrshrun_n_s32(v_sum.val[1], kFilterBits - 1), v_max_bitdepth);
+        vst1_u16(&dest16[x], d0);
+        vst1_u16(&dest16[x + 4], d1);
+      }
+      x += 8;
+    } while (x < width);
+    src += src_stride;
+    dest16 += pred_stride;
+  } while (--y != 0);
+}
+
+template <int filter_index, bool is_compound, bool is_2d>
+void FilterHorizontalWidth4(const uint16_t* LIBGAV1_RESTRICT src,
+                            const ptrdiff_t src_stride,
+                            void* LIBGAV1_RESTRICT const dest,
+                            const ptrdiff_t pred_stride, const int height,
+                            const int16x4_t* const v_tap) {
+  auto* dest16 = static_cast<uint16_t*>(dest);
+  const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+  int y = height;
+  do {
+    const uint16x8_t v_zero = vdupq_n_u16(0);
+    uint16x4_t v_src[4];
+    int32x4_t v_sum;
+    const uint16x8_t src_long = vld1q_u16(src);
+    v_src[0] = vget_low_u16(src_long);
+    if (filter_index == 3) {
+      v_src[1] = vget_low_u16(vextq_u16(src_long, v_zero, 1));
+      v_sum = SumOnePassTaps<filter_index>(v_src, v_tap + 3);
+    } else {
+      v_src[1] = vget_low_u16(vextq_u16(src_long, v_zero, 1));
+      v_src[2] = vget_low_u16(vextq_u16(src_long, v_zero, 2));
+      v_src[3] = vget_low_u16(vextq_u16(src_long, v_zero, 3));
+      v_sum = SumOnePassTaps<filter_index>(v_src, v_tap + 2);
+    }
+    if (is_compound || is_2d) {
+      const int16x4_t d0 = vqrshrn_n_s32(v_sum, kInterRoundBitsHorizontal - 1);
+      if (is_compound && !is_2d) {
+        vst1_u16(&dest16[0], vreinterpret_u16_s16(
+                                 vadd_s16(d0, vdup_n_s16(kCompoundOffset))));
+      } else {
+        vst1_u16(&dest16[0], vreinterpret_u16_s16(d0));
+      }
+    } else {
+      const int32x4_t v_first_shift_rounding_bit =
+          vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 2));
+      v_sum = vaddq_s32(v_sum, v_first_shift_rounding_bit);
+      const uint16x4_t d0 =
+          vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
+      vst1_u16(&dest16[0], d0);
+    }
+    src += src_stride;
+    dest16 += pred_stride;
+  } while (--y != 0);
+}
+
+template <int filter_index, bool is_2d>
+void FilterHorizontalWidth2(const uint16_t* LIBGAV1_RESTRICT src,
+                            const ptrdiff_t src_stride,
+                            void* LIBGAV1_RESTRICT const dest,
+                            const ptrdiff_t pred_stride, const int height,
+                            const int16x4_t* const v_tap) {
+  auto* dest16 = static_cast<uint16_t*>(dest);
+  const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+  int y = height >> 1;
+  do {
+    const int16x8_t v_zero = vdupq_n_s16(0);
+    const int16x8_t input0 = vreinterpretq_s16_u16(vld1q_u16(src));
+    const int16x8_t input1 = vreinterpretq_s16_u16(vld1q_u16(src + src_stride));
+    const int16x8x2_t input = vzipq_s16(input0, input1);
+    int32x4_t v_sum;
+    if (filter_index == 3) {
+      v_sum = vmull_s16(vget_low_s16(input.val[0]), v_tap[3]);
+      v_sum = vmlal_s16(v_sum,
+                        vget_low_s16(vextq_s16(input.val[0], input.val[1], 2)),
+                        v_tap[4]);
+    } else {
+      v_sum = vmull_s16(vget_low_s16(input.val[0]), v_tap[2]);
+      v_sum = vmlal_s16(v_sum, vget_low_s16(vextq_s16(input.val[0], v_zero, 2)),
+                        v_tap[3]);
+      v_sum = vmlal_s16(v_sum, vget_low_s16(vextq_s16(input.val[0], v_zero, 4)),
+                        v_tap[4]);
+      v_sum = vmlal_s16(v_sum,
+                        vget_low_s16(vextq_s16(input.val[0], input.val[1], 6)),
+                        v_tap[5]);
+    }
+    if (is_2d) {
+      const uint16x4_t d0 = vreinterpret_u16_s16(
+          vqrshrn_n_s32(v_sum, kInterRoundBitsHorizontal - 1));
+      dest16[0] = vget_lane_u16(d0, 0);
+      dest16[1] = vget_lane_u16(d0, 2);
+      dest16 += pred_stride;
+      dest16[0] = vget_lane_u16(d0, 1);
+      dest16[1] = vget_lane_u16(d0, 3);
+      dest16 += pred_stride;
+    } else {
+      // Normally the Horizontal pass does the downshift in two passes:
+      // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
+      // kInterRoundBitsHorizontal). Each one uses a rounding shift.
+      // Combining them requires adding the rounding offset from the skipped
+      // shift.
+      const int32x4_t v_first_shift_rounding_bit =
+          vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 2));
+      v_sum = vaddq_s32(v_sum, v_first_shift_rounding_bit);
+      const uint16x4_t d0 =
+          vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
+      dest16[0] = vget_lane_u16(d0, 0);
+      dest16[1] = vget_lane_u16(d0, 2);
+      dest16 += pred_stride;
+      dest16[0] = vget_lane_u16(d0, 1);
+      dest16[1] = vget_lane_u16(d0, 3);
+      dest16 += pred_stride;
+    }
+    src += src_stride << 1;
+  } while (--y != 0);
+
+  // The 2d filters have an odd |height| because the horizontal pass
+  // generates context for the vertical pass.
+  if (is_2d) {
+    assert(height % 2 == 1);
+    const int16x8_t input = vreinterpretq_s16_u16(vld1q_u16(src));
+    int32x4_t v_sum;
+    if (filter_index == 3) {
+      v_sum = vmull_s16(vget_low_s16(input), v_tap[3]);
+      v_sum =
+          vmlal_s16(v_sum, vget_low_s16(vextq_s16(input, input, 1)), v_tap[4]);
+    } else {
+      v_sum = vmull_s16(vget_low_s16(input), v_tap[2]);
+      v_sum =
+          vmlal_s16(v_sum, vget_low_s16(vextq_s16(input, input, 1)), v_tap[3]);
+      v_sum =
+          vmlal_s16(v_sum, vget_low_s16(vextq_s16(input, input, 2)), v_tap[4]);
+      v_sum =
+          vmlal_s16(v_sum, vget_low_s16(vextq_s16(input, input, 3)), v_tap[5]);
+    }
+    const uint16x4_t d0 = vreinterpret_u16_s16(
+        vqrshrn_n_s32(v_sum, kInterRoundBitsHorizontal - 1));
+    Store2<0>(dest16, d0);
+  }
+}
+
+template <int filter_index, bool is_compound, bool is_2d>
+void FilterHorizontal(const uint16_t* LIBGAV1_RESTRICT const src,
+                      const ptrdiff_t src_stride,
+                      void* LIBGAV1_RESTRICT const dest,
+                      const ptrdiff_t pred_stride, const int width,
+                      const int height, const int16x4_t* const v_tap) {
+  assert(width < 8 || filter_index <= 3);
+  // Don't simplify the redundant if conditions with the template parameters,
+  // which helps the compiler generate compact code.
+  if (width >= 8 && filter_index <= 3) {
+    FilterHorizontalWidth8AndUp<filter_index, is_compound, is_2d>(
+        src, src_stride, dest, pred_stride, width, height, v_tap);
+    return;
+  }
+
+  // Horizontal passes only needs to account for number of taps 2 and 4 when
+  // |width| <= 4.
+  assert(width <= 4);
+  assert(filter_index >= 3 && filter_index <= 5);
+  if (filter_index >= 3 && filter_index <= 5) {
+    if (width == 4) {
+      FilterHorizontalWidth4<filter_index, is_compound, is_2d>(
+          src, src_stride, dest, pred_stride, height, v_tap);
+      return;
+    }
+    assert(width == 2);
+    if (!is_compound) {
+      FilterHorizontalWidth2<filter_index, is_2d>(src, src_stride, dest,
+                                                  pred_stride, height, v_tap);
+    }
+  }
+}
+
+template <bool is_compound = false, bool is_2d = false>
+LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
+    const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride,
+    const int width, const int height, const int filter_id,
+    const int filter_index) {
+  // Duplicate the absolute value for each tap.  Negative taps are corrected
+  // by using the vmlsl_u8 instruction.  Positive taps use vmlal_u8.
+  int16x4_t v_tap[kSubPixelTaps];
+  assert(filter_id != 0);
+
+  for (int k = 0; k < kSubPixelTaps; ++k) {
+    v_tap[k] = vdup_n_s16(kHalfSubPixelFilters[filter_index][filter_id][k]);
+  }
+
+  if (filter_index == 2) {  // 8 tap.
+    FilterHorizontal<2, is_compound, is_2d>(src, src_stride, dst, dst_stride,
+                                            width, height, v_tap);
+  } else if (filter_index == 1) {  // 6 tap.
+    FilterHorizontal<1, is_compound, is_2d>(src + 1, src_stride, dst,
+                                            dst_stride, width, height, v_tap);
+  } else if (filter_index == 0) {  // 6 tap.
+    FilterHorizontal<0, is_compound, is_2d>(src + 1, src_stride, dst,
+                                            dst_stride, width, height, v_tap);
+  } else if (filter_index == 4) {  // 4 tap.
+    FilterHorizontal<4, is_compound, is_2d>(src + 2, src_stride, dst,
+                                            dst_stride, width, height, v_tap);
+  } else if (filter_index == 5) {  // 4 tap.
+    FilterHorizontal<5, is_compound, is_2d>(src + 2, src_stride, dst,
+                                            dst_stride, width, height, v_tap);
+  } else {  // 2 tap.
+    FilterHorizontal<3, is_compound, is_2d>(src + 3, src_stride, dst,
+                                            dst_stride, width, height, v_tap);
+  }
+}
+
+void ConvolveHorizontal_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int /*vertical_filter_index*/, const int horizontal_filter_id,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
+  const int filter_index = GetFilterIndex(horizontal_filter_index, width);
+  // Set |src| to the outermost tap.
+  const auto* const src =
+      static_cast<const uint16_t*>(reference) - kHorizontalOffset;
+  auto* const dest = static_cast<uint16_t*>(prediction);
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  const ptrdiff_t dst_stride = pred_stride >> 1;
+
+  DoHorizontalPass(src, src_stride, dest, dst_stride, width, height,
+                   horizontal_filter_id, filter_index);
+}
+
+void ConvolveCompoundHorizontal_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int /*vertical_filter_index*/, const int horizontal_filter_id,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
+  const int filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const auto* const src =
+      static_cast<const uint16_t*>(reference) - kHorizontalOffset;
+  auto* const dest = static_cast<uint16_t*>(prediction);
+  const ptrdiff_t src_stride = reference_stride >> 1;
+
+  DoHorizontalPass</*is_compound=*/true>(src, src_stride, dest, width, width,
+                                         height, horizontal_filter_id,
+                                         filter_index);
+}
+
+template <int filter_index, bool is_compound = false>
+void FilterVertical(const uint16_t* LIBGAV1_RESTRICT const src,
+                    const ptrdiff_t src_stride,
+                    void* LIBGAV1_RESTRICT const dst,
+                    const ptrdiff_t dst_stride, const int width,
+                    const int height, const int16x4_t* const taps) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  const int next_row = num_taps - 1;
+  const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+  auto* const dst16 = static_cast<uint16_t*>(dst);
+  assert(width >= 8);
+
+  int x = 0;
+  do {
+    const uint16_t* src_x = src + x;
+    uint16x8_t srcs[8];
+    srcs[0] = vld1q_u16(src_x);
+    src_x += src_stride;
+    if (num_taps >= 4) {
+      srcs[1] = vld1q_u16(src_x);
+      src_x += src_stride;
+      srcs[2] = vld1q_u16(src_x);
+      src_x += src_stride;
+      if (num_taps >= 6) {
+        srcs[3] = vld1q_u16(src_x);
+        src_x += src_stride;
+        srcs[4] = vld1q_u16(src_x);
+        src_x += src_stride;
+        if (num_taps == 8) {
+          srcs[5] = vld1q_u16(src_x);
+          src_x += src_stride;
+          srcs[6] = vld1q_u16(src_x);
+          src_x += src_stride;
+        }
+      }
+    }
+
+    // Decreasing the y loop counter produces worse code with clang.
+    // Don't unroll this loop since it generates too much code and the decoder
+    // is even slower.
+    int y = 0;
+    do {
+      srcs[next_row] = vld1q_u16(src_x);
+      src_x += src_stride;
+
+      const int32x4x2_t v_sum = SumOnePassTaps<filter_index>(srcs, taps);
+      if (is_compound) {
+        const int16x4_t v_compound_offset = vdup_n_s16(kCompoundOffset);
+        const int16x4_t d0 =
+            vqrshrn_n_s32(v_sum.val[0], kInterRoundBitsHorizontal - 1);
+        const int16x4_t d1 =
+            vqrshrn_n_s32(v_sum.val[1], kInterRoundBitsHorizontal - 1);
+        vst1_u16(dst16 + x + y * dst_stride,
+                 vreinterpret_u16_s16(vadd_s16(d0, v_compound_offset)));
+        vst1_u16(dst16 + x + 4 + y * dst_stride,
+                 vreinterpret_u16_s16(vadd_s16(d1, v_compound_offset)));
+      } else {
+        const uint16x4_t d0 = vmin_u16(
+            vqrshrun_n_s32(v_sum.val[0], kFilterBits - 1), v_max_bitdepth);
+        const uint16x4_t d1 = vmin_u16(
+            vqrshrun_n_s32(v_sum.val[1], kFilterBits - 1), v_max_bitdepth);
+        vst1_u16(dst16 + x + y * dst_stride, d0);
+        vst1_u16(dst16 + x + 4 + y * dst_stride, d1);
+      }
+
+      srcs[0] = srcs[1];
+      if (num_taps >= 4) {
+        srcs[1] = srcs[2];
+        srcs[2] = srcs[3];
+        if (num_taps >= 6) {
+          srcs[3] = srcs[4];
+          srcs[4] = srcs[5];
+          if (num_taps == 8) {
+            srcs[5] = srcs[6];
+            srcs[6] = srcs[7];
+          }
+        }
+      }
+    } while (++y < height);
+    x += 8;
+  } while (x < width);
+}
+
+template <int filter_index, bool is_compound = false>
+void FilterVertical4xH(const uint16_t* LIBGAV1_RESTRICT src,
+                       const ptrdiff_t src_stride,
+                       void* LIBGAV1_RESTRICT const dst,
+                       const ptrdiff_t dst_stride, const int height,
+                       const int16x4_t* const taps) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  const int next_row = num_taps - 1;
+  const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+
+  uint16x4_t srcs[9];
+  srcs[0] = vld1_u16(src);
+  src += src_stride;
+  if (num_taps >= 4) {
+    srcs[1] = vld1_u16(src);
+    src += src_stride;
+    srcs[2] = vld1_u16(src);
+    src += src_stride;
+    if (num_taps >= 6) {
+      srcs[3] = vld1_u16(src);
+      src += src_stride;
+      srcs[4] = vld1_u16(src);
+      src += src_stride;
+      if (num_taps == 8) {
+        srcs[5] = vld1_u16(src);
+        src += src_stride;
+        srcs[6] = vld1_u16(src);
+        src += src_stride;
+      }
+    }
+  }
+
+  int y = height;
+  do {
+    srcs[next_row] = vld1_u16(src);
+    src += src_stride;
+    srcs[num_taps] = vld1_u16(src);
+    src += src_stride;
+
+    const int32x4_t v_sum = SumOnePassTaps<filter_index>(srcs, taps);
+    const int32x4_t v_sum_1 = SumOnePassTaps<filter_index>(srcs + 1, taps);
+    if (is_compound) {
+      const int16x4_t d0 = vqrshrn_n_s32(v_sum, kInterRoundBitsHorizontal - 1);
+      const int16x4_t d1 =
+          vqrshrn_n_s32(v_sum_1, kInterRoundBitsHorizontal - 1);
+      vst1_u16(dst16,
+               vreinterpret_u16_s16(vadd_s16(d0, vdup_n_s16(kCompoundOffset))));
+      dst16 += dst_stride;
+      vst1_u16(dst16,
+               vreinterpret_u16_s16(vadd_s16(d1, vdup_n_s16(kCompoundOffset))));
+      dst16 += dst_stride;
+    } else {
+      const uint16x4_t d0 =
+          vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
+      const uint16x4_t d1 =
+          vmin_u16(vqrshrun_n_s32(v_sum_1, kFilterBits - 1), v_max_bitdepth);
+      vst1_u16(dst16, d0);
+      dst16 += dst_stride;
+      vst1_u16(dst16, d1);
+      dst16 += dst_stride;
+    }
+
+    srcs[0] = srcs[2];
+    if (num_taps >= 4) {
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      if (num_taps >= 6) {
+        srcs[3] = srcs[5];
+        srcs[4] = srcs[6];
+        if (num_taps == 8) {
+          srcs[5] = srcs[7];
+          srcs[6] = srcs[8];
+        }
+      }
+    }
+    y -= 2;
+  } while (y != 0);
+}
+
+template <int filter_index>
+void FilterVertical2xH(const uint16_t* LIBGAV1_RESTRICT src,
+                       const ptrdiff_t src_stride,
+                       void* LIBGAV1_RESTRICT const dst,
+                       const ptrdiff_t dst_stride, const int height,
+                       const int16x4_t* const taps) {
+  const int num_taps = GetNumTapsInFilter(filter_index);
+  const int next_row = num_taps - 1;
+  const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+  const uint16x4_t v_zero = vdup_n_u16(0);
+
+  uint16x4_t srcs[9];
+  srcs[0] = Load2<0>(src, v_zero);
+  src += src_stride;
+  if (num_taps >= 4) {
+    srcs[0] = Load2<1>(src, srcs[0]);
+    src += src_stride;
+    srcs[2] = Load2<0>(src, v_zero);
+    src += src_stride;
+    srcs[1] = vext_u16(srcs[0], srcs[2], 2);
+    if (num_taps >= 6) {
+      srcs[2] = Load2<1>(src, srcs[2]);
+      src += src_stride;
+      srcs[4] = Load2<0>(src, v_zero);
+      src += src_stride;
+      srcs[3] = vext_u16(srcs[2], srcs[4], 2);
+      if (num_taps == 8) {
+        srcs[4] = Load2<1>(src, srcs[4]);
+        src += src_stride;
+        srcs[6] = Load2<0>(src, v_zero);
+        src += src_stride;
+        srcs[5] = vext_u16(srcs[4], srcs[6], 2);
+      }
+    }
+  }
+
+  int y = height;
+  do {
+    srcs[next_row - 1] = Load2<1>(src, srcs[next_row - 1]);
+    src += src_stride;
+    srcs[num_taps] = Load2<0>(src, v_zero);
+    src += src_stride;
+    srcs[next_row] = vext_u16(srcs[next_row - 1], srcs[num_taps], 2);
+
+    const int32x4_t v_sum = SumOnePassTaps<filter_index>(srcs, taps);
+    const uint16x4_t d0 =
+        vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
+    Store2<0>(dst16, d0);
+    dst16 += dst_stride;
+    Store2<1>(dst16, d0);
+    dst16 += dst_stride;
+
+    srcs[0] = srcs[2];
+    if (num_taps >= 4) {
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      if (num_taps >= 6) {
+        srcs[3] = srcs[5];
+        srcs[4] = srcs[6];
+        if (num_taps == 8) {
+          srcs[5] = srcs[7];
+          srcs[6] = srcs[8];
+        }
+      }
+    }
+    y -= 2;
+  } while (y != 0);
+}
+
+template <int num_taps, bool is_compound>
+int16x8_t SimpleSum2DVerticalTaps(const int16x8_t* const src,
+                                  const int16x8_t taps) {
+  const int16x4_t taps_lo = vget_low_s16(taps);
+  const int16x4_t taps_hi = vget_high_s16(taps);
+  int32x4_t sum_lo, sum_hi;
+  if (num_taps == 8) {
+    sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 0);
+    sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 0);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 1);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 1);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 2);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 2);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_lo, 3);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_lo, 3);
+
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 0);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 0);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 1);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 1);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[6]), taps_hi, 2);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[6]), taps_hi, 2);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[7]), taps_hi, 3);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[7]), taps_hi, 3);
+  } else if (num_taps == 6) {
+    sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 1);
+    sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 1);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 2);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 2);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 3);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 3);
+
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 0);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 0);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 1);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 1);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 2);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 2);
+  } else if (num_taps == 4) {
+    sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 2);
+    sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 2);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 3);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 3);
+
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_hi, 0);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_hi, 0);
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 1);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 1);
+  } else if (num_taps == 2) {
+    sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 3);
+    sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 3);
+
+    sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_hi, 0);
+    sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_hi, 0);
+  }
+
+  if (is_compound) {
+    // Output is compound, so leave signed and do not saturate. Offset will
+    // accurately bring the value back into positive range.
+    return vcombine_s16(
+        vrshrn_n_s32(sum_lo, kInterRoundBitsCompoundVertical - 1),
+        vrshrn_n_s32(sum_hi, kInterRoundBitsCompoundVertical - 1));
+  }
+
+  // Output is pixel, so saturate to clip at 0.
+  return vreinterpretq_s16_u16(
+      vcombine_u16(vqrshrun_n_s32(sum_lo, kInterRoundBitsVertical - 1),
+                   vqrshrun_n_s32(sum_hi, kInterRoundBitsVertical - 1)));
+}
+
+template <int num_taps, bool is_compound = false>
+void Filter2DVerticalWidth8AndUp(const int16_t* LIBGAV1_RESTRICT src,
+                                 void* LIBGAV1_RESTRICT const dst,
+                                 const ptrdiff_t dst_stride, const int width,
+                                 const int height, const int16x8_t taps) {
+  assert(width >= 8);
+  constexpr int next_row = num_taps - 1;
+  const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
+  auto* const dst16 = static_cast<uint16_t*>(dst);
+
+  int x = 0;
+  do {
+    int16x8_t srcs[9];
+    srcs[0] = vld1q_s16(src);
+    src += 8;
+    if (num_taps >= 4) {
+      srcs[1] = vld1q_s16(src);
+      src += 8;
+      srcs[2] = vld1q_s16(src);
+      src += 8;
+      if (num_taps >= 6) {
+        srcs[3] = vld1q_s16(src);
+        src += 8;
+        srcs[4] = vld1q_s16(src);
+        src += 8;
+        if (num_taps == 8) {
+          srcs[5] = vld1q_s16(src);
+          src += 8;
+          srcs[6] = vld1q_s16(src);
+          src += 8;
+        }
+      }
+    }
+
+    uint16_t* d16 = dst16 + x;
+    int y = height;
+    do {
+      srcs[next_row] = vld1q_s16(src);
+      src += 8;
+      srcs[next_row + 1] = vld1q_s16(src);
+      src += 8;
+      const int16x8_t sum0 =
+          SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs + 0, taps);
+      const int16x8_t sum1 =
+          SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs + 1, taps);
+      if (is_compound) {
+        const int16x8_t v_compound_offset = vdupq_n_s16(kCompoundOffset);
+        vst1q_u16(d16,
+                  vreinterpretq_u16_s16(vaddq_s16(sum0, v_compound_offset)));
+        d16 += dst_stride;
+        vst1q_u16(d16,
+                  vreinterpretq_u16_s16(vaddq_s16(sum1, v_compound_offset)));
+        d16 += dst_stride;
+      } else {
+        vst1q_u16(d16, vminq_u16(vreinterpretq_u16_s16(sum0), v_max_bitdepth));
+        d16 += dst_stride;
+        vst1q_u16(d16, vminq_u16(vreinterpretq_u16_s16(sum1), v_max_bitdepth));
+        d16 += dst_stride;
+      }
+      srcs[0] = srcs[2];
+      if (num_taps >= 4) {
+        srcs[1] = srcs[3];
+        srcs[2] = srcs[4];
+        if (num_taps >= 6) {
+          srcs[3] = srcs[5];
+          srcs[4] = srcs[6];
+          if (num_taps == 8) {
+            srcs[5] = srcs[7];
+            srcs[6] = srcs[8];
+          }
+        }
+      }
+      y -= 2;
+    } while (y != 0);
+    x += 8;
+  } while (x < width);
+}
+
+// Take advantage of |src_stride| == |width| to process two rows at a time.
+template <int num_taps, bool is_compound = false>
+void Filter2DVerticalWidth4(const int16_t* LIBGAV1_RESTRICT src,
+                            void* LIBGAV1_RESTRICT const dst,
+                            const ptrdiff_t dst_stride, const int height,
+                            const int16x8_t taps) {
+  const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+
+  int16x8_t srcs[9];
+  srcs[0] = vld1q_s16(src);
+  src += 8;
+  if (num_taps >= 4) {
+    srcs[2] = vld1q_s16(src);
+    src += 8;
+    srcs[1] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[2]));
+    if (num_taps >= 6) {
+      srcs[4] = vld1q_s16(src);
+      src += 8;
+      srcs[3] = vcombine_s16(vget_high_s16(srcs[2]), vget_low_s16(srcs[4]));
+      if (num_taps == 8) {
+        srcs[6] = vld1q_s16(src);
+        src += 8;
+        srcs[5] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[6]));
+      }
+    }
+  }
+
+  int y = height;
+  do {
+    srcs[num_taps] = vld1q_s16(src);
+    src += 8;
+    srcs[num_taps - 1] = vcombine_s16(vget_high_s16(srcs[num_taps - 2]),
+                                      vget_low_s16(srcs[num_taps]));
+
+    const int16x8_t sum =
+        SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
+    if (is_compound) {
+      const int16x8_t v_compound_offset = vdupq_n_s16(kCompoundOffset);
+      vst1q_u16(dst16,
+                vreinterpretq_u16_s16(vaddq_s16(sum, v_compound_offset)));
+      dst16 += 4 << 1;
+    } else {
+      const uint16x8_t d0 =
+          vminq_u16(vreinterpretq_u16_s16(sum), v_max_bitdepth);
+      vst1_u16(dst16, vget_low_u16(d0));
+      dst16 += dst_stride;
+      vst1_u16(dst16, vget_high_u16(d0));
+      dst16 += dst_stride;
+    }
+
+    srcs[0] = srcs[2];
+    if (num_taps >= 4) {
+      srcs[1] = srcs[3];
+      srcs[2] = srcs[4];
+      if (num_taps >= 6) {
+        srcs[3] = srcs[5];
+        srcs[4] = srcs[6];
+        if (num_taps == 8) {
+          srcs[5] = srcs[7];
+          srcs[6] = srcs[8];
+        }
+      }
+    }
+    y -= 2;
+  } while (y != 0);
+}
+
+// Take advantage of |src_stride| == |width| to process four rows at a time.
+template <int num_taps>
+void Filter2DVerticalWidth2(const int16_t* LIBGAV1_RESTRICT src,
+                            void* LIBGAV1_RESTRICT const dst,
+                            const ptrdiff_t dst_stride, const int height,
+                            const int16x8_t taps) {
+  constexpr int next_row = (num_taps < 6) ? 4 : 8;
+  const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
+  auto* dst16 = static_cast<uint16_t*>(dst);
+
+  int16x8_t srcs[9];
+  srcs[0] = vld1q_s16(src);
+  src += 8;
+  if (num_taps >= 6) {
+    srcs[4] = vld1q_s16(src);
+    src += 8;
+    srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
+    if (num_taps == 8) {
+      srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
+      srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
+    }
+  }
+
+  int y = height;
+  do {
+    srcs[next_row] = vld1q_s16(src);
+    src += 8;
+    if (num_taps == 2) {
+      srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
+    } else if (num_taps == 4) {
+      srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
+      srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
+      srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
+    } else if (num_taps == 6) {
+      srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
+      srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
+      srcs[5] = vextq_s16(srcs[4], srcs[8], 2);
+    } else if (num_taps == 8) {
+      srcs[5] = vextq_s16(srcs[4], srcs[8], 2);
+      srcs[6] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[8]));
+      srcs[7] = vextq_s16(srcs[4], srcs[8], 6);
+    }
+    const int16x8_t sum =
+        SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps);
+    const uint16x8_t d0 = vminq_u16(vreinterpretq_u16_s16(sum), v_max_bitdepth);
+    Store2<0>(dst16, d0);
+    dst16 += dst_stride;
+    Store2<1>(dst16, d0);
+    // When |height| <= 4 the taps are restricted to 2 and 4 tap variants.
+    // Therefore we don't need to check this condition when |height| > 4.
+    if (num_taps <= 4 && height == 2) return;
+    dst16 += dst_stride;
+    Store2<2>(dst16, d0);
+    dst16 += dst_stride;
+    Store2<3>(dst16, d0);
+    dst16 += dst_stride;
+
+    srcs[0] = srcs[4];
+    if (num_taps == 6) {
+      srcs[1] = srcs[5];
+      srcs[4] = srcs[8];
+    } else if (num_taps == 8) {
+      srcs[1] = srcs[5];
+      srcs[2] = srcs[6];
+      srcs[3] = srcs[7];
+      srcs[4] = srcs[8];
+    }
+
+    y -= 4;
+  } while (y != 0);
+}
+
+template <int vertical_taps>
+void Filter2DVertical(const int16_t* LIBGAV1_RESTRICT const intermediate_result,
+                      const int width, const int height, const int16x8_t taps,
+                      void* LIBGAV1_RESTRICT const prediction,
+                      const ptrdiff_t pred_stride) {
+  auto* const dest = static_cast<uint16_t*>(prediction);
+  if (width >= 8) {
+    Filter2DVerticalWidth8AndUp<vertical_taps>(
+        intermediate_result, dest, pred_stride, width, height, taps);
+  } else if (width == 4) {
+    Filter2DVerticalWidth4<vertical_taps>(intermediate_result, dest,
+                                          pred_stride, height, taps);
+  } else {
+    assert(width == 2);
+    Filter2DVerticalWidth2<vertical_taps>(intermediate_result, dest,
+                                          pred_stride, height, taps);
+  }
+}
+
+void Convolve2D_NEON(const void* LIBGAV1_RESTRICT const reference,
+                     const ptrdiff_t reference_stride,
+                     const int horizontal_filter_index,
+                     const int vertical_filter_index,
+                     const int horizontal_filter_id,
+                     const int vertical_filter_id, const int width,
+                     const int height, void* LIBGAV1_RESTRICT const prediction,
+                     const ptrdiff_t pred_stride) {
+  const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
+  // The output of the horizontal filter is guaranteed to fit in 16 bits.
+  int16_t intermediate_result[kMaxSuperBlockSizeInPixels *
+                              (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
+#if LIBGAV1_MSAN
+  // Quiet msan warnings. Set with random non-zero value to aid in debugging.
+  memset(intermediate_result, 0x43, sizeof(intermediate_result));
+#endif
+  const int intermediate_height = height + vertical_taps - 1;
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  const auto* const src = static_cast<const uint16_t*>(reference) -
+                          (vertical_taps / 2 - 1) * src_stride -
+                          kHorizontalOffset;
+  const ptrdiff_t dest_stride = pred_stride >> 1;
+
+  DoHorizontalPass</*is_compound=*/false, /*is_2d=*/true>(
+      src, src_stride, intermediate_result, width, width, intermediate_height,
+      horizontal_filter_id, horiz_filter_index);
+
+  assert(vertical_filter_id != 0);
+  const int16x8_t taps = vmovl_s8(
+      vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]));
+  if (vertical_taps == 8) {
+    Filter2DVertical<8>(intermediate_result, width, height, taps, prediction,
+                        dest_stride);
+  } else if (vertical_taps == 6) {
+    Filter2DVertical<6>(intermediate_result, width, height, taps, prediction,
+                        dest_stride);
+  } else if (vertical_taps == 4) {
+    Filter2DVertical<4>(intermediate_result, width, height, taps, prediction,
+                        dest_stride);
+  } else {  // |vertical_taps| == 2
+    Filter2DVertical<2>(intermediate_result, width, height, taps, prediction,
+                        dest_stride);
+  }
+}
+
+template <int vertical_taps>
+void Compound2DVertical(
+    const int16_t* LIBGAV1_RESTRICT const intermediate_result, const int width,
+    const int height, const int16x8_t taps,
+    void* LIBGAV1_RESTRICT const prediction) {
+  auto* const dest = static_cast<uint16_t*>(prediction);
+  if (width == 4) {
+    Filter2DVerticalWidth4<vertical_taps, /*is_compound=*/true>(
+        intermediate_result, dest, width, height, taps);
+  } else {
+    Filter2DVerticalWidth8AndUp<vertical_taps, /*is_compound=*/true>(
+        intermediate_result, dest, width, width, height, taps);
+  }
+}
+
+void ConvolveCompound2D_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int vertical_filter_index, const int horizontal_filter_id,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
+  // The output of the horizontal filter, i.e. the intermediate_result, is
+  // guaranteed to fit in int16_t.
+  int16_t
+      intermediate_result[(kMaxSuperBlockSizeInPixels *
+                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1))];
+
+  // Horizontal filter.
+  // Filter types used for width <= 4 are different from those for width > 4.
+  // When width > 4, the valid filter index range is always [0, 3].
+  // When width <= 4, the valid filter index range is always [4, 5].
+  // Similarly for height.
+  const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
+  const int intermediate_height = height + vertical_taps - 1;
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  const auto* const src = static_cast<const uint16_t*>(reference) -
+                          (vertical_taps / 2 - 1) * src_stride -
+                          kHorizontalOffset;
+
+  DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>(
+      src, src_stride, intermediate_result, width, width, intermediate_height,
+      horizontal_filter_id, horiz_filter_index);
+
+  // Vertical filter.
+  assert(vertical_filter_id != 0);
+  const int16x8_t taps = vmovl_s8(
+      vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]));
+  if (vertical_taps == 8) {
+    Compound2DVertical<8>(intermediate_result, width, height, taps, prediction);
+  } else if (vertical_taps == 6) {
+    Compound2DVertical<6>(intermediate_result, width, height, taps, prediction);
+  } else if (vertical_taps == 4) {
+    Compound2DVertical<4>(intermediate_result, width, height, taps, prediction);
+  } else {  // |vertical_taps| == 2
+    Compound2DVertical<2>(intermediate_result, width, height, taps, prediction);
+  }
+}
+
+void ConvolveVertical_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int vertical_filter_index, const int /*horizontal_filter_id*/,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
+  const int filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(filter_index);
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  const auto* src = static_cast<const uint16_t*>(reference) -
+                    (vertical_taps / 2 - 1) * src_stride;
+  auto* const dest = static_cast<uint16_t*>(prediction);
+  const ptrdiff_t dest_stride = pred_stride >> 1;
+  assert(vertical_filter_id != 0);
+
+  int16x4_t taps[8];
+  for (int k = 0; k < kSubPixelTaps; ++k) {
+    taps[k] =
+        vdup_n_s16(kHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
+  }
+
+  if (filter_index == 0) {  // 6 tap.
+    if (width == 2) {
+      FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else if (width == 4) {
+      FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else {
+      FilterVertical<0>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 1);
+    }
+  } else if ((static_cast<int>(filter_index == 1) &
+              (static_cast<int>(vertical_filter_id == 1) |
+               static_cast<int>(vertical_filter_id == 7) |
+               static_cast<int>(vertical_filter_id == 8) |
+               static_cast<int>(vertical_filter_id == 9) |
+               static_cast<int>(vertical_filter_id == 15))) != 0) {  // 6 tap.
+    if (width == 2) {
+      FilterVertical2xH<1>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else if (width == 4) {
+      FilterVertical4xH<1>(src, src_stride, dest, dest_stride, height,
+                           taps + 1);
+    } else {
+      FilterVertical<1>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 1);
+    }
+  } else if (filter_index == 2) {  // 8 tap.
+    if (width == 2) {
+      FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps);
+    } else if (width == 4) {
+      FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps);
+    } else {
+      FilterVertical<2>(src, src_stride, dest, dest_stride, width, height,
+                        taps);
+    }
+  } else if (filter_index == 3) {  // 2 tap.
+    if (width == 2) {
+      FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height,
+                           taps + 3);
+    } else if (width == 4) {
+      FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height,
+                           taps + 3);
+    } else {
+      FilterVertical<3>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 3);
+    }
+  } else {
+    // 4 tap. When |filter_index| == 1 the |vertical_filter_id| values listed
+    // below map to 4 tap filters.
+    assert(filter_index == 5 || filter_index == 4 ||
+           (filter_index == 1 &&
+            (vertical_filter_id == 0 || vertical_filter_id == 2 ||
+             vertical_filter_id == 3 || vertical_filter_id == 4 ||
+             vertical_filter_id == 5 || vertical_filter_id == 6 ||
+             vertical_filter_id == 10 || vertical_filter_id == 11 ||
+             vertical_filter_id == 12 || vertical_filter_id == 13 ||
+             vertical_filter_id == 14)));
+    // According to GetNumTapsInFilter() this has 6 taps but here we are
+    // treating it as though it has 4.
+    if (filter_index == 1) src += src_stride;
+    if (width == 2) {
+      FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height,
+                           taps + 2);
+    } else if (width == 4) {
+      FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height,
+                           taps + 2);
+    } else {
+      FilterVertical<5>(src, src_stride, dest, dest_stride, width, height,
+                        taps + 2);
+    }
+  }
+}
+
+void ConvolveCompoundVertical_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int vertical_filter_index, const int /*horizontal_filter_id*/,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
+  const int filter_index = GetFilterIndex(vertical_filter_index, height);
+  const int vertical_taps = GetNumTapsInFilter(filter_index);
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  const auto* src = static_cast<const uint16_t*>(reference) -
+                    (vertical_taps / 2 - 1) * src_stride;
+  auto* const dest = static_cast<uint16_t*>(prediction);
+  assert(vertical_filter_id != 0);
+
+  int16x4_t taps[8];
+  for (int k = 0; k < kSubPixelTaps; ++k) {
+    taps[k] =
+        vdup_n_s16(kHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
+  }
+
+  if (filter_index == 0) {  // 6 tap.
+    if (width == 4) {
+      FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 1);
+    } else {
+      FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 1);
+    }
+  } else if ((static_cast<int>(filter_index == 1) &
+              (static_cast<int>(vertical_filter_id == 1) |
+               static_cast<int>(vertical_filter_id == 7) |
+               static_cast<int>(vertical_filter_id == 8) |
+               static_cast<int>(vertical_filter_id == 9) |
+               static_cast<int>(vertical_filter_id == 15))) != 0) {  // 6 tap.
+    if (width == 4) {
+      FilterVertical4xH<1, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 1);
+    } else {
+      FilterVertical<1, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 1);
+    }
+  } else if (filter_index == 2) {  // 8 tap.
+    if (width == 4) {
+      FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps);
+    } else {
+      FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps);
+    }
+  } else if (filter_index == 3) {  // 2 tap.
+    if (width == 4) {
+      FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 3);
+    } else {
+      FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 3);
+    }
+  } else {
+    // 4 tap. When |filter_index| == 1 the |filter_id| values listed below map
+    // to 4 tap filters.
+    assert(filter_index == 5 || filter_index == 4 ||
+           (filter_index == 1 &&
+            (vertical_filter_id == 2 || vertical_filter_id == 3 ||
+             vertical_filter_id == 4 || vertical_filter_id == 5 ||
+             vertical_filter_id == 6 || vertical_filter_id == 10 ||
+             vertical_filter_id == 11 || vertical_filter_id == 12 ||
+             vertical_filter_id == 13 || vertical_filter_id == 14)));
+    // According to GetNumTapsInFilter() this has 6 taps but here we are
+    // treating it as though it has 4.
+    if (filter_index == 1) src += src_stride;
+    if (width == 4) {
+      FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4,
+                                                 height, taps + 2);
+    } else {
+      FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width,
+                                              width, height, taps + 2);
+    }
+  }
+}
+
+void ConvolveCompoundCopy_NEON(
+    const void* const reference, const ptrdiff_t reference_stride,
+    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
+    const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/,
+    const int width, const int height, void* const prediction,
+    const ptrdiff_t /*pred_stride*/) {
+  const auto* src = static_cast<const uint16_t*>(reference);
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  auto* dest = static_cast<uint16_t*>(prediction);
+  constexpr int final_shift =
+      kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
+  const uint16x8_t offset =
+      vdupq_n_u16((1 << kBitdepth10) + (1 << (kBitdepth10 - 1)));
+
+  if (width >= 16) {
+    int y = height;
+    do {
+      int x = 0;
+      int w = width;
+      do {
+        const uint16x8_t v_src_lo = vld1q_u16(&src[x]);
+        const uint16x8_t v_src_hi = vld1q_u16(&src[x + 8]);
+        const uint16x8_t v_sum_lo = vaddq_u16(v_src_lo, offset);
+        const uint16x8_t v_sum_hi = vaddq_u16(v_src_hi, offset);
+        const uint16x8_t v_dest_lo = vshlq_n_u16(v_sum_lo, final_shift);
+        const uint16x8_t v_dest_hi = vshlq_n_u16(v_sum_hi, final_shift);
+        vst1q_u16(&dest[x], v_dest_lo);
+        vst1q_u16(&dest[x + 8], v_dest_hi);
+        x += 16;
+        w -= 16;
+      } while (w != 0);
+      src += src_stride;
+      dest += width;
+    } while (--y != 0);
+  } else if (width == 8) {
+    int y = height;
+    do {
+      const uint16x8_t v_src_lo = vld1q_u16(&src[0]);
+      const uint16x8_t v_src_hi = vld1q_u16(&src[src_stride]);
+      const uint16x8_t v_sum_lo = vaddq_u16(v_src_lo, offset);
+      const uint16x8_t v_sum_hi = vaddq_u16(v_src_hi, offset);
+      const uint16x8_t v_dest_lo = vshlq_n_u16(v_sum_lo, final_shift);
+      const uint16x8_t v_dest_hi = vshlq_n_u16(v_sum_hi, final_shift);
+      vst1q_u16(&dest[0], v_dest_lo);
+      vst1q_u16(&dest[8], v_dest_hi);
+      src += src_stride << 1;
+      dest += 16;
+      y -= 2;
+    } while (y != 0);
+  } else {  // width == 4
+    int y = height;
+    do {
+      const uint16x4_t v_src_lo = vld1_u16(&src[0]);
+      const uint16x4_t v_src_hi = vld1_u16(&src[src_stride]);
+      const uint16x4_t v_sum_lo = vadd_u16(v_src_lo, vget_low_u16(offset));
+      const uint16x4_t v_sum_hi = vadd_u16(v_src_hi, vget_low_u16(offset));
+      const uint16x4_t v_dest_lo = vshl_n_u16(v_sum_lo, final_shift);
+      const uint16x4_t v_dest_hi = vshl_n_u16(v_sum_hi, final_shift);
+      vst1_u16(&dest[0], v_dest_lo);
+      vst1_u16(&dest[4], v_dest_hi);
+      src += src_stride << 1;
+      dest += 8;
+      y -= 2;
+    } while (y != 0);
+  }
+}
+
+inline void HalfAddHorizontal(const uint16_t* LIBGAV1_RESTRICT const src,
+                              uint16_t* LIBGAV1_RESTRICT const dst) {
+  const uint16x8_t left = vld1q_u16(src);
+  const uint16x8_t right = vld1q_u16(src + 1);
+  vst1q_u16(dst, vrhaddq_u16(left, right));
+}
+
+inline void HalfAddHorizontal16(const uint16_t* LIBGAV1_RESTRICT const src,
+                                uint16_t* LIBGAV1_RESTRICT const dst) {
+  HalfAddHorizontal(src, dst);
+  HalfAddHorizontal(src + 8, dst + 8);
+}
+
+template <int width>
+inline void IntraBlockCopyHorizontal(const uint16_t* LIBGAV1_RESTRICT src,
+                                     const ptrdiff_t src_stride,
+                                     const int height,
+                                     uint16_t* LIBGAV1_RESTRICT dst,
+                                     const ptrdiff_t dst_stride) {
+  const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
+  const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
+
+  int y = height;
+  do {
+    HalfAddHorizontal16(src, dst);
+    if (width >= 32) {
+      src += 16;
+      dst += 16;
+      HalfAddHorizontal16(src, dst);
+      if (width >= 64) {
+        src += 16;
+        dst += 16;
+        HalfAddHorizontal16(src, dst);
+        src += 16;
+        dst += 16;
+        HalfAddHorizontal16(src, dst);
+        if (width == 128) {
+          src += 16;
+          dst += 16;
+          HalfAddHorizontal16(src, dst);
+          src += 16;
+          dst += 16;
+          HalfAddHorizontal16(src, dst);
+          src += 16;
+          dst += 16;
+          HalfAddHorizontal16(src, dst);
+          src += 16;
+          dst += 16;
+          HalfAddHorizontal16(src, dst);
+        }
+      }
+    }
+    src += src_remainder_stride;
+    dst += dst_remainder_stride;
+  } while (--y != 0);
+}
+
+void ConvolveIntraBlockCopyHorizontal_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*subpixel_x*/,
+    const int /*subpixel_y*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
+  assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
+  assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
+  const auto* src = static_cast<const uint16_t*>(reference);
+  auto* dest = static_cast<uint16_t*>(prediction);
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  const ptrdiff_t dst_stride = pred_stride >> 1;
+
+  if (width == 128) {
+    IntraBlockCopyHorizontal<128>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 64) {
+    IntraBlockCopyHorizontal<64>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 32) {
+    IntraBlockCopyHorizontal<32>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 16) {
+    IntraBlockCopyHorizontal<16>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 8) {
+    int y = height;
+    do {
+      HalfAddHorizontal(src, dest);
+      src += src_stride;
+      dest += dst_stride;
+    } while (--y != 0);
+  } else {  // width == 4
+    int y = height;
+    do {
+      uint16x4x2_t left;
+      uint16x4x2_t right;
+      left.val[0] = vld1_u16(src);
+      right.val[0] = vld1_u16(src + 1);
+      src += src_stride;
+      left.val[1] = vld1_u16(src);
+      right.val[1] = vld1_u16(src + 1);
+      src += src_stride;
+
+      vst1_u16(dest, vrhadd_u16(left.val[0], right.val[0]));
+      dest += dst_stride;
+      vst1_u16(dest, vrhadd_u16(left.val[1], right.val[1]));
+      dest += dst_stride;
+      y -= 2;
+    } while (y != 0);
+  }
+}
+
+template <int width>
+inline void IntraBlockCopyVertical(const uint16_t* LIBGAV1_RESTRICT src,
+                                   const ptrdiff_t src_stride, const int height,
+                                   uint16_t* LIBGAV1_RESTRICT dst,
+                                   const ptrdiff_t dst_stride) {
+  const ptrdiff_t src_remainder_stride = src_stride - (width - 8);
+  const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8);
+  uint16x8_t row[8], below[8];
+
+  row[0] = vld1q_u16(src);
+  if (width >= 16) {
+    src += 8;
+    row[1] = vld1q_u16(src);
+    if (width >= 32) {
+      src += 8;
+      row[2] = vld1q_u16(src);
+      src += 8;
+      row[3] = vld1q_u16(src);
+      if (width == 64) {
+        src += 8;
+        row[4] = vld1q_u16(src);
+        src += 8;
+        row[5] = vld1q_u16(src);
+        src += 8;
+        row[6] = vld1q_u16(src);
+        src += 8;
+        row[7] = vld1q_u16(src);
+      }
+    }
+  }
+  src += src_remainder_stride;
+
+  int y = height;
+  do {
+    below[0] = vld1q_u16(src);
+    if (width >= 16) {
+      src += 8;
+      below[1] = vld1q_u16(src);
+      if (width >= 32) {
+        src += 8;
+        below[2] = vld1q_u16(src);
+        src += 8;
+        below[3] = vld1q_u16(src);
+        if (width == 64) {
+          src += 8;
+          below[4] = vld1q_u16(src);
+          src += 8;
+          below[5] = vld1q_u16(src);
+          src += 8;
+          below[6] = vld1q_u16(src);
+          src += 8;
+          below[7] = vld1q_u16(src);
+        }
+      }
+    }
+    src += src_remainder_stride;
+
+    vst1q_u16(dst, vrhaddq_u16(row[0], below[0]));
+    row[0] = below[0];
+    if (width >= 16) {
+      dst += 8;
+      vst1q_u16(dst, vrhaddq_u16(row[1], below[1]));
+      row[1] = below[1];
+      if (width >= 32) {
+        dst += 8;
+        vst1q_u16(dst, vrhaddq_u16(row[2], below[2]));
+        row[2] = below[2];
+        dst += 8;
+        vst1q_u16(dst, vrhaddq_u16(row[3], below[3]));
+        row[3] = below[3];
+        if (width >= 64) {
+          dst += 8;
+          vst1q_u16(dst, vrhaddq_u16(row[4], below[4]));
+          row[4] = below[4];
+          dst += 8;
+          vst1q_u16(dst, vrhaddq_u16(row[5], below[5]));
+          row[5] = below[5];
+          dst += 8;
+          vst1q_u16(dst, vrhaddq_u16(row[6], below[6]));
+          row[6] = below[6];
+          dst += 8;
+          vst1q_u16(dst, vrhaddq_u16(row[7], below[7]));
+          row[7] = below[7];
+        }
+      }
+    }
+    dst += dst_remainder_stride;
+  } while (--y != 0);
+}
+
+void ConvolveIntraBlockCopyVertical_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
+  assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
+  assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
+  const auto* src = static_cast<const uint16_t*>(reference);
+  auto* dest = static_cast<uint16_t*>(prediction);
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  const ptrdiff_t dst_stride = pred_stride >> 1;
+
+  if (width == 128) {
+    // Due to register pressure, process two 64xH.
+    for (int i = 0; i < 2; ++i) {
+      IntraBlockCopyVertical<64>(src, src_stride, height, dest, dst_stride);
+      src += 64;
+      dest += 64;
+    }
+  } else if (width == 64) {
+    IntraBlockCopyVertical<64>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 32) {
+    IntraBlockCopyVertical<32>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 16) {
+    IntraBlockCopyVertical<16>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 8) {
+    IntraBlockCopyVertical<8>(src, src_stride, height, dest, dst_stride);
+  } else {  // width == 4
+    uint16x4_t row = vld1_u16(src);
+    src += src_stride;
+    int y = height;
+    do {
+      const uint16x4_t below = vld1_u16(src);
+      src += src_stride;
+      vst1_u16(dest, vrhadd_u16(row, below));
+      dest += dst_stride;
+      row = below;
+    } while (--y != 0);
+  }
+}
+
+template <int width>
+inline void IntraBlockCopy2D(const uint16_t* LIBGAV1_RESTRICT src,
+                             const ptrdiff_t src_stride, const int height,
+                             uint16_t* LIBGAV1_RESTRICT dst,
+                             const ptrdiff_t dst_stride) {
+  const ptrdiff_t src_remainder_stride = src_stride - (width - 8);
+  const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8);
+  uint16x8_t row[16];
+  row[0] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+  if (width >= 16) {
+    src += 8;
+    row[1] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+    if (width >= 32) {
+      src += 8;
+      row[2] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+      src += 8;
+      row[3] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+      if (width >= 64) {
+        src += 8;
+        row[4] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+        src += 8;
+        row[5] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+        src += 8;
+        row[6] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+        src += 8;
+        row[7] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+        if (width == 128) {
+          src += 8;
+          row[8] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          src += 8;
+          row[9] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          src += 8;
+          row[10] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          src += 8;
+          row[11] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          src += 8;
+          row[12] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          src += 8;
+          row[13] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          src += 8;
+          row[14] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          src += 8;
+          row[15] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+        }
+      }
+    }
+  }
+  src += src_remainder_stride;
+
+  int y = height;
+  do {
+    const uint16x8_t below_0 = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+    vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[0], below_0), 2));
+    row[0] = below_0;
+    if (width >= 16) {
+      src += 8;
+      dst += 8;
+
+      const uint16x8_t below_1 = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+      vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[1], below_1), 2));
+      row[1] = below_1;
+      if (width >= 32) {
+        src += 8;
+        dst += 8;
+
+        const uint16x8_t below_2 =
+            vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+        vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[2], below_2), 2));
+        row[2] = below_2;
+        src += 8;
+        dst += 8;
+
+        const uint16x8_t below_3 =
+            vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+        vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[3], below_3), 2));
+        row[3] = below_3;
+        if (width >= 64) {
+          src += 8;
+          dst += 8;
+
+          const uint16x8_t below_4 =
+              vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[4], below_4), 2));
+          row[4] = below_4;
+          src += 8;
+          dst += 8;
+
+          const uint16x8_t below_5 =
+              vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[5], below_5), 2));
+          row[5] = below_5;
+          src += 8;
+          dst += 8;
+
+          const uint16x8_t below_6 =
+              vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[6], below_6), 2));
+          row[6] = below_6;
+          src += 8;
+          dst += 8;
+
+          const uint16x8_t below_7 =
+              vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+          vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[7], below_7), 2));
+          row[7] = below_7;
+          if (width == 128) {
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_8 =
+                vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+            vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[8], below_8), 2));
+            row[8] = below_8;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_9 =
+                vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+            vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[9], below_9), 2));
+            row[9] = below_9;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_10 =
+                vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+            vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[10], below_10), 2));
+            row[10] = below_10;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_11 =
+                vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+            vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[11], below_11), 2));
+            row[11] = below_11;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_12 =
+                vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+            vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[12], below_12), 2));
+            row[12] = below_12;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_13 =
+                vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+            vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[13], below_13), 2));
+            row[13] = below_13;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_14 =
+                vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+            vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[14], below_14), 2));
+            row[14] = below_14;
+            src += 8;
+            dst += 8;
+
+            const uint16x8_t below_15 =
+                vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
+            vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[15], below_15), 2));
+            row[15] = below_15;
+          }
+        }
+      }
+    }
+    src += src_remainder_stride;
+    dst += dst_remainder_stride;
+  } while (--y != 0);
+}
+
+void ConvolveIntraBlockCopy2D_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
+  assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
+  assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
+  const auto* src = static_cast<const uint16_t*>(reference);
+  auto* dest = static_cast<uint16_t*>(prediction);
+  const ptrdiff_t src_stride = reference_stride >> 1;
+  const ptrdiff_t dst_stride = pred_stride >> 1;
+
+  // Note: allow vertical access to height + 1. Because this function is only
+  // for u/v plane of intra block copy, such access is guaranteed to be within
+  // the prediction block.
+
+  if (width == 128) {
+    IntraBlockCopy2D<128>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 64) {
+    IntraBlockCopy2D<64>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 32) {
+    IntraBlockCopy2D<32>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 16) {
+    IntraBlockCopy2D<16>(src, src_stride, height, dest, dst_stride);
+  } else if (width == 8) {
+    IntraBlockCopy2D<8>(src, src_stride, height, dest, dst_stride);
+  } else {  // width == 4
+    uint16x4_t row0 = vadd_u16(vld1_u16(src), vld1_u16(src + 1));
+    src += src_stride;
+
+    int y = height;
+    do {
+      const uint16x4_t row1 = vadd_u16(vld1_u16(src), vld1_u16(src + 1));
+      src += src_stride;
+      const uint16x4_t row2 = vadd_u16(vld1_u16(src), vld1_u16(src + 1));
+      src += src_stride;
+      const uint16x4_t result_01 = vrshr_n_u16(vadd_u16(row0, row1), 2);
+      const uint16x4_t result_12 = vrshr_n_u16(vadd_u16(row1, row2), 2);
+      vst1_u16(dest, result_01);
+      dest += dst_stride;
+      vst1_u16(dest, result_12);
+      dest += dst_stride;
+      row0 = row2;
+      y -= 2;
+    } while (y != 0);
+  }
+}
+
+// -----------------------------------------------------------------------------
+// Scaled Convolve
+
+// There are many opportunities for overreading in scaled convolve, because the
+// range of starting points for filter windows is anywhere from 0 to 16 for 8
+// destination pixels, and the window sizes range from 2 to 8. To accommodate
+// this range concisely, we use |grade_x| to mean the most steps in src that can
+// be traversed in a single |step_x| increment, i.e. 1 or 2. When grade_x is 2,
+// we are guaranteed to exceed 8 whole steps in src for every 8 |step_x|
+// increments. The first load covers the initial elements of src_x, while the
+// final load covers the taps.
+template <int grade_x>
+inline uint8x16x3_t LoadSrcVals(const uint16_t* const src_x) {
+  uint8x16x3_t ret;
+  // When fractional step size is less than or equal to 1, the rightmost
+  // starting value for a filter may be at position 7. For an 8-tap filter, the
+  // rightmost value for the final tap may be at position 14. Therefore we load
+  // 2 vectors of eight 16-bit values.
+  ret.val[0] = vreinterpretq_u8_u16(vld1q_u16(src_x));
+  ret.val[1] = vreinterpretq_u8_u16(vld1q_u16(src_x + 8));
+#if LIBGAV1_MSAN
+  // Initialize to quiet msan warnings when grade_x <= 1.
+  ret.val[2] = vdupq_n_u8(0);
+#endif
+  if (grade_x > 1) {
+    // When fractional step size is greater than 1 (up to 2), the rightmost
+    // starting value for a filter may be at position 15. For an 8-tap filter,
+    // the rightmost value for the final tap may be at position 22. Therefore we
+    // load 3 vectors of eight 16-bit values.
+    ret.val[2] = vreinterpretq_u8_u16(vld1q_u16(src_x + 16));
+  }
+  return ret;
+}
+
+// Assemble 4 values corresponding to one tap position across multiple filters.
+// This is a simple case because maximum offset is 8 and only smaller filters
+// work on 4xH.
+inline uint16x4_t PermuteSrcVals(const uint8x16x3_t src_bytes,
+                                 const uint8x8_t indices) {
+  const uint8x16x2_t src_bytes2 = {src_bytes.val[0], src_bytes.val[1]};
+  return vreinterpret_u16_u8(VQTbl2U8(src_bytes2, indices));
+}
+
+// Assemble 8 values corresponding to one tap position across multiple filters.
+// This requires a lot of workaround on A32 architectures, so it may be worth
+// using an overall different algorithm for that architecture.
+template <int grade_x>
+inline uint16x8_t PermuteSrcVals(const uint8x16x3_t src_bytes,
+                                 const uint8x16_t indices) {
+  if (grade_x == 1) {
+    const uint8x16x2_t src_bytes2 = {src_bytes.val[0], src_bytes.val[1]};
+    return vreinterpretq_u16_u8(VQTbl2QU8(src_bytes2, indices));
+  }
+  return vreinterpretq_u16_u8(VQTbl3QU8(src_bytes, indices));
+}
+
+// Pre-transpose the 2 tap filters in |kAbsHalfSubPixelFilters|[3]
+// Although the taps need to be converted to 16-bit values, they must be
+// arranged by table lookup, which is more expensive for larger types than
+// lengthening in-loop. |tap_index| refers to the index within a kernel applied
+// to a single value.
+inline int8x16_t GetPositive2TapFilter(const int tap_index) {
+  assert(tap_index < 2);
+  alignas(
+      16) static constexpr int8_t kAbsHalfSubPixel2TapFilterColumns[2][16] = {
+      {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4},
+      {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}};
+
+  return vld1q_s8(kAbsHalfSubPixel2TapFilterColumns[tap_index]);
+}
+
+template <int grade_x>
+inline void ConvolveKernelHorizontal2Tap(
+    const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int width, const int subpixel_x, const int step_x,
+    const int intermediate_height, int16_t* LIBGAV1_RESTRICT intermediate) {
+  // Account for the 0-taps that precede the 2 nonzero taps in the spec.
+  const int kernel_offset = 3;
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const int step_x8 = step_x << 3;
+  const int8x16_t filter_taps0 = GetPositive2TapFilter(0);
+  const int8x16_t filter_taps1 = GetPositive2TapFilter(1);
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+
+  int p = subpixel_x;
+  if (width <= 4) {
+    const uint16_t* src_y = src;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+    const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+    const uint8x8_t filter_indices =
+        vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
+    // Each lane of lane of taps[k] corresponds to one output value along the
+    // row, containing kSubPixelFilters[filter_index][filter_id][k], where
+    // filter_id depends on x.
+    const int16x4_t taps[2] = {
+        vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps0, filter_indices))),
+        vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps1, filter_indices)))};
+    // Lower byte of Nth value is at position 2*N.
+    // Narrowing shift is not available here because the maximum shift
+    // parameter is 8.
+    const uint8x8_t src_indices0 = vshl_n_u8(
+        vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
+    // Upper byte of Nth value is at position 2*N+1.
+    const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
+    // Only 4 values needed.
+    const uint8x8_t src_indices = InterleaveLow8(src_indices0, src_indices1);
+    const uint8x8_t src_lookup[2] = {src_indices,
+                                     vadd_u8(src_indices, vdup_n_u8(2))};
+
+    int y = intermediate_height;
+    do {
+      const uint16_t* src_x =
+          src_y + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
+      // Load a pool of samples to select from using stepped indices.
+      const uint8x16x3_t src_bytes = LoadSrcVals<1>(src_x);
+      // Each lane corresponds to a different filter kernel.
+      const uint16x4_t src[2] = {PermuteSrcVals(src_bytes, src_lookup[0]),
+                                 PermuteSrcVals(src_bytes, src_lookup[1])};
+
+      vst1_s16(intermediate,
+               vrshrn_n_s32(SumOnePassTaps</*filter_index=*/3>(src, taps),
+                            kInterRoundBitsHorizontal - 1));
+      src_y = AddByteStride(src_y, src_stride);
+      intermediate += kIntermediateStride;
+    } while (--y != 0);
+    return;
+  }
+
+  // |width| >= 8
+  int16_t* intermediate_x = intermediate;
+  int x = 0;
+  do {
+    const uint16_t* src_x =
+        src + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+    const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+    const uint8x8_t filter_indices =
+        vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+                filter_index_mask);
+    // Each lane of lane of taps[k] corresponds to one output value along the
+    // row, containing kSubPixelFilters[filter_index][filter_id][k], where
+    // filter_id depends on x.
+    const int16x8_t taps[2] = {
+        vmovl_s8(VQTbl1S8(filter_taps0, filter_indices)),
+        vmovl_s8(VQTbl1S8(filter_taps1, filter_indices))};
+    const int16x4_t taps_low[2] = {vget_low_s16(taps[0]),
+                                   vget_low_s16(taps[1])};
+    const int16x4_t taps_high[2] = {vget_high_s16(taps[0]),
+                                    vget_high_s16(taps[1])};
+    // Lower byte of Nth value is at position 2*N.
+    const uint8x8_t src_indices0 = vshl_n_u8(
+        vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
+    // Upper byte of Nth value is at position 2*N+1.
+    const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
+    const uint8x8x2_t src_indices_zip = vzip_u8(src_indices0, src_indices1);
+    const uint8x16_t src_indices =
+        vcombine_u8(src_indices_zip.val[0], src_indices_zip.val[1]);
+    const uint8x16_t src_lookup[2] = {src_indices,
+                                      vaddq_u8(src_indices, vdupq_n_u8(2))};
+
+    int y = intermediate_height;
+    do {
+      // Load a pool of samples to select from using stepped indices.
+      const uint8x16x3_t src_bytes = LoadSrcVals<grade_x>(src_x);
+      // Each lane corresponds to a different filter kernel.
+      const uint16x8_t src[2] = {
+          PermuteSrcVals<grade_x>(src_bytes, src_lookup[0]),
+          PermuteSrcVals<grade_x>(src_bytes, src_lookup[1])};
+      const uint16x4_t src_low[2] = {vget_low_u16(src[0]),
+                                     vget_low_u16(src[1])};
+      const uint16x4_t src_high[2] = {vget_high_u16(src[0]),
+                                      vget_high_u16(src[1])};
+
+      vst1_s16(intermediate_x, vrshrn_n_s32(SumOnePassTaps</*filter_index=*/3>(
+                                                src_low, taps_low),
+                                            kInterRoundBitsHorizontal - 1));
+      vst1_s16(
+          intermediate_x + 4,
+          vrshrn_n_s32(SumOnePassTaps</*filter_index=*/3>(src_high, taps_high),
+                       kInterRoundBitsHorizontal - 1));
+      // Avoid right shifting the stride.
+      src_x = AddByteStride(src_x, src_stride);
+      intermediate_x += kIntermediateStride;
+    } while (--y != 0);
+    x += 8;
+    p += step_x8;
+  } while (x < width);
+}
+
+// Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[5].
+inline int8x16_t GetPositive4TapFilter(const int tap_index) {
+  assert(tap_index < 4);
+  alignas(
+      16) static constexpr int8_t kSubPixel4TapPositiveFilterColumns[4][16] = {
+      {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1},
+      {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
+      {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
+      {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}};
+
+  return vld1q_s8(kSubPixel4TapPositiveFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width <= 4.
+inline void ConvolveKernelHorizontalPositive4Tap(
+    const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int subpixel_x, const int step_x, const int intermediate_height,
+    int16_t* LIBGAV1_RESTRICT intermediate) {
+  // Account for the 0-taps that precede the 2 nonzero taps in the spec.
+  const int kernel_offset = 2;
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const int8x16_t filter_taps0 = GetPositive4TapFilter(0);
+  const int8x16_t filter_taps1 = GetPositive4TapFilter(1);
+  const int8x16_t filter_taps2 = GetPositive4TapFilter(2);
+  const int8x16_t filter_taps3 = GetPositive4TapFilter(3);
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+
+  int p = subpixel_x;
+  // Only add steps to the 10-bit truncated p to avoid overflow.
+  const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+  const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+  const uint8x8_t filter_indices =
+      vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
+  // Each lane of lane of taps[k] corresponds to one output value along the row,
+  // containing kSubPixelFilters[filter_index][filter_id][k], where filter_id
+  // depends on x.
+  const int16x4_t taps[4] = {
+      vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps0, filter_indices))),
+      vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps1, filter_indices))),
+      vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps2, filter_indices))),
+      vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps3, filter_indices)))};
+  // Lower byte of Nth value is at position 2*N.
+  // Narrowing shift is not available here because the maximum shift
+  // parameter is 8.
+  const uint8x8_t src_indices0 = vshl_n_u8(
+      vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
+  // Upper byte of Nth value is at position 2*N+1.
+  const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
+  // Only 4 values needed.
+  const uint8x8_t src_indices_base = InterleaveLow8(src_indices0, src_indices1);
+
+  uint8x8_t src_lookup[4];
+  const uint8x8_t two = vdup_n_u8(2);
+  src_lookup[0] = src_indices_base;
+  for (int i = 1; i < 4; ++i) {
+    src_lookup[i] = vadd_u8(src_lookup[i - 1], two);
+  }
+
+  const uint16_t* src_y =
+      src + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
+  int y = intermediate_height;
+  do {
+    // Load a pool of samples to select from using stepped indices.
+    const uint8x16x3_t src_bytes = LoadSrcVals<1>(src_y);
+    // Each lane corresponds to a different filter kernel.
+    const uint16x4_t src[4] = {PermuteSrcVals(src_bytes, src_lookup[0]),
+                               PermuteSrcVals(src_bytes, src_lookup[1]),
+                               PermuteSrcVals(src_bytes, src_lookup[2]),
+                               PermuteSrcVals(src_bytes, src_lookup[3])};
+
+    vst1_s16(intermediate,
+             vrshrn_n_s32(SumOnePassTaps</*filter_index=*/5>(src, taps),
+                          kInterRoundBitsHorizontal - 1));
+    src_y = AddByteStride(src_y, src_stride);
+    intermediate += kIntermediateStride;
+  } while (--y != 0);
+}
+
+// Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[4].
+inline int8x16_t GetSigned4TapFilter(const int tap_index) {
+  assert(tap_index < 4);
+  alignas(16) static constexpr int8_t
+      kAbsHalfSubPixel4TapSignedFilterColumns[4][16] = {
+          {-0, -2, -4, -5, -6, -6, -7, -6, -6, -5, -5, -5, -4, -3, -2, -1},
+          {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
+          {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
+          {-0, -1, -2, -3, -4, -5, -5, -5, -6, -6, -7, -6, -6, -5, -4, -2}};
+
+  return vld1q_s8(kAbsHalfSubPixel4TapSignedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width <= 4.
+inline void ConvolveKernelHorizontalSigned4Tap(
+    const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int subpixel_x, const int step_x, const int intermediate_height,
+    int16_t* LIBGAV1_RESTRICT intermediate) {
+  const int kernel_offset = 2;
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+  const int8x16_t filter_taps0 = GetSigned4TapFilter(0);
+  const int8x16_t filter_taps1 = GetSigned4TapFilter(1);
+  const int8x16_t filter_taps2 = GetSigned4TapFilter(2);
+  const int8x16_t filter_taps3 = GetSigned4TapFilter(3);
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+
+  const int p = subpixel_x;
+  // Only add steps to the 10-bit truncated p to avoid overflow.
+  const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+  const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+  const uint8x8_t filter_indices =
+      vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
+  // Each lane of lane of taps[k] corresponds to one output value along the row,
+  // containing kSubPixelFilters[filter_index][filter_id][k], where filter_id
+  // depends on x.
+  const int16x4_t taps[4] = {
+      vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps0, filter_indices))),
+      vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps1, filter_indices))),
+      vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps2, filter_indices))),
+      vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps3, filter_indices)))};
+  // Lower byte of Nth value is at position 2*N.
+  // Narrowing shift is not available here because the maximum shift
+  // parameter is 8.
+  const uint8x8_t src_indices0 = vshl_n_u8(
+      vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
+  // Upper byte of Nth value is at position 2*N+1.
+  const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
+  // Only 4 values needed.
+  const uint8x8_t src_indices_base = InterleaveLow8(src_indices0, src_indices1);
+
+  uint8x8_t src_lookup[4];
+  const uint8x8_t two = vdup_n_u8(2);
+  src_lookup[0] = src_indices_base;
+  for (int i = 1; i < 4; ++i) {
+    src_lookup[i] = vadd_u8(src_lookup[i - 1], two);
+  }
+
+  const uint16_t* src_y =
+      src + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
+  int y = intermediate_height;
+  do {
+    // Load a pool of samples to select from using stepped indices.
+    const uint8x16x3_t src_bytes = LoadSrcVals<1>(src_y);
+    // Each lane corresponds to a different filter kernel.
+    const uint16x4_t src[4] = {PermuteSrcVals(src_bytes, src_lookup[0]),
+                               PermuteSrcVals(src_bytes, src_lookup[1]),
+                               PermuteSrcVals(src_bytes, src_lookup[2]),
+                               PermuteSrcVals(src_bytes, src_lookup[3])};
+
+    vst1_s16(intermediate,
+             vrshrn_n_s32(SumOnePassTaps</*filter_index=*/4>(src, taps),
+                          kInterRoundBitsHorizontal - 1));
+    src_y = AddByteStride(src_y, src_stride);
+    intermediate += kIntermediateStride;
+  } while (--y != 0);
+}
+
+// Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[0].
+inline int8x16_t GetSigned6TapFilter(const int tap_index) {
+  assert(tap_index < 6);
+  alignas(16) static constexpr int8_t
+      kAbsHalfSubPixel6TapSignedFilterColumns[6][16] = {
+          {0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0},
+          {-0, -3, -5, -6, -7, -7, -8, -7, -7, -6, -6, -6, -5, -4, -2, -1},
+          {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
+          {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
+          {-0, -1, -2, -4, -5, -6, -6, -6, -7, -7, -8, -7, -7, -6, -5, -3},
+          {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}};
+
+  return vld1q_s8(kAbsHalfSubPixel6TapSignedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width >= 8.
+template <int grade_x>
+inline void ConvolveKernelHorizontalSigned6Tap(
+    const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int width, const int subpixel_x, const int step_x,
+    const int intermediate_height,
+    int16_t* LIBGAV1_RESTRICT const intermediate) {
+  const int kernel_offset = 1;
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const int step_x8 = step_x << 3;
+  int8x16_t filter_taps[6];
+  for (int i = 0; i < 6; ++i) {
+    filter_taps[i] = GetSigned6TapFilter(i);
+  }
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+
+  int16_t* intermediate_x = intermediate;
+  int x = 0;
+  int p = subpixel_x;
+  do {
+    const uint16_t* src_x =
+        src + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+    const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+    const uint8x8_t filter_indices =
+        vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+                filter_index_mask);
+
+    // Each lane of lane of taps_(low|high)[k] corresponds to one output value
+    // along the row, containing kSubPixelFilters[filter_index][filter_id][k],
+    // where filter_id depends on x.
+    int16x4_t taps_low[6];
+    int16x4_t taps_high[6];
+    for (int i = 0; i < 6; ++i) {
+      const int16x8_t taps_i =
+          vmovl_s8(VQTbl1S8(filter_taps[i], filter_indices));
+      taps_low[i] = vget_low_s16(taps_i);
+      taps_high[i] = vget_high_s16(taps_i);
+    }
+
+    // Lower byte of Nth value is at position 2*N.
+    const uint8x8_t src_indices0 = vshl_n_u8(
+        vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
+    // Upper byte of Nth value is at position 2*N+1.
+    const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
+    const uint8x8x2_t src_indices_zip = vzip_u8(src_indices0, src_indices1);
+    const uint8x16_t src_indices_base =
+        vcombine_u8(src_indices_zip.val[0], src_indices_zip.val[1]);
+
+    uint8x16_t src_lookup[6];
+    const uint8x16_t two = vdupq_n_u8(2);
+    src_lookup[0] = src_indices_base;
+    for (int i = 1; i < 6; ++i) {
+      src_lookup[i] = vaddq_u8(src_lookup[i - 1], two);
+    }
+
+    int y = intermediate_height;
+    do {
+      // Load a pool of samples to select from using stepped indices.
+      const uint8x16x3_t src_bytes = LoadSrcVals<grade_x>(src_x);
+
+      uint16x4_t src_low[6];
+      uint16x4_t src_high[6];
+      for (int i = 0; i < 6; ++i) {
+        const uint16x8_t src_i =
+            PermuteSrcVals<grade_x>(src_bytes, src_lookup[i]);
+        src_low[i] = vget_low_u16(src_i);
+        src_high[i] = vget_high_u16(src_i);
+      }
+
+      vst1_s16(intermediate_x, vrshrn_n_s32(SumOnePassTaps</*filter_index=*/0>(
+                                                src_low, taps_low),
+                                            kInterRoundBitsHorizontal - 1));
+      vst1_s16(
+          intermediate_x + 4,
+          vrshrn_n_s32(SumOnePassTaps</*filter_index=*/0>(src_high, taps_high),
+                       kInterRoundBitsHorizontal - 1));
+      // Avoid right shifting the stride.
+      src_x = AddByteStride(src_x, src_stride);
+      intermediate_x += kIntermediateStride;
+    } while (--y != 0);
+    x += 8;
+    p += step_x8;
+  } while (x < width);
+}
+
+// Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[1]. This filter
+// has mixed positive and negative outer taps depending on the filter id.
+inline int8x16_t GetMixed6TapFilter(const int tap_index) {
+  assert(tap_index < 6);
+  alignas(16) static constexpr int8_t
+      kAbsHalfSubPixel6TapMixedFilterColumns[6][16] = {
+          {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0},
+          {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1},
+          {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
+          {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
+          {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14},
+          {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}};
+
+  return vld1q_s8(kAbsHalfSubPixel6TapMixedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width >= 8.
+template <int grade_x>
+inline void ConvolveKernelHorizontalMixed6Tap(
+    const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int width, const int subpixel_x, const int step_x,
+    const int intermediate_height,
+    int16_t* LIBGAV1_RESTRICT const intermediate) {
+  const int kernel_offset = 1;
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const int step_x8 = step_x << 3;
+  int8x16_t filter_taps[6];
+  for (int i = 0; i < 6; ++i) {
+    filter_taps[i] = GetMixed6TapFilter(i);
+  }
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+
+  int16_t* intermediate_x = intermediate;
+  int x = 0;
+  int p = subpixel_x;
+  do {
+    const uint16_t* src_x =
+        src + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+    const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+
+    const uint8x8_t filter_indices =
+        vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+                filter_index_mask);
+    // Each lane of lane of taps_(low|high)[k] corresponds to one output value
+    // along the row, containing kSubPixelFilters[filter_index][filter_id][k],
+    // where filter_id depends on x.
+    int16x4_t taps_low[6];
+    int16x4_t taps_high[6];
+    for (int i = 0; i < 6; ++i) {
+      const int16x8_t taps = vmovl_s8(VQTbl1S8(filter_taps[i], filter_indices));
+      taps_low[i] = vget_low_s16(taps);
+      taps_high[i] = vget_high_s16(taps);
+    }
+
+    // Lower byte of Nth value is at position 2*N.
+    const uint8x8_t src_indices0 = vshl_n_u8(
+        vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
+    // Upper byte of Nth value is at position 2*N+1.
+    const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
+    const uint8x8x2_t src_indices_zip = vzip_u8(src_indices0, src_indices1);
+    const uint8x16_t src_indices_base =
+        vcombine_u8(src_indices_zip.val[0], src_indices_zip.val[1]);
+
+    uint8x16_t src_lookup[6];
+    const uint8x16_t two = vdupq_n_u8(2);
+    src_lookup[0] = src_indices_base;
+    for (int i = 1; i < 6; ++i) {
+      src_lookup[i] = vaddq_u8(src_lookup[i - 1], two);
+    }
+
+    int y = intermediate_height;
+    do {
+      // Load a pool of samples to select from using stepped indices.
+      const uint8x16x3_t src_bytes = LoadSrcVals<grade_x>(src_x);
+
+      uint16x4_t src_low[6];
+      uint16x4_t src_high[6];
+      for (int i = 0; i < 6; ++i) {
+        const uint16x8_t src_i =
+            PermuteSrcVals<grade_x>(src_bytes, src_lookup[i]);
+        src_low[i] = vget_low_u16(src_i);
+        src_high[i] = vget_high_u16(src_i);
+      }
+
+      vst1_s16(intermediate_x, vrshrn_n_s32(SumOnePassTaps</*filter_index=*/0>(
+                                                src_low, taps_low),
+                                            kInterRoundBitsHorizontal - 1));
+      vst1_s16(
+          intermediate_x + 4,
+          vrshrn_n_s32(SumOnePassTaps</*filter_index=*/0>(src_high, taps_high),
+                       kInterRoundBitsHorizontal - 1));
+      // Avoid right shifting the stride.
+      src_x = AddByteStride(src_x, src_stride);
+      intermediate_x += kIntermediateStride;
+    } while (--y != 0);
+    x += 8;
+    p += step_x8;
+  } while (x < width);
+}
+
+// Pre-transpose the 8 tap filters in |kAbsHalfSubPixelFilters|[2].
+inline int8x16_t GetSigned8TapFilter(const int tap_index) {
+  assert(tap_index < 8);
+  alignas(16) static constexpr int8_t
+      kAbsHalfSubPixel8TapSignedFilterColumns[8][16] = {
+          {-0, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -0},
+          {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1},
+          {-0, -3, -6, -9, -11, -11, -12, -12, -12, -11, -10, -9, -7, -5, -3,
+           -1},
+          {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4},
+          {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63},
+          {-0, -1, -3, -5, -7, -9, -10, -11, -12, -12, -12, -11, -11, -9, -6,
+           -3},
+          {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1},
+          {-0, -0, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1}};
+
+  return vld1q_s8(kAbsHalfSubPixel8TapSignedFilterColumns[tap_index]);
+}
+
+// This filter is only possible when width >= 8.
+template <int grade_x>
+inline void ConvolveKernelHorizontalSigned8Tap(
+    const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int width, const int subpixel_x, const int step_x,
+    const int intermediate_height,
+    int16_t* LIBGAV1_RESTRICT const intermediate) {
+  const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
+  const int ref_x = subpixel_x >> kScaleSubPixelBits;
+  const int step_x8 = step_x << 3;
+  int8x16_t filter_taps[8];
+  for (int i = 0; i < 8; ++i) {
+    filter_taps[i] = GetSigned8TapFilter(i);
+  }
+  const uint16x8_t index_steps = vmulq_n_u16(
+      vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+  int16_t* intermediate_x = intermediate;
+  int x = 0;
+  int p = subpixel_x;
+  do {
+    const uint16_t* src_x = src + (p >> kScaleSubPixelBits) - ref_x;
+    // Only add steps to the 10-bit truncated p to avoid overflow.
+    const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
+    const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
+
+    const uint8x8_t filter_indices =
+        vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
+                filter_index_mask);
+
+    // Lower byte of Nth value is at position 2*N.
+    const uint8x8_t src_indices0 = vshl_n_u8(
+        vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
+    // Upper byte of Nth value is at position 2*N+1.
+    const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
+    const uint8x8x2_t src_indices_zip = vzip_u8(src_indices0, src_indices1);
+    const uint8x16_t src_indices_base =
+        vcombine_u8(src_indices_zip.val[0], src_indices_zip.val[1]);
+
+    uint8x16_t src_lookup[8];
+    const uint8x16_t two = vdupq_n_u8(2);
+    src_lookup[0] = src_indices_base;
+    for (int i = 1; i < 8; ++i) {
+      src_lookup[i] = vaddq_u8(src_lookup[i - 1], two);
+    }
+    // Each lane of lane of taps_(low|high)[k] corresponds to one output value
+    // along the row, containing kSubPixelFilters[filter_index][filter_id][k],
+    // where filter_id depends on x.
+    int16x4_t taps_low[8];
+    int16x4_t taps_high[8];
+    for (int i = 0; i < 8; ++i) {
+      const int16x8_t taps = vmovl_s8(VQTbl1S8(filter_taps[i], filter_indices));
+      taps_low[i] = vget_low_s16(taps);
+      taps_high[i] = vget_high_s16(taps);
+    }
+
+    int y = intermediate_height;
+    do {
+      // Load a pool of samples to select from using stepped indices.
+      const uint8x16x3_t src_bytes = LoadSrcVals<grade_x>(src_x);
+
+      uint16x4_t src_low[8];
+      uint16x4_t src_high[8];
+      for (int i = 0; i < 8; ++i) {
+        const uint16x8_t src_i =
+            PermuteSrcVals<grade_x>(src_bytes, src_lookup[i]);
+        src_low[i] = vget_low_u16(src_i);
+        src_high[i] = vget_high_u16(src_i);
+      }
+
+      vst1_s16(intermediate_x, vrshrn_n_s32(SumOnePassTaps</*filter_index=*/2>(
+                                                src_low, taps_low),
+                                            kInterRoundBitsHorizontal - 1));
+      vst1_s16(
+          intermediate_x + 4,
+          vrshrn_n_s32(SumOnePassTaps</*filter_index=*/2>(src_high, taps_high),
+                       kInterRoundBitsHorizontal - 1));
+      // Avoid right shifting the stride.
+      src_x = AddByteStride(src_x, src_stride);
+      intermediate_x += kIntermediateStride;
+    } while (--y != 0);
+    x += 8;
+    p += step_x8;
+  } while (x < width);
+}
+
+// Process 16 bit inputs and output 32 bits.
+template <int num_taps, bool is_compound>
+inline int16x4_t Sum2DVerticalTaps4(const int16x4_t* const src,
+                                    const int16x8_t taps) {
+  const int16x4_t taps_lo = vget_low_s16(taps);
+  const int16x4_t taps_hi = vget_high_s16(taps);
+  int32x4_t sum;
+  if (num_taps == 8) {
+    sum = vmull_lane_s16(src[0], taps_lo, 0);
+    sum = vmlal_lane_s16(sum, src[1], taps_lo, 1);
+    sum = vmlal_lane_s16(sum, src[2], taps_lo, 2);
+    sum = vmlal_lane_s16(sum, src[3], taps_lo, 3);
+    sum = vmlal_lane_s16(sum, src[4], taps_hi, 0);
+    sum = vmlal_lane_s16(sum, src[5], taps_hi, 1);
+    sum = vmlal_lane_s16(sum, src[6], taps_hi, 2);
+    sum = vmlal_lane_s16(sum, src[7], taps_hi, 3);
+  } else if (num_taps == 6) {
+    sum = vmull_lane_s16(src[0], taps_lo, 1);
+    sum = vmlal_lane_s16(sum, src[1], taps_lo, 2);
+    sum = vmlal_lane_s16(sum, src[2], taps_lo, 3);
+    sum = vmlal_lane_s16(sum, src[3], taps_hi, 0);
+    sum = vmlal_lane_s16(sum, src[4], taps_hi, 1);
+    sum = vmlal_lane_s16(sum, src[5], taps_hi, 2);
+  } else if (num_taps == 4) {
+    sum = vmull_lane_s16(src[0], taps_lo, 2);
+    sum = vmlal_lane_s16(sum, src[1], taps_lo, 3);
+    sum = vmlal_lane_s16(sum, src[2], taps_hi, 0);
+    sum = vmlal_lane_s16(sum, src[3], taps_hi, 1);
+  } else if (num_taps == 2) {
+    sum = vmull_lane_s16(src[0], taps_lo, 3);
+    sum = vmlal_lane_s16(sum, src[1], taps_hi, 0);
+  }
+
+  if (is_compound) {
+    return vrshrn_n_s32(sum, kInterRoundBitsCompoundVertical - 1);
+  }
+
+  return vreinterpret_s16_u16(vqrshrun_n_s32(sum, kInterRoundBitsVertical - 1));
+}
+
+template <int num_taps, int grade_y, int width, bool is_compound>
+void ConvolveVerticalScale2Or4xH(const int16_t* LIBGAV1_RESTRICT const src,
+                                 const int subpixel_y, const int filter_index,
+                                 const int step_y, const int height,
+                                 void* LIBGAV1_RESTRICT const dest,
+                                 const ptrdiff_t dest_stride) {
+  static_assert(width == 2 || width == 4, "");
+  // We increment stride with the 8-bit pointer and then reinterpret to avoid
+  // shifting |dest_stride|.
+  auto* dest_y = static_cast<uint16_t*>(dest);
+  // In compound mode, |dest_stride| is based on the size of uint16_t, rather
+  // than bytes.
+  auto* compound_dest_y = static_cast<uint16_t*>(dest);
+  // This stride always corresponds to int16_t.
+  constexpr ptrdiff_t src_stride = kIntermediateStride;
+  const int16_t* src_y = src;
+  int16x4_t s[num_taps + grade_y];
+
+  int p = subpixel_y & 1023;
+  int prev_p = p;
+  int y = height;
+  do {
+    for (int i = 0; i < num_taps; ++i) {
+      s[i] = vld1_s16(src_y + i * src_stride);
+    }
+    int filter_id = (p >> 6) & kSubPixelMask;
+    int16x8_t filter =
+        vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+    int16x4_t sums = Sum2DVerticalTaps4<num_taps, is_compound>(s, filter);
+    if (is_compound) {
+      assert(width != 2);
+      // This offset potentially overflows into the sign bit, but should yield
+      // the correct unsigned value.
+      const uint16x4_t result =
+          vreinterpret_u16_s16(vadd_s16(sums, vdup_n_s16(kCompoundOffset)));
+      vst1_u16(compound_dest_y, result);
+      compound_dest_y += dest_stride;
+    } else {
+      const uint16x4_t result = vmin_u16(vreinterpret_u16_s16(sums),
+                                         vdup_n_u16((1 << kBitdepth10) - 1));
+      if (width == 2) {
+        Store2<0>(dest_y, result);
+      } else {
+        vst1_u16(dest_y, result);
+      }
+      dest_y = AddByteStride(dest_y, dest_stride);
+    }
+    p += step_y;
+    const int p_diff =
+        (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits);
+    prev_p = p;
+    // Here we load extra source in case it is needed. If |p_diff| == 0, these
+    // values will be unused, but it's faster to load than to branch.
+    s[num_taps] = vld1_s16(src_y + num_taps * src_stride);
+    if (grade_y > 1) {
+      s[num_taps + 1] = vld1_s16(src_y + (num_taps + 1) * src_stride);
+    }
+
+    filter_id = (p >> 6) & kSubPixelMask;
+    filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+    sums = Sum2DVerticalTaps4<num_taps, is_compound>(&s[p_diff], filter);
+    if (is_compound) {
+      assert(width != 2);
+      const uint16x4_t result =
+          vreinterpret_u16_s16(vadd_s16(sums, vdup_n_s16(kCompoundOffset)));
+      vst1_u16(compound_dest_y, result);
+      compound_dest_y += dest_stride;
+    } else {
+      const uint16x4_t result = vmin_u16(vreinterpret_u16_s16(sums),
+                                         vdup_n_u16((1 << kBitdepth10) - 1));
+      if (width == 2) {
+        Store2<0>(dest_y, result);
+      } else {
+        vst1_u16(dest_y, result);
+      }
+      dest_y = AddByteStride(dest_y, dest_stride);
+    }
+    p += step_y;
+    src_y = src + (p >> kScaleSubPixelBits) * src_stride;
+    prev_p = p;
+    y -= 2;
+  } while (y != 0);
+}
+
+template <int num_taps, int grade_y, bool is_compound>
+void ConvolveVerticalScale(const int16_t* LIBGAV1_RESTRICT const source,
+                           const int intermediate_height, const int width,
+                           const int subpixel_y, const int filter_index,
+                           const int step_y, const int height,
+                           void* LIBGAV1_RESTRICT const dest,
+                           const ptrdiff_t dest_stride) {
+  // This stride always corresponds to int16_t.
+  constexpr ptrdiff_t src_stride = kIntermediateStride;
+
+  int16x8_t s[num_taps + 2];
+
+  const int16_t* src = source;
+  int x = 0;
+  do {
+    const int16_t* src_y = src;
+    int p = subpixel_y & 1023;
+    int prev_p = p;
+    // We increment stride with the 8-bit pointer and then reinterpret to avoid
+    // shifting |dest_stride|.
+    auto* dest_y = static_cast<uint16_t*>(dest) + x;
+    // In compound mode, |dest_stride| is based on the size of uint16_t, rather
+    // than bytes.
+    auto* compound_dest_y = static_cast<uint16_t*>(dest) + x;
+    int y = height;
+    do {
+      for (int i = 0; i < num_taps; ++i) {
+        s[i] = vld1q_s16(src_y + i * src_stride);
+      }
+      int filter_id = (p >> 6) & kSubPixelMask;
+      int16x8_t filter =
+          vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+      int16x8_t sums =
+          SimpleSum2DVerticalTaps<num_taps, is_compound>(s, filter);
+      if (is_compound) {
+        // This offset potentially overflows int16_t, but should yield the
+        // correct unsigned value.
+        const uint16x8_t result = vreinterpretq_u16_s16(
+            vaddq_s16(sums, vdupq_n_s16(kCompoundOffset)));
+        vst1q_u16(compound_dest_y, result);
+        compound_dest_y += dest_stride;
+      } else {
+        const uint16x8_t result = vminq_u16(
+            vreinterpretq_u16_s16(sums), vdupq_n_u16((1 << kBitdepth10) - 1));
+        vst1q_u16(dest_y, result);
+        dest_y = AddByteStride(dest_y, dest_stride);
+      }
+      p += step_y;
+      const int p_diff =
+          (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits);
+      prev_p = p;
+      // Here we load extra source in case it is needed. If |p_diff| == 0, these
+      // values will be unused, but it's faster to load than to branch.
+      s[num_taps] = vld1q_s16(src_y + num_taps * src_stride);
+      if (grade_y > 1) {
+        s[num_taps + 1] = vld1q_s16(src_y + (num_taps + 1) * src_stride);
+      }
+
+      filter_id = (p >> 6) & kSubPixelMask;
+      filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
+      sums = SimpleSum2DVerticalTaps<num_taps, is_compound>(&s[p_diff], filter);
+      if (is_compound) {
+        assert(width != 2);
+        const uint16x8_t result = vreinterpretq_u16_s16(
+            vaddq_s16(sums, vdupq_n_s16(kCompoundOffset)));
+        vst1q_u16(compound_dest_y, result);
+        compound_dest_y += dest_stride;
+      } else {
+        const uint16x8_t result = vminq_u16(
+            vreinterpretq_u16_s16(sums), vdupq_n_u16((1 << kBitdepth10) - 1));
+        vst1q_u16(dest_y, result);
+        dest_y = AddByteStride(dest_y, dest_stride);
+      }
+      p += step_y;
+      src_y = src + (p >> kScaleSubPixelBits) * src_stride;
+      prev_p = p;
+
+      y -= 2;
+    } while (y != 0);
+    src += kIntermediateStride * intermediate_height;
+    x += 8;
+  } while (x < width);
+}
+
+template <bool is_compound>
+void ConvolveScale2D_NEON(const void* LIBGAV1_RESTRICT const reference,
+                          const ptrdiff_t reference_stride,
+                          const int horizontal_filter_index,
+                          const int vertical_filter_index, const int subpixel_x,
+                          const int subpixel_y, const int step_x,
+                          const int step_y, const int width, const int height,
+                          void* LIBGAV1_RESTRICT const prediction,
+                          const ptrdiff_t pred_stride) {
+  const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
+  const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
+  assert(step_x <= 2048);
+  assert(step_y <= 2048);
+  const int num_vert_taps = GetNumTapsInFilter(vert_filter_index);
+  const int intermediate_height =
+      (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
+       kScaleSubPixelBits) +
+      num_vert_taps;
+  int16_t intermediate_result[kIntermediateAllocWidth *
+                              (2 * kIntermediateAllocWidth + 8)];
+#if LIBGAV1_MSAN
+  // Quiet msan warnings. Set with random non-zero value to aid in debugging.
+  memset(intermediate_result, 0x54, sizeof(intermediate_result));
+#endif
+  // Horizontal filter.
+  // Filter types used for width <= 4 are different from those for width > 4.
+  // When width > 4, the valid filter index range is always [0, 3].
+  // When width <= 4, the valid filter index range is always [3, 5].
+  // The same applies to height and vertical filter index.
+  int filter_index = GetFilterIndex(horizontal_filter_index, width);
+  int16_t* intermediate = intermediate_result;
+  const ptrdiff_t src_stride = reference_stride;
+  const auto* src = static_cast<const uint16_t*>(reference);
+  const int vert_kernel_offset = (8 - num_vert_taps) / 2;
+  src = AddByteStride(src, vert_kernel_offset * src_stride);
+
+  // Derive the maximum value of |step_x| at which all source values fit in one
+  // 16-byte (8-value) load. Final index is src_x + |num_taps| - 1 < 16
+  // step_x*7 is the final base subpel index for the shuffle mask for filter
+  // inputs in each iteration on large blocks. When step_x is large, we need a
+  // larger structure and use a larger table lookup in order to gather all
+  // filter inputs.
+  const int num_horiz_taps = GetNumTapsInFilter(horiz_filter_index);
+  // |num_taps| - 1 is the shuffle index of the final filter input.
+  const int kernel_start_ceiling = 16 - num_horiz_taps;
+  // This truncated quotient |grade_x_threshold| selects |step_x| such that:
+  // (step_x * 7) >> kScaleSubPixelBits < single load limit
+  const int grade_x_threshold =
+      (kernel_start_ceiling << kScaleSubPixelBits) / 7;
+
+  switch (filter_index) {
+    case 0:
+      if (step_x > grade_x_threshold) {
+        ConvolveKernelHorizontalSigned6Tap<2>(
+            src, src_stride, width, subpixel_x, step_x, intermediate_height,
+            intermediate);
+      } else {
+        ConvolveKernelHorizontalSigned6Tap<1>(
+            src, src_stride, width, subpixel_x, step_x, intermediate_height,
+            intermediate);
+      }
+      break;
+    case 1:
+      if (step_x > grade_x_threshold) {
+        ConvolveKernelHorizontalMixed6Tap<2>(src, src_stride, width, subpixel_x,
+                                             step_x, intermediate_height,
+                                             intermediate);
+
+      } else {
+        ConvolveKernelHorizontalMixed6Tap<1>(src, src_stride, width, subpixel_x,
+                                             step_x, intermediate_height,
+                                             intermediate);
+      }
+      break;
+    case 2:
+      if (step_x > grade_x_threshold) {
+        ConvolveKernelHorizontalSigned8Tap<2>(
+            src, src_stride, width, subpixel_x, step_x, intermediate_height,
+            intermediate);
+      } else {
+        ConvolveKernelHorizontalSigned8Tap<1>(
+            src, src_stride, width, subpixel_x, step_x, intermediate_height,
+            intermediate);
+      }
+      break;
+    case 3:
+      if (step_x > grade_x_threshold) {
+        ConvolveKernelHorizontal2Tap<2>(src, src_stride, width, subpixel_x,
+                                        step_x, intermediate_height,
+                                        intermediate);
+      } else {
+        ConvolveKernelHorizontal2Tap<1>(src, src_stride, width, subpixel_x,
+                                        step_x, intermediate_height,
+                                        intermediate);
+      }
+      break;
+    case 4:
+      assert(width <= 4);
+      ConvolveKernelHorizontalSigned4Tap(src, src_stride, subpixel_x, step_x,
+                                         intermediate_height, intermediate);
+      break;
+    default:
+      assert(filter_index == 5);
+      ConvolveKernelHorizontalPositive4Tap(src, src_stride, subpixel_x, step_x,
+                                           intermediate_height, intermediate);
+  }
+
+  // Vertical filter.
+  filter_index = GetFilterIndex(vertical_filter_index, height);
+  intermediate = intermediate_result;
+  switch (filter_index) {
+    case 0:
+    case 1:
+      if (step_y <= 1024) {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale2Or4xH<6, 1, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale2Or4xH<6, 1, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<6, 1, is_compound>(
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
+        }
+      } else {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale2Or4xH<6, 2, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale2Or4xH<6, 2, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<6, 2, is_compound>(
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
+        }
+      }
+      break;
+    case 2:
+      if (step_y <= 1024) {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale2Or4xH<8, 1, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale2Or4xH<8, 1, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<8, 1, is_compound>(
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
+        }
+      } else {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale2Or4xH<8, 2, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale2Or4xH<8, 2, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<8, 2, is_compound>(
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
+        }
+      }
+      break;
+    case 3:
+      if (step_y <= 1024) {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale2Or4xH<2, 1, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale2Or4xH<2, 1, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<2, 1, is_compound>(
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
+        }
+      } else {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale2Or4xH<2, 2, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale2Or4xH<2, 2, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<2, 2, is_compound>(
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
+        }
+      }
+      break;
+    default:
+      assert(filter_index == 4 || filter_index == 5);
+      assert(height <= 4);
+      if (step_y <= 1024) {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale2Or4xH<4, 1, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale2Or4xH<4, 1, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<4, 1, is_compound>(
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
+        }
+      } else {
+        if (!is_compound && width == 2) {
+          ConvolveVerticalScale2Or4xH<4, 2, 2, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else if (width == 4) {
+          ConvolveVerticalScale2Or4xH<4, 2, 4, is_compound>(
+              intermediate, subpixel_y, filter_index, step_y, height,
+              prediction, pred_stride);
+        } else {
+          ConvolveVerticalScale<4, 2, is_compound>(
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
+        }
+      }
+  }
+}
+
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON;
+  dsp->convolve[0][0][1][0] = ConvolveVertical_NEON;
+  dsp->convolve[0][0][1][1] = Convolve2D_NEON;
+
+  dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_NEON;
+  dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_NEON;
+  dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON;
+  dsp->convolve[0][1][1][1] = ConvolveCompound2D_NEON;
+
+  dsp->convolve[1][0][0][1] = ConvolveIntraBlockCopyHorizontal_NEON;
+  dsp->convolve[1][0][1][0] = ConvolveIntraBlockCopyVertical_NEON;
+  dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_NEON;
+
+  dsp->convolve_scale[0] = ConvolveScale2D_NEON<false>;
+  dsp->convolve_scale[1] = ConvolveScale2D_NEON<true>;
+}
+
+}  // namespace
+
+void ConvolveInit10bpp_NEON() { Init10bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !(LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10)
+
+namespace libgav1 {
+namespace dsp {
+
+void ConvolveInit10bpp_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
diff --git a/libgav1/src/dsp/arm/convolve_neon.cc b/libgav1/src/dsp/arm/convolve_neon.cc
index 331bfe2..5b80da2 100644
--- a/libgav1/src/dsp/arm/convolve_neon.cc
+++ b/libgav1/src/dsp/arm/convolve_neon.cc
@@ -103,9 +103,11 @@
 
 template <int filter_index, bool negative_outside_taps, bool is_2d,
           bool is_compound>
-void FilterHorizontalWidth8AndUp(const uint8_t* src, const ptrdiff_t src_stride,
-                                 void* const dest, const ptrdiff_t pred_stride,
-                                 const int width, const int height,
+void FilterHorizontalWidth8AndUp(const uint8_t* LIBGAV1_RESTRICT src,
+                                 const ptrdiff_t src_stride,
+                                 void* LIBGAV1_RESTRICT const dest,
+                                 const ptrdiff_t pred_stride, const int width,
+                                 const int height,
                                  const uint8x8_t* const v_tap) {
   auto* dest8 = static_cast<uint8_t*>(dest);
   auto* dest16 = static_cast<uint16_t*>(dest);
@@ -220,9 +222,11 @@
 }
 
 template <int filter_index, bool is_2d, bool is_compound>
-void FilterHorizontalWidth4(const uint8_t* src, const ptrdiff_t src_stride,
-                            void* const dest, const ptrdiff_t pred_stride,
-                            const int height, const uint8x8_t* const v_tap) {
+void FilterHorizontalWidth4(const uint8_t* LIBGAV1_RESTRICT src,
+                            const ptrdiff_t src_stride,
+                            void* LIBGAV1_RESTRICT const dest,
+                            const ptrdiff_t pred_stride, const int height,
+                            const uint8x8_t* const v_tap) {
   auto* dest8 = static_cast<uint8_t*>(dest);
   auto* dest16 = static_cast<uint16_t*>(dest);
   int y = height;
@@ -257,9 +261,11 @@
 }
 
 template <int filter_index, bool is_2d>
-void FilterHorizontalWidth2(const uint8_t* src, const ptrdiff_t src_stride,
-                            void* const dest, const ptrdiff_t pred_stride,
-                            const int height, const uint8x8_t* const v_tap) {
+void FilterHorizontalWidth2(const uint8_t* LIBGAV1_RESTRICT src,
+                            const ptrdiff_t src_stride,
+                            void* LIBGAV1_RESTRICT const dest,
+                            const ptrdiff_t pred_stride, const int height,
+                            const uint8x8_t* const v_tap) {
   auto* dest8 = static_cast<uint8_t*>(dest);
   auto* dest16 = static_cast<uint16_t*>(dest);
   int y = height >> 1;
@@ -345,10 +351,11 @@
 
 template <int filter_index, bool negative_outside_taps, bool is_2d,
           bool is_compound>
-void FilterHorizontal(const uint8_t* const src, const ptrdiff_t src_stride,
-                      void* const dest, const ptrdiff_t pred_stride,
-                      const int width, const int height,
-                      const uint8x8_t* const v_tap) {
+void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT const src,
+                      const ptrdiff_t src_stride,
+                      void* LIBGAV1_RESTRICT const dest,
+                      const ptrdiff_t pred_stride, const int width,
+                      const int height, const uint8x8_t* const v_tap) {
   assert(width < 8 || filter_index <= 3);
   // Don't simplify the redundant if conditions with the template parameters,
   // which helps the compiler generate compact code.
@@ -484,7 +491,8 @@
 }
 
 template <int num_taps, bool is_compound = false>
-void Filter2DVerticalWidth8AndUp(const uint16_t* src, void* const dst,
+void Filter2DVerticalWidth8AndUp(const uint16_t* LIBGAV1_RESTRICT src,
+                                 void* LIBGAV1_RESTRICT const dst,
                                  const ptrdiff_t dst_stride, const int width,
                                  const int height, const int16x8_t taps) {
   assert(width >= 8);
@@ -560,7 +568,8 @@
 
 // Take advantage of |src_stride| == |width| to process two rows at a time.
 template <int num_taps, bool is_compound = false>
-void Filter2DVerticalWidth4(const uint16_t* src, void* const dst,
+void Filter2DVerticalWidth4(const uint16_t* LIBGAV1_RESTRICT src,
+                            void* LIBGAV1_RESTRICT const dst,
                             const ptrdiff_t dst_stride, const int height,
                             const int16x8_t taps) {
   auto* dst8 = static_cast<uint8_t*>(dst);
@@ -626,7 +635,8 @@
 
 // Take advantage of |src_stride| == |width| to process four rows at a time.
 template <int num_taps>
-void Filter2DVerticalWidth2(const uint16_t* src, void* const dst,
+void Filter2DVerticalWidth2(const uint16_t* LIBGAV1_RESTRICT src,
+                            void* LIBGAV1_RESTRICT const dst,
                             const ptrdiff_t dst_stride, const int height,
                             const int16x8_t taps) {
   constexpr int next_row = (num_taps < 6) ? 4 : 8;
@@ -699,9 +709,10 @@
 
 template <bool is_2d = false, bool is_compound = false>
 LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
-    const uint8_t* const src, const ptrdiff_t src_stride, void* const dst,
-    const ptrdiff_t dst_stride, const int width, const int height,
-    const int filter_id, const int filter_index) {
+    const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride,
+    const int width, const int height, const int filter_id,
+    const int filter_index) {
   // Duplicate the absolute value for each tap.  Negative taps are corrected
   // by using the vmlsl_u8 instruction.  Positive taps use vmlal_u8.
   uint8x8_t v_tap[kSubPixelTaps];
@@ -739,9 +750,10 @@
 }
 
 template <int vertical_taps>
-void Filter2DVertical(const uint16_t* const intermediate_result,
-                      const int width, const int height, const int16x8_t taps,
-                      void* const prediction, const ptrdiff_t pred_stride) {
+void Filter2DVertical(
+    const uint16_t* LIBGAV1_RESTRICT const intermediate_result, const int width,
+    const int height, const int16x8_t taps,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
   auto* const dest = static_cast<uint8_t*>(prediction);
   if (width >= 8) {
     Filter2DVerticalWidth8AndUp<vertical_taps>(
@@ -756,13 +768,13 @@
   }
 }
 
-void Convolve2D_NEON(const void* const reference,
+void Convolve2D_NEON(const void* LIBGAV1_RESTRICT const reference,
                      const ptrdiff_t reference_stride,
                      const int horizontal_filter_index,
                      const int vertical_filter_index,
                      const int horizontal_filter_id,
                      const int vertical_filter_id, const int width,
-                     const int height, void* const prediction,
+                     const int height, void* LIBGAV1_RESTRICT const prediction,
                      const ptrdiff_t pred_stride) {
   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
@@ -772,6 +784,10 @@
   uint16_t
       intermediate_result[kMaxSuperBlockSizeInPixels *
                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
+#if LIBGAV1_MSAN
+  // Quiet msan warnings. Set with random non-zero value to aid in debugging.
+  memset(intermediate_result, 0x33, sizeof(intermediate_result));
+#endif
   const int intermediate_height = height + vertical_taps - 1;
   const ptrdiff_t src_stride = reference_stride;
   const auto* const src = static_cast<const uint8_t*>(reference) -
@@ -815,6 +831,10 @@
   const uint8x16_t src_val = vld1q_u8(src_x);
   ret.val[0] = vget_low_u8(src_val);
   ret.val[1] = vget_high_u8(src_val);
+#if LIBGAV1_MSAN
+  // Initialize to quiet msan warnings when grade_x <= 1.
+  ret.val[2] = vdup_n_u8(0);
+#endif
   if (grade_x > 1) {
     ret.val[2] = vld1_u8(src_x + 16);
   }
@@ -833,12 +853,10 @@
 }
 
 template <int grade_x>
-inline void ConvolveKernelHorizontal2Tap(const uint8_t* const src,
-                                         const ptrdiff_t src_stride,
-                                         const int width, const int subpixel_x,
-                                         const int step_x,
-                                         const int intermediate_height,
-                                         int16_t* intermediate) {
+inline void ConvolveKernelHorizontal2Tap(
+    const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int width, const int subpixel_x, const int step_x,
+    const int intermediate_height, int16_t* LIBGAV1_RESTRICT intermediate) {
   // Account for the 0-taps that precede the 2 nonzero taps.
   const int kernel_offset = 3;
   const int ref_x = subpixel_x >> kScaleSubPixelBits;
@@ -891,7 +909,6 @@
   do {
     const uint8_t* src_x =
         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
-    int16_t* intermediate_x = intermediate + x;
     // Only add steps to the 10-bit truncated p to avoid overflow.
     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
@@ -917,11 +934,11 @@
           vtbl3_u8(src_vals, src_indices),
           vtbl3_u8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))};
 
-      vst1q_s16(intermediate_x,
+      vst1q_s16(intermediate,
                 vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps),
                              kInterRoundBitsHorizontal - 1));
       src_x += src_stride;
-      intermediate_x += kIntermediateStride;
+      intermediate += kIntermediateStride;
     } while (--y != 0);
     x += 8;
     p += step_x8;
@@ -943,8 +960,9 @@
 
 // This filter is only possible when width <= 4.
 void ConvolveKernelHorizontalPositive4Tap(
-    const uint8_t* const src, const ptrdiff_t src_stride, const int subpixel_x,
-    const int step_x, const int intermediate_height, int16_t* intermediate) {
+    const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int subpixel_x, const int step_x, const int intermediate_height,
+    int16_t* LIBGAV1_RESTRICT intermediate) {
   const int kernel_offset = 2;
   const int ref_x = subpixel_x >> kScaleSubPixelBits;
   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
@@ -1010,8 +1028,9 @@
 
 // This filter is only possible when width <= 4.
 inline void ConvolveKernelHorizontalSigned4Tap(
-    const uint8_t* const src, const ptrdiff_t src_stride, const int subpixel_x,
-    const int step_x, const int intermediate_height, int16_t* intermediate) {
+    const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int subpixel_x, const int step_x, const int intermediate_height,
+    int16_t* LIBGAV1_RESTRICT intermediate) {
   const int kernel_offset = 2;
   const int ref_x = subpixel_x >> kScaleSubPixelBits;
   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
@@ -1085,9 +1104,10 @@
 // This filter is only possible when width >= 8.
 template <int grade_x>
 inline void ConvolveKernelHorizontalSigned6Tap(
-    const uint8_t* const src, const ptrdiff_t src_stride, const int width,
-    const int subpixel_x, const int step_x, const int intermediate_height,
-    int16_t* const intermediate) {
+    const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int width, const int subpixel_x, const int step_x,
+    const int intermediate_height,
+    int16_t* LIBGAV1_RESTRICT const intermediate) {
   const int kernel_offset = 1;
   const uint8x8_t one = vdup_n_u8(1);
   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
@@ -1100,6 +1120,7 @@
   const uint16x8_t index_steps = vmulq_n_u16(
       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
 
+  int16_t* intermediate_x = intermediate;
   int x = 0;
   int p = subpixel_x;
   do {
@@ -1107,7 +1128,6 @@
     // |trailing_width| can be up to 24.
     const uint8_t* src_x =
         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
-    int16_t* intermediate_x = intermediate + x;
     // Only add steps to the 10-bit truncated p to avoid overflow.
     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
@@ -1178,9 +1198,10 @@
 // This filter is only possible when width >= 8.
 template <int grade_x>
 inline void ConvolveKernelHorizontalMixed6Tap(
-    const uint8_t* const src, const ptrdiff_t src_stride, const int width,
-    const int subpixel_x, const int step_x, const int intermediate_height,
-    int16_t* const intermediate) {
+    const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int width, const int subpixel_x, const int step_x,
+    const int intermediate_height,
+    int16_t* LIBGAV1_RESTRICT const intermediate) {
   const int kernel_offset = 1;
   const uint8x8_t one = vdup_n_u8(1);
   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
@@ -1198,12 +1219,12 @@
   const uint16x8_t index_steps = vmulq_n_u16(
       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
 
+  int16_t* intermediate_x = intermediate;
   int x = 0;
   int p = subpixel_x;
   do {
     const uint8_t* src_x =
         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
-    int16_t* intermediate_x = intermediate + x;
     // Only add steps to the 10-bit truncated p to avoid overflow.
     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
@@ -1272,9 +1293,10 @@
 // This filter is only possible when width >= 8.
 template <int grade_x>
 inline void ConvolveKernelHorizontalSigned8Tap(
-    const uint8_t* const src, const ptrdiff_t src_stride, const int width,
-    const int subpixel_x, const int step_x, const int intermediate_height,
-    int16_t* const intermediate) {
+    const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    const int width, const int subpixel_x, const int step_x,
+    const int intermediate_height,
+    int16_t* LIBGAV1_RESTRICT const intermediate) {
   const uint8x8_t one = vdup_n_u8(1);
   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
   const int ref_x = subpixel_x >> kScaleSubPixelBits;
@@ -1286,11 +1308,12 @@
   }
   const uint16x8_t index_steps = vmulq_n_u16(
       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
+
+  int16_t* intermediate_x = intermediate;
   int x = 0;
   int p = subpixel_x;
   do {
     const uint8_t* src_x = &src[(p >> kScaleSubPixelBits) - ref_x];
-    int16_t* intermediate_x = intermediate + x;
     // Only add steps to the 10-bit truncated p to avoid overflow.
     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
@@ -1336,15 +1359,16 @@
 
 // This function handles blocks of width 2 or 4.
 template <int num_taps, int grade_y, int width, bool is_compound>
-void ConvolveVerticalScale4xH(const int16_t* const src, const int subpixel_y,
-                              const int filter_index, const int step_y,
-                              const int height, void* const dest,
+void ConvolveVerticalScale4xH(const int16_t* LIBGAV1_RESTRICT const src,
+                              const int subpixel_y, const int filter_index,
+                              const int step_y, const int height,
+                              void* LIBGAV1_RESTRICT const dest,
                               const ptrdiff_t dest_stride) {
   constexpr ptrdiff_t src_stride = kIntermediateStride;
   const int16_t* src_y = src;
   // |dest| is 16-bit in compound mode, Pixel otherwise.
-  uint16_t* dest16_y = static_cast<uint16_t*>(dest);
-  uint8_t* dest_y = static_cast<uint8_t*>(dest);
+  auto* dest16_y = static_cast<uint16_t*>(dest);
+  auto* dest_y = static_cast<uint8_t*>(dest);
   int16x4_t s[num_taps + grade_y];
 
   int p = subpixel_y & 1023;
@@ -1408,10 +1432,12 @@
 }
 
 template <int num_taps, int grade_y, bool is_compound>
-inline void ConvolveVerticalScale(const int16_t* const src, const int width,
-                                  const int subpixel_y, const int filter_index,
-                                  const int step_y, const int height,
-                                  void* const dest,
+inline void ConvolveVerticalScale(const int16_t* LIBGAV1_RESTRICT const source,
+                                  const int intermediate_height,
+                                  const int width, const int subpixel_y,
+                                  const int filter_index, const int step_y,
+                                  const int height,
+                                  void* LIBGAV1_RESTRICT const dest,
                                   const ptrdiff_t dest_stride) {
   constexpr ptrdiff_t src_stride = kIntermediateStride;
   // A possible improvement is to use arithmetic to decide how many times to
@@ -1421,11 +1447,11 @@
   // |dest| is 16-bit in compound mode, Pixel otherwise.
   uint16_t* dest16_y;
   uint8_t* dest_y;
+  const int16_t* src = source;
 
   int x = 0;
   do {
-    const int16_t* const src_x = src + x;
-    const int16_t* src_y = src_x;
+    const int16_t* src_y = src;
     dest16_y = static_cast<uint16_t*>(dest) + x;
     dest_y = static_cast<uint8_t*>(dest) + x;
     int p = subpixel_y & 1023;
@@ -1466,38 +1492,43 @@
         vst1_u8(dest_y, vqmovun_s16(sum));
       }
       p += step_y;
-      src_y = src_x + (p >> kScaleSubPixelBits) * src_stride;
+      src_y = src + (p >> kScaleSubPixelBits) * src_stride;
       prev_p = p;
       dest16_y += dest_stride;
       dest_y += dest_stride;
       y -= 2;
     } while (y != 0);
+    src += kIntermediateStride * intermediate_height;
     x += 8;
   } while (x < width);
 }
 
 template <bool is_compound>
-void ConvolveScale2D_NEON(const void* const reference,
+void ConvolveScale2D_NEON(const void* LIBGAV1_RESTRICT const reference,
                           const ptrdiff_t reference_stride,
                           const int horizontal_filter_index,
                           const int vertical_filter_index, const int subpixel_x,
                           const int subpixel_y, const int step_x,
                           const int step_y, const int width, const int height,
-                          void* const prediction, const ptrdiff_t pred_stride) {
+                          void* LIBGAV1_RESTRICT const prediction,
+                          const ptrdiff_t pred_stride) {
   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
   assert(step_x <= 2048);
+  assert(step_y <= 2048);
   const int num_vert_taps = GetNumTapsInFilter(vert_filter_index);
   const int intermediate_height =
       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
        kScaleSubPixelBits) +
       num_vert_taps;
-  assert(step_x <= 2048);
   // The output of the horizontal filter, i.e. the intermediate_result, is
   // guaranteed to fit in int16_t.
-  int16_t intermediate_result[kMaxSuperBlockSizeInPixels *
-                              (2 * kMaxSuperBlockSizeInPixels + 8)];
-
+  int16_t intermediate_result[kIntermediateAllocWidth *
+                              (2 * kIntermediateAllocWidth + 8)];
+#if LIBGAV1_MSAN
+  // Quiet msan warnings. Set with random non-zero value to aid in debugging.
+  memset(intermediate_result, 0x44, sizeof(intermediate_result));
+#endif
   // Horizontal filter.
   // Filter types used for width <= 4 are different from those for width > 4.
   // When width > 4, the valid filter index range is always [0, 3].
@@ -1597,8 +1628,8 @@
               prediction, pred_stride);
         } else {
           ConvolveVerticalScale<6, 1, is_compound>(
-              intermediate, width, subpixel_y, filter_index, step_y, height,
-              prediction, pred_stride);
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
         }
       } else {
         if (!is_compound && width == 2) {
@@ -1611,8 +1642,8 @@
               prediction, pred_stride);
         } else {
           ConvolveVerticalScale<6, 2, is_compound>(
-              intermediate, width, subpixel_y, filter_index, step_y, height,
-              prediction, pred_stride);
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
         }
       }
       break;
@@ -1628,8 +1659,8 @@
               prediction, pred_stride);
         } else {
           ConvolveVerticalScale<8, 1, is_compound>(
-              intermediate, width, subpixel_y, filter_index, step_y, height,
-              prediction, pred_stride);
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
         }
       } else {
         if (!is_compound && width == 2) {
@@ -1642,8 +1673,8 @@
               prediction, pred_stride);
         } else {
           ConvolveVerticalScale<8, 2, is_compound>(
-              intermediate, width, subpixel_y, filter_index, step_y, height,
-              prediction, pred_stride);
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
         }
       }
       break;
@@ -1659,8 +1690,8 @@
               prediction, pred_stride);
         } else {
           ConvolveVerticalScale<2, 1, is_compound>(
-              intermediate, width, subpixel_y, filter_index, step_y, height,
-              prediction, pred_stride);
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
         }
       } else {
         if (!is_compound && width == 2) {
@@ -1673,8 +1704,8 @@
               prediction, pred_stride);
         } else {
           ConvolveVerticalScale<2, 2, is_compound>(
-              intermediate, width, subpixel_y, filter_index, step_y, height,
-              prediction, pred_stride);
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
         }
       }
       break;
@@ -1693,8 +1724,8 @@
               prediction, pred_stride);
         } else {
           ConvolveVerticalScale<4, 1, is_compound>(
-              intermediate, width, subpixel_y, filter_index, step_y, height,
-              prediction, pred_stride);
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
         }
       } else {
         if (!is_compound && width == 2) {
@@ -1707,21 +1738,19 @@
               prediction, pred_stride);
         } else {
           ConvolveVerticalScale<4, 2, is_compound>(
-              intermediate, width, subpixel_y, filter_index, step_y, height,
-              prediction, pred_stride);
+              intermediate, intermediate_height, width, subpixel_y,
+              filter_index, step_y, height, prediction, pred_stride);
         }
       }
   }
 }
 
-void ConvolveHorizontal_NEON(const void* const reference,
-                             const ptrdiff_t reference_stride,
-                             const int horizontal_filter_index,
-                             const int /*vertical_filter_index*/,
-                             const int horizontal_filter_id,
-                             const int /*vertical_filter_id*/, const int width,
-                             const int height, void* const prediction,
-                             const ptrdiff_t pred_stride) {
+void ConvolveHorizontal_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int /*vertical_filter_index*/, const int horizontal_filter_id,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   // Set |src| to the outermost tap.
   const auto* const src =
@@ -1741,10 +1770,11 @@
 
 template <int filter_index, bool is_compound = false,
           bool negative_outside_taps = false>
-void FilterVertical(const uint8_t* const src, const ptrdiff_t src_stride,
-                    void* const dst, const ptrdiff_t dst_stride,
-                    const int width, const int height,
-                    const uint8x8_t* const taps) {
+void FilterVertical(const uint8_t* LIBGAV1_RESTRICT const src,
+                    const ptrdiff_t src_stride,
+                    void* LIBGAV1_RESTRICT const dst,
+                    const ptrdiff_t dst_stride, const int width,
+                    const int height, const uint8x8_t* const taps) {
   const int num_taps = GetNumTapsInFilter(filter_index);
   const int next_row = num_taps - 1;
   auto* const dst8 = static_cast<uint8_t*>(dst);
@@ -1814,9 +1844,11 @@
 
 template <int filter_index, bool is_compound = false,
           bool negative_outside_taps = false>
-void FilterVertical4xH(const uint8_t* src, const ptrdiff_t src_stride,
-                       void* const dst, const ptrdiff_t dst_stride,
-                       const int height, const uint8x8_t* const taps) {
+void FilterVertical4xH(const uint8_t* LIBGAV1_RESTRICT src,
+                       const ptrdiff_t src_stride,
+                       void* LIBGAV1_RESTRICT const dst,
+                       const ptrdiff_t dst_stride, const int height,
+                       const uint8x8_t* const taps) {
   const int num_taps = GetNumTapsInFilter(filter_index);
   auto* dst8 = static_cast<uint8_t*>(dst);
   auto* dst16 = static_cast<uint16_t*>(dst);
@@ -2001,9 +2033,11 @@
 }
 
 template <int filter_index, bool negative_outside_taps = false>
-void FilterVertical2xH(const uint8_t* src, const ptrdiff_t src_stride,
-                       void* const dst, const ptrdiff_t dst_stride,
-                       const int height, const uint8x8_t* const taps) {
+void FilterVertical2xH(const uint8_t* LIBGAV1_RESTRICT src,
+                       const ptrdiff_t src_stride,
+                       void* LIBGAV1_RESTRICT const dst,
+                       const ptrdiff_t dst_stride, const int height,
+                       const uint8x8_t* const taps) {
   const int num_taps = GetNumTapsInFilter(filter_index);
   auto* dst8 = static_cast<uint8_t*>(dst);
 
@@ -2205,14 +2239,12 @@
 // filtering is required.
 // The output is the single prediction of the block, clipped to valid pixel
 // range.
-void ConvolveVertical_NEON(const void* const reference,
-                           const ptrdiff_t reference_stride,
-                           const int /*horizontal_filter_index*/,
-                           const int vertical_filter_index,
-                           const int /*horizontal_filter_id*/,
-                           const int vertical_filter_id, const int width,
-                           const int height, void* const prediction,
-                           const ptrdiff_t pred_stride) {
+void ConvolveVertical_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int vertical_filter_index, const int /*horizontal_filter_id*/,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
   const int vertical_taps = GetNumTapsInFilter(filter_index);
   const ptrdiff_t src_stride = reference_stride;
@@ -2239,8 +2271,9 @@
       FilterVertical<0>(src, src_stride, dest, dest_stride, width, height,
                         taps + 1);
     }
-  } else if ((filter_index == 1) & ((vertical_filter_id == 1) |
-                                    (vertical_filter_id == 15))) {  // 5 tap.
+  } else if ((static_cast<int>(filter_index == 1) &
+              (static_cast<int>(vertical_filter_id == 1) |
+               static_cast<int>(vertical_filter_id == 15))) != 0) {  // 5 tap.
     if (width == 2) {
       FilterVertical2xH<1>(src, src_stride, dest, dest_stride, height,
                            taps + 1);
@@ -2251,9 +2284,11 @@
       FilterVertical<1>(src, src_stride, dest, dest_stride, width, height,
                         taps + 1);
     }
-  } else if ((filter_index == 1) &
-             ((vertical_filter_id == 7) | (vertical_filter_id == 8) |
-              (vertical_filter_id == 9))) {  // 6 tap with weird negative taps.
+  } else if ((static_cast<int>(filter_index == 1) &
+              (static_cast<int>(vertical_filter_id == 7) |
+               static_cast<int>(vertical_filter_id == 8) |
+               static_cast<int>(vertical_filter_id == 9))) !=
+             0) {  // 6 tap with weird negative taps.
     if (width == 2) {
       FilterVertical2xH<1,
                         /*negative_outside_taps=*/true>(
@@ -2325,11 +2360,11 @@
 }
 
 void ConvolveCompoundCopy_NEON(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/,
-    const int width, const int height, void* const prediction,
-    const ptrdiff_t /*pred_stride*/) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
   const auto* src = static_cast<const uint8_t*>(reference);
   const ptrdiff_t src_stride = reference_stride;
   auto* dest = static_cast<uint16_t*>(prediction);
@@ -2381,11 +2416,11 @@
 }
 
 void ConvolveCompoundVertical_NEON(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int /*horizontal_filter_index*/, const int vertical_filter_index,
-    const int /*horizontal_filter_id*/, const int vertical_filter_id,
-    const int width, const int height, void* const prediction,
-    const ptrdiff_t /*pred_stride*/) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int vertical_filter_index, const int /*horizontal_filter_id*/,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
   const int vertical_taps = GetNumTapsInFilter(filter_index);
   const ptrdiff_t src_stride = reference_stride;
@@ -2408,8 +2443,9 @@
       FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width,
                                               width, height, taps + 1);
     }
-  } else if ((filter_index == 1) & ((vertical_filter_id == 1) |
-                                    (vertical_filter_id == 15))) {  // 5 tap.
+  } else if ((static_cast<int>(filter_index == 1) &
+              (static_cast<int>(vertical_filter_id == 1) |
+               static_cast<int>(vertical_filter_id == 15))) != 0) {  // 5 tap.
     if (width == 4) {
       FilterVertical4xH<1, /*is_compound=*/true>(src, src_stride, dest, 4,
                                                  height, taps + 1);
@@ -2417,9 +2453,11 @@
       FilterVertical<1, /*is_compound=*/true>(src, src_stride, dest, width,
                                               width, height, taps + 1);
     }
-  } else if ((filter_index == 1) &
-             ((vertical_filter_id == 7) | (vertical_filter_id == 8) |
-              (vertical_filter_id == 9))) {  // 6 tap with weird negative taps.
+  } else if ((static_cast<int>(filter_index == 1) &
+              (static_cast<int>(vertical_filter_id == 7) |
+               static_cast<int>(vertical_filter_id == 8) |
+               static_cast<int>(vertical_filter_id == 9))) !=
+             0) {  // 6 tap with weird negative taps.
     if (width == 4) {
       FilterVertical4xH<1, /*is_compound=*/true,
                         /*negative_outside_taps=*/true>(src, src_stride, dest,
@@ -2476,11 +2514,11 @@
 }
 
 void ConvolveCompoundHorizontal_NEON(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int horizontal_filter_index, const int /*vertical_filter_index*/,
-    const int horizontal_filter_id, const int /*vertical_filter_id*/,
-    const int width, const int height, void* const prediction,
-    const ptrdiff_t /*pred_stride*/) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int /*vertical_filter_index*/, const int horizontal_filter_id,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   const auto* const src =
       static_cast<const uint8_t*>(reference) - kHorizontalOffset;
@@ -2492,9 +2530,10 @@
 }
 
 template <int vertical_taps>
-void Compound2DVertical(const uint16_t* const intermediate_result,
-                        const int width, const int height, const int16x8_t taps,
-                        void* const prediction) {
+void Compound2DVertical(
+    const uint16_t* LIBGAV1_RESTRICT const intermediate_result, const int width,
+    const int height, const int16x8_t taps,
+    void* LIBGAV1_RESTRICT const prediction) {
   auto* const dest = static_cast<uint16_t*>(prediction);
   if (width == 4) {
     Filter2DVerticalWidth4<vertical_taps, /*is_compound=*/true>(
@@ -2505,14 +2544,12 @@
   }
 }
 
-void ConvolveCompound2D_NEON(const void* const reference,
-                             const ptrdiff_t reference_stride,
-                             const int horizontal_filter_index,
-                             const int vertical_filter_index,
-                             const int horizontal_filter_id,
-                             const int vertical_filter_id, const int width,
-                             const int height, void* const prediction,
-                             const ptrdiff_t /*pred_stride*/) {
+void ConvolveCompound2D_NEON(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int vertical_filter_index, const int horizontal_filter_id,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
   // The output of the horizontal filter, i.e. the intermediate_result, is
   // guaranteed to fit in int16_t.
   uint16_t
@@ -2551,16 +2588,18 @@
   }
 }
 
-inline void HalfAddHorizontal(const uint8_t* const src, uint8_t* const dst) {
+inline void HalfAddHorizontal(const uint8_t* LIBGAV1_RESTRICT const src,
+                              uint8_t* LIBGAV1_RESTRICT const dst) {
   const uint8x16_t left = vld1q_u8(src);
   const uint8x16_t right = vld1q_u8(src + 1);
   vst1q_u8(dst, vrhaddq_u8(left, right));
 }
 
 template <int width>
-inline void IntraBlockCopyHorizontal(const uint8_t* src,
+inline void IntraBlockCopyHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
                                      const ptrdiff_t src_stride,
-                                     const int height, uint8_t* dst,
+                                     const int height,
+                                     uint8_t* LIBGAV1_RESTRICT dst,
                                      const ptrdiff_t dst_stride) {
   const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
@@ -2601,10 +2640,13 @@
 }
 
 void ConvolveIntraBlockCopyHorizontal_NEON(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
-    const int height, void* const prediction, const ptrdiff_t pred_stride) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*subpixel_x*/,
+    const int /*subpixel_y*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
+  assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
+  assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
   const auto* src = static_cast<const uint8_t*>(reference);
   auto* dest = static_cast<uint8_t*>(prediction);
 
@@ -2630,7 +2672,7 @@
       src += reference_stride;
       dest += pred_stride;
     } while (--y != 0);
-  } else if (width == 4) {
+  } else {  // width == 4
     uint8x8_t left = vdup_n_u8(0);
     uint8x8_t right = vdup_n_u8(0);
     int y = height;
@@ -2650,34 +2692,14 @@
       dest += pred_stride;
       y -= 2;
     } while (y != 0);
-  } else {
-    assert(width == 2);
-    uint8x8_t left = vdup_n_u8(0);
-    uint8x8_t right = vdup_n_u8(0);
-    int y = height;
-    do {
-      left = Load2<0>(src, left);
-      right = Load2<0>(src + 1, right);
-      src += reference_stride;
-      left = Load2<1>(src, left);
-      right = Load2<1>(src + 1, right);
-      src += reference_stride;
-
-      const uint8x8_t result = vrhadd_u8(left, right);
-
-      Store2<0>(dest, result);
-      dest += pred_stride;
-      Store2<1>(dest, result);
-      dest += pred_stride;
-      y -= 2;
-    } while (y != 0);
   }
 }
 
 template <int width>
-inline void IntraBlockCopyVertical(const uint8_t* src,
+inline void IntraBlockCopyVertical(const uint8_t* LIBGAV1_RESTRICT src,
                                    const ptrdiff_t src_stride, const int height,
-                                   uint8_t* dst, const ptrdiff_t dst_stride) {
+                                   uint8_t* LIBGAV1_RESTRICT dst,
+                                   const ptrdiff_t dst_stride) {
   const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
   uint8x16_t row[8], below[8];
@@ -2764,11 +2786,13 @@
 }
 
 void ConvolveIntraBlockCopyVertical_NEON(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/,
-    const int width, const int height, void* const prediction,
-    const ptrdiff_t pred_stride) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
+  assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
+  assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
   const auto* src = static_cast<const uint8_t*>(reference);
   auto* dest = static_cast<uint8_t*>(prediction);
 
@@ -2799,7 +2823,7 @@
 
       row = below;
     } while (--y != 0);
-  } else if (width == 4) {
+  } else {  // width == 4
     uint8x8_t row = Load4(src);
     uint8x8_t below = vdup_n_u8(0);
     src += reference_stride;
@@ -2814,28 +2838,13 @@
 
       row = below;
     } while (--y != 0);
-  } else {
-    assert(width == 2);
-    uint8x8_t row = Load2(src);
-    uint8x8_t below = vdup_n_u8(0);
-    src += reference_stride;
-
-    int y = height;
-    do {
-      below = Load2<0>(src, below);
-      src += reference_stride;
-
-      Store2<0>(dest, vrhadd_u8(row, below));
-      dest += pred_stride;
-
-      row = below;
-    } while (--y != 0);
   }
 }
 
 template <int width>
-inline void IntraBlockCopy2D(const uint8_t* src, const ptrdiff_t src_stride,
-                             const int height, uint8_t* dst,
+inline void IntraBlockCopy2D(const uint8_t* LIBGAV1_RESTRICT src,
+                             const ptrdiff_t src_stride, const int height,
+                             uint8_t* LIBGAV1_RESTRICT dst,
                              const ptrdiff_t dst_stride) {
   const ptrdiff_t src_remainder_stride = src_stride - (width - 8);
   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8);
@@ -2996,11 +3005,13 @@
 }
 
 void ConvolveIntraBlockCopy2D_NEON(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/,
-    const int width, const int height, void* const prediction,
-    const ptrdiff_t pred_stride) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
+  assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
+  assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
   const auto* src = static_cast<const uint8_t*>(reference);
   auto* dest = static_cast<uint8_t*>(prediction);
   // Note: allow vertical access to height + 1. Because this function is only
@@ -3017,7 +3028,7 @@
     IntraBlockCopy2D<16>(src, reference_stride, height, dest, pred_stride);
   } else if (width == 8) {
     IntraBlockCopy2D<8>(src, reference_stride, height, dest, pred_stride);
-  } else if (width == 4) {
+  } else {  // width == 4
     uint8x8_t left = Load4(src);
     uint8x8_t right = Load4(src + 1);
     src += reference_stride;
@@ -3045,34 +3056,6 @@
       row = vget_high_u16(below);
       y -= 2;
     } while (y != 0);
-  } else {
-    uint8x8_t left = Load2(src);
-    uint8x8_t right = Load2(src + 1);
-    src += reference_stride;
-
-    uint16x4_t row = vget_low_u16(vaddl_u8(left, right));
-
-    int y = height;
-    do {
-      left = Load2<0>(src, left);
-      right = Load2<0>(src + 1, right);
-      src += reference_stride;
-      left = Load2<2>(src, left);
-      right = Load2<2>(src + 1, right);
-      src += reference_stride;
-
-      const uint16x8_t below = vaddl_u8(left, right);
-
-      const uint8x8_t result = vrshrn_n_u16(
-          vaddq_u16(vcombine_u16(row, vget_low_u16(below)), below), 2);
-      Store2<0>(dest, result);
-      dest += pred_stride;
-      Store2<2>(dest, result);
-      dest += pred_stride;
-
-      row = vget_high_u16(below);
-      y -= 2;
-    } while (y != 0);
   }
 }
 
diff --git a/libgav1/src/dsp/arm/convolve_neon.h b/libgav1/src/dsp/arm/convolve_neon.h
index 948ef4d..9c67bc9 100644
--- a/libgav1/src/dsp/arm/convolve_neon.h
+++ b/libgav1/src/dsp/arm/convolve_neon.h
@@ -25,6 +25,7 @@
 
 // Initializes Dsp::convolve. This function is not thread-safe.
 void ConvolveInit_NEON();
+void ConvolveInit10bpp_NEON();
 
 }  // namespace dsp
 }  // namespace libgav1
@@ -45,6 +46,22 @@
 
 #define LIBGAV1_Dsp8bpp_ConvolveScale2D LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_ConvolveCompoundScale2D LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_ConvolveHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Convolve2D LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_ConvolveCompoundCopy LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveCompoundHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveCompoundVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveCompound2D LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_ConvolveIntraBlockCopyHorizontal LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveIntraBlockCopyVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveIntraBlockCopy2D LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_ConvolveScale2D LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ConvolveCompoundScale2D LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_CONVOLVE_NEON_H_
diff --git a/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc b/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
index a0cd0ac..7d287c8 100644
--- a/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
+++ b/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc
@@ -52,11 +52,10 @@
 }
 
 template <int width, int height>
-inline void DistanceWeightedBlendSmall_NEON(const int16_t* prediction_0,
-                                            const int16_t* prediction_1,
-                                            const int16x4_t weights[2],
-                                            void* const dest,
-                                            const ptrdiff_t dest_stride) {
+inline void DistanceWeightedBlendSmall_NEON(
+    const int16_t* LIBGAV1_RESTRICT prediction_0,
+    const int16_t* LIBGAV1_RESTRICT prediction_1, const int16x4_t weights[2],
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
   constexpr int step = 16 / width;
 
@@ -94,12 +93,11 @@
   }
 }
 
-inline void DistanceWeightedBlendLarge_NEON(const int16_t* prediction_0,
-                                            const int16_t* prediction_1,
-                                            const int16x4_t weights[2],
-                                            const int width, const int height,
-                                            void* const dest,
-                                            const ptrdiff_t dest_stride) {
+inline void DistanceWeightedBlendLarge_NEON(
+    const int16_t* LIBGAV1_RESTRICT prediction_0,
+    const int16_t* LIBGAV1_RESTRICT prediction_1, const int16x4_t weights[2],
+    const int width, const int height, void* LIBGAV1_RESTRICT const dest,
+    const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
 
   int y = height;
@@ -127,12 +125,11 @@
   } while (--y != 0);
 }
 
-inline void DistanceWeightedBlend_NEON(const void* prediction_0,
-                                       const void* prediction_1,
-                                       const uint8_t weight_0,
-                                       const uint8_t weight_1, const int width,
-                                       const int height, void* const dest,
-                                       const ptrdiff_t dest_stride) {
+inline void DistanceWeightedBlend_NEON(
+    const void* LIBGAV1_RESTRICT prediction_0,
+    const void* LIBGAV1_RESTRICT prediction_1, const uint8_t weight_0,
+    const uint8_t weight_1, const int width, const int height,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int16x4_t weights[2] = {vdup_n_s16(weight_0), vdup_n_s16(weight_1)};
@@ -267,11 +264,12 @@
   return x;
 }
 
-void DistanceWeightedBlend_NEON(const void* prediction_0,
-                                const void* prediction_1,
+void DistanceWeightedBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                                const void* LIBGAV1_RESTRICT prediction_1,
                                 const uint8_t weight_0, const uint8_t weight_1,
                                 const int width, const int height,
-                                void* const dest, const ptrdiff_t dest_stride) {
+                                void* LIBGAV1_RESTRICT const dest,
+                                const ptrdiff_t dest_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
   auto* dst = static_cast<uint16_t*>(dest);
diff --git a/libgav1/src/dsp/arm/film_grain_neon.cc b/libgav1/src/dsp/arm/film_grain_neon.cc
index 8ee3745..0b1b481 100644
--- a/libgav1/src/dsp/arm/film_grain_neon.cc
+++ b/libgav1/src/dsp/arm/film_grain_neon.cc
@@ -34,6 +34,7 @@
 #include "src/utils/common.h"
 #include "src/utils/compiler_attributes.h"
 #include "src/utils/logging.h"
+#include "src/utils/memory.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -51,6 +52,12 @@
   return ZeroExtend(vld1_u8(src));
 }
 
+inline int16x8_t GetSignedSource8Msan(const uint8_t* src, int /*valid_range*/) {
+  // TODO(b/194217060): restore |valid_range| usage after correcting call sites
+  // causing test vector failures.
+  return ZeroExtend(Load1MsanU8(src, 0));
+}
+
 inline void StoreUnsigned8(uint8_t* dest, const uint16x8_t data) {
   vst1_u8(dest, vmovn_u16(data));
 }
@@ -62,6 +69,13 @@
   return vreinterpretq_s16_u16(vld1q_u16(src));
 }
 
+inline int16x8_t GetSignedSource8Msan(const uint16_t* src,
+                                      int /*valid_range*/) {
+  // TODO(b/194217060): restore |valid_range| usage after correcting call sites
+  // causing test vector failures.
+  return vreinterpretq_s16_u16(Load1QMsanU16(src, 0));
+}
+
 inline void StoreUnsigned8(uint16_t* dest, const uint16x8_t data) {
   vst1q_u16(dest, data);
 }
@@ -84,8 +98,10 @@
 // compute pixels that come after in the row, we have to finish the calculations
 // one at a time.
 template <int bitdepth, int auto_regression_coeff_lag, int lane>
-inline void WriteFinalAutoRegression(int8_t* grain_cursor, int32x4x2_t sum,
-                                     const int8_t* coeffs, int pos, int shift) {
+inline void WriteFinalAutoRegression(int8_t* LIBGAV1_RESTRICT grain_cursor,
+                                     int32x4x2_t sum,
+                                     const int8_t* LIBGAV1_RESTRICT coeffs,
+                                     int pos, int shift) {
   int32_t result = vgetq_lane_s32(sum.val[lane >> 2], lane & 3);
 
   for (int delta_col = -auto_regression_coeff_lag; delta_col < 0; ++delta_col) {
@@ -99,8 +115,10 @@
 
 #if LIBGAV1_MAX_BITDEPTH >= 10
 template <int bitdepth, int auto_regression_coeff_lag, int lane>
-inline void WriteFinalAutoRegression(int16_t* grain_cursor, int32x4x2_t sum,
-                                     const int8_t* coeffs, int pos, int shift) {
+inline void WriteFinalAutoRegression(int16_t* LIBGAV1_RESTRICT grain_cursor,
+                                     int32x4x2_t sum,
+                                     const int8_t* LIBGAV1_RESTRICT coeffs,
+                                     int pos, int shift) {
   int32_t result = vgetq_lane_s32(sum.val[lane >> 2], lane & 3);
 
   for (int delta_col = -auto_regression_coeff_lag; delta_col < 0; ++delta_col) {
@@ -117,12 +135,11 @@
 // compute pixels that come after in the row, we have to finish the calculations
 // one at a time.
 template <int bitdepth, int auto_regression_coeff_lag, int lane>
-inline void WriteFinalAutoRegressionChroma(int8_t* u_grain_cursor,
-                                           int8_t* v_grain_cursor,
-                                           int32x4x2_t sum_u, int32x4x2_t sum_v,
-                                           const int8_t* coeffs_u,
-                                           const int8_t* coeffs_v, int pos,
-                                           int shift) {
+inline void WriteFinalAutoRegressionChroma(
+    int8_t* LIBGAV1_RESTRICT u_grain_cursor,
+    int8_t* LIBGAV1_RESTRICT v_grain_cursor, int32x4x2_t sum_u,
+    int32x4x2_t sum_v, const int8_t* LIBGAV1_RESTRICT coeffs_u,
+    const int8_t* LIBGAV1_RESTRICT coeffs_v, int pos, int shift) {
   WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
       u_grain_cursor, sum_u, coeffs_u, pos, shift);
   WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
@@ -131,12 +148,11 @@
 
 #if LIBGAV1_MAX_BITDEPTH >= 10
 template <int bitdepth, int auto_regression_coeff_lag, int lane>
-inline void WriteFinalAutoRegressionChroma(int16_t* u_grain_cursor,
-                                           int16_t* v_grain_cursor,
-                                           int32x4x2_t sum_u, int32x4x2_t sum_v,
-                                           const int8_t* coeffs_u,
-                                           const int8_t* coeffs_v, int pos,
-                                           int shift) {
+inline void WriteFinalAutoRegressionChroma(
+    int16_t* LIBGAV1_RESTRICT u_grain_cursor,
+    int16_t* LIBGAV1_RESTRICT v_grain_cursor, int32x4x2_t sum_u,
+    int32x4x2_t sum_v, const int8_t* LIBGAV1_RESTRICT coeffs_u,
+    const int8_t* LIBGAV1_RESTRICT coeffs_v, int pos, int shift) {
   WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
       u_grain_cursor, sum_u, coeffs_u, pos, shift);
   WriteFinalAutoRegression<bitdepth, auto_regression_coeff_lag, lane>(
@@ -181,6 +197,20 @@
   return vmovl_u8(vld1_u8(luma));
 }
 
+inline uint16x8_t GetAverageLumaMsan(const uint8_t* const luma,
+                                     int subsampling_x, int /*valid_range*/) {
+  if (subsampling_x != 0) {
+    // TODO(b/194217060): restore |valid_range| usage after correcting call
+    // sites causing test vector failures.
+    const uint8x16_t src = Load1QMsanU8(luma, 0);
+
+    return vrshrq_n_u16(vpaddlq_u8(src), 1);
+  }
+  // TODO(b/194217060): restore |valid_range| usage after correcting call sites
+  // causing test vector failures.
+  return vmovl_u8(Load1MsanU8(luma, 0));
+}
+
 #if LIBGAV1_MAX_BITDEPTH >= 10
 // Computes subsampled luma for use with chroma, by averaging in the x direction
 // or y direction when applicable.
@@ -220,16 +250,28 @@
   }
   return vld1q_u16(luma);
 }
+
+inline uint16x8_t GetAverageLumaMsan(const uint16_t* const luma,
+                                     int subsampling_x, int /*valid_range*/) {
+  if (subsampling_x != 0) {
+    // TODO(b/194217060): restore |valid_range| usage after correcting call
+    // sites causing test vector failures.
+    const uint16x8x2_t src = Load2QMsanU16(luma, 0);
+    return vrhaddq_u16(src.val[0], src.val[1]);
+  }
+  // TODO(b/194217060): restore |valid_range| usage after correcting call sites
+  // causing test vector failures.
+  return Load1QMsanU16(luma, 0);
+}
 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
 
 template <int bitdepth, typename GrainType, int auto_regression_coeff_lag,
           bool use_luma>
-void ApplyAutoRegressiveFilterToChromaGrains_NEON(const FilmGrainParams& params,
-                                                  const void* luma_grain_buffer,
-                                                  int subsampling_x,
-                                                  int subsampling_y,
-                                                  void* u_grain_buffer,
-                                                  void* v_grain_buffer) {
+void ApplyAutoRegressiveFilterToChromaGrains_NEON(
+    const FilmGrainParams& params,
+    const void* LIBGAV1_RESTRICT luma_grain_buffer, int subsampling_x,
+    int subsampling_y, void* LIBGAV1_RESTRICT u_grain_buffer,
+    void* LIBGAV1_RESTRICT v_grain_buffer) {
   static_assert(auto_regression_coeff_lag <= 3, "Invalid autoregression lag.");
   const auto* luma_grain = static_cast<const GrainType*>(luma_grain_buffer);
   auto* u_grain = static_cast<GrainType*>(u_grain_buffer);
@@ -558,49 +600,93 @@
 #undef ACCUMULATE_WEIGHTED_GRAIN
 }
 
-void InitializeScalingLookupTable_NEON(
-    int num_points, const uint8_t point_value[], const uint8_t point_scaling[],
-    uint8_t scaling_lut[kScalingLookupTableSize]) {
+template <int bitdepth>
+void InitializeScalingLookupTable_NEON(int num_points,
+                                       const uint8_t point_value[],
+                                       const uint8_t point_scaling[],
+                                       int16_t* scaling_lut,
+                                       const int scaling_lut_length) {
+  static_assert(bitdepth < kBitdepth12,
+                "NEON Scaling lookup table only supports 8bpp and 10bpp.");
   if (num_points == 0) {
-    memset(scaling_lut, 0, sizeof(scaling_lut[0]) * kScalingLookupTableSize);
+    memset(scaling_lut, 0, sizeof(scaling_lut[0]) * scaling_lut_length);
     return;
   }
-  static_assert(sizeof(scaling_lut[0]) == 1, "");
-  memset(scaling_lut, point_scaling[0], point_value[0]);
-  const uint32x4_t steps = vmovl_u16(vcreate_u16(0x0003000200010000));
-  const uint32x4_t offset = vdupq_n_u32(32768);
+  static_assert(sizeof(scaling_lut[0]) == 2, "");
+  Memset(scaling_lut, point_scaling[0],
+         std::max(static_cast<int>(point_value[0]), 1)
+             << (bitdepth - kBitdepth8));
+  const int32x4_t steps = vmovl_s16(vcreate_s16(0x0003000200010000));
+  const int32x4_t rounding = vdupq_n_s32(32768);
   for (int i = 0; i < num_points - 1; ++i) {
     const int delta_y = point_scaling[i + 1] - point_scaling[i];
     const int delta_x = point_value[i + 1] - point_value[i];
+    // |delta| corresponds to b, for the function y = a + b*x.
     const int delta = delta_y * ((65536 + (delta_x >> 1)) / delta_x);
     const int delta4 = delta << 2;
-    const uint8x8_t base_point = vdup_n_u8(point_scaling[i]);
-    uint32x4_t upscaled_points0 = vmlaq_n_u32(offset, steps, delta);
-    const uint32x4_t line_increment4 = vdupq_n_u32(delta4);
+    // vmull_n_u16 will not work here because |delta| typically exceeds the
+    // range of uint16_t.
+    int32x4_t upscaled_points0 = vmlaq_n_s32(rounding, steps, delta);
+    const int32x4_t line_increment4 = vdupq_n_s32(delta4);
     // Get the second set of 4 points by adding 4 steps to the first set.
-    uint32x4_t upscaled_points1 = vaddq_u32(upscaled_points0, line_increment4);
+    int32x4_t upscaled_points1 = vaddq_s32(upscaled_points0, line_increment4);
     // We obtain the next set of 8 points by adding 8 steps to each of the
     // current 8 points.
-    const uint32x4_t line_increment8 = vshlq_n_u32(line_increment4, 1);
+    const int32x4_t line_increment8 = vshlq_n_s32(line_increment4, 1);
+    const int16x8_t base_point = vdupq_n_s16(point_scaling[i]);
     int x = 0;
+    // Derive and write 8 values (or 32 values, for 10bpp).
     do {
-      const uint16x4_t interp_points0 = vshrn_n_u32(upscaled_points0, 16);
-      const uint16x4_t interp_points1 = vshrn_n_u32(upscaled_points1, 16);
-      const uint8x8_t interp_points =
-          vmovn_u16(vcombine_u16(interp_points0, interp_points1));
+      const int16x4_t interp_points0 = vshrn_n_s32(upscaled_points0, 16);
+      const int16x4_t interp_points1 = vshrn_n_s32(upscaled_points1, 16);
+      const int16x8_t interp_points =
+          vcombine_s16(interp_points0, interp_points1);
       // The spec guarantees that the max value of |point_value[i]| + x is 255.
-      // Writing 8 bytes starting at the final table byte, leaves 7 bytes of
+      // Writing 8 values starting at the final table byte, leaves 7 values of
       // required padding.
-      vst1_u8(&scaling_lut[point_value[i] + x],
-              vadd_u8(interp_points, base_point));
-      upscaled_points0 = vaddq_u32(upscaled_points0, line_increment8);
-      upscaled_points1 = vaddq_u32(upscaled_points1, line_increment8);
+      const int16x8_t full_interp = vaddq_s16(interp_points, base_point);
+      const int x_base = (point_value[i] + x) << (bitdepth - kBitdepth8);
+      if (bitdepth == kBitdepth10) {
+        const int16x8_t next_val = vaddq_s16(
+            base_point,
+            vdupq_n_s16((vgetq_lane_s32(upscaled_points1, 3) + delta) >> 16));
+        const int16x8_t start = full_interp;
+        const int16x8_t end = vextq_s16(full_interp, next_val, 1);
+        // lut[i << 2] = start;
+        // lut[(i << 2) + 1] = start + RightShiftWithRounding(start - end, 2)
+        // lut[(i << 2) + 2] = start +
+        //                      RightShiftWithRounding(2 * (start - end), 2)
+        // lut[(i << 2) + 3] = start +
+        //                      RightShiftWithRounding(3 * (start - end), 2)
+        const int16x8_t delta = vsubq_s16(end, start);
+        const int16x8_t double_delta = vshlq_n_s16(delta, 1);
+        const int16x8_t delta2 = vrshrq_n_s16(double_delta, 2);
+        const int16x8_t delta3 =
+            vrshrq_n_s16(vaddq_s16(delta, double_delta), 2);
+        const int16x8x4_t result = {
+            start, vaddq_s16(start, vrshrq_n_s16(delta, 2)),
+            vaddq_s16(start, delta2), vaddq_s16(start, delta3)};
+        vst4q_s16(&scaling_lut[x_base], result);
+      } else {
+        vst1q_s16(&scaling_lut[x_base], full_interp);
+      }
+      upscaled_points0 = vaddq_s32(upscaled_points0, line_increment8);
+      upscaled_points1 = vaddq_s32(upscaled_points1, line_increment8);
       x += 8;
     } while (x < delta_x);
   }
-  const uint8_t last_point_value = point_value[num_points - 1];
-  memset(&scaling_lut[last_point_value], point_scaling[num_points - 1],
-         kScalingLookupTableSize - last_point_value);
+  const int16_t last_point_value = point_value[num_points - 1];
+  const int x_base = last_point_value << (bitdepth - kBitdepth8);
+  Memset(&scaling_lut[x_base], point_scaling[num_points - 1],
+         scaling_lut_length - x_base);
+  if (bitdepth == kBitdepth10 && x_base > 0) {
+    const int start = scaling_lut[x_base - 4];
+    const int end = point_scaling[num_points - 1];
+    const int delta = end - start;
+    scaling_lut[x_base - 3] = start + RightShiftWithRounding(delta, 2);
+    scaling_lut[x_base - 2] = start + RightShiftWithRounding(2 * delta, 2);
+    scaling_lut[x_base - 1] = start + RightShiftWithRounding(3 * delta, 2);
+  }
 }
 
 inline int16x8_t Clip3(const int16x8_t value, const int16x8_t low,
@@ -611,86 +697,38 @@
 
 template <int bitdepth, typename Pixel>
 inline int16x8_t GetScalingFactors(
-    const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* source) {
+    const int16_t scaling_lut[kScalingLookupTableSize], const Pixel* source) {
   int16_t start_vals[8];
-  if (bitdepth == 8) {
-    start_vals[0] = scaling_lut[source[0]];
-    start_vals[1] = scaling_lut[source[1]];
-    start_vals[2] = scaling_lut[source[2]];
-    start_vals[3] = scaling_lut[source[3]];
-    start_vals[4] = scaling_lut[source[4]];
-    start_vals[5] = scaling_lut[source[5]];
-    start_vals[6] = scaling_lut[source[6]];
-    start_vals[7] = scaling_lut[source[7]];
-    return vld1q_s16(start_vals);
+  static_assert(bitdepth <= kBitdepth10,
+                "NEON Film Grain is not yet implemented for 12bpp.");
+  for (int i = 0; i < 8; ++i) {
+    assert(source[i] < kScalingLookupTableSize << (bitdepth - 2));
+    start_vals[i] = scaling_lut[source[i]];
   }
-  int16_t end_vals[8];
-  // TODO(petersonab): Precompute this into a larger table for direct lookups.
-  int index = source[0] >> 2;
-  start_vals[0] = scaling_lut[index];
-  end_vals[0] = scaling_lut[index + 1];
-  index = source[1] >> 2;
-  start_vals[1] = scaling_lut[index];
-  end_vals[1] = scaling_lut[index + 1];
-  index = source[2] >> 2;
-  start_vals[2] = scaling_lut[index];
-  end_vals[2] = scaling_lut[index + 1];
-  index = source[3] >> 2;
-  start_vals[3] = scaling_lut[index];
-  end_vals[3] = scaling_lut[index + 1];
-  index = source[4] >> 2;
-  start_vals[4] = scaling_lut[index];
-  end_vals[4] = scaling_lut[index + 1];
-  index = source[5] >> 2;
-  start_vals[5] = scaling_lut[index];
-  end_vals[5] = scaling_lut[index + 1];
-  index = source[6] >> 2;
-  start_vals[6] = scaling_lut[index];
-  end_vals[6] = scaling_lut[index + 1];
-  index = source[7] >> 2;
-  start_vals[7] = scaling_lut[index];
-  end_vals[7] = scaling_lut[index + 1];
-  const int16x8_t start = vld1q_s16(start_vals);
-  const int16x8_t end = vld1q_s16(end_vals);
-  int16x8_t remainder = GetSignedSource8(source);
-  remainder = vandq_s16(remainder, vdupq_n_s16(3));
-  const int16x8_t delta = vmulq_s16(vsubq_s16(end, start), remainder);
-  return vaddq_s16(start, vrshrq_n_s16(delta, 2));
+  return vld1q_s16(start_vals);
 }
 
+template <int bitdepth>
 inline int16x8_t ScaleNoise(const int16x8_t noise, const int16x8_t scaling,
                             const int16x8_t scaling_shift_vect) {
-  const int16x8_t upscaled_noise = vmulq_s16(noise, scaling);
-  return vrshlq_s16(upscaled_noise, scaling_shift_vect);
+  if (bitdepth == kBitdepth8) {
+    const int16x8_t upscaled_noise = vmulq_s16(noise, scaling);
+    return vrshlq_s16(upscaled_noise, scaling_shift_vect);
+  }
+  // Scaling shift is in the range [8, 11]. The doubling multiply returning high
+  // half is equivalent to a right shift by 15, so |scaling_shift_vect| should
+  // provide a left shift equal to 15 - s, where s is the original shift
+  // parameter.
+  const int16x8_t scaling_up = vshlq_s16(scaling, scaling_shift_vect);
+  return vqrdmulhq_s16(noise, scaling_up);
 }
 
-#if LIBGAV1_MAX_BITDEPTH >= 10
-inline int16x8_t ScaleNoise(const int16x8_t noise, const int16x8_t scaling,
-                            const int32x4_t scaling_shift_vect) {
-  // TODO(petersonab): Try refactoring scaling lookup table to int16_t and
-  // upscaling by 7 bits to permit high half multiply. This would eliminate
-  // the intermediate 32x4 registers. Also write the averaged values directly
-  // into the table so it doesn't have to be done for every pixel in
-  // the frame.
-  const int32x4_t upscaled_noise_lo =
-      vmull_s16(vget_low_s16(noise), vget_low_s16(scaling));
-  const int32x4_t upscaled_noise_hi =
-      vmull_s16(vget_high_s16(noise), vget_high_s16(scaling));
-  const int16x4_t noise_lo =
-      vmovn_s32(vrshlq_s32(upscaled_noise_lo, scaling_shift_vect));
-  const int16x4_t noise_hi =
-      vmovn_s32(vrshlq_s32(upscaled_noise_hi, scaling_shift_vect));
-  return vcombine_s16(noise_lo, noise_hi);
-}
-#endif  // LIBGAV1_MAX_BITDEPTH >= 10
-
 template <int bitdepth, typename GrainType, typename Pixel>
 void BlendNoiseWithImageLuma_NEON(
-    const void* noise_image_ptr, int min_value, int max_luma, int scaling_shift,
-    int width, int height, int start_height,
-    const uint8_t scaling_lut_y[kScalingLookupTableSize],
-    const void* source_plane_y, ptrdiff_t source_stride_y, void* dest_plane_y,
-    ptrdiff_t dest_stride_y) {
+    const void* LIBGAV1_RESTRICT noise_image_ptr, int min_value, int max_luma,
+    int scaling_shift, int width, int height, int start_height,
+    const int16_t* scaling_lut_y, const void* source_plane_y,
+    ptrdiff_t source_stride_y, void* dest_plane_y, ptrdiff_t dest_stride_y) {
   const auto* noise_image =
       static_cast<const Array2D<GrainType>*>(noise_image_ptr);
   const auto* in_y_row = static_cast<const Pixel*>(source_plane_y);
@@ -702,10 +740,8 @@
   // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe
   // for 16 bit signed integers. In higher bitdepths, however, we have to
   // expand to 32 to protect the sign bit.
-  const int16x8_t scaling_shift_vect16 = vdupq_n_s16(-scaling_shift);
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  const int32x4_t scaling_shift_vect32 = vdupq_n_s32(-scaling_shift);
-#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+  const int16x8_t scaling_shift_vect = vdupq_n_s16(
+      (bitdepth == kBitdepth10) ? 15 - scaling_shift : -scaling_shift);
 
   int y = 0;
   do {
@@ -713,25 +749,35 @@
     do {
       // This operation on the unsigned input is safe in 8bpp because the vector
       // is widened before it is reinterpreted.
-      const int16x8_t orig = GetSignedSource8(&in_y_row[x]);
-      const int16x8_t scaling =
+      const int16x8_t orig0 = GetSignedSource8(&in_y_row[x]);
+      const int16x8_t scaling0 =
           GetScalingFactors<bitdepth, Pixel>(scaling_lut_y, &in_y_row[x]);
       int16x8_t noise =
           GetSignedSource8(&(noise_image[kPlaneY][y + start_height][x]));
 
-      if (bitdepth == 8) {
-        noise = ScaleNoise(noise, scaling, scaling_shift_vect16);
-      } else {
-#if LIBGAV1_MAX_BITDEPTH >= 10
-        noise = ScaleNoise(noise, scaling, scaling_shift_vect32);
-#endif  // LIBGAV1_MAX_BITDEPTH >= 10
-      }
-      const int16x8_t combined = vaddq_s16(orig, noise);
+      noise = ScaleNoise<bitdepth>(noise, scaling0, scaling_shift_vect);
+      const int16x8_t combined0 = vaddq_s16(orig0, noise);
       // In 8bpp, when params_.clip_to_restricted_range == false, we can replace
       // clipping with vqmovun_s16, but it's not likely to be worth copying the
       // function for just that case, though the gain would be very small.
       StoreUnsigned8(&out_y_row[x],
-                     vreinterpretq_u16_s16(Clip3(combined, floor, ceiling)));
+                     vreinterpretq_u16_s16(Clip3(combined0, floor, ceiling)));
+      x += 8;
+
+      // This operation on the unsigned input is safe in 8bpp because the vector
+      // is widened before it is reinterpreted.
+      const int16x8_t orig1 = GetSignedSource8(&in_y_row[x]);
+      const int16x8_t scaling1 = GetScalingFactors<bitdepth, Pixel>(
+          scaling_lut_y, &in_y_row[std::min(x, width)]);
+      noise = GetSignedSource8(&(noise_image[kPlaneY][y + start_height][x]));
+
+      noise = ScaleNoise<bitdepth>(noise, scaling1, scaling_shift_vect);
+      const int16x8_t combined1 = vaddq_s16(orig1, noise);
+      // In 8bpp, when params_.clip_to_restricted_range == false, we can replace
+      // clipping with vqmovun_s16, but it's not likely to be worth copying the
+      // function for just that case, though the gain would be very small.
+      StoreUnsigned8(&out_y_row[x],
+                     vreinterpretq_u16_s16(Clip3(combined1, floor, ceiling)));
       x += 8;
     } while (x < width);
     in_y_row += source_stride_y;
@@ -741,20 +787,16 @@
 
 template <int bitdepth, typename GrainType, typename Pixel>
 inline int16x8_t BlendChromaValsWithCfl(
-    const Pixel* average_luma_buffer,
-    const uint8_t scaling_lut[kScalingLookupTableSize],
-    const Pixel* chroma_cursor, const GrainType* noise_image_cursor,
-    const int16x8_t scaling_shift_vect16,
-    const int32x4_t scaling_shift_vect32) {
+    const Pixel* LIBGAV1_RESTRICT average_luma_buffer,
+    const int16_t* LIBGAV1_RESTRICT scaling_lut,
+    const Pixel* LIBGAV1_RESTRICT chroma_cursor,
+    const GrainType* LIBGAV1_RESTRICT noise_image_cursor,
+    const int16x8_t scaling_shift_vect) {
   const int16x8_t scaling =
       GetScalingFactors<bitdepth, Pixel>(scaling_lut, average_luma_buffer);
   const int16x8_t orig = GetSignedSource8(chroma_cursor);
   int16x8_t noise = GetSignedSource8(noise_image_cursor);
-  if (bitdepth == 8) {
-    noise = ScaleNoise(noise, scaling, scaling_shift_vect16);
-  } else {
-    noise = ScaleNoise(noise, scaling, scaling_shift_vect32);
-  }
+  noise = ScaleNoise<bitdepth>(noise, scaling, scaling_shift_vect);
   return vaddq_s16(orig, noise);
 }
 
@@ -763,10 +805,10 @@
     const Array2D<GrainType>& noise_image, int min_value, int max_chroma,
     int width, int height, int start_height, int subsampling_x,
     int subsampling_y, int scaling_shift,
-    const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* in_y_row,
-    ptrdiff_t source_stride_y, const Pixel* in_chroma_row,
-    ptrdiff_t source_stride_chroma, Pixel* out_chroma_row,
-    ptrdiff_t dest_stride) {
+    const int16_t* LIBGAV1_RESTRICT scaling_lut,
+    const Pixel* LIBGAV1_RESTRICT in_y_row, ptrdiff_t source_stride_y,
+    const Pixel* in_chroma_row, ptrdiff_t source_stride_chroma,
+    Pixel* out_chroma_row, ptrdiff_t dest_stride) {
   const int16x8_t floor = vdupq_n_s16(min_value);
   const int16x8_t ceiling = vdupq_n_s16(max_chroma);
   Pixel luma_buffer[16];
@@ -774,8 +816,8 @@
   // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe
   // for 16 bit signed integers. In higher bitdepths, however, we have to
   // expand to 32 to protect the sign bit.
-  const int16x8_t scaling_shift_vect16 = vdupq_n_s16(-scaling_shift);
-  const int32x4_t scaling_shift_vect32 = vdupq_n_s32(-scaling_shift);
+  const int16x8_t scaling_shift_vect = vdupq_n_s16(
+      (bitdepth == kBitdepth10) ? 15 - scaling_shift : -scaling_shift);
 
   const int chroma_height = (height + subsampling_y) >> subsampling_y;
   const int chroma_width = (width + subsampling_x) >> subsampling_x;
@@ -791,8 +833,6 @@
     int x = 0;
     do {
       const int luma_x = x << subsampling_x;
-      // TODO(petersonab): Consider specializing by subsampling_x. In the 444
-      // case &in_y_row[x] can be passed to GetScalingFactors directly.
       const uint16x8_t average_luma =
           GetAverageLuma(&in_y_row[luma_x], subsampling_x);
       StoreUnsigned8(average_luma_buffer, average_luma);
@@ -800,8 +840,7 @@
       const int16x8_t blended =
           BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>(
               average_luma_buffer, scaling_lut, &in_chroma_row[x],
-              &(noise_image[y + start_height][x]), scaling_shift_vect16,
-              scaling_shift_vect32);
+              &(noise_image[y + start_height][x]), scaling_shift_vect);
 
       // In 8bpp, when params_.clip_to_restricted_range == false, we can replace
       // clipping with vqmovun_s16, but it's not likely to be worth copying the
@@ -813,18 +852,19 @@
 
     if (x < chroma_width) {
       const int luma_x = x << subsampling_x;
-      const int valid_range = width - luma_x;
-      memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0]));
-      luma_buffer[valid_range] = in_y_row[width - 1];
-      const uint16x8_t average_luma =
-          GetAverageLuma(luma_buffer, subsampling_x);
+      const int valid_range_pixels = width - luma_x;
+      const int valid_range_bytes = valid_range_pixels * sizeof(in_y_row[0]);
+      memcpy(luma_buffer, &in_y_row[luma_x], valid_range_bytes);
+      luma_buffer[valid_range_pixels] = in_y_row[width - 1];
+      const uint16x8_t average_luma = GetAverageLumaMsan(
+          luma_buffer, subsampling_x, valid_range_bytes + sizeof(in_y_row[0]));
+
       StoreUnsigned8(average_luma_buffer, average_luma);
 
       const int16x8_t blended =
           BlendChromaValsWithCfl<bitdepth, GrainType, Pixel>(
               average_luma_buffer, scaling_lut, &in_chroma_row[x],
-              &(noise_image[y + start_height][x]), scaling_shift_vect16,
-              scaling_shift_vect32);
+              &(noise_image[y + start_height][x]), scaling_shift_vect);
       // In 8bpp, when params_.clip_to_restricted_range == false, we can replace
       // clipping with vqmovun_s16, but it's not likely to be worth copying the
       // function for just that case.
@@ -842,11 +882,11 @@
 // This further implies that scaling_lut_u == scaling_lut_v == scaling_lut_y.
 template <int bitdepth, typename GrainType, typename Pixel>
 void BlendNoiseWithImageChromaWithCfl_NEON(
-    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
-    int min_value, int max_chroma, int width, int height, int start_height,
-    int subsampling_x, int subsampling_y,
-    const uint8_t scaling_lut[kScalingLookupTableSize],
-    const void* source_plane_y, ptrdiff_t source_stride_y,
+    Plane plane, const FilmGrainParams& params,
+    const void* LIBGAV1_RESTRICT noise_image_ptr, int min_value, int max_chroma,
+    int width, int height, int start_height, int subsampling_x,
+    int subsampling_y, const int16_t* LIBGAV1_RESTRICT scaling_lut,
+    const void* LIBGAV1_RESTRICT source_plane_y, ptrdiff_t source_stride_y,
     const void* source_plane_uv, ptrdiff_t source_stride_uv,
     void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
   const auto* noise_image =
@@ -872,12 +912,11 @@
 namespace {
 
 inline int16x8_t BlendChromaValsNoCfl(
-    const uint8_t scaling_lut[kScalingLookupTableSize],
-    const uint8_t* chroma_cursor, const int8_t* noise_image_cursor,
+    const int16_t* LIBGAV1_RESTRICT scaling_lut, const int16x8_t orig,
+    const int8_t* LIBGAV1_RESTRICT noise_image_cursor,
     const int16x8_t& average_luma, const int16x8_t& scaling_shift_vect,
     const int16x8_t& offset, int luma_multiplier, int chroma_multiplier) {
   uint8_t merged_buffer[8];
-  const int16x8_t orig = GetSignedSource8(chroma_cursor);
   const int16x8_t weighted_luma = vmulq_n_s16(average_luma, luma_multiplier);
   const int16x8_t weighted_chroma = vmulq_n_s16(orig, chroma_multiplier);
   // Maximum value of |combined_u| is 127*255 = 0x7E81.
@@ -887,9 +926,9 @@
   const uint8x8_t merged = vqshrun_n_s16(vhaddq_s16(offset, combined), 4);
   vst1_u8(merged_buffer, merged);
   const int16x8_t scaling =
-      GetScalingFactors<8, uint8_t>(scaling_lut, merged_buffer);
+      GetScalingFactors<kBitdepth8, uint8_t>(scaling_lut, merged_buffer);
   int16x8_t noise = GetSignedSource8(noise_image_cursor);
-  noise = ScaleNoise(noise, scaling, scaling_shift_vect);
+  noise = ScaleNoise<kBitdepth8>(noise, scaling, scaling_shift_vect);
   return vaddq_s16(orig, noise);
 }
 
@@ -898,10 +937,10 @@
     int width, int height, int start_height, int subsampling_x,
     int subsampling_y, int scaling_shift, int chroma_offset,
     int chroma_multiplier, int luma_multiplier,
-    const uint8_t scaling_lut[kScalingLookupTableSize], const uint8_t* in_y_row,
-    ptrdiff_t source_stride_y, const uint8_t* in_chroma_row,
-    ptrdiff_t source_stride_chroma, uint8_t* out_chroma_row,
-    ptrdiff_t dest_stride) {
+    const int16_t* LIBGAV1_RESTRICT scaling_lut,
+    const uint8_t* LIBGAV1_RESTRICT in_y_row, ptrdiff_t source_stride_y,
+    const uint8_t* in_chroma_row, ptrdiff_t source_stride_chroma,
+    uint8_t* out_chroma_row, ptrdiff_t dest_stride) {
   const int16x8_t floor = vdupq_n_s16(min_value);
   const int16x8_t ceiling = vdupq_n_s16(max_chroma);
   // In 8bpp, the maximum upscaled noise is 127*255 = 0x7E81, which is safe
@@ -913,6 +952,10 @@
   const int chroma_width = (width + subsampling_x) >> subsampling_x;
   const int safe_chroma_width = chroma_width & ~7;
   uint8_t luma_buffer[16];
+#if LIBGAV1_MSAN
+  // Quiet msan warnings.
+  memset(luma_buffer, 0, sizeof(luma_buffer));
+#endif
   const int16x8_t offset = vdupq_n_s16(chroma_offset << 5);
 
   start_height >>= subsampling_y;
@@ -921,10 +964,13 @@
     int x = 0;
     do {
       const int luma_x = x << subsampling_x;
+      const int valid_range = width - luma_x;
+
+      const int16x8_t orig_chroma = GetSignedSource8(&in_chroma_row[x]);
       const int16x8_t average_luma = vreinterpretq_s16_u16(
-          GetAverageLuma(&in_y_row[luma_x], subsampling_x));
+          GetAverageLumaMsan(&in_y_row[luma_x], subsampling_x, valid_range));
       const int16x8_t blended = BlendChromaValsNoCfl(
-          scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]),
+          scaling_lut, orig_chroma, &(noise_image[y + start_height][x]),
           average_luma, scaling_shift_vect, offset, luma_multiplier,
           chroma_multiplier);
       // In 8bpp, when params_.clip_to_restricted_range == false, we can
@@ -940,14 +986,19 @@
       // |average_luma| computation requires a duplicated luma value at the
       // end.
       const int luma_x = x << subsampling_x;
-      const int valid_range = width - luma_x;
-      memcpy(luma_buffer, &in_y_row[luma_x], valid_range * sizeof(in_y_row[0]));
-      luma_buffer[valid_range] = in_y_row[width - 1];
+      const int valid_range_pixels = width - luma_x;
+      const int valid_range_bytes = valid_range_pixels * sizeof(in_y_row[0]);
+      memcpy(luma_buffer, &in_y_row[luma_x], valid_range_bytes);
+      luma_buffer[valid_range_pixels] = in_y_row[width - 1];
+      const int valid_range_chroma_bytes =
+          (chroma_width - x) * sizeof(in_chroma_row[0]);
 
-      const int16x8_t average_luma =
-          vreinterpretq_s16_u16(GetAverageLuma(luma_buffer, subsampling_x));
+      const int16x8_t orig_chroma =
+          GetSignedSource8Msan(&in_chroma_row[x], valid_range_chroma_bytes);
+      const int16x8_t average_luma = vreinterpretq_s16_u16(GetAverageLumaMsan(
+          luma_buffer, subsampling_x, valid_range_bytes + sizeof(in_y_row[0])));
       const int16x8_t blended = BlendChromaValsNoCfl(
-          scaling_lut, &in_chroma_row[x], &(noise_image[y + start_height][x]),
+          scaling_lut, orig_chroma, &(noise_image[y + start_height][x]),
           average_luma, scaling_shift_vect, offset, luma_multiplier,
           chroma_multiplier);
       StoreUnsigned8(&out_chroma_row[x],
@@ -963,11 +1014,11 @@
 
 // This function is for the case params_.chroma_scaling_from_luma == false.
 void BlendNoiseWithImageChroma8bpp_NEON(
-    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
-    int min_value, int max_chroma, int width, int height, int start_height,
-    int subsampling_x, int subsampling_y,
-    const uint8_t scaling_lut[kScalingLookupTableSize],
-    const void* source_plane_y, ptrdiff_t source_stride_y,
+    Plane plane, const FilmGrainParams& params,
+    const void* LIBGAV1_RESTRICT noise_image_ptr, int min_value, int max_chroma,
+    int width, int height, int start_height, int subsampling_x,
+    int subsampling_y, const int16_t* LIBGAV1_RESTRICT scaling_lut,
+    const void* LIBGAV1_RESTRICT source_plane_y, ptrdiff_t source_stride_y,
     const void* source_plane_uv, ptrdiff_t source_stride_uv,
     void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
   assert(plane == kPlaneU || plane == kPlaneV);
@@ -989,12 +1040,11 @@
                             in_uv, source_stride_uv, out_uv, dest_stride_uv);
 }
 
-inline void WriteOverlapLine8bpp_NEON(const int8_t* noise_stripe_row,
-                                      const int8_t* noise_stripe_row_prev,
-                                      int plane_width,
-                                      const int8x8_t grain_coeff,
-                                      const int8x8_t old_coeff,
-                                      int8_t* noise_image_row) {
+inline void WriteOverlapLine8bpp_NEON(
+    const int8_t* LIBGAV1_RESTRICT noise_stripe_row,
+    const int8_t* LIBGAV1_RESTRICT noise_stripe_row_prev, int plane_width,
+    const int8x8_t grain_coeff, const int8x8_t old_coeff,
+    int8_t* LIBGAV1_RESTRICT noise_image_row) {
   int x = 0;
   do {
     // Note that these reads may exceed noise_stripe_row's width by up to 7
@@ -1009,10 +1059,10 @@
   } while (x < plane_width);
 }
 
-void ConstructNoiseImageOverlap8bpp_NEON(const void* noise_stripes_buffer,
-                                         int width, int height,
-                                         int subsampling_x, int subsampling_y,
-                                         void* noise_image_buffer) {
+void ConstructNoiseImageOverlap8bpp_NEON(
+    const void* LIBGAV1_RESTRICT noise_stripes_buffer, int width, int height,
+    int subsampling_x, int subsampling_y,
+    void* LIBGAV1_RESTRICT noise_image_buffer) {
   const auto* noise_stripes =
       static_cast<const Array2DView<int8_t>*>(noise_stripes_buffer);
   auto* noise_image = static_cast<Array2D<int8_t>*>(noise_image_buffer);
@@ -1077,41 +1127,45 @@
 
   // LumaAutoRegressionFunc
   dsp->film_grain.luma_auto_regression[0] =
-      ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 1>;
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<kBitdepth8, int8_t, 1>;
   dsp->film_grain.luma_auto_regression[1] =
-      ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 2>;
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<kBitdepth8, int8_t, 2>;
   dsp->film_grain.luma_auto_regression[2] =
-      ApplyAutoRegressiveFilterToLumaGrain_NEON<8, int8_t, 3>;
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<kBitdepth8, int8_t, 3>;
 
   // ChromaAutoRegressionFunc[use_luma][auto_regression_coeff_lag]
   // Chroma autoregression should never be called when lag is 0 and use_luma
   // is false.
   dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
   dsp->film_grain.chroma_auto_regression[0][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 1, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth8, int8_t, 1,
+                                                   false>;
   dsp->film_grain.chroma_auto_regression[0][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 2, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth8, int8_t, 2,
+                                                   false>;
   dsp->film_grain.chroma_auto_regression[0][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 3, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth8, int8_t, 3,
+                                                   false>;
   dsp->film_grain.chroma_auto_regression[1][0] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 0, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth8, int8_t, 0, true>;
   dsp->film_grain.chroma_auto_regression[1][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 1, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth8, int8_t, 1, true>;
   dsp->film_grain.chroma_auto_regression[1][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 2, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth8, int8_t, 2, true>;
   dsp->film_grain.chroma_auto_regression[1][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<8, int8_t, 3, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth8, int8_t, 3, true>;
 
   dsp->film_grain.construct_noise_image_overlap =
       ConstructNoiseImageOverlap8bpp_NEON;
 
-  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_NEON;
+  dsp->film_grain.initialize_scaling_lut =
+      InitializeScalingLookupTable_NEON<kBitdepth8>;
 
   dsp->film_grain.blend_noise_luma =
-      BlendNoiseWithImageLuma_NEON<8, int8_t, uint8_t>;
+      BlendNoiseWithImageLuma_NEON<kBitdepth8, int8_t, uint8_t>;
   dsp->film_grain.blend_noise_chroma[0] = BlendNoiseWithImageChroma8bpp_NEON;
   dsp->film_grain.blend_noise_chroma[1] =
-      BlendNoiseWithImageChromaWithCfl_NEON<8, int8_t, uint8_t>;
+      BlendNoiseWithImageChromaWithCfl_NEON<kBitdepth8, int8_t, uint8_t>;
 }
 
 }  // namespace
@@ -1121,43 +1175,280 @@
 namespace high_bitdepth {
 namespace {
 
+inline void WriteOverlapLine10bpp_NEON(
+    const int16_t* LIBGAV1_RESTRICT noise_stripe_row,
+    const int16_t* LIBGAV1_RESTRICT noise_stripe_row_prev, int plane_width,
+    const int16x8_t grain_coeff, const int16x8_t old_coeff,
+    int16_t* LIBGAV1_RESTRICT noise_image_row) {
+  int x = 0;
+  do {
+    // Note that these reads may exceed noise_stripe_row's width by up to 7
+    // values.
+    const int16x8_t source_grain = vld1q_s16(noise_stripe_row + x);
+    const int16x8_t source_old = vld1q_s16(noise_stripe_row_prev + x);
+    // Maximum product is 511 * 27 = 0x35E5.
+    const int16x8_t weighted_grain = vmulq_s16(grain_coeff, source_grain);
+    // Maximum sum is 511 * (22 + 23) = 0x59D3.
+    const int16x8_t grain_sum =
+        vmlaq_s16(weighted_grain, old_coeff, source_old);
+    // Note that this write may exceed noise_image_row's width by up to 7
+    // values.
+    const int16x8_t grain = Clip3S16(vrshrq_n_s16(grain_sum, 5),
+                                     vdupq_n_s16(GetGrainMin<kBitdepth10>()),
+                                     vdupq_n_s16(GetGrainMax<kBitdepth10>()));
+    vst1q_s16(noise_image_row + x, grain);
+    x += 8;
+  } while (x < plane_width);
+}
+
+void ConstructNoiseImageOverlap10bpp_NEON(
+    const void* LIBGAV1_RESTRICT noise_stripes_buffer, int width, int height,
+    int subsampling_x, int subsampling_y,
+    void* LIBGAV1_RESTRICT noise_image_buffer) {
+  const auto* noise_stripes =
+      static_cast<const Array2DView<int16_t>*>(noise_stripes_buffer);
+  auto* noise_image = static_cast<Array2D<int16_t>*>(noise_image_buffer);
+  const int plane_width = (width + subsampling_x) >> subsampling_x;
+  const int plane_height = (height + subsampling_y) >> subsampling_y;
+  const int stripe_height = 32 >> subsampling_y;
+  const int stripe_mask = stripe_height - 1;
+  int y = stripe_height;
+  int luma_num = 1;
+  if (subsampling_y == 0) {
+    const int16x8_t first_row_grain_coeff = vdupq_n_s16(17);
+    const int16x8_t first_row_old_coeff = vdupq_n_s16(27);
+    const int16x8_t second_row_grain_coeff = first_row_old_coeff;
+    const int16x8_t second_row_old_coeff = first_row_grain_coeff;
+    for (; y < (plane_height & ~stripe_mask); ++luma_num, y += stripe_height) {
+      const int16_t* noise_stripe = (*noise_stripes)[luma_num];
+      const int16_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+      WriteOverlapLine10bpp_NEON(
+          noise_stripe, &noise_stripe_prev[32 * plane_width], plane_width,
+          first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]);
+
+      WriteOverlapLine10bpp_NEON(&noise_stripe[plane_width],
+                                 &noise_stripe_prev[(32 + 1) * plane_width],
+                                 plane_width, second_row_grain_coeff,
+                                 second_row_old_coeff, (*noise_image)[y + 1]);
+    }
+    // Either one partial stripe remains (remaining_height > 0),
+    // OR image is less than one stripe high (remaining_height < 0),
+    // OR all stripes are completed (remaining_height == 0).
+    const int remaining_height = plane_height - y;
+    if (remaining_height <= 0) {
+      return;
+    }
+    const int16_t* noise_stripe = (*noise_stripes)[luma_num];
+    const int16_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+    WriteOverlapLine10bpp_NEON(
+        noise_stripe, &noise_stripe_prev[32 * plane_width], plane_width,
+        first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]);
+
+    if (remaining_height > 1) {
+      WriteOverlapLine10bpp_NEON(&noise_stripe[plane_width],
+                                 &noise_stripe_prev[(32 + 1) * plane_width],
+                                 plane_width, second_row_grain_coeff,
+                                 second_row_old_coeff, (*noise_image)[y + 1]);
+    }
+  } else {  // subsampling_y == 1
+    const int16x8_t first_row_grain_coeff = vdupq_n_s16(22);
+    const int16x8_t first_row_old_coeff = vdupq_n_s16(23);
+    for (; y < plane_height; ++luma_num, y += stripe_height) {
+      const int16_t* noise_stripe = (*noise_stripes)[luma_num];
+      const int16_t* noise_stripe_prev = (*noise_stripes)[luma_num - 1];
+      WriteOverlapLine10bpp_NEON(
+          noise_stripe, &noise_stripe_prev[16 * plane_width], plane_width,
+          first_row_grain_coeff, first_row_old_coeff, (*noise_image)[y]);
+    }
+  }
+}
+
+inline int16x8_t BlendChromaValsNoCfl(
+    const int16_t* LIBGAV1_RESTRICT scaling_lut, const int16x8_t orig,
+    const int16_t* LIBGAV1_RESTRICT noise_image_cursor,
+    const int16x8_t& average_luma, const int16x8_t& scaling_shift_vect,
+    const int32x4_t& offset, int luma_multiplier, int chroma_multiplier) {
+  uint16_t merged_buffer[8];
+  const int32x4_t weighted_luma_low =
+      vmull_n_s16(vget_low_s16(average_luma), luma_multiplier);
+  const int32x4_t weighted_luma_high =
+      vmull_n_s16(vget_high_s16(average_luma), luma_multiplier);
+  // Maximum value of combined is 127 * 1023 = 0x1FB81.
+  const int32x4_t combined_low =
+      vmlal_n_s16(weighted_luma_low, vget_low_s16(orig), chroma_multiplier);
+  const int32x4_t combined_high =
+      vmlal_n_s16(weighted_luma_high, vget_high_s16(orig), chroma_multiplier);
+  // Maximum value of offset is (255 << 8) = 0xFF00. Offset may be negative.
+  const uint16x4_t merged_low =
+      vqshrun_n_s32(vaddq_s32(offset, combined_low), 6);
+  const uint16x4_t merged_high =
+      vqshrun_n_s32(vaddq_s32(offset, combined_high), 6);
+  const uint16x8_t max_pixel = vdupq_n_u16((1 << kBitdepth10) - 1);
+  vst1q_u16(merged_buffer,
+            vminq_u16(vcombine_u16(merged_low, merged_high), max_pixel));
+  const int16x8_t scaling =
+      GetScalingFactors<kBitdepth10, uint16_t>(scaling_lut, merged_buffer);
+  const int16x8_t noise = GetSignedSource8(noise_image_cursor);
+  const int16x8_t scaled_noise =
+      ScaleNoise<kBitdepth10>(noise, scaling, scaling_shift_vect);
+  return vaddq_s16(orig, scaled_noise);
+}
+
+LIBGAV1_ALWAYS_INLINE void BlendChromaPlane10bpp_NEON(
+    const Array2D<int16_t>& noise_image, int min_value, int max_chroma,
+    int width, int height, int start_height, int subsampling_x,
+    int subsampling_y, int scaling_shift, int chroma_offset,
+    int chroma_multiplier, int luma_multiplier,
+    const int16_t* LIBGAV1_RESTRICT scaling_lut,
+    const uint16_t* LIBGAV1_RESTRICT in_y_row, ptrdiff_t source_stride_y,
+    const uint16_t* in_chroma_row, ptrdiff_t source_stride_chroma,
+    uint16_t* out_chroma_row, ptrdiff_t dest_stride) {
+  const int16x8_t floor = vdupq_n_s16(min_value);
+  const int16x8_t ceiling = vdupq_n_s16(max_chroma);
+  const int16x8_t scaling_shift_vect = vdupq_n_s16(15 - scaling_shift);
+
+  const int chroma_height = (height + subsampling_y) >> subsampling_y;
+  const int chroma_width = (width + subsampling_x) >> subsampling_x;
+  const int safe_chroma_width = chroma_width & ~7;
+  uint16_t luma_buffer[16];
+#if LIBGAV1_MSAN
+  // TODO(b/194217060): This can be removed if the range calculations below are
+  // fixed.
+  memset(luma_buffer, 0, sizeof(luma_buffer));
+#endif
+  // Offset is added before downshifting in order to take advantage of
+  // saturation, so it has to be upscaled by 6 bits, plus 2 bits for 10bpp.
+  const int32x4_t offset = vdupq_n_s32(chroma_offset << (6 + 2));
+
+  start_height >>= subsampling_y;
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      const int luma_x = x << subsampling_x;
+      const int16x8_t average_luma = vreinterpretq_s16_u16(
+          GetAverageLuma(&in_y_row[luma_x], subsampling_x));
+      const int16x8_t orig_chroma = GetSignedSource8(&in_chroma_row[x]);
+      const int16x8_t blended = BlendChromaValsNoCfl(
+          scaling_lut, orig_chroma, &(noise_image[y + start_height][x]),
+          average_luma, scaling_shift_vect, offset, luma_multiplier,
+          chroma_multiplier);
+      StoreUnsigned8(&out_chroma_row[x],
+                     vreinterpretq_u16_s16(Clip3(blended, floor, ceiling)));
+
+      x += 8;
+    } while (x < safe_chroma_width);
+
+    if (x < chroma_width) {
+      // Begin right edge iteration. Same as the normal iterations, but the
+      // |average_luma| computation requires a duplicated luma value at the
+      // end.
+      const int luma_x = x << subsampling_x;
+      const int valid_range_pixels = width - luma_x;
+      const int valid_range_bytes = valid_range_pixels * sizeof(in_y_row[0]);
+      memcpy(luma_buffer, &in_y_row[luma_x], valid_range_bytes);
+      luma_buffer[valid_range_pixels] = in_y_row[width - 1];
+      const int valid_range_chroma_bytes =
+          (chroma_width - x) * sizeof(in_chroma_row[0]);
+      const int16x8_t orig_chroma =
+          GetSignedSource8Msan(&in_chroma_row[x], valid_range_chroma_bytes);
+
+      const int16x8_t average_luma = vreinterpretq_s16_u16(GetAverageLumaMsan(
+          luma_buffer, subsampling_x, valid_range_bytes + sizeof(in_y_row[0])));
+      const int16x8_t blended = BlendChromaValsNoCfl(
+          scaling_lut, orig_chroma, &(noise_image[y + start_height][x]),
+          average_luma, scaling_shift_vect, offset, luma_multiplier,
+          chroma_multiplier);
+      StoreUnsigned8(&out_chroma_row[x],
+                     vreinterpretq_u16_s16(Clip3(blended, floor, ceiling)));
+      // End of right edge iteration.
+    }
+
+    in_y_row = AddByteStride(in_y_row, source_stride_y << subsampling_y);
+    in_chroma_row = AddByteStride(in_chroma_row, source_stride_chroma);
+    out_chroma_row = AddByteStride(out_chroma_row, dest_stride);
+  } while (++y < chroma_height);
+}
+
+// This function is for the case params_.chroma_scaling_from_luma == false.
+void BlendNoiseWithImageChroma10bpp_NEON(
+    Plane plane, const FilmGrainParams& params,
+    const void* LIBGAV1_RESTRICT noise_image_ptr, int min_value, int max_chroma,
+    int width, int height, int start_height, int subsampling_x,
+    int subsampling_y, const int16_t* LIBGAV1_RESTRICT scaling_lut,
+    const void* LIBGAV1_RESTRICT source_plane_y, ptrdiff_t source_stride_y,
+    const void* source_plane_uv, ptrdiff_t source_stride_uv,
+    void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
+  assert(plane == kPlaneU || plane == kPlaneV);
+  const auto* noise_image =
+      static_cast<const Array2D<int16_t>*>(noise_image_ptr);
+  const auto* in_y = static_cast<const uint16_t*>(source_plane_y);
+  const auto* in_uv = static_cast<const uint16_t*>(source_plane_uv);
+  auto* out_uv = static_cast<uint16_t*>(dest_plane_uv);
+
+  const int offset = (plane == kPlaneU) ? params.u_offset : params.v_offset;
+  const int luma_multiplier =
+      (plane == kPlaneU) ? params.u_luma_multiplier : params.v_luma_multiplier;
+  const int multiplier =
+      (plane == kPlaneU) ? params.u_multiplier : params.v_multiplier;
+  BlendChromaPlane10bpp_NEON(
+      noise_image[plane], min_value, max_chroma, width, height, start_height,
+      subsampling_x, subsampling_y, params.chroma_scaling, offset, multiplier,
+      luma_multiplier, scaling_lut, in_y, source_stride_y, in_uv,
+      source_stride_uv, out_uv, dest_stride_uv);
+}
+
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
   assert(dsp != nullptr);
 
   // LumaAutoRegressionFunc
   dsp->film_grain.luma_auto_regression[0] =
-      ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 1>;
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<kBitdepth10, int16_t, 1>;
   dsp->film_grain.luma_auto_regression[1] =
-      ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 2>;
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<kBitdepth10, int16_t, 2>;
   dsp->film_grain.luma_auto_regression[2] =
-      ApplyAutoRegressiveFilterToLumaGrain_NEON<10, int16_t, 3>;
+      ApplyAutoRegressiveFilterToLumaGrain_NEON<kBitdepth10, int16_t, 3>;
 
   // ChromaAutoRegressionFunc[use_luma][auto_regression_coeff_lag][subsampling]
   // Chroma autoregression should never be called when lag is 0 and use_luma
   // is false.
   dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
   dsp->film_grain.chroma_auto_regression[0][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 1, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth10, int16_t, 1,
+                                                   false>;
   dsp->film_grain.chroma_auto_regression[0][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 2, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth10, int16_t, 2,
+                                                   false>;
   dsp->film_grain.chroma_auto_regression[0][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 3, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth10, int16_t, 3,
+                                                   false>;
   dsp->film_grain.chroma_auto_regression[1][0] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 0, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth10, int16_t, 0,
+                                                   true>;
   dsp->film_grain.chroma_auto_regression[1][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 1, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth10, int16_t, 1,
+                                                   true>;
   dsp->film_grain.chroma_auto_regression[1][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 2, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth10, int16_t, 2,
+                                                   true>;
   dsp->film_grain.chroma_auto_regression[1][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_NEON<10, int16_t, 3, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_NEON<kBitdepth10, int16_t, 3,
+                                                   true>;
 
-  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_NEON;
+  dsp->film_grain.construct_noise_image_overlap =
+      ConstructNoiseImageOverlap10bpp_NEON;
 
-  dsp->film_grain.blend_noise_luma =
-      BlendNoiseWithImageLuma_NEON<10, int16_t, uint16_t>;
+  dsp->film_grain.initialize_scaling_lut =
+      InitializeScalingLookupTable_NEON<kBitdepth10>;
+
+  // TODO(b/194442742): reenable this function after segfault under armv7 ASan
+  // is fixed.
+  // dsp->film_grain.blend_noise_luma =
+  //     BlendNoiseWithImageLuma_NEON<kBitdepth10, int16_t, uint16_t>;
+  dsp->film_grain.blend_noise_chroma[0] = BlendNoiseWithImageChroma10bpp_NEON;
   dsp->film_grain.blend_noise_chroma[1] =
-      BlendNoiseWithImageChromaWithCfl_NEON<10, int16_t, uint16_t>;
+      BlendNoiseWithImageChromaWithCfl_NEON<kBitdepth10, int16_t, uint16_t>;
 }
 
 }  // namespace
diff --git a/libgav1/src/dsp/arm/film_grain_neon.h b/libgav1/src/dsp/arm/film_grain_neon.h
index 44b3d1d..3ba2eef 100644
--- a/libgav1/src/dsp/arm/film_grain_neon.h
+++ b/libgav1/src/dsp/arm/film_grain_neon.h
@@ -35,11 +35,15 @@
 #define LIBGAV1_Dsp8bpp_FilmGrainAutoregressionChroma LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp10bpp_FilmGrainAutoregressionChroma LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseImageOverlap LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainConstructNoiseImageOverlap LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp8bpp_FilmGrainInitializeScalingLutFunc LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp10bpp_FilmGrainInitializeScalingLutFunc LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON
-#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON
+// TODO(b/194442742): reenable this function after segfault under armv7 ASan is
+// fixed.
+// #define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseLuma LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChroma LIBGAV1_DSP_NEON
+#define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChroma LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_NEON
 #define LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChromaWithCfl LIBGAV1_DSP_NEON
 #endif  // LIBGAV1_ENABLE_NEON
diff --git a/libgav1/src/dsp/arm/intra_edge_neon.cc b/libgav1/src/dsp/arm/intra_edge_neon.cc
index 074283f..9b20e29 100644
--- a/libgav1/src/dsp/arm/intra_edge_neon.cc
+++ b/libgav1/src/dsp/arm/intra_edge_neon.cc
@@ -248,7 +248,8 @@
 
     vst1_u8(pixel_buffer - 1, InterleaveLow8(result, src21));
     return;
-  } else if (size == 8) {
+  }
+  if (size == 8) {
     // Likewise, one load + multiple vtbls seems preferred to multiple loads.
     const uint8x16_t src = vld1q_u8(pixel_buffer - 1);
     const uint8x8_t src0 = VQTbl1U8(src, vcreate_u8(0x0605040302010000));
diff --git a/libgav1/src/dsp/arm/intrapred_cfl_neon.cc b/libgav1/src/dsp/arm/intrapred_cfl_neon.cc
index 8d8748f..ad39947 100644
--- a/libgav1/src/dsp/arm/intrapred_cfl_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_cfl_neon.cc
@@ -76,7 +76,7 @@
 void CflSubsampler420_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, const ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride) {
   const auto* src = static_cast<const uint8_t*>(source);
   uint32_t sum;
   if (block_width == 4) {
@@ -140,7 +140,7 @@
       const uint8_t a11 = src[max_luma_width - 1 + stride];
       // Dup the 2x2 sum at the max luma offset.
       const uint16x8_t max_luma_sum =
-          vdupq_n_u16((uint16_t)((a00 + a01 + a10 + a11) << 1));
+          vdupq_n_u16(static_cast<uint16_t>((a00 + a01 + a10 + a11) << 1));
       uint16x8_t x_index = {0, 2, 4, 6, 8, 10, 12, 14};
 
       ptrdiff_t src_x_offset = 0;
@@ -173,7 +173,7 @@
 void CflSubsampler444_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, const ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride) {
   const auto* src = static_cast<const uint8_t*>(source);
   uint32_t sum;
   if (block_width == 4) {
@@ -276,7 +276,7 @@
 // uint8_t. Saturated int16_t >> 6 outranges uint8_t.
 template <int block_height>
 inline void CflIntraPredictor4xN_NEON(
-    void* const dest, const ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   auto* dst = static_cast<uint8_t*>(dest);
@@ -295,7 +295,7 @@
 
 template <int block_height>
 inline void CflIntraPredictor8xN_NEON(
-    void* const dest, const ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   auto* dst = static_cast<uint8_t*>(dest);
@@ -310,7 +310,7 @@
 
 template <int block_height>
 inline void CflIntraPredictor16xN_NEON(
-    void* const dest, const ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   auto* dst = static_cast<uint8_t*>(dest);
@@ -328,7 +328,7 @@
 
 template <int block_height>
 inline void CflIntraPredictor32xN_NEON(
-    void* const dest, const ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   auto* dst = static_cast<uint8_t*>(dest);
@@ -507,7 +507,8 @@
 template <int block_height_log2, bool is_inside>
 void CflSubsampler444_4xH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int max_luma_height, const void* const source, ptrdiff_t stride) {
+    const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
+    ptrdiff_t stride) {
   static_assert(block_height_log2 <= 4, "");
   const int block_height = 1 << block_height_log2;
   const int visible_height = max_luma_height;
@@ -568,7 +569,7 @@
 void CflSubsampler444_4xH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_cast<void>(max_luma_width);
   static_cast<void>(max_luma_height);
   static_assert(block_height_log2 <= 4, "");
@@ -588,7 +589,8 @@
 template <int block_height_log2, bool is_inside>
 void CflSubsampler444_8xH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int max_luma_height, const void* const source, ptrdiff_t stride) {
+    const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
+    ptrdiff_t stride) {
   const int block_height = 1 << block_height_log2;
   const int visible_height = max_luma_height;
   const auto* src = static_cast<const uint16_t*>(source);
@@ -643,7 +645,7 @@
 void CflSubsampler444_8xH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_cast<void>(max_luma_width);
   static_cast<void>(max_luma_height);
   static_assert(block_height_log2 <= 5, "");
@@ -667,7 +669,7 @@
 void CflSubsampler444_WxH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   const int block_height = 1 << block_height_log2;
   const int visible_height = max_luma_height;
   const int block_width = 1 << block_width_log2;
@@ -751,7 +753,7 @@
 void CflSubsampler444_WxH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_assert(block_width_log2 == 4 || block_width_log2 == 5,
                 "This function will only work for block_width 16 and 32.");
   static_assert(block_height_log2 <= 5, "");
@@ -773,7 +775,7 @@
 void CflSubsampler420_4xH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int /*max_luma_width*/, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   const int block_height = 1 << block_height_log2;
   const auto* src = static_cast<const uint16_t*>(source);
   const ptrdiff_t src_stride = stride / sizeof(src[0]);
@@ -839,7 +841,8 @@
 template <int block_height_log2, int max_luma_width>
 inline void CflSubsampler420Impl_8xH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int max_luma_height, const void* const source, ptrdiff_t stride) {
+    const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
+    ptrdiff_t stride) {
   const int block_height = 1 << block_height_log2;
   const auto* src = static_cast<const uint16_t*>(source);
   const ptrdiff_t src_stride = stride / sizeof(src[0]);
@@ -944,7 +947,7 @@
 void CflSubsampler420_8xH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   if (max_luma_width == 8) {
     CflSubsampler420Impl_8xH_NEON<block_height_log2, 8>(luma, max_luma_height,
                                                         source, stride);
@@ -957,7 +960,8 @@
 template <int block_width_log2, int block_height_log2, int max_luma_width>
 inline void CflSubsampler420Impl_WxH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int max_luma_height, const void* const source, ptrdiff_t stride) {
+    const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
+    ptrdiff_t stride) {
   const auto* src = static_cast<const uint16_t*>(source);
   const ptrdiff_t src_stride = stride / sizeof(src[0]);
   const int block_height = 1 << block_height_log2;
@@ -1062,7 +1066,7 @@
 void CflSubsampler420_WxH_NEON(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   switch (max_luma_width) {
     case 8:
       CflSubsampler420Impl_WxH_NEON<block_width_log2, block_height_log2, 8>(
@@ -1109,7 +1113,7 @@
 
 template <int block_height, int bitdepth = 10>
 inline void CflIntraPredictor4xN_NEON(
-    void* const dest, const ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   auto* dst = static_cast<uint16_t*>(dest);
@@ -1133,7 +1137,7 @@
 
 template <int block_height, int bitdepth = 10>
 inline void CflIntraPredictor8xN_NEON(
-    void* const dest, const ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   auto* dst = static_cast<uint16_t*>(dest);
@@ -1153,7 +1157,7 @@
 
 template <int block_height, int bitdepth = 10>
 inline void CflIntraPredictor16xN_NEON(
-    void* const dest, const ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   auto* dst = static_cast<uint16_t*>(dest);
@@ -1177,7 +1181,7 @@
 
 template <int block_height, int bitdepth = 10>
 inline void CflIntraPredictor32xN_NEON(
-    void* const dest, const ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   auto* dst = static_cast<uint16_t*>(dest);
diff --git a/libgav1/src/dsp/arm/intrapred_directional_neon.cc b/libgav1/src/dsp/arm/intrapred_directional_neon.cc
index 3f5edbd..3cad4a6 100644
--- a/libgav1/src/dsp/arm/intrapred_directional_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_directional_neon.cc
@@ -29,6 +29,7 @@
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 
 namespace libgav1 {
 namespace dsp {
@@ -40,9 +41,9 @@
                                const uint8x8_t a_weight,
                                const uint8x8_t b_weight) {
   const uint16x8_t a_product = vmull_u8(a, a_weight);
-  const uint16x8_t b_product = vmull_u8(b, b_weight);
+  const uint16x8_t sum = vmlal_u8(a_product, b, b_weight);
 
-  return vrshrn_n_u16(vaddq_u16(a_product, b_product), 5 /*log2(32)*/);
+  return vrshrn_n_u16(sum, 5 /*log2(32)*/);
 }
 
 // For vertical operations the weights are one constant value.
@@ -52,9 +53,9 @@
 }
 
 // Fill |left| and |right| with the appropriate values for a given |base_step|.
-inline void LoadStepwise(const uint8_t* const source, const uint8x8_t left_step,
-                         const uint8x8_t right_step, uint8x8_t* left,
-                         uint8x8_t* right) {
+inline void LoadStepwise(const uint8_t* LIBGAV1_RESTRICT const source,
+                         const uint8x8_t left_step, const uint8x8_t right_step,
+                         uint8x8_t* left, uint8x8_t* right) {
   const uint8x16_t mixed = vld1q_u8(source);
   *left = VQTbl1U8(mixed, left_step);
   *right = VQTbl1U8(mixed, right_step);
@@ -62,17 +63,18 @@
 
 // Handle signed step arguments by ignoring the sign. Negative values are
 // considered out of range and overwritten later.
-inline void LoadStepwise(const uint8_t* const source, const int8x8_t left_step,
-                         const int8x8_t right_step, uint8x8_t* left,
-                         uint8x8_t* right) {
+inline void LoadStepwise(const uint8_t* LIBGAV1_RESTRICT const source,
+                         const int8x8_t left_step, const int8x8_t right_step,
+                         uint8x8_t* left, uint8x8_t* right) {
   LoadStepwise(source, vreinterpret_u8_s8(left_step),
                vreinterpret_u8_s8(right_step), left, right);
 }
 
 // Process 4 or 8 |width| by any |height|.
 template <int width>
-inline void DirectionalZone1_WxH(uint8_t* dst, const ptrdiff_t stride,
-                                 const int height, const uint8_t* const top,
+inline void DirectionalZone1_WxH(uint8_t* LIBGAV1_RESTRICT dst,
+                                 const ptrdiff_t stride, const int height,
+                                 const uint8_t* LIBGAV1_RESTRICT const top,
                                  const int xstep, const bool upsampled) {
   assert(width == 4 || width == 8);
 
@@ -142,10 +144,11 @@
 
 // Process a multiple of 8 |width| by any |height|. Processes horizontally
 // before vertically in the hopes of being a little more cache friendly.
-inline void DirectionalZone1_WxH(uint8_t* dst, const ptrdiff_t stride,
-                                 const int width, const int height,
-                                 const uint8_t* const top, const int xstep,
-                                 const bool upsampled) {
+inline void DirectionalZone1_WxH(uint8_t* LIBGAV1_RESTRICT dst,
+                                 const ptrdiff_t stride, const int width,
+                                 const int height,
+                                 const uint8_t* LIBGAV1_RESTRICT const top,
+                                 const int xstep, const bool upsampled) {
   assert(width % 8 == 0);
   const int upsample_shift = static_cast<int>(upsampled);
   const int scale_bits = 6 - upsample_shift;
@@ -203,14 +206,12 @@
   } while (++y < height);
 }
 
-void DirectionalIntraPredictorZone1_NEON(void* const dest,
-                                         const ptrdiff_t stride,
-                                         const void* const top_row,
-                                         const int width, const int height,
-                                         const int xstep,
-                                         const bool upsampled_top) {
-  const uint8_t* const top = static_cast<const uint8_t*>(top_row);
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+void DirectionalIntraPredictorZone1_NEON(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row, const int width,
+    const int height, const int xstep, const bool upsampled_top) {
+  const auto* const top = static_cast<const uint8_t*>(top_row);
+  auto* dst = static_cast<uint8_t*>(dest);
 
   assert(xstep > 0);
 
@@ -282,11 +283,10 @@
 
 // Process 4 or 8 |width| by 4 or 8 |height|.
 template <int width>
-inline void DirectionalZone3_WxH(uint8_t* dest, const ptrdiff_t stride,
-                                 const int height,
-                                 const uint8_t* const left_column,
-                                 const int base_left_y, const int ystep,
-                                 const int upsample_shift) {
+inline void DirectionalZone3_WxH(
+    uint8_t* LIBGAV1_RESTRICT dest, const ptrdiff_t stride, const int height,
+    const uint8_t* LIBGAV1_RESTRICT const left_column, const int base_left_y,
+    const int ystep, const int upsample_shift) {
   assert(width == 4 || width == 8);
   assert(height == 4 || height == 8);
   const int scale_bits = 6 - upsample_shift;
@@ -417,12 +417,10 @@
 
 // Process 4 or 8 |width| by any |height|.
 template <int width>
-inline void DirectionalZone2FromLeftCol_WxH(uint8_t* dst,
-                                            const ptrdiff_t stride,
-                                            const int height,
-                                            const uint8_t* const left_column,
-                                            const int16x8_t left_y,
-                                            const int upsample_shift) {
+inline void DirectionalZone2FromLeftCol_WxH(
+    uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t stride, const int height,
+    const uint8_t* LIBGAV1_RESTRICT const left_column, const int16x8_t left_y,
+    const int upsample_shift) {
   assert(width == 4 || width == 8);
 
   // The shift argument must be a constant.
@@ -468,12 +466,10 @@
 
 // Process 4 or 8 |width| by any |height|.
 template <int width>
-inline void DirectionalZone1Blend_WxH(uint8_t* dest, const ptrdiff_t stride,
-                                      const int height,
-                                      const uint8_t* const top_row,
-                                      int zone_bounds, int top_x,
-                                      const int xstep,
-                                      const int upsample_shift) {
+inline void DirectionalZone1Blend_WxH(
+    uint8_t* LIBGAV1_RESTRICT dest, const ptrdiff_t stride, const int height,
+    const uint8_t* LIBGAV1_RESTRICT const top_row, int zone_bounds, int top_x,
+    const int xstep, const int upsample_shift) {
   assert(width == 4 || width == 8);
 
   const int scale_bits_x = 6 - upsample_shift;
@@ -523,12 +519,12 @@
 // then handle only blocks that take from |left_ptr|. Additionally, a fast
 // index-shuffle approach is used for pred values from |left_column| in sections
 // that permit it.
-inline void DirectionalZone2_4xH(uint8_t* dst, const ptrdiff_t stride,
-                                 const uint8_t* const top_row,
-                                 const uint8_t* const left_column,
-                                 const int height, const int xstep,
-                                 const int ystep, const bool upsampled_top,
-                                 const bool upsampled_left) {
+inline void DirectionalZone2_4xH(
+    uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t stride,
+    const uint8_t* LIBGAV1_RESTRICT const top_row,
+    const uint8_t* LIBGAV1_RESTRICT const left_column, const int height,
+    const int xstep, const int ystep, const bool upsampled_top,
+    const bool upsampled_left) {
   const int upsample_left_shift = static_cast<int>(upsampled_left);
   const int upsample_top_shift = static_cast<int>(upsampled_top);
 
@@ -564,8 +560,8 @@
   // If the 64 scaling is regarded as a decimal point, the first value of the
   // left_y vector omits the portion which is covered under the left_column
   // offset. The following values need the full ystep as a relative offset.
-  int16x8_t left_y = vmulq_n_s16(zero_to_seven, -ystep);
-  left_y = vaddq_s16(left_y, vdupq_n_s16(-ystep_remainder));
+  const int16x8_t remainder = vdupq_n_s16(-ystep_remainder);
+  const int16x8_t left_y = vmlaq_n_s16(remainder, zero_to_seven, -ystep);
 
   // This loop treats each set of 4 columns in 3 stages with y-value boundaries.
   // The first stage, before the first y-loop, covers blocks that are only
@@ -639,13 +635,12 @@
 }
 
 // Process a multiple of 8 |width|.
-inline void DirectionalZone2_8(uint8_t* const dst, const ptrdiff_t stride,
-                               const uint8_t* const top_row,
-                               const uint8_t* const left_column,
-                               const int width, const int height,
-                               const int xstep, const int ystep,
-                               const bool upsampled_top,
-                               const bool upsampled_left) {
+inline void DirectionalZone2_8(
+    uint8_t* LIBGAV1_RESTRICT const dst, const ptrdiff_t stride,
+    const uint8_t* LIBGAV1_RESTRICT const top_row,
+    const uint8_t* LIBGAV1_RESTRICT const left_column, const int width,
+    const int height, const int xstep, const int ystep,
+    const bool upsampled_top, const bool upsampled_left) {
   const int upsample_left_shift = static_cast<int>(upsampled_left);
   const int upsample_top_shift = static_cast<int>(upsampled_top);
 
@@ -668,12 +663,6 @@
   assert(xstep >= 3);
   const int min_top_only_x = std::min((height * xstep) >> 6, width);
 
-  // For steep angles, the source pixels from |left_column| may not fit in a
-  // 16-byte load for shuffling.
-  // TODO(petersonab): Find a more precise formula for this subject to x.
-  const int max_shuffle_height =
-      std::min(kDirectionalZone2ShuffleInvalidHeight[ystep >> 6], height);
-
   // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1
   int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1;
 
@@ -687,8 +676,8 @@
   // If the 64 scaling is regarded as a decimal point, the first value of the
   // left_y vector omits the portion which is covered under the left_column
   // offset. Following values need the full ystep as a relative offset.
-  int16x8_t left_y = vmulq_n_s16(zero_to_seven, -ystep);
-  left_y = vaddq_s16(left_y, vdupq_n_s16(-ystep_remainder));
+  const int16x8_t remainder = vdupq_n_s16(-ystep_remainder);
+  int16x8_t left_y = vmlaq_n_s16(remainder, zero_to_seven, -ystep);
 
   // This loop treats each set of 4 columns in 3 stages with y-value boundaries.
   // The first stage, before the first y-loop, covers blocks that are only
@@ -696,12 +685,21 @@
   // blocks that have a mixture of values computed from top or left. The final
   // stage covers blocks that are only computed from the left.
   int x = 0;
+  // For steep angles, the source pixels from |left_column| may not fit in a
+  // 16-byte load for shuffling. |d| represents the number of pixels that can
+  // fit in one contiguous vector when stepping by |ystep|. For a given x
+  // position, the left column values can be obtained by VTBL as long as the
+  // values at row[x + d] and beyond come from the top row. However, this does
+  // not guarantee that the vector will also contain all of the values needed
+  // from top row.
+  const int d = 16 / ((ystep >> 6) + 1);
   for (int left_offset = -left_base_increment; x < min_top_only_x; x += 8,
            xstep_bounds_base -= (8 << 6),
            left_y = vsubq_s16(left_y, increment_left8),
            left_offset -= left_base_increment8) {
     uint8_t* dst_x = dst + x;
-
+    const int max_shuffle_height =
+        std::min(((x + d) << 6) / xstep, height) & ~7;
     // Round down to the nearest multiple of 8.
     const int max_top_only_y = std::min(((x + 1) << 6) / xstep, height) & ~7;
     DirectionalZone1_WxH<8>(dst_x, stride, max_top_only_y,
@@ -770,14 +768,20 @@
 }
 
 void DirectionalIntraPredictorZone2_NEON(
-    void* const dest, const ptrdiff_t stride, const void* const top_row,
-    const void* const left_column, const int width, const int height,
-    const int xstep, const int ystep, const bool upsampled_top,
-    const bool upsampled_left) {
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column, const int width,
+    const int height, const int xstep, const int ystep,
+    const bool upsampled_top, const bool upsampled_left) {
   // Increasing the negative buffer for this function allows more rows to be
   // processed at a time without branching in an inner loop to check the base.
   uint8_t top_buffer[288];
   uint8_t left_buffer[288];
+#if LIBGAV1_MSAN
+  memset(top_buffer, 0, sizeof(top_buffer));
+  memset(left_buffer, 0, sizeof(left_buffer));
+#endif  // LIBGAV1_MSAN
+
   memcpy(top_buffer + 128, static_cast<const uint8_t*>(top_row) - 16, 160);
   memcpy(left_buffer + 128, static_cast<const uint8_t*>(left_column) - 16, 160);
   const uint8_t* top_ptr = top_buffer + 144;
@@ -793,12 +797,10 @@
   }
 }
 
-void DirectionalIntraPredictorZone3_NEON(void* const dest,
-                                         const ptrdiff_t stride,
-                                         const void* const left_column,
-                                         const int width, const int height,
-                                         const int ystep,
-                                         const bool upsampled_left) {
+void DirectionalIntraPredictorZone3_NEON(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const left_column, const int width,
+    const int height, const int ystep, const bool upsampled_left) {
   const auto* const left = static_cast<const uint8_t*>(left_column);
 
   assert(ystep > 0);
@@ -819,7 +821,7 @@
     do {
       int x = 0;
       do {
-        uint8_t* dst = static_cast<uint8_t*>(dest);
+        auto* dst = static_cast<uint8_t*>(dest);
         dst += y * stride + x;
         uint8x8_t left_v[4], right_v[4], value_v[4];
         const int ystep_base = ystep * x;
@@ -886,7 +888,7 @@
     do {
       int x = 0;
       do {
-        uint8_t* dst = static_cast<uint8_t*>(dest);
+        auto* dst = static_cast<uint8_t*>(dest);
         dst += y * stride + x;
         const int ystep_base = ystep * (x + 1);
 
@@ -934,7 +936,8 @@
 }
 
 // Each element of |dest| contains values associated with one weight value.
-inline void LoadEdgeVals(uint16x4x2_t* dest, const uint16_t* const source,
+inline void LoadEdgeVals(uint16x4x2_t* dest,
+                         const uint16_t* LIBGAV1_RESTRICT const source,
                          const bool upsampled) {
   if (upsampled) {
     *dest = vld2_u16(source);
@@ -945,7 +948,8 @@
 }
 
 // Each element of |dest| contains values associated with one weight value.
-inline void LoadEdgeVals(uint16x8x2_t* dest, const uint16_t* const source,
+inline void LoadEdgeVals(uint16x8x2_t* dest,
+                         const uint16_t* LIBGAV1_RESTRICT const source,
                          const bool upsampled) {
   if (upsampled) {
     *dest = vld2q_u16(source);
@@ -956,8 +960,9 @@
 }
 
 template <bool upsampled>
-inline void DirectionalZone1_4xH(uint16_t* dst, const ptrdiff_t stride,
-                                 const int height, const uint16_t* const top,
+inline void DirectionalZone1_4xH(uint16_t* LIBGAV1_RESTRICT dst,
+                                 const ptrdiff_t stride, const int height,
+                                 const uint16_t* LIBGAV1_RESTRICT const top,
                                  const int xstep) {
   const int upsample_shift = static_cast<int>(upsampled);
   const int index_scale_bits = 6 - upsample_shift;
@@ -1007,9 +1012,11 @@
 // Process a multiple of 8 |width| by any |height|. Processes horizontally
 // before vertically in the hopes of being a little more cache friendly.
 template <bool upsampled>
-inline void DirectionalZone1_WxH(uint16_t* dst, const ptrdiff_t stride,
-                                 const int width, const int height,
-                                 const uint16_t* const top, const int xstep) {
+inline void DirectionalZone1_WxH(uint16_t* LIBGAV1_RESTRICT dst,
+                                 const ptrdiff_t stride, const int width,
+                                 const int height,
+                                 const uint16_t* LIBGAV1_RESTRICT const top,
+                                 const int xstep) {
   assert(width % 8 == 0);
   const int upsample_shift = static_cast<int>(upsampled);
   const int index_scale_bits = 6 - upsample_shift;
@@ -1068,10 +1075,11 @@
 
 // Process a multiple of 8 |width| by any |height|. Processes horizontally
 // before vertically in the hopes of being a little more cache friendly.
-inline void DirectionalZone1_Large(uint16_t* dst, const ptrdiff_t stride,
-                                   const int width, const int height,
-                                   const uint16_t* const top, const int xstep,
-                                   const bool upsampled) {
+inline void DirectionalZone1_Large(uint16_t* LIBGAV1_RESTRICT dst,
+                                   const ptrdiff_t stride, const int width,
+                                   const int height,
+                                   const uint16_t* LIBGAV1_RESTRICT const top,
+                                   const int xstep, const bool upsampled) {
   assert(width % 8 == 0);
   const int upsample_shift = static_cast<int>(upsampled);
   const int index_scale_bits = 6 - upsample_shift;
@@ -1156,13 +1164,12 @@
   }
 }
 
-void DirectionalIntraPredictorZone1_NEON(void* const dest, ptrdiff_t stride,
-                                         const void* const top_row,
-                                         const int width, const int height,
-                                         const int xstep,
-                                         const bool upsampled_top) {
-  const uint16_t* const top = static_cast<const uint16_t*>(top_row);
-  uint16_t* dst = static_cast<uint16_t*>(dest);
+void DirectionalIntraPredictorZone1_NEON(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row, const int width,
+    const int height, const int xstep, const bool upsampled_top) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  auto* dst = static_cast<uint16_t*>(dest);
   stride /= sizeof(top[0]);
 
   assert(xstep > 0);
@@ -1225,9 +1232,10 @@
 // 42 52 62 72             60 61 62 63
 // 43 53 63 73             70 71 72 73
 template <bool upsampled>
-inline void DirectionalZone3_4x4(uint8_t* dst, const ptrdiff_t stride,
-                                 const uint16_t* const left, const int ystep,
-                                 const int base_left_y = 0) {
+inline void DirectionalZone3_4x4(uint8_t* LIBGAV1_RESTRICT dst,
+                                 const ptrdiff_t stride,
+                                 const uint16_t* LIBGAV1_RESTRICT const left,
+                                 const int ystep, const int base_left_y = 0) {
   const int upsample_shift = static_cast<int>(upsampled);
   const int index_scale_bits = 6 - upsample_shift;
 
@@ -1278,8 +1286,9 @@
 }
 
 template <bool upsampled>
-inline void DirectionalZone3_4xH(uint8_t* dest, const ptrdiff_t stride,
-                                 const int height, const uint16_t* const left,
+inline void DirectionalZone3_4xH(uint8_t* LIBGAV1_RESTRICT dest,
+                                 const ptrdiff_t stride, const int height,
+                                 const uint16_t* LIBGAV1_RESTRICT const left,
                                  const int ystep) {
   const int upsample_shift = static_cast<int>(upsampled);
   int y = 0;
@@ -1292,8 +1301,9 @@
 }
 
 template <bool upsampled>
-inline void DirectionalZone3_Wx4(uint8_t* dest, const ptrdiff_t stride,
-                                 const int width, const uint16_t* const left,
+inline void DirectionalZone3_Wx4(uint8_t* LIBGAV1_RESTRICT dest,
+                                 const ptrdiff_t stride, const int width,
+                                 const uint16_t* LIBGAV1_RESTRICT const left,
                                  const int ystep) {
   int x = 0;
   int base_left_y = 0;
@@ -1308,9 +1318,10 @@
 }
 
 template <bool upsampled>
-inline void DirectionalZone3_8x8(uint8_t* dest, const ptrdiff_t stride,
-                                 const uint16_t* const left, const int ystep,
-                                 const int base_left_y = 0) {
+inline void DirectionalZone3_8x8(uint8_t* LIBGAV1_RESTRICT dest,
+                                 const ptrdiff_t stride,
+                                 const uint16_t* LIBGAV1_RESTRICT const left,
+                                 const int ystep, const int base_left_y = 0) {
   const int upsample_shift = static_cast<int>(upsampled);
   const int index_scale_bits = 6 - upsample_shift;
 
@@ -1400,9 +1411,11 @@
 }
 
 template <bool upsampled>
-inline void DirectionalZone3_WxH(uint8_t* dest, const ptrdiff_t stride,
-                                 const int width, const int height,
-                                 const uint16_t* const left, const int ystep) {
+inline void DirectionalZone3_WxH(uint8_t* LIBGAV1_RESTRICT dest,
+                                 const ptrdiff_t stride, const int width,
+                                 const int height,
+                                 const uint16_t* LIBGAV1_RESTRICT const left,
+                                 const int ystep) {
   const int upsample_shift = static_cast<int>(upsampled);
   // Zone3 never runs out of left_column values.
   assert((width + height - 1) << upsample_shift >  // max_base_y
@@ -1424,14 +1437,12 @@
   } while (y < height);
 }
 
-void DirectionalIntraPredictorZone3_NEON(void* const dest,
-                                         const ptrdiff_t stride,
-                                         const void* const left_column,
-                                         const int width, const int height,
-                                         const int ystep,
-                                         const bool upsampled_left) {
-  const uint16_t* const left = static_cast<const uint16_t*>(left_column);
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+void DirectionalIntraPredictorZone3_NEON(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const left_column, const int width,
+    const int height, const int ystep, const bool upsampled_left) {
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+  auto* dst = static_cast<uint8_t*>(dest);
 
   if (ystep == 64) {
     assert(!upsampled_left);
@@ -1472,10 +1483,672 @@
   }
 }
 
+// -----------------------------------------------------------------------------
+// Zone2
+// This function deals with cases not found in zone 1 or zone 3. The extreme
+// angles are 93, which makes for sharp ascents along |left_column| with each
+// successive dest row element until reaching |top_row|, and 177, with a shallow
+// ascent up |left_column| until reaching large jumps along |top_row|. In the
+// extremely steep cases, source vectors can only be loaded one lane at a time.
+
+// Fill |left| and |right| with the appropriate values for a given |base_step|.
+inline void LoadStepwise(const void* LIBGAV1_RESTRICT const source,
+                         const uint8x8_t left_step, const uint8x8_t right_step,
+                         uint16x4_t* left, uint16x4_t* right) {
+  const uint8x16x2_t mixed = {
+      vld1q_u8(static_cast<const uint8_t*>(source)),
+      vld1q_u8(static_cast<const uint8_t*>(source) + 16)};
+  *left = vreinterpret_u16_u8(VQTbl2U8(mixed, left_step));
+  *right = vreinterpret_u16_u8(VQTbl2U8(mixed, right_step));
+}
+
+inline void LoadStepwise(const void* LIBGAV1_RESTRICT const source,
+                         const uint8x8_t left_step_0,
+                         const uint8x8_t right_step_0,
+                         const uint8x8_t left_step_1,
+                         const uint8x8_t right_step_1, uint16x8_t* left,
+                         uint16x8_t* right) {
+  const uint8x16x2_t mixed = {
+      vld1q_u8(static_cast<const uint8_t*>(source)),
+      vld1q_u8(static_cast<const uint8_t*>(source) + 16)};
+  const uint16x4_t left_low = vreinterpret_u16_u8(VQTbl2U8(mixed, left_step_0));
+  const uint16x4_t left_high =
+      vreinterpret_u16_u8(VQTbl2U8(mixed, left_step_1));
+  *left = vcombine_u16(left_low, left_high);
+  const uint16x4_t right_low =
+      vreinterpret_u16_u8(VQTbl2U8(mixed, right_step_0));
+  const uint16x4_t right_high =
+      vreinterpret_u16_u8(VQTbl2U8(mixed, right_step_1));
+  *right = vcombine_u16(right_low, right_high);
+}
+
+// Blend two values based on weight pairs that each sum to 32.
+inline uint16x4_t WeightedBlend(const uint16x4_t a, const uint16x4_t b,
+                                const uint16x4_t a_weight,
+                                const uint16x4_t b_weight) {
+  const uint16x4_t a_product = vmul_u16(a, a_weight);
+  const uint16x4_t sum = vmla_u16(a_product, b, b_weight);
+
+  return vrshr_n_u16(sum, 5 /*log2(32)*/);
+}
+
+// Blend two values based on weight pairs that each sum to 32.
+inline uint16x8_t WeightedBlend(const uint16x8_t a, const uint16x8_t b,
+                                const uint16x8_t a_weight,
+                                const uint16x8_t b_weight) {
+  const uint16x8_t a_product = vmulq_u16(a, a_weight);
+  const uint16x8_t sum = vmlaq_u16(a_product, b, b_weight);
+
+  return vrshrq_n_u16(sum, 5 /*log2(32)*/);
+}
+
+// Because the source values "move backwards" as the row index increases, the
+// indices derived from ystep are generally negative in localized functions.
+// This is accommodated by making sure the relative indices are within [-15, 0]
+// when the function is called, and sliding them into the inclusive range
+// [0, 15], relative to a lower base address. 15 is the Pixel offset, so 30 is
+// the byte offset for table lookups.
+
+constexpr int kPositiveIndexOffsetPixels = 15;
+constexpr int kPositiveIndexOffsetBytes = 30;
+
+inline void DirectionalZone2FromLeftCol_4xH(
+    uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t stride, const int height,
+    const uint16_t* LIBGAV1_RESTRICT const left_column, const int16x4_t left_y,
+    const bool upsampled) {
+  const int upsample_shift = static_cast<int>(upsampled);
+
+  const int index_scale_bits = 6;
+  // The values in |offset_y| are negative, except for the first element, which
+  // is zero.
+  int16x4_t offset_y;
+  int16x4_t shift_upsampled = left_y;
+  // The shift argument must be a constant, otherwise use upsample_shift
+  // directly.
+  if (upsampled) {
+    offset_y = vshr_n_s16(left_y, index_scale_bits - 1 /*upsample_shift*/);
+    shift_upsampled = vshl_n_s16(shift_upsampled, 1);
+  } else {
+    offset_y = vshr_n_s16(left_y, index_scale_bits);
+  }
+  offset_y = vshl_n_s16(offset_y, 1);
+
+  // Select values to the left of the starting point.
+  // The 15th element (and 16th) will be all the way at the end, to the
+  // right. With a negative ystep everything else will be "left" of them.
+  // This supports cumulative steps up to 15. We could support up to 16 by
+  // doing separate loads for |left_values| and |right_values|. vtbl
+  // supports 2 Q registers as input which would allow for cumulative
+  // offsets of 32.
+  // |sampler_0| indexes the first byte of each 16-bit value.
+  const int16x4_t sampler_0 =
+      vadd_s16(offset_y, vdup_n_s16(kPositiveIndexOffsetBytes));
+  // |sampler_1| indexes the second byte of each 16-bit value.
+  const int16x4_t sampler_1 = vadd_s16(sampler_0, vdup_n_s16(1));
+  const int16x4x2_t sampler = vzip_s16(sampler_0, sampler_1);
+  const uint8x8_t left_indices =
+      vqmovun_s16(vcombine_s16(sampler.val[0], sampler.val[1]));
+  const uint8x8_t right_indices =
+      vadd_u8(left_indices, vdup_n_u8(sizeof(uint16_t)));
+
+  const int16x4_t shift_masked = vand_s16(shift_upsampled, vdup_n_s16(0x3f));
+  const uint16x4_t shift_0 = vreinterpret_u16_s16(vshr_n_s16(shift_masked, 1));
+  const uint16x4_t shift_1 = vsub_u16(vdup_n_u16(32), shift_0);
+
+  int y = 0;
+  do {
+    uint16x4_t src_left, src_right;
+    LoadStepwise(
+        left_column - kPositiveIndexOffsetPixels + (y << upsample_shift),
+        left_indices, right_indices, &src_left, &src_right);
+    const uint16x4_t val = WeightedBlend(src_left, src_right, shift_1, shift_0);
+
+    Store4(dst, val);
+    dst += stride;
+  } while (++y < height);
+}
+
+inline void DirectionalZone2FromLeftCol_8xH(
+    uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t stride, const int height,
+    const uint16_t* LIBGAV1_RESTRICT const left_column, const int16x8_t left_y,
+    const bool upsampled) {
+  const int upsample_shift = static_cast<int>(upsampled);
+
+  const int index_scale_bits = 6;
+  // The values in |offset_y| are negative, except for the first element, which
+  // is zero.
+  int16x8_t offset_y = left_y;
+  int16x8_t shift_upsampled = left_y;
+  // The shift argument must be a constant, otherwise use upsample_shift
+  // directly.
+  if (upsampled) {
+    offset_y = vshrq_n_s16(left_y, index_scale_bits - 1);
+    shift_upsampled = vshlq_n_s16(shift_upsampled, 1);
+  } else {
+    offset_y = vshrq_n_s16(left_y, index_scale_bits);
+  }
+  offset_y = vshlq_n_s16(offset_y, 1);
+
+  // Select values to the left of the starting point.
+  // The 15th element (and 16th) will be all the way at the end, to the right.
+  // With a negative ystep everything else will be "left" of them.
+  // This supports cumulative steps up to 15. We could support up to 16 by doing
+  // separate loads for |left_values| and |right_values|. vtbl supports 2 Q
+  // registers as input which would allow for cumulative offsets of 32.
+  // |sampler_0| indexes the first byte of each 16-bit value.
+  const int16x8_t sampler_0 =
+      vaddq_s16(offset_y, vdupq_n_s16(kPositiveIndexOffsetBytes));
+  // |sampler_1| indexes the second byte of each 16-bit value.
+  const int16x8_t sampler_1 = vaddq_s16(sampler_0, vdupq_n_s16(1));
+  const int16x8x2_t sampler = vzipq_s16(sampler_0, sampler_1);
+  const uint8x8_t left_values_0 = vqmovun_s16(sampler.val[0]);
+  const uint8x8_t left_values_1 = vqmovun_s16(sampler.val[1]);
+  const uint8x8_t right_values_0 =
+      vadd_u8(left_values_0, vdup_n_u8(sizeof(uint16_t)));
+  const uint8x8_t right_values_1 =
+      vadd_u8(left_values_1, vdup_n_u8(sizeof(uint16_t)));
+
+  const int16x8_t shift_masked = vandq_s16(shift_upsampled, vdupq_n_s16(0x3f));
+  const uint16x8_t shift_0 =
+      vreinterpretq_u16_s16(vshrq_n_s16(shift_masked, 1));
+  const uint16x8_t shift_1 = vsubq_u16(vdupq_n_u16(32), shift_0);
+
+  int y = 0;
+  do {
+    uint16x8_t src_left, src_right;
+    LoadStepwise(
+        left_column - kPositiveIndexOffsetPixels + (y << upsample_shift),
+        left_values_0, right_values_0, left_values_1, right_values_1, &src_left,
+        &src_right);
+    const uint16x8_t val = WeightedBlend(src_left, src_right, shift_1, shift_0);
+
+    Store8(dst, val);
+    dst += stride;
+  } while (++y < height);
+}
+
+template <bool upsampled>
+inline void DirectionalZone1Blend_4xH(
+    uint8_t* LIBGAV1_RESTRICT dest, const ptrdiff_t stride, const int height,
+    const uint16_t* LIBGAV1_RESTRICT const top_row, int zone_bounds, int top_x,
+    const int xstep) {
+  const int upsample_shift = static_cast<int>(upsampled);
+  const int scale_bits_x = 6 - upsample_shift;
+
+  // Representing positions along the row, which |zone_bounds| will target for
+  // the blending boundary.
+  const int16x4_t indices = {0, 1, 2, 3};
+
+  uint16x4x2_t top_vals;
+  int y = height;
+  do {
+    const uint16_t* const src = top_row + (top_x >> scale_bits_x);
+    LoadEdgeVals(&top_vals, src, upsampled);
+
+    const uint16_t shift_0 = ((top_x << upsample_shift) & 0x3f) >> 1;
+    const uint16_t shift_1 = 32 - shift_0;
+
+    const uint16x4_t val =
+        WeightedBlend(top_vals.val[0], top_vals.val[1], shift_1, shift_0);
+
+    const uint16x4_t dst_blend = Load4U16(dest);
+    // |zone_bounds| values can be negative.
+    const uint16x4_t blend = vcge_s16(indices, vdup_n_s16(zone_bounds >> 6));
+    const uint16x4_t output = vbsl_u16(blend, val, dst_blend);
+
+    Store4(dest, output);
+    dest += stride;
+    zone_bounds += xstep;
+    top_x -= xstep;
+  } while (--y != 0);
+}
+
+template <bool upsampled>
+inline void DirectionalZone1Blend_8xH(
+    uint8_t* LIBGAV1_RESTRICT dest, const ptrdiff_t stride, const int height,
+    const uint16_t* LIBGAV1_RESTRICT const top_row, int zone_bounds, int top_x,
+    const int xstep) {
+  const int upsample_shift = static_cast<int>(upsampled);
+  const int scale_bits_x = 6 - upsample_shift;
+
+  // Representing positions along the row, which |zone_bounds| will target for
+  // the blending boundary.
+  const int16x8_t indices = {0, 1, 2, 3, 4, 5, 6, 7};
+
+  uint16x8x2_t top_vals;
+  int y = height;
+  do {
+    const uint16_t* const src = top_row + (top_x >> scale_bits_x);
+    LoadEdgeVals(&top_vals, src, upsampled);
+
+    const uint16_t shift_0 = ((top_x << upsample_shift) & 0x3f) >> 1;
+    const uint16_t shift_1 = 32 - shift_0;
+
+    const uint16x8_t val =
+        WeightedBlend(top_vals.val[0], top_vals.val[1], shift_1, shift_0);
+
+    const uint16x8_t dst_blend = Load8U16(dest);
+    // |zone_bounds| values can be negative.
+    const uint16x8_t blend = vcgeq_s16(indices, vdupq_n_s16(zone_bounds >> 6));
+    const uint16x8_t output = vbslq_u16(blend, val, dst_blend);
+
+    Store8(dest, output);
+    dest += stride;
+    zone_bounds += xstep;
+    top_x -= xstep;
+  } while (--y != 0);
+}
+
+// The height at which a load of 16 bytes will not contain enough source pixels
+// from |left_column| to supply an accurate row when computing 8 pixels at a
+// time. The values are found by inspection. By coincidence, all angles that
+// satisfy (ystep >> 6) == 2 map to the same value, so it is enough to look up
+// by ystep >> 6. The largest index for this lookup is 1023 >> 6 == 15. Indices
+// that do not correspond to angle derivatives are left at zero.
+// Notably, in cases with upsampling, the shuffle-invalid height is always
+// greater than the prediction height (which is 8 at maximum).
+constexpr int kDirectionalZone2ShuffleInvalidHeight[16] = {
+    1024, 1024, 16, 16, 16, 16, 0, 0, 18, 0, 0, 0, 0, 0, 0, 40};
+
+// 7.11.2.4 (8) 90 < angle > 180
+// The strategy for these functions (4xH and 8+xH) is to know how many blocks
+// can be processed with just pixels from |top_ptr|, then handle mixed blocks,
+// then handle only blocks that take from |left_ptr|. Additionally, a fast
+// index-shuffle approach is used for pred values from |left_column| in sections
+// that permit it.
+template <bool upsampled_top, bool upsampled_left>
+inline void DirectionalZone2_4xH(
+    uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t stride,
+    const uint16_t* LIBGAV1_RESTRICT const top_row,
+    const uint16_t* LIBGAV1_RESTRICT const left_column, const int height,
+    const int xstep, const int ystep) {
+  const int upsample_left_shift = static_cast<int>(upsampled_left);
+
+  // Helper vector for index computation.
+  const int16x4_t zero_to_three = {0, 1, 2, 3};
+
+  // Loop increments for moving by block (4xN). Vertical still steps by 8. If
+  // it's only 4, it will be finished in the first iteration.
+  const ptrdiff_t stride8 = stride << 3;
+  const int xstep8 = xstep << 3;
+
+  const int min_height = (height == 4) ? 4 : 8;
+
+  // All columns from |min_top_only_x| to the right will only need |top_row| to
+  // compute and can therefore call the Zone1 functions. This assumes |xstep| is
+  // at least 3.
+  assert(xstep >= 3);
+
+  // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1
+  int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1;
+
+  const int left_base_increment = ystep >> 6;
+  const int ystep_remainder = ystep & 0x3F;
+
+  // If the 64 scaling is regarded as a decimal point, the first value of the
+  // left_y vector omits the portion which is covered under the left_column
+  // offset. The following values need the full ystep as a relative offset.
+  const int16x4_t left_y =
+      vmla_n_s16(vdup_n_s16(-ystep_remainder), zero_to_three, -ystep);
+
+  // This loop treats the 4 columns in 3 stages with y-value boundaries.
+  // The first stage, before the first y-loop, covers blocks that are only
+  // computed from the top row. The second stage, comprising two y-loops, covers
+  // blocks that have a mixture of values computed from top or left. The final
+  // stage covers blocks that are only computed from the left.
+  // Round down to the nearest multiple of 8.
+  // TODO(petersonab): Check if rounding to the nearest 4 is okay.
+  const int max_top_only_y = std::min((1 << 6) / xstep, height) & ~7;
+  DirectionalZone1_4xH<upsampled_top>(reinterpret_cast<uint16_t*>(dst),
+                                      stride >> 1, max_top_only_y, top_row,
+                                      -xstep);
+
+  if (max_top_only_y == height) return;
+
+  int y = max_top_only_y;
+  dst += stride * y;
+  const int xstep_y = xstep * y;
+
+  // All rows from |min_left_only_y| down for this set of columns only need
+  // |left_column| to compute.
+  const int min_left_only_y = std::min((4 /*width*/ << 6) / xstep, height);
+  int xstep_bounds = xstep_bounds_base + xstep_y;
+  int top_x = -xstep - xstep_y;
+
+  // +8 increment is OK because if height is 4 this only runs once.
+  for (; y < min_left_only_y;
+       y += 8, dst += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+    DirectionalZone2FromLeftCol_4xH(
+        dst, stride, min_height,
+        left_column + ((y - left_base_increment) << upsample_left_shift),
+        left_y, upsampled_left);
+
+    DirectionalZone1Blend_4xH<upsampled_top>(dst, stride, min_height, top_row,
+                                             xstep_bounds, top_x, xstep);
+  }
+
+  // Loop over y for left-only rows.
+  for (; y < height; y += 8, dst += stride8) {
+    // Angle expected by Zone3 is flipped about the 180 degree vector, which
+    // is the x-axis.
+    DirectionalZone3_4xH<upsampled_left>(
+        dst, stride, min_height, left_column + (y << upsample_left_shift),
+        -ystep);
+  }
+}
+
+// Process 8x4 and 16x4 blocks. This avoids a lot of overhead and simplifies
+// address safety.
+template <bool upsampled_top, bool upsampled_left>
+inline void DirectionalZone2_Wx4(
+    uint8_t* LIBGAV1_RESTRICT const dst, const ptrdiff_t stride,
+    const uint16_t* LIBGAV1_RESTRICT const top_row,
+    const uint16_t* LIBGAV1_RESTRICT const left_column, const int width,
+    const int xstep, const int ystep) {
+  const int upsample_top_shift = static_cast<int>(upsampled_top);
+  // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1
+  int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1;
+
+  const int min_top_only_x = std::min((4 * xstep) >> 6, width);
+  int x = 0;
+  for (; x < min_top_only_x; x += 4, xstep_bounds_base -= (4 << 6)) {
+    uint8_t* dst_x = dst + x * sizeof(uint16_t);
+
+    // Round down to the nearest multiple of 4.
+    const int max_top_only_y = (((x + 1) << 6) / xstep) & ~3;
+    if (max_top_only_y != 0) {
+      DirectionalZone1_4xH<upsampled_top>(
+          reinterpret_cast<uint16_t*>(dst_x), stride >> 1, 4,
+          top_row + (x << upsample_top_shift), -xstep);
+      continue;
+    }
+
+    DirectionalZone3_4x4<upsampled_left>(dst_x, stride, left_column, -ystep,
+                                         -ystep * x);
+
+    const int min_left_only_y = ((x + 4) << 6) / xstep;
+    if (min_left_only_y != 0) {
+      const int top_x = -xstep;
+      DirectionalZone1Blend_4xH<upsampled_top>(
+          dst_x, stride, 4, top_row + (x << upsample_top_shift),
+          xstep_bounds_base, top_x, xstep);
+    }
+  }
+  // Reached |min_top_only_x|.
+  for (; x < width; x += 4) {
+    DirectionalZone1_4xH<upsampled_top>(
+        reinterpret_cast<uint16_t*>(dst) + x, stride >> 1, 4,
+        top_row + (x << upsample_top_shift), -xstep);
+  }
+}
+
+// Process a multiple of 8 |width|.
+template <bool upsampled_top, bool upsampled_left>
+inline void DirectionalZone2_8(
+    uint8_t* LIBGAV1_RESTRICT const dst, const ptrdiff_t stride,
+    const uint16_t* LIBGAV1_RESTRICT const top_row,
+    const uint16_t* LIBGAV1_RESTRICT const left_column, const int width,
+    const int height, const int xstep, const int ystep) {
+  if (height == 4) {
+    DirectionalZone2_Wx4<upsampled_top, upsampled_left>(
+        dst, stride, top_row, left_column, width, xstep, ystep);
+    return;
+  }
+  const int upsample_left_shift = static_cast<int>(upsampled_left);
+  const int upsample_top_shift = static_cast<int>(upsampled_top);
+
+  // Helper vector.
+  const int16x8_t zero_to_seven = {0, 1, 2, 3, 4, 5, 6, 7};
+
+  // Loop increments for moving by block (8x8). This function handles blocks
+  // with height 4 as well. They are calculated in one pass so these variables
+  // do not get used.
+  const ptrdiff_t stride8 = stride << 3;
+  const int xstep8 = xstep << 3;
+  const int ystep8 = ystep << 3;
+
+  // All columns from |min_top_only_x| to the right will only need |top_row| to
+  // compute and can therefore call the Zone1 functions. This assumes |xstep| is
+  // at least 3.
+  assert(xstep >= 3);
+  const int min_top_only_x = std::min((height * xstep) >> 6, width);
+
+  // For steep angles, the source pixels from |left_column| may not fit in a
+  // 16-byte load for shuffling.
+  // TODO(petersonab): Find a more precise formula for this subject to x.
+  const int max_shuffle_height =
+      std::min(kDirectionalZone2ShuffleInvalidHeight[ystep >> 6], height);
+
+  // Offsets the original zone bound value to simplify x < (y+1)*xstep/64 -1
+  int xstep_bounds_base = (xstep == 64) ? 0 : xstep - 1;
+
+  const int left_base_increment = ystep >> 6;
+  const int ystep_remainder = ystep & 0x3F;
+
+  const int left_base_increment8 = ystep8 >> 6;
+  const int ystep_remainder8 = ystep8 & 0x3F;
+  const int16x8_t increment_left8 = vdupq_n_s16(ystep_remainder8);
+
+  // If the 64 scaling is regarded as a decimal point, the first value of the
+  // left_y vector omits the portion which is covered under the left_column
+  // offset. Following values need the full ystep as a relative offset.
+  int16x8_t left_y =
+      vmlaq_n_s16(vdupq_n_s16(-ystep_remainder), zero_to_seven, -ystep);
+
+  // This loop treats each set of 4 columns in 3 stages with y-value boundaries.
+  // The first stage, before the first y-loop, covers blocks that are only
+  // computed from the top row. The second stage, comprising two y-loops, covers
+  // blocks that have a mixture of values computed from top or left. The final
+  // stage covers blocks that are only computed from the left.
+  int x = 0;
+  for (int left_offset = -left_base_increment; x < min_top_only_x; x += 8,
+           xstep_bounds_base -= (8 << 6),
+           left_y = vsubq_s16(left_y, increment_left8),
+           left_offset -= left_base_increment8) {
+    uint8_t* dst_x = dst + x * sizeof(uint16_t);
+
+    // Round down to the nearest multiple of 8.
+    const int max_top_only_y = std::min(((x + 1) << 6) / xstep, height) & ~7;
+    DirectionalZone1_WxH<upsampled_top>(
+        reinterpret_cast<uint16_t*>(dst_x), stride >> 1, 8, max_top_only_y,
+        top_row + (x << upsample_top_shift), -xstep);
+
+    if (max_top_only_y == height) continue;
+
+    int y = max_top_only_y;
+    dst_x += stride * y;
+    const int xstep_y = xstep * y;
+
+    // All rows from |min_left_only_y| down for this set of columns only need
+    // |left_column| to compute.
+    const int min_left_only_y = std::min(((x + 8) << 6) / xstep, height);
+    // At high angles such that min_left_only_y < 8, ystep is low and xstep is
+    // high. This means that max_shuffle_height is unbounded and xstep_bounds
+    // will overflow in 16 bits. This is prevented by stopping the first
+    // blending loop at min_left_only_y for such cases, which means we skip over
+    // the second blending loop as well.
+    const int left_shuffle_stop_y =
+        std::min(max_shuffle_height, min_left_only_y);
+    int xstep_bounds = xstep_bounds_base + xstep_y;
+    int top_x = -xstep - xstep_y;
+
+    for (; y < left_shuffle_stop_y;
+         y += 8, dst_x += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+      DirectionalZone2FromLeftCol_8xH(
+          dst_x, stride, 8,
+          left_column + ((left_offset + y) << upsample_left_shift), left_y,
+          upsample_left_shift);
+
+      DirectionalZone1Blend_8xH<upsampled_top>(
+          dst_x, stride, 8, top_row + (x << upsample_top_shift), xstep_bounds,
+          top_x, xstep);
+    }
+
+    // Pick up from the last y-value, using the slower but secure method for
+    // left prediction.
+    for (; y < min_left_only_y;
+         y += 8, dst_x += stride8, xstep_bounds += xstep8, top_x -= xstep8) {
+      DirectionalZone3_8x8<upsampled_left>(
+          dst_x, stride, left_column + (y << upsample_left_shift), -ystep,
+          -ystep * x);
+
+      DirectionalZone1Blend_8xH<upsampled_top>(
+          dst_x, stride, 8, top_row + (x << upsample_top_shift), xstep_bounds,
+          top_x, xstep);
+    }
+    // Loop over y for left_only rows.
+    for (; y < height; y += 8, dst_x += stride8) {
+      DirectionalZone3_8x8<upsampled_left>(
+          dst_x, stride, left_column + (y << upsample_left_shift), -ystep,
+          -ystep * x);
+    }
+  }
+  // Reached |min_top_only_x|.
+  if (x < width) {
+    DirectionalZone1_WxH<upsampled_top>(
+        reinterpret_cast<uint16_t*>(dst) + x, stride >> 1, width - x, height,
+        top_row + (x << upsample_top_shift), -xstep);
+  }
+}
+
+// At this angle, neither edges are upsampled.
+// |min_width| is either 4 or 8.
+template <int min_width>
+void DirectionalAngle135(uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t stride,
+                         const uint16_t* LIBGAV1_RESTRICT const top,
+                         const uint16_t* LIBGAV1_RESTRICT const left,
+                         const int width, const int height) {
+  // y = 0 is more trivial than the other rows.
+  memcpy(dst, top - 1, width * sizeof(top[0]));
+  dst += stride;
+
+  // If |height| > |width|, then there is a point at which top_row is no longer
+  // used in each row.
+  const int min_left_only_y = std::min(width, height);
+
+  int y = 1;
+  do {
+    // Example: If y is 4 (min_width), the dest row starts with left[3],
+    // left[2], left[1], left[0], because the angle points up. Therefore, load
+    // starts at left[0] and is then reversed. If y is 2, the load starts at
+    // left[-2], and is reversed to store left[1], left[0], with negative values
+    // overwritten from |top_row|.
+    const uint16_t* const load_left = left + y - min_width;
+    auto* dst16 = reinterpret_cast<uint16_t*>(dst);
+
+    // Some values will be overwritten when |y| is not a multiple of
+    // |min_width|.
+    if (min_width == 4) {
+      const uint16x4_t left_toward_corner = vrev64_u16(vld1_u16(load_left));
+      vst1_u16(dst16, left_toward_corner);
+    } else {
+      int x = 0;
+      do {
+        const uint16x8_t left_toward_corner =
+            vrev64q_u16(vld1q_u16(load_left - x));
+        vst1_u16(dst16 + x, vget_high_u16(left_toward_corner));
+        vst1_u16(dst16 + x + 4, vget_low_u16(left_toward_corner));
+        x += 8;
+      } while (x < y);
+    }
+    // Entering |top|.
+    memcpy(dst16 + y, top - 1, (width - y) * sizeof(top[0]));
+    dst += stride;
+  } while (++y < min_left_only_y);
+
+  // Left only.
+  for (; y < height; ++y, dst += stride) {
+    auto* dst16 = reinterpret_cast<uint16_t*>(dst);
+    const uint16_t* const load_left = left + y - min_width;
+
+    int x = 0;
+    if (min_width == 4) {
+      const uint16x4_t left_toward_corner = vrev64_u16(vld1_u16(load_left - x));
+      vst1_u16(dst16 + x, left_toward_corner);
+    } else {
+      do {
+        const uint16x8_t left_toward_corner =
+            vrev64q_u16(vld1q_u16(load_left - x));
+        vst1_u16(dst16 + x, vget_high_u16(left_toward_corner));
+        vst1_u16(dst16 + x + 4, vget_low_u16(left_toward_corner));
+        x += 8;
+      } while (x < width);
+    }
+  }
+}
+
+void DirectionalIntraPredictorZone2_NEON(
+    void* LIBGAV1_RESTRICT dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column, const int width,
+    const int height, const int xstep, const int ystep,
+    const bool upsampled_top, const bool upsampled_left) {
+  // Increasing the negative buffer for this function allows more rows to be
+  // processed at a time without branching in an inner loop to check the base.
+  uint16_t top_buffer[288];
+  uint16_t left_buffer[288];
+#if LIBGAV1_MSAN
+  memset(top_buffer, 0, sizeof(top_buffer));
+  memset(left_buffer, 0, sizeof(left_buffer));
+#endif  // LIBGAV1_MSAN
+  memcpy(top_buffer + 128, static_cast<const uint16_t*>(top_row) - 16, 160);
+  memcpy(left_buffer + 128, static_cast<const uint16_t*>(left_column) - 16,
+         160);
+  const uint16_t* top_ptr = top_buffer + 144;
+  const uint16_t* left_ptr = left_buffer + 144;
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  if (width == 4) {
+    if (xstep == 64) {
+      assert(ystep == 64);
+      DirectionalAngle135<4>(dst, stride, top_ptr, left_ptr, width, height);
+      return;
+    }
+    if (upsampled_top) {
+      if (upsampled_left) {
+        DirectionalZone2_4xH<true, true>(dst, stride, top_ptr, left_ptr, height,
+                                         xstep, ystep);
+      } else {
+        DirectionalZone2_4xH<true, false>(dst, stride, top_ptr, left_ptr,
+                                          height, xstep, ystep);
+      }
+    } else if (upsampled_left) {
+      DirectionalZone2_4xH<false, true>(dst, stride, top_ptr, left_ptr, height,
+                                        xstep, ystep);
+    } else {
+      DirectionalZone2_4xH<false, false>(dst, stride, top_ptr, left_ptr, height,
+                                         xstep, ystep);
+    }
+    return;
+  }
+
+  if (xstep == 64) {
+    assert(ystep == 64);
+    DirectionalAngle135<8>(dst, stride, top_ptr, left_ptr, width, height);
+    return;
+  }
+  if (upsampled_top) {
+    if (upsampled_left) {
+      DirectionalZone2_8<true, true>(dst, stride, top_ptr, left_ptr, width,
+                                     height, xstep, ystep);
+    } else {
+      DirectionalZone2_8<true, false>(dst, stride, top_ptr, left_ptr, width,
+                                      height, xstep, ystep);
+    }
+  } else if (upsampled_left) {
+    DirectionalZone2_8<false, true>(dst, stride, top_ptr, left_ptr, width,
+                                    height, xstep, ystep);
+  } else {
+    DirectionalZone2_8<false, false>(dst, stride, top_ptr, left_ptr, width,
+                                     height, xstep, ystep);
+  }
+}
+
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
   assert(dsp != nullptr);
   dsp->directional_intra_predictor_zone1 = DirectionalIntraPredictorZone1_NEON;
+  dsp->directional_intra_predictor_zone2 = DirectionalIntraPredictorZone2_NEON;
   dsp->directional_intra_predictor_zone3 = DirectionalIntraPredictorZone3_NEON;
 }
 
diff --git a/libgav1/src/dsp/arm/intrapred_directional_neon.h b/libgav1/src/dsp/arm/intrapred_directional_neon.h
index f7d6235..310d90b 100644
--- a/libgav1/src/dsp/arm/intrapred_directional_neon.h
+++ b/libgav1/src/dsp/arm/intrapred_directional_neon.h
@@ -47,6 +47,10 @@
 #define LIBGAV1_Dsp10bpp_DirectionalIntraPredictorZone1 LIBGAV1_CPU_NEON
 #endif
 
+#ifndef LIBGAV1_Dsp10bpp_DirectionalIntraPredictorZone2
+#define LIBGAV1_Dsp10bpp_DirectionalIntraPredictorZone2 LIBGAV1_CPU_NEON
+#endif
+
 #ifndef LIBGAV1_Dsp10bpp_DirectionalIntraPredictorZone3
 #define LIBGAV1_Dsp10bpp_DirectionalIntraPredictorZone3 LIBGAV1_CPU_NEON
 #endif
diff --git a/libgav1/src/dsp/arm/intrapred_filter_neon.cc b/libgav1/src/dsp/arm/intrapred_filter_neon.cc
index bd9f61d..70bd62b 100644
--- a/libgav1/src/dsp/arm/intrapred_filter_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_filter_neon.cc
@@ -85,17 +85,18 @@
       {14, 12, 11, 10, 0, 0, 1, 1},
       {0, 0, 0, 0, 14, 12, 11, 9}}};
 
-void FilterIntraPredictor_NEON(void* const dest, ptrdiff_t stride,
-                               const void* const top_row,
-                               const void* const left_column,
+void FilterIntraPredictor_NEON(void* LIBGAV1_RESTRICT const dest,
+                               ptrdiff_t stride,
+                               const void* LIBGAV1_RESTRICT const top_row,
+                               const void* LIBGAV1_RESTRICT const left_column,
                                FilterIntraPredictor pred, int width,
                                int height) {
-  const uint8_t* const top = static_cast<const uint8_t*>(top_row);
-  const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+  const auto* const top = static_cast<const uint8_t*>(top_row);
+  const auto* const left = static_cast<const uint8_t*>(left_column);
 
   assert(width <= 32 && height <= 32);
 
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+  auto* dst = static_cast<uint8_t*>(dest);
 
   uint8x8_t transposed_taps[7];
   for (int i = 0; i < 7; ++i) {
@@ -160,7 +161,136 @@
 }  // namespace
 }  // namespace low_bitdepth
 
-void IntraPredFilterInit_NEON() { low_bitdepth::Init8bpp(); }
+//------------------------------------------------------------------------------
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+alignas(kMaxAlignment) constexpr int16_t
+    kTransposedTaps[kNumFilterIntraPredictors][7][8] = {
+        {{-6, -5, -3, -3, -4, -3, -3, -3},
+         {10, 2, 1, 1, 6, 2, 2, 1},
+         {0, 10, 1, 1, 0, 6, 2, 2},
+         {0, 0, 10, 2, 0, 0, 6, 2},
+         {0, 0, 0, 10, 0, 0, 0, 6},
+         {12, 9, 7, 5, 2, 2, 2, 3},
+         {0, 0, 0, 0, 12, 9, 7, 5}},
+        {{-10, -6, -4, -2, -10, -6, -4, -2},
+         {16, 0, 0, 0, 16, 0, 0, 0},
+         {0, 16, 0, 0, 0, 16, 0, 0},
+         {0, 0, 16, 0, 0, 0, 16, 0},
+         {0, 0, 0, 16, 0, 0, 0, 16},
+         {10, 6, 4, 2, 0, 0, 0, 0},
+         {0, 0, 0, 0, 10, 6, 4, 2}},
+        {{-8, -8, -8, -8, -4, -4, -4, -4},
+         {8, 0, 0, 0, 4, 0, 0, 0},
+         {0, 8, 0, 0, 0, 4, 0, 0},
+         {0, 0, 8, 0, 0, 0, 4, 0},
+         {0, 0, 0, 8, 0, 0, 0, 4},
+         {16, 16, 16, 16, 0, 0, 0, 0},
+         {0, 0, 0, 0, 16, 16, 16, 16}},
+        {{-2, -1, -1, -0, -1, -1, -1, -1},
+         {8, 3, 2, 1, 4, 3, 2, 2},
+         {0, 8, 3, 2, 0, 4, 3, 2},
+         {0, 0, 8, 3, 0, 0, 4, 3},
+         {0, 0, 0, 8, 0, 0, 0, 4},
+         {10, 6, 4, 2, 3, 4, 4, 3},
+         {0, 0, 0, 0, 10, 6, 4, 3}},
+        {{-12, -10, -9, -8, -10, -9, -8, -7},
+         {14, 0, 0, 0, 12, 1, 0, 0},
+         {0, 14, 0, 0, 0, 12, 0, 0},
+         {0, 0, 14, 0, 0, 0, 12, 1},
+         {0, 0, 0, 14, 0, 0, 0, 12},
+         {14, 12, 11, 10, 0, 0, 1, 1},
+         {0, 0, 0, 0, 14, 12, 11, 9}}};
+
+void FilterIntraPredictor_NEON(void* LIBGAV1_RESTRICT const dest,
+                               ptrdiff_t stride,
+                               const void* LIBGAV1_RESTRICT const top_row,
+                               const void* LIBGAV1_RESTRICT const left_column,
+                               FilterIntraPredictor pred, int width,
+                               int height) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+
+  assert(width <= 32 && height <= 32);
+
+  auto* dst = static_cast<uint16_t*>(dest);
+
+  stride >>= 1;
+
+  int16x8_t transposed_taps[7];
+  for (int i = 0; i < 7; ++i) {
+    transposed_taps[i] = vld1q_s16(kTransposedTaps[pred][i]);
+  }
+
+  uint16_t relative_top_left = top[-1];
+  const uint16_t* relative_top = top;
+  uint16_t relative_left[2] = {left[0], left[1]};
+
+  int y = 0;
+  do {
+    uint16_t* row_dst = dst;
+    int x = 0;
+    do {
+      int16x8_t sum =
+          vmulq_s16(transposed_taps[0],
+                    vreinterpretq_s16_u16(vdupq_n_u16(relative_top_left)));
+      for (int i = 1; i < 5; ++i) {
+        sum =
+            vmlaq_s16(sum, transposed_taps[i],
+                      vreinterpretq_s16_u16(vdupq_n_u16(relative_top[i - 1])));
+      }
+      for (int i = 5; i < 7; ++i) {
+        sum =
+            vmlaq_s16(sum, transposed_taps[i],
+                      vreinterpretq_s16_u16(vdupq_n_u16(relative_left[i - 5])));
+      }
+
+      const int16x8_t sum_shifted = vrshrq_n_s16(sum, 4);
+      const uint16x8_t sum_saturated = vminq_u16(
+          vreinterpretq_u16_s16(vmaxq_s16(sum_shifted, vdupq_n_s16(0))),
+          vdupq_n_u16((1 << kBitdepth10) - 1));
+
+      vst1_u16(row_dst, vget_low_u16(sum_saturated));
+      vst1_u16(row_dst + stride, vget_high_u16(sum_saturated));
+
+      // Progress across
+      relative_top_left = relative_top[3];
+      relative_top += 4;
+      relative_left[0] = row_dst[3];
+      relative_left[1] = row_dst[3 + stride];
+      row_dst += 4;
+      x += 4;
+    } while (x < width);
+
+    // Progress down.
+    relative_top_left = left[y + 1];
+    relative_top = dst + stride;
+    relative_left[0] = left[y + 2];
+    relative_left[1] = left[y + 3];
+
+    dst += 2 * stride;
+    y += 2;
+  } while (y < height);
+}
+
+void Init10bpp() {
+  Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->filter_intra_predictor = FilterIntraPredictor_NEON;
+}
+
+}  // namespace
+}  // namespace high_bitdepth
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+void IntraPredFilterInit_NEON() {
+  low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  high_bitdepth::Init10bpp();
+#endif
+}
 
 }  // namespace dsp
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/arm/intrapred_filter_neon.h b/libgav1/src/dsp/arm/intrapred_filter_neon.h
index 283c1b1..d005f4c 100644
--- a/libgav1/src/dsp/arm/intrapred_filter_neon.h
+++ b/libgav1/src/dsp/arm/intrapred_filter_neon.h
@@ -32,6 +32,8 @@
 
 #if LIBGAV1_ENABLE_NEON
 #define LIBGAV1_Dsp8bpp_FilterIntraPredictor LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_FilterIntraPredictor LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_INTRAPRED_FILTER_NEON_H_
diff --git a/libgav1/src/dsp/arm/intrapred_neon.cc b/libgav1/src/dsp/arm/intrapred_neon.cc
index c143648..cd47a22 100644
--- a/libgav1/src/dsp/arm/intrapred_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_neon.cc
@@ -26,6 +26,7 @@
 #include "src/dsp/arm/common_neon.h"
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/common.h"
 #include "src/utils/constants.h"
 
 namespace libgav1 {
@@ -56,10 +57,10 @@
 
 template <int block_width_log2, int block_height_log2, DcSumFunc sumfn,
           DcStoreFunc storefn>
-void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn,
-                      storefn>::DcTop(void* const dest, ptrdiff_t stride,
-                                      const void* const top_row,
-                                      const void* /*left_column*/) {
+void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn, storefn>::
+    DcTop(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+          const void* LIBGAV1_RESTRICT const top_row,
+          const void* /*left_column*/) {
   const uint32x2_t sum = sumfn(top_row, block_width_log2, false, nullptr, 0);
   const uint32x2_t dc = vrshr_n_u32(sum, block_width_log2);
   storefn(dest, stride, dc);
@@ -67,10 +68,10 @@
 
 template <int block_width_log2, int block_height_log2, DcSumFunc sumfn,
           DcStoreFunc storefn>
-void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn,
-                      storefn>::DcLeft(void* const dest, ptrdiff_t stride,
-                                       const void* /*top_row*/,
-                                       const void* const left_column) {
+void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn, storefn>::
+    DcLeft(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+           const void* /*top_row*/,
+           const void* LIBGAV1_RESTRICT const left_column) {
   const uint32x2_t sum =
       sumfn(left_column, block_height_log2, false, nullptr, 0);
   const uint32x2_t dc = vrshr_n_u32(sum, block_height_log2);
@@ -80,8 +81,9 @@
 template <int block_width_log2, int block_height_log2, DcSumFunc sumfn,
           DcStoreFunc storefn>
 void DcPredFuncs_NEON<block_width_log2, block_height_log2, sumfn, storefn>::Dc(
-    void* const dest, ptrdiff_t stride, const void* const top_row,
-    const void* const left_column) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const uint32x2_t sum =
       sumfn(top_row, block_width_log2, true, left_column, block_height_log2);
   if (block_width_log2 == block_height_log2) {
@@ -154,92 +156,116 @@
 // If |use_ref_1| is false then only sum |ref_0|.
 // For |ref[01]_size_log2| == 4 this relies on |ref_[01]| being aligned to
 // uint32_t.
-inline uint32x2_t DcSum_NEON(const void* ref_0, const int ref_0_size_log2,
-                             const bool use_ref_1, const void* ref_1,
+inline uint32x2_t DcSum_NEON(const void* LIBGAV1_RESTRICT ref_0,
+                             const int ref_0_size_log2, const bool use_ref_1,
+                             const void* LIBGAV1_RESTRICT ref_1,
                              const int ref_1_size_log2) {
   const auto* const ref_0_u8 = static_cast<const uint8_t*>(ref_0);
   const auto* const ref_1_u8 = static_cast<const uint8_t*>(ref_1);
   if (ref_0_size_log2 == 2) {
     uint8x8_t val = Load4(ref_0_u8);
     if (use_ref_1) {
-      if (ref_1_size_log2 == 2) {  // 4x4
-        val = Load4<1>(ref_1_u8, val);
-        return Sum(vpaddl_u8(val));
-      } else if (ref_1_size_log2 == 3) {  // 4x8
-        const uint8x8_t val_1 = vld1_u8(ref_1_u8);
-        const uint16x4_t sum_0 = vpaddl_u8(val);
-        const uint16x4_t sum_1 = vpaddl_u8(val_1);
-        return Sum(vadd_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 4) {  // 4x16
-        const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
-        return Sum(vaddw_u8(vpaddlq_u8(val_1), val));
+      switch (ref_1_size_log2) {
+        case 2: {  // 4x4
+          val = Load4<1>(ref_1_u8, val);
+          return Sum(vpaddl_u8(val));
+        }
+        case 3: {  // 4x8
+          const uint8x8_t val_1 = vld1_u8(ref_1_u8);
+          const uint16x4_t sum_0 = vpaddl_u8(val);
+          const uint16x4_t sum_1 = vpaddl_u8(val_1);
+          return Sum(vadd_u16(sum_0, sum_1));
+        }
+        case 4: {  // 4x16
+          const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
+          return Sum(vaddw_u8(vpaddlq_u8(val_1), val));
+        }
       }
     }
     // 4x1
     const uint16x4_t sum = vpaddl_u8(val);
     return vpaddl_u16(sum);
-  } else if (ref_0_size_log2 == 3) {
+  }
+  if (ref_0_size_log2 == 3) {
     const uint8x8_t val_0 = vld1_u8(ref_0_u8);
     if (use_ref_1) {
-      if (ref_1_size_log2 == 2) {  // 8x4
-        const uint8x8_t val_1 = Load4(ref_1_u8);
-        const uint16x4_t sum_0 = vpaddl_u8(val_0);
-        const uint16x4_t sum_1 = vpaddl_u8(val_1);
-        return Sum(vadd_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 3) {  // 8x8
-        const uint8x8_t val_1 = vld1_u8(ref_1_u8);
-        const uint16x4_t sum_0 = vpaddl_u8(val_0);
-        const uint16x4_t sum_1 = vpaddl_u8(val_1);
-        return Sum(vadd_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 4) {  // 8x16
-        const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
-        return Sum(vaddw_u8(vpaddlq_u8(val_1), val_0));
-      } else if (ref_1_size_log2 == 5) {  // 8x32
-        return Sum(vaddw_u8(LoadAndAdd32(ref_1_u8), val_0));
+      switch (ref_1_size_log2) {
+        case 2: {  // 8x4
+          const uint8x8_t val_1 = Load4(ref_1_u8);
+          const uint16x4_t sum_0 = vpaddl_u8(val_0);
+          const uint16x4_t sum_1 = vpaddl_u8(val_1);
+          return Sum(vadd_u16(sum_0, sum_1));
+        }
+        case 3: {  // 8x8
+          const uint8x8_t val_1 = vld1_u8(ref_1_u8);
+          const uint16x4_t sum_0 = vpaddl_u8(val_0);
+          const uint16x4_t sum_1 = vpaddl_u8(val_1);
+          return Sum(vadd_u16(sum_0, sum_1));
+        }
+        case 4: {  // 8x16
+          const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
+          return Sum(vaddw_u8(vpaddlq_u8(val_1), val_0));
+        }
+        case 5: {  // 8x32
+          return Sum(vaddw_u8(LoadAndAdd32(ref_1_u8), val_0));
+        }
       }
     }
     // 8x1
     return Sum(vpaddl_u8(val_0));
-  } else if (ref_0_size_log2 == 4) {
+  }
+  if (ref_0_size_log2 == 4) {
     const uint8x16_t val_0 = vld1q_u8(ref_0_u8);
     if (use_ref_1) {
-      if (ref_1_size_log2 == 2) {  // 16x4
-        const uint8x8_t val_1 = Load4(ref_1_u8);
-        return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1));
-      } else if (ref_1_size_log2 == 3) {  // 16x8
-        const uint8x8_t val_1 = vld1_u8(ref_1_u8);
-        return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1));
-      } else if (ref_1_size_log2 == 4) {  // 16x16
-        const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
-        return Sum(Add(val_0, val_1));
-      } else if (ref_1_size_log2 == 5) {  // 16x32
-        const uint16x8_t sum_0 = vpaddlq_u8(val_0);
-        const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8);
-        return Sum(vaddq_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 6) {  // 16x64
-        const uint16x8_t sum_0 = vpaddlq_u8(val_0);
-        const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8);
-        return Sum(vaddq_u16(sum_0, sum_1));
+      switch (ref_1_size_log2) {
+        case 2: {  // 16x4
+          const uint8x8_t val_1 = Load4(ref_1_u8);
+          return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1));
+        }
+        case 3: {  // 16x8
+          const uint8x8_t val_1 = vld1_u8(ref_1_u8);
+          return Sum(vaddw_u8(vpaddlq_u8(val_0), val_1));
+        }
+        case 4: {  // 16x16
+          const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
+          return Sum(Add(val_0, val_1));
+        }
+        case 5: {  // 16x32
+          const uint16x8_t sum_0 = vpaddlq_u8(val_0);
+          const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
+        case 6: {  // 16x64
+          const uint16x8_t sum_0 = vpaddlq_u8(val_0);
+          const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
       }
     }
     // 16x1
     return Sum(vpaddlq_u8(val_0));
-  } else if (ref_0_size_log2 == 5) {
+  }
+  if (ref_0_size_log2 == 5) {
     const uint16x8_t sum_0 = LoadAndAdd32(ref_0_u8);
     if (use_ref_1) {
-      if (ref_1_size_log2 == 3) {  // 32x8
-        const uint8x8_t val_1 = vld1_u8(ref_1_u8);
-        return Sum(vaddw_u8(sum_0, val_1));
-      } else if (ref_1_size_log2 == 4) {  // 32x16
-        const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
-        const uint16x8_t sum_1 = vpaddlq_u8(val_1);
-        return Sum(vaddq_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 5) {  // 32x32
-        const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8);
-        return Sum(vaddq_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 6) {  // 32x64
-        const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8);
-        return Sum(vaddq_u16(sum_0, sum_1));
+      switch (ref_1_size_log2) {
+        case 3: {  // 32x8
+          const uint8x8_t val_1 = vld1_u8(ref_1_u8);
+          return Sum(vaddw_u8(sum_0, val_1));
+        }
+        case 4: {  // 32x16
+          const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
+          const uint16x8_t sum_1 = vpaddlq_u8(val_1);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
+        case 5: {  // 32x32
+          const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
+        case 6: {  // 32x64
+          const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
       }
     }
     // 32x1
@@ -249,16 +275,20 @@
   assert(ref_0_size_log2 == 6);
   const uint16x8_t sum_0 = LoadAndAdd64(ref_0_u8);
   if (use_ref_1) {
-    if (ref_1_size_log2 == 4) {  // 64x16
-      const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
-      const uint16x8_t sum_1 = vpaddlq_u8(val_1);
-      return Sum(vaddq_u16(sum_0, sum_1));
-    } else if (ref_1_size_log2 == 5) {  // 64x32
-      const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8);
-      return Sum(vaddq_u16(sum_0, sum_1));
-    } else if (ref_1_size_log2 == 6) {  // 64x64
-      const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8);
-      return Sum(vaddq_u16(sum_0, sum_1));
+    switch (ref_1_size_log2) {
+      case 4: {  // 64x16
+        const uint8x16_t val_1 = vld1q_u8(ref_1_u8);
+        const uint16x8_t sum_1 = vpaddlq_u8(val_1);
+        return Sum(vaddq_u16(sum_0, sum_1));
+      }
+      case 5: {  // 64x32
+        const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u8);
+        return Sum(vaddq_u16(sum_0, sum_1));
+      }
+      case 6: {  // 64x64
+        const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u8);
+        return Sum(vaddq_u16(sum_0, sum_1));
+      }
     }
   }
   // 64x1
@@ -318,9 +348,10 @@
 }
 
 template <int width, int height>
-inline void Paeth4Or8xN_NEON(void* const dest, ptrdiff_t stride,
-                             const void* const top_row,
-                             const void* const left_column) {
+inline void Paeth4Or8xN_NEON(void* LIBGAV1_RESTRICT const dest,
+                             ptrdiff_t stride,
+                             const void* LIBGAV1_RESTRICT const top_row,
+                             const void* LIBGAV1_RESTRICT const left_column) {
   auto* dest_u8 = static_cast<uint8_t*>(dest);
   const auto* const top_row_u8 = static_cast<const uint8_t*>(top_row);
   const auto* const left_col_u8 = static_cast<const uint8_t*>(left_column);
@@ -425,9 +456,10 @@
       top_dist, top_left_##num##_dist_low, top_left_##num##_dist_high)
 
 template <int width, int height>
-inline void Paeth16PlusxN_NEON(void* const dest, ptrdiff_t stride,
-                               const void* const top_row,
-                               const void* const left_column) {
+inline void Paeth16PlusxN_NEON(void* LIBGAV1_RESTRICT const dest,
+                               ptrdiff_t stride,
+                               const void* LIBGAV1_RESTRICT const top_row,
+                               const void* LIBGAV1_RESTRICT const left_column) {
   auto* dest_u8 = static_cast<uint8_t*>(dest);
   const auto* const top_row_u8 = static_cast<const uint8_t*>(top_row);
   const auto* const left_col_u8 = static_cast<const uint8_t*>(left_column);
@@ -769,87 +801,111 @@
 
 // |ref_[01]| each point to 1 << |ref[01]_size_log2| packed uint16_t values.
 // If |use_ref_1| is false then only sum |ref_0|.
-inline uint32x2_t DcSum_NEON(const void* ref_0, const int ref_0_size_log2,
-                             const bool use_ref_1, const void* ref_1,
+inline uint32x2_t DcSum_NEON(const void* LIBGAV1_RESTRICT ref_0,
+                             const int ref_0_size_log2, const bool use_ref_1,
+                             const void* LIBGAV1_RESTRICT ref_1,
                              const int ref_1_size_log2) {
   const auto* ref_0_u16 = static_cast<const uint16_t*>(ref_0);
   const auto* ref_1_u16 = static_cast<const uint16_t*>(ref_1);
   if (ref_0_size_log2 == 2) {
     const uint16x4_t val_0 = vld1_u16(ref_0_u16);
     if (use_ref_1) {
-      if (ref_1_size_log2 == 2) {  // 4x4
-        const uint16x4_t val_1 = vld1_u16(ref_1_u16);
-        return Sum(vadd_u16(val_0, val_1));
-      } else if (ref_1_size_log2 == 3) {  // 4x8
-        const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
-        const uint16x8_t sum_0 = vcombine_u16(vdup_n_u16(0), val_0);
-        return Sum(vaddq_u16(sum_0, val_1));
-      } else if (ref_1_size_log2 == 4) {  // 4x16
-        const uint16x8_t sum_0 = vcombine_u16(vdup_n_u16(0), val_0);
-        const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
-        return Sum(vaddq_u16(sum_0, sum_1));
+      switch (ref_1_size_log2) {
+        case 2: {  // 4x4
+          const uint16x4_t val_1 = vld1_u16(ref_1_u16);
+          return Sum(vadd_u16(val_0, val_1));
+        }
+        case 3: {  // 4x8
+          const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
+          const uint16x8_t sum_0 = vcombine_u16(vdup_n_u16(0), val_0);
+          return Sum(vaddq_u16(sum_0, val_1));
+        }
+        case 4: {  // 4x16
+          const uint16x8_t sum_0 = vcombine_u16(vdup_n_u16(0), val_0);
+          const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
       }
     }
     // 4x1
     return Sum(val_0);
-  } else if (ref_0_size_log2 == 3) {
+  }
+  if (ref_0_size_log2 == 3) {
     const uint16x8_t val_0 = vld1q_u16(ref_0_u16);
     if (use_ref_1) {
-      if (ref_1_size_log2 == 2) {  // 8x4
-        const uint16x4_t val_1 = vld1_u16(ref_1_u16);
-        const uint16x8_t sum_1 = vcombine_u16(vdup_n_u16(0), val_1);
-        return Sum(vaddq_u16(val_0, sum_1));
-      } else if (ref_1_size_log2 == 3) {  // 8x8
-        const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
-        return Sum(vaddq_u16(val_0, val_1));
-      } else if (ref_1_size_log2 == 4) {  // 8x16
-        const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
-        return Sum(vaddq_u16(val_0, sum_1));
-      } else if (ref_1_size_log2 == 5) {  // 8x32
-        const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
-        return Sum(vaddq_u16(val_0, sum_1));
+      switch (ref_1_size_log2) {
+        case 2: {  // 8x4
+          const uint16x4_t val_1 = vld1_u16(ref_1_u16);
+          const uint16x8_t sum_1 = vcombine_u16(vdup_n_u16(0), val_1);
+          return Sum(vaddq_u16(val_0, sum_1));
+        }
+        case 3: {  // 8x8
+          const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
+          return Sum(vaddq_u16(val_0, val_1));
+        }
+        case 4: {  // 8x16
+          const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
+          return Sum(vaddq_u16(val_0, sum_1));
+        }
+        case 5: {  // 8x32
+          const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
+          return Sum(vaddq_u16(val_0, sum_1));
+        }
       }
     }
     // 8x1
     return Sum(val_0);
-  } else if (ref_0_size_log2 == 4) {
+  }
+  if (ref_0_size_log2 == 4) {
     const uint16x8_t sum_0 = LoadAndAdd16(ref_0_u16);
     if (use_ref_1) {
-      if (ref_1_size_log2 == 2) {  // 16x4
-        const uint16x4_t val_1 = vld1_u16(ref_1_u16);
-        const uint16x8_t sum_1 = vcombine_u16(vdup_n_u16(0), val_1);
-        return Sum(vaddq_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 3) {  // 16x8
-        const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
-        return Sum(vaddq_u16(sum_0, val_1));
-      } else if (ref_1_size_log2 == 4) {  // 16x16
-        const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
-        return Sum(vaddq_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 5) {  // 16x32
-        const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
-        return Sum(vaddq_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 6) {  // 16x64
-        const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16);
-        return Sum(vaddq_u16(sum_0, sum_1));
+      switch (ref_1_size_log2) {
+        case 2: {  // 16x4
+          const uint16x4_t val_1 = vld1_u16(ref_1_u16);
+          const uint16x8_t sum_1 = vcombine_u16(vdup_n_u16(0), val_1);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
+        case 3: {  // 16x8
+          const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
+          return Sum(vaddq_u16(sum_0, val_1));
+        }
+        case 4: {  // 16x16
+          const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
+        case 5: {  // 16x32
+          const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
+        case 6: {  // 16x64
+          const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
       }
     }
     // 16x1
     return Sum(sum_0);
-  } else if (ref_0_size_log2 == 5) {
+  }
+  if (ref_0_size_log2 == 5) {
     const uint16x8_t sum_0 = LoadAndAdd32(ref_0_u16);
     if (use_ref_1) {
-      if (ref_1_size_log2 == 3) {  // 32x8
-        const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
-        return Sum(vaddq_u16(sum_0, val_1));
-      } else if (ref_1_size_log2 == 4) {  // 32x16
-        const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
-        return Sum(vaddq_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 5) {  // 32x32
-        const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
-        return Sum(vaddq_u16(sum_0, sum_1));
-      } else if (ref_1_size_log2 == 6) {  // 32x64
-        const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16);
-        return Sum(vaddq_u16(sum_0, sum_1));
+      switch (ref_1_size_log2) {
+        case 3: {  // 32x8
+          const uint16x8_t val_1 = vld1q_u16(ref_1_u16);
+          return Sum(vaddq_u16(sum_0, val_1));
+        }
+        case 4: {  // 32x16
+          const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
+        case 5: {  // 32x32
+          const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
+        case 6: {  // 32x64
+          const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16);
+          return Sum(vaddq_u16(sum_0, sum_1));
+        }
       }
     }
     // 32x1
@@ -859,15 +915,19 @@
   assert(ref_0_size_log2 == 6);
   const uint16x8_t sum_0 = LoadAndAdd64(ref_0_u16);
   if (use_ref_1) {
-    if (ref_1_size_log2 == 4) {  // 64x16
-      const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
-      return Sum(vaddq_u16(sum_0, sum_1));
-    } else if (ref_1_size_log2 == 5) {  // 64x32
-      const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
-      return Sum(vaddq_u16(sum_0, sum_1));
-    } else if (ref_1_size_log2 == 6) {  // 64x64
-      const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16);
-      return Sum(vaddq_u16(sum_0, sum_1));
+    switch (ref_1_size_log2) {
+      case 4: {  // 64x16
+        const uint16x8_t sum_1 = LoadAndAdd16(ref_1_u16);
+        return Sum(vaddq_u16(sum_0, sum_1));
+      }
+      case 5: {  // 64x32
+        const uint16x8_t sum_1 = LoadAndAdd32(ref_1_u16);
+        return Sum(vaddq_u16(sum_0, sum_1));
+      }
+      case 6: {  // 64x64
+        const uint16x8_t sum_1 = LoadAndAdd64(ref_1_u16);
+        return Sum(vaddq_u16(sum_0, sum_1));
+      }
     }
   }
   // 64x1
@@ -968,9 +1028,9 @@
 // IntraPredFuncs_NEON::Horizontal -- duplicate left column across all rows
 
 template <int block_height>
-void Horizontal4xH_NEON(void* const dest, ptrdiff_t stride,
+void Horizontal4xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
                         const void* /*top_row*/,
-                        const void* const left_column) {
+                        const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left = static_cast<const uint16_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   int y = 0;
@@ -983,9 +1043,9 @@
 }
 
 template <int block_height>
-void Horizontal8xH_NEON(void* const dest, ptrdiff_t stride,
+void Horizontal8xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
                         const void* /*top_row*/,
-                        const void* const left_column) {
+                        const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left = static_cast<const uint16_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   int y = 0;
@@ -998,9 +1058,9 @@
 }
 
 template <int block_height>
-void Horizontal16xH_NEON(void* const dest, ptrdiff_t stride,
+void Horizontal16xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
                          const void* /*top_row*/,
-                         const void* const left_column) {
+                         const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left = static_cast<const uint16_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   int y = 0;
@@ -1020,9 +1080,9 @@
 }
 
 template <int block_height>
-void Horizontal32xH_NEON(void* const dest, ptrdiff_t stride,
+void Horizontal32xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
                          const void* /*top_row*/,
-                         const void* const left_column) {
+                         const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left = static_cast<const uint16_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   int y = 0;
@@ -1048,8 +1108,8 @@
 // IntraPredFuncs_NEON::Vertical -- copy top row to all rows
 
 template <int block_height>
-void Vertical4xH_NEON(void* const dest, ptrdiff_t stride,
-                      const void* const top_row,
+void Vertical4xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT const top_row,
                       const void* const /*left_column*/) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   auto* dst = static_cast<uint8_t*>(dest);
@@ -1062,8 +1122,8 @@
 }
 
 template <int block_height>
-void Vertical8xH_NEON(void* const dest, ptrdiff_t stride,
-                      const void* const top_row,
+void Vertical8xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT const top_row,
                       const void* const /*left_column*/) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   auto* dst = static_cast<uint8_t*>(dest);
@@ -1076,8 +1136,8 @@
 }
 
 template <int block_height>
-void Vertical16xH_NEON(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
+void Vertical16xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
                        const void* const /*left_column*/) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   auto* dst = static_cast<uint8_t*>(dest);
@@ -1096,8 +1156,8 @@
 }
 
 template <int block_height>
-void Vertical32xH_NEON(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
+void Vertical32xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
                        const void* const /*left_column*/) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   auto* dst = static_cast<uint8_t*>(dest);
@@ -1122,8 +1182,8 @@
 }
 
 template <int block_height>
-void Vertical64xH_NEON(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
+void Vertical64xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
                        const void* const /*left_column*/) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   auto* dst = static_cast<uint8_t*>(dest);
@@ -1159,6 +1219,145 @@
   } while (y != 0);
 }
 
+template <int height>
+inline void Paeth4xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                          const void* LIBGAV1_RESTRICT const top_ptr,
+                          const void* LIBGAV1_RESTRICT const left_ptr) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const auto* const top_row = static_cast<const uint16_t*>(top_ptr);
+  const auto* const left_col = static_cast<const uint16_t*>(left_ptr);
+
+  const uint16x4_t top_left = vdup_n_u16(top_row[-1]);
+  const uint16x4_t top_left_x2 = vshl_n_u16(top_left, 1);
+  const uint16x4_t top = vld1_u16(top_row);
+
+  for (int y = 0; y < height; ++y) {
+    auto* dst16 = reinterpret_cast<uint16_t*>(dst);
+    const uint16x4_t left = vdup_n_u16(left_col[y]);
+
+    const uint16x4_t left_dist = vabd_u16(top, top_left);
+    const uint16x4_t top_dist = vabd_u16(left, top_left);
+    const uint16x4_t top_left_dist = vabd_u16(vadd_u16(top, left), top_left_x2);
+
+    const uint16x4_t left_le_top = vcle_u16(left_dist, top_dist);
+    const uint16x4_t left_le_top_left = vcle_u16(left_dist, top_left_dist);
+    const uint16x4_t top_le_top_left = vcle_u16(top_dist, top_left_dist);
+
+    // if (left_dist <= top_dist && left_dist <= top_left_dist)
+    const uint16x4_t left_mask = vand_u16(left_le_top, left_le_top_left);
+    //   dest[x] = left_column[y];
+    // Fill all the unused spaces with 'top'. They will be overwritten when
+    // the positions for top_left are known.
+    uint16x4_t result = vbsl_u16(left_mask, left, top);
+    // else if (top_dist <= top_left_dist)
+    //   dest[x] = top_row[x];
+    // Add these values to the mask. They were already set.
+    const uint16x4_t left_or_top_mask = vorr_u16(left_mask, top_le_top_left);
+    // else
+    //   dest[x] = top_left;
+    result = vbsl_u16(left_or_top_mask, result, top_left);
+
+    vst1_u16(dst16, result);
+    dst += stride;
+  }
+}
+
+template <int height>
+inline void Paeth8xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                          const void* LIBGAV1_RESTRICT const top_ptr,
+                          const void* LIBGAV1_RESTRICT const left_ptr) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const auto* const top_row = static_cast<const uint16_t*>(top_ptr);
+  const auto* const left_col = static_cast<const uint16_t*>(left_ptr);
+
+  const uint16x8_t top_left = vdupq_n_u16(top_row[-1]);
+  const uint16x8_t top_left_x2 = vshlq_n_u16(top_left, 1);
+  const uint16x8_t top = vld1q_u16(top_row);
+
+  for (int y = 0; y < height; ++y) {
+    auto* dst16 = reinterpret_cast<uint16_t*>(dst);
+    const uint16x8_t left = vdupq_n_u16(left_col[y]);
+
+    const uint16x8_t left_dist = vabdq_u16(top, top_left);
+    const uint16x8_t top_dist = vabdq_u16(left, top_left);
+    const uint16x8_t top_left_dist =
+        vabdq_u16(vaddq_u16(top, left), top_left_x2);
+
+    const uint16x8_t left_le_top = vcleq_u16(left_dist, top_dist);
+    const uint16x8_t left_le_top_left = vcleq_u16(left_dist, top_left_dist);
+    const uint16x8_t top_le_top_left = vcleq_u16(top_dist, top_left_dist);
+
+    // if (left_dist <= top_dist && left_dist <= top_left_dist)
+    const uint16x8_t left_mask = vandq_u16(left_le_top, left_le_top_left);
+    //   dest[x] = left_column[y];
+    // Fill all the unused spaces with 'top'. They will be overwritten when
+    // the positions for top_left are known.
+    uint16x8_t result = vbslq_u16(left_mask, left, top);
+    // else if (top_dist <= top_left_dist)
+    //   dest[x] = top_row[x];
+    // Add these values to the mask. They were already set.
+    const uint16x8_t left_or_top_mask = vorrq_u16(left_mask, top_le_top_left);
+    // else
+    //   dest[x] = top_left;
+    result = vbslq_u16(left_or_top_mask, result, top_left);
+
+    vst1q_u16(dst16, result);
+    dst += stride;
+  }
+}
+
+// For 16xH and above.
+template <int width, int height>
+inline void PaethWxH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                          const void* LIBGAV1_RESTRICT const top_ptr,
+                          const void* LIBGAV1_RESTRICT const left_ptr) {
+  auto* dst = static_cast<uint8_t*>(dest);
+  const auto* const top_row = static_cast<const uint16_t*>(top_ptr);
+  const auto* const left_col = static_cast<const uint16_t*>(left_ptr);
+
+  const uint16x8_t top_left = vdupq_n_u16(top_row[-1]);
+  const uint16x8_t top_left_x2 = vshlq_n_u16(top_left, 1);
+
+  uint16x8_t top[width >> 3];
+  for (int i = 0; i < width >> 3; ++i) {
+    top[i] = vld1q_u16(top_row + (i << 3));
+  }
+
+  for (int y = 0; y < height; ++y) {
+    auto* dst_x = reinterpret_cast<uint16_t*>(dst);
+    const uint16x8_t left = vdupq_n_u16(left_col[y]);
+    const uint16x8_t top_dist = vabdq_u16(left, top_left);
+
+    for (int i = 0; i < (width >> 3); ++i) {
+      const uint16x8_t left_dist = vabdq_u16(top[i], top_left);
+      const uint16x8_t top_left_dist =
+          vabdq_u16(vaddq_u16(top[i], left), top_left_x2);
+
+      const uint16x8_t left_le_top = vcleq_u16(left_dist, top_dist);
+      const uint16x8_t left_le_top_left = vcleq_u16(left_dist, top_left_dist);
+      const uint16x8_t top_le_top_left = vcleq_u16(top_dist, top_left_dist);
+
+      // if (left_dist <= top_dist && left_dist <= top_left_dist)
+      const uint16x8_t left_mask = vandq_u16(left_le_top, left_le_top_left);
+      //   dest[x] = left_column[y];
+      // Fill all the unused spaces with 'top'. They will be overwritten when
+      // the positions for top_left are known.
+      uint16x8_t result = vbslq_u16(left_mask, left, top[i]);
+      // else if (top_dist <= top_left_dist)
+      //   dest[x] = top_row[x];
+      // Add these values to the mask. They were already set.
+      const uint16x8_t left_or_top_mask = vorrq_u16(left_mask, top_le_top_left);
+      // else
+      //   dest[x] = top_left;
+      result = vbslq_u16(left_or_top_mask, result, top_left);
+
+      vst1q_u16(dst_x, result);
+      dst_x += 8;
+    }
+    dst += stride;
+  }
+}
+
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
   assert(dsp != nullptr);
@@ -1170,6 +1369,8 @@
       DcDefs::_4x4::Dc;
   dsp->intra_predictors[kTransformSize4x4][kIntraPredictorVertical] =
       Vertical4xH_NEON<4>;
+  dsp->intra_predictors[kTransformSize4x4][kIntraPredictorPaeth] =
+      Paeth4xH_NEON<4>;
 
   // 4x8
   dsp->intra_predictors[kTransformSize4x8][kIntraPredictorDcTop] =
@@ -1182,6 +1383,8 @@
       Horizontal4xH_NEON<8>;
   dsp->intra_predictors[kTransformSize4x8][kIntraPredictorVertical] =
       Vertical4xH_NEON<8>;
+  dsp->intra_predictors[kTransformSize4x8][kIntraPredictorPaeth] =
+      Paeth4xH_NEON<8>;
 
   // 4x16
   dsp->intra_predictors[kTransformSize4x16][kIntraPredictorDcTop] =
@@ -1194,6 +1397,8 @@
       Horizontal4xH_NEON<16>;
   dsp->intra_predictors[kTransformSize4x16][kIntraPredictorVertical] =
       Vertical4xH_NEON<16>;
+  dsp->intra_predictors[kTransformSize4x16][kIntraPredictorPaeth] =
+      Paeth4xH_NEON<16>;
 
   // 8x4
   dsp->intra_predictors[kTransformSize8x4][kIntraPredictorDcTop] =
@@ -1204,6 +1409,8 @@
       DcDefs::_8x4::Dc;
   dsp->intra_predictors[kTransformSize8x4][kIntraPredictorVertical] =
       Vertical8xH_NEON<4>;
+  dsp->intra_predictors[kTransformSize8x4][kIntraPredictorPaeth] =
+      Paeth8xH_NEON<4>;
 
   // 8x8
   dsp->intra_predictors[kTransformSize8x8][kIntraPredictorDcTop] =
@@ -1216,6 +1423,8 @@
       Horizontal8xH_NEON<8>;
   dsp->intra_predictors[kTransformSize8x8][kIntraPredictorVertical] =
       Vertical8xH_NEON<8>;
+  dsp->intra_predictors[kTransformSize8x8][kIntraPredictorPaeth] =
+      Paeth8xH_NEON<8>;
 
   // 8x16
   dsp->intra_predictors[kTransformSize8x16][kIntraPredictorDcTop] =
@@ -1226,6 +1435,8 @@
       DcDefs::_8x16::Dc;
   dsp->intra_predictors[kTransformSize8x16][kIntraPredictorVertical] =
       Vertical8xH_NEON<16>;
+  dsp->intra_predictors[kTransformSize8x16][kIntraPredictorPaeth] =
+      Paeth8xH_NEON<16>;
 
   // 8x32
   dsp->intra_predictors[kTransformSize8x32][kIntraPredictorDcTop] =
@@ -1238,6 +1449,8 @@
       Horizontal8xH_NEON<32>;
   dsp->intra_predictors[kTransformSize8x32][kIntraPredictorVertical] =
       Vertical8xH_NEON<32>;
+  dsp->intra_predictors[kTransformSize8x32][kIntraPredictorPaeth] =
+      Paeth8xH_NEON<32>;
 
   // 16x4
   dsp->intra_predictors[kTransformSize16x4][kIntraPredictorDcTop] =
@@ -1248,6 +1461,8 @@
       DcDefs::_16x4::Dc;
   dsp->intra_predictors[kTransformSize16x4][kIntraPredictorVertical] =
       Vertical16xH_NEON<4>;
+  dsp->intra_predictors[kTransformSize16x4][kIntraPredictorPaeth] =
+      PaethWxH_NEON<16, 4>;
 
   // 16x8
   dsp->intra_predictors[kTransformSize16x8][kIntraPredictorDcTop] =
@@ -1260,6 +1475,8 @@
       Horizontal16xH_NEON<8>;
   dsp->intra_predictors[kTransformSize16x8][kIntraPredictorVertical] =
       Vertical16xH_NEON<8>;
+  dsp->intra_predictors[kTransformSize16x8][kIntraPredictorPaeth] =
+      PaethWxH_NEON<16, 8>;
 
   // 16x16
   dsp->intra_predictors[kTransformSize16x16][kIntraPredictorDcTop] =
@@ -1270,6 +1487,8 @@
       DcDefs::_16x16::Dc;
   dsp->intra_predictors[kTransformSize16x16][kIntraPredictorVertical] =
       Vertical16xH_NEON<16>;
+  dsp->intra_predictors[kTransformSize16x16][kIntraPredictorPaeth] =
+      PaethWxH_NEON<16, 16>;
 
   // 16x32
   dsp->intra_predictors[kTransformSize16x32][kIntraPredictorDcTop] =
@@ -1280,6 +1499,8 @@
       DcDefs::_16x32::Dc;
   dsp->intra_predictors[kTransformSize16x32][kIntraPredictorVertical] =
       Vertical16xH_NEON<32>;
+  dsp->intra_predictors[kTransformSize16x32][kIntraPredictorPaeth] =
+      PaethWxH_NEON<16, 32>;
 
   // 16x64
   dsp->intra_predictors[kTransformSize16x64][kIntraPredictorDcTop] =
@@ -1290,6 +1511,8 @@
       DcDefs::_16x64::Dc;
   dsp->intra_predictors[kTransformSize16x64][kIntraPredictorVertical] =
       Vertical16xH_NEON<64>;
+  dsp->intra_predictors[kTransformSize16x64][kIntraPredictorPaeth] =
+      PaethWxH_NEON<16, 64>;
 
   // 32x8
   dsp->intra_predictors[kTransformSize32x8][kIntraPredictorDcTop] =
@@ -1300,6 +1523,8 @@
       DcDefs::_32x8::Dc;
   dsp->intra_predictors[kTransformSize32x8][kIntraPredictorVertical] =
       Vertical32xH_NEON<8>;
+  dsp->intra_predictors[kTransformSize32x8][kIntraPredictorPaeth] =
+      PaethWxH_NEON<32, 8>;
 
   // 32x16
   dsp->intra_predictors[kTransformSize32x16][kIntraPredictorDcTop] =
@@ -1310,6 +1535,8 @@
       DcDefs::_32x16::Dc;
   dsp->intra_predictors[kTransformSize32x16][kIntraPredictorVertical] =
       Vertical32xH_NEON<16>;
+  dsp->intra_predictors[kTransformSize32x16][kIntraPredictorPaeth] =
+      PaethWxH_NEON<32, 16>;
 
   // 32x32
   dsp->intra_predictors[kTransformSize32x32][kIntraPredictorDcTop] =
@@ -1320,6 +1547,8 @@
       DcDefs::_32x32::Dc;
   dsp->intra_predictors[kTransformSize32x32][kIntraPredictorVertical] =
       Vertical32xH_NEON<32>;
+  dsp->intra_predictors[kTransformSize32x32][kIntraPredictorPaeth] =
+      PaethWxH_NEON<32, 32>;
 
   // 32x64
   dsp->intra_predictors[kTransformSize32x64][kIntraPredictorDcTop] =
@@ -1332,6 +1561,8 @@
       Horizontal32xH_NEON<64>;
   dsp->intra_predictors[kTransformSize32x64][kIntraPredictorVertical] =
       Vertical32xH_NEON<64>;
+  dsp->intra_predictors[kTransformSize32x64][kIntraPredictorPaeth] =
+      PaethWxH_NEON<32, 64>;
 
   // 64x16
   dsp->intra_predictors[kTransformSize64x16][kIntraPredictorDcTop] =
@@ -1342,6 +1573,8 @@
       DcDefs::_64x16::Dc;
   dsp->intra_predictors[kTransformSize64x16][kIntraPredictorVertical] =
       Vertical64xH_NEON<16>;
+  dsp->intra_predictors[kTransformSize64x16][kIntraPredictorPaeth] =
+      PaethWxH_NEON<64, 16>;
 
   // 64x32
   dsp->intra_predictors[kTransformSize64x32][kIntraPredictorDcTop] =
@@ -1352,6 +1585,8 @@
       DcDefs::_64x32::Dc;
   dsp->intra_predictors[kTransformSize64x32][kIntraPredictorVertical] =
       Vertical64xH_NEON<32>;
+  dsp->intra_predictors[kTransformSize64x32][kIntraPredictorPaeth] =
+      PaethWxH_NEON<64, 32>;
 
   // 64x64
   dsp->intra_predictors[kTransformSize64x64][kIntraPredictorDcTop] =
@@ -1362,6 +1597,8 @@
       DcDefs::_64x64::Dc;
   dsp->intra_predictors[kTransformSize64x64][kIntraPredictorVertical] =
       Vertical64xH_NEON<64>;
+  dsp->intra_predictors[kTransformSize64x64][kIntraPredictorPaeth] =
+      PaethWxH_NEON<64, 64>;
 }
 
 }  // namespace
diff --git a/libgav1/src/dsp/arm/intrapred_neon.h b/libgav1/src/dsp/arm/intrapred_neon.h
index b27f29f..5a56924 100644
--- a/libgav1/src/dsp/arm/intrapred_neon.h
+++ b/libgav1/src/dsp/arm/intrapred_neon.h
@@ -152,6 +152,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 4x8
 #define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -161,6 +162,7 @@
   LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 4x16
 #define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -170,6 +172,7 @@
   LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 8x4
 #define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -177,6 +180,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 8x8
 #define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -186,6 +190,7 @@
   LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 8x16
 #define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -193,6 +198,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 8x32
 #define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -202,6 +208,7 @@
   LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 16x4
 #define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -209,6 +216,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 16x8
 #define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -218,6 +226,7 @@
   LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 16x16
 #define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -226,6 +235,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 16x32
 #define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -234,6 +244,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 16x64
 #define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -242,6 +253,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 32x8
 #define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -249,6 +261,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 32x16
 #define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -257,6 +270,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 32x32
 #define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -265,6 +279,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 32x64
 #define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -275,6 +290,7 @@
   LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 64x16
 #define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -283,6 +299,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 64x32
 #define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -291,6 +308,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorPaeth LIBGAV1_CPU_NEON
 
 // 64x64
 #define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDcTop LIBGAV1_CPU_NEON
@@ -299,6 +317,7 @@
 #define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorDc LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorVertical \
   LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorPaeth LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_INTRAPRED_NEON_H_
diff --git a/libgav1/src/dsp/arm/intrapred_smooth_neon.cc b/libgav1/src/dsp/arm/intrapred_smooth_neon.cc
index c33f333..bcda131 100644
--- a/libgav1/src/dsp/arm/intrapred_smooth_neon.cc
+++ b/libgav1/src/dsp/arm/intrapred_smooth_neon.cc
@@ -26,6 +26,7 @@
 #include "src/dsp/arm/common_neon.h"
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
+#include "src/utils/common.h"
 #include "src/utils/constants.h"
 
 namespace libgav1 {
@@ -38,24 +39,9 @@
 // to have visibility of the values. This helps reduce loads and in the
 // creation of the inverse weights.
 constexpr uint8_t kSmoothWeights[] = {
-    // block dimension = 4
-    255, 149, 85, 64,
-    // block dimension = 8
-    255, 197, 146, 105, 73, 50, 37, 32,
-    // block dimension = 16
-    255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16,
-    // block dimension = 32
-    255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74,
-    66, 59, 52, 45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8,
-    // block dimension = 64
-    255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156,
-    150, 144, 138, 133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73,
-    69, 65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16,
-    15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4};
+#include "src/dsp/smooth_weights.inc"
+};
 
-// TODO(b/150459137): Keeping the intermediate values in uint16_t would allow
-// processing more values at once. At the high end, it could do 4x4 or 8x2 at a
-// time.
 inline uint16x4_t CalculatePred(const uint16x4_t weighted_top,
                                 const uint16x4_t weighted_left,
                                 const uint16x4_t weighted_bl,
@@ -66,26 +52,74 @@
   return vrshrn_n_u32(pred_2, kSmoothWeightScale + 1);
 }
 
-template <int width, int height>
-inline void Smooth4Or8xN_NEON(void* const dest, ptrdiff_t stride,
-                              const void* const top_row,
-                              const void* const left_column) {
-  const uint8_t* const top = static_cast<const uint8_t*>(top_row);
-  const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+template <int height>
+inline void Smooth4xN_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                           const void* LIBGAV1_RESTRICT const top_row,
+                           const void* LIBGAV1_RESTRICT const left_column) {
+  constexpr int width = 4;
+  const auto* const top = static_cast<const uint8_t*>(top_row);
+  const auto* const left = static_cast<const uint8_t*>(left_column);
   const uint8_t top_right = top[width - 1];
   const uint8_t bottom_left = left[height - 1];
   const uint8_t* const weights_y = kSmoothWeights + height - 4;
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+  auto* dst = static_cast<uint8_t*>(dest);
 
-  uint8x8_t top_v;
-  if (width == 4) {
-    top_v = Load4(top);
-  } else {  // width == 8
-    top_v = vld1_u8(top);
-  }
+  const uint8x8_t top_v = Load4(top);
   const uint8x8_t top_right_v = vdup_n_u8(top_right);
   const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left);
-  // Over-reads for 4xN but still within the array.
+  const uint8x8_t weights_x_v = Load4(kSmoothWeights + width - 4);
+  // 256 - weights = vneg_s8(weights)
+  const uint8x8_t scaled_weights_x =
+      vreinterpret_u8_s8(vneg_s8(vreinterpret_s8_u8(weights_x_v)));
+
+  for (int y = 0; y < height; ++y) {
+    const uint8x8_t left_v = vdup_n_u8(left[y]);
+    const uint8x8_t weights_y_v = vdup_n_u8(weights_y[y]);
+    const uint8x8_t scaled_weights_y =
+        vreinterpret_u8_s8(vneg_s8(vreinterpret_s8_u8(weights_y_v)));
+    const uint16x4_t weighted_bl =
+        vget_low_u16(vmull_u8(scaled_weights_y, bottom_left_v));
+
+    const uint16x4_t weighted_top = vget_low_u16(vmull_u8(weights_y_v, top_v));
+    const uint16x4_t weighted_left =
+        vget_low_u16(vmull_u8(weights_x_v, left_v));
+    const uint16x4_t weighted_tr =
+        vget_low_u16(vmull_u8(scaled_weights_x, top_right_v));
+    const uint16x4_t result =
+        CalculatePred(weighted_top, weighted_left, weighted_bl, weighted_tr);
+
+    StoreLo4(dst, vmovn_u16(vcombine_u16(result, result)));
+    dst += stride;
+  }
+}
+
+inline uint8x8_t CalculatePred(const uint16x8_t weighted_top,
+                               const uint16x8_t weighted_left,
+                               const uint16x8_t weighted_bl,
+                               const uint16x8_t weighted_tr) {
+  // Maximum value: 0xFF00
+  const uint16x8_t pred_0 = vaddq_u16(weighted_top, weighted_bl);
+  // Maximum value: 0xFF00
+  const uint16x8_t pred_1 = vaddq_u16(weighted_left, weighted_tr);
+  const uint16x8_t pred_2 = vhaddq_u16(pred_0, pred_1);
+  return vrshrn_n_u16(pred_2, kSmoothWeightScale);
+}
+
+template <int height>
+inline void Smooth8xN_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                           const void* LIBGAV1_RESTRICT const top_row,
+                           const void* LIBGAV1_RESTRICT const left_column) {
+  constexpr int width = 8;
+  const auto* const top = static_cast<const uint8_t*>(top_row);
+  const auto* const left = static_cast<const uint8_t*>(left_column);
+  const uint8_t top_right = top[width - 1];
+  const uint8_t bottom_left = left[height - 1];
+  const uint8_t* const weights_y = kSmoothWeights + height - 4;
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  const uint8x8_t top_v = vld1_u8(top);
+  const uint8x8_t top_right_v = vdup_n_u8(top_right);
+  const uint8x8_t bottom_left_v = vdup_n_u8(bottom_left);
   const uint8x8_t weights_x_v = vld1_u8(kSmoothWeights + width - 4);
   // 256 - weights = vneg_s8(weights)
   const uint8x8_t scaled_weights_x =
@@ -100,18 +134,10 @@
     const uint16x8_t weighted_top = vmull_u8(weights_y_v, top_v);
     const uint16x8_t weighted_left = vmull_u8(weights_x_v, left_v);
     const uint16x8_t weighted_tr = vmull_u8(scaled_weights_x, top_right_v);
-    const uint16x4_t dest_0 =
-        CalculatePred(vget_low_u16(weighted_top), vget_low_u16(weighted_left),
-                      vget_low_u16(weighted_tr), vget_low_u16(weighted_bl));
+    const uint8x8_t result =
+        CalculatePred(weighted_top, weighted_left, weighted_bl, weighted_tr);
 
-    if (width == 4) {
-      StoreLo4(dst, vmovn_u16(vcombine_u16(dest_0, dest_0)));
-    } else {  // width == 8
-      const uint16x4_t dest_1 = CalculatePred(
-          vget_high_u16(weighted_top), vget_high_u16(weighted_left),
-          vget_high_u16(weighted_tr), vget_high_u16(weighted_bl));
-      vst1_u8(dst, vmovn_u16(vcombine_u16(dest_0, dest_1)));
-    }
+    vst1_u8(dst, result);
     dst += stride;
   }
 }
@@ -124,39 +150,30 @@
   const uint16x8_t weighted_left_low = vmull_u8(vget_low_u8(weights_x), left);
   const uint16x8_t weighted_tr_low =
       vmull_u8(vget_low_u8(scaled_weights_x), top_right);
-  const uint16x4_t dest_0 = CalculatePred(
-      vget_low_u16(weighted_top_low), vget_low_u16(weighted_left_low),
-      vget_low_u16(weighted_tr_low), vget_low_u16(weighted_bl));
-  const uint16x4_t dest_1 = CalculatePred(
-      vget_high_u16(weighted_top_low), vget_high_u16(weighted_left_low),
-      vget_high_u16(weighted_tr_low), vget_high_u16(weighted_bl));
-  const uint8x8_t dest_0_u8 = vmovn_u16(vcombine_u16(dest_0, dest_1));
+  const uint8x8_t result_low = CalculatePred(
+      weighted_top_low, weighted_left_low, weighted_bl, weighted_tr_low);
 
   const uint16x8_t weighted_top_high = vmull_u8(weights_y, vget_high_u8(top));
   const uint16x8_t weighted_left_high = vmull_u8(vget_high_u8(weights_x), left);
   const uint16x8_t weighted_tr_high =
       vmull_u8(vget_high_u8(scaled_weights_x), top_right);
-  const uint16x4_t dest_2 = CalculatePred(
-      vget_low_u16(weighted_top_high), vget_low_u16(weighted_left_high),
-      vget_low_u16(weighted_tr_high), vget_low_u16(weighted_bl));
-  const uint16x4_t dest_3 = CalculatePred(
-      vget_high_u16(weighted_top_high), vget_high_u16(weighted_left_high),
-      vget_high_u16(weighted_tr_high), vget_high_u16(weighted_bl));
-  const uint8x8_t dest_1_u8 = vmovn_u16(vcombine_u16(dest_2, dest_3));
+  const uint8x8_t result_high = CalculatePred(
+      weighted_top_high, weighted_left_high, weighted_bl, weighted_tr_high);
 
-  return vcombine_u8(dest_0_u8, dest_1_u8);
+  return vcombine_u8(result_low, result_high);
 }
 
 template <int width, int height>
-inline void Smooth16PlusxN_NEON(void* const dest, ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
-  const uint8_t* const top = static_cast<const uint8_t*>(top_row);
-  const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+inline void Smooth16PlusxN_NEON(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint8_t*>(top_row);
+  const auto* const left = static_cast<const uint8_t*>(left_column);
   const uint8_t top_right = top[width - 1];
   const uint8_t bottom_left = left[height - 1];
   const uint8_t* const weights_y = kSmoothWeights + height - 4;
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+  auto* dst = static_cast<uint8_t*>(dest);
 
   uint8x16_t top_v[4];
   top_v[0] = vld1q_u8(top);
@@ -229,14 +246,15 @@
 }
 
 template <int width, int height>
-inline void SmoothVertical4Or8xN_NEON(void* const dest, ptrdiff_t stride,
-                                      const void* const top_row,
-                                      const void* const left_column) {
-  const uint8_t* const top = static_cast<const uint8_t*>(top_row);
-  const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+inline void SmoothVertical4Or8xN_NEON(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint8_t*>(top_row);
+  const auto* const left = static_cast<const uint8_t*>(left_column);
   const uint8_t bottom_left = left[height - 1];
   const uint8_t* const weights_y = kSmoothWeights + height - 4;
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+  auto* dst = static_cast<uint8_t*>(dest);
 
   uint8x8_t top_v;
   if (width == 4) {
@@ -279,14 +297,15 @@
 }
 
 template <int width, int height>
-inline void SmoothVertical16PlusxN_NEON(void* const dest, ptrdiff_t stride,
-                                        const void* const top_row,
-                                        const void* const left_column) {
-  const uint8_t* const top = static_cast<const uint8_t*>(top_row);
-  const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+inline void SmoothVertical16PlusxN_NEON(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint8_t*>(top_row);
+  const auto* const left = static_cast<const uint8_t*>(left_column);
   const uint8_t bottom_left = left[height - 1];
   const uint8_t* const weights_y = kSmoothWeights + height - 4;
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+  auto* dst = static_cast<uint8_t*>(dest);
 
   uint8x16_t top_v[4];
   top_v[0] = vld1q_u8(top);
@@ -330,13 +349,14 @@
 }
 
 template <int width, int height>
-inline void SmoothHorizontal4Or8xN_NEON(void* const dest, ptrdiff_t stride,
-                                        const void* const top_row,
-                                        const void* const left_column) {
-  const uint8_t* const top = static_cast<const uint8_t*>(top_row);
-  const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+inline void SmoothHorizontal4Or8xN_NEON(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint8_t*>(top_row);
+  const auto* const left = static_cast<const uint8_t*>(left_column);
   const uint8_t top_right = top[width - 1];
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+  auto* dst = static_cast<uint8_t*>(dest);
 
   const uint8x8_t top_right_v = vdup_n_u8(top_right);
   // Over-reads for 4xN but still within the array.
@@ -382,13 +402,14 @@
 }
 
 template <int width, int height>
-inline void SmoothHorizontal16PlusxN_NEON(void* const dest, ptrdiff_t stride,
-                                          const void* const top_row,
-                                          const void* const left_column) {
-  const uint8_t* const top = static_cast<const uint8_t*>(top_row);
-  const uint8_t* const left = static_cast<const uint8_t*>(left_column);
+inline void SmoothHorizontal16PlusxN_NEON(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint8_t*>(top_row);
+  const auto* const left = static_cast<const uint8_t*>(left_column);
   const uint8_t top_right = top[width - 1];
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+  auto* dst = static_cast<uint8_t*>(dest);
 
   const uint8x8_t top_right_v = vdup_n_u8(top_right);
 
@@ -447,7 +468,7 @@
   assert(dsp != nullptr);
   // 4x4
   dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmooth] =
-      Smooth4Or8xN_NEON<4, 4>;
+      Smooth4xN_NEON<4>;
   dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothVertical] =
       SmoothVertical4Or8xN_NEON<4, 4>;
   dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothHorizontal] =
@@ -455,7 +476,7 @@
 
   // 4x8
   dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmooth] =
-      Smooth4Or8xN_NEON<4, 8>;
+      Smooth4xN_NEON<8>;
   dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothVertical] =
       SmoothVertical4Or8xN_NEON<4, 8>;
   dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothHorizontal] =
@@ -463,7 +484,7 @@
 
   // 4x16
   dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmooth] =
-      Smooth4Or8xN_NEON<4, 16>;
+      Smooth4xN_NEON<16>;
   dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothVertical] =
       SmoothVertical4Or8xN_NEON<4, 16>;
   dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothHorizontal] =
@@ -471,7 +492,7 @@
 
   // 8x4
   dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmooth] =
-      Smooth4Or8xN_NEON<8, 4>;
+      Smooth8xN_NEON<4>;
   dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothVertical] =
       SmoothVertical4Or8xN_NEON<8, 4>;
   dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothHorizontal] =
@@ -479,7 +500,7 @@
 
   // 8x8
   dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmooth] =
-      Smooth4Or8xN_NEON<8, 8>;
+      Smooth8xN_NEON<8>;
   dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothVertical] =
       SmoothVertical4Or8xN_NEON<8, 8>;
   dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothHorizontal] =
@@ -487,7 +508,7 @@
 
   // 8x16
   dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmooth] =
-      Smooth4Or8xN_NEON<8, 16>;
+      Smooth8xN_NEON<16>;
   dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothVertical] =
       SmoothVertical4Or8xN_NEON<8, 16>;
   dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothHorizontal] =
@@ -495,7 +516,7 @@
 
   // 8x32
   dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmooth] =
-      Smooth4Or8xN_NEON<8, 32>;
+      Smooth8xN_NEON<32>;
   dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothVertical] =
       SmoothVertical4Or8xN_NEON<8, 32>;
   dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothHorizontal] =
@@ -601,7 +622,535 @@
 }  // namespace
 }  // namespace low_bitdepth
 
-void IntraPredSmoothInit_NEON() { low_bitdepth::Init8bpp(); }
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+// Note these constants are duplicated from intrapred.cc to allow the compiler
+// to have visibility of the values. This helps reduce loads and in the
+// creation of the inverse weights.
+constexpr uint16_t kSmoothWeights[] = {
+#include "src/dsp/smooth_weights.inc"
+};
+
+template <int height>
+inline void Smooth4xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                           const void* LIBGAV1_RESTRICT const top_row,
+                           const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+  const uint16_t top_right = top[3];
+  const uint16_t bottom_left = left[height - 1];
+  const uint16_t* const weights_y = kSmoothWeights + height - 4;
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  const uint16x4_t top_v = vld1_u16(top);
+  const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);
+  const uint16x4_t weights_x_v = vld1_u16(kSmoothWeights);
+  const uint16x4_t scaled_weights_x = vsub_u16(vdup_n_u16(256), weights_x_v);
+
+  // Weighted top right doesn't change with each row.
+  const uint32x4_t weighted_tr = vmull_n_u16(scaled_weights_x, top_right);
+
+  for (int y = 0; y < height; ++y) {
+    // Each variable in the running summation is named for the last item to be
+    // accumulated.
+    const uint32x4_t weighted_top =
+        vmlal_n_u16(weighted_tr, top_v, weights_y[y]);
+    const uint32x4_t weighted_left =
+        vmlal_n_u16(weighted_top, weights_x_v, left[y]);
+    const uint32x4_t weighted_bl =
+        vmlal_n_u16(weighted_left, bottom_left_v, 256 - weights_y[y]);
+
+    const uint16x4_t pred = vrshrn_n_u32(weighted_bl, kSmoothWeightScale + 1);
+    vst1_u16(reinterpret_cast<uint16_t*>(dst), pred);
+    dst += stride;
+  }
+}
+
+// Common code between 8xH and [16|32|64]xH.
+inline void CalculatePred8(uint16_t* LIBGAV1_RESTRICT dst,
+                           const uint32x4_t& weighted_corners_low,
+                           const uint32x4_t& weighted_corners_high,
+                           const uint16x4x2_t& top_vals,
+                           const uint16x4x2_t& weights_x, const uint16_t left_y,
+                           const uint16_t weight_y) {
+  // Each variable in the running summation is named for the last item to be
+  // accumulated.
+  const uint32x4_t weighted_top_low =
+      vmlal_n_u16(weighted_corners_low, top_vals.val[0], weight_y);
+  const uint32x4_t weighted_edges_low =
+      vmlal_n_u16(weighted_top_low, weights_x.val[0], left_y);
+
+  const uint16x4_t pred_low =
+      vrshrn_n_u32(weighted_edges_low, kSmoothWeightScale + 1);
+  vst1_u16(dst, pred_low);
+
+  const uint32x4_t weighted_top_high =
+      vmlal_n_u16(weighted_corners_high, top_vals.val[1], weight_y);
+  const uint32x4_t weighted_edges_high =
+      vmlal_n_u16(weighted_top_high, weights_x.val[1], left_y);
+
+  const uint16x4_t pred_high =
+      vrshrn_n_u32(weighted_edges_high, kSmoothWeightScale + 1);
+  vst1_u16(dst + 4, pred_high);
+}
+
+template <int height>
+inline void Smooth8xH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                           const void* LIBGAV1_RESTRICT const top_row,
+                           const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+  const uint16_t top_right = top[7];
+  const uint16_t bottom_left = left[height - 1];
+  const uint16_t* const weights_y = kSmoothWeights + height - 4;
+
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  const uint16x4x2_t top_vals = {vld1_u16(top), vld1_u16(top + 4)};
+  const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);
+  const uint16x4x2_t weights_x = {vld1_u16(kSmoothWeights + 4),
+                                  vld1_u16(kSmoothWeights + 8)};
+  // Weighted top right doesn't change with each row.
+  const uint32x4_t weighted_tr_low =
+      vmull_n_u16(vsub_u16(vdup_n_u16(256), weights_x.val[0]), top_right);
+  const uint32x4_t weighted_tr_high =
+      vmull_n_u16(vsub_u16(vdup_n_u16(256), weights_x.val[1]), top_right);
+
+  for (int y = 0; y < height; ++y) {
+    // |weighted_bl| is invariant across the row.
+    const uint32x4_t weighted_bl =
+        vmull_n_u16(bottom_left_v, 256 - weights_y[y]);
+    const uint32x4_t weighted_corners_low =
+        vaddq_u32(weighted_bl, weighted_tr_low);
+    const uint32x4_t weighted_corners_high =
+        vaddq_u32(weighted_bl, weighted_tr_high);
+    CalculatePred8(reinterpret_cast<uint16_t*>(dst), weighted_corners_low,
+                   weighted_corners_high, top_vals, weights_x, left[y],
+                   weights_y[y]);
+    dst += stride;
+  }
+}
+
+// For width 16 and above.
+template <int width, int height>
+inline void SmoothWxH_NEON(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                           const void* LIBGAV1_RESTRICT const top_row,
+                           const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+  const uint16_t top_right = top[width - 1];
+  const uint16_t bottom_left = left[height - 1];
+  const uint16_t* const weights_y = kSmoothWeights + height - 4;
+
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  const uint16x4_t weight_scaling = vdup_n_u16(256);
+  // Precompute weighted values that don't vary with |y|.
+  uint32x4_t weighted_tr_low[width >> 3];
+  uint32x4_t weighted_tr_high[width >> 3];
+  for (int i = 0; i < width >> 3; ++i) {
+    const int x = i << 3;
+    const uint16x4_t weights_x_low = vld1_u16(kSmoothWeights + width - 4 + x);
+    weighted_tr_low[i] =
+        vmull_n_u16(vsub_u16(weight_scaling, weights_x_low), top_right);
+    const uint16x4_t weights_x_high = vld1_u16(kSmoothWeights + width + x);
+    weighted_tr_high[i] =
+        vmull_n_u16(vsub_u16(weight_scaling, weights_x_high), top_right);
+  }
+
+  const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);
+  for (int y = 0; y < height; ++y) {
+    // |weighted_bl| is invariant across the row.
+    const uint32x4_t weighted_bl =
+        vmull_n_u16(bottom_left_v, 256 - weights_y[y]);
+    auto* dst_x = reinterpret_cast<uint16_t*>(dst);
+    for (int i = 0; i < width >> 3; ++i) {
+      const int x = i << 3;
+      const uint16x4x2_t top_vals = {vld1_u16(top + x), vld1_u16(top + x + 4)};
+      const uint32x4_t weighted_corners_low =
+          vaddq_u32(weighted_bl, weighted_tr_low[i]);
+      const uint32x4_t weighted_corners_high =
+          vaddq_u32(weighted_bl, weighted_tr_high[i]);
+      // Accumulate weighted edge values and store.
+      const uint16x4x2_t weights_x = {vld1_u16(kSmoothWeights + width - 4 + x),
+                                      vld1_u16(kSmoothWeights + width + x)};
+      CalculatePred8(dst_x, weighted_corners_low, weighted_corners_high,
+                     top_vals, weights_x, left[y], weights_y[y]);
+      dst_x += 8;
+    }
+    dst += stride;
+  }
+}
+
+template <int height>
+inline void SmoothVertical4xH_NEON(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+  const uint16_t bottom_left = left[height - 1];
+  const uint16_t* const weights_y = kSmoothWeights + height - 4;
+
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  const uint16x4_t top_v = vld1_u16(top);
+  const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);
+
+  for (int y = 0; y < height; ++y) {
+    auto* dst16 = reinterpret_cast<uint16_t*>(dst);
+    const uint32x4_t weighted_bl =
+        vmull_n_u16(bottom_left_v, 256 - weights_y[y]);
+    const uint32x4_t weighted_top =
+        vmlal_n_u16(weighted_bl, top_v, weights_y[y]);
+    vst1_u16(dst16, vrshrn_n_u32(weighted_top, kSmoothWeightScale));
+
+    dst += stride;
+  }
+}
+
+template <int height>
+inline void SmoothVertical8xH_NEON(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+  const uint16_t bottom_left = left[height - 1];
+  const uint16_t* const weights_y = kSmoothWeights + height - 4;
+
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  const uint16x4_t top_low = vld1_u16(top);
+  const uint16x4_t top_high = vld1_u16(top + 4);
+  const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);
+
+  for (int y = 0; y < height; ++y) {
+    auto* dst16 = reinterpret_cast<uint16_t*>(dst);
+    // |weighted_bl| is invariant across the row.
+    const uint32x4_t weighted_bl =
+        vmull_n_u16(bottom_left_v, 256 - weights_y[y]);
+
+    const uint32x4_t weighted_top_low =
+        vmlal_n_u16(weighted_bl, top_low, weights_y[y]);
+    vst1_u16(dst16, vrshrn_n_u32(weighted_top_low, kSmoothWeightScale));
+
+    const uint32x4_t weighted_top_high =
+        vmlal_n_u16(weighted_bl, top_high, weights_y[y]);
+    vst1_u16(dst16 + 4, vrshrn_n_u32(weighted_top_high, kSmoothWeightScale));
+    dst += stride;
+  }
+}
+
+// For width 16 and above.
+template <int width, int height>
+inline void SmoothVerticalWxH_NEON(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+  const uint16_t bottom_left = left[height - 1];
+  const uint16_t* const weights_y = kSmoothWeights + height - 4;
+
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  uint16x4x2_t top_vals[width >> 3];
+  for (int i = 0; i < width >> 3; ++i) {
+    const int x = i << 3;
+    top_vals[i] = {vld1_u16(top + x), vld1_u16(top + x + 4)};
+  }
+
+  const uint16x4_t bottom_left_v = vdup_n_u16(bottom_left);
+  for (int y = 0; y < height; ++y) {
+    // |weighted_bl| is invariant across the row.
+    const uint32x4_t weighted_bl =
+        vmull_n_u16(bottom_left_v, 256 - weights_y[y]);
+
+    auto* dst_x = reinterpret_cast<uint16_t*>(dst);
+    for (int i = 0; i < width >> 3; ++i) {
+      const uint32x4_t weighted_top_low =
+          vmlal_n_u16(weighted_bl, top_vals[i].val[0], weights_y[y]);
+      vst1_u16(dst_x, vrshrn_n_u32(weighted_top_low, kSmoothWeightScale));
+
+      const uint32x4_t weighted_top_high =
+          vmlal_n_u16(weighted_bl, top_vals[i].val[1], weights_y[y]);
+      vst1_u16(dst_x + 4, vrshrn_n_u32(weighted_top_high, kSmoothWeightScale));
+      dst_x += 8;
+    }
+    dst += stride;
+  }
+}
+
+template <int height>
+inline void SmoothHorizontal4xH_NEON(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+  const uint16_t top_right = top[3];
+
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  const uint16x4_t weights_x = vld1_u16(kSmoothWeights);
+  const uint16x4_t scaled_weights_x = vsub_u16(vdup_n_u16(256), weights_x);
+
+  const uint32x4_t weighted_tr = vmull_n_u16(scaled_weights_x, top_right);
+  for (int y = 0; y < height; ++y) {
+    auto* dst16 = reinterpret_cast<uint16_t*>(dst);
+    const uint32x4_t weighted_left =
+        vmlal_n_u16(weighted_tr, weights_x, left[y]);
+    vst1_u16(dst16, vrshrn_n_u32(weighted_left, kSmoothWeightScale));
+    dst += stride;
+  }
+}
+
+template <int height>
+inline void SmoothHorizontal8xH_NEON(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+  const uint16_t top_right = top[7];
+
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  const uint16x4x2_t weights_x = {vld1_u16(kSmoothWeights + 4),
+                                  vld1_u16(kSmoothWeights + 8)};
+
+  const uint32x4_t weighted_tr_low =
+      vmull_n_u16(vsub_u16(vdup_n_u16(256), weights_x.val[0]), top_right);
+  const uint32x4_t weighted_tr_high =
+      vmull_n_u16(vsub_u16(vdup_n_u16(256), weights_x.val[1]), top_right);
+
+  for (int y = 0; y < height; ++y) {
+    auto* dst16 = reinterpret_cast<uint16_t*>(dst);
+    const uint16_t left_y = left[y];
+    const uint32x4_t weighted_left_low =
+        vmlal_n_u16(weighted_tr_low, weights_x.val[0], left_y);
+    vst1_u16(dst16, vrshrn_n_u32(weighted_left_low, kSmoothWeightScale));
+
+    const uint32x4_t weighted_left_high =
+        vmlal_n_u16(weighted_tr_high, weights_x.val[1], left_y);
+    vst1_u16(dst16 + 4, vrshrn_n_u32(weighted_left_high, kSmoothWeightScale));
+    dst += stride;
+  }
+}
+
+// For width 16 and above.
+template <int width, int height>
+inline void SmoothHorizontalWxH_NEON(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
+  const auto* const top = static_cast<const uint16_t*>(top_row);
+  const auto* const left = static_cast<const uint16_t*>(left_column);
+  const uint16_t top_right = top[width - 1];
+
+  auto* dst = static_cast<uint8_t*>(dest);
+
+  const uint16x4_t weight_scaling = vdup_n_u16(256);
+
+  uint16x4_t weights_x_low[width >> 3];
+  uint16x4_t weights_x_high[width >> 3];
+  uint32x4_t weighted_tr_low[width >> 3];
+  uint32x4_t weighted_tr_high[width >> 3];
+  for (int i = 0; i < width >> 3; ++i) {
+    const int x = i << 3;
+    weights_x_low[i] = vld1_u16(kSmoothWeights + width - 4 + x);
+    weighted_tr_low[i] =
+        vmull_n_u16(vsub_u16(weight_scaling, weights_x_low[i]), top_right);
+    weights_x_high[i] = vld1_u16(kSmoothWeights + width + x);
+    weighted_tr_high[i] =
+        vmull_n_u16(vsub_u16(weight_scaling, weights_x_high[i]), top_right);
+  }
+
+  for (int y = 0; y < height; ++y) {
+    auto* dst_x = reinterpret_cast<uint16_t*>(dst);
+    const uint16_t left_y = left[y];
+    for (int i = 0; i < width >> 3; ++i) {
+      const uint32x4_t weighted_left_low =
+          vmlal_n_u16(weighted_tr_low[i], weights_x_low[i], left_y);
+      vst1_u16(dst_x, vrshrn_n_u32(weighted_left_low, kSmoothWeightScale));
+
+      const uint32x4_t weighted_left_high =
+          vmlal_n_u16(weighted_tr_high[i], weights_x_high[i], left_y);
+      vst1_u16(dst_x + 4, vrshrn_n_u32(weighted_left_high, kSmoothWeightScale));
+      dst_x += 8;
+    }
+    dst += stride;
+  }
+}
+
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  // 4x4
+  dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmooth] =
+      Smooth4xH_NEON<4>;
+  dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothVertical] =
+      SmoothVertical4xH_NEON<4>;
+  dsp->intra_predictors[kTransformSize4x4][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontal4xH_NEON<4>;
+
+  // 4x8
+  dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmooth] =
+      Smooth4xH_NEON<8>;
+  dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothVertical] =
+      SmoothVertical4xH_NEON<8>;
+  dsp->intra_predictors[kTransformSize4x8][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontal4xH_NEON<8>;
+
+  // 4x16
+  dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmooth] =
+      Smooth4xH_NEON<16>;
+  dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothVertical] =
+      SmoothVertical4xH_NEON<16>;
+  dsp->intra_predictors[kTransformSize4x16][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontal4xH_NEON<16>;
+
+  // 8x4
+  dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmooth] =
+      Smooth8xH_NEON<4>;
+  dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothVertical] =
+      SmoothVertical8xH_NEON<4>;
+  dsp->intra_predictors[kTransformSize8x4][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontal8xH_NEON<4>;
+
+  // 8x8
+  dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmooth] =
+      Smooth8xH_NEON<8>;
+  dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothVertical] =
+      SmoothVertical8xH_NEON<8>;
+  dsp->intra_predictors[kTransformSize8x8][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontal8xH_NEON<8>;
+
+  // 8x16
+  dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmooth] =
+      Smooth8xH_NEON<16>;
+  dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothVertical] =
+      SmoothVertical8xH_NEON<16>;
+  dsp->intra_predictors[kTransformSize8x16][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontal8xH_NEON<16>;
+
+  // 8x32
+  dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmooth] =
+      Smooth8xH_NEON<32>;
+  dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothVertical] =
+      SmoothVertical8xH_NEON<32>;
+  dsp->intra_predictors[kTransformSize8x32][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontal8xH_NEON<32>;
+
+  // 16x4
+  dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<16, 4>;
+  dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<16, 4>;
+  dsp->intra_predictors[kTransformSize16x4][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<16, 4>;
+
+  // 16x8
+  dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<16, 8>;
+  dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<16, 8>;
+  dsp->intra_predictors[kTransformSize16x8][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<16, 8>;
+
+  // 16x16
+  dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<16, 16>;
+  dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<16, 16>;
+  dsp->intra_predictors[kTransformSize16x16][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<16, 16>;
+
+  // 16x32
+  dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<16, 32>;
+  dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<16, 32>;
+  dsp->intra_predictors[kTransformSize16x32][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<16, 32>;
+
+  // 16x64
+  dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<16, 64>;
+  dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<16, 64>;
+  dsp->intra_predictors[kTransformSize16x64][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<16, 64>;
+
+  // 32x8
+  dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<32, 8>;
+  dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<32, 8>;
+  dsp->intra_predictors[kTransformSize32x8][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<32, 8>;
+
+  // 32x16
+  dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<32, 16>;
+  dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<32, 16>;
+  dsp->intra_predictors[kTransformSize32x16][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<32, 16>;
+
+  // 32x32
+  dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<32, 32>;
+  dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<32, 32>;
+  dsp->intra_predictors[kTransformSize32x32][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<32, 32>;
+
+  // 32x64
+  dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<32, 64>;
+  dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<32, 64>;
+  dsp->intra_predictors[kTransformSize32x64][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<32, 64>;
+
+  // 64x16
+  dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<64, 16>;
+  dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<64, 16>;
+  dsp->intra_predictors[kTransformSize64x16][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<64, 16>;
+
+  // 64x32
+  dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<64, 32>;
+  dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<64, 32>;
+  dsp->intra_predictors[kTransformSize64x32][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<64, 32>;
+
+  // 64x64
+  dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmooth] =
+      SmoothWxH_NEON<64, 64>;
+  dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothVertical] =
+      SmoothVerticalWxH_NEON<64, 64>;
+  dsp->intra_predictors[kTransformSize64x64][kIntraPredictorSmoothHorizontal] =
+      SmoothHorizontalWxH_NEON<64, 64>;
+}
+}  // namespace
+}  // namespace high_bitdepth
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+void IntraPredSmoothInit_NEON() {
+  low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  high_bitdepth::Init10bpp();
+#endif
+}
 
 }  // namespace dsp
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/arm/intrapred_smooth_neon.h b/libgav1/src/dsp/arm/intrapred_smooth_neon.h
index edd01be..28b5bd5 100644
--- a/libgav1/src/dsp/arm/intrapred_smooth_neon.h
+++ b/libgav1/src/dsp/arm/intrapred_smooth_neon.h
@@ -144,6 +144,131 @@
   LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_TransformSize64x64_IntraPredictorSmoothHorizontal \
   LIBGAV1_CPU_NEON
+
+// 10bpp
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x4_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x8_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize4x16_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x4_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x8_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x16_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize8x32_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x4_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x8_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorSmooth \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x16_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorSmooth \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x32_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorSmooth \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize16x64_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorSmooth LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x8_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorSmooth \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x16_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorSmooth \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x32_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorSmooth \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize32x64_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorSmooth \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x16_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorSmooth \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x32_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorSmooth \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorSmoothVertical \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_TransformSize64x64_IntraPredictorSmoothHorizontal \
+  LIBGAV1_CPU_NEON
+
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_INTRAPRED_SMOOTH_NEON_H_
diff --git a/libgav1/src/dsp/arm/inverse_transform_10bit_neon.cc b/libgav1/src/dsp/arm/inverse_transform_10bit_neon.cc
index ff184a1..617accc 100644
--- a/libgav1/src/dsp/arm/inverse_transform_10bit_neon.cc
+++ b/libgav1/src/dsp/arm/inverse_transform_10bit_neon.cc
@@ -67,7 +67,8 @@
 
 //------------------------------------------------------------------------------
 template <int store_count>
-LIBGAV1_ALWAYS_INLINE void StoreDst(int32_t* dst, int32_t stride, int32_t idx,
+LIBGAV1_ALWAYS_INLINE void StoreDst(int32_t* LIBGAV1_RESTRICT dst,
+                                    int32_t stride, int32_t idx,
                                     const int32x4_t* const s) {
   assert(store_count % 4 == 0);
   for (int i = 0; i < store_count; i += 4) {
@@ -79,8 +80,8 @@
 }
 
 template <int load_count>
-LIBGAV1_ALWAYS_INLINE void LoadSrc(const int32_t* src, int32_t stride,
-                                   int32_t idx, int32x4_t* x) {
+LIBGAV1_ALWAYS_INLINE void LoadSrc(const int32_t* LIBGAV1_RESTRICT src,
+                                   int32_t stride, int32_t idx, int32x4_t* x) {
   assert(load_count % 4 == 0);
   for (int i = 0; i < load_count; i += 4) {
     x[i] = vld1q_s32(&src[i * stride + idx]);
@@ -168,8 +169,8 @@
 }
 
 LIBGAV1_ALWAYS_INLINE void HadamardRotation(int32x4_t* a, int32x4_t* b,
-                                            bool flip, const int32x4_t* min,
-                                            const int32x4_t* max) {
+                                            bool flip, const int32x4_t min,
+                                            const int32x4_t max) {
   int32x4_t x, y;
   if (flip) {
     y = vqaddq_s32(*b, *a);
@@ -178,8 +179,8 @@
     x = vqaddq_s32(*a, *b);
     y = vqsubq_s32(*a, *b);
   }
-  *a = vmaxq_s32(vminq_s32(x, *max), *min);
-  *b = vmaxq_s32(vminq_s32(y, *max), *min);
+  *a = vmaxq_s32(vminq_s32(x, max), min);
+  *b = vmaxq_s32(vminq_s32(y, max), min);
 }
 
 using ButterflyRotationFunc = void (*)(int32x4_t* a, int32x4_t* b, int angle,
@@ -248,8 +249,8 @@
 
 template <ButterflyRotationFunc butterfly_rotation,
           bool is_fast_butterfly = false>
-LIBGAV1_ALWAYS_INLINE void Dct4Stages(int32x4_t* s, const int32x4_t* min,
-                                      const int32x4_t* max,
+LIBGAV1_ALWAYS_INLINE void Dct4Stages(int32x4_t* s, const int32x4_t min,
+                                      const int32x4_t max,
                                       const bool is_last_stage) {
   // stage 12.
   if (is_fast_butterfly) {
@@ -293,12 +294,12 @@
   s[2] = x[1];
   s[3] = x[3];
 
-  Dct4Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/true);
+  Dct4Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/true);
 
   if (is_row) {
     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
-    for (int i = 0; i < 4; ++i) {
-      s[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(s[i], v_row_shift)));
+    for (auto& i : s) {
+      i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
     }
     Transpose4x4(s, s);
   }
@@ -307,8 +308,8 @@
 
 template <ButterflyRotationFunc butterfly_rotation,
           bool is_fast_butterfly = false>
-LIBGAV1_ALWAYS_INLINE void Dct8Stages(int32x4_t* s, const int32x4_t* min,
-                                      const int32x4_t* max,
+LIBGAV1_ALWAYS_INLINE void Dct8Stages(int32x4_t* s, const int32x4_t min,
+                                      const int32x4_t max,
                                       const bool is_last_stage) {
   // stage 8.
   if (is_fast_butterfly) {
@@ -370,13 +371,13 @@
   s[6] = x[3];
   s[7] = x[7];
 
-  Dct4Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/false);
-  Dct8Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/true);
+  Dct4Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/false);
+  Dct8Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/true);
 
   if (is_row) {
     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
-    for (int i = 0; i < 8; ++i) {
-      s[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(s[i], v_row_shift)));
+    for (auto& i : s) {
+      i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
     }
     Transpose4x4(&s[0], &s[0]);
     Transpose4x4(&s[4], &s[4]);
@@ -389,8 +390,8 @@
 
 template <ButterflyRotationFunc butterfly_rotation,
           bool is_fast_butterfly = false>
-LIBGAV1_ALWAYS_INLINE void Dct16Stages(int32x4_t* s, const int32x4_t* min,
-                                       const int32x4_t* max,
+LIBGAV1_ALWAYS_INLINE void Dct16Stages(int32x4_t* s, const int32x4_t min,
+                                       const int32x4_t max,
                                        const bool is_last_stage) {
   // stage 5.
   if (is_fast_butterfly) {
@@ -487,14 +488,14 @@
   s[14] = x[7];
   s[15] = x[15];
 
-  Dct4Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/false);
-  Dct8Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/false);
-  Dct16Stages<butterfly_rotation>(s, &min, &max, /*is_last_stage=*/true);
+  Dct4Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/false);
+  Dct8Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/false);
+  Dct16Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/true);
 
   if (is_row) {
     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
-    for (int i = 0; i < 16; ++i) {
-      s[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(s[i], v_row_shift)));
+    for (auto& i : s) {
+      i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
     }
     for (int idx = 0; idx < 16; idx += 8) {
       Transpose4x4(&s[idx], &s[idx]);
@@ -509,8 +510,8 @@
 
 template <ButterflyRotationFunc butterfly_rotation,
           bool is_fast_butterfly = false>
-LIBGAV1_ALWAYS_INLINE void Dct32Stages(int32x4_t* s, const int32x4_t* min,
-                                       const int32x4_t* max,
+LIBGAV1_ALWAYS_INLINE void Dct32Stages(int32x4_t* s, const int32x4_t min,
+                                       const int32x4_t max,
                                        const bool is_last_stage) {
   // stage 3
   if (is_fast_butterfly) {
@@ -677,10 +678,10 @@
   s[30] = x[15];
   s[31] = x[31];
 
-  Dct4Stages<ButterflyRotation_4>(s, &min, &max, /*is_last_stage=*/false);
-  Dct8Stages<ButterflyRotation_4>(s, &min, &max, /*is_last_stage=*/false);
-  Dct16Stages<ButterflyRotation_4>(s, &min, &max, /*is_last_stage=*/false);
-  Dct32Stages<ButterflyRotation_4>(s, &min, &max, /*is_last_stage=*/true);
+  Dct4Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/false);
+  Dct8Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/false);
+  Dct16Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/false);
+  Dct32Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/true);
 
   if (is_row) {
     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
@@ -688,8 +689,8 @@
       int32x4_t output[8];
       Transpose4x4(&s[idx], &output[0]);
       Transpose4x4(&s[idx + 4], &output[4]);
-      for (int i = 0; i < 8; ++i) {
-        output[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(output[i], v_row_shift)));
+      for (auto& o : output) {
+        o = vmovl_s16(vqmovn_s32(vqrshlq_s32(o, v_row_shift)));
       }
       StoreDst<4>(dst, step, idx, &output[0]);
       StoreDst<4>(dst, step, idx + 4, &output[4]);
@@ -764,13 +765,13 @@
   s[62] = x[31];
 
   Dct4Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
-      s, &min, &max, /*is_last_stage=*/false);
+      s, min, max, /*is_last_stage=*/false);
   Dct8Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
-      s, &min, &max, /*is_last_stage=*/false);
+      s, min, max, /*is_last_stage=*/false);
   Dct16Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
-      s, &min, &max, /*is_last_stage=*/false);
+      s, min, max, /*is_last_stage=*/false);
   Dct32Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
-      s, &min, &max, /*is_last_stage=*/false);
+      s, min, max, /*is_last_stage=*/false);
 
   //-- start dct 64 stages
   // stage 2.
@@ -792,22 +793,22 @@
   ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false);
 
   // stage 4.
-  HadamardRotation(&s[32], &s[33], false, &min, &max);
-  HadamardRotation(&s[34], &s[35], true, &min, &max);
-  HadamardRotation(&s[36], &s[37], false, &min, &max);
-  HadamardRotation(&s[38], &s[39], true, &min, &max);
-  HadamardRotation(&s[40], &s[41], false, &min, &max);
-  HadamardRotation(&s[42], &s[43], true, &min, &max);
-  HadamardRotation(&s[44], &s[45], false, &min, &max);
-  HadamardRotation(&s[46], &s[47], true, &min, &max);
-  HadamardRotation(&s[48], &s[49], false, &min, &max);
-  HadamardRotation(&s[50], &s[51], true, &min, &max);
-  HadamardRotation(&s[52], &s[53], false, &min, &max);
-  HadamardRotation(&s[54], &s[55], true, &min, &max);
-  HadamardRotation(&s[56], &s[57], false, &min, &max);
-  HadamardRotation(&s[58], &s[59], true, &min, &max);
-  HadamardRotation(&s[60], &s[61], false, &min, &max);
-  HadamardRotation(&s[62], &s[63], true, &min, &max);
+  HadamardRotation(&s[32], &s[33], false, min, max);
+  HadamardRotation(&s[34], &s[35], true, min, max);
+  HadamardRotation(&s[36], &s[37], false, min, max);
+  HadamardRotation(&s[38], &s[39], true, min, max);
+  HadamardRotation(&s[40], &s[41], false, min, max);
+  HadamardRotation(&s[42], &s[43], true, min, max);
+  HadamardRotation(&s[44], &s[45], false, min, max);
+  HadamardRotation(&s[46], &s[47], true, min, max);
+  HadamardRotation(&s[48], &s[49], false, min, max);
+  HadamardRotation(&s[50], &s[51], true, min, max);
+  HadamardRotation(&s[52], &s[53], false, min, max);
+  HadamardRotation(&s[54], &s[55], true, min, max);
+  HadamardRotation(&s[56], &s[57], false, min, max);
+  HadamardRotation(&s[58], &s[59], true, min, max);
+  HadamardRotation(&s[60], &s[61], false, min, max);
+  HadamardRotation(&s[62], &s[63], true, min, max);
 
   // stage 7.
   ButterflyRotation_4(&s[62], &s[33], 60 - 0, true);
@@ -820,22 +821,22 @@
   ButterflyRotation_4(&s[49], &s[46], 60 - 48 + 64, true);
 
   // stage 11.
-  HadamardRotation(&s[32], &s[35], false, &min, &max);
-  HadamardRotation(&s[33], &s[34], false, &min, &max);
-  HadamardRotation(&s[36], &s[39], true, &min, &max);
-  HadamardRotation(&s[37], &s[38], true, &min, &max);
-  HadamardRotation(&s[40], &s[43], false, &min, &max);
-  HadamardRotation(&s[41], &s[42], false, &min, &max);
-  HadamardRotation(&s[44], &s[47], true, &min, &max);
-  HadamardRotation(&s[45], &s[46], true, &min, &max);
-  HadamardRotation(&s[48], &s[51], false, &min, &max);
-  HadamardRotation(&s[49], &s[50], false, &min, &max);
-  HadamardRotation(&s[52], &s[55], true, &min, &max);
-  HadamardRotation(&s[53], &s[54], true, &min, &max);
-  HadamardRotation(&s[56], &s[59], false, &min, &max);
-  HadamardRotation(&s[57], &s[58], false, &min, &max);
-  HadamardRotation(&s[60], &s[63], true, &min, &max);
-  HadamardRotation(&s[61], &s[62], true, &min, &max);
+  HadamardRotation(&s[32], &s[35], false, min, max);
+  HadamardRotation(&s[33], &s[34], false, min, max);
+  HadamardRotation(&s[36], &s[39], true, min, max);
+  HadamardRotation(&s[37], &s[38], true, min, max);
+  HadamardRotation(&s[40], &s[43], false, min, max);
+  HadamardRotation(&s[41], &s[42], false, min, max);
+  HadamardRotation(&s[44], &s[47], true, min, max);
+  HadamardRotation(&s[45], &s[46], true, min, max);
+  HadamardRotation(&s[48], &s[51], false, min, max);
+  HadamardRotation(&s[49], &s[50], false, min, max);
+  HadamardRotation(&s[52], &s[55], true, min, max);
+  HadamardRotation(&s[53], &s[54], true, min, max);
+  HadamardRotation(&s[56], &s[59], false, min, max);
+  HadamardRotation(&s[57], &s[58], false, min, max);
+  HadamardRotation(&s[60], &s[63], true, min, max);
+  HadamardRotation(&s[61], &s[62], true, min, max);
 
   // stage 16.
   ButterflyRotation_4(&s[61], &s[34], 56, true);
@@ -848,22 +849,22 @@
   ButterflyRotation_4(&s[50], &s[45], 56 - 32 + 64, true);
 
   // stage 21.
-  HadamardRotation(&s[32], &s[39], false, &min, &max);
-  HadamardRotation(&s[33], &s[38], false, &min, &max);
-  HadamardRotation(&s[34], &s[37], false, &min, &max);
-  HadamardRotation(&s[35], &s[36], false, &min, &max);
-  HadamardRotation(&s[40], &s[47], true, &min, &max);
-  HadamardRotation(&s[41], &s[46], true, &min, &max);
-  HadamardRotation(&s[42], &s[45], true, &min, &max);
-  HadamardRotation(&s[43], &s[44], true, &min, &max);
-  HadamardRotation(&s[48], &s[55], false, &min, &max);
-  HadamardRotation(&s[49], &s[54], false, &min, &max);
-  HadamardRotation(&s[50], &s[53], false, &min, &max);
-  HadamardRotation(&s[51], &s[52], false, &min, &max);
-  HadamardRotation(&s[56], &s[63], true, &min, &max);
-  HadamardRotation(&s[57], &s[62], true, &min, &max);
-  HadamardRotation(&s[58], &s[61], true, &min, &max);
-  HadamardRotation(&s[59], &s[60], true, &min, &max);
+  HadamardRotation(&s[32], &s[39], false, min, max);
+  HadamardRotation(&s[33], &s[38], false, min, max);
+  HadamardRotation(&s[34], &s[37], false, min, max);
+  HadamardRotation(&s[35], &s[36], false, min, max);
+  HadamardRotation(&s[40], &s[47], true, min, max);
+  HadamardRotation(&s[41], &s[46], true, min, max);
+  HadamardRotation(&s[42], &s[45], true, min, max);
+  HadamardRotation(&s[43], &s[44], true, min, max);
+  HadamardRotation(&s[48], &s[55], false, min, max);
+  HadamardRotation(&s[49], &s[54], false, min, max);
+  HadamardRotation(&s[50], &s[53], false, min, max);
+  HadamardRotation(&s[51], &s[52], false, min, max);
+  HadamardRotation(&s[56], &s[63], true, min, max);
+  HadamardRotation(&s[57], &s[62], true, min, max);
+  HadamardRotation(&s[58], &s[61], true, min, max);
+  HadamardRotation(&s[59], &s[60], true, min, max);
 
   // stage 25.
   ButterflyRotation_4(&s[59], &s[36], 48, true);
@@ -876,22 +877,22 @@
   ButterflyRotation_4(&s[52], &s[43], 112, true);
 
   // stage 28.
-  HadamardRotation(&s[32], &s[47], false, &min, &max);
-  HadamardRotation(&s[33], &s[46], false, &min, &max);
-  HadamardRotation(&s[34], &s[45], false, &min, &max);
-  HadamardRotation(&s[35], &s[44], false, &min, &max);
-  HadamardRotation(&s[36], &s[43], false, &min, &max);
-  HadamardRotation(&s[37], &s[42], false, &min, &max);
-  HadamardRotation(&s[38], &s[41], false, &min, &max);
-  HadamardRotation(&s[39], &s[40], false, &min, &max);
-  HadamardRotation(&s[48], &s[63], true, &min, &max);
-  HadamardRotation(&s[49], &s[62], true, &min, &max);
-  HadamardRotation(&s[50], &s[61], true, &min, &max);
-  HadamardRotation(&s[51], &s[60], true, &min, &max);
-  HadamardRotation(&s[52], &s[59], true, &min, &max);
-  HadamardRotation(&s[53], &s[58], true, &min, &max);
-  HadamardRotation(&s[54], &s[57], true, &min, &max);
-  HadamardRotation(&s[55], &s[56], true, &min, &max);
+  HadamardRotation(&s[32], &s[47], false, min, max);
+  HadamardRotation(&s[33], &s[46], false, min, max);
+  HadamardRotation(&s[34], &s[45], false, min, max);
+  HadamardRotation(&s[35], &s[44], false, min, max);
+  HadamardRotation(&s[36], &s[43], false, min, max);
+  HadamardRotation(&s[37], &s[42], false, min, max);
+  HadamardRotation(&s[38], &s[41], false, min, max);
+  HadamardRotation(&s[39], &s[40], false, min, max);
+  HadamardRotation(&s[48], &s[63], true, min, max);
+  HadamardRotation(&s[49], &s[62], true, min, max);
+  HadamardRotation(&s[50], &s[61], true, min, max);
+  HadamardRotation(&s[51], &s[60], true, min, max);
+  HadamardRotation(&s[52], &s[59], true, min, max);
+  HadamardRotation(&s[53], &s[58], true, min, max);
+  HadamardRotation(&s[54], &s[57], true, min, max);
+  HadamardRotation(&s[55], &s[56], true, min, max);
 
   // stage 30.
   ButterflyRotation_4(&s[55], &s[40], 32, true);
@@ -905,10 +906,10 @@
 
   // stage 31.
   for (int i = 0; i < 32; i += 4) {
-    HadamardRotation(&s[i], &s[63 - i], false, &min, &max);
-    HadamardRotation(&s[i + 1], &s[63 - i - 1], false, &min, &max);
-    HadamardRotation(&s[i + 2], &s[63 - i - 2], false, &min, &max);
-    HadamardRotation(&s[i + 3], &s[63 - i - 3], false, &min, &max);
+    HadamardRotation(&s[i], &s[63 - i], false, min, max);
+    HadamardRotation(&s[i + 1], &s[63 - i - 1], false, min, max);
+    HadamardRotation(&s[i + 2], &s[63 - i - 2], false, min, max);
+    HadamardRotation(&s[i + 3], &s[63 - i - 3], false, min, max);
   }
   //-- end dct 64 stages
   if (is_row) {
@@ -917,8 +918,8 @@
       int32x4_t output[8];
       Transpose4x4(&s[idx], &output[0]);
       Transpose4x4(&s[idx + 4], &output[4]);
-      for (int i = 0; i < 8; ++i) {
-        output[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(output[i], v_row_shift)));
+      for (auto& o : output) {
+        o = vmovl_s16(vqmovn_s32(vqrshlq_s32(o, v_row_shift)));
       }
       StoreDst<4>(dst, step, idx, &output[0]);
       StoreDst<4>(dst, step, idx + 4, &output[4]);
@@ -1089,20 +1090,20 @@
   butterfly_rotation(&s[6], &s[7], 60 - 48, true);
 
   // stage 3.
-  HadamardRotation(&s[0], &s[4], false, &min, &max);
-  HadamardRotation(&s[1], &s[5], false, &min, &max);
-  HadamardRotation(&s[2], &s[6], false, &min, &max);
-  HadamardRotation(&s[3], &s[7], false, &min, &max);
+  HadamardRotation(&s[0], &s[4], false, min, max);
+  HadamardRotation(&s[1], &s[5], false, min, max);
+  HadamardRotation(&s[2], &s[6], false, min, max);
+  HadamardRotation(&s[3], &s[7], false, min, max);
 
   // stage 4.
   butterfly_rotation(&s[4], &s[5], 48 - 0, true);
   butterfly_rotation(&s[7], &s[6], 48 - 32, true);
 
   // stage 5.
-  HadamardRotation(&s[0], &s[2], false, &min, &max);
-  HadamardRotation(&s[4], &s[6], false, &min, &max);
-  HadamardRotation(&s[1], &s[3], false, &min, &max);
-  HadamardRotation(&s[5], &s[7], false, &min, &max);
+  HadamardRotation(&s[0], &s[2], false, min, max);
+  HadamardRotation(&s[4], &s[6], false, min, max);
+  HadamardRotation(&s[1], &s[3], false, min, max);
+  HadamardRotation(&s[5], &s[7], false, min, max);
 
   // stage 6.
   butterfly_rotation(&s[2], &s[3], 32, true);
@@ -1120,8 +1121,8 @@
 
   if (is_row) {
     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
-    for (int i = 0; i < 8; ++i) {
-      x[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[i], v_row_shift)));
+    for (auto& i : x) {
+      i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
     }
     Transpose4x4(&x[0], &x[0]);
     Transpose4x4(&x[4], &x[4]);
@@ -1289,14 +1290,14 @@
   butterfly_rotation(&s[14], &s[15], 62 - 56, true);
 
   // stage 3.
-  HadamardRotation(&s[0], &s[8], false, &min, &max);
-  HadamardRotation(&s[1], &s[9], false, &min, &max);
-  HadamardRotation(&s[2], &s[10], false, &min, &max);
-  HadamardRotation(&s[3], &s[11], false, &min, &max);
-  HadamardRotation(&s[4], &s[12], false, &min, &max);
-  HadamardRotation(&s[5], &s[13], false, &min, &max);
-  HadamardRotation(&s[6], &s[14], false, &min, &max);
-  HadamardRotation(&s[7], &s[15], false, &min, &max);
+  HadamardRotation(&s[0], &s[8], false, min, max);
+  HadamardRotation(&s[1], &s[9], false, min, max);
+  HadamardRotation(&s[2], &s[10], false, min, max);
+  HadamardRotation(&s[3], &s[11], false, min, max);
+  HadamardRotation(&s[4], &s[12], false, min, max);
+  HadamardRotation(&s[5], &s[13], false, min, max);
+  HadamardRotation(&s[6], &s[14], false, min, max);
+  HadamardRotation(&s[7], &s[15], false, min, max);
 
   // stage 4.
   butterfly_rotation(&s[8], &s[9], 56 - 0, true);
@@ -1305,14 +1306,14 @@
   butterfly_rotation(&s[15], &s[14], 8 + 32, true);
 
   // stage 5.
-  HadamardRotation(&s[0], &s[4], false, &min, &max);
-  HadamardRotation(&s[8], &s[12], false, &min, &max);
-  HadamardRotation(&s[1], &s[5], false, &min, &max);
-  HadamardRotation(&s[9], &s[13], false, &min, &max);
-  HadamardRotation(&s[2], &s[6], false, &min, &max);
-  HadamardRotation(&s[10], &s[14], false, &min, &max);
-  HadamardRotation(&s[3], &s[7], false, &min, &max);
-  HadamardRotation(&s[11], &s[15], false, &min, &max);
+  HadamardRotation(&s[0], &s[4], false, min, max);
+  HadamardRotation(&s[8], &s[12], false, min, max);
+  HadamardRotation(&s[1], &s[5], false, min, max);
+  HadamardRotation(&s[9], &s[13], false, min, max);
+  HadamardRotation(&s[2], &s[6], false, min, max);
+  HadamardRotation(&s[10], &s[14], false, min, max);
+  HadamardRotation(&s[3], &s[7], false, min, max);
+  HadamardRotation(&s[11], &s[15], false, min, max);
 
   // stage 6.
   butterfly_rotation(&s[4], &s[5], 48 - 0, true);
@@ -1321,14 +1322,14 @@
   butterfly_rotation(&s[15], &s[14], 48 - 32, true);
 
   // stage 7.
-  HadamardRotation(&s[0], &s[2], false, &min, &max);
-  HadamardRotation(&s[4], &s[6], false, &min, &max);
-  HadamardRotation(&s[8], &s[10], false, &min, &max);
-  HadamardRotation(&s[12], &s[14], false, &min, &max);
-  HadamardRotation(&s[1], &s[3], false, &min, &max);
-  HadamardRotation(&s[5], &s[7], false, &min, &max);
-  HadamardRotation(&s[9], &s[11], false, &min, &max);
-  HadamardRotation(&s[13], &s[15], false, &min, &max);
+  HadamardRotation(&s[0], &s[2], false, min, max);
+  HadamardRotation(&s[4], &s[6], false, min, max);
+  HadamardRotation(&s[8], &s[10], false, min, max);
+  HadamardRotation(&s[12], &s[14], false, min, max);
+  HadamardRotation(&s[1], &s[3], false, min, max);
+  HadamardRotation(&s[5], &s[7], false, min, max);
+  HadamardRotation(&s[9], &s[11], false, min, max);
+  HadamardRotation(&s[13], &s[15], false, min, max);
 
   // stage 8.
   butterfly_rotation(&s[2], &s[3], 32, true);
@@ -1356,8 +1357,8 @@
 
   if (is_row) {
     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
-    for (int i = 0; i < 16; ++i) {
-      x[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[i], v_row_shift)));
+    for (auto& i : x) {
+      i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
     }
     for (int idx = 0; idx < 16; idx += 8) {
       Transpose4x4(&x[idx], &x[idx]);
@@ -1517,59 +1518,23 @@
 template <int identity_size>
 LIBGAV1_ALWAYS_INLINE void IdentityColumnStoreToFrame(
     Array2DView<uint16_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int32_t* source) {
-  static_assert(identity_size == 4 || identity_size == 8 || identity_size == 16,
+    const int tx_width, const int tx_height,
+    const int32_t* LIBGAV1_RESTRICT source) {
+  static_assert(identity_size == 4 || identity_size == 8 ||
+                    identity_size == 16 || identity_size == 32,
                 "Invalid identity_size.");
   const int stride = frame.columns();
-  uint16_t* dst = frame[start_y] + start_x;
+  uint16_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
   const int32x4_t v_dual_round = vdupq_n_s32((1 + (1 << 4)) << 11);
   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
 
-  if (tx_width == 4) {
-    int i = 0;
-    do {
-      int32x4x2_t v_src, v_dst_i, a, b;
-      v_src.val[0] = vld1q_s32(&source[i * 4]);
-      v_src.val[1] = vld1q_s32(&source[(i * 4) + 4]);
-      if (identity_size == 4) {
-        v_dst_i.val[0] =
-            vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity4Multiplier);
-        v_dst_i.val[1] =
-            vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity4Multiplier);
-        a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
-        a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
-      } else if (identity_size == 8) {
-        v_dst_i.val[0] = vaddq_s32(v_src.val[0], v_src.val[0]);
-        v_dst_i.val[1] = vaddq_s32(v_src.val[1], v_src.val[1]);
-        a.val[0] = vrshrq_n_s32(v_dst_i.val[0], 4);
-        a.val[1] = vrshrq_n_s32(v_dst_i.val[1], 4);
-      } else {  // identity_size == 16
-        v_dst_i.val[0] =
-            vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier);
-        v_dst_i.val[1] =
-            vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier);
-        a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
-        a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
-      }
-      uint16x4x2_t frame_data;
-      frame_data.val[0] = vld1_u16(dst);
-      frame_data.val[1] = vld1_u16(dst + stride);
-      b.val[0] = vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0]));
-      b.val[1] = vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1]));
-      vst1_u16(dst, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth));
-      vst1_u16(dst + stride, vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth));
-      dst += stride << 1;
-      i += 2;
-    } while (i < tx_height);
-  } else {
-    int i = 0;
-    do {
-      const int row = i * tx_width;
-      int j = 0;
+  if (identity_size < 32) {
+    if (tx_width == 4) {
+      int i = 0;
       do {
         int32x4x2_t v_src, v_dst_i, a, b;
-        v_src.val[0] = vld1q_s32(&source[row + j]);
-        v_src.val[1] = vld1q_s32(&source[row + j + 4]);
+        v_src.val[0] = vld1q_s32(&source[i * 4]);
+        v_src.val[1] = vld1q_s32(&source[(i * 4) + 4]);
         if (identity_size == 4) {
           v_dst_i.val[0] =
               vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity4Multiplier);
@@ -1591,13 +1556,72 @@
           a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
         }
         uint16x4x2_t frame_data;
-        frame_data.val[0] = vld1_u16(dst + j);
-        frame_data.val[1] = vld1_u16(dst + j + 4);
+        frame_data.val[0] = vld1_u16(dst);
+        frame_data.val[1] = vld1_u16(dst + stride);
         b.val[0] = vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0]));
         b.val[1] = vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1]));
-        vst1_u16(dst + j, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth));
-        vst1_u16(dst + j + 4, vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth));
-        j += 8;
+        vst1_u16(dst, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth));
+        vst1_u16(dst + stride, vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth));
+        dst += stride << 1;
+        i += 2;
+      } while (i < tx_height);
+    } else {
+      int i = 0;
+      do {
+        const int row = i * tx_width;
+        int j = 0;
+        do {
+          int32x4x2_t v_src, v_dst_i, a, b;
+          v_src.val[0] = vld1q_s32(&source[row + j]);
+          v_src.val[1] = vld1q_s32(&source[row + j + 4]);
+          if (identity_size == 4) {
+            v_dst_i.val[0] =
+                vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity4Multiplier);
+            v_dst_i.val[1] =
+                vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity4Multiplier);
+            a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
+            a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
+          } else if (identity_size == 8) {
+            v_dst_i.val[0] = vaddq_s32(v_src.val[0], v_src.val[0]);
+            v_dst_i.val[1] = vaddq_s32(v_src.val[1], v_src.val[1]);
+            a.val[0] = vrshrq_n_s32(v_dst_i.val[0], 4);
+            a.val[1] = vrshrq_n_s32(v_dst_i.val[1], 4);
+          } else {  // identity_size == 16
+            v_dst_i.val[0] =
+                vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier);
+            v_dst_i.val[1] =
+                vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier);
+            a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
+            a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
+          }
+          uint16x4x2_t frame_data;
+          frame_data.val[0] = vld1_u16(dst + j);
+          frame_data.val[1] = vld1_u16(dst + j + 4);
+          b.val[0] =
+              vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0]));
+          b.val[1] =
+              vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1]));
+          vst1_u16(dst + j, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth));
+          vst1_u16(dst + j + 4,
+                   vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth));
+          j += 8;
+        } while (j < tx_width);
+        dst += stride;
+      } while (++i < tx_height);
+    }
+  } else {
+    int i = 0;
+    do {
+      const int row = i * tx_width;
+      int j = 0;
+      do {
+        const int32x4_t v_dst_i = vld1q_s32(&source[row + j]);
+        const uint16x4_t frame_data = vld1_u16(dst + j);
+        const int32x4_t a = vrshrq_n_s32(v_dst_i, 2);
+        const int32x4_t b = vaddw_s16(a, vreinterpret_s16_u16(frame_data));
+        const uint16x4_t d = vmin_u16(vqmovun_s32(b), v_max_bitdepth);
+        vst1_u16(dst + j, d);
+        j += 4;
       } while (j < tx_width);
       dst += stride;
     } while (++i < tx_height);
@@ -1606,9 +1630,10 @@
 
 LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame(
     Array2DView<uint16_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int32_t* source) {
+    const int tx_width, const int tx_height,
+    const int32_t* LIBGAV1_RESTRICT source) {
   const int stride = frame.columns();
-  uint16_t* dst = frame[start_y] + start_x;
+  uint16_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
   const int32x4_t v_round = vdupq_n_s32((1 + (0)) << 11);
   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
 
@@ -1747,6 +1772,119 @@
   return true;
 }
 
+LIBGAV1_ALWAYS_INLINE void Identity32Row16_NEON(void* dest,
+                                                const int32_t step) {
+  auto* const dst = static_cast<int32_t*>(dest);
+
+  // When combining the identity32 multiplier with the row shift, the
+  // calculation for tx_height equal to 16 can be simplified from
+  // ((A * 4) + 1) >> 1) to (A * 2).
+  for (int i = 0; i < 4; ++i) {
+    for (int j = 0; j < 32; j += 4) {
+      const int32x4_t v_src = vld1q_s32(&dst[i * step + j]);
+      const int32x4_t v_dst_i = vqaddq_s32(v_src, v_src);
+      vst1q_s32(&dst[i * step + j], v_dst_i);
+    }
+  }
+}
+
+LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest,
+                                            int adjusted_tx_height) {
+  if (adjusted_tx_height > 1) return false;
+
+  auto* dst = static_cast<int32_t*>(dest);
+  const int32x2_t v_src0 = vdup_n_s32(dst[0]);
+  const int32x2_t v_src =
+      vqrdmulh_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
+  // When combining the identity32 multiplier with the row shift, the
+  // calculation for tx_height equal to 16 can be simplified from
+  // ((A * 4) + 1) >> 1) to (A * 2).
+  const int32x2_t v_dst_0 = vqadd_s32(v_src, v_src);
+  vst1_lane_s32(dst, v_dst_0, 0);
+  return true;
+}
+
+//------------------------------------------------------------------------------
+// Walsh Hadamard Transform.
+
+// Process 4 wht4 rows and columns.
+LIBGAV1_ALWAYS_INLINE void Wht4_NEON(uint16_t* LIBGAV1_RESTRICT dst,
+                                     const int dst_stride,
+                                     const void* LIBGAV1_RESTRICT source,
+                                     const int adjusted_tx_height) {
+  const auto* const src = static_cast<const int32_t*>(source);
+  int32x4_t s[4];
+
+  if (adjusted_tx_height == 1) {
+    // Special case: only src[0] is nonzero.
+    //   src[0]  0   0   0
+    //       0   0   0   0
+    //       0   0   0   0
+    //       0   0   0   0
+    //
+    // After the row and column transforms are applied, we have:
+    //       f   h   h   h
+    //       g   i   i   i
+    //       g   i   i   i
+    //       g   i   i   i
+    // where f, g, h, i are computed as follows.
+    int32_t f = (src[0] >> 2) - (src[0] >> 3);
+    const int32_t g = f >> 1;
+    f = f - (f >> 1);
+    const int32_t h = (src[0] >> 3) - (src[0] >> 4);
+    const int32_t i = (src[0] >> 4);
+    s[0] = vdupq_n_s32(h);
+    s[0] = vsetq_lane_s32(f, s[0], 0);
+    s[1] = vdupq_n_s32(i);
+    s[1] = vsetq_lane_s32(g, s[1], 0);
+    s[2] = s[3] = s[1];
+  } else {
+    // Load the 4x4 source in transposed form.
+    int32x4x4_t columns = vld4q_s32(src);
+
+    // Shift right and permute the columns for the WHT.
+    s[0] = vshrq_n_s32(columns.val[0], 2);
+    s[2] = vshrq_n_s32(columns.val[1], 2);
+    s[3] = vshrq_n_s32(columns.val[2], 2);
+    s[1] = vshrq_n_s32(columns.val[3], 2);
+
+    // Row transforms.
+    s[0] = vaddq_s32(s[0], s[2]);
+    s[3] = vsubq_s32(s[3], s[1]);
+    int32x4_t e = vhsubq_s32(s[0], s[3]);  // e = (s[0] - s[3]) >> 1
+    s[1] = vsubq_s32(e, s[1]);
+    s[2] = vsubq_s32(e, s[2]);
+    s[0] = vsubq_s32(s[0], s[1]);
+    s[3] = vaddq_s32(s[3], s[2]);
+
+    int32x4_t x[4];
+    Transpose4x4(s, x);
+
+    s[0] = x[0];
+    s[2] = x[1];
+    s[3] = x[2];
+    s[1] = x[3];
+
+    // Column transforms.
+    s[0] = vaddq_s32(s[0], s[2]);
+    s[3] = vsubq_s32(s[3], s[1]);
+    e = vhsubq_s32(s[0], s[3]);  // e = (s[0] - s[3]) >> 1
+    s[1] = vsubq_s32(e, s[1]);
+    s[2] = vsubq_s32(e, s[2]);
+    s[0] = vsubq_s32(s[0], s[1]);
+    s[3] = vaddq_s32(s[3], s[2]);
+  }
+
+  // Store to frame.
+  const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+  for (int row = 0; row < 4; row += 1) {
+    const uint16x4_t frame_data = vld1_u16(dst);
+    const int32x4_t b = vaddw_s16(s[row], vreinterpret_s16_u16(frame_data));
+    vst1_u16(dst, vmin_u16(vqmovun_s32(b), v_max_bitdepth));
+    dst += dst_stride;
+  }
+}
+
 //------------------------------------------------------------------------------
 // row/column transform loops
 
@@ -1837,11 +1975,12 @@
 template <int tx_height, bool enable_flip_rows = false>
 LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
     Array2DView<uint16_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int32_t* source, TransformType tx_type) {
+    const int tx_width, const int32_t* LIBGAV1_RESTRICT source,
+    TransformType tx_type) {
   const bool flip_rows =
       enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
   const int stride = frame.columns();
-  uint16_t* dst = frame[start_y] + start_x;
+  uint16_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
 
   if (tx_width == 4) {
     for (int i = 0; i < tx_height; ++i) {
@@ -1887,7 +2026,7 @@
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_height = kTransformHeight[tx_size];
   const bool should_round = (tx_height == 8);
-  const int row_shift = (tx_height == 16);
+  const int row_shift = static_cast<int>(tx_height == 16);
 
   if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) {
     return;
@@ -1909,8 +2048,10 @@
 }
 
 void Dct4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                  int adjusted_tx_height, void* src_buffer,
-                                  int start_x, int start_y, void* dst_frame) {
+                                  int adjusted_tx_height,
+                                  void* LIBGAV1_RESTRICT src_buffer,
+                                  int start_x, int start_y,
+                                  void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -1962,8 +2103,10 @@
 }
 
 void Dct8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                  int adjusted_tx_height, void* src_buffer,
-                                  int start_x, int start_y, void* dst_frame) {
+                                  int adjusted_tx_height,
+                                  void* LIBGAV1_RESTRICT src_buffer,
+                                  int start_x, int start_y,
+                                  void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2014,8 +2157,10 @@
 }
 
 void Dct16TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                   int adjusted_tx_height, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame) {
+                                   int adjusted_tx_height,
+                                   void* LIBGAV1_RESTRICT src_buffer,
+                                   int start_x, int start_y,
+                                   void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2066,8 +2211,10 @@
 }
 
 void Dct32TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                   int adjusted_tx_height, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame) {
+                                   int adjusted_tx_height,
+                                   void* LIBGAV1_RESTRICT src_buffer,
+                                   int start_x, int start_y,
+                                   void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2117,8 +2264,10 @@
 }
 
 void Dct64TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                   int adjusted_tx_height, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame) {
+                                   int adjusted_tx_height,
+                                   void* LIBGAV1_RESTRICT src_buffer,
+                                   int start_x, int start_y,
+                                   void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2168,8 +2317,10 @@
 }
 
 void Adst4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                   int adjusted_tx_height, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame) {
+                                   int adjusted_tx_height,
+                                   void* LIBGAV1_RESTRICT src_buffer,
+                                   int start_x, int start_y,
+                                   void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2222,8 +2373,10 @@
 }
 
 void Adst8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                   int adjusted_tx_height, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame) {
+                                   int adjusted_tx_height,
+                                   void* LIBGAV1_RESTRICT src_buffer,
+                                   int start_x, int start_y,
+                                   void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2275,8 +2428,10 @@
 
 void Adst16TransformLoopColumn_NEON(TransformType tx_type,
                                     TransformSize tx_size,
-                                    int adjusted_tx_height, void* src_buffer,
-                                    int start_x, int start_y, void* dst_frame) {
+                                    int adjusted_tx_height,
+                                    void* LIBGAV1_RESTRICT src_buffer,
+                                    int start_x, int start_y,
+                                    void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2335,9 +2490,10 @@
 
 void Identity4TransformLoopColumn_NEON(TransformType tx_type,
                                        TransformSize tx_size,
-                                       int adjusted_tx_height, void* src_buffer,
+                                       int adjusted_tx_height,
+                                       void* LIBGAV1_RESTRICT src_buffer,
                                        int start_x, int start_y,
-                                       void* dst_frame) {
+                                       void* LIBGAV1_RESTRICT dst_frame) {
   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2416,9 +2572,10 @@
 
 void Identity8TransformLoopColumn_NEON(TransformType tx_type,
                                        TransformSize tx_size,
-                                       int adjusted_tx_height, void* src_buffer,
+                                       int adjusted_tx_height,
+                                       void* LIBGAV1_RESTRICT src_buffer,
                                        int start_x, int start_y,
-                                       void* dst_frame) {
+                                       void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2457,8 +2614,9 @@
 void Identity16TransformLoopColumn_NEON(TransformType tx_type,
                                         TransformSize tx_size,
                                         int adjusted_tx_height,
-                                        void* src_buffer, int start_x,
-                                        int start_y, void* dst_frame) {
+                                        void* LIBGAV1_RESTRICT src_buffer,
+                                        int start_x, int start_y,
+                                        void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int32_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2470,60 +2628,144 @@
                                  adjusted_tx_height, src);
 }
 
+void Identity32TransformLoopRow_NEON(TransformType /*tx_type*/,
+                                     TransformSize tx_size,
+                                     int adjusted_tx_height, void* src_buffer,
+                                     int /*start_x*/, int /*start_y*/,
+                                     void* /*dst_frame*/) {
+  const int tx_height = kTransformHeight[tx_size];
+
+  // When combining the identity32 multiplier with the row shift, the
+  // calculations for tx_height == 8 and tx_height == 32 can be simplified
+  // from ((A * 4) + 2) >> 2) to A.
+  if ((tx_height & 0x28) != 0) {
+    return;
+  }
+
+  // Process kTransformSize32x16. The src is always rounded before the identity
+  // transform and shifted by 1 afterwards.
+  auto* src = static_cast<int32_t*>(src_buffer);
+  if (Identity32DcOnly(src, adjusted_tx_height)) {
+    return;
+  }
+
+  assert(tx_size == kTransformSize32x16);
+  ApplyRounding<32>(src, adjusted_tx_height);
+  int i = adjusted_tx_height;
+  do {
+    Identity32Row16_NEON(src, /*step=*/32);
+    src += 128;
+    i -= 4;
+  } while (i != 0);
+}
+
+void Identity32TransformLoopColumn_NEON(TransformType /*tx_type*/,
+                                        TransformSize tx_size,
+                                        int adjusted_tx_height,
+                                        void* LIBGAV1_RESTRICT src_buffer,
+                                        int start_x, int start_y,
+                                        void* LIBGAV1_RESTRICT dst_frame) {
+  auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
+  auto* src = static_cast<int32_t*>(src_buffer);
+  const int tx_width = kTransformWidth[tx_size];
+
+  IdentityColumnStoreToFrame<32>(frame, start_x, start_y, tx_width,
+                                 adjusted_tx_height, src);
+}
+
+void Wht4TransformLoopRow_NEON(TransformType tx_type, TransformSize tx_size,
+                               int /*adjusted_tx_height*/, void* /*src_buffer*/,
+                               int /*start_x*/, int /*start_y*/,
+                               void* /*dst_frame*/) {
+  assert(tx_type == kTransformTypeDctDct);
+  assert(tx_size == kTransformSize4x4);
+  static_cast<void>(tx_type);
+  static_cast<void>(tx_size);
+  // Do both row and column transforms in the column-transform pass.
+}
+
+void Wht4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
+                                  int adjusted_tx_height,
+                                  void* LIBGAV1_RESTRICT src_buffer,
+                                  int start_x, int start_y,
+                                  void* LIBGAV1_RESTRICT dst_frame) {
+  assert(tx_type == kTransformTypeDctDct);
+  assert(tx_size == kTransformSize4x4);
+  static_cast<void>(tx_type);
+  static_cast<void>(tx_size);
+
+  // Process 4 1d wht4 rows and columns in parallel.
+  const auto* src = static_cast<int32_t*>(src_buffer);
+  auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
+  uint16_t* dst = frame[start_y] + start_x;
+  const int dst_stride = frame.columns();
+  Wht4_NEON(dst, dst_stride, src, adjusted_tx_height);
+}
+
 //------------------------------------------------------------------------------
 
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
   assert(dsp != nullptr);
   // Maximum transform size for Dct is 64.
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
       Dct4TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
       Dct4TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
       Dct8TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
       Dct8TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
       Dct16TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
       Dct16TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
       Dct32TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
       Dct32TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
       Dct64TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
       Dct64TransformLoopColumn_NEON;
 
   // Maximum transform size for Adst is 16.
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
       Adst4TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
       Adst4TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
       Adst8TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
       Adst8TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
       Adst16TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
       Adst16TransformLoopColumn_NEON;
 
   // Maximum transform size for Identity transform is 32.
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
       Identity4TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
       Identity4TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
       Identity8TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
       Identity8TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
       Identity16TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
       Identity16TransformLoopColumn_NEON;
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
+      Identity32TransformLoopRow_NEON;
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
+      Identity32TransformLoopColumn_NEON;
+
+  // Maximum transform size for Wht is 4.
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
+      Wht4TransformLoopRow_NEON;
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
+      Wht4TransformLoopColumn_NEON;
 }
 
 }  // namespace
diff --git a/libgav1/src/dsp/arm/inverse_transform_neon.cc b/libgav1/src/dsp/arm/inverse_transform_neon.cc
index 315d5e9..1c2e111 100644
--- a/libgav1/src/dsp/arm/inverse_transform_neon.cc
+++ b/libgav1/src/dsp/arm/inverse_transform_neon.cc
@@ -273,7 +273,8 @@
 
 //------------------------------------------------------------------------------
 template <int store_width, int store_count>
-LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* dst, int32_t stride, int32_t idx,
+LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* LIBGAV1_RESTRICT dst,
+                                    int32_t stride, int32_t idx,
                                     const int16x8_t* const s) {
   assert(store_count % 4 == 0);
   assert(store_width == 8 || store_width == 16);
@@ -297,8 +298,8 @@
 }
 
 template <int load_width, int load_count>
-LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* src, int32_t stride,
-                                   int32_t idx, int16x8_t* x) {
+LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* LIBGAV1_RESTRICT src,
+                                   int32_t stride, int32_t idx, int16x8_t* x) {
   assert(load_count % 4 == 0);
   assert(load_width == 8 || load_width == 16);
   // NOTE: It is expected that the compiler will unroll these loops.
@@ -388,6 +389,33 @@
                                                          int16x8_t* b,
                                                          const int angle,
                                                          const bool flip) {
+#if defined(__ARM_FEATURE_QRDMX) && defined(__aarch64__) && \
+    defined(__clang__)  // ARM v8.1-A
+  // Clang optimizes vqrdmulhq_n_s16 and vqsubq_s16 (in HadamardRotation) into
+  // vqrdmlshq_s16 resulting in an "off by one" error. For now, do not use
+  // vqrdmulhq_n_s16().
+  const int16_t cos128 = Cos128(angle);
+  const int16_t sin128 = Sin128(angle);
+  const int32x4_t x0 = vmull_n_s16(vget_low_s16(*b), -sin128);
+  const int32x4_t y0 = vmull_n_s16(vget_low_s16(*b), cos128);
+  const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
+  const int16x4_t y1 = vqrshrn_n_s32(y0, 12);
+
+  const int32x4_t x0_hi = vmull_n_s16(vget_high_s16(*b), -sin128);
+  const int32x4_t y0_hi = vmull_n_s16(vget_high_s16(*b), cos128);
+  const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12);
+  const int16x4_t y1_hi = vqrshrn_n_s32(y0_hi, 12);
+
+  const int16x8_t x = vcombine_s16(x1, x1_hi);
+  const int16x8_t y = vcombine_s16(y1, y1_hi);
+  if (flip) {
+    *a = y;
+    *b = x;
+  } else {
+    *a = x;
+    *b = y;
+  }
+#else
   const int16_t cos128 = Cos128(angle);
   const int16_t sin128 = Sin128(angle);
   // For this function, the max value returned by Sin128() is 4091, which fits
@@ -403,12 +431,40 @@
     *a = x;
     *b = y;
   }
+#endif
 }
 
 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(int16x8_t* a,
                                                           int16x8_t* b,
                                                           const int angle,
                                                           const bool flip) {
+#if defined(__ARM_FEATURE_QRDMX) && defined(__aarch64__) && \
+    defined(__clang__)  // ARM v8.1-A
+  // Clang optimizes vqrdmulhq_n_s16 and vqsubq_s16 (in HadamardRotation) into
+  // vqrdmlshq_s16 resulting in an "off by one" error. For now, do not use
+  // vqrdmulhq_n_s16().
+  const int16_t cos128 = Cos128(angle);
+  const int16_t sin128 = Sin128(angle);
+  const int32x4_t x0 = vmull_n_s16(vget_low_s16(*a), cos128);
+  const int32x4_t y0 = vmull_n_s16(vget_low_s16(*a), sin128);
+  const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
+  const int16x4_t y1 = vqrshrn_n_s32(y0, 12);
+
+  const int32x4_t x0_hi = vmull_n_s16(vget_high_s16(*a), cos128);
+  const int32x4_t y0_hi = vmull_n_s16(vget_high_s16(*a), sin128);
+  const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12);
+  const int16x4_t y1_hi = vqrshrn_n_s32(y0_hi, 12);
+
+  const int16x8_t x = vcombine_s16(x1, x1_hi);
+  const int16x8_t y = vcombine_s16(y1, y1_hi);
+  if (flip) {
+    *a = y;
+    *b = x;
+  } else {
+    *a = x;
+    *b = y;
+  }
+#else
   const int16_t cos128 = Cos128(angle);
   const int16_t sin128 = Sin128(angle);
   const int16x8_t x = vqrdmulhq_n_s16(*a, cos128 << 3);
@@ -420,6 +476,7 @@
     *a = x;
     *b = y;
   }
+#endif
 }
 
 LIBGAV1_ALWAYS_INLINE void HadamardRotation(int16x8_t* a, int16x8_t* b,
@@ -736,8 +793,8 @@
 
   if (is_row) {
     const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
-    for (int i = 0; i < 16; ++i) {
-      s[i] = vqrshlq_s16(s[i], v_row_shift);
+    for (auto& i : s) {
+      i = vqrshlq_s16(i, v_row_shift);
     }
   }
 
@@ -914,8 +971,8 @@
     for (int idx = 0; idx < 32; idx += 8) {
       int16x8_t output[8];
       Transpose8x8(&s[idx], output);
-      for (int i = 0; i < 8; ++i) {
-        output[i] = vqrshlq_s16(output[i], v_row_shift);
+      for (auto& o : output) {
+        o = vqrshlq_s16(o, v_row_shift);
       }
       StoreDst<16, 8>(dst, step, idx, output);
     }
@@ -1135,8 +1192,8 @@
     for (int idx = 0; idx < 64; idx += 8) {
       int16x8_t output[8];
       Transpose8x8(&s[idx], output);
-      for (int i = 0; i < 8; ++i) {
-        output[i] = vqrshlq_s16(output[i], v_row_shift);
+      for (auto& o : output) {
+        o = vqrshlq_s16(o, v_row_shift);
       }
       StoreDst<16, 8>(dst, step, idx, output);
     }
@@ -1611,13 +1668,13 @@
       const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
       int16x8_t output[4];
       Transpose4x8To8x4(x, output);
-      for (int i = 0; i < 4; ++i) {
-        output[i] = vqrshlq_s16(output[i], v_row_shift);
+      for (auto& o : output) {
+        o = vqrshlq_s16(o, v_row_shift);
       }
       StoreDst<16, 4>(dst, step, 0, output);
       Transpose4x8To8x4(&x[8], output);
-      for (int i = 0; i < 4; ++i) {
-        output[i] = vqrshlq_s16(output[i], v_row_shift);
+      for (auto& o : output) {
+        o = vqrshlq_s16(o, v_row_shift);
       }
       StoreDst<16, 4>(dst, step, 8, output);
     } else {
@@ -1629,8 +1686,8 @@
       for (int idx = 0; idx < 16; idx += 8) {
         int16x8_t output[8];
         Transpose8x8(&x[idx], output);
-        for (int i = 0; i < 8; ++i) {
-          output[i] = vqrshlq_s16(output[i], v_row_shift);
+        for (auto& o : output) {
+          o = vqrshlq_s16(o, v_row_shift);
         }
         StoreDst<16, 8>(dst, step, idx, output);
       }
@@ -1805,9 +1862,10 @@
 template <int identity_size>
 LIBGAV1_ALWAYS_INLINE void IdentityColumnStoreToFrame(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source) {
+    const int tx_width, const int tx_height,
+    const int16_t* LIBGAV1_RESTRICT source) {
   const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
+  uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
 
   if (identity_size < 32) {
     if (tx_width == 4) {
@@ -1891,9 +1949,10 @@
 
 LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source) {
+    const int tx_width, const int tx_height,
+    const int16_t* LIBGAV1_RESTRICT source) {
   const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
+  uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
 
   if (tx_width == 4) {
     uint8x8_t frame_data = vdup_n_u8(0);
@@ -2106,8 +2165,9 @@
 }
 
 // Process 4 wht4 rows and columns.
-LIBGAV1_ALWAYS_INLINE void Wht4_NEON(uint8_t* dst, const int dst_stride,
-                                     const void* source,
+LIBGAV1_ALWAYS_INLINE void Wht4_NEON(uint8_t* LIBGAV1_RESTRICT dst,
+                                     const int dst_stride,
+                                     const void* LIBGAV1_RESTRICT source,
                                      const int adjusted_tx_height) {
   const auto* const src = static_cast<const int16_t*>(source);
   int16x4_t s[4];
@@ -2273,11 +2333,12 @@
 template <int tx_height, bool enable_flip_rows = false>
 LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int16_t* source, TransformType tx_type) {
+    const int tx_width, const int16_t* LIBGAV1_RESTRICT source,
+    TransformType tx_type) {
   const bool flip_rows =
       enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
   const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
+  uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
 
   // Enable for 4x4, 4x8, 4x16
   if (tx_height < 32 && tx_width == 4) {
@@ -2338,7 +2399,7 @@
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_height = kTransformHeight[tx_size];
   const bool should_round = (tx_height == 8);
-  const int row_shift = (tx_height == 16);
+  const int row_shift = static_cast<int>(tx_height == 16);
 
   if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) {
     return;
@@ -2368,8 +2429,10 @@
 }
 
 void Dct4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                  int adjusted_tx_height, void* src_buffer,
-                                  int start_x, int start_y, void* dst_frame) {
+                                  int adjusted_tx_height,
+                                  void* LIBGAV1_RESTRICT src_buffer,
+                                  int start_x, int start_y,
+                                  void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2435,8 +2498,10 @@
 }
 
 void Dct8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                  int adjusted_tx_height, void* src_buffer,
-                                  int start_x, int start_y, void* dst_frame) {
+                                  int adjusted_tx_height,
+                                  void* LIBGAV1_RESTRICT src_buffer,
+                                  int start_x, int start_y,
+                                  void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2497,8 +2562,10 @@
 }
 
 void Dct16TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                   int adjusted_tx_height, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame) {
+                                   int adjusted_tx_height,
+                                   void* LIBGAV1_RESTRICT src_buffer,
+                                   int start_x, int start_y,
+                                   void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2551,8 +2618,10 @@
 }
 
 void Dct32TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                   int adjusted_tx_height, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame) {
+                                   int adjusted_tx_height,
+                                   void* LIBGAV1_RESTRICT src_buffer,
+                                   int start_x, int start_y,
+                                   void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2594,8 +2663,10 @@
 }
 
 void Dct64TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                   int adjusted_tx_height, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame) {
+                                   int adjusted_tx_height,
+                                   void* LIBGAV1_RESTRICT src_buffer,
+                                   int start_x, int start_y,
+                                   void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2645,8 +2716,10 @@
 }
 
 void Adst4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                   int adjusted_tx_height, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame) {
+                                   int adjusted_tx_height,
+                                   void* LIBGAV1_RESTRICT src_buffer,
+                                   int start_x, int start_y,
+                                   void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2707,8 +2780,10 @@
 }
 
 void Adst8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                   int adjusted_tx_height, void* src_buffer,
-                                   int start_x, int start_y, void* dst_frame) {
+                                   int adjusted_tx_height,
+                                   void* LIBGAV1_RESTRICT src_buffer,
+                                   int start_x, int start_y,
+                                   void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2771,8 +2846,10 @@
 
 void Adst16TransformLoopColumn_NEON(TransformType tx_type,
                                     TransformSize tx_size,
-                                    int adjusted_tx_height, void* src_buffer,
-                                    int start_x, int start_y, void* dst_frame) {
+                                    int adjusted_tx_height,
+                                    void* LIBGAV1_RESTRICT src_buffer,
+                                    int start_x, int start_y,
+                                    void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2844,9 +2921,10 @@
 
 void Identity4TransformLoopColumn_NEON(TransformType tx_type,
                                        TransformSize tx_size,
-                                       int adjusted_tx_height, void* src_buffer,
+                                       int adjusted_tx_height,
+                                       void* LIBGAV1_RESTRICT src_buffer,
                                        int start_x, int start_y,
-                                       void* dst_frame) {
+                                       void* LIBGAV1_RESTRICT dst_frame) {
   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2919,9 +2997,10 @@
 
 void Identity8TransformLoopColumn_NEON(TransformType tx_type,
                                        TransformSize tx_size,
-                                       int adjusted_tx_height, void* src_buffer,
+                                       int adjusted_tx_height,
+                                       void* LIBGAV1_RESTRICT src_buffer,
                                        int start_x, int start_y,
-                                       void* dst_frame) {
+                                       void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2960,8 +3039,9 @@
 void Identity16TransformLoopColumn_NEON(TransformType tx_type,
                                         TransformSize tx_size,
                                         int adjusted_tx_height,
-                                        void* src_buffer, int start_x,
-                                        int start_y, void* dst_frame) {
+                                        void* LIBGAV1_RESTRICT src_buffer,
+                                        int start_x, int start_y,
+                                        void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -3007,8 +3087,9 @@
 void Identity32TransformLoopColumn_NEON(TransformType /*tx_type*/,
                                         TransformSize tx_size,
                                         int adjusted_tx_height,
-                                        void* src_buffer, int start_x,
-                                        int start_y, void* dst_frame) {
+                                        void* LIBGAV1_RESTRICT src_buffer,
+                                        int start_x, int start_y,
+                                        void* LIBGAV1_RESTRICT dst_frame) {
   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -3029,8 +3110,10 @@
 }
 
 void Wht4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
-                                  int adjusted_tx_height, void* src_buffer,
-                                  int start_x, int start_y, void* dst_frame) {
+                                  int adjusted_tx_height,
+                                  void* LIBGAV1_RESTRICT src_buffer,
+                                  int start_x, int start_y,
+                                  void* LIBGAV1_RESTRICT dst_frame) {
   assert(tx_type == kTransformTypeDctDct);
   assert(tx_size == kTransformSize4x4);
   static_cast<void>(tx_type);
@@ -3050,63 +3133,63 @@
   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   // Maximum transform size for Dct is 64.
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
       Dct4TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
       Dct4TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
       Dct8TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
       Dct8TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
       Dct16TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
       Dct16TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
       Dct32TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
       Dct32TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
       Dct64TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
       Dct64TransformLoopColumn_NEON;
 
   // Maximum transform size for Adst is 16.
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
       Adst4TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
       Adst4TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
       Adst8TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
       Adst8TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
       Adst16TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
       Adst16TransformLoopColumn_NEON;
 
   // Maximum transform size for Identity transform is 32.
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
       Identity4TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
       Identity4TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
       Identity8TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
       Identity8TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
       Identity16TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
       Identity16TransformLoopColumn_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
       Identity32TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
       Identity32TransformLoopColumn_NEON;
 
   // Maximum transform size for Wht is 4.
-  dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] =
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
       Wht4TransformLoopRow_NEON;
-  dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
       Wht4TransformLoopColumn_NEON;
 }
 
diff --git a/libgav1/src/dsp/arm/inverse_transform_neon.h b/libgav1/src/dsp/arm/inverse_transform_neon.h
index 91e0e83..ebd7cf4 100644
--- a/libgav1/src/dsp/arm/inverse_transform_neon.h
+++ b/libgav1/src/dsp/arm/inverse_transform_neon.h
@@ -32,36 +32,39 @@
 }  // namespace libgav1
 
 #if LIBGAV1_ENABLE_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize32_Transform1dDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize64_Transform1dDct LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dAdst LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dAdst LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dAdst LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize32_Transform1dIdentity LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dWht LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformDct LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformDct LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformDct LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp10bpp_1DTransformSize32_1DTransformDct LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp10bpp_1DTransformSize64_1DTransformDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize4_Transform1dDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize8_Transform1dDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize16_Transform1dDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize32_Transform1dDct LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize64_Transform1dDct LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformAdst LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformAdst LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformAdst LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize4_Transform1dAdst LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize8_Transform1dAdst LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize16_Transform1dAdst LIBGAV1_CPU_NEON
 
-#define LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_CPU_NEON
-#define LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize4_Transform1dIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize8_Transform1dIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize16_Transform1dIdentity LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_Transform1dSize32_Transform1dIdentity LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_Transform1dSize4_Transform1dWht LIBGAV1_CPU_NEON
 
 #endif  // LIBGAV1_ENABLE_NEON
 
diff --git a/libgav1/src/dsp/arm/loop_filter_neon.cc b/libgav1/src/dsp/arm/loop_filter_neon.cc
index 8d72892..8c03928 100644
--- a/libgav1/src/dsp/arm/loop_filter_neon.cc
+++ b/libgav1/src/dsp/arm/loop_filter_neon.cc
@@ -50,7 +50,7 @@
 }
 
 // abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh &&
-//   OuterThreshhold()
+//   OuterThreshold()
 inline uint8x8_t NeedsFilter4(const uint8x8_t abd_p0p1_q0q1,
                               const uint8x8_t p0q0, const uint8x8_t p1q1,
                               const uint8_t inner_thresh,
@@ -65,6 +65,7 @@
                          const uint8_t hev_thresh, const uint8_t outer_thresh,
                          const uint8_t inner_thresh, uint8x8_t* const hev_mask,
                          uint8x8_t* const needs_filter4_mask) {
+  // First half is |p0 - p1|, second half is |q0 - q1|.
   const uint8x8_t p0p1_q0q1 = vabd_u8(p0q0, p1q1);
   // This includes cases where NeedsFilter4() is not true and so Filter2() will
   // not be applied.
@@ -131,7 +132,7 @@
 void Horizontal4_NEON(void* const dest, const ptrdiff_t stride,
                       const int outer_thresh, const int inner_thresh,
                       const int hev_thresh) {
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+  auto* dst = static_cast<uint8_t*>(dest);
 
   const uint8x8_t p1_v = Load4(dst - 2 * stride);
   const uint8x8_t p0_v = Load4(dst - stride);
@@ -180,7 +181,7 @@
 void Vertical4_NEON(void* const dest, const ptrdiff_t stride,
                     const int outer_thresh, const int inner_thresh,
                     const int hev_thresh) {
-  uint8_t* dst = static_cast<uint8_t*>(dest);
+  auto* dst = static_cast<uint8_t*>(dest);
 
   // Move |dst| to the left side of the filter window.
   dst -= 2;
@@ -256,7 +257,7 @@
 
 // abs(p2 - p1) <= inner_thresh && abs(p1 - p0) <= inner_thresh &&
 //   abs(q1 - q0) <= inner_thresh && abs(q2 - q1) <= inner_thresh &&
-//   OuterThreshhold()
+//   OuterThreshold()
 inline uint8x8_t NeedsFilter6(const uint8x8_t abd_p0p1_q0q1,
                               const uint8x8_t abd_p1p2_q1q2,
                               const uint8x8_t p0q0, const uint8x8_t p1q1,
@@ -288,26 +289,26 @@
   // Sum p1 and q1 output from opposite directions
   // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
   //      ^^^^^^^^
-  // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
+  // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q2)
   //                                 ^^^^^^^^
   const uint16x8_t p2q2_double = vaddl_u8(p2q2, p2q2);
   uint16x8_t sum = vaddw_u8(p2q2_double, p2q2);
 
   // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
   //                 ^^^^^^^^
-  // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
+  // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q2)
   //                      ^^^^^^^^
   sum = vaddq_u16(vaddl_u8(p1q1, p1q1), sum);
 
   // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
   //                            ^^^^^^^^
-  // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
+  // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q2)
   //           ^^^^^^^^
   sum = vaddq_u16(vaddl_u8(p0q0, p0q0), sum);
 
   // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
   //                                       ^^
-  // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q3)
+  // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q2)
   //      ^^
   const uint8x8_t q0p0 = Transpose32(p0q0);
   sum = vaddw_u8(sum, q0p0);
@@ -488,7 +489,7 @@
 // abs(p3 - p2) <= inner_thresh && abs(p2 - p1) <= inner_thresh &&
 //   abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh &&
 //   abs(q2 - q1) <= inner_thresh && abs(q3 - q2) <= inner_thresh
-//   OuterThreshhold()
+//   OuterThreshold()
 inline uint8x8_t NeedsFilter8(const uint8x8_t abd_p0p1_q0q1,
                               const uint8x8_t abd_p1p2_q1q2,
                               const uint8x8_t abd_p2p3_q2q3,
@@ -522,29 +523,35 @@
                     const uint8x8_t p1q1, const uint8x8_t p0q0,
                     uint8x8_t* const p2q2_output, uint8x8_t* const p1q1_output,
                     uint8x8_t* const p0q0_output) {
-  // Sum p2 and q2 output from opposite directions
+  // Sum p2 and q2 output from opposite directions.
+  // The formula is regrouped to allow 2 doubling operations to be combined.
   // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
   //      ^^^^^^^^
   // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
   //                                ^^^^^^^^
-  uint16x8_t sum = vaddw_u8(vaddl_u8(p3q3, p3q3), p3q3);
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //                    ^^^^^^^^^^^
+  const uint16x8_t p23q23 = vaddl_u8(p3q3, p2q2);
 
-  // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
-  //                 ^^^^^^^^
-  // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
-  //                     ^^^^^^^^
-  sum = vaddq_u16(vaddl_u8(p2q2, p2q2), sum);
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //               ^^^^^
+  uint16x8_t sum = vshlq_n_u16(p23q23, 1);
 
-  // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
-  //                            ^^^^^^^
-  // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
-  //           ^^^^^^^
-  sum = vaddq_u16(vaddl_u8(p1q1, p0q0), sum);
+  // Add two other terms to make dual issue with shift more likely.
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //                                   ^^^^^^^^^^^
+  const uint16x8_t p01q01 = vaddl_u8(p0q0, p1q1);
 
-  // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
-  //                                      ^^
-  // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
-  //      ^^
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //                                 ^^^^^^^^^^^^^
+  sum = vaddq_u16(sum, p01q01);
+
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //        ^^^^^^
+  sum = vaddw_u8(sum, p3q3);
+
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //                                               ^^^^^^
   const uint8x8_t q0p0 = Transpose32(p0q0);
   sum = vaddw_u8(sum, q0p0);
 
@@ -553,9 +560,9 @@
   // Convert to p1 and q1 output:
   // p1 = p2 - p3 - p2 + p1 + q1
   // q1 = q2 - q3 - q2 + q0 + p1
-  sum = vsubq_u16(sum, vaddl_u8(p3q3, p2q2));
+  sum = vsubq_u16(sum, p23q23);
   const uint8x8_t q1p1 = Transpose32(p1q1);
-  sum = vaddq_u16(vaddl_u8(p1q1, q1p1), sum);
+  sum = vaddq_u16(sum, vaddl_u8(p1q1, q1p1));
 
   *p1q1_output = vrshrn_n_u16(sum, 3);
 
@@ -564,7 +571,7 @@
   // q0 = q1 - q3 - q1 + q0 + p2
   sum = vsubq_u16(sum, vaddl_u8(p3q3, p1q1));
   const uint8x8_t q2p2 = Transpose32(p2q2);
-  sum = vaddq_u16(vaddl_u8(p0q0, q2p2), sum);
+  sum = vaddq_u16(sum, vaddl_u8(p0q0, q2p2));
 
   *p0q0_output = vrshrn_n_u16(sum, 3);
 }
@@ -1174,7 +1181,1264 @@
 }  // namespace
 }  // namespace low_bitdepth
 
-void LoopFilterInit_NEON() { low_bitdepth::Init8bpp(); }
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+// (abs(p1 - p0) > thresh) || (abs(q1 - q0) > thresh)
+inline uint16x4_t Hev(const uint16x8_t abd_p0p1_q0q1, const uint16_t thresh) {
+  const uint16x8_t a = vcgtq_u16(abd_p0p1_q0q1, vdupq_n_u16(thresh));
+  return vorr_u16(vget_low_u16(a), vget_high_u16(a));
+}
+
+// abs(p0 - q0) * 2 + abs(p1 - q1) / 2 <= outer_thresh
+inline uint16x4_t OuterThreshold(const uint16x4_t p1, const uint16x4_t p0,
+                                 const uint16x4_t q0, const uint16x4_t q1,
+                                 const uint16_t outer_thresh) {
+  const uint16x4_t abd_p0q0 = vabd_u16(p0, q0);
+  const uint16x4_t abd_p1q1 = vabd_u16(p1, q1);
+  const uint16x4_t p0q0_double = vshl_n_u16(abd_p0q0, 1);
+  const uint16x4_t p1q1_half = vshr_n_u16(abd_p1q1, 1);
+  const uint16x4_t sum = vadd_u16(p0q0_double, p1q1_half);
+  return vcle_u16(sum, vdup_n_u16(outer_thresh));
+}
+
+// abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh &&
+//   OuterThreshold()
+inline uint16x4_t NeedsFilter4(const uint16x8_t abd_p0p1_q0q1,
+                               const uint16_t inner_thresh,
+                               const uint16x4_t outer_mask) {
+  const uint16x8_t a = vcleq_u16(abd_p0p1_q0q1, vdupq_n_u16(inner_thresh));
+  const uint16x4_t inner_mask = vand_u16(vget_low_u16(a), vget_high_u16(a));
+  return vand_u16(inner_mask, outer_mask);
+}
+
+// abs(p2 - p1) <= inner_thresh && abs(p1 - p0) <= inner_thresh &&
+//   abs(q1 - q0) <= inner_thresh && abs(q2 - q1) <= inner_thresh &&
+//   OuterThreshold()
+inline uint16x4_t NeedsFilter6(const uint16x8_t abd_p0p1_q0q1,
+                               const uint16x8_t abd_p1p2_q1q2,
+                               const uint16_t inner_thresh,
+                               const uint16x4_t outer_mask) {
+  const uint16x8_t a = vmaxq_u16(abd_p0p1_q0q1, abd_p1p2_q1q2);
+  const uint16x8_t b = vcleq_u16(a, vdupq_n_u16(inner_thresh));
+  const uint16x4_t inner_mask = vand_u16(vget_low_u16(b), vget_high_u16(b));
+  return vand_u16(inner_mask, outer_mask);
+}
+
+// abs(p3 - p2) <= inner_thresh && abs(p2 - p1) <= inner_thresh &&
+//   abs(p1 - p0) <= inner_thresh && abs(q1 - q0) <= inner_thresh &&
+//   abs(q2 - q1) <= inner_thresh && abs(q3 - q2) <= inner_thresh
+//   OuterThreshold()
+inline uint16x4_t NeedsFilter8(const uint16x8_t abd_p0p1_q0q1,
+                               const uint16x8_t abd_p1p2_q1q2,
+                               const uint16x8_t abd_p2p3_q2q3,
+                               const uint16_t inner_thresh,
+                               const uint16x4_t outer_mask) {
+  const uint16x8_t a = vmaxq_u16(abd_p0p1_q0q1, abd_p1p2_q1q2);
+  const uint16x8_t b = vmaxq_u16(a, abd_p2p3_q2q3);
+  const uint16x8_t c = vcleq_u16(b, vdupq_n_u16(inner_thresh));
+  const uint16x4_t inner_mask = vand_u16(vget_low_u16(c), vget_high_u16(c));
+  return vand_u16(inner_mask, outer_mask);
+}
+
+// -----------------------------------------------------------------------------
+// FilterNMasks functions.
+
+inline void Filter4Masks(const uint16x8_t p0q0, const uint16x8_t p1q1,
+                         const uint16_t hev_thresh, const uint16x4_t outer_mask,
+                         const uint16_t inner_thresh,
+                         uint16x4_t* const hev_mask,
+                         uint16x4_t* const needs_filter4_mask) {
+  const uint16x8_t p0p1_q0q1 = vabdq_u16(p0q0, p1q1);
+  // This includes cases where NeedsFilter4() is not true and so Filter2() will
+  // not be applied.
+  const uint16x4_t hev_tmp_mask = Hev(p0p1_q0q1, hev_thresh);
+
+  *needs_filter4_mask = NeedsFilter4(p0p1_q0q1, inner_thresh, outer_mask);
+
+  // Filter2() will only be applied if both NeedsFilter4() and Hev() are true.
+  *hev_mask = vand_u16(hev_tmp_mask, *needs_filter4_mask);
+}
+
+// abs(p1 - p0) <= flat_thresh && abs(q1 - q0) <= flat_thresh &&
+//   abs(p2 - p0) <= flat_thresh && abs(q2 - q0) <= flat_thresh
+// |flat_thresh| == 4 for 10 bit decode.
+inline uint16x4_t IsFlat3(const uint16x8_t abd_p0p1_q0q1,
+                          const uint16x8_t abd_p0p2_q0q2) {
+  constexpr int flat_thresh = 1 << 2;
+  const uint16x8_t a = vmaxq_u16(abd_p0p1_q0q1, abd_p0p2_q0q2);
+  const uint16x8_t b = vcleq_u16(a, vdupq_n_u16(flat_thresh));
+  return vand_u16(vget_low_u16(b), vget_high_u16(b));
+}
+
+inline void Filter6Masks(const uint16x8_t p2q2, const uint16x8_t p1q1,
+                         const uint16x8_t p0q0, const uint16_t hev_thresh,
+                         const uint16x4_t outer_mask,
+                         const uint16_t inner_thresh,
+                         uint16x4_t* const needs_filter6_mask,
+                         uint16x4_t* const is_flat3_mask,
+                         uint16x4_t* const hev_mask) {
+  const uint16x8_t abd_p0p1_q0q1 = vabdq_u16(p0q0, p1q1);
+  *hev_mask = Hev(abd_p0p1_q0q1, hev_thresh);
+  *is_flat3_mask = IsFlat3(abd_p0p1_q0q1, vabdq_u16(p0q0, p2q2));
+  *needs_filter6_mask = NeedsFilter6(abd_p0p1_q0q1, vabdq_u16(p1q1, p2q2),
+                                     inner_thresh, outer_mask);
+}
+
+// IsFlat4 uses N=1, IsFlatOuter4 uses N=4.
+// abs(p[N] - p0) <= flat_thresh && abs(q[N] - q0) <= flat_thresh &&
+//   abs(p[N+1] - p0) <= flat_thresh && abs(q[N+1] - q0) <= flat_thresh &&
+//   abs(p[N+2] - p0) <= flat_thresh && abs(q[N+1] - q0) <= flat_thresh
+// |flat_thresh| == 4 for 10 bit decode.
+inline uint16x4_t IsFlat4(const uint16x8_t abd_pnp0_qnq0,
+                          const uint16x8_t abd_pn1p0_qn1q0,
+                          const uint16x8_t abd_pn2p0_qn2q0) {
+  constexpr int flat_thresh = 1 << 2;
+  const uint16x8_t a = vmaxq_u16(abd_pnp0_qnq0, abd_pn1p0_qn1q0);
+  const uint16x8_t b = vmaxq_u16(a, abd_pn2p0_qn2q0);
+  const uint16x8_t c = vcleq_u16(b, vdupq_n_u16(flat_thresh));
+  return vand_u16(vget_low_u16(c), vget_high_u16(c));
+}
+
+inline void Filter8Masks(const uint16x8_t p3q3, const uint16x8_t p2q2,
+                         const uint16x8_t p1q1, const uint16x8_t p0q0,
+                         const uint16_t hev_thresh, const uint16x4_t outer_mask,
+                         const uint16_t inner_thresh,
+                         uint16x4_t* const needs_filter8_mask,
+                         uint16x4_t* const is_flat4_mask,
+                         uint16x4_t* const hev_mask) {
+  const uint16x8_t abd_p0p1_q0q1 = vabdq_u16(p0q0, p1q1);
+  *hev_mask = Hev(abd_p0p1_q0q1, hev_thresh);
+  const uint16x4_t is_flat4 =
+      IsFlat4(abd_p0p1_q0q1, vabdq_u16(p0q0, p2q2), vabdq_u16(p0q0, p3q3));
+  *needs_filter8_mask =
+      NeedsFilter8(abd_p0p1_q0q1, vabdq_u16(p1q1, p2q2), vabdq_u16(p2q2, p3q3),
+                   inner_thresh, outer_mask);
+  // |is_flat4_mask| is used to decide where to use the result of Filter8.
+  // In rare cases, |is_flat4| can be true where |needs_filter8_mask| is false,
+  // overriding the question of whether to use Filter8. Because Filter4 doesn't
+  // apply to p2q2, |is_flat4_mask| chooses directly between Filter8 and the
+  // source value. To be correct, the mask must account for this override.
+  *is_flat4_mask = vand_u16(is_flat4, *needs_filter8_mask);
+}
+
+// -----------------------------------------------------------------------------
+// FilterN functions.
+
+// Calculate Filter4() or Filter2() based on |hev_mask|.
+inline void Filter4(const uint16x8_t p0q0, const uint16x8_t p0q1,
+                    const uint16x8_t p1q1, const uint16x4_t hev_mask,
+                    uint16x8_t* const p1q1_result,
+                    uint16x8_t* const p0q0_result) {
+  const uint16x8_t q0p1 = vextq_u16(p0q0, p1q1, 4);
+  // a = 3 * (q0 - p0) + Clip3(p1 - q1, min_signed_val, max_signed_val);
+  // q0mp0 means "q0 minus p0".
+  const int16x8_t q0mp0_p1mq1 = vreinterpretq_s16_u16(vsubq_u16(q0p1, p0q1));
+  const int16x4_t q0mp0_3 = vmul_n_s16(vget_low_s16(q0mp0_p1mq1), 3);
+
+  // If this is for Filter2() then include |p1mq1|. Otherwise zero it.
+  const int16x4_t min_signed_pixel = vdup_n_s16(-(1 << (9 /*bitdepth-1*/)));
+  const int16x4_t max_signed_pixel = vdup_n_s16((1 << (9 /*bitdepth-1*/)) - 1);
+  const int16x4_t p1mq1 = vget_high_s16(q0mp0_p1mq1);
+  const int16x4_t p1mq1_saturated =
+      Clip3S16(p1mq1, min_signed_pixel, max_signed_pixel);
+  const int16x4_t hev_option =
+      vand_s16(vreinterpret_s16_u16(hev_mask), p1mq1_saturated);
+
+  const int16x4_t a = vadd_s16(q0mp0_3, hev_option);
+
+  // Need to figure out what's going on here because there are some unnecessary
+  // tricks to accommodate 8x8 as smallest 8bpp vector
+
+  // We can not shift with rounding because the clamp comes *before* the
+  // shifting. a1 = Clip3(a + 4, min_signed_val, max_signed_val) >> 3; a2 =
+  // Clip3(a + 3, min_signed_val, max_signed_val) >> 3;
+  const int16x4_t plus_four =
+      Clip3S16(vadd_s16(a, vdup_n_s16(4)), min_signed_pixel, max_signed_pixel);
+  const int16x4_t plus_three =
+      Clip3S16(vadd_s16(a, vdup_n_s16(3)), min_signed_pixel, max_signed_pixel);
+  const int16x4_t a1 = vshr_n_s16(plus_four, 3);
+  const int16x4_t a2 = vshr_n_s16(plus_three, 3);
+
+  // a3 = (a1 + 1) >> 1;
+  const int16x4_t a3 = vrshr_n_s16(a1, 1);
+
+  const int16x8_t a3_ma3 = vcombine_s16(a3, vneg_s16(a3));
+  const int16x8_t p1q1_a3 = vaddq_s16(vreinterpretq_s16_u16(p1q1), a3_ma3);
+
+  // Need to shift the second term or we end up with a2_ma2.
+  const int16x8_t a2_ma1 = vcombine_s16(a2, vneg_s16(a1));
+  const int16x8_t p0q0_a = vaddq_s16(vreinterpretq_s16_u16(p0q0), a2_ma1);
+  *p1q1_result = ConvertToUnsignedPixelU16(p1q1_a3, kBitdepth10);
+  *p0q0_result = ConvertToUnsignedPixelU16(p0q0_a, kBitdepth10);
+}
+
+void Horizontal4_NEON(void* const dest, const ptrdiff_t stride,
+                      int outer_thresh, int inner_thresh, int hev_thresh) {
+  auto* const dst = static_cast<uint8_t*>(dest);
+  auto* const dst_p1 = reinterpret_cast<uint16_t*>(dst - 2 * stride);
+  auto* const dst_p0 = reinterpret_cast<uint16_t*>(dst - stride);
+  auto* const dst_q0 = reinterpret_cast<uint16_t*>(dst);
+  auto* const dst_q1 = reinterpret_cast<uint16_t*>(dst + stride);
+
+  const uint16x4_t src[4] = {vld1_u16(dst_p1), vld1_u16(dst_p0),
+                             vld1_u16(dst_q0), vld1_u16(dst_q1)};
+
+  // Adjust thresholds to bitdepth.
+  outer_thresh <<= 2;
+  inner_thresh <<= 2;
+  hev_thresh <<= 2;
+  const uint16x4_t outer_mask =
+      OuterThreshold(src[0], src[1], src[2], src[3], outer_thresh);
+  uint16x4_t hev_mask;
+  uint16x4_t needs_filter4_mask;
+  const uint16x8_t p0q0 = vcombine_u16(src[1], src[2]);
+  const uint16x8_t p1q1 = vcombine_u16(src[0], src[3]);
+  Filter4Masks(p0q0, p1q1, hev_thresh, outer_mask, inner_thresh, &hev_mask,
+               &needs_filter4_mask);
+
+#if defined(__aarch64__)
+  // This provides a good speedup for the unit test, but may not come up often
+  // enough to warrant it.
+  if (vaddv_u16(needs_filter4_mask) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#else   // !defined(__aarch64__)
+  const uint64x1_t needs_filter4_mask64 =
+      vreinterpret_u64_u16(needs_filter4_mask);
+  if (vget_lane_u64(needs_filter4_mask64, 0) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#endif  // defined(__aarch64__)
+
+  // Copy the masks to the high bits for packed comparisons later.
+  const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask);
+  const uint16x8_t needs_filter4_mask_8 =
+      vcombine_u16(needs_filter4_mask, needs_filter4_mask);
+
+  uint16x8_t f_p1q1;
+  uint16x8_t f_p0q0;
+  const uint16x8_t p0q1 = vcombine_u16(src[1], src[3]);
+  Filter4(p0q0, p0q1, p1q1, hev_mask, &f_p1q1, &f_p0q0);
+
+  // Already integrated the Hev mask when calculating the filtered values.
+  const uint16x8_t p0q0_output = vbslq_u16(needs_filter4_mask_8, f_p0q0, p0q0);
+
+  // p1/q1 are unmodified if only Hev() is true. This works because it was and'd
+  // with |needs_filter4_mask| previously.
+  const uint16x8_t p1q1_mask = veorq_u16(hev_mask_8, needs_filter4_mask_8);
+  const uint16x8_t p1q1_output = vbslq_u16(p1q1_mask, f_p1q1, p1q1);
+
+  vst1_u16(dst_p1, vget_low_u16(p1q1_output));
+  vst1_u16(dst_p0, vget_low_u16(p0q0_output));
+  vst1_u16(dst_q0, vget_high_u16(p0q0_output));
+  vst1_u16(dst_q1, vget_high_u16(p1q1_output));
+}
+
+void Vertical4_NEON(void* const dest, const ptrdiff_t stride, int outer_thresh,
+                    int inner_thresh, int hev_thresh) {
+  // Offset by 2 uint16_t values to load from first p1 position.
+  auto* dst = static_cast<uint8_t*>(dest) - 4;
+  auto* dst_p1 = reinterpret_cast<uint16_t*>(dst);
+  auto* dst_p0 = reinterpret_cast<uint16_t*>(dst + stride);
+  auto* dst_q0 = reinterpret_cast<uint16_t*>(dst + stride * 2);
+  auto* dst_q1 = reinterpret_cast<uint16_t*>(dst + stride * 3);
+
+  uint16x4_t src[4] = {vld1_u16(dst_p1), vld1_u16(dst_p0), vld1_u16(dst_q0),
+                       vld1_u16(dst_q1)};
+  Transpose4x4(src);
+
+  // Adjust thresholds to bitdepth.
+  outer_thresh <<= 2;
+  inner_thresh <<= 2;
+  hev_thresh <<= 2;
+  const uint16x4_t outer_mask =
+      OuterThreshold(src[0], src[1], src[2], src[3], outer_thresh);
+  uint16x4_t hev_mask;
+  uint16x4_t needs_filter4_mask;
+  const uint16x8_t p0q0 = vcombine_u16(src[1], src[2]);
+  const uint16x8_t p1q1 = vcombine_u16(src[0], src[3]);
+  Filter4Masks(p0q0, p1q1, hev_thresh, outer_mask, inner_thresh, &hev_mask,
+               &needs_filter4_mask);
+
+#if defined(__aarch64__)
+  // This provides a good speedup for the unit test. Not sure how applicable it
+  // is to valid streams though.
+  // Consider doing this on armv7 if there is a quick way to check if a vector
+  // is zero.
+  if (vaddv_u16(needs_filter4_mask) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#else   // !defined(__aarch64__)
+  const uint64x1_t needs_filter4_mask64 =
+      vreinterpret_u64_u16(needs_filter4_mask);
+  if (vget_lane_u64(needs_filter4_mask64, 0) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#endif  // defined(__aarch64__)
+
+  // Copy the masks to the high bits for packed comparisons later.
+  const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask);
+  const uint16x8_t needs_filter4_mask_8 =
+      vcombine_u16(needs_filter4_mask, needs_filter4_mask);
+
+  uint16x8_t f_p1q1;
+  uint16x8_t f_p0q0;
+  const uint16x8_t p0q1 = vcombine_u16(src[1], src[3]);
+  Filter4(p0q0, p0q1, p1q1, hev_mask, &f_p1q1, &f_p0q0);
+
+  // Already integrated the Hev mask when calculating the filtered values.
+  const uint16x8_t p0q0_output = vbslq_u16(needs_filter4_mask_8, f_p0q0, p0q0);
+
+  // p1/q1 are unmodified if only Hev() is true. This works because it was and'd
+  // with |needs_filter4_mask| previously.
+  const uint16x8_t p1q1_mask = veorq_u16(hev_mask_8, needs_filter4_mask_8);
+  const uint16x8_t p1q1_output = vbslq_u16(p1q1_mask, f_p1q1, p1q1);
+
+  uint16x4_t output[4] = {
+      vget_low_u16(p1q1_output),
+      vget_low_u16(p0q0_output),
+      vget_high_u16(p0q0_output),
+      vget_high_u16(p1q1_output),
+  };
+  Transpose4x4(output);
+
+  vst1_u16(dst_p1, output[0]);
+  vst1_u16(dst_p0, output[1]);
+  vst1_u16(dst_q0, output[2]);
+  vst1_u16(dst_q1, output[3]);
+}
+
+inline void Filter6(const uint16x8_t p2q2, const uint16x8_t p1q1,
+                    const uint16x8_t p0q0, uint16x8_t* const p1q1_output,
+                    uint16x8_t* const p0q0_output) {
+  // Sum p1 and q1 output from opposite directions.
+  // The formula is regrouped to allow 3 doubling operations to be combined.
+  //
+  // p1 = (3 * p2) + (2 * p1) + (2 * p0) + q0
+  //      ^^^^^^^^
+  // q1 = p0 + (2 * q0) + (2 * q1) + (3 * q2)
+  //                                 ^^^^^^^^
+  // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0
+  //                    ^^^^^^^^^^^
+  uint16x8_t sum = vaddq_u16(p2q2, p1q1);
+
+  // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0
+  //                                ^^^^^^
+  sum = vaddq_u16(sum, p0q0);
+
+  // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0
+  //               ^^^^^
+  sum = vshlq_n_u16(sum, 1);
+
+  // p1q1 = p2q2 + 2 * (p2q2 + p1q1 + p0q0) + q0p0
+  //        ^^^^^^                          ^^^^^^
+  // Should dual issue with the left shift.
+  const uint16x8_t q0p0 = Transpose64(p0q0);
+  const uint16x8_t outer_sum = vaddq_u16(p2q2, q0p0);
+  sum = vaddq_u16(sum, outer_sum);
+
+  *p1q1_output = vrshrq_n_u16(sum, 3);
+
+  // Convert to p0 and q0 output:
+  // p0 = p1 - (2 * p2) + q0 + q1
+  // q0 = q1 - (2 * q2) + p0 + p1
+  // p0q0 = p1q1 - (2 * p2q2) + q0p0 + q1p1
+  //                ^^^^^^^^
+  const uint16x8_t p2q2_double = vshlq_n_u16(p2q2, 1);
+  // p0q0 = p1q1 - (2 * p2q2) + q0p0 + q1p1
+  //        ^^^^^^^^
+  sum = vsubq_u16(sum, p2q2_double);
+  const uint16x8_t q1p1 = Transpose64(p1q1);
+  sum = vaddq_u16(sum, vaddq_u16(q0p0, q1p1));
+
+  *p0q0_output = vrshrq_n_u16(sum, 3);
+}
+
+void Horizontal6_NEON(void* const dest, const ptrdiff_t stride,
+                      int outer_thresh, int inner_thresh, int hev_thresh) {
+  auto* const dst = static_cast<uint8_t*>(dest);
+  auto* const dst_p2 = reinterpret_cast<uint16_t*>(dst - 3 * stride);
+  auto* const dst_p1 = reinterpret_cast<uint16_t*>(dst - 2 * stride);
+  auto* const dst_p0 = reinterpret_cast<uint16_t*>(dst - stride);
+  auto* const dst_q0 = reinterpret_cast<uint16_t*>(dst);
+  auto* const dst_q1 = reinterpret_cast<uint16_t*>(dst + stride);
+  auto* const dst_q2 = reinterpret_cast<uint16_t*>(dst + 2 * stride);
+
+  const uint16x4_t src[6] = {vld1_u16(dst_p2), vld1_u16(dst_p1),
+                             vld1_u16(dst_p0), vld1_u16(dst_q0),
+                             vld1_u16(dst_q1), vld1_u16(dst_q2)};
+
+  // Adjust thresholds to bitdepth.
+  outer_thresh <<= 2;
+  inner_thresh <<= 2;
+  hev_thresh <<= 2;
+  const uint16x4_t outer_mask =
+      OuterThreshold(src[1], src[2], src[3], src[4], outer_thresh);
+  uint16x4_t hev_mask;
+  uint16x4_t needs_filter_mask;
+  uint16x4_t is_flat3_mask;
+  const uint16x8_t p0q0 = vcombine_u16(src[2], src[3]);
+  const uint16x8_t p1q1 = vcombine_u16(src[1], src[4]);
+  const uint16x8_t p2q2 = vcombine_u16(src[0], src[5]);
+  Filter6Masks(p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh,
+               &needs_filter_mask, &is_flat3_mask, &hev_mask);
+
+#if defined(__aarch64__)
+  if (vaddv_u16(needs_filter_mask) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#else   // !defined(__aarch64__)
+  // This might be faster than vaddv (latency 3) because mov to general register
+  // has latency 2.
+  const uint64x1_t needs_filter_mask64 =
+      vreinterpret_u64_u16(needs_filter_mask);
+  if (vget_lane_u64(needs_filter_mask64, 0) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#endif  // defined(__aarch64__)
+
+  // Copy the masks to the high bits for packed comparisons later.
+  const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask);
+  const uint16x8_t is_flat3_mask_8 = vcombine_u16(is_flat3_mask, is_flat3_mask);
+  const uint16x8_t needs_filter_mask_8 =
+      vcombine_u16(needs_filter_mask, needs_filter_mask);
+
+  uint16x8_t f4_p1q1;
+  uint16x8_t f4_p0q0;
+  // ZIP1 p0q0, p1q1 may perform better here.
+  const uint16x8_t p0q1 = vcombine_u16(src[2], src[4]);
+  Filter4(p0q0, p0q1, p1q1, hev_mask, &f4_p1q1, &f4_p0q0);
+  f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1);
+
+  uint16x8_t p0q0_output, p1q1_output;
+  // Because we did not return after testing |needs_filter_mask| we know it is
+  // nonzero. |is_flat3_mask| controls whether the needed filter is Filter4 or
+  // Filter6. Therefore if it is false when |needs_filter_mask| is true, Filter6
+  // output is not used.
+  uint16x8_t f6_p1q1, f6_p0q0;
+  const uint64x1_t need_filter6 = vreinterpret_u64_u16(is_flat3_mask);
+  if (vget_lane_u64(need_filter6, 0) == 0) {
+    // Filter6() does not apply, but Filter4() applies to one or more values.
+    p0q0_output = p0q0;
+    p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0);
+  } else {
+    Filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0);
+    p1q1_output = vbslq_u16(is_flat3_mask_8, f6_p1q1, f4_p1q1);
+    p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1);
+    p0q0_output = vbslq_u16(is_flat3_mask_8, f6_p0q0, f4_p0q0);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0);
+  }
+
+  vst1_u16(dst_p1, vget_low_u16(p1q1_output));
+  vst1_u16(dst_p0, vget_low_u16(p0q0_output));
+  vst1_u16(dst_q0, vget_high_u16(p0q0_output));
+  vst1_u16(dst_q1, vget_high_u16(p1q1_output));
+}
+
+void Vertical6_NEON(void* const dest, const ptrdiff_t stride, int outer_thresh,
+                    int inner_thresh, int hev_thresh) {
+  // Left side of the filter window.
+  auto* const dst = static_cast<uint8_t*>(dest) - 3 * sizeof(uint16_t);
+  auto* const dst_0 = reinterpret_cast<uint16_t*>(dst);
+  auto* const dst_1 = reinterpret_cast<uint16_t*>(dst + stride);
+  auto* const dst_2 = reinterpret_cast<uint16_t*>(dst + 2 * stride);
+  auto* const dst_3 = reinterpret_cast<uint16_t*>(dst + 3 * stride);
+
+  // Overread by 2 values. These overreads become the high halves of src_raw[2]
+  // and src_raw[3] after transpose.
+  uint16x8_t src_raw[4] = {vld1q_u16(dst_0), vld1q_u16(dst_1), vld1q_u16(dst_2),
+                           vld1q_u16(dst_3)};
+  Transpose4x8(src_raw);
+  // p2, p1, p0, q0, q1, q2
+  const uint16x4_t src[6] = {
+      vget_low_u16(src_raw[0]),  vget_low_u16(src_raw[1]),
+      vget_low_u16(src_raw[2]),  vget_low_u16(src_raw[3]),
+      vget_high_u16(src_raw[0]), vget_high_u16(src_raw[1]),
+  };
+
+  // Adjust thresholds to bitdepth.
+  outer_thresh <<= 2;
+  inner_thresh <<= 2;
+  hev_thresh <<= 2;
+  const uint16x4_t outer_mask =
+      OuterThreshold(src[1], src[2], src[3], src[4], outer_thresh);
+  uint16x4_t hev_mask;
+  uint16x4_t needs_filter_mask;
+  uint16x4_t is_flat3_mask;
+  const uint16x8_t p0q0 = vcombine_u16(src[2], src[3]);
+  const uint16x8_t p1q1 = vcombine_u16(src[1], src[4]);
+  const uint16x8_t p2q2 = vcombine_u16(src[0], src[5]);
+  Filter6Masks(p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh,
+               &needs_filter_mask, &is_flat3_mask, &hev_mask);
+
+#if defined(__aarch64__)
+  if (vaddv_u16(needs_filter_mask) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#else   // !defined(__aarch64__)
+  // This might be faster than vaddv (latency 3) because mov to general register
+  // has latency 2.
+  const uint64x1_t needs_filter_mask64 =
+      vreinterpret_u64_u16(needs_filter_mask);
+  if (vget_lane_u64(needs_filter_mask64, 0) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#endif  // defined(__aarch64__)
+
+  // Copy the masks to the high bits for packed comparisons later.
+  const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask);
+  const uint16x8_t is_flat3_mask_8 = vcombine_u16(is_flat3_mask, is_flat3_mask);
+  const uint16x8_t needs_filter_mask_8 =
+      vcombine_u16(needs_filter_mask, needs_filter_mask);
+
+  uint16x8_t f4_p1q1;
+  uint16x8_t f4_p0q0;
+  // ZIP1 p0q0, p1q1 may perform better here.
+  const uint16x8_t p0q1 = vcombine_u16(src[2], src[4]);
+  Filter4(p0q0, p0q1, p1q1, hev_mask, &f4_p1q1, &f4_p0q0);
+  f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1);
+
+  uint16x8_t p0q0_output, p1q1_output;
+  // Because we did not return after testing |needs_filter_mask| we know it is
+  // nonzero. |is_flat3_mask| controls whether the needed filter is Filter4 or
+  // Filter6. Therefore if it is false when |needs_filter_mask| is true, Filter6
+  // output is not used.
+  uint16x8_t f6_p1q1, f6_p0q0;
+  const uint64x1_t need_filter6 = vreinterpret_u64_u16(is_flat3_mask);
+  if (vget_lane_u64(need_filter6, 0) == 0) {
+    // Filter6() does not apply, but Filter4() applies to one or more values.
+    p0q0_output = p0q0;
+    p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0);
+  } else {
+    Filter6(p2q2, p1q1, p0q0, &f6_p1q1, &f6_p0q0);
+    p1q1_output = vbslq_u16(is_flat3_mask_8, f6_p1q1, f4_p1q1);
+    p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1);
+    p0q0_output = vbslq_u16(is_flat3_mask_8, f6_p0q0, f4_p0q0);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0);
+  }
+
+  uint16x4_t output[4] = {
+      vget_low_u16(p1q1_output),
+      vget_low_u16(p0q0_output),
+      vget_high_u16(p0q0_output),
+      vget_high_u16(p1q1_output),
+  };
+  Transpose4x4(output);
+
+  // dst_n starts at p2, so adjust to p1.
+  vst1_u16(dst_0 + 1, output[0]);
+  vst1_u16(dst_1 + 1, output[1]);
+  vst1_u16(dst_2 + 1, output[2]);
+  vst1_u16(dst_3 + 1, output[3]);
+}
+
+inline void Filter8(const uint16x8_t p3q3, const uint16x8_t p2q2,
+                    const uint16x8_t p1q1, const uint16x8_t p0q0,
+                    uint16x8_t* const p2q2_output,
+                    uint16x8_t* const p1q1_output,
+                    uint16x8_t* const p0q0_output) {
+  // Sum p2 and q2 output from opposite directions.
+  // The formula is regrouped to allow 2 doubling operations to be combined.
+  // p2 = (3 * p3) + (2 * p2) + p1 + p0 + q0
+  //      ^^^^^^^^
+  // q2 = p0 + q0 + q1 + (2 * q2) + (3 * q3)
+  //                                ^^^^^^^^
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //                    ^^^^^^^^^^^
+  const uint16x8_t p23q23 = vaddq_u16(p3q3, p2q2);
+
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //               ^^^^^
+  uint16x8_t sum = vshlq_n_u16(p23q23, 1);
+
+  // Add two other terms to make dual issue with shift more likely.
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //                                   ^^^^^^^^^^^
+  const uint16x8_t p01q01 = vaddq_u16(p0q0, p1q1);
+
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //                                 ^^^^^^^^^^^^^
+  sum = vaddq_u16(sum, p01q01);
+
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //        ^^^^^^
+  sum = vaddq_u16(sum, p3q3);
+
+  // p2q2 = p3q3 + 2 * (p3q3 + p2q2) + p1q1 + p0q0 + q0p0
+  //                                               ^^^^^^
+  const uint16x8_t q0p0 = Transpose64(p0q0);
+  sum = vaddq_u16(sum, q0p0);
+
+  *p2q2_output = vrshrq_n_u16(sum, 3);
+
+  // Convert to p1 and q1 output:
+  // p1 = p2 - p3 - p2 + p1 + q1
+  // q1 = q2 - q3 - q2 + q0 + p1
+  sum = vsubq_u16(sum, p23q23);
+  const uint16x8_t q1p1 = Transpose64(p1q1);
+  sum = vaddq_u16(sum, vaddq_u16(p1q1, q1p1));
+
+  *p1q1_output = vrshrq_n_u16(sum, 3);
+
+  // Convert to p0 and q0 output:
+  // p0 = p1 - p3 - p1 + p0 + q2
+  // q0 = q1 - q3 - q1 + q0 + p2
+  sum = vsubq_u16(sum, vaddq_u16(p3q3, p1q1));
+  const uint16x8_t q2p2 = Transpose64(p2q2);
+  sum = vaddq_u16(sum, vaddq_u16(p0q0, q2p2));
+
+  *p0q0_output = vrshrq_n_u16(sum, 3);
+}
+
+void Horizontal8_NEON(void* const dest, const ptrdiff_t stride,
+                      int outer_thresh, int inner_thresh, int hev_thresh) {
+  auto* const dst = static_cast<uint8_t*>(dest);
+  auto* const dst_p3 = reinterpret_cast<uint16_t*>(dst - 4 * stride);
+  auto* const dst_p2 = reinterpret_cast<uint16_t*>(dst - 3 * stride);
+  auto* const dst_p1 = reinterpret_cast<uint16_t*>(dst - 2 * stride);
+  auto* const dst_p0 = reinterpret_cast<uint16_t*>(dst - stride);
+  auto* const dst_q0 = reinterpret_cast<uint16_t*>(dst);
+  auto* const dst_q1 = reinterpret_cast<uint16_t*>(dst + stride);
+  auto* const dst_q2 = reinterpret_cast<uint16_t*>(dst + 2 * stride);
+  auto* const dst_q3 = reinterpret_cast<uint16_t*>(dst + 3 * stride);
+
+  const uint16x4_t src[8] = {
+      vld1_u16(dst_p3), vld1_u16(dst_p2), vld1_u16(dst_p1), vld1_u16(dst_p0),
+      vld1_u16(dst_q0), vld1_u16(dst_q1), vld1_u16(dst_q2), vld1_u16(dst_q3)};
+
+  // Adjust thresholds to bitdepth.
+  outer_thresh <<= 2;
+  inner_thresh <<= 2;
+  hev_thresh <<= 2;
+  const uint16x4_t outer_mask =
+      OuterThreshold(src[2], src[3], src[4], src[5], outer_thresh);
+  uint16x4_t hev_mask;
+  uint16x4_t needs_filter_mask;
+  uint16x4_t is_flat4_mask;
+  const uint16x8_t p0q0 = vcombine_u16(src[3], src[4]);
+  const uint16x8_t p1q1 = vcombine_u16(src[2], src[5]);
+  const uint16x8_t p2q2 = vcombine_u16(src[1], src[6]);
+  const uint16x8_t p3q3 = vcombine_u16(src[0], src[7]);
+  Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh,
+               &needs_filter_mask, &is_flat4_mask, &hev_mask);
+
+#if defined(__aarch64__)
+  if (vaddv_u16(needs_filter_mask) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#else   // !defined(__aarch64__)
+  // This might be faster than vaddv (latency 3) because mov to general register
+  // has latency 2.
+  const uint64x1_t needs_filter_mask64 =
+      vreinterpret_u64_u16(needs_filter_mask);
+  if (vget_lane_u64(needs_filter_mask64, 0) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#endif  // defined(__aarch64__)
+
+  // Copy the masks to the high bits for packed comparisons later.
+  const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask);
+  const uint16x8_t needs_filter_mask_8 =
+      vcombine_u16(needs_filter_mask, needs_filter_mask);
+
+  uint16x8_t f4_p1q1;
+  uint16x8_t f4_p0q0;
+  // ZIP1 p0q0, p1q1 may perform better here.
+  const uint16x8_t p0q1 = vcombine_u16(src[3], src[5]);
+  Filter4(p0q0, p0q1, p1q1, hev_mask, &f4_p1q1, &f4_p0q0);
+  f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1);
+
+  uint16x8_t p0q0_output, p1q1_output, p2q2_output;
+  // Because we did not return after testing |needs_filter_mask| we know it is
+  // nonzero. |is_flat4_mask| controls whether the needed filter is Filter4 or
+  // Filter8. Therefore if it is false when |needs_filter_mask| is true, Filter8
+  // output is not used.
+  uint16x8_t f8_p2q2, f8_p1q1, f8_p0q0;
+  const uint64x1_t need_filter8 = vreinterpret_u64_u16(is_flat4_mask);
+  if (vget_lane_u64(need_filter8, 0) == 0) {
+    // Filter8() does not apply, but Filter4() applies to one or more values.
+    p2q2_output = p2q2;
+    p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0);
+  } else {
+    const uint16x8_t is_flat4_mask_8 =
+        vcombine_u16(is_flat4_mask, is_flat4_mask);
+    Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0);
+    p2q2_output = vbslq_u16(is_flat4_mask_8, f8_p2q2, p2q2);
+    p1q1_output = vbslq_u16(is_flat4_mask_8, f8_p1q1, f4_p1q1);
+    p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1);
+    p0q0_output = vbslq_u16(is_flat4_mask_8, f8_p0q0, f4_p0q0);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0);
+  }
+
+  vst1_u16(dst_p2, vget_low_u16(p2q2_output));
+  vst1_u16(dst_p1, vget_low_u16(p1q1_output));
+  vst1_u16(dst_p0, vget_low_u16(p0q0_output));
+  vst1_u16(dst_q0, vget_high_u16(p0q0_output));
+  vst1_u16(dst_q1, vget_high_u16(p1q1_output));
+  vst1_u16(dst_q2, vget_high_u16(p2q2_output));
+}
+
+inline uint16x8_t ReverseLowHalf(const uint16x8_t a) {
+  return vcombine_u16(vrev64_u16(vget_low_u16(a)), vget_high_u16(a));
+}
+
+void Vertical8_NEON(void* const dest, const ptrdiff_t stride, int outer_thresh,
+                    int inner_thresh, int hev_thresh) {
+  auto* const dst = static_cast<uint8_t*>(dest) - 4 * sizeof(uint16_t);
+  auto* const dst_0 = reinterpret_cast<uint16_t*>(dst);
+  auto* const dst_1 = reinterpret_cast<uint16_t*>(dst + stride);
+  auto* const dst_2 = reinterpret_cast<uint16_t*>(dst + 2 * stride);
+  auto* const dst_3 = reinterpret_cast<uint16_t*>(dst + 3 * stride);
+
+  // src_raw[n] contains p3, p2, p1, p0, q0, q1, q2, q3 for row n.
+  // To get desired pairs after transpose, one half should be reversed.
+  uint16x8_t src[4] = {vld1q_u16(dst_0), vld1q_u16(dst_1), vld1q_u16(dst_2),
+                       vld1q_u16(dst_3)};
+
+  // src[0] = p0q0
+  // src[1] = p1q1
+  // src[2] = p2q2
+  // src[3] = p3q3
+  LoopFilterTranspose4x8(src);
+
+  // Adjust thresholds to bitdepth.
+  outer_thresh <<= 2;
+  inner_thresh <<= 2;
+  hev_thresh <<= 2;
+  const uint16x4_t outer_mask = OuterThreshold(
+      vget_low_u16(src[1]), vget_low_u16(src[0]), vget_high_u16(src[0]),
+      vget_high_u16(src[1]), outer_thresh);
+  uint16x4_t hev_mask;
+  uint16x4_t needs_filter_mask;
+  uint16x4_t is_flat4_mask;
+  const uint16x8_t p0q0 = src[0];
+  const uint16x8_t p1q1 = src[1];
+  const uint16x8_t p2q2 = src[2];
+  const uint16x8_t p3q3 = src[3];
+  Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh,
+               &needs_filter_mask, &is_flat4_mask, &hev_mask);
+
+#if defined(__aarch64__)
+  if (vaddv_u16(needs_filter_mask) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#else   // !defined(__aarch64__)
+  // This might be faster than vaddv (latency 3) because mov to general register
+  // has latency 2.
+  const uint64x1_t needs_filter_mask64 =
+      vreinterpret_u64_u16(needs_filter_mask);
+  if (vget_lane_u64(needs_filter_mask64, 0) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#endif  // defined(__aarch64__)
+
+  // Copy the masks to the high bits for packed comparisons later.
+  const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask);
+  const uint16x8_t needs_filter_mask_8 =
+      vcombine_u16(needs_filter_mask, needs_filter_mask);
+
+  uint16x8_t f4_p1q1;
+  uint16x8_t f4_p0q0;
+  const uint16x8_t p0q1 = vcombine_u16(vget_low_u16(p0q0), vget_high_u16(p1q1));
+  Filter4(p0q0, p0q1, p1q1, hev_mask, &f4_p1q1, &f4_p0q0);
+  f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1);
+
+  uint16x8_t p0q0_output, p1q1_output, p2q2_output;
+  // Because we did not return after testing |needs_filter_mask| we know it is
+  // nonzero. |is_flat4_mask| controls whether the needed filter is Filter4 or
+  // Filter8. Therefore if it is false when |needs_filter_mask| is true, Filter8
+  // output is not used.
+  const uint64x1_t need_filter8 = vreinterpret_u64_u16(is_flat4_mask);
+  if (vget_lane_u64(need_filter8, 0) == 0) {
+    // Filter8() does not apply, but Filter4() applies to one or more values.
+    p2q2_output = p2q2;
+    p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0);
+  } else {
+    const uint16x8_t is_flat4_mask_8 =
+        vcombine_u16(is_flat4_mask, is_flat4_mask);
+    uint16x8_t f8_p2q2, f8_p1q1, f8_p0q0;
+    Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0);
+    p2q2_output = vbslq_u16(is_flat4_mask_8, f8_p2q2, p2q2);
+    p1q1_output = vbslq_u16(is_flat4_mask_8, f8_p1q1, f4_p1q1);
+    p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1);
+    p0q0_output = vbslq_u16(is_flat4_mask_8, f8_p0q0, f4_p0q0);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0);
+  }
+
+  uint16x8_t output[4] = {p0q0_output, p1q1_output, p2q2_output, p3q3};
+  // After transpose, |output| will contain rows of the form:
+  // p0 p1 p2 p3 q0 q1 q2 q3
+  Transpose4x8(output);
+
+  // Reverse p values to produce original order:
+  // p3 p2 p1 p0 q0 q1 q2 q3
+  vst1q_u16(dst_0, ReverseLowHalf(output[0]));
+  vst1q_u16(dst_1, ReverseLowHalf(output[1]));
+  vst1q_u16(dst_2, ReverseLowHalf(output[2]));
+  vst1q_u16(dst_3, ReverseLowHalf(output[3]));
+}
+inline void Filter14(const uint16x8_t p6q6, const uint16x8_t p5q5,
+                     const uint16x8_t p4q4, const uint16x8_t p3q3,
+                     const uint16x8_t p2q2, const uint16x8_t p1q1,
+                     const uint16x8_t p0q0, uint16x8_t* const p5q5_output,
+                     uint16x8_t* const p4q4_output,
+                     uint16x8_t* const p3q3_output,
+                     uint16x8_t* const p2q2_output,
+                     uint16x8_t* const p1q1_output,
+                     uint16x8_t* const p0q0_output) {
+  // Sum p5 and q5 output from opposite directions.
+  // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+  //      ^^^^^^^^
+  // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+  //                                                     ^^^^^^^^
+  const uint16x8_t p6q6_x7 = vsubq_u16(vshlq_n_u16(p6q6, 3), p6q6);
+
+  // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+  //                 ^^^^^^^^^^^^^^^^^^^
+  // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+  //                               ^^^^^^^^^^^^^^^^^^^
+  uint16x8_t sum = vshlq_n_u16(vaddq_u16(p5q5, p4q4), 1);
+  sum = vaddq_u16(sum, p6q6_x7);
+
+  // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+  //                                       ^^^^^^^
+  // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+  //                     ^^^^^^^
+  sum = vaddq_u16(vaddq_u16(p3q3, p2q2), sum);
+
+  // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+  //                                                 ^^^^^^^
+  // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+  //           ^^^^^^^
+  sum = vaddq_u16(vaddq_u16(p1q1, p0q0), sum);
+
+  // p5 = (7 * p6) + (2 * p5) + (2 * p4) + p3 + p2 + p1 + p0 + q0
+  //                                                           ^^
+  // q5 = p0 + q0 + q1 + q2 + q3 + (2 * q4) + (2 * q5) + (7 * q6)
+  //      ^^
+  const uint16x8_t q0p0 = Transpose64(p0q0);
+  sum = vaddq_u16(sum, q0p0);
+
+  *p5q5_output = vrshrq_n_u16(sum, 4);
+
+  // Convert to p4 and q4 output:
+  // p4 = p5 - (2 * p6) + p3 + q1
+  // q4 = q5 - (2 * q6) + q3 + p1
+  sum = vsubq_u16(sum, vshlq_n_u16(p6q6, 1));
+  const uint16x8_t q1p1 = Transpose64(p1q1);
+  sum = vaddq_u16(vaddq_u16(p3q3, q1p1), sum);
+
+  *p4q4_output = vrshrq_n_u16(sum, 4);
+
+  // Convert to p3 and q3 output:
+  // p3 = p4 - p6 - p5 + p2 + q2
+  // q3 = q4 - q6 - q5 + q2 + p2
+  sum = vsubq_u16(sum, vaddq_u16(p6q6, p5q5));
+  const uint16x8_t q2p2 = Transpose64(p2q2);
+  sum = vaddq_u16(vaddq_u16(p2q2, q2p2), sum);
+
+  *p3q3_output = vrshrq_n_u16(sum, 4);
+
+  // Convert to p2 and q2 output:
+  // p2 = p3 - p6 - p4 + p1 + q3
+  // q2 = q3 - q6 - q4 + q1 + p3
+  sum = vsubq_u16(sum, vaddq_u16(p6q6, p4q4));
+  const uint16x8_t q3p3 = Transpose64(p3q3);
+  sum = vaddq_u16(vaddq_u16(p1q1, q3p3), sum);
+
+  *p2q2_output = vrshrq_n_u16(sum, 4);
+
+  // Convert to p1 and q1 output:
+  // p1 = p2 - p6 - p3 + p0 + q4
+  // q1 = q2 - q6 - q3 + q0 + p4
+  sum = vsubq_u16(sum, vaddq_u16(p6q6, p3q3));
+  const uint16x8_t q4p4 = Transpose64(p4q4);
+  sum = vaddq_u16(vaddq_u16(p0q0, q4p4), sum);
+
+  *p1q1_output = vrshrq_n_u16(sum, 4);
+
+  // Convert to p0 and q0 output:
+  // p0 = p1 - p6 - p2 + q0 + q5
+  // q0 = q1 - q6 - q2 + p0 + p5
+  sum = vsubq_u16(sum, vaddq_u16(p6q6, p2q2));
+  const uint16x8_t q5p5 = Transpose64(p5q5);
+  sum = vaddq_u16(vaddq_u16(q0p0, q5p5), sum);
+
+  *p0q0_output = vrshrq_n_u16(sum, 4);
+}
+
+void Horizontal14_NEON(void* const dest, const ptrdiff_t stride,
+                       int outer_thresh, int inner_thresh, int hev_thresh) {
+  auto* const dst = static_cast<uint8_t*>(dest);
+  auto* const dst_p6 = reinterpret_cast<uint16_t*>(dst - 7 * stride);
+  auto* const dst_p5 = reinterpret_cast<uint16_t*>(dst - 6 * stride);
+  auto* const dst_p4 = reinterpret_cast<uint16_t*>(dst - 5 * stride);
+  auto* const dst_p3 = reinterpret_cast<uint16_t*>(dst - 4 * stride);
+  auto* const dst_p2 = reinterpret_cast<uint16_t*>(dst - 3 * stride);
+  auto* const dst_p1 = reinterpret_cast<uint16_t*>(dst - 2 * stride);
+  auto* const dst_p0 = reinterpret_cast<uint16_t*>(dst - stride);
+  auto* const dst_q0 = reinterpret_cast<uint16_t*>(dst);
+  auto* const dst_q1 = reinterpret_cast<uint16_t*>(dst + stride);
+  auto* const dst_q2 = reinterpret_cast<uint16_t*>(dst + 2 * stride);
+  auto* const dst_q3 = reinterpret_cast<uint16_t*>(dst + 3 * stride);
+  auto* const dst_q4 = reinterpret_cast<uint16_t*>(dst + 4 * stride);
+  auto* const dst_q5 = reinterpret_cast<uint16_t*>(dst + 5 * stride);
+  auto* const dst_q6 = reinterpret_cast<uint16_t*>(dst + 6 * stride);
+
+  const uint16x4_t src[14] = {
+      vld1_u16(dst_p6), vld1_u16(dst_p5), vld1_u16(dst_p4), vld1_u16(dst_p3),
+      vld1_u16(dst_p2), vld1_u16(dst_p1), vld1_u16(dst_p0), vld1_u16(dst_q0),
+      vld1_u16(dst_q1), vld1_u16(dst_q2), vld1_u16(dst_q3), vld1_u16(dst_q4),
+      vld1_u16(dst_q5), vld1_u16(dst_q6)};
+
+  // Adjust thresholds to bitdepth.
+  outer_thresh <<= 2;
+  inner_thresh <<= 2;
+  hev_thresh <<= 2;
+  const uint16x4_t outer_mask =
+      OuterThreshold(src[5], src[6], src[7], src[8], outer_thresh);
+  uint16x4_t hev_mask;
+  uint16x4_t needs_filter_mask;
+  uint16x4_t is_flat4_mask;
+  const uint16x8_t p0q0 = vcombine_u16(src[6], src[7]);
+  const uint16x8_t p1q1 = vcombine_u16(src[5], src[8]);
+  const uint16x8_t p2q2 = vcombine_u16(src[4], src[9]);
+  const uint16x8_t p3q3 = vcombine_u16(src[3], src[10]);
+  Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh,
+               &needs_filter_mask, &is_flat4_mask, &hev_mask);
+
+#if defined(__aarch64__)
+  if (vaddv_u16(needs_filter_mask) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#else   // !defined(__aarch64__)
+  // This might be faster than vaddv (latency 3) because mov to general register
+  // has latency 2.
+  const uint64x1_t needs_filter_mask64 =
+      vreinterpret_u64_u16(needs_filter_mask);
+  if (vget_lane_u64(needs_filter_mask64, 0) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#endif  // defined(__aarch64__)
+  const uint16x8_t p4q4 = vcombine_u16(src[2], src[11]);
+  const uint16x8_t p5q5 = vcombine_u16(src[1], src[12]);
+  const uint16x8_t p6q6 = vcombine_u16(src[0], src[13]);
+  // Mask to choose between the outputs of Filter8 and Filter14.
+  // As with the derivation of |is_flat4_mask|, the question of whether to use
+  // Filter14 is only raised where |is_flat4_mask| is true.
+  const uint16x4_t is_flat4_outer_mask = vand_u16(
+      is_flat4_mask, IsFlat4(vabdq_u16(p0q0, p4q4), vabdq_u16(p0q0, p5q5),
+                             vabdq_u16(p0q0, p6q6)));
+  // Copy the masks to the high bits for packed comparisons later.
+  const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask);
+  const uint16x8_t needs_filter_mask_8 =
+      vcombine_u16(needs_filter_mask, needs_filter_mask);
+
+  uint16x8_t f4_p1q1;
+  uint16x8_t f4_p0q0;
+  // ZIP1 p0q0, p1q1 may perform better here.
+  const uint16x8_t p0q1 = vcombine_u16(src[6], src[8]);
+  Filter4(p0q0, p0q1, p1q1, hev_mask, &f4_p1q1, &f4_p0q0);
+  f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1);
+
+  uint16x8_t p0q0_output, p1q1_output, p2q2_output, p3q3_output, p4q4_output,
+      p5q5_output;
+  // Because we did not return after testing |needs_filter_mask| we know it is
+  // nonzero. |is_flat4_mask| controls whether the needed filter is Filter4 or
+  // Filter8. Therefore if it is false when |needs_filter_mask| is true, Filter8
+  // output is not used.
+  uint16x8_t f8_p2q2, f8_p1q1, f8_p0q0;
+  const uint64x1_t need_filter8 = vreinterpret_u64_u16(is_flat4_mask);
+  if (vget_lane_u64(need_filter8, 0) == 0) {
+    // Filter8() and Filter14() do not apply, but Filter4() applies to one or
+    // more values.
+    p5q5_output = p5q5;
+    p4q4_output = p4q4;
+    p3q3_output = p3q3;
+    p2q2_output = p2q2;
+    p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0);
+  } else {
+    const uint16x8_t use_filter8_mask =
+        vcombine_u16(is_flat4_mask, is_flat4_mask);
+    Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0);
+    const uint64x1_t need_filter14 = vreinterpret_u64_u16(is_flat4_outer_mask);
+    if (vget_lane_u64(need_filter14, 0) == 0) {
+      // Filter14() does not apply, but Filter8() and Filter4() apply to one or
+      // more values.
+      p5q5_output = p5q5;
+      p4q4_output = p4q4;
+      p3q3_output = p3q3;
+      p2q2_output = vbslq_u16(use_filter8_mask, f8_p2q2, p2q2);
+      p1q1_output = vbslq_u16(use_filter8_mask, f8_p1q1, f4_p1q1);
+      p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1);
+      p0q0_output = vbslq_u16(use_filter8_mask, f8_p0q0, f4_p0q0);
+      p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0);
+    } else {
+      // All filters may contribute values to final outputs.
+      const uint16x8_t use_filter14_mask =
+          vcombine_u16(is_flat4_outer_mask, is_flat4_outer_mask);
+      uint16x8_t f14_p5q5, f14_p4q4, f14_p3q3, f14_p2q2, f14_p1q1, f14_p0q0;
+      Filter14(p6q6, p5q5, p4q4, p3q3, p2q2, p1q1, p0q0, &f14_p5q5, &f14_p4q4,
+               &f14_p3q3, &f14_p2q2, &f14_p1q1, &f14_p0q0);
+      p5q5_output = vbslq_u16(use_filter14_mask, f14_p5q5, p5q5);
+      p4q4_output = vbslq_u16(use_filter14_mask, f14_p4q4, p4q4);
+      p3q3_output = vbslq_u16(use_filter14_mask, f14_p3q3, p3q3);
+      p2q2_output = vbslq_u16(use_filter14_mask, f14_p2q2, f8_p2q2);
+      p2q2_output = vbslq_u16(use_filter8_mask, p2q2_output, p2q2);
+      p2q2_output = vbslq_u16(needs_filter_mask_8, p2q2_output, p2q2);
+      p1q1_output = vbslq_u16(use_filter14_mask, f14_p1q1, f8_p1q1);
+      p1q1_output = vbslq_u16(use_filter8_mask, p1q1_output, f4_p1q1);
+      p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1);
+      p0q0_output = vbslq_u16(use_filter14_mask, f14_p0q0, f8_p0q0);
+      p0q0_output = vbslq_u16(use_filter8_mask, p0q0_output, f4_p0q0);
+      p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0);
+    }
+  }
+
+  vst1_u16(dst_p5, vget_low_u16(p5q5_output));
+  vst1_u16(dst_p4, vget_low_u16(p4q4_output));
+  vst1_u16(dst_p3, vget_low_u16(p3q3_output));
+  vst1_u16(dst_p2, vget_low_u16(p2q2_output));
+  vst1_u16(dst_p1, vget_low_u16(p1q1_output));
+  vst1_u16(dst_p0, vget_low_u16(p0q0_output));
+  vst1_u16(dst_q0, vget_high_u16(p0q0_output));
+  vst1_u16(dst_q1, vget_high_u16(p1q1_output));
+  vst1_u16(dst_q2, vget_high_u16(p2q2_output));
+  vst1_u16(dst_q3, vget_high_u16(p3q3_output));
+  vst1_u16(dst_q4, vget_high_u16(p4q4_output));
+  vst1_u16(dst_q5, vget_high_u16(p5q5_output));
+}
+
+inline uint16x8x2_t PermuteACDB64(const uint16x8_t ab, const uint16x8_t cd) {
+  uint16x8x2_t acdb;
+#if defined(__aarch64__)
+  // a[b] <- [c]d
+  acdb.val[0] = vreinterpretq_u16_u64(
+      vtrn1q_u64(vreinterpretq_u64_u16(ab), vreinterpretq_u64_u16(cd)));
+  // [a]b <- c[d]
+  acdb.val[1] = vreinterpretq_u16_u64(
+      vtrn2q_u64(vreinterpretq_u64_u16(cd), vreinterpretq_u64_u16(ab)));
+#else
+  // a[b] <- [c]d
+  acdb.val[0] = vreinterpretq_u16_u64(
+      vsetq_lane_u64(vgetq_lane_u64(vreinterpretq_u64_u16(cd), 0),
+                     vreinterpretq_u64_u16(ab), 1));
+  // [a]b <- c[d]
+  acdb.val[1] = vreinterpretq_u16_u64(
+      vsetq_lane_u64(vgetq_lane_u64(vreinterpretq_u64_u16(cd), 1),
+                     vreinterpretq_u64_u16(ab), 0));
+#endif  // defined(__aarch64__)
+  return acdb;
+}
+
+void Vertical14_NEON(void* const dest, const ptrdiff_t stride, int outer_thresh,
+                     int inner_thresh, int hev_thresh) {
+  auto* const dst = static_cast<uint8_t*>(dest) - 8 * sizeof(uint16_t);
+  auto* const dst_0 = reinterpret_cast<uint16_t*>(dst);
+  auto* const dst_1 = reinterpret_cast<uint16_t*>(dst + stride);
+  auto* const dst_2 = reinterpret_cast<uint16_t*>(dst + 2 * stride);
+  auto* const dst_3 = reinterpret_cast<uint16_t*>(dst + 3 * stride);
+
+  // Low halves:  p7 p6 p5 p4
+  // High halves: p3 p2 p1 p0
+  uint16x8_t src_p[4] = {vld1q_u16(dst_0), vld1q_u16(dst_1), vld1q_u16(dst_2),
+                         vld1q_u16(dst_3)};
+  // p7 will be the low half of src_p[0]. Not used until the end.
+  Transpose4x8(src_p);
+
+  // Low halves:  q0 q1 q2 q3
+  // High halves: q4 q5 q6 q7
+  uint16x8_t src_q[4] = {vld1q_u16(dst_0 + 8), vld1q_u16(dst_1 + 8),
+                         vld1q_u16(dst_2 + 8), vld1q_u16(dst_3 + 8)};
+  // q7 will be the high half of src_q[3]. Not used until the end.
+  Transpose4x8(src_q);
+
+  // Adjust thresholds to bitdepth.
+  outer_thresh <<= 2;
+  inner_thresh <<= 2;
+  hev_thresh <<= 2;
+  const uint16x4_t outer_mask = OuterThreshold(
+      vget_high_u16(src_p[2]), vget_high_u16(src_p[3]), vget_low_u16(src_q[0]),
+      vget_low_u16(src_q[1]), outer_thresh);
+  const uint16x8_t p0q0 = vextq_u16(src_p[3], src_q[0], 4);
+  const uint16x8_t p1q1 = vextq_u16(src_p[2], src_q[1], 4);
+  const uint16x8_t p2q2 = vextq_u16(src_p[1], src_q[2], 4);
+  const uint16x8_t p3q3 = vextq_u16(src_p[0], src_q[3], 4);
+  uint16x4_t hev_mask;
+  uint16x4_t needs_filter_mask;
+  uint16x4_t is_flat4_mask;
+  Filter8Masks(p3q3, p2q2, p1q1, p0q0, hev_thresh, outer_mask, inner_thresh,
+               &needs_filter_mask, &is_flat4_mask, &hev_mask);
+
+#if defined(__aarch64__)
+  if (vaddv_u16(needs_filter_mask) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#else   // !defined(__aarch64__)
+  // This might be faster than vaddv (latency 3) because mov to general register
+  // has latency 2.
+  const uint64x1_t needs_filter_mask64 =
+      vreinterpret_u64_u16(needs_filter_mask);
+  if (vget_lane_u64(needs_filter_mask64, 0) == 0) {
+    // None of the values will be filtered.
+    return;
+  }
+#endif  // defined(__aarch64__)
+  const uint16x8_t p4q4 =
+      vcombine_u16(vget_low_u16(src_p[3]), vget_high_u16(src_q[0]));
+  const uint16x8_t p5q5 =
+      vcombine_u16(vget_low_u16(src_p[2]), vget_high_u16(src_q[1]));
+  const uint16x8_t p6q6 =
+      vcombine_u16(vget_low_u16(src_p[1]), vget_high_u16(src_q[2]));
+  const uint16x8_t p7q7 =
+      vcombine_u16(vget_low_u16(src_p[0]), vget_high_u16(src_q[3]));
+  // Mask to choose between the outputs of Filter8 and Filter14.
+  // As with the derivation of |is_flat4_mask|, the question of whether to use
+  // Filter14 is only raised where |is_flat4_mask| is true.
+  const uint16x4_t is_flat4_outer_mask = vand_u16(
+      is_flat4_mask, IsFlat4(vabdq_u16(p0q0, p4q4), vabdq_u16(p0q0, p5q5),
+                             vabdq_u16(p0q0, p6q6)));
+  // Copy the masks to the high bits for packed comparisons later.
+  const uint16x8_t hev_mask_8 = vcombine_u16(hev_mask, hev_mask);
+  const uint16x8_t needs_filter_mask_8 =
+      vcombine_u16(needs_filter_mask, needs_filter_mask);
+
+  uint16x8_t f4_p1q1;
+  uint16x8_t f4_p0q0;
+  const uint16x8_t p0q1 = vcombine_u16(vget_low_u16(p0q0), vget_high_u16(p1q1));
+  Filter4(p0q0, p0q1, p1q1, hev_mask, &f4_p1q1, &f4_p0q0);
+  f4_p1q1 = vbslq_u16(hev_mask_8, p1q1, f4_p1q1);
+
+  uint16x8_t p0q0_output, p1q1_output, p2q2_output, p3q3_output, p4q4_output,
+      p5q5_output;
+  // Because we did not return after testing |needs_filter_mask| we know it is
+  // nonzero. |is_flat4_mask| controls whether the needed filter is Filter4 or
+  // Filter8. Therefore if it is false when |needs_filter_mask| is true, Filter8
+  // output is not used.
+  uint16x8_t f8_p2q2, f8_p1q1, f8_p0q0;
+  const uint64x1_t need_filter8 = vreinterpret_u64_u16(is_flat4_mask);
+  if (vget_lane_u64(need_filter8, 0) == 0) {
+    // Filter8() and Filter14() do not apply, but Filter4() applies to one or
+    // more values.
+    p5q5_output = p5q5;
+    p4q4_output = p4q4;
+    p3q3_output = p3q3;
+    p2q2_output = p2q2;
+    p1q1_output = vbslq_u16(needs_filter_mask_8, f4_p1q1, p1q1);
+    p0q0_output = vbslq_u16(needs_filter_mask_8, f4_p0q0, p0q0);
+  } else {
+    const uint16x8_t use_filter8_mask =
+        vcombine_u16(is_flat4_mask, is_flat4_mask);
+    Filter8(p3q3, p2q2, p1q1, p0q0, &f8_p2q2, &f8_p1q1, &f8_p0q0);
+    const uint64x1_t need_filter14 = vreinterpret_u64_u16(is_flat4_outer_mask);
+    if (vget_lane_u64(need_filter14, 0) == 0) {
+      // Filter14() does not apply, but Filter8() and Filter4() apply to one or
+      // more values.
+      p5q5_output = p5q5;
+      p4q4_output = p4q4;
+      p3q3_output = p3q3;
+      p2q2_output = vbslq_u16(use_filter8_mask, f8_p2q2, p2q2);
+      p1q1_output = vbslq_u16(use_filter8_mask, f8_p1q1, f4_p1q1);
+      p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1);
+      p0q0_output = vbslq_u16(use_filter8_mask, f8_p0q0, f4_p0q0);
+      p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0);
+    } else {
+      // All filters may contribute values to final outputs.
+      const uint16x8_t use_filter14_mask =
+          vcombine_u16(is_flat4_outer_mask, is_flat4_outer_mask);
+      uint16x8_t f14_p5q5, f14_p4q4, f14_p3q3, f14_p2q2, f14_p1q1, f14_p0q0;
+      Filter14(p6q6, p5q5, p4q4, p3q3, p2q2, p1q1, p0q0, &f14_p5q5, &f14_p4q4,
+               &f14_p3q3, &f14_p2q2, &f14_p1q1, &f14_p0q0);
+      p5q5_output = vbslq_u16(use_filter14_mask, f14_p5q5, p5q5);
+      p4q4_output = vbslq_u16(use_filter14_mask, f14_p4q4, p4q4);
+      p3q3_output = vbslq_u16(use_filter14_mask, f14_p3q3, p3q3);
+      p2q2_output = vbslq_u16(use_filter14_mask, f14_p2q2, f8_p2q2);
+      p2q2_output = vbslq_u16(use_filter8_mask, p2q2_output, p2q2);
+      p2q2_output = vbslq_u16(needs_filter_mask_8, p2q2_output, p2q2);
+      p1q1_output = vbslq_u16(use_filter14_mask, f14_p1q1, f8_p1q1);
+      p1q1_output = vbslq_u16(use_filter8_mask, p1q1_output, f4_p1q1);
+      p1q1_output = vbslq_u16(needs_filter_mask_8, p1q1_output, p1q1);
+      p0q0_output = vbslq_u16(use_filter14_mask, f14_p0q0, f8_p0q0);
+      p0q0_output = vbslq_u16(use_filter8_mask, p0q0_output, f4_p0q0);
+      p0q0_output = vbslq_u16(needs_filter_mask_8, p0q0_output, p0q0);
+    }
+  }
+  // To get the correctly ordered rows from the transpose, we need:
+  // p7p3 p6p2 p5p1 p4p0
+  // q0q4 q1q5 q2q6 q3q7
+  const uint16x8x2_t p7p3_q3q7 = PermuteACDB64(p7q7, p3q3_output);
+  const uint16x8x2_t p6p2_q2q6 = PermuteACDB64(p6q6, p2q2_output);
+  const uint16x8x2_t p5p1_q1q5 = PermuteACDB64(p5q5_output, p1q1_output);
+  const uint16x8x2_t p4p0_q0q4 = PermuteACDB64(p4q4_output, p0q0_output);
+  uint16x8_t output_p[4] = {p7p3_q3q7.val[0], p6p2_q2q6.val[0],
+                            p5p1_q1q5.val[0], p4p0_q0q4.val[0]};
+  Transpose4x8(output_p);
+  uint16x8_t output_q[4] = {p4p0_q0q4.val[1], p5p1_q1q5.val[1],
+                            p6p2_q2q6.val[1], p7p3_q3q7.val[1]};
+  Transpose4x8(output_q);
+
+  // Reverse p values to produce original order:
+  // p3 p2 p1 p0 q0 q1 q2 q3
+  vst1q_u16(dst_0, output_p[0]);
+  vst1q_u16(dst_0 + 8, output_q[0]);
+  vst1q_u16(dst_1, output_p[1]);
+  vst1q_u16(dst_1 + 8, output_q[1]);
+  vst1q_u16(dst_2, output_p[2]);
+  vst1q_u16(dst_2 + 8, output_q[2]);
+  vst1q_u16(dst_3, output_p[3]);
+  vst1q_u16(dst_3 + 8, output_q[3]);
+}
+
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeHorizontal] =
+      Horizontal4_NEON;
+  dsp->loop_filters[kLoopFilterSize4][kLoopFilterTypeVertical] = Vertical4_NEON;
+  dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeHorizontal] =
+      Horizontal6_NEON;
+  dsp->loop_filters[kLoopFilterSize6][kLoopFilterTypeVertical] = Vertical6_NEON;
+  dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeHorizontal] =
+      Horizontal8_NEON;
+  dsp->loop_filters[kLoopFilterSize8][kLoopFilterTypeVertical] = Vertical8_NEON;
+  dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeHorizontal] =
+      Horizontal14_NEON;
+  dsp->loop_filters[kLoopFilterSize14][kLoopFilterTypeVertical] =
+      Vertical14_NEON;
+}
+
+}  // namespace
+}  // namespace high_bitdepth
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+void LoopFilterInit_NEON() {
+  low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  high_bitdepth::Init10bpp();
+#endif
+}
 
 }  // namespace dsp
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/arm/loop_filter_neon.h b/libgav1/src/dsp/arm/loop_filter_neon.h
index 5f79200..540defc 100644
--- a/libgav1/src/dsp/arm/loop_filter_neon.h
+++ b/libgav1/src/dsp/arm/loop_filter_neon.h
@@ -48,6 +48,23 @@
   LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_LoopFilterSize14_LoopFilterTypeVertical LIBGAV1_CPU_NEON
 
+#define LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeHorizontal \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_LoopFilterSize4_LoopFilterTypeVertical LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeHorizontal \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_LoopFilterSize6_LoopFilterTypeVertical LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeHorizontal \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_LoopFilterSize8_LoopFilterTypeVertical LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeHorizontal \
+  LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_LoopFilterSize14_LoopFilterTypeVertical \
+  LIBGAV1_CPU_NEON
+
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_LOOP_FILTER_NEON_H_
diff --git a/libgav1/src/dsp/arm/loop_restoration_10bit_neon.cc b/libgav1/src/dsp/arm/loop_restoration_10bit_neon.cc
new file mode 100644
index 0000000..410bc20
--- /dev/null
+++ b/libgav1/src/dsp/arm/loop_restoration_10bit_neon.cc
@@ -0,0 +1,2652 @@
+// Copyright 2021 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/dsp/loop_restoration.h"
+#include "src/utils/cpu.h"
+
+#if LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
+#include <arm_neon.h>
+
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+
+#include "src/dsp/arm/common_neon.h"
+#include "src/dsp/constants.h"
+#include "src/dsp/dsp.h"
+#include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
+#include "src/utils/constants.h"
+
+namespace libgav1 {
+namespace dsp {
+namespace {
+
+//------------------------------------------------------------------------------
+// Wiener
+
+// Must make a local copy of coefficients to help compiler know that they have
+// no overlap with other buffers. Using 'const' keyword is not enough. Actually
+// compiler doesn't make a copy, since there is enough registers in this case.
+inline void PopulateWienerCoefficients(
+    const RestorationUnitInfo& restoration_info, const int direction,
+    int16_t filter[4]) {
+  for (int i = 0; i < 4; ++i) {
+    filter[i] = restoration_info.wiener_info.filter[direction][i];
+  }
+}
+
+inline int32x4x2_t WienerHorizontal2(const uint16x8_t s0, const uint16x8_t s1,
+                                     const int16_t filter,
+                                     const int32x4x2_t sum) {
+  const int16x8_t ss = vreinterpretq_s16_u16(vaddq_u16(s0, s1));
+  int32x4x2_t res;
+  res.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(ss), filter);
+  res.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(ss), filter);
+  return res;
+}
+
+inline void WienerHorizontalSum(const uint16x8_t s[3], const int16_t filter[4],
+                                int32x4x2_t sum, int16_t* const wiener_buffer) {
+  constexpr int offset =
+      1 << (kBitdepth10 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
+  constexpr int limit = (offset << 2) - 1;
+  const int16x8_t s_0_2 = vreinterpretq_s16_u16(vaddq_u16(s[0], s[2]));
+  const int16x8_t s_1 = vreinterpretq_s16_u16(s[1]);
+  int16x4x2_t sum16;
+  sum.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(s_0_2), filter[2]);
+  sum.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(s_1), filter[3]);
+  sum16.val[0] = vqshrn_n_s32(sum.val[0], kInterRoundBitsHorizontal);
+  sum16.val[0] = vmax_s16(sum16.val[0], vdup_n_s16(-offset));
+  sum16.val[0] = vmin_s16(sum16.val[0], vdup_n_s16(limit - offset));
+  vst1_s16(wiener_buffer, sum16.val[0]);
+  sum.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(s_0_2), filter[2]);
+  sum.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(s_1), filter[3]);
+  sum16.val[1] = vqshrn_n_s32(sum.val[1], kInterRoundBitsHorizontal);
+  sum16.val[1] = vmax_s16(sum16.val[1], vdup_n_s16(-offset));
+  sum16.val[1] = vmin_s16(sum16.val[1], vdup_n_s16(limit - offset));
+  vst1_s16(wiener_buffer + 4, sum16.val[1]);
+}
+
+inline void WienerHorizontalTap7(const uint16_t* src,
+                                 const ptrdiff_t src_stride,
+                                 const ptrdiff_t wiener_stride,
+                                 const ptrdiff_t width, const int height,
+                                 const int16_t filter[4],
+                                 int16_t** const wiener_buffer) {
+  const ptrdiff_t src_width =
+      width + ((kRestorationHorizontalBorder - 1) * sizeof(*src));
+  for (int y = height; y != 0; --y) {
+    const uint16_t* src_ptr = src;
+    uint16x8_t s[8];
+    s[0] = vld1q_u16(src_ptr);
+    ptrdiff_t x = wiener_stride;
+    ptrdiff_t valid_bytes = src_width * 2;
+    do {
+      src_ptr += 8;
+      valid_bytes -= 16;
+      s[7] = Load1QMsanU16(src_ptr, 16 - valid_bytes);
+      s[1] = vextq_u16(s[0], s[7], 1);
+      s[2] = vextq_u16(s[0], s[7], 2);
+      s[3] = vextq_u16(s[0], s[7], 3);
+      s[4] = vextq_u16(s[0], s[7], 4);
+      s[5] = vextq_u16(s[0], s[7], 5);
+      s[6] = vextq_u16(s[0], s[7], 6);
+      int32x4x2_t sum;
+      sum.val[0] = sum.val[1] =
+          vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 1));
+      sum = WienerHorizontal2(s[0], s[6], filter[0], sum);
+      sum = WienerHorizontal2(s[1], s[5], filter[1], sum);
+      WienerHorizontalSum(s + 2, filter, sum, *wiener_buffer);
+      s[0] = s[7];
+      *wiener_buffer += 8;
+      x -= 8;
+    } while (x != 0);
+    src += src_stride;
+  }
+}
+
+inline void WienerHorizontalTap5(const uint16_t* src,
+                                 const ptrdiff_t src_stride,
+                                 const ptrdiff_t wiener_stride,
+                                 const ptrdiff_t width, const int height,
+                                 const int16_t filter[4],
+                                 int16_t** const wiener_buffer) {
+  const ptrdiff_t src_width =
+      width + ((kRestorationHorizontalBorder - 1) * sizeof(*src));
+  for (int y = height; y != 0; --y) {
+    const uint16_t* src_ptr = src;
+    uint16x8_t s[6];
+    s[0] = vld1q_u16(src_ptr);
+    ptrdiff_t x = wiener_stride;
+    ptrdiff_t valid_bytes = src_width * 2;
+    do {
+      src_ptr += 8;
+      valid_bytes -= 16;
+      s[5] = Load1QMsanU16(src_ptr, 16 - valid_bytes);
+      s[1] = vextq_u16(s[0], s[5], 1);
+      s[2] = vextq_u16(s[0], s[5], 2);
+      s[3] = vextq_u16(s[0], s[5], 3);
+      s[4] = vextq_u16(s[0], s[5], 4);
+
+      int32x4x2_t sum;
+      sum.val[0] = sum.val[1] =
+          vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 1));
+      sum = WienerHorizontal2(s[0], s[4], filter[1], sum);
+      WienerHorizontalSum(s + 1, filter, sum, *wiener_buffer);
+      s[0] = s[5];
+      *wiener_buffer += 8;
+      x -= 8;
+    } while (x != 0);
+    src += src_stride;
+  }
+}
+
+inline void WienerHorizontalTap3(const uint16_t* src,
+                                 const ptrdiff_t src_stride,
+                                 const ptrdiff_t width, const int height,
+                                 const int16_t filter[4],
+                                 int16_t** const wiener_buffer) {
+  for (int y = height; y != 0; --y) {
+    const uint16_t* src_ptr = src;
+    uint16x8_t s[3];
+    ptrdiff_t x = width;
+    do {
+      s[0] = vld1q_u16(src_ptr);
+      s[1] = vld1q_u16(src_ptr + 1);
+      s[2] = vld1q_u16(src_ptr + 2);
+
+      int32x4x2_t sum;
+      sum.val[0] = sum.val[1] =
+          vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 1));
+      WienerHorizontalSum(s, filter, sum, *wiener_buffer);
+      src_ptr += 8;
+      *wiener_buffer += 8;
+      x -= 8;
+    } while (x != 0);
+    src += src_stride;
+  }
+}
+
+inline void WienerHorizontalTap1(const uint16_t* src,
+                                 const ptrdiff_t src_stride,
+                                 const ptrdiff_t width, const int height,
+                                 int16_t** const wiener_buffer) {
+  for (int y = height; y != 0; --y) {
+    ptrdiff_t x = 0;
+    do {
+      const uint16x8_t s = vld1q_u16(src + x);
+      const int16x8_t d = vreinterpretq_s16_u16(vshlq_n_u16(s, 4));
+      vst1q_s16(*wiener_buffer + x, d);
+      x += 8;
+    } while (x < width);
+    src += src_stride;
+    *wiener_buffer += width;
+  }
+}
+
+inline int32x4x2_t WienerVertical2(const int16x8_t a0, const int16x8_t a1,
+                                   const int16_t filter,
+                                   const int32x4x2_t sum) {
+  int32x4x2_t d;
+  d.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(a0), filter);
+  d.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(a0), filter);
+  d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a1), filter);
+  d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a1), filter);
+  return d;
+}
+
+inline uint16x8_t WienerVertical(const int16x8_t a[3], const int16_t filter[4],
+                                 const int32x4x2_t sum) {
+  int32x4x2_t d = WienerVertical2(a[0], a[2], filter[2], sum);
+  d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a[1]), filter[3]);
+  d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a[1]), filter[3]);
+  const uint16x4_t sum_lo_16 = vqrshrun_n_s32(d.val[0], 11);
+  const uint16x4_t sum_hi_16 = vqrshrun_n_s32(d.val[1], 11);
+  return vcombine_u16(sum_lo_16, sum_hi_16);
+}
+
+inline uint16x8_t WienerVerticalTap7Kernel(const int16_t* const wiener_buffer,
+                                           const ptrdiff_t wiener_stride,
+                                           const int16_t filter[4],
+                                           int16x8_t a[7]) {
+  int32x4x2_t sum;
+  a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
+  a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
+  a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
+  a[6] = vld1q_s16(wiener_buffer + 6 * wiener_stride);
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  sum = WienerVertical2(a[0], a[6], filter[0], sum);
+  sum = WienerVertical2(a[1], a[5], filter[1], sum);
+  a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
+  a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
+  a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
+  return WienerVertical(a + 2, filter, sum);
+}
+
+inline uint16x8x2_t WienerVerticalTap7Kernel2(
+    const int16_t* const wiener_buffer, const ptrdiff_t wiener_stride,
+    const int16_t filter[4]) {
+  int16x8_t a[8];
+  int32x4x2_t sum;
+  uint16x8x2_t d;
+  d.val[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
+  a[7] = vld1q_s16(wiener_buffer + 7 * wiener_stride);
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  sum = WienerVertical2(a[1], a[7], filter[0], sum);
+  sum = WienerVertical2(a[2], a[6], filter[1], sum);
+  d.val[1] = WienerVertical(a + 3, filter, sum);
+  return d;
+}
+
+inline void WienerVerticalTap7(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               const int16_t filter[4], uint16_t* dst,
+                               const ptrdiff_t dst_stride) {
+  const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
+  for (int y = height >> 1; y != 0; --y) {
+    uint16_t* dst_ptr = dst;
+    ptrdiff_t x = width;
+    do {
+      uint16x8x2_t d[2];
+      d[0] = WienerVerticalTap7Kernel2(wiener_buffer + 0, width, filter);
+      d[1] = WienerVerticalTap7Kernel2(wiener_buffer + 8, width, filter);
+      vst1q_u16(dst_ptr, vminq_u16(d[0].val[0], v_max_bitdepth));
+      vst1q_u16(dst_ptr + 8, vminq_u16(d[1].val[0], v_max_bitdepth));
+      vst1q_u16(dst_ptr + dst_stride, vminq_u16(d[0].val[1], v_max_bitdepth));
+      vst1q_u16(dst_ptr + 8 + dst_stride,
+                vminq_u16(d[1].val[1], v_max_bitdepth));
+      wiener_buffer += 16;
+      dst_ptr += 16;
+      x -= 16;
+    } while (x != 0);
+    wiener_buffer += width;
+    dst += 2 * dst_stride;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = width;
+    do {
+      int16x8_t a[7];
+      const uint16x8_t d0 =
+          WienerVerticalTap7Kernel(wiener_buffer + 0, width, filter, a);
+      const uint16x8_t d1 =
+          WienerVerticalTap7Kernel(wiener_buffer + 8, width, filter, a);
+      vst1q_u16(dst, vminq_u16(d0, v_max_bitdepth));
+      vst1q_u16(dst + 8, vminq_u16(d1, v_max_bitdepth));
+      wiener_buffer += 16;
+      dst += 16;
+      x -= 16;
+    } while (x != 0);
+  }
+}
+
+inline uint16x8_t WienerVerticalTap5Kernel(const int16_t* const wiener_buffer,
+                                           const ptrdiff_t wiener_stride,
+                                           const int16_t filter[4],
+                                           int16x8_t a[5]) {
+  a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
+  a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
+  a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
+  a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
+  a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
+  int32x4x2_t sum;
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  sum = WienerVertical2(a[0], a[4], filter[1], sum);
+  return WienerVertical(a + 1, filter, sum);
+}
+
+inline uint16x8x2_t WienerVerticalTap5Kernel2(
+    const int16_t* const wiener_buffer, const ptrdiff_t wiener_stride,
+    const int16_t filter[4]) {
+  int16x8_t a[6];
+  int32x4x2_t sum;
+  uint16x8x2_t d;
+  d.val[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
+  a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  sum = WienerVertical2(a[1], a[5], filter[1], sum);
+  d.val[1] = WienerVertical(a + 2, filter, sum);
+  return d;
+}
+
+inline void WienerVerticalTap5(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               const int16_t filter[4], uint16_t* dst,
+                               const ptrdiff_t dst_stride) {
+  const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
+  for (int y = height >> 1; y != 0; --y) {
+    uint16_t* dst_ptr = dst;
+    ptrdiff_t x = width;
+    do {
+      uint16x8x2_t d[2];
+      d[0] = WienerVerticalTap5Kernel2(wiener_buffer + 0, width, filter);
+      d[1] = WienerVerticalTap5Kernel2(wiener_buffer + 8, width, filter);
+      vst1q_u16(dst_ptr, vminq_u16(d[0].val[0], v_max_bitdepth));
+      vst1q_u16(dst_ptr + 8, vminq_u16(d[1].val[0], v_max_bitdepth));
+      vst1q_u16(dst_ptr + dst_stride, vminq_u16(d[0].val[1], v_max_bitdepth));
+      vst1q_u16(dst_ptr + 8 + dst_stride,
+                vminq_u16(d[1].val[1], v_max_bitdepth));
+      wiener_buffer += 16;
+      dst_ptr += 16;
+      x -= 16;
+    } while (x != 0);
+    wiener_buffer += width;
+    dst += 2 * dst_stride;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = width;
+    do {
+      int16x8_t a[5];
+      const uint16x8_t d0 =
+          WienerVerticalTap5Kernel(wiener_buffer + 0, width, filter, a);
+      const uint16x8_t d1 =
+          WienerVerticalTap5Kernel(wiener_buffer + 8, width, filter, a);
+      vst1q_u16(dst, vminq_u16(d0, v_max_bitdepth));
+      vst1q_u16(dst + 8, vminq_u16(d1, v_max_bitdepth));
+      wiener_buffer += 16;
+      dst += 16;
+      x -= 16;
+    } while (x != 0);
+  }
+}
+
+inline uint16x8_t WienerVerticalTap3Kernel(const int16_t* const wiener_buffer,
+                                           const ptrdiff_t wiener_stride,
+                                           const int16_t filter[4],
+                                           int16x8_t a[3]) {
+  a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
+  a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
+  a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
+  int32x4x2_t sum;
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  return WienerVertical(a, filter, sum);
+}
+
+inline uint16x8x2_t WienerVerticalTap3Kernel2(
+    const int16_t* const wiener_buffer, const ptrdiff_t wiener_stride,
+    const int16_t filter[4]) {
+  int16x8_t a[4];
+  int32x4x2_t sum;
+  uint16x8x2_t d;
+  d.val[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
+  a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
+  sum.val[0] = sum.val[1] = vdupq_n_s32(0);
+  d.val[1] = WienerVertical(a + 1, filter, sum);
+  return d;
+}
+
+inline void WienerVerticalTap3(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               const int16_t filter[4], uint16_t* dst,
+                               const ptrdiff_t dst_stride) {
+  const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
+
+  for (int y = height >> 1; y != 0; --y) {
+    uint16_t* dst_ptr = dst;
+    ptrdiff_t x = width;
+    do {
+      uint16x8x2_t d[2];
+      d[0] = WienerVerticalTap3Kernel2(wiener_buffer + 0, width, filter);
+      d[1] = WienerVerticalTap3Kernel2(wiener_buffer + 8, width, filter);
+
+      vst1q_u16(dst_ptr, vminq_u16(d[0].val[0], v_max_bitdepth));
+      vst1q_u16(dst_ptr + 8, vminq_u16(d[1].val[0], v_max_bitdepth));
+      vst1q_u16(dst_ptr + dst_stride, vminq_u16(d[0].val[1], v_max_bitdepth));
+      vst1q_u16(dst_ptr + 8 + dst_stride,
+                vminq_u16(d[1].val[1], v_max_bitdepth));
+
+      wiener_buffer += 16;
+      dst_ptr += 16;
+      x -= 16;
+    } while (x != 0);
+    wiener_buffer += width;
+    dst += 2 * dst_stride;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = width;
+    do {
+      int16x8_t a[3];
+      const uint16x8_t d0 =
+          WienerVerticalTap3Kernel(wiener_buffer + 0, width, filter, a);
+      const uint16x8_t d1 =
+          WienerVerticalTap3Kernel(wiener_buffer + 8, width, filter, a);
+      vst1q_u16(dst, vminq_u16(d0, v_max_bitdepth));
+      vst1q_u16(dst + 8, vminq_u16(d1, v_max_bitdepth));
+      wiener_buffer += 16;
+      dst += 16;
+      x -= 16;
+    } while (x != 0);
+  }
+}
+
+inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
+                                     uint16_t* const dst) {
+  const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
+  const int16x8_t a0 = vld1q_s16(wiener_buffer + 0);
+  const int16x8_t a1 = vld1q_s16(wiener_buffer + 8);
+  const int16x8_t d0 = vrshrq_n_s16(a0, 4);
+  const int16x8_t d1 = vrshrq_n_s16(a1, 4);
+  vst1q_u16(dst, vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(d0, vdupq_n_s16(0))),
+                           v_max_bitdepth));
+  vst1q_u16(dst + 8,
+            vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(d1, vdupq_n_s16(0))),
+                      v_max_bitdepth));
+}
+
+inline void WienerVerticalTap1(const int16_t* wiener_buffer,
+                               const ptrdiff_t width, const int height,
+                               uint16_t* dst, const ptrdiff_t dst_stride) {
+  for (int y = height >> 1; y != 0; --y) {
+    uint16_t* dst_ptr = dst;
+    ptrdiff_t x = width;
+    do {
+      WienerVerticalTap1Kernel(wiener_buffer, dst_ptr);
+      WienerVerticalTap1Kernel(wiener_buffer + width, dst_ptr + dst_stride);
+      wiener_buffer += 16;
+      dst_ptr += 16;
+      x -= 16;
+    } while (x != 0);
+    wiener_buffer += width;
+    dst += 2 * dst_stride;
+  }
+
+  if ((height & 1) != 0) {
+    ptrdiff_t x = width;
+    do {
+      WienerVerticalTap1Kernel(wiener_buffer, dst);
+      wiener_buffer += 16;
+      dst += 16;
+      x -= 16;
+    } while (x != 0);
+  }
+}
+
+// For width 16 and up, store the horizontal results, and then do the vertical
+// filter row by row. This is faster than doing it column by column when
+// considering cache issues.
+void WienerFilter_NEON(
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
+    const ptrdiff_t bottom_border_stride, const int width, const int height,
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
+  const int16_t* const number_leading_zero_coefficients =
+      restoration_info.wiener_info.number_leading_zero_coefficients;
+  const int number_rows_to_skip = std::max(
+      static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
+      1);
+  const ptrdiff_t wiener_stride = Align(width, 16);
+  int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer;
+  // The values are saturated to 13 bits before storing.
+  int16_t* wiener_buffer_horizontal =
+      wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
+  int16_t filter_horizontal[(kWienerFilterTaps + 1) / 2];
+  int16_t filter_vertical[(kWienerFilterTaps + 1) / 2];
+  PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal,
+                             filter_horizontal);
+  PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical,
+                             filter_vertical);
+  // horizontal filtering.
+  const int height_horizontal =
+      height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
+  const int height_extra = (height_horizontal - height) >> 1;
+  assert(height_extra <= 2);
+  const auto* const src = static_cast<const uint16_t*>(source);
+  const auto* const top = static_cast<const uint16_t*>(top_border);
+  const auto* const bottom = static_cast<const uint16_t*>(bottom_border);
+  if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
+    WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3,
+                         top_border_stride, wiener_stride, width, height_extra,
+                         filter_horizontal, &wiener_buffer_horizontal);
+    WienerHorizontalTap7(src - 3, stride, wiener_stride, width, height,
+                         filter_horizontal, &wiener_buffer_horizontal);
+    WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride, width,
+                         height_extra, filter_horizontal,
+                         &wiener_buffer_horizontal);
+  } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
+    WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2,
+                         top_border_stride, wiener_stride, width, height_extra,
+                         filter_horizontal, &wiener_buffer_horizontal);
+    WienerHorizontalTap5(src - 2, stride, wiener_stride, width, height,
+                         filter_horizontal, &wiener_buffer_horizontal);
+    WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride, width,
+                         height_extra, filter_horizontal,
+                         &wiener_buffer_horizontal);
+  } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
+    WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1,
+                         top_border_stride, wiener_stride, height_extra,
+                         filter_horizontal, &wiener_buffer_horizontal);
+    WienerHorizontalTap3(src - 1, stride, wiener_stride, height,
+                         filter_horizontal, &wiener_buffer_horizontal);
+    WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride,
+                         height_extra, filter_horizontal,
+                         &wiener_buffer_horizontal);
+  } else {
+    assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
+    WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride,
+                         top_border_stride, wiener_stride, height_extra,
+                         &wiener_buffer_horizontal);
+    WienerHorizontalTap1(src, stride, wiener_stride, height,
+                         &wiener_buffer_horizontal);
+    WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride,
+                         height_extra, &wiener_buffer_horizontal);
+  }
+
+  // vertical filtering.
+  auto* dst = static_cast<uint16_t*>(dest);
+  if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
+    // Because the top row of |source| is a duplicate of the second row, and the
+    // bottom row of |source| is a duplicate of its above row, we can duplicate
+    // the top and bottom row of |wiener_buffer| accordingly.
+    memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
+           sizeof(*wiener_buffer_horizontal) * wiener_stride);
+    memcpy(restoration_buffer->wiener_buffer,
+           restoration_buffer->wiener_buffer + wiener_stride,
+           sizeof(*restoration_buffer->wiener_buffer) * wiener_stride);
+    WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
+                       filter_vertical, dst, stride);
+  } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
+    WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
+                       height, filter_vertical, dst, stride);
+  } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
+    WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
+                       wiener_stride, height, filter_vertical, dst, stride);
+  } else {
+    assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
+    WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
+                       wiener_stride, height, dst, stride);
+  }
+}
+
+//------------------------------------------------------------------------------
+// SGR
+
+// SIMD overreads 8 - (width % 8) - 2 * padding pixels, where padding is 3 for
+// Pass 1 and 2 for Pass 2.
+constexpr int kOverreadInBytesPass1 = 4;
+constexpr int kOverreadInBytesPass2 = 8;
+
+inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x,
+                               uint16x8_t dst[2]) {
+  dst[0] = vld1q_u16(src[0] + x);
+  dst[1] = vld1q_u16(src[1] + x);
+}
+
+inline void LoadAligned16x2U16Msan(const uint16_t* const src[2],
+                                   const ptrdiff_t x, const ptrdiff_t border,
+                                   uint16x8_t dst[2]) {
+  dst[0] = Load1QMsanU16(src[0] + x, sizeof(**src) * (x + 8 - border));
+  dst[1] = Load1QMsanU16(src[1] + x, sizeof(**src) * (x + 8 - border));
+}
+
+inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x,
+                               uint16x8_t dst[3]) {
+  dst[0] = vld1q_u16(src[0] + x);
+  dst[1] = vld1q_u16(src[1] + x);
+  dst[2] = vld1q_u16(src[2] + x);
+}
+
+inline void LoadAligned16x3U16Msan(const uint16_t* const src[3],
+                                   const ptrdiff_t x, const ptrdiff_t border,
+                                   uint16x8_t dst[3]) {
+  dst[0] = Load1QMsanU16(src[0] + x, sizeof(**src) * (x + 8 - border));
+  dst[1] = Load1QMsanU16(src[1] + x, sizeof(**src) * (x + 8 - border));
+  dst[2] = Load1QMsanU16(src[2] + x, sizeof(**src) * (x + 8 - border));
+}
+
+inline void LoadAligned32U32(const uint32_t* const src, uint32x4_t dst[2]) {
+  dst[0] = vld1q_u32(src + 0);
+  dst[1] = vld1q_u32(src + 4);
+}
+
+inline void LoadAligned32U32Msan(const uint32_t* const src, const ptrdiff_t x,
+                                 const ptrdiff_t border, uint32x4_t dst[2]) {
+  dst[0] = Load1QMsanU32(src + x + 0, sizeof(*src) * (x + 4 - border));
+  dst[1] = Load1QMsanU32(src + x + 4, sizeof(*src) * (x + 8 - border));
+}
+
+inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x,
+                               uint32x4_t dst[2][2]) {
+  LoadAligned32U32(src[0] + x, dst[0]);
+  LoadAligned32U32(src[1] + x, dst[1]);
+}
+
+inline void LoadAligned32x2U32Msan(const uint32_t* const src[2],
+                                   const ptrdiff_t x, const ptrdiff_t border,
+                                   uint32x4_t dst[2][2]) {
+  LoadAligned32U32Msan(src[0], x, border, dst[0]);
+  LoadAligned32U32Msan(src[1], x, border, dst[1]);
+}
+
+inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x,
+                               uint32x4_t dst[3][2]) {
+  LoadAligned32U32(src[0] + x, dst[0]);
+  LoadAligned32U32(src[1] + x, dst[1]);
+  LoadAligned32U32(src[2] + x, dst[2]);
+}
+
+inline void LoadAligned32x3U32Msan(const uint32_t* const src[3],
+                                   const ptrdiff_t x, const ptrdiff_t border,
+                                   uint32x4_t dst[3][2]) {
+  LoadAligned32U32Msan(src[0], x, border, dst[0]);
+  LoadAligned32U32Msan(src[1], x, border, dst[1]);
+  LoadAligned32U32Msan(src[2], x, border, dst[2]);
+}
+
+inline void StoreAligned32U16(uint16_t* const dst, const uint16x8_t src[2]) {
+  vst1q_u16(dst + 0, src[0]);
+  vst1q_u16(dst + 8, src[1]);
+}
+
+inline void StoreAligned32U32(uint32_t* const dst, const uint32x4_t src[2]) {
+  vst1q_u32(dst + 0, src[0]);
+  vst1q_u32(dst + 4, src[1]);
+}
+
+inline void StoreAligned64U32(uint32_t* const dst, const uint32x4_t src[4]) {
+  StoreAligned32U32(dst + 0, src + 0);
+  StoreAligned32U32(dst + 8, src + 2);
+}
+
+inline uint16x8_t VaddwLo8(const uint16x8_t src0, const uint8x16_t src1) {
+  const uint8x8_t s1 = vget_low_u8(src1);
+  return vaddw_u8(src0, s1);
+}
+
+inline uint16x8_t VaddwHi8(const uint16x8_t src0, const uint8x16_t src1) {
+  const uint8x8_t s1 = vget_high_u8(src1);
+  return vaddw_u8(src0, s1);
+}
+
+inline uint32x4_t VmullLo16(const uint16x8_t src0, const uint16x8_t src1) {
+  return vmull_u16(vget_low_u16(src0), vget_low_u16(src1));
+}
+
+inline uint32x4_t VmullHi16(const uint16x8_t src0, const uint16x8_t src1) {
+  return vmull_u16(vget_high_u16(src0), vget_high_u16(src1));
+}
+
+template <int bytes>
+inline uint8x8_t VshrU128(const uint8x8x2_t src) {
+  return vext_u8(src.val[0], src.val[1], bytes);
+}
+
+template <int bytes>
+inline uint8x8_t VshrU128(const uint8x8_t src[2]) {
+  return vext_u8(src[0], src[1], bytes);
+}
+
+template <int bytes>
+inline uint8x16_t VshrU128(const uint8x16_t src[2]) {
+  return vextq_u8(src[0], src[1], bytes);
+}
+
+template <int bytes>
+inline uint16x8_t VshrU128(const uint16x8x2_t src) {
+  return vextq_u16(src.val[0], src.val[1], bytes / 2);
+}
+
+template <int bytes>
+inline uint16x8_t VshrU128(const uint16x8_t src[2]) {
+  return vextq_u16(src[0], src[1], bytes / 2);
+}
+
+inline uint32x4_t Square(uint16x4_t s) { return vmull_u16(s, s); }
+
+inline void Square(const uint16x8_t src, uint32x4_t dst[2]) {
+  const uint16x4_t s_lo = vget_low_u16(src);
+  const uint16x4_t s_hi = vget_high_u16(src);
+  dst[0] = Square(s_lo);
+  dst[1] = Square(s_hi);
+}
+
+template <int offset>
+inline void Prepare3_8(const uint8x16_t src[2], uint8x16_t dst[3]) {
+  dst[0] = VshrU128<offset + 0>(src);
+  dst[1] = VshrU128<offset + 1>(src);
+  dst[2] = VshrU128<offset + 2>(src);
+}
+
+inline void Prepare3_16(const uint16x8_t src[2], uint16x8_t dst[3]) {
+  dst[0] = src[0];
+  dst[1] = vextq_u16(src[0], src[1], 1);
+  dst[2] = vextq_u16(src[0], src[1], 2);
+}
+
+template <int offset>
+inline void Prepare5_8(const uint8x16_t src[2], uint8x16_t dst[5]) {
+  dst[0] = VshrU128<offset + 0>(src);
+  dst[1] = VshrU128<offset + 1>(src);
+  dst[2] = VshrU128<offset + 2>(src);
+  dst[3] = VshrU128<offset + 3>(src);
+  dst[4] = VshrU128<offset + 4>(src);
+}
+
+inline void Prepare5_16(const uint16x8_t src[2], uint16x8_t dst[5]) {
+  dst[0] = src[0];
+  dst[1] = vextq_u16(src[0], src[1], 1);
+  dst[2] = vextq_u16(src[0], src[1], 2);
+  dst[3] = vextq_u16(src[0], src[1], 3);
+  dst[4] = vextq_u16(src[0], src[1], 4);
+}
+
+inline void Prepare3_32(const uint32x4_t src[2], uint32x4_t dst[3]) {
+  dst[0] = src[0];
+  dst[1] = vextq_u32(src[0], src[1], 1);
+  dst[2] = vextq_u32(src[0], src[1], 2);
+}
+
+inline void Prepare5_32(const uint32x4_t src[2], uint32x4_t dst[5]) {
+  Prepare3_32(src, dst);
+  dst[3] = vextq_u32(src[0], src[1], 3);
+  dst[4] = src[1];
+}
+
+inline uint16x8_t Sum3WLo16(const uint8x16_t src[3]) {
+  const uint16x8_t sum = vaddl_u8(vget_low_u8(src[0]), vget_low_u8(src[1]));
+  return vaddw_u8(sum, vget_low_u8(src[2]));
+}
+
+inline uint16x8_t Sum3WHi16(const uint8x16_t src[3]) {
+  const uint16x8_t sum = vaddl_u8(vget_high_u8(src[0]), vget_high_u8(src[1]));
+  return vaddw_u8(sum, vget_high_u8(src[2]));
+}
+
+inline uint16x8_t Sum3_16(const uint16x8_t src0, const uint16x8_t src1,
+                          const uint16x8_t src2) {
+  const uint16x8_t sum = vaddq_u16(src0, src1);
+  return vaddq_u16(sum, src2);
+}
+
+inline uint16x8_t Sum3_16(const uint16x8_t src[3]) {
+  return Sum3_16(src[0], src[1], src[2]);
+}
+
+inline uint32x4_t Sum3_32(const uint32x4_t src0, const uint32x4_t src1,
+                          const uint32x4_t src2) {
+  const uint32x4_t sum = vaddq_u32(src0, src1);
+  return vaddq_u32(sum, src2);
+}
+
+inline uint32x4_t Sum3_32(const uint32x4_t src[3]) {
+  return Sum3_32(src[0], src[1], src[2]);
+}
+
+inline void Sum3_32(const uint32x4_t src[3][2], uint32x4_t dst[2]) {
+  dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]);
+  dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]);
+}
+
+inline uint16x8_t Sum5_16(const uint16x8_t src[5]) {
+  const uint16x8_t sum01 = vaddq_u16(src[0], src[1]);
+  const uint16x8_t sum23 = vaddq_u16(src[2], src[3]);
+  const uint16x8_t sum = vaddq_u16(sum01, sum23);
+  return vaddq_u16(sum, src[4]);
+}
+
+inline uint32x4_t Sum5_32(const uint32x4_t* src0, const uint32x4_t* src1,
+                          const uint32x4_t* src2, const uint32x4_t* src3,
+                          const uint32x4_t* src4) {
+  const uint32x4_t sum01 = vaddq_u32(*src0, *src1);
+  const uint32x4_t sum23 = vaddq_u32(*src2, *src3);
+  const uint32x4_t sum = vaddq_u32(sum01, sum23);
+  return vaddq_u32(sum, *src4);
+}
+
+inline uint32x4_t Sum5_32(const uint32x4_t src[5]) {
+  return Sum5_32(&src[0], &src[1], &src[2], &src[3], &src[4]);
+}
+
+inline void Sum5_32(const uint32x4_t src[5][2], uint32x4_t dst[2]) {
+  dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]);
+  dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]);
+}
+
+inline uint16x8_t Sum3Horizontal16(const uint16x8_t src[2]) {
+  uint16x8_t s[3];
+  Prepare3_16(src, s);
+  return Sum3_16(s);
+}
+
+inline void Sum3Horizontal32(const uint32x4_t src[3], uint32x4_t dst[2]) {
+  uint32x4_t s[3];
+  Prepare3_32(src + 0, s);
+  dst[0] = Sum3_32(s);
+  Prepare3_32(src + 1, s);
+  dst[1] = Sum3_32(s);
+}
+
+inline uint16x8_t Sum5Horizontal16(const uint16x8_t src[2]) {
+  uint16x8_t s[5];
+  Prepare5_16(src, s);
+  return Sum5_16(s);
+}
+
+inline void Sum5Horizontal32(const uint32x4_t src[3], uint32x4_t dst[2]) {
+  uint32x4_t s[5];
+  Prepare5_32(src + 0, s);
+  dst[0] = Sum5_32(s);
+  Prepare5_32(src + 1, s);
+  dst[1] = Sum5_32(s);
+}
+
+void SumHorizontal16(const uint16x8_t src[2], uint16x8_t* const row3,
+                     uint16x8_t* const row5) {
+  uint16x8_t s[5];
+  Prepare5_16(src, s);
+  const uint16x8_t sum04 = vaddq_u16(s[0], s[4]);
+  *row3 = Sum3_16(s + 1);
+  *row5 = vaddq_u16(sum04, *row3);
+}
+
+inline void SumHorizontal16(const uint16x8_t src[3], uint16x8_t* const row3_0,
+                            uint16x8_t* const row3_1, uint16x8_t* const row5_0,
+                            uint16x8_t* const row5_1) {
+  SumHorizontal16(src + 0, row3_0, row5_0);
+  SumHorizontal16(src + 1, row3_1, row5_1);
+}
+
+void SumHorizontal32(const uint32x4_t src[5], uint32x4_t* const row_sq3,
+                     uint32x4_t* const row_sq5) {
+  const uint32x4_t sum04 = vaddq_u32(src[0], src[4]);
+  *row_sq3 = Sum3_32(src + 1);
+  *row_sq5 = vaddq_u32(sum04, *row_sq3);
+}
+
+inline void SumHorizontal32(const uint32x4_t src[3],
+                            uint32x4_t* const row_sq3_0,
+                            uint32x4_t* const row_sq3_1,
+                            uint32x4_t* const row_sq5_0,
+                            uint32x4_t* const row_sq5_1) {
+  uint32x4_t s[5];
+  Prepare5_32(src + 0, s);
+  SumHorizontal32(s, row_sq3_0, row_sq5_0);
+  Prepare5_32(src + 1, s);
+  SumHorizontal32(s, row_sq3_1, row_sq5_1);
+}
+
+inline uint16x8_t Sum343Lo(const uint8x16_t ma3[3]) {
+  const uint16x8_t sum = Sum3WLo16(ma3);
+  const uint16x8_t sum3 = Sum3_16(sum, sum, sum);
+  return VaddwLo8(sum3, ma3[1]);
+}
+
+inline uint16x8_t Sum343Hi(const uint8x16_t ma3[3]) {
+  const uint16x8_t sum = Sum3WHi16(ma3);
+  const uint16x8_t sum3 = Sum3_16(sum, sum, sum);
+  return VaddwHi8(sum3, ma3[1]);
+}
+
+inline uint32x4_t Sum343(const uint32x4_t src[3]) {
+  const uint32x4_t sum = Sum3_32(src);
+  const uint32x4_t sum3 = Sum3_32(sum, sum, sum);
+  return vaddq_u32(sum3, src[1]);
+}
+
+inline void Sum343(const uint32x4_t src[3], uint32x4_t dst[2]) {
+  uint32x4_t s[3];
+  Prepare3_32(src + 0, s);
+  dst[0] = Sum343(s);
+  Prepare3_32(src + 1, s);
+  dst[1] = Sum343(s);
+}
+
+inline uint16x8_t Sum565Lo(const uint8x16_t src[3]) {
+  const uint16x8_t sum = Sum3WLo16(src);
+  const uint16x8_t sum4 = vshlq_n_u16(sum, 2);
+  const uint16x8_t sum5 = vaddq_u16(sum4, sum);
+  return VaddwLo8(sum5, src[1]);
+}
+
+inline uint16x8_t Sum565Hi(const uint8x16_t src[3]) {
+  const uint16x8_t sum = Sum3WHi16(src);
+  const uint16x8_t sum4 = vshlq_n_u16(sum, 2);
+  const uint16x8_t sum5 = vaddq_u16(sum4, sum);
+  return VaddwHi8(sum5, src[1]);
+}
+
+inline uint32x4_t Sum565(const uint32x4_t src[3]) {
+  const uint32x4_t sum = Sum3_32(src);
+  const uint32x4_t sum4 = vshlq_n_u32(sum, 2);
+  const uint32x4_t sum5 = vaddq_u32(sum4, sum);
+  return vaddq_u32(sum5, src[1]);
+}
+
+inline void Sum565(const uint32x4_t src[3], uint32x4_t dst[2]) {
+  uint32x4_t s[3];
+  Prepare3_32(src + 0, s);
+  dst[0] = Sum565(s);
+  Prepare3_32(src + 1, s);
+  dst[1] = Sum565(s);
+}
+
+inline void BoxSum(const uint16_t* src, const ptrdiff_t src_stride,
+                   const ptrdiff_t width, const ptrdiff_t sum_stride,
+                   const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5,
+                   uint32_t* square_sum3, uint32_t* square_sum5) {
+  const ptrdiff_t overread_in_bytes =
+      kOverreadInBytesPass1 - sizeof(*src) * width;
+  int y = 2;
+  do {
+    uint16x8_t s[3];
+    uint32x4_t sq[6];
+    s[0] = Load1QMsanU16(src, overread_in_bytes);
+    Square(s[0], sq);
+    ptrdiff_t x = sum_width;
+    do {
+      uint16x8_t row3[2], row5[2];
+      uint32x4_t row_sq3[2], row_sq5[2];
+      s[1] = Load1QMsanU16(
+          src + 8, overread_in_bytes + sizeof(*src) * (sum_width - x + 8));
+      x -= 16;
+      src += 16;
+      s[2] = Load1QMsanU16(src,
+                           overread_in_bytes + sizeof(*src) * (sum_width - x));
+      Square(s[1], sq + 2);
+      Square(s[2], sq + 4);
+      SumHorizontal16(s, &row3[0], &row3[1], &row5[0], &row5[1]);
+      StoreAligned32U16(sum3, row3);
+      StoreAligned32U16(sum5, row5);
+      SumHorizontal32(sq + 0, &row_sq3[0], &row_sq3[1], &row_sq5[0],
+                      &row_sq5[1]);
+      StoreAligned32U32(square_sum3 + 0, row_sq3);
+      StoreAligned32U32(square_sum5 + 0, row_sq5);
+      SumHorizontal32(sq + 2, &row_sq3[0], &row_sq3[1], &row_sq5[0],
+                      &row_sq5[1]);
+      StoreAligned32U32(square_sum3 + 8, row_sq3);
+      StoreAligned32U32(square_sum5 + 8, row_sq5);
+      s[0] = s[2];
+      sq[0] = sq[4];
+      sq[1] = sq[5];
+      sum3 += 16;
+      sum5 += 16;
+      square_sum3 += 16;
+      square_sum5 += 16;
+    } while (x != 0);
+    src += src_stride - sum_width;
+    sum3 += sum_stride - sum_width;
+    sum5 += sum_stride - sum_width;
+    square_sum3 += sum_stride - sum_width;
+    square_sum5 += sum_stride - sum_width;
+  } while (--y != 0);
+}
+
+template <int size>
+inline void BoxSum(const uint16_t* src, const ptrdiff_t src_stride,
+                   const ptrdiff_t width, const ptrdiff_t sum_stride,
+                   const ptrdiff_t sum_width, uint16_t* sums,
+                   uint32_t* square_sums) {
+  static_assert(size == 3 || size == 5, "");
+  const ptrdiff_t overread_in_bytes =
+      ((size == 5) ? kOverreadInBytesPass1 : kOverreadInBytesPass2) -
+      sizeof(*src) * width;
+  int y = 2;
+  do {
+    uint16x8_t s[3];
+    uint32x4_t sq[6];
+    s[0] = Load1QMsanU16(src, overread_in_bytes);
+    Square(s[0], sq);
+    ptrdiff_t x = sum_width;
+    do {
+      uint16x8_t row[2];
+      uint32x4_t row_sq[4];
+      s[1] = Load1QMsanU16(
+          src + 8, overread_in_bytes + sizeof(*src) * (sum_width - x + 8));
+      x -= 16;
+      src += 16;
+      s[2] = Load1QMsanU16(src,
+                           overread_in_bytes + sizeof(*src) * (sum_width - x));
+      Square(s[1], sq + 2);
+      Square(s[2], sq + 4);
+      if (size == 3) {
+        row[0] = Sum3Horizontal16(s + 0);
+        row[1] = Sum3Horizontal16(s + 1);
+        Sum3Horizontal32(sq + 0, row_sq + 0);
+        Sum3Horizontal32(sq + 2, row_sq + 2);
+      } else {
+        row[0] = Sum5Horizontal16(s + 0);
+        row[1] = Sum5Horizontal16(s + 1);
+        Sum5Horizontal32(sq + 0, row_sq + 0);
+        Sum5Horizontal32(sq + 2, row_sq + 2);
+      }
+      StoreAligned32U16(sums, row);
+      StoreAligned64U32(square_sums, row_sq);
+      s[0] = s[2];
+      sq[0] = sq[4];
+      sq[1] = sq[5];
+      sums += 16;
+      square_sums += 16;
+    } while (x != 0);
+    src += src_stride - sum_width;
+    sums += sum_stride - sum_width;
+    square_sums += sum_stride - sum_width;
+  } while (--y != 0);
+}
+
+template <int n>
+inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq,
+                              const uint32_t scale) {
+  // a = |sum_sq|
+  // d = |sum|
+  // p = (a * n < d * d) ? 0 : a * n - d * d;
+  const uint32x4_t dxd = vmull_u16(sum, sum);
+  const uint32x4_t axn = vmulq_n_u32(sum_sq, n);
+  // Ensure |p| does not underflow by using saturating subtraction.
+  const uint32x4_t p = vqsubq_u32(axn, dxd);
+  const uint32x4_t pxs = vmulq_n_u32(p, scale);
+  // vrshrn_n_u32() (narrowing shift) can only shift by 16 and kSgrProjScaleBits
+  // is 20.
+  const uint32x4_t shifted = vrshrq_n_u32(pxs, kSgrProjScaleBits);
+  return vmovn_u32(shifted);
+}
+
+template <int n>
+inline uint16x8_t CalculateMa(const uint16x8_t sum, const uint32x4_t sum_sq[2],
+                              const uint32_t scale) {
+  static_assert(n == 9 || n == 25, "");
+  const uint16x8_t b = vrshrq_n_u16(sum, 2);
+  const uint16x4_t sum_lo = vget_low_u16(b);
+  const uint16x4_t sum_hi = vget_high_u16(b);
+  const uint16x4_t z0 =
+      CalculateMa<n>(sum_lo, vrshrq_n_u32(sum_sq[0], 4), scale);
+  const uint16x4_t z1 =
+      CalculateMa<n>(sum_hi, vrshrq_n_u32(sum_sq[1], 4), scale);
+  return vcombine_u16(z0, z1);
+}
+
+inline void CalculateB5(const uint16x8_t sum, const uint16x8_t ma,
+                        uint32x4_t b[2]) {
+  // one_over_n == 164.
+  constexpr uint32_t one_over_n =
+      ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25;
+  // one_over_n_quarter == 41.
+  constexpr uint32_t one_over_n_quarter = one_over_n >> 2;
+  static_assert(one_over_n == one_over_n_quarter << 2, "");
+  // |ma| is in range [0, 255].
+  const uint32x4_t m2 = VmullLo16(ma, sum);
+  const uint32x4_t m3 = VmullHi16(ma, sum);
+  const uint32x4_t m0 = vmulq_n_u32(m2, one_over_n_quarter);
+  const uint32x4_t m1 = vmulq_n_u32(m3, one_over_n_quarter);
+  b[0] = vrshrq_n_u32(m0, kSgrProjReciprocalBits - 2);
+  b[1] = vrshrq_n_u32(m1, kSgrProjReciprocalBits - 2);
+}
+
+inline void CalculateB3(const uint16x8_t sum, const uint16x8_t ma,
+                        uint32x4_t b[2]) {
+  // one_over_n == 455.
+  constexpr uint32_t one_over_n =
+      ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9;
+  const uint32x4_t m0 = VmullLo16(ma, sum);
+  const uint32x4_t m1 = VmullHi16(ma, sum);
+  const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n);
+  const uint32x4_t m3 = vmulq_n_u32(m1, one_over_n);
+  b[0] = vrshrq_n_u32(m2, kSgrProjReciprocalBits);
+  b[1] = vrshrq_n_u32(m3, kSgrProjReciprocalBits);
+}
+
+inline void CalculateSumAndIndex3(const uint16x8_t s3[3],
+                                  const uint32x4_t sq3[3][2],
+                                  const uint32_t scale, uint16x8_t* const sum,
+                                  uint16x8_t* const index) {
+  uint32x4_t sum_sq[2];
+  *sum = Sum3_16(s3);
+  Sum3_32(sq3, sum_sq);
+  *index = CalculateMa<9>(*sum, sum_sq, scale);
+}
+
+inline void CalculateSumAndIndex5(const uint16x8_t s5[5],
+                                  const uint32x4_t sq5[5][2],
+                                  const uint32_t scale, uint16x8_t* const sum,
+                                  uint16x8_t* const index) {
+  uint32x4_t sum_sq[2];
+  *sum = Sum5_16(s5);
+  Sum5_32(sq5, sum_sq);
+  *index = CalculateMa<25>(*sum, sum_sq, scale);
+}
+
+template <int n, int offset>
+inline void LookupIntermediate(const uint16x8_t sum, const uint16x8_t index,
+                               uint8x16_t* const ma, uint32x4_t b[2]) {
+  static_assert(n == 9 || n == 25, "");
+  static_assert(offset == 0 || offset == 8, "");
+
+  const uint8x8_t idx = vqmovn_u16(index);
+  uint8_t temp[8];
+  vst1_u8(temp, idx);
+  *ma = vsetq_lane_u8(kSgrMaLookup[temp[0]], *ma, offset + 0);
+  *ma = vsetq_lane_u8(kSgrMaLookup[temp[1]], *ma, offset + 1);
+  *ma = vsetq_lane_u8(kSgrMaLookup[temp[2]], *ma, offset + 2);
+  *ma = vsetq_lane_u8(kSgrMaLookup[temp[3]], *ma, offset + 3);
+  *ma = vsetq_lane_u8(kSgrMaLookup[temp[4]], *ma, offset + 4);
+  *ma = vsetq_lane_u8(kSgrMaLookup[temp[5]], *ma, offset + 5);
+  *ma = vsetq_lane_u8(kSgrMaLookup[temp[6]], *ma, offset + 6);
+  *ma = vsetq_lane_u8(kSgrMaLookup[temp[7]], *ma, offset + 7);
+  // b = ma * b * one_over_n
+  // |ma| = [0, 255]
+  // |sum| is a box sum with radius 1 or 2.
+  // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
+  // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
+  // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
+  // When radius is 2 |n| is 25. |one_over_n| is 164.
+  // When radius is 1 |n| is 9. |one_over_n| is 455.
+  // |kSgrProjReciprocalBits| is 12.
+  // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
+  // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
+  const uint16x8_t maq =
+      vmovl_u8((offset == 0) ? vget_low_u8(*ma) : vget_high_u8(*ma));
+  if (n == 9) {
+    CalculateB3(sum, maq, b);
+  } else {
+    CalculateB5(sum, maq, b);
+  }
+}
+
+inline uint8x8_t AdjustValue(const uint8x8_t value, const uint8x8_t index,
+                             const int threshold) {
+  const uint8x8_t thresholds = vdup_n_u8(threshold);
+  const uint8x8_t offset = vcgt_u8(index, thresholds);
+  // Adding 255 is equivalent to subtracting 1 for 8-bit data.
+  return vadd_u8(value, offset);
+}
+
+inline uint8x8_t MaLookupAndAdjust(const uint8x8x4_t table0,
+                                   const uint8x8x2_t table1,
+                                   const uint16x8_t index) {
+  const uint8x8_t idx = vqmovn_u16(index);
+  // All elements whose indices are out of range [0, 47] are set to 0.
+  uint8x8_t val = vtbl4_u8(table0, idx);  // Range [0, 31].
+  // Subtract 8 to shuffle the next index range.
+  const uint8x8_t sub_idx = vsub_u8(idx, vdup_n_u8(32));
+  const uint8x8_t res = vtbl2_u8(table1, sub_idx);  // Range [32, 47].
+  // Use OR instruction to combine shuffle results together.
+  val = vorr_u8(val, res);
+
+  // For elements whose indices are larger than 47, since they seldom change
+  // values with the increase of the index, we use comparison and arithmetic
+  // operations to calculate their values.
+  // Elements whose indices are larger than 47 (with value 0) are set to 5.
+  val = vmax_u8(val, vdup_n_u8(5));
+  val = AdjustValue(val, idx, 55);   // 55 is the last index which value is 5.
+  val = AdjustValue(val, idx, 72);   // 72 is the last index which value is 4.
+  val = AdjustValue(val, idx, 101);  // 101 is the last index which value is 3.
+  val = AdjustValue(val, idx, 169);  // 169 is the last index which value is 2.
+  val = AdjustValue(val, idx, 254);  // 254 is the last index which value is 1.
+  return val;
+}
+
+inline void CalculateIntermediate(const uint16x8_t sum[2],
+                                  const uint16x8_t index[2],
+                                  uint8x16_t* const ma, uint32x4_t b0[2],
+                                  uint32x4_t b1[2]) {
+  // Use table lookup to read elements whose indices are less than 48.
+  // Using one uint8x8x4_t vector and one uint8x8x2_t vector is faster than
+  // using two uint8x8x3_t vectors.
+  uint8x8x4_t table0;
+  uint8x8x2_t table1;
+  table0.val[0] = vld1_u8(kSgrMaLookup + 0 * 8);
+  table0.val[1] = vld1_u8(kSgrMaLookup + 1 * 8);
+  table0.val[2] = vld1_u8(kSgrMaLookup + 2 * 8);
+  table0.val[3] = vld1_u8(kSgrMaLookup + 3 * 8);
+  table1.val[0] = vld1_u8(kSgrMaLookup + 4 * 8);
+  table1.val[1] = vld1_u8(kSgrMaLookup + 5 * 8);
+  const uint8x8_t ma_lo = MaLookupAndAdjust(table0, table1, index[0]);
+  const uint8x8_t ma_hi = MaLookupAndAdjust(table0, table1, index[1]);
+  *ma = vcombine_u8(ma_lo, ma_hi);
+  // b = ma * b * one_over_n
+  // |ma| = [0, 255]
+  // |sum| is a box sum with radius 1 or 2.
+  // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
+  // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
+  // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
+  // When radius is 2 |n| is 25. |one_over_n| is 164.
+  // When radius is 1 |n| is 9. |one_over_n| is 455.
+  // |kSgrProjReciprocalBits| is 12.
+  // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
+  // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
+  const uint16x8_t maq0 = vmovl_u8(vget_low_u8(*ma));
+  CalculateB3(sum[0], maq0, b0);
+  const uint16x8_t maq1 = vmovl_u8(vget_high_u8(*ma));
+  CalculateB3(sum[1], maq1, b1);
+}
+
+inline void CalculateIntermediate(const uint16x8_t sum[2],
+                                  const uint16x8_t index[2], uint8x16_t ma[2],
+                                  uint32x4_t b[4]) {
+  uint8x16_t mas;
+  CalculateIntermediate(sum, index, &mas, b + 0, b + 2);
+  ma[0] = vcombine_u8(vget_low_u8(ma[0]), vget_low_u8(mas));
+  ma[1] = vextq_u8(mas, vdupq_n_u8(0), 8);
+}
+
+template <int offset>
+inline void CalculateIntermediate5(const uint16x8_t s5[5],
+                                   const uint32x4_t sq5[5][2],
+                                   const uint32_t scale, uint8x16_t* const ma,
+                                   uint32x4_t b[2]) {
+  static_assert(offset == 0 || offset == 8, "");
+  uint16x8_t sum, index;
+  CalculateSumAndIndex5(s5, sq5, scale, &sum, &index);
+  LookupIntermediate<25, offset>(sum, index, ma, b);
+}
+
+inline void CalculateIntermediate3(const uint16x8_t s3[3],
+                                   const uint32x4_t sq3[3][2],
+                                   const uint32_t scale, uint8x16_t* const ma,
+                                   uint32x4_t b[2]) {
+  uint16x8_t sum, index;
+  CalculateSumAndIndex3(s3, sq3, scale, &sum, &index);
+  LookupIntermediate<9, 0>(sum, index, ma, b);
+}
+
+inline void Store343_444(const uint32x4_t b3[3], const ptrdiff_t x,
+                         uint32x4_t sum_b343[2], uint32x4_t sum_b444[2],
+                         uint32_t* const b343, uint32_t* const b444) {
+  uint32x4_t b[3], sum_b111[2];
+  Prepare3_32(b3 + 0, b);
+  sum_b111[0] = Sum3_32(b);
+  sum_b444[0] = vshlq_n_u32(sum_b111[0], 2);
+  sum_b343[0] = vsubq_u32(sum_b444[0], sum_b111[0]);
+  sum_b343[0] = vaddq_u32(sum_b343[0], b[1]);
+  Prepare3_32(b3 + 1, b);
+  sum_b111[1] = Sum3_32(b);
+  sum_b444[1] = vshlq_n_u32(sum_b111[1], 2);
+  sum_b343[1] = vsubq_u32(sum_b444[1], sum_b111[1]);
+  sum_b343[1] = vaddq_u32(sum_b343[1], b[1]);
+  StoreAligned32U32(b444 + x, sum_b444);
+  StoreAligned32U32(b343 + x, sum_b343);
+}
+
+inline void Store343_444Lo(const uint8x16_t ma3[3], const uint32x4_t b3[3],
+                           const ptrdiff_t x, uint16x8_t* const sum_ma343,
+                           uint16x8_t* const sum_ma444, uint32x4_t sum_b343[2],
+                           uint32x4_t sum_b444[2], uint16_t* const ma343,
+                           uint16_t* const ma444, uint32_t* const b343,
+                           uint32_t* const b444) {
+  const uint16x8_t sum_ma111 = Sum3WLo16(ma3);
+  *sum_ma444 = vshlq_n_u16(sum_ma111, 2);
+  vst1q_u16(ma444 + x, *sum_ma444);
+  const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111);
+  *sum_ma343 = VaddwLo8(sum333, ma3[1]);
+  vst1q_u16(ma343 + x, *sum_ma343);
+  Store343_444(b3, x, sum_b343, sum_b444, b343, b444);
+}
+
+inline void Store343_444Hi(const uint8x16_t ma3[3], const uint32x4_t b3[2],
+                           const ptrdiff_t x, uint16x8_t* const sum_ma343,
+                           uint16x8_t* const sum_ma444, uint32x4_t sum_b343[2],
+                           uint32x4_t sum_b444[2], uint16_t* const ma343,
+                           uint16_t* const ma444, uint32_t* const b343,
+                           uint32_t* const b444) {
+  const uint16x8_t sum_ma111 = Sum3WHi16(ma3);
+  *sum_ma444 = vshlq_n_u16(sum_ma111, 2);
+  vst1q_u16(ma444 + x, *sum_ma444);
+  const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111);
+  *sum_ma343 = VaddwHi8(sum333, ma3[1]);
+  vst1q_u16(ma343 + x, *sum_ma343);
+  Store343_444(b3, x, sum_b343, sum_b444, b343, b444);
+}
+
+inline void Store343_444Lo(const uint8x16_t ma3[3], const uint32x4_t b3[2],
+                           const ptrdiff_t x, uint16x8_t* const sum_ma343,
+                           uint32x4_t sum_b343[2], uint16_t* const ma343,
+                           uint16_t* const ma444, uint32_t* const b343,
+                           uint32_t* const b444) {
+  uint16x8_t sum_ma444;
+  uint32x4_t sum_b444[2];
+  Store343_444Lo(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343,
+                 ma444, b343, b444);
+}
+
+inline void Store343_444Hi(const uint8x16_t ma3[3], const uint32x4_t b3[2],
+                           const ptrdiff_t x, uint16x8_t* const sum_ma343,
+                           uint32x4_t sum_b343[2], uint16_t* const ma343,
+                           uint16_t* const ma444, uint32_t* const b343,
+                           uint32_t* const b444) {
+  uint16x8_t sum_ma444;
+  uint32x4_t sum_b444[2];
+  Store343_444Hi(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343,
+                 ma444, b343, b444);
+}
+
+inline void Store343_444Lo(const uint8x16_t ma3[3], const uint32x4_t b3[2],
+                           const ptrdiff_t x, uint16_t* const ma343,
+                           uint16_t* const ma444, uint32_t* const b343,
+                           uint32_t* const b444) {
+  uint16x8_t sum_ma343;
+  uint32x4_t sum_b343[2];
+  Store343_444Lo(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444);
+}
+
+inline void Store343_444Hi(const uint8x16_t ma3[3], const uint32x4_t b3[2],
+                           const ptrdiff_t x, uint16_t* const ma343,
+                           uint16_t* const ma444, uint32_t* const b343,
+                           uint32_t* const b444) {
+  uint16x8_t sum_ma343;
+  uint32x4_t sum_b343[2];
+  Store343_444Hi(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo(
+    const uint16x8_t s[2][4], const uint32_t scale, uint16_t* const sum5[5],
+    uint32_t* const square_sum5[5], uint32x4_t sq[2][8], uint8x16_t* const ma,
+    uint32x4_t b[2]) {
+  uint16x8_t s5[2][5];
+  uint32x4_t sq5[5][2];
+  Square(s[0][1], sq[0] + 2);
+  Square(s[1][1], sq[1] + 2);
+  s5[0][3] = Sum5Horizontal16(s[0]);
+  vst1q_u16(sum5[3], s5[0][3]);
+  s5[0][4] = Sum5Horizontal16(s[1]);
+  vst1q_u16(sum5[4], s5[0][4]);
+  Sum5Horizontal32(sq[0], sq5[3]);
+  StoreAligned32U32(square_sum5[3], sq5[3]);
+  Sum5Horizontal32(sq[1], sq5[4]);
+  StoreAligned32U32(square_sum5[4], sq5[4]);
+  LoadAligned16x3U16(sum5, 0, s5[0]);
+  LoadAligned32x3U32(square_sum5, 0, sq5);
+  CalculateIntermediate5<0>(s5[0], sq5, scale, ma, b);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
+    const uint16x8_t s[2][4], const ptrdiff_t sum_width, const ptrdiff_t x,
+    const uint32_t scale, uint16_t* const sum5[5],
+    uint32_t* const square_sum5[5], uint32x4_t sq[2][8], uint8x16_t ma[2],
+    uint32x4_t b[6]) {
+  uint16x8_t s5[2][5];
+  uint32x4_t sq5[5][2];
+  Square(s[0][2], sq[0] + 4);
+  Square(s[1][2], sq[1] + 4);
+  s5[0][3] = Sum5Horizontal16(s[0] + 1);
+  s5[1][3] = Sum5Horizontal16(s[0] + 2);
+  vst1q_u16(sum5[3] + x + 0, s5[0][3]);
+  vst1q_u16(sum5[3] + x + 8, s5[1][3]);
+  s5[0][4] = Sum5Horizontal16(s[1] + 1);
+  s5[1][4] = Sum5Horizontal16(s[1] + 2);
+  vst1q_u16(sum5[4] + x + 0, s5[0][4]);
+  vst1q_u16(sum5[4] + x + 8, s5[1][4]);
+  Sum5Horizontal32(sq[0] + 2, sq5[3]);
+  StoreAligned32U32(square_sum5[3] + x, sq5[3]);
+  Sum5Horizontal32(sq[1] + 2, sq5[4]);
+  StoreAligned32U32(square_sum5[4] + x, sq5[4]);
+  LoadAligned16x3U16(sum5, x, s5[0]);
+  LoadAligned32x3U32(square_sum5, x, sq5);
+  CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], b + 2);
+
+  Square(s[0][3], sq[0] + 6);
+  Square(s[1][3], sq[1] + 6);
+  Sum5Horizontal32(sq[0] + 4, sq5[3]);
+  StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]);
+  Sum5Horizontal32(sq[1] + 4, sq5[4]);
+  StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]);
+  LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]);
+  LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5);
+  CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], b + 4);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo(
+    const uint16x8_t s[2], const uint32_t scale, const uint16_t* const sum5[5],
+    const uint32_t* const square_sum5[5], uint32x4_t sq[4],
+    uint8x16_t* const ma, uint32x4_t b[2]) {
+  uint16x8_t s5[5];
+  uint32x4_t sq5[5][2];
+  Square(s[1], sq + 2);
+  s5[3] = s5[4] = Sum5Horizontal16(s);
+  Sum5Horizontal32(sq, sq5[3]);
+  sq5[4][0] = sq5[3][0];
+  sq5[4][1] = sq5[3][1];
+  LoadAligned16x3U16(sum5, 0, s5);
+  LoadAligned32x3U32(square_sum5, 0, sq5);
+  CalculateIntermediate5<0>(s5, sq5, scale, ma, b);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
+    const uint16x8_t s[4], const ptrdiff_t sum_width, const ptrdiff_t x,
+    const uint32_t scale, const uint16_t* const sum5[5],
+    const uint32_t* const square_sum5[5], uint32x4_t sq[8], uint8x16_t ma[2],
+    uint32x4_t b[6]) {
+  uint16x8_t s5[2][5];
+  uint32x4_t sq5[5][2];
+  Square(s[2], sq + 4);
+  s5[0][3] = Sum5Horizontal16(s + 1);
+  s5[1][3] = Sum5Horizontal16(s + 2);
+  s5[0][4] = s5[0][3];
+  s5[1][4] = s5[1][3];
+  Sum5Horizontal32(sq + 2, sq5[3]);
+  sq5[4][0] = sq5[3][0];
+  sq5[4][1] = sq5[3][1];
+  LoadAligned16x3U16(sum5, x, s5[0]);
+  LoadAligned32x3U32(square_sum5, x, sq5);
+  CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], b + 2);
+
+  Square(s[3], sq + 6);
+  Sum5Horizontal32(sq + 4, sq5[3]);
+  sq5[4][0] = sq5[3][0];
+  sq5[4][1] = sq5[3][1];
+  LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]);
+  LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5);
+  CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], b + 4);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo(
+    const uint16x8_t s[2], const uint32_t scale, uint16_t* const sum3[3],
+    uint32_t* const square_sum3[3], uint32x4_t sq[4], uint8x16_t* const ma,
+    uint32x4_t b[2]) {
+  uint16x8_t s3[3];
+  uint32x4_t sq3[3][2];
+  Square(s[1], sq + 2);
+  s3[2] = Sum3Horizontal16(s);
+  vst1q_u16(sum3[2], s3[2]);
+  Sum3Horizontal32(sq, sq3[2]);
+  StoreAligned32U32(square_sum3[2], sq3[2]);
+  LoadAligned16x2U16(sum3, 0, s3);
+  LoadAligned32x2U32(square_sum3, 0, sq3);
+  CalculateIntermediate3(s3, sq3, scale, ma, b);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
+    const uint16x8_t s[4], const ptrdiff_t x, const ptrdiff_t sum_width,
+    const uint32_t scale, uint16_t* const sum3[3],
+    uint32_t* const square_sum3[3], uint32x4_t sq[8], uint8x16_t ma[2],
+    uint32x4_t b[6]) {
+  uint16x8_t s3[4], sum[2], index[2];
+  uint32x4_t sq3[3][2];
+
+  Square(s[2], sq + 4);
+  s3[2] = Sum3Horizontal16(s + 1);
+  s3[3] = Sum3Horizontal16(s + 2);
+  StoreAligned32U16(sum3[2] + x, s3 + 2);
+  Sum3Horizontal32(sq + 2, sq3[2]);
+  StoreAligned32U32(square_sum3[2] + x + 0, sq3[2]);
+  LoadAligned16x2U16(sum3, x, s3);
+  LoadAligned32x2U32(square_sum3, x, sq3);
+  CalculateSumAndIndex3(s3, sq3, scale, &sum[0], &index[0]);
+
+  Square(s[3], sq + 6);
+  Sum3Horizontal32(sq + 4, sq3[2]);
+  StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]);
+  LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3 + 1);
+  LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3);
+  CalculateSumAndIndex3(s3 + 1, sq3, scale, &sum[1], &index[1]);
+  CalculateIntermediate(sum, index, ma, b + 2);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo(
+    const uint16x8_t s[2][4], const uint16_t scales[2], uint16_t* const sum3[4],
+    uint16_t* const sum5[5], uint32_t* const square_sum3[4],
+    uint32_t* const square_sum5[5], uint32x4_t sq[2][8], uint8x16_t ma3[2][2],
+    uint32x4_t b3[2][6], uint8x16_t* const ma5, uint32x4_t b5[2]) {
+  uint16x8_t s3[4], s5[5], sum[2], index[2];
+  uint32x4_t sq3[4][2], sq5[5][2];
+
+  Square(s[0][1], sq[0] + 2);
+  Square(s[1][1], sq[1] + 2);
+  SumHorizontal16(s[0], &s3[2], &s5[3]);
+  SumHorizontal16(s[1], &s3[3], &s5[4]);
+  vst1q_u16(sum3[2], s3[2]);
+  vst1q_u16(sum3[3], s3[3]);
+  vst1q_u16(sum5[3], s5[3]);
+  vst1q_u16(sum5[4], s5[4]);
+  SumHorizontal32(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
+  StoreAligned32U32(square_sum3[2], sq3[2]);
+  StoreAligned32U32(square_sum5[3], sq5[3]);
+  SumHorizontal32(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
+  StoreAligned32U32(square_sum3[3], sq3[3]);
+  StoreAligned32U32(square_sum5[4], sq5[4]);
+  LoadAligned16x2U16(sum3, 0, s3);
+  LoadAligned32x2U32(square_sum3, 0, sq3);
+  LoadAligned16x3U16(sum5, 0, s5);
+  LoadAligned32x3U32(square_sum5, 0, sq5);
+  CalculateSumAndIndex3(s3 + 0, sq3 + 0, scales[1], &sum[0], &index[0]);
+  CalculateSumAndIndex3(s3 + 1, sq3 + 1, scales[1], &sum[1], &index[1]);
+  CalculateIntermediate(sum, index, &ma3[0][0], b3[0], b3[1]);
+  ma3[1][0] = vextq_u8(ma3[0][0], vdupq_n_u8(0), 8);
+  CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
+    const uint16x8_t s[2][4], const ptrdiff_t x, const uint16_t scales[2],
+    uint16_t* const sum3[4], uint16_t* const sum5[5],
+    uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+    const ptrdiff_t sum_width, uint32x4_t sq[2][8], uint8x16_t ma3[2][2],
+    uint32x4_t b3[2][6], uint8x16_t ma5[2], uint32x4_t b5[6]) {
+  uint16x8_t s3[2][4], s5[2][5], sum[2][2], index[2][2];
+  uint32x4_t sq3[4][2], sq5[5][2];
+
+  SumHorizontal16(s[0] + 1, &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]);
+  vst1q_u16(sum3[2] + x + 0, s3[0][2]);
+  vst1q_u16(sum3[2] + x + 8, s3[1][2]);
+  vst1q_u16(sum5[3] + x + 0, s5[0][3]);
+  vst1q_u16(sum5[3] + x + 8, s5[1][3]);
+  SumHorizontal16(s[1] + 1, &s3[0][3], &s3[1][3], &s5[0][4], &s5[1][4]);
+  vst1q_u16(sum3[3] + x + 0, s3[0][3]);
+  vst1q_u16(sum3[3] + x + 8, s3[1][3]);
+  vst1q_u16(sum5[4] + x + 0, s5[0][4]);
+  vst1q_u16(sum5[4] + x + 8, s5[1][4]);
+  Square(s[0][2], sq[0] + 4);
+  Square(s[1][2], sq[1] + 4);
+  SumHorizontal32(sq[0] + 2, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
+  StoreAligned32U32(square_sum3[2] + x, sq3[2]);
+  StoreAligned32U32(square_sum5[3] + x, sq5[3]);
+  SumHorizontal32(sq[1] + 2, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
+  StoreAligned32U32(square_sum3[3] + x, sq3[3]);
+  StoreAligned32U32(square_sum5[4] + x, sq5[4]);
+  LoadAligned16x2U16(sum3, x, s3[0]);
+  LoadAligned32x2U32(square_sum3, x, sq3);
+  CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum[0][0], &index[0][0]);
+  CalculateSumAndIndex3(s3[0] + 1, sq3 + 1, scales[1], &sum[1][0],
+                        &index[1][0]);
+  LoadAligned16x3U16(sum5, x, s5[0]);
+  LoadAligned32x3U32(square_sum5, x, sq5);
+  CalculateIntermediate5<8>(s5[0], sq5, scales[0], &ma5[0], b5 + 2);
+
+  Square(s[0][3], sq[0] + 6);
+  Square(s[1][3], sq[1] + 6);
+  SumHorizontal32(sq[0] + 4, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
+  StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]);
+  StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]);
+  SumHorizontal32(sq[1] + 4, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
+  StoreAligned32U32(square_sum3[3] + x + 8, sq3[3]);
+  StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]);
+  LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3[1]);
+  LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3);
+  CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum[0][1], &index[0][1]);
+  CalculateSumAndIndex3(s3[1] + 1, sq3 + 1, scales[1], &sum[1][1],
+                        &index[1][1]);
+  CalculateIntermediate(sum[0], index[0], ma3[0], b3[0] + 2);
+  CalculateIntermediate(sum[1], index[1], ma3[1], b3[1] + 2);
+  LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]);
+  LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5);
+  CalculateIntermediate5<0>(s5[1], sq5, scales[0], &ma5[1], b5 + 4);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo(
+    const uint16x8_t s[2], const uint16_t scales[2],
+    const uint16_t* const sum3[4], const uint16_t* const sum5[5],
+    const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
+    uint32x4_t sq[4], uint8x16_t* const ma3, uint8x16_t* const ma5,
+    uint32x4_t b3[2], uint32x4_t b5[2]) {
+  uint16x8_t s3[3], s5[5];
+  uint32x4_t sq3[3][2], sq5[5][2];
+
+  Square(s[1], sq + 2);
+  SumHorizontal16(s, &s3[2], &s5[3]);
+  SumHorizontal32(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
+  LoadAligned16x3U16(sum5, 0, s5);
+  s5[4] = s5[3];
+  LoadAligned32x3U32(square_sum5, 0, sq5);
+  sq5[4][0] = sq5[3][0];
+  sq5[4][1] = sq5[3][1];
+  CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5);
+  LoadAligned16x2U16(sum3, 0, s3);
+  LoadAligned32x2U32(square_sum3, 0, sq3);
+  CalculateIntermediate3(s3, sq3, scales[1], ma3, b3);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
+    const uint16x8_t s[4], const ptrdiff_t sum_width, const ptrdiff_t x,
+    const uint16_t scales[2], const uint16_t* const sum3[4],
+    const uint16_t* const sum5[5], const uint32_t* const square_sum3[4],
+    const uint32_t* const square_sum5[5], uint32x4_t sq[8], uint8x16_t ma3[2],
+    uint8x16_t ma5[2], uint32x4_t b3[6], uint32x4_t b5[6]) {
+  uint16x8_t s3[2][3], s5[2][5], sum[2], index[2];
+  uint32x4_t sq3[3][2], sq5[5][2];
+
+  Square(s[2], sq + 4);
+  SumHorizontal16(s + 1, &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]);
+  SumHorizontal32(sq + 2, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
+  LoadAligned16x3U16(sum5, x, s5[0]);
+  s5[0][4] = s5[0][3];
+  LoadAligned32x3U32(square_sum5, x, sq5);
+  sq5[4][0] = sq5[3][0];
+  sq5[4][1] = sq5[3][1];
+  CalculateIntermediate5<8>(s5[0], sq5, scales[0], ma5, b5 + 2);
+  LoadAligned16x2U16(sum3, x, s3[0]);
+  LoadAligned32x2U32(square_sum3, x, sq3);
+  CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum[0], &index[0]);
+
+  Square(s[3], sq + 6);
+  SumHorizontal32(sq + 4, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
+  LoadAligned16x3U16Msan(sum5, x + 8, sum_width, s5[1]);
+  s5[1][4] = s5[1][3];
+  LoadAligned32x3U32Msan(square_sum5, x + 8, sum_width, sq5);
+  sq5[4][0] = sq5[3][0];
+  sq5[4][1] = sq5[3][1];
+  CalculateIntermediate5<0>(s5[1], sq5, scales[0], ma5 + 1, b5 + 4);
+  LoadAligned16x2U16Msan(sum3, x + 8, sum_width, s3[1]);
+  LoadAligned32x2U32Msan(square_sum3, x + 8, sum_width, sq3);
+  CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum[1], &index[1]);
+  CalculateIntermediate(sum, index, ma3, b3 + 2);
+}
+
+inline void BoxSumFilterPreProcess5(const uint16_t* const src0,
+                                    const uint16_t* const src1, const int width,
+                                    const uint32_t scale,
+                                    uint16_t* const sum5[5],
+                                    uint32_t* const square_sum5[5],
+                                    const ptrdiff_t sum_width, uint16_t* ma565,
+                                    uint32_t* b565) {
+  const ptrdiff_t overread_in_bytes =
+      kOverreadInBytesPass1 - sizeof(*src0) * width;
+  uint16x8_t s[2][4];
+  uint8x16_t mas[2];
+  uint32x4_t sq[2][8], bs[6];
+
+  s[0][0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
+  s[0][1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
+  s[1][0] = Load1QMsanU16(src1 + 0, overread_in_bytes + 0);
+  s[1][1] = Load1QMsanU16(src1 + 8, overread_in_bytes + 16);
+  Square(s[0][0], sq[0]);
+  Square(s[1][0], sq[1]);
+  BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], bs);
+
+  int x = 0;
+  do {
+    uint8x16_t ma5[3];
+    uint16x8_t ma[2];
+    uint32x4_t b[4];
+
+    s[0][2] = Load1QMsanU16(src0 + x + 16,
+                            overread_in_bytes + sizeof(*src0) * (x + 16));
+    s[0][3] = Load1QMsanU16(src0 + x + 24,
+                            overread_in_bytes + sizeof(*src0) * (x + 24));
+    s[1][2] = Load1QMsanU16(src1 + x + 16,
+                            overread_in_bytes + sizeof(*src1) * (x + 16));
+    s[1][3] = Load1QMsanU16(src1 + x + 24,
+                            overread_in_bytes + sizeof(*src1) * (x + 24));
+
+    BoxFilterPreProcess5(s, sum_width, x + 8, scale, sum5, square_sum5, sq, mas,
+                         bs);
+    Prepare3_8<0>(mas, ma5);
+    ma[0] = Sum565Lo(ma5);
+    ma[1] = Sum565Hi(ma5);
+    StoreAligned32U16(ma565, ma);
+    Sum565(bs + 0, b + 0);
+    Sum565(bs + 2, b + 2);
+    StoreAligned64U32(b565, b);
+    s[0][0] = s[0][2];
+    s[0][1] = s[0][3];
+    s[1][0] = s[1][2];
+    s[1][1] = s[1][3];
+    sq[0][2] = sq[0][6];
+    sq[0][3] = sq[0][7];
+    sq[1][2] = sq[1][6];
+    sq[1][3] = sq[1][7];
+    mas[0] = mas[1];
+    bs[0] = bs[4];
+    bs[1] = bs[5];
+    ma565 += 16;
+    b565 += 16;
+    x += 16;
+  } while (x < width);
+}
+
+template <bool calculate444>
+LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
+    const uint16_t* const src, const int width, const uint32_t scale,
+    uint16_t* const sum3[3], uint32_t* const square_sum3[3],
+    const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343,
+    uint32_t* b444) {
+  const ptrdiff_t overread_in_bytes =
+      kOverreadInBytesPass2 - sizeof(*src) * width;
+  uint16x8_t s[4];
+  uint8x16_t mas[2];
+  uint32x4_t sq[8], bs[6];
+
+  s[0] = Load1QMsanU16(src + 0, overread_in_bytes + 0);
+  s[1] = Load1QMsanU16(src + 8, overread_in_bytes + 16);
+  Square(s[0], sq);
+  // Quiet "may be used uninitialized" warning.
+  mas[0] = mas[1] = vdupq_n_u8(0);
+  BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq, &mas[0], bs);
+
+  int x = 0;
+  do {
+    s[2] = Load1QMsanU16(src + x + 16,
+                         overread_in_bytes + sizeof(*src) * (x + 16));
+    s[3] = Load1QMsanU16(src + x + 24,
+                         overread_in_bytes + sizeof(*src) * (x + 24));
+    BoxFilterPreProcess3(s, x + 8, sum_width, scale, sum3, square_sum3, sq, mas,
+                         bs);
+    uint8x16_t ma3[3];
+    Prepare3_8<0>(mas, ma3);
+    if (calculate444) {  // NOLINT(readability-simplify-boolean-expr)
+      Store343_444Lo(ma3, bs + 0, 0, ma343, ma444, b343, b444);
+      Store343_444Hi(ma3, bs + 2, 8, ma343, ma444, b343, b444);
+      ma444 += 16;
+      b444 += 16;
+    } else {
+      uint16x8_t ma[2];
+      uint32x4_t b[4];
+      ma[0] = Sum343Lo(ma3);
+      ma[1] = Sum343Hi(ma3);
+      StoreAligned32U16(ma343, ma);
+      Sum343(bs + 0, b + 0);
+      Sum343(bs + 2, b + 2);
+      StoreAligned64U32(b343, b);
+    }
+    s[1] = s[3];
+    sq[2] = sq[6];
+    sq[3] = sq[7];
+    mas[0] = mas[1];
+    bs[0] = bs[4];
+    bs[1] = bs[5];
+    ma343 += 16;
+    b343 += 16;
+    x += 16;
+  } while (x < width);
+}
+
+inline void BoxSumFilterPreProcess(
+    const uint16_t* const src0, const uint16_t* const src1, const int width,
+    const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
+    uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+    const ptrdiff_t sum_width, uint16_t* const ma343[4], uint16_t* const ma444,
+    uint16_t* ma565, uint32_t* const b343[4], uint32_t* const b444,
+    uint32_t* b565) {
+  const ptrdiff_t overread_in_bytes =
+      kOverreadInBytesPass1 - sizeof(*src0) * width;
+  uint16x8_t s[2][4];
+  uint8x16_t ma3[2][2], ma5[2];
+  uint32x4_t sq[2][8], b3[2][6], b5[6];
+
+  s[0][0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
+  s[0][1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
+  s[1][0] = Load1QMsanU16(src1 + 0, overread_in_bytes + 0);
+  s[1][1] = Load1QMsanU16(src1 + 8, overread_in_bytes + 16);
+  Square(s[0][0], sq[0]);
+  Square(s[1][0], sq[1]);
+  BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq,
+                        ma3, b3, &ma5[0], b5);
+
+  int x = 0;
+  do {
+    uint16x8_t ma[2];
+    uint32x4_t b[4];
+    uint8x16_t ma3x[3], ma5x[3];
+
+    s[0][2] = Load1QMsanU16(src0 + x + 16,
+                            overread_in_bytes + sizeof(*src0) * (x + 16));
+    s[0][3] = Load1QMsanU16(src0 + x + 24,
+                            overread_in_bytes + sizeof(*src0) * (x + 24));
+    s[1][2] = Load1QMsanU16(src1 + x + 16,
+                            overread_in_bytes + sizeof(*src1) * (x + 16));
+    s[1][3] = Load1QMsanU16(src1 + x + 24,
+                            overread_in_bytes + sizeof(*src1) * (x + 24));
+    BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5,
+                        sum_width, sq, ma3, b3, ma5, b5);
+
+    Prepare3_8<0>(ma3[0], ma3x);
+    ma[0] = Sum343Lo(ma3x);
+    ma[1] = Sum343Hi(ma3x);
+    StoreAligned32U16(ma343[0] + x, ma);
+    Sum343(b3[0] + 0, b + 0);
+    Sum343(b3[0] + 2, b + 2);
+    StoreAligned64U32(b343[0] + x, b);
+    Sum565(b5 + 0, b + 0);
+    Sum565(b5 + 2, b + 2);
+    StoreAligned64U32(b565, b);
+    Prepare3_8<0>(ma3[1], ma3x);
+    Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444);
+    Store343_444Hi(ma3x, b3[1] + 2, x + 8, ma343[1], ma444, b343[1], b444);
+    Prepare3_8<0>(ma5, ma5x);
+    ma[0] = Sum565Lo(ma5x);
+    ma[1] = Sum565Hi(ma5x);
+    StoreAligned32U16(ma565, ma);
+    s[0][0] = s[0][2];
+    s[0][1] = s[0][3];
+    s[1][0] = s[1][2];
+    s[1][1] = s[1][3];
+    sq[0][2] = sq[0][6];
+    sq[0][3] = sq[0][7];
+    sq[1][2] = sq[1][6];
+    sq[1][3] = sq[1][7];
+    ma3[0][0] = ma3[0][1];
+    ma3[1][0] = ma3[1][1];
+    ma5[0] = ma5[1];
+    b3[0][0] = b3[0][4];
+    b3[0][1] = b3[0][5];
+    b3[1][0] = b3[1][4];
+    b3[1][1] = b3[1][5];
+    b5[0] = b5[4];
+    b5[1] = b5[5];
+    ma565 += 16;
+    b565 += 16;
+    x += 16;
+  } while (x < width);
+}
+
+template <int shift>
+inline int16x4_t FilterOutput(const uint32x4_t ma_x_src, const uint32x4_t b) {
+  // ma: 255 * 32 = 8160 (13 bits)
+  // b: 65088 * 32 = 2082816 (21 bits)
+  // v: b - ma * 255 (22 bits)
+  const int32x4_t v = vreinterpretq_s32_u32(vsubq_u32(b, ma_x_src));
+  // kSgrProjSgrBits = 8
+  // kSgrProjRestoreBits = 4
+  // shift = 4 or 5
+  // v >> 8 or 9 (13 bits)
+  return vqrshrn_n_s32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
+}
+
+template <int shift>
+inline int16x8_t CalculateFilteredOutput(const uint16x8_t src,
+                                         const uint16x8_t ma,
+                                         const uint32x4_t b[2]) {
+  const uint32x4_t ma_x_src_lo = VmullLo16(ma, src);
+  const uint32x4_t ma_x_src_hi = VmullHi16(ma, src);
+  const int16x4_t dst_lo = FilterOutput<shift>(ma_x_src_lo, b[0]);
+  const int16x4_t dst_hi = FilterOutput<shift>(ma_x_src_hi, b[1]);
+  return vcombine_s16(dst_lo, dst_hi);  // 13 bits
+}
+
+inline int16x8_t CalculateFilteredOutputPass1(const uint16x8_t src,
+                                              const uint16x8_t ma[2],
+                                              const uint32x4_t b[2][2]) {
+  const uint16x8_t ma_sum = vaddq_u16(ma[0], ma[1]);
+  uint32x4_t b_sum[2];
+  b_sum[0] = vaddq_u32(b[0][0], b[1][0]);
+  b_sum[1] = vaddq_u32(b[0][1], b[1][1]);
+  return CalculateFilteredOutput<5>(src, ma_sum, b_sum);
+}
+
+inline int16x8_t CalculateFilteredOutputPass2(const uint16x8_t src,
+                                              const uint16x8_t ma[3],
+                                              const uint32x4_t b[3][2]) {
+  const uint16x8_t ma_sum = Sum3_16(ma);
+  uint32x4_t b_sum[2];
+  Sum3_32(b, b_sum);
+  return CalculateFilteredOutput<5>(src, ma_sum, b_sum);
+}
+
+inline int16x8_t SelfGuidedFinal(const uint16x8_t src, const int32x4_t v[2]) {
+  const int16x4_t v_lo =
+      vqrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
+  const int16x4_t v_hi =
+      vqrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
+  const int16x8_t vv = vcombine_s16(v_lo, v_hi);
+  return vaddq_s16(vreinterpretq_s16_u16(src), vv);
+}
+
+inline int16x8_t SelfGuidedDoubleMultiplier(const uint16x8_t src,
+                                            const int16x8_t filter[2],
+                                            const int w0, const int w2) {
+  int32x4_t v[2];
+  v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0);
+  v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0);
+  v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2);
+  v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2);
+  return SelfGuidedFinal(src, v);
+}
+
+inline int16x8_t SelfGuidedSingleMultiplier(const uint16x8_t src,
+                                            const int16x8_t filter,
+                                            const int w0) {
+  // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
+  int32x4_t v[2];
+  v[0] = vmull_n_s16(vget_low_s16(filter), w0);
+  v[1] = vmull_n_s16(vget_high_s16(filter), w0);
+  return SelfGuidedFinal(src, v);
+}
+
+inline void ClipAndStore(uint16_t* const dst, const int16x8_t val) {
+  const uint16x8_t val0 = vreinterpretq_u16_s16(vmaxq_s16(val, vdupq_n_s16(0)));
+  const uint16x8_t val1 = vminq_u16(val0, vdupq_n_u16((1 << kBitdepth10) - 1));
+  vst1q_u16(dst, val1);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPass1(
+    const uint16_t* const src, const uint16_t* const src0,
+    const uint16_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5],
+    uint32_t* const square_sum5[5], const int width, const ptrdiff_t sum_width,
+    const uint32_t scale, const int16_t w0, uint16_t* const ma565[2],
+    uint32_t* const b565[2], uint16_t* const dst) {
+  const ptrdiff_t overread_in_bytes =
+      kOverreadInBytesPass1 - sizeof(*src0) * width;
+  uint16x8_t s[2][4];
+  uint8x16_t mas[2];
+  uint32x4_t sq[2][8], bs[6];
+
+  s[0][0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
+  s[0][1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
+  s[1][0] = Load1QMsanU16(src1 + 0, overread_in_bytes + 0);
+  s[1][1] = Load1QMsanU16(src1 + 8, overread_in_bytes + 16);
+
+  Square(s[0][0], sq[0]);
+  Square(s[1][0], sq[1]);
+  BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], bs);
+
+  int x = 0;
+  do {
+    uint16x8_t ma[2];
+    uint32x4_t b[2][2];
+    uint8x16_t ma5[3];
+    int16x8_t p[2];
+
+    s[0][2] = Load1QMsanU16(src0 + x + 16,
+                            overread_in_bytes + sizeof(*src0) * (x + 16));
+    s[0][3] = Load1QMsanU16(src0 + x + 24,
+                            overread_in_bytes + sizeof(*src0) * (x + 24));
+    s[1][2] = Load1QMsanU16(src1 + x + 16,
+                            overread_in_bytes + sizeof(*src1) * (x + 16));
+    s[1][3] = Load1QMsanU16(src1 + x + 24,
+                            overread_in_bytes + sizeof(*src1) * (x + 24));
+    BoxFilterPreProcess5(s, sum_width, x + 8, scale, sum5, square_sum5, sq, mas,
+                         bs);
+    Prepare3_8<0>(mas, ma5);
+    ma[1] = Sum565Lo(ma5);
+    vst1q_u16(ma565[1] + x, ma[1]);
+    Sum565(bs, b[1]);
+    StoreAligned32U32(b565[1] + x, b[1]);
+    const uint16x8_t sr0_lo = vld1q_u16(src + x + 0);
+    const uint16x8_t sr1_lo = vld1q_u16(src + stride + x + 0);
+    ma[0] = vld1q_u16(ma565[0] + x);
+    LoadAligned32U32(b565[0] + x, b[0]);
+    p[0] = CalculateFilteredOutputPass1(sr0_lo, ma, b);
+    p[1] = CalculateFilteredOutput<4>(sr1_lo, ma[1], b[1]);
+    const int16x8_t d00 = SelfGuidedSingleMultiplier(sr0_lo, p[0], w0);
+    const int16x8_t d10 = SelfGuidedSingleMultiplier(sr1_lo, p[1], w0);
+
+    ma[1] = Sum565Hi(ma5);
+    vst1q_u16(ma565[1] + x + 8, ma[1]);
+    Sum565(bs + 2, b[1]);
+    StoreAligned32U32(b565[1] + x + 8, b[1]);
+    const uint16x8_t sr0_hi = vld1q_u16(src + x + 8);
+    const uint16x8_t sr1_hi = vld1q_u16(src + stride + x + 8);
+    ma[0] = vld1q_u16(ma565[0] + x + 8);
+    LoadAligned32U32(b565[0] + x + 8, b[0]);
+    p[0] = CalculateFilteredOutputPass1(sr0_hi, ma, b);
+    p[1] = CalculateFilteredOutput<4>(sr1_hi, ma[1], b[1]);
+    const int16x8_t d01 = SelfGuidedSingleMultiplier(sr0_hi, p[0], w0);
+    ClipAndStore(dst + x + 0, d00);
+    ClipAndStore(dst + x + 8, d01);
+    const int16x8_t d11 = SelfGuidedSingleMultiplier(sr1_hi, p[1], w0);
+    ClipAndStore(dst + stride + x + 0, d10);
+    ClipAndStore(dst + stride + x + 8, d11);
+    s[0][0] = s[0][2];
+    s[0][1] = s[0][3];
+    s[1][0] = s[1][2];
+    s[1][1] = s[1][3];
+    sq[0][2] = sq[0][6];
+    sq[0][3] = sq[0][7];
+    sq[1][2] = sq[1][6];
+    sq[1][3] = sq[1][7];
+    mas[0] = mas[1];
+    bs[0] = bs[4];
+    bs[1] = bs[5];
+    x += 16;
+  } while (x < width);
+}
+
+inline void BoxFilterPass1LastRow(
+    const uint16_t* const src, const uint16_t* const src0, const int width,
+    const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0,
+    uint16_t* const sum5[5], uint32_t* const square_sum5[5], uint16_t* ma565,
+    uint32_t* b565, uint16_t* const dst) {
+  const ptrdiff_t overread_in_bytes =
+      kOverreadInBytesPass1 - sizeof(*src0) * width;
+  uint16x8_t s[4];
+  uint8x16_t mas[2];
+  uint32x4_t sq[8], bs[6];
+
+  s[0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
+  s[1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
+  Square(s[0], sq);
+  BoxFilterPreProcess5LastRowLo(s, scale, sum5, square_sum5, sq, &mas[0], bs);
+
+  int x = 0;
+  do {
+    uint16x8_t ma[2];
+    uint32x4_t b[2][2];
+    uint8x16_t ma5[3];
+
+    s[2] = Load1QMsanU16(src0 + x + 16,
+                         overread_in_bytes + sizeof(*src0) * (x + 16));
+    s[3] = Load1QMsanU16(src0 + x + 24,
+                         overread_in_bytes + sizeof(*src0) * (x + 24));
+    BoxFilterPreProcess5LastRow(s, sum_width, x + 8, scale, sum5, square_sum5,
+                                sq, mas, bs);
+    Prepare3_8<0>(mas, ma5);
+    ma[1] = Sum565Lo(ma5);
+    Sum565(bs, b[1]);
+    ma[0] = vld1q_u16(ma565);
+    LoadAligned32U32(b565, b[0]);
+    const uint16x8_t sr_lo = vld1q_u16(src + x + 0);
+    int16x8_t p = CalculateFilteredOutputPass1(sr_lo, ma, b);
+    const int16x8_t d0 = SelfGuidedSingleMultiplier(sr_lo, p, w0);
+
+    ma[1] = Sum565Hi(ma5);
+    Sum565(bs + 2, b[1]);
+    ma[0] = vld1q_u16(ma565 + 8);
+    LoadAligned32U32(b565 + 8, b[0]);
+    const uint16x8_t sr_hi = vld1q_u16(src + x + 8);
+    p = CalculateFilteredOutputPass1(sr_hi, ma, b);
+    const int16x8_t d1 = SelfGuidedSingleMultiplier(sr_hi, p, w0);
+    ClipAndStore(dst + x + 0, d0);
+    ClipAndStore(dst + x + 8, d1);
+    s[1] = s[3];
+    sq[2] = sq[6];
+    sq[3] = sq[7];
+    mas[0] = mas[1];
+    bs[0] = bs[4];
+    bs[1] = bs[5];
+    ma565 += 16;
+    b565 += 16;
+    x += 16;
+  } while (x < width);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterPass2(
+    const uint16_t* const src, const uint16_t* const src0, const int width,
+    const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0,
+    uint16_t* const sum3[3], uint32_t* const square_sum3[3],
+    uint16_t* const ma343[3], uint16_t* const ma444[2], uint32_t* const b343[3],
+    uint32_t* const b444[2], uint16_t* const dst) {
+  const ptrdiff_t overread_in_bytes =
+      kOverreadInBytesPass2 - sizeof(*src0) * width;
+  uint16x8_t s[4];
+  uint8x16_t mas[2];
+  uint32x4_t sq[8], bs[6];
+
+  s[0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
+  s[1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
+  Square(s[0], sq);
+  // Quiet "may be used uninitialized" warning.
+  mas[0] = mas[1] = vdupq_n_u8(0);
+  BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq, &mas[0], bs);
+
+  int x = 0;
+  do {
+    s[2] = Load1QMsanU16(src0 + x + 16,
+                         overread_in_bytes + sizeof(*src0) * (x + 16));
+    s[3] = Load1QMsanU16(src0 + x + 24,
+                         overread_in_bytes + sizeof(*src0) * (x + 24));
+    BoxFilterPreProcess3(s, x + 8, sum_width, scale, sum3, square_sum3, sq, mas,
+                         bs);
+    uint16x8_t ma[3];
+    uint32x4_t b[3][2];
+    uint8x16_t ma3[3];
+
+    Prepare3_8<0>(mas, ma3);
+    Store343_444Lo(ma3, bs + 0, x, &ma[2], b[2], ma343[2], ma444[1], b343[2],
+                   b444[1]);
+    const uint16x8_t sr_lo = vld1q_u16(src + x + 0);
+    ma[0] = vld1q_u16(ma343[0] + x);
+    ma[1] = vld1q_u16(ma444[0] + x);
+    LoadAligned32U32(b343[0] + x, b[0]);
+    LoadAligned32U32(b444[0] + x, b[1]);
+    const int16x8_t p0 = CalculateFilteredOutputPass2(sr_lo, ma, b);
+
+    Store343_444Hi(ma3, bs + 2, x + 8, &ma[2], b[2], ma343[2], ma444[1],
+                   b343[2], b444[1]);
+    const uint16x8_t sr_hi = vld1q_u16(src + x + 8);
+    ma[0] = vld1q_u16(ma343[0] + x + 8);
+    ma[1] = vld1q_u16(ma444[0] + x + 8);
+    LoadAligned32U32(b343[0] + x + 8, b[0]);
+    LoadAligned32U32(b444[0] + x + 8, b[1]);
+    const int16x8_t p1 = CalculateFilteredOutputPass2(sr_hi, ma, b);
+    const int16x8_t d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0);
+    const int16x8_t d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0);
+    ClipAndStore(dst + x + 0, d0);
+    ClipAndStore(dst + x + 8, d1);
+    s[1] = s[3];
+    sq[2] = sq[6];
+    sq[3] = sq[7];
+    mas[0] = mas[1];
+    bs[0] = bs[4];
+    bs[1] = bs[5];
+    x += 16;
+  } while (x < width);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilter(
+    const uint16_t* const src, const uint16_t* const src0,
+    const uint16_t* const src1, const ptrdiff_t stride, const int width,
+    const uint16_t scales[2], const int16_t w0, const int16_t w2,
+    uint16_t* const sum3[4], uint16_t* const sum5[5],
+    uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+    const ptrdiff_t sum_width, uint16_t* const ma343[4],
+    uint16_t* const ma444[3], uint16_t* const ma565[2], uint32_t* const b343[4],
+    uint32_t* const b444[3], uint32_t* const b565[2], uint16_t* const dst) {
+  const ptrdiff_t overread_in_bytes =
+      kOverreadInBytesPass1 - sizeof(*src0) * width;
+  uint16x8_t s[2][4];
+  uint8x16_t ma3[2][2], ma5[2];
+  uint32x4_t sq[2][8], b3[2][6], b5[6];
+
+  s[0][0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
+  s[0][1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
+  s[1][0] = Load1QMsanU16(src1 + 0, overread_in_bytes + 0);
+  s[1][1] = Load1QMsanU16(src1 + 8, overread_in_bytes + 16);
+  Square(s[0][0], sq[0]);
+  Square(s[1][0], sq[1]);
+  BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq,
+                        ma3, b3, &ma5[0], b5);
+
+  int x = 0;
+  do {
+    uint16x8_t ma[3][3];
+    uint32x4_t b[3][3][2];
+    uint8x16_t ma3x[2][3], ma5x[3];
+    int16x8_t p[2][2];
+
+    s[0][2] = Load1QMsanU16(src0 + x + 16,
+                            overread_in_bytes + sizeof(*src0) * (x + 16));
+    s[0][3] = Load1QMsanU16(src0 + x + 24,
+                            overread_in_bytes + sizeof(*src0) * (x + 24));
+    s[1][2] = Load1QMsanU16(src1 + x + 16,
+                            overread_in_bytes + sizeof(*src1) * (x + 16));
+    s[1][3] = Load1QMsanU16(src1 + x + 24,
+                            overread_in_bytes + sizeof(*src1) * (x + 24));
+
+    BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5,
+                        sum_width, sq, ma3, b3, ma5, b5);
+    Prepare3_8<0>(ma3[0], ma3x[0]);
+    Prepare3_8<0>(ma3[1], ma3x[1]);
+    Prepare3_8<0>(ma5, ma5x);
+    Store343_444Lo(ma3x[0], b3[0], x, &ma[1][2], &ma[2][1], b[1][2], b[2][1],
+                   ma343[2], ma444[1], b343[2], b444[1]);
+    Store343_444Lo(ma3x[1], b3[1], x, &ma[2][2], b[2][2], ma343[3], ma444[2],
+                   b343[3], b444[2]);
+    ma[0][1] = Sum565Lo(ma5x);
+    vst1q_u16(ma565[1] + x, ma[0][1]);
+    Sum565(b5, b[0][1]);
+    StoreAligned32U32(b565[1] + x, b[0][1]);
+    const uint16x8_t sr0_lo = vld1q_u16(src + x);
+    const uint16x8_t sr1_lo = vld1q_u16(src + stride + x);
+    ma[0][0] = vld1q_u16(ma565[0] + x);
+    LoadAligned32U32(b565[0] + x, b[0][0]);
+    p[0][0] = CalculateFilteredOutputPass1(sr0_lo, ma[0], b[0]);
+    p[1][0] = CalculateFilteredOutput<4>(sr1_lo, ma[0][1], b[0][1]);
+    ma[1][0] = vld1q_u16(ma343[0] + x);
+    ma[1][1] = vld1q_u16(ma444[0] + x);
+    LoadAligned32U32(b343[0] + x, b[1][0]);
+    LoadAligned32U32(b444[0] + x, b[1][1]);
+    p[0][1] = CalculateFilteredOutputPass2(sr0_lo, ma[1], b[1]);
+    const int16x8_t d00 = SelfGuidedDoubleMultiplier(sr0_lo, p[0], w0, w2);
+    ma[2][0] = vld1q_u16(ma343[1] + x);
+    LoadAligned32U32(b343[1] + x, b[2][0]);
+    p[1][1] = CalculateFilteredOutputPass2(sr1_lo, ma[2], b[2]);
+    const int16x8_t d10 = SelfGuidedDoubleMultiplier(sr1_lo, p[1], w0, w2);
+
+    Store343_444Hi(ma3x[0], b3[0] + 2, x + 8, &ma[1][2], &ma[2][1], b[1][2],
+                   b[2][1], ma343[2], ma444[1], b343[2], b444[1]);
+    Store343_444Hi(ma3x[1], b3[1] + 2, x + 8, &ma[2][2], b[2][2], ma343[3],
+                   ma444[2], b343[3], b444[2]);
+    ma[0][1] = Sum565Hi(ma5x);
+    vst1q_u16(ma565[1] + x + 8, ma[0][1]);
+    Sum565(b5 + 2, b[0][1]);
+    StoreAligned32U32(b565[1] + x + 8, b[0][1]);
+    const uint16x8_t sr0_hi = Load1QMsanU16(
+        src + x + 8, overread_in_bytes + 4 + sizeof(*src) * (x + 8));
+    const uint16x8_t sr1_hi = Load1QMsanU16(
+        src + stride + x + 8, overread_in_bytes + 4 + sizeof(*src) * (x + 8));
+    ma[0][0] = vld1q_u16(ma565[0] + x + 8);
+    LoadAligned32U32(b565[0] + x + 8, b[0][0]);
+    p[0][0] = CalculateFilteredOutputPass1(sr0_hi, ma[0], b[0]);
+    p[1][0] = CalculateFilteredOutput<4>(sr1_hi, ma[0][1], b[0][1]);
+    ma[1][0] = vld1q_u16(ma343[0] + x + 8);
+    ma[1][1] = vld1q_u16(ma444[0] + x + 8);
+    LoadAligned32U32(b343[0] + x + 8, b[1][0]);
+    LoadAligned32U32(b444[0] + x + 8, b[1][1]);
+    p[0][1] = CalculateFilteredOutputPass2(sr0_hi, ma[1], b[1]);
+    const int16x8_t d01 = SelfGuidedDoubleMultiplier(sr0_hi, p[0], w0, w2);
+    ClipAndStore(dst + x + 0, d00);
+    ClipAndStore(dst + x + 8, d01);
+    ma[2][0] = vld1q_u16(ma343[1] + x + 8);
+    LoadAligned32U32(b343[1] + x + 8, b[2][0]);
+    p[1][1] = CalculateFilteredOutputPass2(sr1_hi, ma[2], b[2]);
+    const int16x8_t d11 = SelfGuidedDoubleMultiplier(sr1_hi, p[1], w0, w2);
+    ClipAndStore(dst + stride + x + 0, d10);
+    ClipAndStore(dst + stride + x + 8, d11);
+    s[0][0] = s[0][2];
+    s[0][1] = s[0][3];
+    s[1][0] = s[1][2];
+    s[1][1] = s[1][3];
+    sq[0][2] = sq[0][6];
+    sq[0][3] = sq[0][7];
+    sq[1][2] = sq[1][6];
+    sq[1][3] = sq[1][7];
+    ma3[0][0] = ma3[0][1];
+    ma3[1][0] = ma3[1][1];
+    ma5[0] = ma5[1];
+    b3[0][0] = b3[0][4];
+    b3[0][1] = b3[0][5];
+    b3[1][0] = b3[1][4];
+    b3[1][1] = b3[1][5];
+    b5[0] = b5[4];
+    b5[1] = b5[5];
+    x += 16;
+  } while (x < width);
+}
+
+inline void BoxFilterLastRow(
+    const uint16_t* const src, const uint16_t* const src0, const int width,
+    const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0,
+    const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5],
+    uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+    uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565,
+    uint32_t* const b343, uint32_t* const b444, uint32_t* const b565,
+    uint16_t* const dst) {
+  const ptrdiff_t overread_in_bytes =
+      kOverreadInBytesPass1 - sizeof(*src0) * width;
+  uint16x8_t s[4];
+  uint8x16_t ma3[2], ma5[2];
+  uint32x4_t sq[8], b3[6], b5[6];
+  uint16x8_t ma[3];
+  uint32x4_t b[3][2];
+
+  s[0] = Load1QMsanU16(src0 + 0, overread_in_bytes + 0);
+  s[1] = Load1QMsanU16(src0 + 8, overread_in_bytes + 16);
+  Square(s[0], sq);
+  // Quiet "may be used uninitialized" warning.
+  ma3[0] = ma3[1] = vdupq_n_u8(0);
+  BoxFilterPreProcessLastRowLo(s, scales, sum3, sum5, square_sum3, square_sum5,
+                               sq, &ma3[0], &ma5[0], b3, b5);
+
+  int x = 0;
+  do {
+    uint8x16_t ma3x[3], ma5x[3];
+    int16x8_t p[2];
+
+    s[2] = Load1QMsanU16(src0 + x + 16,
+                         overread_in_bytes + sizeof(*src0) * (x + 16));
+    s[3] = Load1QMsanU16(src0 + x + 24,
+                         overread_in_bytes + sizeof(*src0) * (x + 24));
+    BoxFilterPreProcessLastRow(s, sum_width, x + 8, scales, sum3, sum5,
+                               square_sum3, square_sum5, sq, ma3, ma5, b3, b5);
+    Prepare3_8<0>(ma3, ma3x);
+    Prepare3_8<0>(ma5, ma5x);
+    ma[1] = Sum565Lo(ma5x);
+    Sum565(b5, b[1]);
+    ma[2] = Sum343Lo(ma3x);
+    Sum343(b3, b[2]);
+    const uint16x8_t sr_lo = vld1q_u16(src + x + 0);
+    ma[0] = vld1q_u16(ma565 + x);
+    LoadAligned32U32(b565 + x, b[0]);
+    p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b);
+    ma[0] = vld1q_u16(ma343 + x);
+    ma[1] = vld1q_u16(ma444 + x);
+    LoadAligned32U32(b343 + x, b[0]);
+    LoadAligned32U32(b444 + x, b[1]);
+    p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b);
+    const int16x8_t d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2);
+
+    ma[1] = Sum565Hi(ma5x);
+    Sum565(b5 + 2, b[1]);
+    ma[2] = Sum343Hi(ma3x);
+    Sum343(b3 + 2, b[2]);
+    const uint16x8_t sr_hi = Load1QMsanU16(
+        src + x + 8, overread_in_bytes + 4 + sizeof(*src) * (x + 8));
+    ma[0] = vld1q_u16(ma565 + x + 8);
+    LoadAligned32U32(b565 + x + 8, b[0]);
+    p[0] = CalculateFilteredOutputPass1(sr_hi, ma, b);
+    ma[0] = vld1q_u16(ma343 + x + 8);
+    ma[1] = vld1q_u16(ma444 + x + 8);
+    LoadAligned32U32(b343 + x + 8, b[0]);
+    LoadAligned32U32(b444 + x + 8, b[1]);
+    p[1] = CalculateFilteredOutputPass2(sr_hi, ma, b);
+    const int16x8_t d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2);
+    ClipAndStore(dst + x + 0, d0);
+    ClipAndStore(dst + x + 8, d1);
+    s[1] = s[3];
+    sq[2] = sq[6];
+    sq[3] = sq[7];
+    ma3[0] = ma3[1];
+    ma5[0] = ma5[1];
+    b3[0] = b3[4];
+    b3[1] = b3[5];
+    b5[0] = b5[4];
+    b5[1] = b5[5];
+    x += 16;
+  } while (x < width);
+}
+
+LIBGAV1_ALWAYS_INLINE void BoxFilterProcess(
+    const RestorationUnitInfo& restoration_info, const uint16_t* src,
+    const ptrdiff_t stride, const uint16_t* const top_border,
+    const ptrdiff_t top_border_stride, const uint16_t* bottom_border,
+    const ptrdiff_t bottom_border_stride, const int width, const int height,
+    SgrBuffer* const sgr_buffer, uint16_t* dst) {
+  const auto temp_stride = Align<ptrdiff_t>(width, 16);
+  const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
+  const auto sum_stride = temp_stride + 16;
+  const int sgr_proj_index = restoration_info.sgr_proj_info.index;
+  const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index];  // < 2^12.
+  const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
+  const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
+  const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
+  uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
+  uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
+  sum3[0] = sgr_buffer->sum3;
+  square_sum3[0] = sgr_buffer->square_sum3;
+  ma343[0] = sgr_buffer->ma343;
+  b343[0] = sgr_buffer->b343;
+  for (int i = 1; i <= 3; ++i) {
+    sum3[i] = sum3[i - 1] + sum_stride;
+    square_sum3[i] = square_sum3[i - 1] + sum_stride;
+    ma343[i] = ma343[i - 1] + temp_stride;
+    b343[i] = b343[i - 1] + temp_stride;
+  }
+  sum5[0] = sgr_buffer->sum5;
+  square_sum5[0] = sgr_buffer->square_sum5;
+  for (int i = 1; i <= 4; ++i) {
+    sum5[i] = sum5[i - 1] + sum_stride;
+    square_sum5[i] = square_sum5[i - 1] + sum_stride;
+  }
+  ma444[0] = sgr_buffer->ma444;
+  b444[0] = sgr_buffer->b444;
+  for (int i = 1; i <= 2; ++i) {
+    ma444[i] = ma444[i - 1] + temp_stride;
+    b444[i] = b444[i - 1] + temp_stride;
+  }
+  ma565[0] = sgr_buffer->ma565;
+  ma565[1] = ma565[0] + temp_stride;
+  b565[0] = sgr_buffer->b565;
+  b565[1] = b565[0] + temp_stride;
+  assert(scales[0] != 0);
+  assert(scales[1] != 0);
+  BoxSum(top_border, top_border_stride, width, sum_stride, sum_width, sum3[0],
+         sum5[1], square_sum3[0], square_sum5[1]);
+  sum5[0] = sum5[1];
+  square_sum5[0] = square_sum5[1];
+  const uint16_t* const s = (height > 1) ? src + stride : bottom_border;
+  BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3,
+                         square_sum5, sum_width, ma343, ma444[0], ma565[0],
+                         b343, b444[0], b565[0]);
+  sum5[0] = sgr_buffer->sum5;
+  square_sum5[0] = sgr_buffer->square_sum5;
+
+  for (int y = (height >> 1) - 1; y > 0; --y) {
+    Circulate4PointersBy2<uint16_t>(sum3);
+    Circulate4PointersBy2<uint32_t>(square_sum3);
+    Circulate5PointersBy2<uint16_t>(sum5);
+    Circulate5PointersBy2<uint32_t>(square_sum5);
+    BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width,
+              scales, w0, w2, sum3, sum5, square_sum3, square_sum5, sum_width,
+              ma343, ma444, ma565, b343, b444, b565, dst);
+    src += 2 * stride;
+    dst += 2 * stride;
+    Circulate4PointersBy2<uint16_t>(ma343);
+    Circulate4PointersBy2<uint32_t>(b343);
+    std::swap(ma444[0], ma444[2]);
+    std::swap(b444[0], b444[2]);
+    std::swap(ma565[0], ma565[1]);
+    std::swap(b565[0], b565[1]);
+  }
+
+  Circulate4PointersBy2<uint16_t>(sum3);
+  Circulate4PointersBy2<uint32_t>(square_sum3);
+  Circulate5PointersBy2<uint16_t>(sum5);
+  Circulate5PointersBy2<uint32_t>(square_sum5);
+  if ((height & 1) == 0 || height > 1) {
+    const uint16_t* sr[2];
+    if ((height & 1) == 0) {
+      sr[0] = bottom_border;
+      sr[1] = bottom_border + bottom_border_stride;
+    } else {
+      sr[0] = src + 2 * stride;
+      sr[1] = bottom_border;
+    }
+    BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5,
+              square_sum3, square_sum5, sum_width, ma343, ma444, ma565, b343,
+              b444, b565, dst);
+  }
+  if ((height & 1) != 0) {
+    if (height > 1) {
+      src += 2 * stride;
+      dst += 2 * stride;
+      Circulate4PointersBy2<uint16_t>(sum3);
+      Circulate4PointersBy2<uint32_t>(square_sum3);
+      Circulate5PointersBy2<uint16_t>(sum5);
+      Circulate5PointersBy2<uint32_t>(square_sum5);
+      Circulate4PointersBy2<uint16_t>(ma343);
+      Circulate4PointersBy2<uint32_t>(b343);
+      std::swap(ma444[0], ma444[2]);
+      std::swap(b444[0], b444[2]);
+      std::swap(ma565[0], ma565[1]);
+      std::swap(b565[0], b565[1]);
+    }
+    BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width,
+                     sum_width, scales, w0, w2, sum3, sum5, square_sum3,
+                     square_sum5, ma343[0], ma444[0], ma565[0], b343[0],
+                     b444[0], b565[0], dst);
+  }
+}
+
+inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
+                                  const uint16_t* src, const ptrdiff_t stride,
+                                  const uint16_t* const top_border,
+                                  const ptrdiff_t top_border_stride,
+                                  const uint16_t* bottom_border,
+                                  const ptrdiff_t bottom_border_stride,
+                                  const int width, const int height,
+                                  SgrBuffer* const sgr_buffer, uint16_t* dst) {
+  const auto temp_stride = Align<ptrdiff_t>(width, 16);
+  const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
+  const auto sum_stride = temp_stride + 16;
+  const int sgr_proj_index = restoration_info.sgr_proj_info.index;
+  const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0];  // < 2^12.
+  const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
+  uint16_t *sum5[5], *ma565[2];
+  uint32_t *square_sum5[5], *b565[2];
+  sum5[0] = sgr_buffer->sum5;
+  square_sum5[0] = sgr_buffer->square_sum5;
+  for (int i = 1; i <= 4; ++i) {
+    sum5[i] = sum5[i - 1] + sum_stride;
+    square_sum5[i] = square_sum5[i - 1] + sum_stride;
+  }
+  ma565[0] = sgr_buffer->ma565;
+  ma565[1] = ma565[0] + temp_stride;
+  b565[0] = sgr_buffer->b565;
+  b565[1] = b565[0] + temp_stride;
+  assert(scale != 0);
+
+  BoxSum<5>(top_border, top_border_stride, width, sum_stride, sum_width,
+            sum5[1], square_sum5[1]);
+  sum5[0] = sum5[1];
+  square_sum5[0] = square_sum5[1];
+  const uint16_t* const s = (height > 1) ? src + stride : bottom_border;
+  BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, sum_width,
+                          ma565[0], b565[0]);
+  sum5[0] = sgr_buffer->sum5;
+  square_sum5[0] = sgr_buffer->square_sum5;
+
+  for (int y = (height >> 1) - 1; y > 0; --y) {
+    Circulate5PointersBy2<uint16_t>(sum5);
+    Circulate5PointersBy2<uint32_t>(square_sum5);
+    BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5,
+                   square_sum5, width, sum_width, scale, w0, ma565, b565, dst);
+    src += 2 * stride;
+    dst += 2 * stride;
+    std::swap(ma565[0], ma565[1]);
+    std::swap(b565[0], b565[1]);
+  }
+
+  Circulate5PointersBy2<uint16_t>(sum5);
+  Circulate5PointersBy2<uint32_t>(square_sum5);
+  if ((height & 1) == 0 || height > 1) {
+    const uint16_t* sr[2];
+    if ((height & 1) == 0) {
+      sr[0] = bottom_border;
+      sr[1] = bottom_border + bottom_border_stride;
+    } else {
+      sr[0] = src + 2 * stride;
+      sr[1] = bottom_border;
+    }
+    BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width,
+                   sum_width, scale, w0, ma565, b565, dst);
+  }
+  if ((height & 1) != 0) {
+    src += 3;
+    if (height > 1) {
+      src += 2 * stride;
+      dst += 2 * stride;
+      std::swap(ma565[0], ma565[1]);
+      std::swap(b565[0], b565[1]);
+      Circulate5PointersBy2<uint16_t>(sum5);
+      Circulate5PointersBy2<uint32_t>(square_sum5);
+    }
+    BoxFilterPass1LastRow(src, bottom_border + bottom_border_stride, width,
+                          sum_width, scale, w0, sum5, square_sum5, ma565[0],
+                          b565[0], dst);
+  }
+}
+
+inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
+                                  const uint16_t* src, const ptrdiff_t stride,
+                                  const uint16_t* const top_border,
+                                  const ptrdiff_t top_border_stride,
+                                  const uint16_t* bottom_border,
+                                  const ptrdiff_t bottom_border_stride,
+                                  const int width, const int height,
+                                  SgrBuffer* const sgr_buffer, uint16_t* dst) {
+  assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
+  const auto temp_stride = Align<ptrdiff_t>(width, 16);
+  const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
+  const auto sum_stride = temp_stride + 16;
+  const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
+  const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
+  const int sgr_proj_index = restoration_info.sgr_proj_info.index;
+  const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1];  // < 2^12.
+  uint16_t *sum3[3], *ma343[3], *ma444[2];
+  uint32_t *square_sum3[3], *b343[3], *b444[2];
+  sum3[0] = sgr_buffer->sum3;
+  square_sum3[0] = sgr_buffer->square_sum3;
+  ma343[0] = sgr_buffer->ma343;
+  b343[0] = sgr_buffer->b343;
+  for (int i = 1; i <= 2; ++i) {
+    sum3[i] = sum3[i - 1] + sum_stride;
+    square_sum3[i] = square_sum3[i - 1] + sum_stride;
+    ma343[i] = ma343[i - 1] + temp_stride;
+    b343[i] = b343[i - 1] + temp_stride;
+  }
+  ma444[0] = sgr_buffer->ma444;
+  ma444[1] = ma444[0] + temp_stride;
+  b444[0] = sgr_buffer->b444;
+  b444[1] = b444[0] + temp_stride;
+  assert(scale != 0);
+  BoxSum<3>(top_border, top_border_stride, width, sum_stride, sum_width,
+            sum3[0], square_sum3[0]);
+  BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3,
+                                 sum_width, ma343[0], nullptr, b343[0],
+                                 nullptr);
+  Circulate3PointersBy1<uint16_t>(sum3);
+  Circulate3PointersBy1<uint32_t>(square_sum3);
+  const uint16_t* s;
+  if (height > 1) {
+    s = src + stride;
+  } else {
+    s = bottom_border;
+    bottom_border += bottom_border_stride;
+  }
+  BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width,
+                                ma343[1], ma444[0], b343[1], b444[0]);
+
+  for (int y = height - 2; y > 0; --y) {
+    Circulate3PointersBy1<uint16_t>(sum3);
+    Circulate3PointersBy1<uint32_t>(square_sum3);
+    BoxFilterPass2(src + 2, src + 2 * stride, width, sum_width, scale, w0, sum3,
+                   square_sum3, ma343, ma444, b343, b444, dst);
+    src += stride;
+    dst += stride;
+    Circulate3PointersBy1<uint16_t>(ma343);
+    Circulate3PointersBy1<uint32_t>(b343);
+    std::swap(ma444[0], ma444[1]);
+    std::swap(b444[0], b444[1]);
+  }
+
+  int y = std::min(height, 2);
+  src += 2;
+  do {
+    Circulate3PointersBy1<uint16_t>(sum3);
+    Circulate3PointersBy1<uint32_t>(square_sum3);
+    BoxFilterPass2(src, bottom_border, width, sum_width, scale, w0, sum3,
+                   square_sum3, ma343, ma444, b343, b444, dst);
+    src += stride;
+    dst += stride;
+    bottom_border += bottom_border_stride;
+    Circulate3PointersBy1<uint16_t>(ma343);
+    Circulate3PointersBy1<uint32_t>(b343);
+    std::swap(ma444[0], ma444[1]);
+    std::swap(b444[0], b444[1]);
+  } while (--y != 0);
+}
+
+// If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in
+// the end of each row. It is safe to overwrite the output as it will not be
+// part of the visible frame.
+void SelfGuidedFilter_NEON(
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
+    const ptrdiff_t bottom_border_stride, const int width, const int height,
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
+  const int index = restoration_info.sgr_proj_info.index;
+  const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
+  const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
+  const auto* const src = static_cast<const uint16_t*>(source);
+  const auto* top = static_cast<const uint16_t*>(top_border);
+  const auto* bottom = static_cast<const uint16_t*>(bottom_border);
+  auto* const dst = static_cast<uint16_t*>(dest);
+  SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
+  if (radius_pass_1 == 0) {
+    // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
+    // following assertion.
+    assert(radius_pass_0 != 0);
+    BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3,
+                          top_border_stride, bottom - 3, bottom_border_stride,
+                          width, height, sgr_buffer, dst);
+  } else if (radius_pass_0 == 0) {
+    BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2,
+                          top_border_stride, bottom - 2, bottom_border_stride,
+                          width, height, sgr_buffer, dst);
+  } else {
+    BoxFilterProcess(restoration_info, src - 3, stride, top - 3,
+                     top_border_stride, bottom - 3, bottom_border_stride, width,
+                     height, sgr_buffer, dst);
+  }
+}
+
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->loop_restorations[0] = WienerFilter_NEON;
+  dsp->loop_restorations[1] = SelfGuidedFilter_NEON;
+}
+
+}  // namespace
+
+void LoopRestorationInit10bpp_NEON() { Init10bpp(); }
+
+}  // namespace dsp
+}  // namespace libgav1
+
+#else   // !(LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10)
+namespace libgav1 {
+namespace dsp {
+
+void LoopRestorationInit10bpp_NEON() {}
+
+}  // namespace dsp
+}  // namespace libgav1
+#endif  // LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
diff --git a/libgav1/src/dsp/arm/loop_restoration_neon.cc b/libgav1/src/dsp/arm/loop_restoration_neon.cc
index e6ceb66..2db137f 100644
--- a/libgav1/src/dsp/arm/loop_restoration_neon.cc
+++ b/libgav1/src/dsp/arm/loop_restoration_neon.cc
@@ -28,6 +28,7 @@
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
 
 namespace libgav1 {
@@ -491,11 +492,14 @@
 // filter row by row. This is faster than doing it column by column when
 // considering cache issues.
 void WienerFilter_NEON(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int16_t* const number_leading_zero_coefficients =
       restoration_info.wiener_info.number_leading_zero_coefficients;
   const int number_rows_to_skip = std::max(
@@ -591,6 +595,74 @@
 //------------------------------------------------------------------------------
 // SGR
 
+// SIMD overreads 8 - (width % 8) - 2 * padding pixels, where padding is 3 for
+// Pass 1 and 2 for Pass 2.
+constexpr int kOverreadInBytesPass1 = 2;
+constexpr int kOverreadInBytesPass2 = 4;
+
+// SIMD overreads 16 - (width % 16) - 2 * padding pixels, where padding is 3 for
+// Pass 1 and 2 for Pass 2.
+constexpr int kWideOverreadInBytesPass1 = 10;
+constexpr int kWideOverreadInBytesPass2 = 12;
+
+inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x,
+                               uint16x8_t dst[2]) {
+  dst[0] = vld1q_u16(src[0] + x);
+  dst[1] = vld1q_u16(src[1] + x);
+}
+
+inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x,
+                               uint16x8_t dst[3]) {
+  dst[0] = vld1q_u16(src[0] + x);
+  dst[1] = vld1q_u16(src[1] + x);
+  dst[2] = vld1q_u16(src[2] + x);
+}
+
+inline void LoadAligned32U32(const uint32_t* const src, uint32x4x2_t* dst) {
+  (*dst).val[0] = vld1q_u32(src + 0);
+  (*dst).val[1] = vld1q_u32(src + 4);
+}
+
+inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x,
+                               uint32x4x2_t dst[2]) {
+  LoadAligned32U32(src[0] + x, &dst[0]);
+  LoadAligned32U32(src[1] + x, &dst[1]);
+}
+
+inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x,
+                               uint32x4x2_t dst[3]) {
+  LoadAligned32U32(src[0] + x, &dst[0]);
+  LoadAligned32U32(src[1] + x, &dst[1]);
+  LoadAligned32U32(src[2] + x, &dst[2]);
+}
+
+inline void StoreAligned32U16(uint16_t* const dst, const uint16x8_t src[2]) {
+  vst1q_u16(dst + 0, src[0]);
+  vst1q_u16(dst + 8, src[1]);
+}
+
+inline void StoreAligned32U32(uint32_t* const dst, const uint32x4x2_t src) {
+  vst1q_u32(dst + 0, src.val[0]);
+  vst1q_u32(dst + 4, src.val[1]);
+}
+
+inline void StoreAligned64U32(uint32_t* const dst, const uint32x4x2_t src[2]) {
+  vst1q_u32(dst + 0, src[0].val[0]);
+  vst1q_u32(dst + 4, src[0].val[1]);
+  vst1q_u32(dst + 8, src[1].val[0]);
+  vst1q_u32(dst + 12, src[1].val[1]);
+}
+
+inline uint16x8_t SquareLo8(const uint8x8_t src) { return vmull_u8(src, src); }
+
+inline uint16x8_t SquareLo8(const uint8x16_t src) {
+  return vmull_u8(vget_low_u8(src), vget_low_u8(src));
+}
+
+inline uint16x8_t SquareHi8(const uint8x16_t src) {
+  return vmull_u8(vget_high_u8(src), vget_high_u8(src));
+}
+
 inline void Prepare3_8(const uint8x8_t src[2], uint8x8_t dst[3]) {
   dst[0] = VshrU128<0>(src);
   dst[1] = VshrU128<1>(src);
@@ -904,58 +976,69 @@
 }
 
 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
-                   const ptrdiff_t sum_stride, uint16_t* sum3, uint16_t* sum5,
+                   const ptrdiff_t width, const ptrdiff_t sum_stride,
+                   const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5,
                    uint32_t* square_sum3, uint32_t* square_sum5) {
+  const ptrdiff_t overread_in_bytes = kOverreadInBytesPass1 - width;
   int y = 2;
   // Don't change loop width to 16, which is even slower.
   do {
     uint8x8_t s[2];
     uint16x8_t sq[2];
-    s[0] = vld1_u8(src);
-    sq[0] = vmull_u8(s[0], s[0]);
-    ptrdiff_t x = 0;
+    s[0] = Load1MsanU8(src, overread_in_bytes);
+    sq[0] = SquareLo8(s[0]);
+    ptrdiff_t x = sum_width;
     do {
       uint16x8_t row3, row5;
       uint32x4x2_t row_sq3, row_sq5;
-      s[1] = vld1_u8(src + x + 8);
-      sq[1] = vmull_u8(s[1], s[1]);
+      x -= 8;
+      src += 8;
+      s[1] = Load1MsanU8(src, sum_width - x + overread_in_bytes);
+      sq[1] = SquareLo8(s[1]);
       SumHorizontal(s, sq, &row3, &row5, &row_sq3, &row_sq5);
       vst1q_u16(sum3, row3);
       vst1q_u16(sum5, row5);
-      vst1q_u32(square_sum3 + 0, row_sq3.val[0]);
-      vst1q_u32(square_sum3 + 4, row_sq3.val[1]);
-      vst1q_u32(square_sum5 + 0, row_sq5.val[0]);
-      vst1q_u32(square_sum5 + 4, row_sq5.val[1]);
+      StoreAligned32U32(square_sum3 + 0, row_sq3);
+      StoreAligned32U32(square_sum5 + 0, row_sq5);
       s[0] = s[1];
       sq[0] = sq[1];
       sum3 += 8;
       sum5 += 8;
       square_sum3 += 8;
       square_sum5 += 8;
-      x += 8;
-    } while (x < sum_stride);
-    src += src_stride;
+    } while (x != 0);
+    src += src_stride - sum_width;
+    sum3 += sum_stride - sum_width;
+    sum5 += sum_stride - sum_width;
+    square_sum3 += sum_stride - sum_width;
+    square_sum5 += sum_stride - sum_width;
   } while (--y != 0);
 }
 
 template <int size>
 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
-                   const ptrdiff_t sum_stride, uint16_t* sums,
+                   const ptrdiff_t width, const ptrdiff_t sum_stride,
+                   const ptrdiff_t sum_width, uint16_t* sums,
                    uint32_t* square_sums) {
   static_assert(size == 3 || size == 5, "");
+  const ptrdiff_t overread_in_bytes =
+      ((size == 5) ? kOverreadInBytesPass1 : kOverreadInBytesPass2) -
+      sizeof(*src) * width;
   int y = 2;
   // Don't change loop width to 16, which is even slower.
   do {
     uint8x8_t s[2];
     uint16x8_t sq[2];
-    s[0] = vld1_u8(src);
-    sq[0] = vmull_u8(s[0], s[0]);
-    ptrdiff_t x = 0;
+    s[0] = Load1MsanU8(src, overread_in_bytes);
+    sq[0] = SquareLo8(s[0]);
+    ptrdiff_t x = sum_width;
     do {
       uint16x8_t row;
       uint32x4x2_t row_sq;
-      s[1] = vld1_u8(src + x + 8);
-      sq[1] = vmull_u8(s[1], s[1]);
+      x -= 8;
+      src += 8;
+      s[1] = Load1MsanU8(src, sum_width - x + overread_in_bytes);
+      sq[1] = SquareLo8(s[1]);
       if (size == 3) {
         row = Sum3Horizontal(s);
         row_sq = Sum3WHorizontal(sq);
@@ -964,15 +1047,15 @@
         row_sq = Sum5WHorizontal(sq);
       }
       vst1q_u16(sums, row);
-      vst1q_u32(square_sums + 0, row_sq.val[0]);
-      vst1q_u32(square_sums + 4, row_sq.val[1]);
+      StoreAligned32U32(square_sums, row_sq);
       s[0] = s[1];
       sq[0] = sq[1];
       sums += 8;
       square_sums += 8;
-      x += 8;
-    } while (x < sum_stride);
-    src += src_stride;
+    } while (x != 0);
+    src += src_stride - sum_width;
+    sums += sum_stride - sum_width;
+    square_sums += sum_stride - sum_width;
   } while (--y != 0);
 }
 
@@ -1143,339 +1226,216 @@
 }
 
 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo(
-    const uint8_t* const src0, const uint8_t* const src1, const uint32_t scale,
-    uint8x16_t s[2][2], uint16_t* const sum5[5], uint32_t* const square_sum5[5],
-    uint16x8_t sq[2][4], uint8x16_t* const ma, uint16x8_t* const b) {
+    uint8x16_t s[2][2], const uint32_t scale, uint16_t* const sum5[5],
+    uint32_t* const square_sum5[5], uint16x8_t sq[2][4], uint8x16_t* const ma,
+    uint16x8_t* const b) {
   uint16x8_t s5[5];
   uint32x4x2_t sq5[5];
-  s[0][0] = vld1q_u8(src0);
-  s[1][0] = vld1q_u8(src1);
-  sq[0][0] = vmull_u8(vget_low_u8(s[0][0]), vget_low_u8(s[0][0]));
-  sq[1][0] = vmull_u8(vget_low_u8(s[1][0]), vget_low_u8(s[1][0]));
-  sq[0][1] = vmull_u8(vget_high_u8(s[0][0]), vget_high_u8(s[0][0]));
-  sq[1][1] = vmull_u8(vget_high_u8(s[1][0]), vget_high_u8(s[1][0]));
+  sq[0][0] = SquareLo8(s[0][0]);
+  sq[1][0] = SquareLo8(s[1][0]);
+  sq[0][1] = SquareHi8(s[0][0]);
+  sq[1][1] = SquareHi8(s[1][0]);
   s5[3] = Sum5Horizontal(s[0][0]);
   s5[4] = Sum5Horizontal(s[1][0]);
   sq5[3] = Sum5WHorizontal(sq[0]);
   sq5[4] = Sum5WHorizontal(sq[1]);
   vst1q_u16(sum5[3], s5[3]);
   vst1q_u16(sum5[4], s5[4]);
-  vst1q_u32(square_sum5[3] + 0, sq5[3].val[0]);
-  vst1q_u32(square_sum5[3] + 4, sq5[3].val[1]);
-  vst1q_u32(square_sum5[4] + 0, sq5[4].val[0]);
-  vst1q_u32(square_sum5[4] + 4, sq5[4].val[1]);
-  s5[0] = vld1q_u16(sum5[0]);
-  s5[1] = vld1q_u16(sum5[1]);
-  s5[2] = vld1q_u16(sum5[2]);
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + 0);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + 4);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + 0);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + 4);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + 0);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + 4);
+  StoreAligned32U32(square_sum5[3], sq5[3]);
+  StoreAligned32U32(square_sum5[4], sq5[4]);
+  LoadAligned16x3U16(sum5, 0, s5);
+  LoadAligned32x3U32(square_sum5, 0, sq5);
   CalculateIntermediate5<0>(s5, sq5, scale, ma, b);
 }
 
 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
-    const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x,
-    const uint32_t scale, uint8x16_t s[2][2], uint16_t* const sum5[5],
-    uint32_t* const square_sum5[5], uint16x8_t sq[2][4], uint8x16_t ma[2],
-    uint16x8_t b[2]) {
+    uint8x16_t s[2][2], const ptrdiff_t x, const uint32_t scale,
+    uint16_t* const sum5[5], uint32_t* const square_sum5[5],
+    uint16x8_t sq[2][4], uint8x16_t ma[2], uint16x8_t b[2]) {
   uint16x8_t s5[2][5];
   uint32x4x2_t sq5[5];
-  s[0][1] = vld1q_u8(src0 + x + 8);
-  s[1][1] = vld1q_u8(src1 + x + 8);
-  sq[0][2] = vmull_u8(vget_low_u8(s[0][1]), vget_low_u8(s[0][1]));
-  sq[1][2] = vmull_u8(vget_low_u8(s[1][1]), vget_low_u8(s[1][1]));
+  sq[0][2] = SquareLo8(s[0][1]);
+  sq[1][2] = SquareLo8(s[1][1]);
   Sum5Horizontal<8>(s[0], &s5[0][3], &s5[1][3]);
   Sum5Horizontal<8>(s[1], &s5[0][4], &s5[1][4]);
   sq5[3] = Sum5WHorizontal(sq[0] + 1);
   sq5[4] = Sum5WHorizontal(sq[1] + 1);
   vst1q_u16(sum5[3] + x, s5[0][3]);
   vst1q_u16(sum5[4] + x, s5[0][4]);
-  vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
-  vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
-  vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
-  vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
-  s5[0][0] = vld1q_u16(sum5[0] + x);
-  s5[0][1] = vld1q_u16(sum5[1] + x);
-  s5[0][2] = vld1q_u16(sum5[2] + x);
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+  StoreAligned32U32(square_sum5[3] + x, sq5[3]);
+  StoreAligned32U32(square_sum5[4] + x, sq5[4]);
+  LoadAligned16x3U16(sum5, x, s5[0]);
+  LoadAligned32x3U32(square_sum5, x, sq5);
   CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], &b[0]);
 
-  sq[0][3] = vmull_u8(vget_high_u8(s[0][1]), vget_high_u8(s[0][1]));
-  sq[1][3] = vmull_u8(vget_high_u8(s[1][1]), vget_high_u8(s[1][1]));
+  sq[0][3] = SquareHi8(s[0][1]);
+  sq[1][3] = SquareHi8(s[1][1]);
   sq5[3] = Sum5WHorizontal(sq[0] + 2);
   sq5[4] = Sum5WHorizontal(sq[1] + 2);
   vst1q_u16(sum5[3] + x + 8, s5[1][3]);
   vst1q_u16(sum5[4] + x + 8, s5[1][4]);
-  vst1q_u32(square_sum5[3] + x + 8, sq5[3].val[0]);
-  vst1q_u32(square_sum5[3] + x + 12, sq5[3].val[1]);
-  vst1q_u32(square_sum5[4] + x + 8, sq5[4].val[0]);
-  vst1q_u32(square_sum5[4] + x + 12, sq5[4].val[1]);
-  s5[1][0] = vld1q_u16(sum5[0] + x + 8);
-  s5[1][1] = vld1q_u16(sum5[1] + x + 8);
-  s5[1][2] = vld1q_u16(sum5[2] + x + 8);
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 8);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 12);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 8);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 12);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 8);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 12);
+  StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]);
+  StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]);
+  LoadAligned16x3U16(sum5, x + 8, s5[1]);
+  LoadAligned32x3U32(square_sum5, x + 8, sq5);
   CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], &b[1]);
 }
 
 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo(
-    const uint8_t* const src, const uint32_t scale, uint8x16_t* const s,
-    const uint16_t* const sum5[5], const uint32_t* const square_sum5[5],
-    uint16x8_t sq[2], uint8x16_t* const ma, uint16x8_t* const b) {
+    uint8x16_t* const s, const uint32_t scale, const uint16_t* const sum5[5],
+    const uint32_t* const square_sum5[5], uint16x8_t sq[2],
+    uint8x16_t* const ma, uint16x8_t* const b) {
   uint16x8_t s5[5];
   uint32x4x2_t sq5[5];
-  *s = vld1q_u8(src);
-  sq[0] = vmull_u8(vget_low_u8(*s), vget_low_u8(*s));
-  sq[1] = vmull_u8(vget_high_u8(*s), vget_high_u8(*s));
+  sq[0] = SquareLo8(s[0]);
+  sq[1] = SquareHi8(s[0]);
   s5[3] = s5[4] = Sum5Horizontal(*s);
   sq5[3] = sq5[4] = Sum5WHorizontal(sq);
-  s5[0] = vld1q_u16(sum5[0]);
-  s5[1] = vld1q_u16(sum5[1]);
-  s5[2] = vld1q_u16(sum5[2]);
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + 0);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + 4);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + 0);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + 4);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + 0);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + 4);
+  LoadAligned16x3U16(sum5, 0, s5);
+  LoadAligned32x3U32(square_sum5, 0, sq5);
   CalculateIntermediate5<0>(s5, sq5, scale, ma, b);
 }
 
 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
-    const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
-    uint8x16_t s[2], const uint16_t* const sum5[5],
-    const uint32_t* const square_sum5[5], uint16x8_t sq[3], uint8x16_t ma[2],
-    uint16x8_t b[2]) {
+    uint8x16_t s[2], const ptrdiff_t x, const uint32_t scale,
+    const uint16_t* const sum5[5], const uint32_t* const square_sum5[5],
+    uint16x8_t sq[3], uint8x16_t ma[2], uint16x8_t b[2]) {
   uint16x8_t s5[2][5];
   uint32x4x2_t sq5[5];
-  s[1] = vld1q_u8(src + x + 8);
-  sq[1] = vmull_u8(vget_low_u8(s[1]), vget_low_u8(s[1]));
+  sq[1] = SquareLo8(s[1]);
   Sum5Horizontal<8>(s, &s5[0][3], &s5[1][3]);
   sq5[3] = sq5[4] = Sum5WHorizontal(sq);
-  s5[0][0] = vld1q_u16(sum5[0] + x);
-  s5[0][1] = vld1q_u16(sum5[1] + x);
-  s5[0][2] = vld1q_u16(sum5[2] + x);
+  LoadAligned16x3U16(sum5, x, s5[0]);
   s5[0][4] = s5[0][3];
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+  LoadAligned32x3U32(square_sum5, x, sq5);
   CalculateIntermediate5<8>(s5[0], sq5, scale, &ma[0], &b[0]);
 
-  sq[2] = vmull_u8(vget_high_u8(s[1]), vget_high_u8(s[1]));
+  sq[2] = SquareHi8(s[1]);
   sq5[3] = sq5[4] = Sum5WHorizontal(sq + 1);
-  s5[1][0] = vld1q_u16(sum5[0] + x + 8);
-  s5[1][1] = vld1q_u16(sum5[1] + x + 8);
-  s5[1][2] = vld1q_u16(sum5[2] + x + 8);
+  LoadAligned16x3U16(sum5, x + 8, s5[1]);
   s5[1][4] = s5[1][3];
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 8);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 12);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 8);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 12);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 8);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 12);
+  LoadAligned32x3U32(square_sum5, x + 8, sq5);
   CalculateIntermediate5<0>(s5[1], sq5, scale, &ma[1], &b[1]);
 }
 
 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo(
-    const uint8_t* const src, const uint32_t scale, uint8x16_t* const s,
-    uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16x8_t sq[2],
-    uint8x16_t* const ma, uint16x8_t* const b) {
+    uint8x16_t* const s, const uint32_t scale, uint16_t* const sum3[3],
+    uint32_t* const square_sum3[3], uint16x8_t sq[2], uint8x16_t* const ma,
+    uint16x8_t* const b) {
   uint16x8_t s3[3];
   uint32x4x2_t sq3[3];
-  *s = vld1q_u8(src);
-  sq[0] = vmull_u8(vget_low_u8(*s), vget_low_u8(*s));
-  sq[1] = vmull_u8(vget_high_u8(*s), vget_high_u8(*s));
+  sq[0] = SquareLo8(*s);
+  sq[1] = SquareHi8(*s);
   s3[2] = Sum3Horizontal(*s);
   sq3[2] = Sum3WHorizontal(sq);
   vst1q_u16(sum3[2], s3[2]);
-  vst1q_u32(square_sum3[2] + 0, sq3[2].val[0]);
-  vst1q_u32(square_sum3[2] + 4, sq3[2].val[1]);
-  s3[0] = vld1q_u16(sum3[0]);
-  s3[1] = vld1q_u16(sum3[1]);
-  sq3[0].val[0] = vld1q_u32(square_sum3[0] + 0);
-  sq3[0].val[1] = vld1q_u32(square_sum3[0] + 4);
-  sq3[1].val[0] = vld1q_u32(square_sum3[1] + 0);
-  sq3[1].val[1] = vld1q_u32(square_sum3[1] + 4);
+  StoreAligned32U32(square_sum3[2], sq3[2]);
+  LoadAligned16x2U16(sum3, 0, s3);
+  LoadAligned32x2U32(square_sum3, 0, sq3);
   CalculateIntermediate3<0>(s3, sq3, scale, ma, b);
 }
 
 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
-    const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
-    uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint8x16_t s[2],
-    uint16x8_t sq[3], uint8x16_t ma[2], uint16x8_t b[2]) {
+    uint8x16_t s[2], const ptrdiff_t x, const uint32_t scale,
+    uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16x8_t sq[3],
+    uint8x16_t ma[2], uint16x8_t b[2]) {
   uint16x8_t s3[4];
   uint32x4x2_t sq3[3];
-  s[1] = vld1q_u8(src + x + 8);
-  sq[1] = vmull_u8(vget_low_u8(s[1]), vget_low_u8(s[1]));
+  sq[1] = SquareLo8(s[1]);
   Sum3Horizontal<8>(s, s3 + 2);
   sq3[2] = Sum3WHorizontal(sq);
   vst1q_u16(sum3[2] + x, s3[2]);
-  vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
-  vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
-  s3[0] = vld1q_u16(sum3[0] + x);
-  s3[1] = vld1q_u16(sum3[1] + x);
-  sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
-  sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
-  sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
-  sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
+  StoreAligned32U32(square_sum3[2] + x, sq3[2]);
+  LoadAligned16x2U16(sum3, x, s3);
+  LoadAligned32x2U32(square_sum3, x, sq3);
   CalculateIntermediate3<8>(s3, sq3, scale, &ma[0], &b[0]);
 
-  sq[2] = vmull_u8(vget_high_u8(s[1]), vget_high_u8(s[1]));
+  sq[2] = SquareHi8(s[1]);
   sq3[2] = Sum3WHorizontal(sq + 1);
   vst1q_u16(sum3[2] + x + 8, s3[3]);
-  vst1q_u32(square_sum3[2] + x + 8, sq3[2].val[0]);
-  vst1q_u32(square_sum3[2] + x + 12, sq3[2].val[1]);
-  s3[1] = vld1q_u16(sum3[0] + x + 8);
-  s3[2] = vld1q_u16(sum3[1] + x + 8);
-  sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 8);
-  sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 12);
-  sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 8);
-  sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 12);
+  StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]);
+  LoadAligned16x2U16(sum3, x + 8, s3 + 1);
+  LoadAligned32x2U32(square_sum3, x + 8, sq3);
   CalculateIntermediate3<0>(s3 + 1, sq3, scale, &ma[1], &b[1]);
 }
 
 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo(
-    const uint8_t* const src0, const uint8_t* const src1,
-    const uint16_t scales[2], uint8x16_t s[2][2], uint16_t* const sum3[4],
+    uint8x16_t s[2][2], const uint16_t scales[2], uint16_t* const sum3[4],
     uint16_t* const sum5[5], uint32_t* const square_sum3[4],
     uint32_t* const square_sum5[5], uint16x8_t sq[2][4], uint8x16_t ma3[2][2],
     uint16x8_t b3[2][3], uint8x16_t* const ma5, uint16x8_t* const b5) {
   uint16x8_t s3[4], s5[5];
   uint32x4x2_t sq3[4], sq5[5];
-  s[0][0] = vld1q_u8(src0);
-  s[1][0] = vld1q_u8(src1);
-  sq[0][0] = vmull_u8(vget_low_u8(s[0][0]), vget_low_u8(s[0][0]));
-  sq[1][0] = vmull_u8(vget_low_u8(s[1][0]), vget_low_u8(s[1][0]));
-  sq[0][1] = vmull_u8(vget_high_u8(s[0][0]), vget_high_u8(s[0][0]));
-  sq[1][1] = vmull_u8(vget_high_u8(s[1][0]), vget_high_u8(s[1][0]));
+  sq[0][0] = SquareLo8(s[0][0]);
+  sq[1][0] = SquareLo8(s[1][0]);
+  sq[0][1] = SquareHi8(s[0][0]);
+  sq[1][1] = SquareHi8(s[1][0]);
   SumHorizontal(s[0][0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]);
   SumHorizontal(s[1][0], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]);
   vst1q_u16(sum3[2], s3[2]);
   vst1q_u16(sum3[3], s3[3]);
-  vst1q_u32(square_sum3[2] + 0, sq3[2].val[0]);
-  vst1q_u32(square_sum3[2] + 4, sq3[2].val[1]);
-  vst1q_u32(square_sum3[3] + 0, sq3[3].val[0]);
-  vst1q_u32(square_sum3[3] + 4, sq3[3].val[1]);
+  StoreAligned32U32(square_sum3[2], sq3[2]);
+  StoreAligned32U32(square_sum3[3], sq3[3]);
   vst1q_u16(sum5[3], s5[3]);
   vst1q_u16(sum5[4], s5[4]);
-  vst1q_u32(square_sum5[3] + 0, sq5[3].val[0]);
-  vst1q_u32(square_sum5[3] + 4, sq5[3].val[1]);
-  vst1q_u32(square_sum5[4] + 0, sq5[4].val[0]);
-  vst1q_u32(square_sum5[4] + 4, sq5[4].val[1]);
-  s3[0] = vld1q_u16(sum3[0]);
-  s3[1] = vld1q_u16(sum3[1]);
-  sq3[0].val[0] = vld1q_u32(square_sum3[0] + 0);
-  sq3[0].val[1] = vld1q_u32(square_sum3[0] + 4);
-  sq3[1].val[0] = vld1q_u32(square_sum3[1] + 0);
-  sq3[1].val[1] = vld1q_u32(square_sum3[1] + 4);
-  s5[0] = vld1q_u16(sum5[0]);
-  s5[1] = vld1q_u16(sum5[1]);
-  s5[2] = vld1q_u16(sum5[2]);
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + 0);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + 4);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + 0);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + 4);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + 0);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + 4);
+  StoreAligned32U32(square_sum5[3], sq5[3]);
+  StoreAligned32U32(square_sum5[4], sq5[4]);
+  LoadAligned16x2U16(sum3, 0, s3);
+  LoadAligned32x2U32(square_sum3, 0, sq3);
+  LoadAligned16x3U16(sum5, 0, s5);
+  LoadAligned32x3U32(square_sum5, 0, sq5);
   CalculateIntermediate3<0>(s3, sq3, scales[1], ma3[0], b3[0]);
   CalculateIntermediate3<0>(s3 + 1, sq3 + 1, scales[1], ma3[1], b3[1]);
   CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5);
 }
 
 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
-    const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x,
-    const uint16_t scales[2], uint8x16_t s[2][2], uint16_t* const sum3[4],
-    uint16_t* const sum5[5], uint32_t* const square_sum3[4],
-    uint32_t* const square_sum5[5], uint16x8_t sq[2][4], uint8x16_t ma3[2][2],
-    uint16x8_t b3[2][3], uint8x16_t ma5[2], uint16x8_t b5[2]) {
+    const uint8x16_t s[2][2], const ptrdiff_t x, const uint16_t scales[2],
+    uint16_t* const sum3[4], uint16_t* const sum5[5],
+    uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
+    uint16x8_t sq[2][4], uint8x16_t ma3[2][2], uint16x8_t b3[2][3],
+    uint8x16_t ma5[2], uint16x8_t b5[2]) {
   uint16x8_t s3[2][4], s5[2][5];
   uint32x4x2_t sq3[4], sq5[5];
-  s[0][1] = vld1q_u8(src0 + x + 8);
-  s[1][1] = vld1q_u8(src1 + x + 8);
-  sq[0][2] = vmull_u8(vget_low_u8(s[0][1]), vget_low_u8(s[0][1]));
-  sq[1][2] = vmull_u8(vget_low_u8(s[1][1]), vget_low_u8(s[1][1]));
+  sq[0][2] = SquareLo8(s[0][1]);
+  sq[1][2] = SquareLo8(s[1][1]);
   SumHorizontal<8>(s[0], &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]);
   SumHorizontal<8>(s[1], &s3[0][3], &s3[1][3], &s5[0][4], &s5[1][4]);
   SumHorizontal(sq[0] + 1, &sq3[2], &sq5[3]);
   SumHorizontal(sq[1] + 1, &sq3[3], &sq5[4]);
   vst1q_u16(sum3[2] + x, s3[0][2]);
   vst1q_u16(sum3[3] + x, s3[0][3]);
-  vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
-  vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
-  vst1q_u32(square_sum3[3] + x + 0, sq3[3].val[0]);
-  vst1q_u32(square_sum3[3] + x + 4, sq3[3].val[1]);
+  StoreAligned32U32(square_sum3[2] + x, sq3[2]);
+  StoreAligned32U32(square_sum3[3] + x, sq3[3]);
   vst1q_u16(sum5[3] + x, s5[0][3]);
   vst1q_u16(sum5[4] + x, s5[0][4]);
-  vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
-  vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
-  vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
-  vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
-  s3[0][0] = vld1q_u16(sum3[0] + x);
-  s3[0][1] = vld1q_u16(sum3[1] + x);
-  sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
-  sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
-  sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
-  sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
-  s5[0][0] = vld1q_u16(sum5[0] + x);
-  s5[0][1] = vld1q_u16(sum5[1] + x);
-  s5[0][2] = vld1q_u16(sum5[2] + x);
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+  StoreAligned32U32(square_sum5[3] + x, sq5[3]);
+  StoreAligned32U32(square_sum5[4] + x, sq5[4]);
+  LoadAligned16x2U16(sum3, x, s3[0]);
+  LoadAligned32x2U32(square_sum3, x, sq3);
+  LoadAligned16x3U16(sum5, x, s5[0]);
+  LoadAligned32x3U32(square_sum5, x, sq5);
   CalculateIntermediate3<8>(s3[0], sq3, scales[1], &ma3[0][0], &b3[0][1]);
   CalculateIntermediate3<8>(s3[0] + 1, sq3 + 1, scales[1], &ma3[1][0],
                             &b3[1][1]);
   CalculateIntermediate5<8>(s5[0], sq5, scales[0], &ma5[0], &b5[0]);
 
-  sq[0][3] = vmull_u8(vget_high_u8(s[0][1]), vget_high_u8(s[0][1]));
-  sq[1][3] = vmull_u8(vget_high_u8(s[1][1]), vget_high_u8(s[1][1]));
+  sq[0][3] = SquareHi8(s[0][1]);
+  sq[1][3] = SquareHi8(s[1][1]);
   SumHorizontal(sq[0] + 2, &sq3[2], &sq5[3]);
   SumHorizontal(sq[1] + 2, &sq3[3], &sq5[4]);
   vst1q_u16(sum3[2] + x + 8, s3[1][2]);
   vst1q_u16(sum3[3] + x + 8, s3[1][3]);
-  vst1q_u32(square_sum3[2] + x + 8, sq3[2].val[0]);
-  vst1q_u32(square_sum3[2] + x + 12, sq3[2].val[1]);
-  vst1q_u32(square_sum3[3] + x + 8, sq3[3].val[0]);
-  vst1q_u32(square_sum3[3] + x + 12, sq3[3].val[1]);
+  StoreAligned32U32(square_sum3[2] + x + 8, sq3[2]);
+  StoreAligned32U32(square_sum3[3] + x + 8, sq3[3]);
   vst1q_u16(sum5[3] + x + 8, s5[1][3]);
   vst1q_u16(sum5[4] + x + 8, s5[1][4]);
-  vst1q_u32(square_sum5[3] + x + 8, sq5[3].val[0]);
-  vst1q_u32(square_sum5[3] + x + 12, sq5[3].val[1]);
-  vst1q_u32(square_sum5[4] + x + 8, sq5[4].val[0]);
-  vst1q_u32(square_sum5[4] + x + 12, sq5[4].val[1]);
-  s3[1][0] = vld1q_u16(sum3[0] + x + 8);
-  s3[1][1] = vld1q_u16(sum3[1] + x + 8);
-  sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 8);
-  sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 12);
-  sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 8);
-  sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 12);
-  s5[1][0] = vld1q_u16(sum5[0] + x + 8);
-  s5[1][1] = vld1q_u16(sum5[1] + x + 8);
-  s5[1][2] = vld1q_u16(sum5[2] + x + 8);
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 8);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 12);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 8);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 12);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 8);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 12);
+  StoreAligned32U32(square_sum5[3] + x + 8, sq5[3]);
+  StoreAligned32U32(square_sum5[4] + x + 8, sq5[4]);
+  LoadAligned16x2U16(sum3, x + 8, s3[1]);
+  LoadAligned32x2U32(square_sum3, x + 8, sq3);
+  LoadAligned16x3U16(sum5, x + 8, s5[1]);
+  LoadAligned32x3U32(square_sum5, x + 8, sq5);
   CalculateIntermediate3<0>(s3[1], sq3, scales[1], &ma3[0][1], &b3[0][2]);
   CalculateIntermediate3<0>(s3[1] + 1, sq3 + 1, scales[1], &ma3[1][1],
                             &b3[1][2]);
@@ -1483,90 +1443,55 @@
 }
 
 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo(
-    const uint8_t* const src, const uint16_t scales[2],
+    uint8x16_t* const s, const uint16_t scales[2],
     const uint16_t* const sum3[4], const uint16_t* const sum5[5],
     const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
-    uint8x16_t* const s, uint16x8_t sq[2], uint8x16_t* const ma3,
-    uint8x16_t* const ma5, uint16x8_t* const b3, uint16x8_t* const b5) {
+    uint16x8_t sq[2], uint8x16_t* const ma3, uint8x16_t* const ma5,
+    uint16x8_t* const b3, uint16x8_t* const b5) {
   uint16x8_t s3[3], s5[5];
   uint32x4x2_t sq3[3], sq5[5];
-  *s = vld1q_u8(src);
-  sq[0] = vmull_u8(vget_low_u8(*s), vget_low_u8(*s));
-  sq[1] = vmull_u8(vget_high_u8(*s), vget_high_u8(*s));
+  sq[0] = SquareLo8(s[0]);
+  sq[1] = SquareHi8(s[0]);
   SumHorizontal(*s, sq, &s3[2], &s5[3], &sq3[2], &sq5[3]);
-  s5[0] = vld1q_u16(sum5[0]);
-  s5[1] = vld1q_u16(sum5[1]);
-  s5[2] = vld1q_u16(sum5[2]);
+  LoadAligned16x3U16(sum5, 0, s5);
   s5[4] = s5[3];
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + 0);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + 4);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + 0);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + 4);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + 0);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + 4);
+  LoadAligned32x3U32(square_sum5, 0, sq5);
   sq5[4] = sq5[3];
   CalculateIntermediate5<0>(s5, sq5, scales[0], ma5, b5);
-  s3[0] = vld1q_u16(sum3[0]);
-  s3[1] = vld1q_u16(sum3[1]);
-  sq3[0].val[0] = vld1q_u32(square_sum3[0] + 0);
-  sq3[0].val[1] = vld1q_u32(square_sum3[0] + 4);
-  sq3[1].val[0] = vld1q_u32(square_sum3[1] + 0);
-  sq3[1].val[1] = vld1q_u32(square_sum3[1] + 4);
+  LoadAligned16x2U16(sum3, 0, s3);
+  LoadAligned32x2U32(square_sum3, 0, sq3);
   CalculateIntermediate3<0>(s3, sq3, scales[1], ma3, b3);
 }
 
 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
-    const uint8_t* const src, const ptrdiff_t x, const uint16_t scales[2],
+    uint8x16_t s[2], const ptrdiff_t x, const uint16_t scales[2],
     const uint16_t* const sum3[4], const uint16_t* const sum5[5],
     const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
-    uint8x16_t s[2], uint16x8_t sq[3], uint8x16_t ma3[2], uint8x16_t ma5[2],
-    uint16x8_t b3[2], uint16x8_t b5[2]) {
+    uint16x8_t sq[3], uint8x16_t ma3[2], uint8x16_t ma5[2], uint16x8_t b3[2],
+    uint16x8_t b5[2]) {
   uint16x8_t s3[2][3], s5[2][5];
   uint32x4x2_t sq3[3], sq5[5];
-  s[1] = vld1q_u8(src + x + 8);
-  sq[1] = vmull_u8(vget_low_u8(s[1]), vget_low_u8(s[1]));
+  sq[1] = SquareLo8(s[1]);
   SumHorizontal<8>(s, &s3[0][2], &s3[1][2], &s5[0][3], &s5[1][3]);
   SumHorizontal(sq, &sq3[2], &sq5[3]);
-  s5[0][0] = vld1q_u16(sum5[0] + x);
-  s5[0][1] = vld1q_u16(sum5[1] + x);
-  s5[0][2] = vld1q_u16(sum5[2] + x);
+  LoadAligned16x3U16(sum5, x, s5[0]);
   s5[0][4] = s5[0][3];
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
+  LoadAligned32x3U32(square_sum5, x, sq5);
   sq5[4] = sq5[3];
   CalculateIntermediate5<8>(s5[0], sq5, scales[0], &ma5[0], &b5[0]);
-  s3[0][0] = vld1q_u16(sum3[0] + x);
-  s3[0][1] = vld1q_u16(sum3[1] + x);
-  sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
-  sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
-  sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
-  sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
+  LoadAligned16x2U16(sum3, x, s3[0]);
+  LoadAligned32x2U32(square_sum3, x, sq3);
   CalculateIntermediate3<8>(s3[0], sq3, scales[1], &ma3[0], &b3[0]);
 
-  sq[2] = vmull_u8(vget_high_u8(s[1]), vget_high_u8(s[1]));
+  sq[2] = SquareHi8(s[1]);
   SumHorizontal(sq + 1, &sq3[2], &sq5[3]);
-  s5[1][0] = vld1q_u16(sum5[0] + x + 8);
-  s5[1][1] = vld1q_u16(sum5[1] + x + 8);
-  s5[1][2] = vld1q_u16(sum5[2] + x + 8);
+  LoadAligned16x3U16(sum5, x + 8, s5[1]);
   s5[1][4] = s5[1][3];
-  sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 8);
-  sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 12);
-  sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 8);
-  sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 12);
-  sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 8);
-  sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 12);
+  LoadAligned32x3U32(square_sum5, x + 8, sq5);
   sq5[4] = sq5[3];
   CalculateIntermediate5<0>(s5[1], sq5, scales[0], &ma5[1], &b5[1]);
-  s3[1][0] = vld1q_u16(sum3[0] + x + 8);
-  s3[1][1] = vld1q_u16(sum3[1] + x + 8);
-  sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 8);
-  sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 12);
-  sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 8);
-  sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 12);
+  LoadAligned16x2U16(sum3, x + 8, s3[1]);
+  LoadAligned32x2U32(square_sum3, x + 8, sq3);
   CalculateIntermediate3<0>(s3[1], sq3, scales[1], &ma3[1], &b3[1]);
 }
 
@@ -1576,18 +1501,23 @@
                                     uint16_t* const sum5[5],
                                     uint32_t* const square_sum5[5],
                                     uint16_t* ma565, uint32_t* b565) {
+  const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass1 - width;
   uint8x16_t s[2][2], mas[2];
   uint16x8_t sq[2][4], bs[3];
-  BoxFilterPreProcess5Lo(src0, src1, scale, s, sum5, square_sum5, sq, &mas[0],
-                         &bs[0]);
+  // TODO(b/194217060): Future msan load.
+  s[0][0] = vld1q_u8(src0);
+  s[1][0] = vld1q_u8(src1);
+
+  BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], &bs[0]);
 
   int x = 0;
   do {
     uint16x8_t ma[2];
     uint8x16_t masx[3];
     uint32x4x2_t b[2];
-    BoxFilterPreProcess5(src0, src1, x + 8, scale, s, sum5, square_sum5, sq,
-                         mas, bs + 1);
+    s[0][1] = Load1QMsanU8(src0 + x + 16, x + 16 + overread_in_bytes);
+    s[1][1] = Load1QMsanU8(src1 + x + 16, x + 16 + overread_in_bytes);
+    BoxFilterPreProcess5(s, x + 8, scale, sum5, square_sum5, sq, mas, bs + 1);
     Prepare3_8<0>(mas, masx);
     ma[0] = Sum565<0>(masx);
     b[0] = Sum565W(bs);
@@ -1617,15 +1547,17 @@
     const uint8_t* const src, const int width, const uint32_t scale,
     uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16_t* ma343,
     uint16_t* ma444, uint32_t* b343, uint32_t* b444) {
+  const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass2 - width;
   uint8x16_t s[2], mas[2];
   uint16x8_t sq[4], bs[3];
-  BoxFilterPreProcess3Lo(src, scale, &s[0], sum3, square_sum3, sq, &mas[0],
-                         &bs[0]);
+  s[0] = Load1QMsanU8(src, overread_in_bytes);
+  BoxFilterPreProcess3Lo(&s[0], scale, sum3, square_sum3, sq, &mas[0], &bs[0]);
 
   int x = 0;
   do {
     uint8x16_t ma3x[3];
-    BoxFilterPreProcess3(src, x + 8, scale, sum3, square_sum3, s, sq + 1, mas,
+    s[1] = Load1QMsanU8(src + x + 16, x + 16 + overread_in_bytes);
+    BoxFilterPreProcess3(s, x + 8, scale, sum3, square_sum3, sq + 1, mas,
                          bs + 1);
     Prepare3_8<0>(mas, ma3x);
     if (calculate444) {
@@ -1664,43 +1596,43 @@
     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
     uint16_t* const ma343[4], uint16_t* const ma444, uint16_t* ma565,
     uint32_t* const b343[4], uint32_t* const b444, uint32_t* b565) {
+  const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass1 - width;
   uint8x16_t s[2][2], ma3[2][2], ma5[2];
   uint16x8_t sq[2][4], b3[2][3], b5[3];
-  BoxFilterPreProcessLo(src0, src1, scales, s, sum3, sum5, square_sum3,
-                        square_sum5, sq, ma3, b3, &ma5[0], &b5[0]);
+  // TODO(b/194217060): Future msan load.
+  s[0][0] = vld1q_u8(src0);
+  s[1][0] = vld1q_u8(src1);
+
+  BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq,
+                        ma3, b3, &ma5[0], &b5[0]);
 
   int x = 0;
   do {
     uint16x8_t ma[2];
     uint8x16_t ma3x[3], ma5x[3];
     uint32x4x2_t b[2];
-    BoxFilterPreProcess(src0, src1, x + 8, scales, s, sum3, sum5, square_sum3,
-                        square_sum5, sq, ma3, b3, ma5, b5 + 1);
+
+    s[0][1] = Load1QMsanU8(src0 + x + 16, x + 16 + overread_in_bytes);
+    s[1][1] = Load1QMsanU8(src1 + x + 16, x + 16 + overread_in_bytes);
+    BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5,
+                        sq, ma3, b3, ma5, b5 + 1);
     Prepare3_8<0>(ma3[0], ma3x);
     ma[0] = Sum343<0>(ma3x);
     ma[1] = Sum343<8>(ma3x);
+    StoreAligned32U16(ma343[0] + x, ma);
     b[0] = Sum343W(b3[0] + 0);
     b[1] = Sum343W(b3[0] + 1);
-    vst1q_u16(ma343[0] + x, ma[0]);
-    vst1q_u16(ma343[0] + x + 8, ma[1]);
-    vst1q_u32(b343[0] + x, b[0].val[0]);
-    vst1q_u32(b343[0] + x + 4, b[0].val[1]);
-    vst1q_u32(b343[0] + x + 8, b[1].val[0]);
-    vst1q_u32(b343[0] + x + 12, b[1].val[1]);
+    StoreAligned64U32(b343[0] + x, b);
     Prepare3_8<0>(ma3[1], ma3x);
     Store343_444<0>(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444);
     Store343_444<8>(ma3x, b3[1] + 1, x + 8, ma343[1], ma444, b343[1], b444);
     Prepare3_8<0>(ma5, ma5x);
     ma[0] = Sum565<0>(ma5x);
     ma[1] = Sum565<8>(ma5x);
+    StoreAligned32U16(ma565, ma);
     b[0] = Sum565W(b5);
     b[1] = Sum565W(b5 + 1);
-    vst1q_u16(ma565, ma[0]);
-    vst1q_u16(ma565 + 8, ma[1]);
-    vst1q_u32(b565 + 0, b[0].val[0]);
-    vst1q_u32(b565 + 4, b[0].val[1]);
-    vst1q_u32(b565 + 8, b[1].val[0]);
-    vst1q_u32(b565 + 12, b[1].val[1]);
+    StoreAligned64U32(b565, b);
     s[0][0] = s[0][1];
     s[1][0] = s[1][1];
     sq[0][1] = sq[0][3];
@@ -1799,10 +1731,13 @@
     uint32_t* const square_sum5[5], const int width, const uint32_t scale,
     const int16_t w0, uint16_t* const ma565[2], uint32_t* const b565[2],
     uint8_t* const dst) {
+  const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass1 - width;
   uint8x16_t s[2][2], mas[2];
   uint16x8_t sq[2][4], bs[3];
-  BoxFilterPreProcess5Lo(src0, src1, scale, s, sum5, square_sum5, sq, &mas[0],
-                         &bs[0]);
+  s[0][0] = Load1QMsanU8(src0, overread_in_bytes);
+  s[1][0] = Load1QMsanU8(src1, overread_in_bytes);
+
+  BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq, &mas[0], &bs[0]);
 
   int x = 0;
   do {
@@ -1810,8 +1745,9 @@
     uint8x16_t masx[3];
     uint32x4x2_t b[2];
     int16x8_t p0, p1;
-    BoxFilterPreProcess5(src0, src1, x + 8, scale, s, sum5, square_sum5, sq,
-                         mas, bs + 1);
+    s[0][1] = Load1QMsanU8(src0 + x + 16, x + 16 + overread_in_bytes);
+    s[1][1] = Load1QMsanU8(src1 + x + 16, x + 16 + overread_in_bytes);
+    BoxFilterPreProcess5(s, x + 8, scale, sum5, square_sum5, sq, mas, bs + 1);
     Prepare3_8<0>(mas, masx);
     ma[1] = Sum565<0>(masx);
     b[1] = Sum565W(bs);
@@ -1865,7 +1801,10 @@
                                   uint8_t* const dst) {
   uint8x16_t s[2], mas[2];
   uint16x8_t sq[4], bs[4];
-  BoxFilterPreProcess5LastRowLo(src0, scale, s, sum5, square_sum5, sq, &mas[0],
+  // TODO(b/194217060): Future msan load.
+  s[0] = vld1q_u8(src0);
+
+  BoxFilterPreProcess5LastRowLo(s, scale, sum5, square_sum5, sq, &mas[0],
                                 &bs[0]);
 
   int x = 0;
@@ -1873,8 +1812,11 @@
     uint16x8_t ma[2];
     uint8x16_t masx[3];
     uint32x4x2_t b[2];
-    BoxFilterPreProcess5LastRow(src0, x + 8, scale, s, sum5, square_sum5,
-                                sq + 1, mas, bs + 1);
+    // TODO(b/194217060): Future msan load.
+    s[1] = vld1q_u8(src0 + x + 16);
+
+    BoxFilterPreProcess5LastRow(s, x + 8, scale, sum5, square_sum5, sq + 1, mas,
+                                bs + 1);
     Prepare3_8<0>(mas, masx);
     ma[1] = Sum565<0>(masx);
     b[1] = Sum565W(bs);
@@ -1911,17 +1853,21 @@
     uint32_t* const square_sum3[3], uint16_t* const ma343[3],
     uint16_t* const ma444[2], uint32_t* const b343[3], uint32_t* const b444[2],
     uint8_t* const dst) {
+  const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass2 - width;
   uint8x16_t s[2], mas[2];
   uint16x8_t sq[4], bs[3];
-  BoxFilterPreProcess3Lo(src0, scale, &s[0], sum3, square_sum3, sq, &mas[0],
-                         &bs[0]);
+  // TODO(b/194217060): Future msan load.
+  s[0] = vld1q_u8(src0);
+
+  BoxFilterPreProcess3Lo(&s[0], scale, sum3, square_sum3, sq, &mas[0], &bs[0]);
 
   int x = 0;
   do {
     uint16x8_t ma[3];
     uint8x16_t ma3x[3];
     uint32x4x2_t b[3];
-    BoxFilterPreProcess3(src0, x + 8, scale, sum3, square_sum3, s, sq + 1, mas,
+    s[1] = Load1QMsanU8(src0 + x + 16, x + 16 + overread_in_bytes);
+    BoxFilterPreProcess3(s, x + 8, scale, sum3, square_sum3, sq + 1, mas,
                          bs + 1);
     Prepare3_8<0>(mas, ma3x);
     Store343_444<0>(ma3x, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2],
@@ -1966,10 +1912,15 @@
     uint16_t* const ma343[4], uint16_t* const ma444[3],
     uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3],
     uint32_t* const b565[2], uint8_t* const dst) {
+  const ptrdiff_t overread_in_bytes = kWideOverreadInBytesPass1 - width;
   uint8x16_t s[2][2], ma3[2][2], ma5[2];
   uint16x8_t sq[2][4], b3[2][3], b5[3];
-  BoxFilterPreProcessLo(src0, src1, scales, s, sum3, sum5, square_sum3,
-                        square_sum5, sq, ma3, b3, &ma5[0], &b5[0]);
+  // TODO(b/194217060): Future msan load.
+  s[0][0] = vld1q_u8(src0);
+  s[1][0] = vld1q_u8(src1);
+
+  BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq,
+                        ma3, b3, &ma5[0], &b5[0]);
 
   int x = 0;
   do {
@@ -1977,8 +1928,10 @@
     uint8x16_t ma3x[2][3], ma5x[3];
     uint32x4x2_t b[3][3];
     int16x8_t p[2][2];
-    BoxFilterPreProcess(src0, src1, x + 8, scales, s, sum3, sum5, square_sum3,
-                        square_sum5, sq, ma3, b3, ma5, b5 + 1);
+    s[0][1] = Load1QMsanU8(src0 + x + 16, x + 16 + overread_in_bytes);
+    s[1][1] = Load1QMsanU8(src1 + x + 16, x + 16 + overread_in_bytes);
+    BoxFilterPreProcess(s, x + 8, scales, sum3, sum5, square_sum3, square_sum5,
+                        sq, ma3, b3, ma5, b5 + 1);
     Prepare3_8<0>(ma3[0], ma3x[0]);
     Prepare3_8<0>(ma3[1], ma3x[1]);
     Store343_444<0>(ma3x[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1],
@@ -2070,17 +2023,21 @@
   uint8x16_t s[2], ma3[2], ma5[2];
   uint16x8_t sq[4], ma[3], b3[3], b5[3];
   uint32x4x2_t b[3];
-  BoxFilterPreProcessLastRowLo(src0, scales, sum3, sum5, square_sum3,
-                               square_sum5, &s[0], sq, &ma3[0], &ma5[0], &b3[0],
-                               &b5[0]);
+  // TODO(b/194217060): Future msan load.
+  s[0] = vld1q_u8(src0);
+
+  BoxFilterPreProcessLastRowLo(s, scales, sum3, sum5, square_sum3, square_sum5,
+                               sq, &ma3[0], &ma5[0], &b3[0], &b5[0]);
 
   int x = 0;
   do {
     uint8x16_t ma3x[3], ma5x[3];
     int16x8_t p[2];
-    BoxFilterPreProcessLastRow(src0, x + 8, scales, sum3, sum5, square_sum3,
-                               square_sum5, s, sq + 1, ma3, ma5, &b3[1],
-                               &b5[1]);
+    // TODO(b/194217060): Future msan load.
+    s[1] = vld1q_u8(src0 + x + 16);
+
+    BoxFilterPreProcessLastRow(s, x + 8, scales, sum3, sum5, square_sum3,
+                               square_sum5, sq + 1, ma3, ma5, &b3[1], &b5[1]);
     Prepare3_8<0>(ma5, ma5x);
     ma[1] = Sum565<0>(ma5x);
     b[1] = Sum565W(b5);
@@ -2137,6 +2094,7 @@
     const ptrdiff_t bottom_border_stride, const int width, const int height,
     SgrBuffer* const sgr_buffer, uint8_t* dst) {
   const auto temp_stride = Align<ptrdiff_t>(width, 16);
+  const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
   const ptrdiff_t sum_stride = temp_stride + 8;
   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
   const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index];  // < 2^12.
@@ -2173,8 +2131,8 @@
   b565[1] = b565[0] + temp_stride;
   assert(scales[0] != 0);
   assert(scales[1] != 0);
-  BoxSum(top_border, top_border_stride, sum_stride, sum3[0], sum5[1],
-         square_sum3[0], square_sum5[1]);
+  BoxSum(top_border, top_border_stride, width, sum_stride, sum_width, sum3[0],
+         sum5[1], square_sum3[0], square_sum5[1]);
   sum5[0] = sum5[1];
   square_sum5[0] = square_sum5[1];
   const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
@@ -2250,6 +2208,7 @@
                                   const int width, const int height,
                                   SgrBuffer* const sgr_buffer, uint8_t* dst) {
   const auto temp_stride = Align<ptrdiff_t>(width, 16);
+  const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
   const ptrdiff_t sum_stride = temp_stride + 8;
   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0];  // < 2^12.
@@ -2267,7 +2226,8 @@
   b565[0] = sgr_buffer->b565;
   b565[1] = b565[0] + temp_stride;
   assert(scale != 0);
-  BoxSum<5>(top_border, top_border_stride, sum_stride, sum5[1], square_sum5[1]);
+  BoxSum<5>(top_border, top_border_stride, width, sum_stride, sum_width,
+            sum5[1], square_sum5[1]);
   sum5[0] = sum5[1];
   square_sum5[0] = square_sum5[1];
   const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
@@ -2325,6 +2285,7 @@
                                   SgrBuffer* const sgr_buffer, uint8_t* dst) {
   assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
   const auto temp_stride = Align<ptrdiff_t>(width, 16);
+  const auto sum_width = Align<ptrdiff_t>(width + 8, 16);
   const ptrdiff_t sum_stride = temp_stride + 8;
   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
   const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
@@ -2347,7 +2308,8 @@
   b444[0] = sgr_buffer->b444;
   b444[1] = b444[0] + temp_stride;
   assert(scale != 0);
-  BoxSum<3>(top_border, top_border_stride, sum_stride, sum3[0], square_sum3[0]);
+  BoxSum<3>(top_border, top_border_stride, width, sum_stride, sum_width,
+            sum3[0], square_sum3[0]);
   BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, ma343[0],
                                  nullptr, b343[0], nullptr);
   Circulate3PointersBy1<uint16_t>(sum3);
@@ -2396,11 +2358,14 @@
 // the end of each row. It is safe to overwrite the output as it will not be
 // part of the visible frame.
 void SelfGuidedFilter_NEON(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int index = restoration_info.sgr_proj_info.index;
   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
@@ -2409,6 +2374,12 @@
   const auto* bottom = static_cast<const uint8_t*>(bottom_border);
   auto* const dst = static_cast<uint8_t*>(dest);
   SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
+
+#if LIBGAV1_MSAN
+  // Initialize to prevent msan warnings when intermediate overreads occur.
+  memset(sgr_buffer, 0, sizeof(SgrBuffer));
+#endif
+
   if (radius_pass_1 == 0) {
     // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
     // following assertion.
diff --git a/libgav1/src/dsp/arm/loop_restoration_neon.h b/libgav1/src/dsp/arm/loop_restoration_neon.h
index b551610..b9a4803 100644
--- a/libgav1/src/dsp/arm/loop_restoration_neon.h
+++ b/libgav1/src/dsp/arm/loop_restoration_neon.h
@@ -26,6 +26,7 @@
 // Initializes Dsp::loop_restorations, see the defines below for specifics.
 // This function is not thread-safe.
 void LoopRestorationInit_NEON();
+void LoopRestorationInit10bpp_NEON();
 
 }  // namespace dsp
 }  // namespace libgav1
@@ -35,6 +36,9 @@
 #define LIBGAV1_Dsp8bpp_WienerFilter LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_SelfGuidedFilter LIBGAV1_CPU_NEON
 
+#define LIBGAV1_Dsp10bpp_WienerFilter LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_SelfGuidedFilter LIBGAV1_CPU_NEON
+
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_LOOP_RESTORATION_NEON_H_
diff --git a/libgav1/src/dsp/arm/mask_blend_neon.cc b/libgav1/src/dsp/arm/mask_blend_neon.cc
index ee50923..853f949 100644
--- a/libgav1/src/dsp/arm/mask_blend_neon.cc
+++ b/libgav1/src/dsp/arm/mask_blend_neon.cc
@@ -79,10 +79,11 @@
   return vreinterpretq_s16_u16(vmovl_u8(mask_val));
 }
 
-inline void WriteMaskBlendLine4x2(const int16_t* const pred_0,
-                                  const int16_t* const pred_1,
+inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0,
+                                  const int16_t* LIBGAV1_RESTRICT const pred_1,
                                   const int16x8_t pred_mask_0,
-                                  const int16x8_t pred_mask_1, uint8_t* dst,
+                                  const int16x8_t pred_mask_1,
+                                  uint8_t* LIBGAV1_RESTRICT dst,
                                   const ptrdiff_t dst_stride) {
   const int16x8_t pred_val_0 = vld1q_s16(pred_0);
   const int16x8_t pred_val_1 = vld1q_s16(pred_1);
@@ -109,9 +110,11 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void MaskBlending4x4_NEON(const int16_t* pred_0, const int16_t* pred_1,
-                                 const uint8_t* mask,
-                                 const ptrdiff_t mask_stride, uint8_t* dst,
+inline void MaskBlending4x4_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
+                                 const int16_t* LIBGAV1_RESTRICT pred_1,
+                                 const uint8_t* LIBGAV1_RESTRICT mask,
+                                 const ptrdiff_t mask_stride,
+                                 uint8_t* LIBGAV1_RESTRICT dst,
                                  const ptrdiff_t dst_stride) {
   const int16x8_t mask_inverter = vdupq_n_s16(64);
   int16x8_t pred_mask_0 =
@@ -133,10 +136,12 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void MaskBlending4xH_NEON(const int16_t* pred_0, const int16_t* pred_1,
-                                 const uint8_t* const mask_ptr,
+inline void MaskBlending4xH_NEON(const int16_t* LIBGAV1_RESTRICT pred_0,
+                                 const int16_t* LIBGAV1_RESTRICT pred_1,
+                                 const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
                                  const ptrdiff_t mask_stride, const int height,
-                                 uint8_t* dst, const ptrdiff_t dst_stride) {
+                                 uint8_t* LIBGAV1_RESTRICT dst,
+                                 const ptrdiff_t dst_stride) {
   const uint8_t* mask = mask_ptr;
   if (height == 4) {
     MaskBlending4x4_NEON<subsampling_x, subsampling_y>(
@@ -188,11 +193,12 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void MaskBlend_NEON(const void* prediction_0, const void* prediction_1,
+inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                           const void* LIBGAV1_RESTRICT prediction_1,
                            const ptrdiff_t /*prediction_stride_1*/,
-                           const uint8_t* const mask_ptr,
+                           const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
                            const ptrdiff_t mask_stride, const int width,
-                           const int height, void* dest,
+                           const int height, void* LIBGAV1_RESTRICT dest,
                            const ptrdiff_t dst_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
@@ -302,11 +308,10 @@
   return vld1_u8(mask);
 }
 
-inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0,
-                                                uint8_t* const pred_1,
-                                                const ptrdiff_t pred_stride_1,
-                                                const uint8x8_t pred_mask_0,
-                                                const uint8x8_t pred_mask_1) {
+inline void InterIntraWriteMaskBlendLine8bpp4x2(
+    const uint8_t* LIBGAV1_RESTRICT const pred_0,
+    uint8_t* LIBGAV1_RESTRICT const pred_1, const ptrdiff_t pred_stride_1,
+    const uint8x8_t pred_mask_0, const uint8x8_t pred_mask_1) {
   const uint8x8_t pred_val_0 = vld1_u8(pred_0);
   uint8x8_t pred_val_1 = Load4(pred_1);
   pred_val_1 = Load4<1>(pred_1 + pred_stride_1, pred_val_1);
@@ -320,11 +325,10 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void InterIntraMaskBlending8bpp4x4_NEON(const uint8_t* pred_0,
-                                               uint8_t* pred_1,
-                                               const ptrdiff_t pred_stride_1,
-                                               const uint8_t* mask,
-                                               const ptrdiff_t mask_stride) {
+inline void InterIntraMaskBlending8bpp4x4_NEON(
+    const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
+    const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
+    const ptrdiff_t mask_stride) {
   const uint8x8_t mask_inverter = vdup_n_u8(64);
   uint8x8_t pred_mask_1 =
       GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
@@ -344,8 +348,9 @@
 
 template <int subsampling_x, int subsampling_y>
 inline void InterIntraMaskBlending8bpp4xH_NEON(
-    const uint8_t* pred_0, uint8_t* pred_1, const ptrdiff_t pred_stride_1,
-    const uint8_t* mask, const ptrdiff_t mask_stride, const int height) {
+    const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
+    const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
+    const ptrdiff_t mask_stride, const int height) {
   if (height == 4) {
     InterIntraMaskBlending8bpp4x4_NEON<subsampling_x, subsampling_y>(
         pred_0, pred_1, pred_stride_1, mask, mask_stride);
@@ -369,12 +374,11 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void InterIntraMaskBlend8bpp_NEON(const uint8_t* prediction_0,
-                                         uint8_t* prediction_1,
-                                         const ptrdiff_t prediction_stride_1,
-                                         const uint8_t* const mask_ptr,
-                                         const ptrdiff_t mask_stride,
-                                         const int width, const int height) {
+inline void InterIntraMaskBlend8bpp_NEON(
+    const uint8_t* LIBGAV1_RESTRICT prediction_0,
+    uint8_t* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t prediction_stride_1,
+    const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
+    const int width, const int height) {
   if (width == 4) {
     InterIntraMaskBlending8bpp4xH_NEON<subsampling_x, subsampling_y>(
         prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
@@ -427,7 +431,293 @@
 }  // namespace
 }  // namespace low_bitdepth
 
-void MaskBlendInit_NEON() { low_bitdepth::Init8bpp(); }
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+template <int subsampling_x, int subsampling_y>
+inline uint16x8_t GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
+  if (subsampling_x == 1) {
+    const uint8x8_t mask_val0 = vld1_u8(mask);
+    const uint8x8_t mask_val1 = vld1_u8(mask + (mask_stride << subsampling_y));
+    uint16x8_t final_val = vpaddlq_u8(vcombine_u8(mask_val0, mask_val1));
+    if (subsampling_y == 1) {
+      const uint8x8_t next_mask_val0 = vld1_u8(mask + mask_stride);
+      const uint8x8_t next_mask_val1 = vld1_u8(mask + mask_stride * 3);
+      final_val = vaddq_u16(
+          final_val, vpaddlq_u8(vcombine_u8(next_mask_val0, next_mask_val1)));
+    }
+    return vrshrq_n_u16(final_val, subsampling_y + 1);
+  }
+  assert(subsampling_y == 0 && subsampling_x == 0);
+  const uint8x8_t mask_val0 = Load4(mask);
+  const uint8x8_t mask_val = Load4<1>(mask + mask_stride, mask_val0);
+  return vmovl_u8(mask_val);
+}
+
+template <int subsampling_x, int subsampling_y>
+inline uint16x8_t GetMask8(const uint8_t* mask, ptrdiff_t mask_stride) {
+  if (subsampling_x == 1) {
+    uint16x8_t mask_val = vpaddlq_u8(vld1q_u8(mask));
+    if (subsampling_y == 1) {
+      const uint16x8_t next_mask_val = vpaddlq_u8(vld1q_u8(mask + mask_stride));
+      mask_val = vaddq_u16(mask_val, next_mask_val);
+    }
+    return vrshrq_n_u16(mask_val, 1 + subsampling_y);
+  }
+  assert(subsampling_y == 0 && subsampling_x == 0);
+  const uint8x8_t mask_val = vld1_u8(mask);
+  return vmovl_u8(mask_val);
+}
+
+template <bool is_inter_intra>
+uint16x8_t SumWeightedPred(const uint16x8_t pred_mask_0,
+                           const uint16x8_t pred_mask_1,
+                           const uint16x8_t pred_val_0,
+                           const uint16x8_t pred_val_1) {
+  if (is_inter_intra) {
+    // dst[x] = static_cast<Pixel>(RightShiftWithRounding(
+    //     mask_value * pred_1[x] + (64 - mask_value) * pred_0[x], 6));
+    uint16x8_t sum = vmulq_u16(pred_mask_1, pred_val_0);
+    sum = vmlaq_u16(sum, pred_mask_0, pred_val_1);
+    return vrshrq_n_u16(sum, 6);
+  } else {
+    // int res = (mask_value * prediction_0[x] +
+    //      (64 - mask_value) * prediction_1[x]) >> 6;
+    const uint32x4_t weighted_pred_0_lo =
+        vmull_u16(vget_low_u16(pred_mask_0), vget_low_u16(pred_val_0));
+    const uint32x4_t weighted_pred_0_hi = VMullHighU16(pred_mask_0, pred_val_0);
+    uint32x4x2_t sum;
+    sum.val[0] = vmlal_u16(weighted_pred_0_lo, vget_low_u16(pred_mask_1),
+                           vget_low_u16(pred_val_1));
+    sum.val[1] = VMlalHighU16(weighted_pred_0_hi, pred_mask_1, pred_val_1);
+    return vcombine_u16(vshrn_n_u32(sum.val[0], 6), vshrn_n_u32(sum.val[1], 6));
+  }
+}
+
+template <bool is_inter_intra, int width, int bitdepth = 10>
+inline void StoreShiftedResult(uint8_t* dst, const uint16x8_t result,
+                               const ptrdiff_t dst_stride = 0) {
+  if (is_inter_intra) {
+    if (width == 4) {
+      // Store 2 lines of width 4.
+      assert(dst_stride != 0);
+      vst1_u16(reinterpret_cast<uint16_t*>(dst), vget_low_u16(result));
+      vst1_u16(reinterpret_cast<uint16_t*>(dst + dst_stride),
+               vget_high_u16(result));
+    } else {
+      // Store 1 line of width 8.
+      vst1q_u16(reinterpret_cast<uint16_t*>(dst), result);
+    }
+  } else {
+    // res -= (bitdepth == 8) ? 0 : kCompoundOffset;
+    // dst[x] = static_cast<Pixel>(
+    //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
+    //           (1 << kBitdepth8) - 1));
+    constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
+    const uint16x8_t compound_result =
+        vminq_u16(vrshrq_n_u16(vqsubq_u16(result, vdupq_n_u16(kCompoundOffset)),
+                               inter_post_round_bits),
+                  vdupq_n_u16((1 << bitdepth) - 1));
+    if (width == 4) {
+      // Store 2 lines of width 4.
+      assert(dst_stride != 0);
+      vst1_u16(reinterpret_cast<uint16_t*>(dst), vget_low_u16(compound_result));
+      vst1_u16(reinterpret_cast<uint16_t*>(dst + dst_stride),
+               vget_high_u16(compound_result));
+    } else {
+      // Store 1 line of width 8.
+      vst1q_u16(reinterpret_cast<uint16_t*>(dst), compound_result);
+    }
+  }
+}
+
+template <int subsampling_x, int subsampling_y, bool is_inter_intra>
+inline void MaskBlend4x2_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
+                              const uint16_t* LIBGAV1_RESTRICT pred_1,
+                              const ptrdiff_t pred_stride_1,
+                              const uint8_t* LIBGAV1_RESTRICT mask,
+                              const uint16x8_t mask_inverter,
+                              const ptrdiff_t mask_stride,
+                              uint8_t* LIBGAV1_RESTRICT dst,
+                              const ptrdiff_t dst_stride) {
+  // This works because stride == width == 4.
+  const uint16x8_t pred_val_0 = vld1q_u16(pred_0);
+  const uint16x8_t pred_val_1 =
+      is_inter_intra
+          ? vcombine_u16(vld1_u16(pred_1), vld1_u16(pred_1 + pred_stride_1))
+          : vld1q_u16(pred_1);
+  const uint16x8_t pred_mask_0 =
+      GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
+  const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+  const uint16x8_t weighted_pred_sum = SumWeightedPred<is_inter_intra>(
+      pred_mask_0, pred_mask_1, pred_val_0, pred_val_1);
+
+  StoreShiftedResult<is_inter_intra, 4>(dst, weighted_pred_sum, dst_stride);
+}
+
+template <int subsampling_x, int subsampling_y, bool is_inter_intra>
+inline void MaskBlending4x4_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
+                                 const uint16_t* LIBGAV1_RESTRICT pred_1,
+                                 const ptrdiff_t pred_stride_1,
+                                 const uint8_t* LIBGAV1_RESTRICT mask,
+                                 const ptrdiff_t mask_stride,
+                                 uint8_t* LIBGAV1_RESTRICT dst,
+                                 const ptrdiff_t dst_stride) {
+  // Double stride because the function works on 2 lines at a time.
+  const ptrdiff_t mask_stride_y = mask_stride << (subsampling_y + 1);
+  const ptrdiff_t dst_stride_y = dst_stride << 1;
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+
+  MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+      pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+      dst_stride);
+
+  pred_0 += 4 << 1;
+  pred_1 += pred_stride_1 << 1;
+  mask += mask_stride_y;
+  dst += dst_stride_y;
+
+  MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+      pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+      dst_stride);
+}
+
+template <int subsampling_x, int subsampling_y, bool is_inter_intra>
+inline void MaskBlending4xH_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
+                                 const uint16_t* LIBGAV1_RESTRICT pred_1,
+                                 const ptrdiff_t pred_stride_1,
+                                 const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
+                                 const ptrdiff_t mask_stride, const int height,
+                                 uint8_t* LIBGAV1_RESTRICT dst,
+                                 const ptrdiff_t dst_stride) {
+  const uint8_t* mask = mask_ptr;
+  if (height == 4) {
+    MaskBlending4x4_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+        pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride);
+    return;
+  }
+  // Double stride because the function works on 2 lines at a time.
+  const ptrdiff_t mask_stride_y = mask_stride << (subsampling_y + 1);
+  const ptrdiff_t dst_stride_y = dst_stride << 1;
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  int y = 0;
+  do {
+    MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+        pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+        dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += pred_stride_1 << 1;
+    mask += mask_stride_y;
+    dst += dst_stride_y;
+
+    MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+        pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+        dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += pred_stride_1 << 1;
+    mask += mask_stride_y;
+    dst += dst_stride_y;
+
+    MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+        pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+        dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += pred_stride_1 << 1;
+    mask += mask_stride_y;
+    dst += dst_stride_y;
+
+    MaskBlend4x2_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+        pred_0, pred_1, pred_stride_1, mask, mask_inverter, mask_stride, dst,
+        dst_stride);
+    pred_0 += 4 << 1;
+    pred_1 += pred_stride_1 << 1;
+    mask += mask_stride_y;
+    dst += dst_stride_y;
+    y += 8;
+  } while (y < height);
+}
+
+template <int subsampling_x, int subsampling_y, bool is_inter_intra>
+void MaskBlend8_NEON(const uint16_t* LIBGAV1_RESTRICT pred_0,
+                     const uint16_t* LIBGAV1_RESTRICT pred_1,
+                     const uint8_t* LIBGAV1_RESTRICT mask,
+                     const uint16x8_t mask_inverter,
+                     const ptrdiff_t mask_stride,
+                     uint8_t* LIBGAV1_RESTRICT dst) {
+  const uint16x8_t pred_val_0 = vld1q_u16(pred_0);
+  const uint16x8_t pred_val_1 = vld1q_u16(pred_1);
+  const uint16x8_t pred_mask_0 =
+      GetMask8<subsampling_x, subsampling_y>(mask, mask_stride);
+  const uint16x8_t pred_mask_1 = vsubq_u16(mask_inverter, pred_mask_0);
+  const uint16x8_t weighted_pred_sum = SumWeightedPred<is_inter_intra>(
+      pred_mask_0, pred_mask_1, pred_val_0, pred_val_1);
+
+  StoreShiftedResult<is_inter_intra, 8>(dst, weighted_pred_sum);
+}
+
+template <int subsampling_x, int subsampling_y, bool is_inter_intra>
+inline void MaskBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                           const void* LIBGAV1_RESTRICT prediction_1,
+                           const ptrdiff_t prediction_stride_1,
+                           const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
+                           const ptrdiff_t mask_stride, const int width,
+                           const int height, void* LIBGAV1_RESTRICT dest,
+                           const ptrdiff_t dst_stride) {
+  if (!is_inter_intra) {
+    assert(prediction_stride_1 == width);
+  }
+  auto* dst = static_cast<uint8_t*>(dest);
+  const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
+  const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
+  if (width == 4) {
+    MaskBlending4xH_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+        pred_0, pred_1, prediction_stride_1, mask_ptr, mask_stride, height, dst,
+        dst_stride);
+    return;
+  }
+  const ptrdiff_t mask_stride_y = mask_stride << subsampling_y;
+  const uint8_t* mask = mask_ptr;
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  int y = 0;
+  do {
+    int x = 0;
+    do {
+      MaskBlend8_NEON<subsampling_x, subsampling_y, is_inter_intra>(
+          pred_0 + x, pred_1 + x, mask + (x << subsampling_x), mask_inverter,
+          mask_stride,
+          reinterpret_cast<uint8_t*>(reinterpret_cast<uint16_t*>(dst) + x));
+      x += 8;
+    } while (x < width);
+    dst += dst_stride;
+    pred_0 += width;
+    pred_1 += prediction_stride_1;
+    mask += mask_stride_y;
+  } while (++y < height);
+}
+
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->mask_blend[0][0] = MaskBlend_NEON<0, 0, false>;
+  dsp->mask_blend[1][0] = MaskBlend_NEON<1, 0, false>;
+  dsp->mask_blend[2][0] = MaskBlend_NEON<1, 1, false>;
+
+  dsp->mask_blend[0][1] = MaskBlend_NEON<0, 0, true>;
+  dsp->mask_blend[1][1] = MaskBlend_NEON<1, 0, true>;
+  dsp->mask_blend[2][1] = MaskBlend_NEON<1, 1, true>;
+}
+
+}  // namespace
+}  // namespace high_bitdepth
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+void MaskBlendInit_NEON() {
+  low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  high_bitdepth::Init10bpp();
+#endif
+}
 
 }  // namespace dsp
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/arm/mask_blend_neon.h b/libgav1/src/dsp/arm/mask_blend_neon.h
index 3829274..c24f2f8 100644
--- a/libgav1/src/dsp/arm/mask_blend_neon.h
+++ b/libgav1/src/dsp/arm/mask_blend_neon.h
@@ -36,6 +36,13 @@
 #define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp444 LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp422 LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_InterIntraMaskBlend8bpp420 LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_MaskBlend444 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_MaskBlend422 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_MaskBlend420 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_MaskBlendInterIntra444 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_MaskBlendInterIntra422 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_MaskBlendInterIntra420 LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_MASK_BLEND_NEON_H_
diff --git a/libgav1/src/dsp/arm/motion_field_projection_neon.cc b/libgav1/src/dsp/arm/motion_field_projection_neon.cc
index 3e731b2..144adf7 100644
--- a/libgav1/src/dsp/arm/motion_field_projection_neon.cc
+++ b/libgav1/src/dsp/arm/motion_field_projection_neon.cc
@@ -356,27 +356,12 @@
   } while (++y8 < y8_end);
 }
 
-void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
-  assert(dsp != nullptr);
-  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON;
-}
-
-#if LIBGAV1_MAX_BITDEPTH >= 10
-void Init10bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
-  assert(dsp != nullptr);
-  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON;
-}
-#endif
-
 }  // namespace
 
 void MotionFieldProjectionInit_NEON() {
-  Init8bpp();
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  Init10bpp();
-#endif
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON;
 }
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/arm/motion_vector_search_neon.cc b/libgav1/src/dsp/arm/motion_vector_search_neon.cc
index da3ba17..4720879 100644
--- a/libgav1/src/dsp/arm/motion_vector_search_neon.cc
+++ b/libgav1/src/dsp/arm/motion_vector_search_neon.cc
@@ -61,8 +61,8 @@
 }
 
 inline int16x8_t MvProjectionCompoundClip(
-    const MotionVector* const temporal_mvs,
-    const int8_t* const temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT const temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT const temporal_reference_offsets,
     const int reference_offsets[2]) {
   const auto* const tmvs = reinterpret_cast<const int32_t*>(temporal_mvs);
   const int32x2_t temporal_mv = vld1_s32(tmvs);
@@ -76,9 +76,9 @@
 }
 
 inline int16x8_t MvProjectionSingleClip(
-    const MotionVector* const temporal_mvs,
-    const int8_t* const temporal_reference_offsets, const int reference_offset,
-    int16x4_t* const lookup) {
+    const MotionVector* LIBGAV1_RESTRICT const temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT const temporal_reference_offsets,
+    const int reference_offset, int16x4_t* const lookup) {
   const auto* const tmvs = reinterpret_cast<const int16_t*>(temporal_mvs);
   const int16x8_t temporal_mv = vld1q_s16(tmvs);
   *lookup = vld1_lane_s16(
@@ -116,9 +116,10 @@
 }
 
 void MvProjectionCompoundLowPrecision_NEON(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
     const int reference_offsets[2], const int count,
-    CompoundMotionVector* candidate_mvs) {
+    CompoundMotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // |reference_offsets| non-zero check usually equals true and is ignored.
   // To facilitate the compilers, make a local copy of |reference_offsets|.
   const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
@@ -131,13 +132,14 @@
     temporal_mvs += 2;
     temporal_reference_offsets += 2;
     candidate_mvs += 2;
-  } while (--loop_count);
+  } while (--loop_count != 0);
 }
 
 void MvProjectionCompoundForceInteger_NEON(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
     const int reference_offsets[2], const int count,
-    CompoundMotionVector* candidate_mvs) {
+    CompoundMotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // |reference_offsets| non-zero check usually equals true and is ignored.
   // To facilitate the compilers, make a local copy of |reference_offsets|.
   const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
@@ -150,13 +152,14 @@
     temporal_mvs += 2;
     temporal_reference_offsets += 2;
     candidate_mvs += 2;
-  } while (--loop_count);
+  } while (--loop_count != 0);
 }
 
 void MvProjectionCompoundHighPrecision_NEON(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
     const int reference_offsets[2], const int count,
-    CompoundMotionVector* candidate_mvs) {
+    CompoundMotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // |reference_offsets| non-zero check usually equals true and is ignored.
   // To facilitate the compilers, make a local copy of |reference_offsets|.
   const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
@@ -169,12 +172,14 @@
     temporal_mvs += 2;
     temporal_reference_offsets += 2;
     candidate_mvs += 2;
-  } while (--loop_count);
+  } while (--loop_count != 0);
 }
 
 void MvProjectionSingleLowPrecision_NEON(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
-    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
+    const int reference_offset, const int count,
+    MotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // Up to three more elements could be calculated.
   int loop_count = (count + 3) >> 2;
   int16x4_t lookup = vdup_n_s16(0);
@@ -185,12 +190,14 @@
     temporal_mvs += 4;
     temporal_reference_offsets += 4;
     candidate_mvs += 4;
-  } while (--loop_count);
+  } while (--loop_count != 0);
 }
 
 void MvProjectionSingleForceInteger_NEON(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
-    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
+    const int reference_offset, const int count,
+    MotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // Up to three more elements could be calculated.
   int loop_count = (count + 3) >> 2;
   int16x4_t lookup = vdup_n_s16(0);
@@ -201,12 +208,14 @@
     temporal_mvs += 4;
     temporal_reference_offsets += 4;
     candidate_mvs += 4;
-  } while (--loop_count);
+  } while (--loop_count != 0);
 }
 
 void MvProjectionSingleHighPrecision_NEON(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
-    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
+    const int reference_offset, const int count,
+    MotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // Up to three more elements could be calculated.
   int loop_count = (count + 3) >> 2;
   int16x4_t lookup = vdup_n_s16(0);
@@ -217,10 +226,12 @@
     temporal_mvs += 4;
     temporal_reference_offsets += 4;
     candidate_mvs += 4;
-  } while (--loop_count);
+  } while (--loop_count != 0);
 }
 
-void Init8bpp() {
+}  // namespace
+
+void MotionVectorSearchInit_NEON() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_NEON;
@@ -231,28 +242,6 @@
   dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_NEON;
 }
 
-#if LIBGAV1_MAX_BITDEPTH >= 10
-void Init10bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
-  assert(dsp != nullptr);
-  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_NEON;
-  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_NEON;
-  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_NEON;
-  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_NEON;
-  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_NEON;
-  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_NEON;
-}
-#endif
-
-}  // namespace
-
-void MotionVectorSearchInit_NEON() {
-  Init8bpp();
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  Init10bpp();
-#endif
-}
-
 }  // namespace dsp
 }  // namespace libgav1
 
diff --git a/libgav1/src/dsp/arm/obmc_neon.cc b/libgav1/src/dsp/arm/obmc_neon.cc
index 1111a90..659ed8e 100644
--- a/libgav1/src/dsp/arm/obmc_neon.cc
+++ b/libgav1/src/dsp/arm/obmc_neon.cc
@@ -33,10 +33,15 @@
 namespace libgav1 {
 namespace dsp {
 namespace {
-
 #include "src/dsp/obmc.inc"
 
-inline void WriteObmcLine4(uint8_t* const pred, const uint8_t* const obmc_pred,
+}  // namespace
+
+namespace low_bitdepth {
+namespace {
+
+inline void WriteObmcLine4(uint8_t* LIBGAV1_RESTRICT const pred,
+                           const uint8_t* LIBGAV1_RESTRICT const obmc_pred,
                            const uint8x8_t pred_mask,
                            const uint8x8_t obmc_pred_mask) {
   const uint8x8_t pred_val = Load4(pred);
@@ -47,35 +52,17 @@
   StoreLo4(pred, result);
 }
 
-template <bool from_left>
-inline void OverlapBlend2xH_NEON(uint8_t* const prediction,
-                                 const ptrdiff_t prediction_stride,
-                                 const int height,
-                                 const uint8_t* const obmc_prediction,
-                                 const ptrdiff_t obmc_prediction_stride) {
-  uint8_t* pred = prediction;
+inline void OverlapBlendFromLeft2xH_NEON(
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
+    const ptrdiff_t obmc_prediction_stride) {
   const uint8x8_t mask_inverter = vdup_n_u8(64);
-  const uint8_t* obmc_pred = obmc_prediction;
-  uint8x8_t pred_mask;
-  uint8x8_t obmc_pred_mask;
-  int compute_height;
-  const int mask_offset = height - 2;
-  if (from_left) {
-    pred_mask = Load2(kObmcMask);
-    obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
-    compute_height = height;
-  } else {
-    // Weights for the last line are all 64, which is a no-op.
-    compute_height = height - 1;
-  }
+  const uint8x8_t pred_mask = Load2(kObmcMask);
+  const uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
   uint8x8_t pred_val = vdup_n_u8(0);
   uint8x8_t obmc_pred_val = vdup_n_u8(0);
   int y = 0;
   do {
-    if (!from_left) {
-      pred_mask = vdup_n_u8(kObmcMask[mask_offset + y]);
-      obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
-    }
     pred_val = Load2<0>(pred, pred_val);
     const uint16x8_t weighted_pred = vmull_u8(pred_mask, pred_val);
     obmc_pred_val = Load2<0>(obmc_pred, obmc_pred_val);
@@ -85,16 +72,13 @@
 
     pred += prediction_stride;
     obmc_pred += obmc_prediction_stride;
-  } while (++y != compute_height);
+  } while (++y != height);
 }
 
 inline void OverlapBlendFromLeft4xH_NEON(
-    uint8_t* const prediction, const ptrdiff_t prediction_stride,
-    const int height, const uint8_t* const obmc_prediction,
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
     const ptrdiff_t obmc_prediction_stride) {
-  uint8_t* pred = prediction;
-  const uint8_t* obmc_pred = obmc_prediction;
-
   const uint8x8_t mask_inverter = vdup_n_u8(64);
   const uint8x8_t pred_mask = Load4(kObmcMask + 2);
   // 64 - mask
@@ -114,11 +98,9 @@
 }
 
 inline void OverlapBlendFromLeft8xH_NEON(
-    uint8_t* const prediction, const ptrdiff_t prediction_stride,
-    const int height, const uint8_t* const obmc_prediction,
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
     const ptrdiff_t obmc_prediction_stride) {
-  uint8_t* pred = prediction;
-  const uint8_t* obmc_pred = obmc_prediction;
   const uint8x8_t mask_inverter = vdup_n_u8(64);
   const uint8x8_t pred_mask = vld1_u8(kObmcMask + 6);
   // 64 - mask
@@ -137,17 +119,19 @@
   } while (++y != height);
 }
 
-void OverlapBlendFromLeft_NEON(void* const prediction,
-                               const ptrdiff_t prediction_stride,
-                               const int width, const int height,
-                               const void* const obmc_prediction,
-                               const ptrdiff_t obmc_prediction_stride) {
+void OverlapBlendFromLeft_NEON(
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
+    const int width, const int height,
+    const void* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
   auto* pred = static_cast<uint8_t*>(prediction);
   const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+  assert(width >= 2);
+  assert(height >= 4);
 
   if (width == 2) {
-    OverlapBlend2xH_NEON<true>(pred, prediction_stride, height, obmc_pred,
-                               obmc_prediction_stride);
+    OverlapBlendFromLeft2xH_NEON(pred, prediction_stride, height, obmc_pred,
+                                 obmc_prediction_stride);
     return;
   }
   if (width == 4) {
@@ -194,13 +178,10 @@
   } while (x < width);
 }
 
-inline void OverlapBlendFromTop4x4_NEON(uint8_t* const prediction,
-                                        const ptrdiff_t prediction_stride,
-                                        const uint8_t* const obmc_prediction,
-                                        const ptrdiff_t obmc_prediction_stride,
-                                        const int height) {
-  uint8_t* pred = prediction;
-  const uint8_t* obmc_pred = obmc_prediction;
+inline void OverlapBlendFromTop4x4_NEON(
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const uint8_t* LIBGAV1_RESTRICT obmc_pred,
+    const ptrdiff_t obmc_prediction_stride, const int height) {
   uint8x8_t pred_mask = vdup_n_u8(kObmcMask[height - 2]);
   const uint8x8_t mask_inverter = vdup_n_u8(64);
   uint8x8_t obmc_pred_mask = vsub_u8(mask_inverter, pred_mask);
@@ -224,16 +205,14 @@
 }
 
 inline void OverlapBlendFromTop4xH_NEON(
-    uint8_t* const prediction, const ptrdiff_t prediction_stride,
-    const int height, const uint8_t* const obmc_prediction,
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
     const ptrdiff_t obmc_prediction_stride) {
   if (height < 8) {
-    OverlapBlendFromTop4x4_NEON(prediction, prediction_stride, obmc_prediction,
+    OverlapBlendFromTop4x4_NEON(pred, prediction_stride, obmc_pred,
                                 obmc_prediction_stride, height);
     return;
   }
-  uint8_t* pred = prediction;
-  const uint8_t* obmc_pred = obmc_prediction;
   const uint8_t* mask = kObmcMask + height - 2;
   const uint8x8_t mask_inverter = vdup_n_u8(64);
   int y = 0;
@@ -282,11 +261,9 @@
 }
 
 inline void OverlapBlendFromTop8xH_NEON(
-    uint8_t* const prediction, const ptrdiff_t prediction_stride,
-    const int height, const uint8_t* const obmc_prediction,
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
     const ptrdiff_t obmc_prediction_stride) {
-  uint8_t* pred = prediction;
-  const uint8_t* obmc_pred = obmc_prediction;
   const uint8x8_t mask_inverter = vdup_n_u8(64);
   const uint8_t* mask = kObmcMask + height - 2;
   const int compute_height = height - (height >> 2);
@@ -307,19 +284,16 @@
   } while (++y != compute_height);
 }
 
-void OverlapBlendFromTop_NEON(void* const prediction,
-                              const ptrdiff_t prediction_stride,
-                              const int width, const int height,
-                              const void* const obmc_prediction,
-                              const ptrdiff_t obmc_prediction_stride) {
+void OverlapBlendFromTop_NEON(
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
+    const int width, const int height,
+    const void* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
   auto* pred = static_cast<uint8_t*>(prediction);
   const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+  assert(width >= 4);
+  assert(height >= 2);
 
-  if (width == 2) {
-    OverlapBlend2xH_NEON<false>(pred, prediction_stride, height, obmc_pred,
-                                obmc_prediction_stride);
-    return;
-  }
   if (width == 4) {
     OverlapBlendFromTop4xH_NEON(pred, prediction_stride, height, obmc_pred,
                                 obmc_prediction_stride);
@@ -374,8 +348,582 @@
 }
 
 }  // namespace
+}  // namespace low_bitdepth
 
-void ObmcInit_NEON() { Init8bpp(); }
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+// This is a flat array of masks for each block dimension from 2 to 32. The
+// starting index for each length is length-2. The value 64 leaves the result
+// equal to |pred| and may be ignored if convenient. Vector loads may overrread
+// values meant for larger sizes, but these values will be unused.
+constexpr uint16_t kObmcMask[62] = {
+    // Obmc Mask 2
+    45, 64,
+    // Obmc Mask 4
+    39, 50, 59, 64,
+    // Obmc Mask 8
+    36, 42, 48, 53, 57, 61, 64, 64,
+    // Obmc Mask 16
+    34, 37, 40, 43, 46, 49, 52, 54, 56, 58, 60, 61, 64, 64, 64, 64,
+    // Obmc Mask 32
+    33, 35, 36, 38, 40, 41, 43, 44, 45, 47, 48, 50, 51, 52, 53, 55, 56, 57, 58,
+    59, 60, 60, 61, 62, 64, 64, 64, 64, 64, 64, 64, 64};
+
+inline uint16x4_t BlendObmc2Or4(uint8_t* LIBGAV1_RESTRICT const pred,
+                                const uint8_t* LIBGAV1_RESTRICT const obmc_pred,
+                                const uint16x4_t pred_mask,
+                                const uint16x4_t obmc_pred_mask) {
+  const uint16x4_t pred_val = vld1_u16(reinterpret_cast<uint16_t*>(pred));
+  const uint16x4_t obmc_pred_val =
+      vld1_u16(reinterpret_cast<const uint16_t*>(obmc_pred));
+  const uint16x4_t weighted_pred = vmul_u16(pred_mask, pred_val);
+  const uint16x4_t result =
+      vrshr_n_u16(vmla_u16(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
+  return result;
+}
+
+inline uint16x8_t BlendObmc8(uint8_t* LIBGAV1_RESTRICT const pred,
+                             const uint8_t* LIBGAV1_RESTRICT const obmc_pred,
+                             const uint16x8_t pred_mask,
+                             const uint16x8_t obmc_pred_mask) {
+  const uint16x8_t pred_val = vld1q_u16(reinterpret_cast<uint16_t*>(pred));
+  const uint16x8_t obmc_pred_val =
+      vld1q_u16(reinterpret_cast<const uint16_t*>(obmc_pred));
+  const uint16x8_t weighted_pred = vmulq_u16(pred_mask, pred_val);
+  const uint16x8_t result =
+      vrshrq_n_u16(vmlaq_u16(weighted_pred, obmc_pred_mask, obmc_pred_val), 6);
+  return result;
+}
+
+inline void OverlapBlendFromLeft2xH_NEON(
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
+    const ptrdiff_t obmc_prediction_stride) {
+  const uint16x4_t mask_inverter = vdup_n_u16(64);
+  // Second two lanes unused.
+  const uint16x4_t pred_mask = vld1_u16(kObmcMask);
+  const uint16x4_t obmc_pred_mask = vsub_u16(mask_inverter, pred_mask);
+  int y = 0;
+  do {
+    const uint16x4_t result_0 =
+        BlendObmc2Or4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    Store2<0>(reinterpret_cast<uint16_t*>(pred), result_0);
+
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    const uint16x4_t result_1 =
+        BlendObmc2Or4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    Store2<0>(reinterpret_cast<uint16_t*>(pred), result_1);
+
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    y += 2;
+  } while (y != height);
+}
+
+inline void OverlapBlendFromLeft4xH_NEON(
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
+    const ptrdiff_t obmc_prediction_stride) {
+  const uint16x4_t mask_inverter = vdup_n_u16(64);
+  const uint16x4_t pred_mask = vld1_u16(kObmcMask + 2);
+  // 64 - mask
+  const uint16x4_t obmc_pred_mask = vsub_u16(mask_inverter, pred_mask);
+  int y = 0;
+  do {
+    const uint16x4_t result_0 =
+        BlendObmc2Or4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    vst1_u16(reinterpret_cast<uint16_t*>(pred), result_0);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    const uint16x4_t result_1 =
+        BlendObmc2Or4(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    vst1_u16(reinterpret_cast<uint16_t*>(pred), result_1);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    y += 2;
+  } while (y != height);
+}
+
+void OverlapBlendFromLeft_NEON(
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
+    const int width, const int height,
+    const void* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  auto* pred = static_cast<uint8_t*>(prediction);
+  const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+  assert(width >= 2);
+  assert(height >= 4);
+
+  if (width == 2) {
+    OverlapBlendFromLeft2xH_NEON(pred, prediction_stride, height, obmc_pred,
+                                 obmc_prediction_stride);
+    return;
+  }
+  if (width == 4) {
+    OverlapBlendFromLeft4xH_NEON(pred, prediction_stride, height, obmc_pred,
+                                 obmc_prediction_stride);
+    return;
+  }
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  const uint16_t* mask = kObmcMask + width - 2;
+  int x = 0;
+  do {
+    pred = reinterpret_cast<uint8_t*>(static_cast<uint16_t*>(prediction) + x);
+    obmc_pred = reinterpret_cast<const uint8_t*>(
+        static_cast<const uint16_t*>(obmc_prediction) + x);
+    const uint16x8_t pred_mask = vld1q_u16(mask + x);
+    // 64 - mask
+    const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
+    int y = 0;
+    do {
+      const uint16x8_t result =
+          BlendObmc8(pred, obmc_pred, pred_mask, obmc_pred_mask);
+      vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+
+      pred += prediction_stride;
+      obmc_pred += obmc_prediction_stride;
+    } while (++y < height);
+    x += 8;
+  } while (x < width);
+}
+
+template <int lane>
+inline uint16x4_t BlendObmcFromTop4(
+    uint8_t* LIBGAV1_RESTRICT const pred,
+    const uint8_t* LIBGAV1_RESTRICT const obmc_pred, const uint16x8_t pred_mask,
+    const uint16x8_t obmc_pred_mask) {
+  const uint16x4_t pred_val = vld1_u16(reinterpret_cast<uint16_t*>(pred));
+  const uint16x4_t obmc_pred_val =
+      vld1_u16(reinterpret_cast<const uint16_t*>(obmc_pred));
+  const uint16x4_t weighted_pred = VMulLaneQU16<lane>(pred_val, pred_mask);
+  const uint16x4_t result = vrshr_n_u16(
+      VMlaLaneQU16<lane>(weighted_pred, obmc_pred_val, obmc_pred_mask), 6);
+  return result;
+}
+
+template <int lane>
+inline uint16x8_t BlendObmcFromTop8(
+    uint8_t* LIBGAV1_RESTRICT const pred,
+    const uint8_t* LIBGAV1_RESTRICT const obmc_pred, const uint16x8_t pred_mask,
+    const uint16x8_t obmc_pred_mask) {
+  const uint16x8_t pred_val = vld1q_u16(reinterpret_cast<uint16_t*>(pred));
+  const uint16x8_t obmc_pred_val =
+      vld1q_u16(reinterpret_cast<const uint16_t*>(obmc_pred));
+  const uint16x8_t weighted_pred = VMulQLaneQU16<lane>(pred_val, pred_mask);
+  const uint16x8_t result = vrshrq_n_u16(
+      VMlaQLaneQU16<lane>(weighted_pred, obmc_pred_val, obmc_pred_mask), 6);
+  return result;
+}
+
+inline void OverlapBlendFromTop4x2Or4_NEON(
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const uint8_t* LIBGAV1_RESTRICT obmc_pred,
+    const ptrdiff_t obmc_prediction_stride, const int height) {
+  const uint16x8_t pred_mask = vld1q_u16(&kObmcMask[height - 2]);
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
+  uint16x4_t result =
+      BlendObmcFromTop4<0>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  if (height == 2) {
+    // Mask value is 64, meaning |pred| is unchanged.
+    return;
+  }
+
+  result = BlendObmcFromTop4<1>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop4<2>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1_u16(reinterpret_cast<uint16_t*>(pred), result);
+}
+
+inline void OverlapBlendFromTop4xH_NEON(
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const int height, const uint8_t* LIBGAV1_RESTRICT obmc_pred,
+    const ptrdiff_t obmc_prediction_stride) {
+  if (height < 8) {
+    OverlapBlendFromTop4x2Or4_NEON(pred, prediction_stride, obmc_pred,
+                                   obmc_prediction_stride, height);
+    return;
+  }
+  const uint16_t* mask = kObmcMask + height - 2;
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  int y = 0;
+  // Compute 6 lines for height 8, or 12 lines for height 16. The remaining
+  // lines are unchanged as the corresponding mask value is 64.
+  do {
+    const uint16x8_t pred_mask = vld1q_u16(&mask[y]);
+    const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
+    uint16x4_t result =
+        BlendObmcFromTop4<0>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    vst1_u16(reinterpret_cast<uint16_t*>(pred), result);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    result = BlendObmcFromTop4<1>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    vst1_u16(reinterpret_cast<uint16_t*>(pred), result);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    result = BlendObmcFromTop4<2>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    vst1_u16(reinterpret_cast<uint16_t*>(pred), result);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    result = BlendObmcFromTop4<3>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    vst1_u16(reinterpret_cast<uint16_t*>(pred), result);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    result = BlendObmcFromTop4<4>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    vst1_u16(reinterpret_cast<uint16_t*>(pred), result);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    result = BlendObmcFromTop4<5>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+    vst1_u16(reinterpret_cast<uint16_t*>(pred), result);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    // Increment for the right mask index.
+    y += 6;
+  } while (y < height - 4);
+}
+
+inline void OverlapBlendFromTop8xH_NEON(
+    uint8_t* LIBGAV1_RESTRICT pred, const ptrdiff_t prediction_stride,
+    const uint8_t* LIBGAV1_RESTRICT obmc_pred,
+    const ptrdiff_t obmc_prediction_stride, const int height) {
+  const uint16_t* mask = kObmcMask + height - 2;
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  uint16x8_t pred_mask = vld1q_u16(mask);
+  uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
+  uint16x8_t result =
+      BlendObmcFromTop8<0>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  if (height == 2) return;
+
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<1>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<2>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<3>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  if (height == 4) return;
+
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<4>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<5>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+
+  if (height == 8) return;
+
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<6>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<7>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  pred_mask = vld1q_u16(&mask[8]);
+  obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
+
+  result = BlendObmcFromTop8<0>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<1>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<2>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<3>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+
+  if (height == 16) return;
+
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<4>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<5>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<6>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<7>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  pred_mask = vld1q_u16(&mask[16]);
+  obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
+
+  result = BlendObmcFromTop8<0>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<1>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<2>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<3>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<4>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<5>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<6>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+  pred += prediction_stride;
+  obmc_pred += obmc_prediction_stride;
+
+  result = BlendObmcFromTop8<7>(pred, obmc_pred, pred_mask, obmc_pred_mask);
+  vst1q_u16(reinterpret_cast<uint16_t*>(pred), result);
+}
+
+void OverlapBlendFromTop_NEON(
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
+    const int width, const int height,
+    const void* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
+  auto* pred = static_cast<uint8_t*>(prediction);
+  const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+  assert(width >= 4);
+  assert(height >= 2);
+
+  if (width == 4) {
+    OverlapBlendFromTop4xH_NEON(pred, prediction_stride, height, obmc_pred,
+                                obmc_prediction_stride);
+    return;
+  }
+
+  if (width == 8) {
+    OverlapBlendFromTop8xH_NEON(pred, prediction_stride, obmc_pred,
+                                obmc_prediction_stride, height);
+    return;
+  }
+
+  const uint16_t* mask = kObmcMask + height - 2;
+  const uint16x8_t mask_inverter = vdupq_n_u16(64);
+  const uint16x8_t pred_mask = vld1q_u16(mask);
+  // 64 - mask
+  const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
+#define OBMC_ROW_FROM_TOP(n)                                                 \
+  do {                                                                       \
+    int x = 0;                                                               \
+    do {                                                                     \
+      const uint16x8_t result = BlendObmcFromTop8<n>(                        \
+          reinterpret_cast<uint8_t*>(reinterpret_cast<uint16_t*>(pred) + x), \
+          reinterpret_cast<const uint8_t*>(                                  \
+              reinterpret_cast<const uint16_t*>(obmc_pred) + x),             \
+          pred_mask, obmc_pred_mask);                                        \
+      vst1q_u16(reinterpret_cast<uint16_t*>(pred) + x, result);              \
+                                                                             \
+      x += 8;                                                                \
+    } while (x < width);                                                     \
+  } while (false)
+
+  // Compute 1 row.
+  if (height == 2) {
+    OBMC_ROW_FROM_TOP(0);
+    return;
+  }
+
+  // Compute 3 rows.
+  if (height == 4) {
+    OBMC_ROW_FROM_TOP(0);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(1);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(2);
+    return;
+  }
+
+  // Compute 6 rows.
+  if (height == 8) {
+    OBMC_ROW_FROM_TOP(0);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(1);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(2);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(3);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(4);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(5);
+    return;
+  }
+
+  // Compute 12 rows.
+  if (height == 16) {
+    OBMC_ROW_FROM_TOP(0);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(1);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(2);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(3);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(4);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(5);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(6);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(7);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    const uint16x8_t pred_mask = vld1q_u16(&mask[8]);
+    // 64 - mask
+    const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
+    OBMC_ROW_FROM_TOP(0);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(1);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(2);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(3);
+    return;
+  }
+
+  // Stop when mask value becomes 64. This is a multiple of 8 for height 32
+  // and 64.
+  const int compute_height = height - (height >> 2);
+  int y = 0;
+  do {
+    const uint16x8_t pred_mask = vld1q_u16(&mask[y]);
+    // 64 - mask
+    const uint16x8_t obmc_pred_mask = vsubq_u16(mask_inverter, pred_mask);
+    OBMC_ROW_FROM_TOP(0);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(1);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(2);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(3);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(4);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(5);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(6);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+    OBMC_ROW_FROM_TOP(7);
+    pred += prediction_stride;
+    obmc_pred += obmc_prediction_stride;
+
+    y += 8;
+  } while (y < compute_height);
+}
+
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_NEON;
+  dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_NEON;
+}
+
+}  // namespace
+}  // namespace high_bitdepth
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+void ObmcInit_NEON() {
+  low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  high_bitdepth::Init10bpp();
+#endif
+}
 
 }  // namespace dsp
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/arm/obmc_neon.h b/libgav1/src/dsp/arm/obmc_neon.h
index d5c9d9c..788017e 100644
--- a/libgav1/src/dsp/arm/obmc_neon.h
+++ b/libgav1/src/dsp/arm/obmc_neon.h
@@ -33,6 +33,9 @@
 #if LIBGAV1_ENABLE_NEON
 #define LIBGAV1_Dsp8bpp_ObmcVertical LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_ObmcHorizontal LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_ObmcVertical LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_ObmcHorizontal LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_OBMC_NEON_H_
diff --git a/libgav1/src/dsp/arm/super_res_neon.cc b/libgav1/src/dsp/arm/super_res_neon.cc
index 91537c4..2f8dde6 100644
--- a/libgav1/src/dsp/arm/super_res_neon.cc
+++ b/libgav1/src/dsp/arm/super_res_neon.cc
@@ -23,6 +23,7 @@
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
 
 namespace libgav1 {
@@ -81,19 +82,27 @@
   return vqrshrn_n_u16(res, kFilterBits);
 }
 
-void SuperRes_NEON(const void* const coefficients, void* const source,
+void SuperRes_NEON(const void* LIBGAV1_RESTRICT const coefficients,
+                   void* LIBGAV1_RESTRICT const source,
                    const ptrdiff_t source_stride, const int height,
                    const int downscaled_width, const int upscaled_width,
                    const int initial_subpixel_x, const int step,
-                   void* const dest, const ptrdiff_t dest_stride) {
+                   void* LIBGAV1_RESTRICT const dest,
+                   const ptrdiff_t dest_stride) {
   auto* src = static_cast<uint8_t*>(source) - DivideBy2(kSuperResFilterTaps);
   auto* dst = static_cast<uint8_t*>(dest);
   int y = height;
   do {
     const auto* filter = static_cast<const uint8_t*>(coefficients);
     uint8_t* dst_ptr = dst;
+#if LIBGAV1_MSAN
+    // Initialize the padding area to prevent msan warnings.
+    const int super_res_right_border = kSuperResHorizontalPadding;
+#else
+    const int super_res_right_border = kSuperResHorizontalBorder;
+#endif
     ExtendLine<uint8_t>(src + DivideBy2(kSuperResFilterTaps), downscaled_width,
-                        kSuperResHorizontalBorder, kSuperResHorizontalBorder);
+                        kSuperResHorizontalBorder, super_res_right_border);
     int subpixel_x = initial_subpixel_x;
     uint8x8_t sr[8];
     uint8x16_t s[8];
@@ -234,19 +243,27 @@
 }
 
 template <int bitdepth>
-void SuperRes_NEON(const void* const coefficients, void* const source,
+void SuperRes_NEON(const void* LIBGAV1_RESTRICT const coefficients,
+                   void* LIBGAV1_RESTRICT const source,
                    const ptrdiff_t source_stride, const int height,
                    const int downscaled_width, const int upscaled_width,
                    const int initial_subpixel_x, const int step,
-                   void* const dest, const ptrdiff_t dest_stride) {
+                   void* LIBGAV1_RESTRICT const dest,
+                   const ptrdiff_t dest_stride) {
   auto* src = static_cast<uint16_t*>(source) - DivideBy2(kSuperResFilterTaps);
   auto* dst = static_cast<uint16_t*>(dest);
   int y = height;
   do {
     const auto* filter = static_cast<const uint16_t*>(coefficients);
     uint16_t* dst_ptr = dst;
+#if LIBGAV1_MSAN
+    // Initialize the padding area to prevent msan warnings.
+    const int super_res_right_border = kSuperResHorizontalPadding;
+#else
+    const int super_res_right_border = kSuperResHorizontalBorder;
+#endif
     ExtendLine<uint16_t>(src + DivideBy2(kSuperResFilterTaps), downscaled_width,
-                         kSuperResHorizontalBorder, kSuperResHorizontalBorder);
+                         kSuperResHorizontalBorder, super_res_right_border);
     int subpixel_x = initial_subpixel_x;
     uint16x8_t sr[8];
     int x = RightShiftWithCeiling(upscaled_width, 3);
diff --git a/libgav1/src/dsp/arm/warp_neon.cc b/libgav1/src/dsp/arm/warp_neon.cc
index c7fb739..71e0a43 100644
--- a/libgav1/src/dsp/arm/warp_neon.cc
+++ b/libgav1/src/dsp/arm/warp_neon.cc
@@ -34,11 +34,16 @@
 
 namespace libgav1 {
 namespace dsp {
-namespace low_bitdepth {
 namespace {
 
 // Number of extra bits of precision in warped filtering.
 constexpr int kWarpedDiffPrecisionBits = 10;
+
+}  // namespace
+
+namespace low_bitdepth {
+namespace {
+
 constexpr int kFirstPassOffset = 1 << 14;
 constexpr int kOffsetRemoval =
     (kFirstPassOffset >> kInterRoundBitsHorizontal) * 128;
@@ -54,10 +59,10 @@
                       int16_t intermediate_result_row[8]) {
   int sx = sx4 - MultiplyBy4(alpha);
   int8x8_t filter[8];
-  for (int x = 0; x < 8; ++x) {
+  for (auto& f : filter) {
     const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
                        kWarpedPixelPrecisionShifts;
-    filter[x] = vld1_s8(kWarpedFilters8[offset]);
+    f = vld1_s8(kWarpedFilters8[offset]);
     sx += alpha;
   }
   Transpose8x8(filter);
@@ -103,13 +108,15 @@
 }
 
 template <bool is_compound>
-void Warp_NEON(const void* const source, const ptrdiff_t source_stride,
-               const int source_width, const int source_height,
-               const int* const warp_params, const int subsampling_x,
-               const int subsampling_y, const int block_start_x,
-               const int block_start_y, const int block_width,
-               const int block_height, const int16_t alpha, const int16_t beta,
-               const int16_t gamma, const int16_t delta, void* dest,
+void Warp_NEON(const void* LIBGAV1_RESTRICT const source,
+               const ptrdiff_t source_stride, const int source_width,
+               const int source_height,
+               const int* LIBGAV1_RESTRICT const warp_params,
+               const int subsampling_x, const int subsampling_y,
+               const int block_start_x, const int block_start_y,
+               const int block_width, const int block_height,
+               const int16_t alpha, const int16_t beta, const int16_t gamma,
+               const int16_t delta, void* LIBGAV1_RESTRICT dest,
                const ptrdiff_t dest_stride) {
   constexpr int kRoundBitsVertical =
       is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical;
@@ -393,11 +400,11 @@
       for (int y = 0; y < 8; ++y) {
         int sy = sy4 - MultiplyBy4(gamma);
         int16x8_t filter[8];
-        for (int x = 0; x < 8; ++x) {
+        for (auto& f : filter) {
           const int offset =
               RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
               kWarpedPixelPrecisionShifts;
-          filter[x] = vld1q_s16(kWarpedFilters[offset]);
+          f = vld1q_s16(kWarpedFilters[offset]);
           sy += gamma;
         }
         Transpose8x8(filter);
@@ -438,7 +445,453 @@
 }  // namespace
 }  // namespace low_bitdepth
 
-void WarpInit_NEON() { low_bitdepth::Init8bpp(); }
+//------------------------------------------------------------------------------
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+LIBGAV1_ALWAYS_INLINE uint16x8x2_t LoadSrcRow(uint16_t const* ptr) {
+  uint16x8x2_t x;
+  // Clang/gcc uses ldp here.
+  x.val[0] = vld1q_u16(ptr);
+  x.val[1] = vld1q_u16(ptr + 8);
+  return x;
+}
+
+LIBGAV1_ALWAYS_INLINE void HorizontalFilter(
+    const int sx4, const int16_t alpha, const uint16x8x2_t src_row,
+    int16_t intermediate_result_row[8]) {
+  int sx = sx4 - MultiplyBy4(alpha);
+  int8x8_t filter8[8];
+  for (auto& f : filter8) {
+    const int offset = RightShiftWithRounding(sx, kWarpedDiffPrecisionBits) +
+                       kWarpedPixelPrecisionShifts;
+    f = vld1_s8(kWarpedFilters8[offset]);
+    sx += alpha;
+  }
+
+  Transpose8x8(filter8);
+
+  int16x8_t filter[8];
+  for (int i = 0; i < 8; ++i) {
+    filter[i] = vmovl_s8(filter8[i]);
+  }
+
+  int32x4x2_t sum;
+  int16x8_t src_row_window;
+  // k = 0.
+  src_row_window = vreinterpretq_s16_u16(src_row.val[0]);
+  sum.val[0] = vmull_s16(vget_low_s16(filter[0]), vget_low_s16(src_row_window));
+  sum.val[1] = VMullHighS16(filter[0], src_row_window);
+  // k = 1.
+  src_row_window =
+      vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 1));
+  sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[1]),
+                         vget_low_s16(src_row_window));
+  sum.val[1] = VMlalHighS16(sum.val[1], filter[1], src_row_window);
+  // k = 2.
+  src_row_window =
+      vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 2));
+  sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[2]),
+                         vget_low_s16(src_row_window));
+  sum.val[1] = VMlalHighS16(sum.val[1], filter[2], src_row_window);
+  // k = 3.
+  src_row_window =
+      vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 3));
+  sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[3]),
+                         vget_low_s16(src_row_window));
+  sum.val[1] = VMlalHighS16(sum.val[1], filter[3], src_row_window);
+  // k = 4.
+  src_row_window =
+      vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 4));
+  sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[4]),
+                         vget_low_s16(src_row_window));
+  sum.val[1] = VMlalHighS16(sum.val[1], filter[4], src_row_window);
+  // k = 5.
+  src_row_window =
+      vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 5));
+  sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[5]),
+                         vget_low_s16(src_row_window));
+  sum.val[1] = VMlalHighS16(sum.val[1], filter[5], src_row_window);
+  // k = 6.
+  src_row_window =
+      vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 6));
+  sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[6]),
+                         vget_low_s16(src_row_window));
+  sum.val[1] = VMlalHighS16(sum.val[1], filter[6], src_row_window);
+  // k = 7.
+  src_row_window =
+      vreinterpretq_s16_u16(vextq_u16(src_row.val[0], src_row.val[1], 7));
+  sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(filter[7]),
+                         vget_low_s16(src_row_window));
+  sum.val[1] = VMlalHighS16(sum.val[1], filter[7], src_row_window);
+  // End of unrolled k = 0..7 loop.
+
+  vst1_s16(intermediate_result_row,
+           vrshrn_n_s32(sum.val[0], kInterRoundBitsHorizontal));
+  vst1_s16(intermediate_result_row + 4,
+           vrshrn_n_s32(sum.val[1], kInterRoundBitsHorizontal));
+}
+
+template <bool is_compound>
+void Warp_NEON(const void* LIBGAV1_RESTRICT const source,
+               const ptrdiff_t source_stride, const int source_width,
+               const int source_height,
+               const int* LIBGAV1_RESTRICT const warp_params,
+               const int subsampling_x, const int subsampling_y,
+               const int block_start_x, const int block_start_y,
+               const int block_width, const int block_height,
+               const int16_t alpha, const int16_t beta, const int16_t gamma,
+               const int16_t delta, void* LIBGAV1_RESTRICT dest,
+               const ptrdiff_t dest_stride) {
+  constexpr int kRoundBitsVertical =
+      is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical;
+  union {
+    // Intermediate_result is the output of the horizontal filtering and
+    // rounding. The range is within 13 (= bitdepth + kFilterBits + 1 -
+    // kInterRoundBitsHorizontal) bits (unsigned). We use the signed int16_t
+    // type so that we can multiply it by kWarpedFilters (which has signed
+    // values) using vmlal_s16().
+    int16_t intermediate_result[15][8];  // 15 rows, 8 columns.
+    // In the simple special cases where the samples in each row are all the
+    // same, store one sample per row in a column vector.
+    int16_t intermediate_result_column[15];
+  };
+
+  const auto* const src = static_cast<const uint16_t*>(source);
+  const ptrdiff_t src_stride = source_stride >> 1;
+  using DestType =
+      typename std::conditional<is_compound, int16_t, uint16_t>::type;
+  auto* dst = static_cast<DestType*>(dest);
+  const ptrdiff_t dst_stride = is_compound ? dest_stride : dest_stride >> 1;
+  assert(block_width >= 8);
+  assert(block_height >= 8);
+
+  // Warp process applies for each 8x8 block.
+  int start_y = block_start_y;
+  do {
+    int start_x = block_start_x;
+    do {
+      const int src_x = (start_x + 4) << subsampling_x;
+      const int src_y = (start_y + 4) << subsampling_y;
+      const int dst_x =
+          src_x * warp_params[2] + src_y * warp_params[3] + warp_params[0];
+      const int dst_y =
+          src_x * warp_params[4] + src_y * warp_params[5] + warp_params[1];
+      const int x4 = dst_x >> subsampling_x;
+      const int y4 = dst_y >> subsampling_y;
+      const int ix4 = x4 >> kWarpedModelPrecisionBits;
+      const int iy4 = y4 >> kWarpedModelPrecisionBits;
+      // A prediction block may fall outside the frame's boundaries. If a
+      // prediction block is calculated using only samples outside the frame's
+      // boundary, the filtering can be simplified. We can divide the plane
+      // into several regions and handle them differently.
+      //
+      //                |           |
+      //            1   |     3     |   1
+      //                |           |
+      //         -------+-----------+-------
+      //                |***********|
+      //            2   |*****4*****|   2
+      //                |***********|
+      //         -------+-----------+-------
+      //                |           |
+      //            1   |     3     |   1
+      //                |           |
+      //
+      // At the center, region 4 represents the frame and is the general case.
+      //
+      // In regions 1 and 2, the prediction block is outside the frame's
+      // boundary horizontally. Therefore the horizontal filtering can be
+      // simplified. Furthermore, in the region 1 (at the four corners), the
+      // prediction is outside the frame's boundary both horizontally and
+      // vertically, so we get a constant prediction block.
+      //
+      // In region 3, the prediction block is outside the frame's boundary
+      // vertically. Unfortunately because we apply the horizontal filters
+      // first, by the time we apply the vertical filters, they no longer see
+      // simple inputs. So the only simplification is that all the rows are
+      // the same, but we still need to apply all the horizontal and vertical
+      // filters.
+
+      // Check for two simple special cases, where the horizontal filter can
+      // be significantly simplified.
+      //
+      // In general, for each row, the horizontal filter is calculated as
+      // follows:
+      //   for (int x = -4; x < 4; ++x) {
+      //     const int offset = ...;
+      //     int sum = first_pass_offset;
+      //     for (int k = 0; k < 8; ++k) {
+      //       const int column = Clip3(ix4 + x + k - 3, 0, source_width - 1);
+      //       sum += kWarpedFilters[offset][k] * src_row[column];
+      //     }
+      //     ...
+      //   }
+      // The column index before clipping, ix4 + x + k - 3, varies in the range
+      // ix4 - 7 <= ix4 + x + k - 3 <= ix4 + 7. If ix4 - 7 >= source_width - 1
+      // or ix4 + 7 <= 0, then all the column indexes are clipped to the same
+      // border index (source_width - 1 or 0, respectively). Then for each x,
+      // the inner for loop of the horizontal filter is reduced to multiplying
+      // the border pixel by the sum of the filter coefficients.
+      if (ix4 - 7 >= source_width - 1 || ix4 + 7 <= 0) {
+        // Regions 1 and 2.
+        // Points to the left or right border of the first row of |src|.
+        const uint16_t* first_row_border =
+            (ix4 + 7 <= 0) ? src : src + source_width - 1;
+        // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+        //   const int row = Clip3(iy4 + y, 0, source_height - 1);
+        // In two special cases, iy4 + y is clipped to either 0 or
+        // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+        // bounded and we can avoid clipping iy4 + y by relying on a reference
+        // frame's boundary extension on the top and bottom.
+        if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) {
+          // Region 1.
+          // Every sample used to calculate the prediction block has the same
+          // value. So the whole prediction block has the same value.
+          const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1;
+          const uint16_t row_border_pixel = first_row_border[row * src_stride];
+
+          DestType* dst_row = dst + start_x - block_start_x;
+          for (int y = 0; y < 8; ++y) {
+            if (is_compound) {
+              const int16x8_t sum =
+                  vdupq_n_s16(row_border_pixel << (kInterRoundBitsVertical -
+                                                   kRoundBitsVertical));
+              vst1q_s16(reinterpret_cast<int16_t*>(dst_row),
+                        vaddq_s16(sum, vdupq_n_s16(kCompoundOffset)));
+            } else {
+              vst1q_u16(reinterpret_cast<uint16_t*>(dst_row),
+                        vdupq_n_u16(row_border_pixel));
+            }
+            dst_row += dst_stride;
+          }
+          // End of region 1. Continue the |start_x| do-while loop.
+          start_x += 8;
+          continue;
+        }
+
+        // Region 2.
+        // Horizontal filter.
+        // The input values in this region are generated by extending the border
+        // which makes them identical in the horizontal direction. This
+        // computation could be inlined in the vertical pass but most
+        // implementations will need a transpose of some sort.
+        // It is not necessary to use the offset values here because the
+        // horizontal pass is a simple shift and the vertical pass will always
+        // require using 32 bits.
+        for (int y = -7; y < 8; ++y) {
+          // We may over-read up to 13 pixels above the top source row, or up
+          // to 13 pixels below the bottom source row. This is proved in
+          // warp.cc.
+          const int row = iy4 + y;
+          int sum = first_row_border[row * src_stride];
+          sum <<= (kFilterBits - kInterRoundBitsHorizontal);
+          intermediate_result_column[y + 7] = sum;
+        }
+        // Vertical filter.
+        DestType* dst_row = dst + start_x - block_start_x;
+        int sy4 =
+            (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
+        for (int y = 0; y < 8; ++y) {
+          int sy = sy4 - MultiplyBy4(gamma);
+#if defined(__aarch64__)
+          const int16x8_t intermediate =
+              vld1q_s16(&intermediate_result_column[y]);
+          int16_t tmp[8];
+          for (int x = 0; x < 8; ++x) {
+            const int offset =
+                RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+                kWarpedPixelPrecisionShifts;
+            const int16x8_t filter = vld1q_s16(kWarpedFilters[offset]);
+            const int32x4_t product_low =
+                vmull_s16(vget_low_s16(filter), vget_low_s16(intermediate));
+            const int32x4_t product_high =
+                vmull_s16(vget_high_s16(filter), vget_high_s16(intermediate));
+            // vaddvq_s32 is only available on __aarch64__.
+            const int32_t sum =
+                vaddvq_s32(product_low) + vaddvq_s32(product_high);
+            const int16_t sum_descale =
+                RightShiftWithRounding(sum, kRoundBitsVertical);
+            if (is_compound) {
+              dst_row[x] = sum_descale + kCompoundOffset;
+            } else {
+              tmp[x] = sum_descale;
+            }
+            sy += gamma;
+          }
+          if (!is_compound) {
+            const uint16x8_t v_max_bitdepth =
+                vdupq_n_u16((1 << kBitdepth10) - 1);
+            const int16x8_t sum = vld1q_s16(tmp);
+            const uint16x8_t d0 =
+                vminq_u16(vreinterpretq_u16_s16(vmaxq_s16(sum, vdupq_n_s16(0))),
+                          v_max_bitdepth);
+            vst1q_u16(reinterpret_cast<uint16_t*>(dst_row), d0);
+          }
+#else   // !defined(__aarch64__)
+          int16x8_t filter[8];
+          for (int x = 0; x < 8; ++x) {
+            const int offset =
+                RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+                kWarpedPixelPrecisionShifts;
+            filter[x] = vld1q_s16(kWarpedFilters[offset]);
+            sy += gamma;
+          }
+          Transpose8x8(filter);
+          int32x4_t sum_low = vdupq_n_s32(0);
+          int32x4_t sum_high = sum_low;
+          for (int k = 0; k < 8; ++k) {
+            const int16_t intermediate = intermediate_result_column[y + k];
+            sum_low =
+                vmlal_n_s16(sum_low, vget_low_s16(filter[k]), intermediate);
+            sum_high =
+                vmlal_n_s16(sum_high, vget_high_s16(filter[k]), intermediate);
+          }
+          if (is_compound) {
+            const int16x8_t sum =
+                vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical),
+                             vrshrn_n_s32(sum_high, kRoundBitsVertical));
+            vst1q_s16(reinterpret_cast<int16_t*>(dst_row),
+                      vaddq_s16(sum, vdupq_n_s16(kCompoundOffset)));
+          } else {
+            const uint16x4_t v_max_bitdepth =
+                vdup_n_u16((1 << kBitdepth10) - 1);
+            const uint16x4_t d0 = vmin_u16(
+                vqrshrun_n_s32(sum_low, kRoundBitsVertical), v_max_bitdepth);
+            const uint16x4_t d1 = vmin_u16(
+                vqrshrun_n_s32(sum_high, kRoundBitsVertical), v_max_bitdepth);
+            vst1_u16(reinterpret_cast<uint16_t*>(dst_row), d0);
+            vst1_u16(reinterpret_cast<uint16_t*>(dst_row + 4), d1);
+          }
+#endif  // defined(__aarch64__)
+          dst_row += dst_stride;
+          sy4 += delta;
+        }
+        // End of region 2. Continue the |start_x| do-while loop.
+        start_x += 8;
+        continue;
+      }
+
+      // Regions 3 and 4.
+      // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
+
+      // In general, for y in [-7, 8), the row number iy4 + y is clipped:
+      //   const int row = Clip3(iy4 + y, 0, source_height - 1);
+      // In two special cases, iy4 + y is clipped to either 0 or
+      // source_height - 1 for all y. In the rest of the cases, iy4 + y is
+      // bounded and we can avoid clipping iy4 + y by relying on a reference
+      // frame's boundary extension on the top and bottom.
+      if (iy4 - 7 >= source_height - 1 || iy4 + 7 <= 0) {
+        // Region 3.
+        // Horizontal filter.
+        const int row = (iy4 + 7 <= 0) ? 0 : source_height - 1;
+        const uint16_t* const src_row = src + row * src_stride;
+        // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
+        // read but is ignored.
+        //
+        // NOTE: This may read up to 13 pixels before src_row[0] or up to 14
+        // pixels after src_row[source_width - 1]. We assume the source frame
+        // has left and right borders of at least 13 pixels that extend the
+        // frame boundary pixels. We also assume there is at least one extra
+        // padding pixel after the right border of the last source row.
+        const uint16x8x2_t src_row_v = LoadSrcRow(&src_row[ix4 - 7]);
+        int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+        for (int y = -7; y < 8; ++y) {
+          HorizontalFilter(sx4, alpha, src_row_v, intermediate_result[y + 7]);
+          sx4 += beta;
+        }
+      } else {
+        // Region 4.
+        // Horizontal filter.
+        int sx4 = (x4 & ((1 << kWarpedModelPrecisionBits) - 1)) - beta * 7;
+        for (int y = -7; y < 8; ++y) {
+          // We may over-read up to 13 pixels above the top source row, or up
+          // to 13 pixels below the bottom source row. This is proved in
+          // warp.cc.
+          const int row = iy4 + y;
+          const uint16_t* const src_row = src + row * src_stride;
+          // Read 15 samples from &src_row[ix4 - 7]. The 16th sample is also
+          // read but is ignored.
+          //
+          // NOTE: This may read up to pixels bytes before src_row[0] or up to
+          // 14 pixels after src_row[source_width - 1]. We assume the source
+          // frame has left and right borders of at least 13 pixels that extend
+          // the frame boundary pixels. We also assume there is at least one
+          // extra padding pixel after the right border of the last source row.
+          const uint16x8x2_t src_row_v = LoadSrcRow(&src_row[ix4 - 7]);
+          HorizontalFilter(sx4, alpha, src_row_v, intermediate_result[y + 7]);
+          sx4 += beta;
+        }
+      }
+
+      // Regions 3 and 4.
+      // Vertical filter.
+      DestType* dst_row = dst + start_x - block_start_x;
+      int sy4 =
+          (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
+      for (int y = 0; y < 8; ++y) {
+        int sy = sy4 - MultiplyBy4(gamma);
+        int16x8_t filter[8];
+        for (auto& f : filter) {
+          const int offset =
+              RightShiftWithRounding(sy, kWarpedDiffPrecisionBits) +
+              kWarpedPixelPrecisionShifts;
+          f = vld1q_s16(kWarpedFilters[offset]);
+          sy += gamma;
+        }
+        Transpose8x8(filter);
+        int32x4_t sum_low = vdupq_n_s32(0);
+        int32x4_t sum_high = sum_low;
+        for (int k = 0; k < 8; ++k) {
+          const int16x8_t intermediate = vld1q_s16(intermediate_result[y + k]);
+          sum_low = vmlal_s16(sum_low, vget_low_s16(filter[k]),
+                              vget_low_s16(intermediate));
+          sum_high = vmlal_s16(sum_high, vget_high_s16(filter[k]),
+                               vget_high_s16(intermediate));
+        }
+        if (is_compound) {
+          const int16x8_t sum =
+              vcombine_s16(vrshrn_n_s32(sum_low, kRoundBitsVertical),
+                           vrshrn_n_s32(sum_high, kRoundBitsVertical));
+          vst1q_s16(reinterpret_cast<int16_t*>(dst_row),
+                    vaddq_s16(sum, vdupq_n_s16(kCompoundOffset)));
+        } else {
+          const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
+          const uint16x4_t d0 = vmin_u16(
+              vqrshrun_n_s32(sum_low, kRoundBitsVertical), v_max_bitdepth);
+          const uint16x4_t d1 = vmin_u16(
+              vqrshrun_n_s32(sum_high, kRoundBitsVertical), v_max_bitdepth);
+          vst1_u16(reinterpret_cast<uint16_t*>(dst_row), d0);
+          vst1_u16(reinterpret_cast<uint16_t*>(dst_row + 4), d1);
+        }
+        dst_row += dst_stride;
+        sy4 += delta;
+      }
+      start_x += 8;
+    } while (start_x < block_start_x + block_width);
+    dst += 8 * dst_stride;
+    start_y += 8;
+  } while (start_y < block_start_y + block_height);
+}
+
+void Init10bpp() {
+  Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  dsp->warp = Warp_NEON</*is_compound=*/false>;
+  dsp->warp_compound = Warp_NEON</*is_compound=*/true>;
+}
+
+}  // namespace
+}  // namespace high_bitdepth
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+void WarpInit_NEON() {
+  low_bitdepth::Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  high_bitdepth::Init10bpp();
+#endif
+}
 
 }  // namespace dsp
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/arm/warp_neon.h b/libgav1/src/dsp/arm/warp_neon.h
index dbcaa23..cd60602 100644
--- a/libgav1/src/dsp/arm/warp_neon.h
+++ b/libgav1/src/dsp/arm/warp_neon.h
@@ -32,6 +32,9 @@
 #if LIBGAV1_ENABLE_NEON
 #define LIBGAV1_Dsp8bpp_Warp LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_WarpCompound LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_Warp LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WarpCompound LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_WARP_NEON_H_
diff --git a/libgav1/src/dsp/arm/weight_mask_neon.cc b/libgav1/src/dsp/arm/weight_mask_neon.cc
index 7e5bff0..5ad6b97 100644
--- a/libgav1/src/dsp/arm/weight_mask_neon.cc
+++ b/libgav1/src/dsp/arm/weight_mask_neon.cc
@@ -32,20 +32,51 @@
 
 namespace libgav1 {
 namespace dsp {
-namespace low_bitdepth {
 namespace {
 
-constexpr int kRoundingBits8bpp = 4;
+inline int16x8x2_t LoadPred(const int16_t* LIBGAV1_RESTRICT prediction_0,
+                            const int16_t* LIBGAV1_RESTRICT prediction_1) {
+  const int16x8x2_t pred = {vld1q_s16(prediction_0), vld1q_s16(prediction_1)};
+  return pred;
+}
 
-template <bool mask_is_inverse>
-inline void WeightMask8_NEON(const int16_t* prediction_0,
-                             const int16_t* prediction_1, uint8_t* mask) {
-  const int16x8_t pred_0 = vld1q_s16(prediction_0);
-  const int16x8_t pred_1 = vld1q_s16(prediction_1);
+#if LIBGAV1_MAX_BITDEPTH >= 10
+inline uint16x8x2_t LoadPred(const uint16_t* LIBGAV1_RESTRICT prediction_0,
+                             const uint16_t* LIBGAV1_RESTRICT prediction_1) {
+  const uint16x8x2_t pred = {vld1q_u16(prediction_0), vld1q_u16(prediction_1)};
+  return pred;
+}
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+
+template <int bitdepth>
+inline uint16x8_t AbsolutePredDifference(const int16x8x2_t pred) {
+  static_assert(bitdepth == 8, "");
+  constexpr int rounding_bits = bitdepth - 8 + ((bitdepth == 12) ? 2 : 4);
+  return vrshrq_n_u16(
+      vreinterpretq_u16_s16(vabdq_s16(pred.val[0], pred.val[1])),
+      rounding_bits);
+}
+
+template <int bitdepth>
+inline uint16x8_t AbsolutePredDifference(const uint16x8x2_t pred) {
+  constexpr int rounding_bits = bitdepth - 8 + ((bitdepth == 12) ? 2 : 4);
+  return vrshrq_n_u16(vabdq_u16(pred.val[0], pred.val[1]), rounding_bits);
+}
+
+template <bool mask_is_inverse, int bitdepth>
+inline void WeightMask8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                             const void* LIBGAV1_RESTRICT prediction_1,
+                             uint8_t* LIBGAV1_RESTRICT mask) {
+  using PredType =
+      typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
+  using PredTypeVecx2 =
+      typename std::conditional<bitdepth == 8, int16x8x2_t, uint16x8x2_t>::type;
+  const PredTypeVecx2 pred =
+      LoadPred(static_cast<const PredType*>(prediction_0),
+               static_cast<const PredType*>(prediction_1));
+  const uint16x8_t difference = AbsolutePredDifference<bitdepth>(pred);
   const uint8x8_t difference_offset = vdup_n_u8(38);
   const uint8x8_t mask_ceiling = vdup_n_u8(64);
-  const uint16x8_t difference = vrshrq_n_u16(
-      vreinterpretq_u16_s16(vabdq_s16(pred_0, pred_1)), kRoundingBits8bpp);
   const uint8x8_t adjusted_difference =
       vqadd_u8(vqshrn_n_u16(difference, 4), difference_offset);
   const uint8x8_t mask_value = vmin_u8(adjusted_difference, mask_ceiling);
@@ -58,7 +89,7 @@
 }
 
 #define WEIGHT8_WITHOUT_STRIDE \
-  WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask)
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask)
 
 #define WEIGHT8_AND_STRIDE \
   WEIGHT8_WITHOUT_STRIDE;  \
@@ -66,9 +97,12 @@
   pred_1 += 8;             \
   mask += mask_stride
 
-template <bool mask_is_inverse>
-void WeightMask8x8_NEON(const void* prediction_0, const void* prediction_1,
-                        uint8_t* mask, ptrdiff_t mask_stride) {
+// |pred_0| and |pred_1| are cast as int16_t* for the sake of pointer math. They
+// are uint16_t* for 10bpp and 12bpp, and this is handled in WeightMask8_NEON.
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask8x8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                        const void* LIBGAV1_RESTRICT prediction_1,
+                        uint8_t* LIBGAV1_RESTRICT mask, ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y = 0;
@@ -78,9 +112,11 @@
   WEIGHT8_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask8x16_NEON(const void* prediction_0, const void* prediction_1,
-                         uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask8x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                         const void* LIBGAV1_RESTRICT prediction_1,
+                         uint8_t* LIBGAV1_RESTRICT mask,
+                         ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -92,9 +128,11 @@
   WEIGHT8_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask8x32_NEON(const void* prediction_0, const void* prediction_1,
-                         uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask8x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                         const void* LIBGAV1_RESTRICT prediction_1,
+                         uint8_t* LIBGAV1_RESTRICT mask,
+                         ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y5 = 0;
@@ -109,9 +147,9 @@
   WEIGHT8_WITHOUT_STRIDE;
 }
 
-#define WEIGHT16_WITHOUT_STRIDE                            \
-  WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask); \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8)
+#define WEIGHT16_WITHOUT_STRIDE                                      \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask); \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 8, pred_1 + 8, mask + 8)
 
 #define WEIGHT16_AND_STRIDE \
   WEIGHT16_WITHOUT_STRIDE;  \
@@ -119,9 +157,11 @@
   pred_1 += 16;             \
   mask += mask_stride
 
-template <bool mask_is_inverse>
-void WeightMask16x8_NEON(const void* prediction_0, const void* prediction_1,
-                         uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask16x8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                         const void* LIBGAV1_RESTRICT prediction_1,
+                         uint8_t* LIBGAV1_RESTRICT mask,
+                         ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y = 0;
@@ -131,9 +171,11 @@
   WEIGHT16_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask16x16_NEON(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask16x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -145,9 +187,11 @@
   WEIGHT16_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask16x32_NEON(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask16x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y5 = 0;
@@ -162,9 +206,11 @@
   WEIGHT16_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask16x64_NEON(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask16x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -176,11 +222,14 @@
   WEIGHT16_WITHOUT_STRIDE;
 }
 
-#define WEIGHT32_WITHOUT_STRIDE                                           \
-  WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask);                \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8);    \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24)
+#define WEIGHT32_WITHOUT_STRIDE                                         \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask);    \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 8, pred_1 + 8,   \
+                                              mask + 8);                \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 16, pred_1 + 16, \
+                                              mask + 16);               \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 24, pred_1 + 24, \
+                                              mask + 24)
 
 #define WEIGHT32_AND_STRIDE \
   WEIGHT32_WITHOUT_STRIDE;  \
@@ -188,9 +237,11 @@
   pred_1 += 32;             \
   mask += mask_stride
 
-template <bool mask_is_inverse>
-void WeightMask32x8_NEON(const void* prediction_0, const void* prediction_1,
-                         uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask32x8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                         const void* LIBGAV1_RESTRICT prediction_1,
+                         uint8_t* LIBGAV1_RESTRICT mask,
+                         ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   WEIGHT32_AND_STRIDE;
@@ -203,9 +254,11 @@
   WEIGHT32_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask32x16_NEON(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask32x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -217,9 +270,11 @@
   WEIGHT32_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask32x32_NEON(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask32x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y5 = 0;
@@ -234,9 +289,11 @@
   WEIGHT32_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask32x64_NEON(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask32x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -248,15 +305,22 @@
   WEIGHT32_WITHOUT_STRIDE;
 }
 
-#define WEIGHT64_WITHOUT_STRIDE                                           \
-  WeightMask8_NEON<mask_is_inverse>(pred_0, pred_1, mask);                \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 8, pred_1 + 8, mask + 8);    \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 16, pred_1 + 16, mask + 16); \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 24, pred_1 + 24, mask + 24); \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 32, pred_1 + 32, mask + 32); \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 40, pred_1 + 40, mask + 40); \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 48, pred_1 + 48, mask + 48); \
-  WeightMask8_NEON<mask_is_inverse>(pred_0 + 56, pred_1 + 56, mask + 56)
+#define WEIGHT64_WITHOUT_STRIDE                                         \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask);    \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 8, pred_1 + 8,   \
+                                              mask + 8);                \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 16, pred_1 + 16, \
+                                              mask + 16);               \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 24, pred_1 + 24, \
+                                              mask + 24);               \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 32, pred_1 + 32, \
+                                              mask + 32);               \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 40, pred_1 + 40, \
+                                              mask + 40);               \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 48, pred_1 + 48, \
+                                              mask + 48);               \
+  WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 56, pred_1 + 56, \
+                                              mask + 56)
 
 #define WEIGHT64_AND_STRIDE \
   WEIGHT64_WITHOUT_STRIDE;  \
@@ -264,9 +328,11 @@
   pred_1 += 64;             \
   mask += mask_stride
 
-template <bool mask_is_inverse>
-void WeightMask64x16_NEON(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask64x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -278,9 +344,11 @@
   WEIGHT64_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask64x32_NEON(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask64x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y5 = 0;
@@ -295,9 +363,11 @@
   WEIGHT64_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask64x64_NEON(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask64x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -309,9 +379,11 @@
   WEIGHT64_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask64x128_NEON(const void* prediction_0, const void* prediction_1,
-                           uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask64x128_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                           const void* LIBGAV1_RESTRICT prediction_1,
+                           uint8_t* LIBGAV1_RESTRICT mask,
+                           ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -324,9 +396,11 @@
   WEIGHT64_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask128x64_NEON(const void* prediction_0, const void* prediction_1,
-                           uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask128x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                           const void* LIBGAV1_RESTRICT prediction_1,
+                           uint8_t* LIBGAV1_RESTRICT mask,
+                           ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -366,9 +440,11 @@
   WEIGHT64_WITHOUT_STRIDE;
 }
 
-template <bool mask_is_inverse>
-void WeightMask128x128_NEON(const void* prediction_0, const void* prediction_1,
-                            uint8_t* mask, ptrdiff_t mask_stride) {
+template <bool mask_is_inverse, int bitdepth>
+void WeightMask128x128_NEON(const void* LIBGAV1_RESTRICT prediction_0,
+                            const void* LIBGAV1_RESTRICT prediction_1,
+                            uint8_t* LIBGAV1_RESTRICT mask,
+                            ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -416,11 +492,20 @@
   mask += 64;
   WEIGHT64_WITHOUT_STRIDE;
 }
+#undef WEIGHT8_WITHOUT_STRIDE
+#undef WEIGHT8_AND_STRIDE
+#undef WEIGHT16_WITHOUT_STRIDE
+#undef WEIGHT16_AND_STRIDE
+#undef WEIGHT32_WITHOUT_STRIDE
+#undef WEIGHT32_AND_STRIDE
+#undef WEIGHT64_WITHOUT_STRIDE
+#undef WEIGHT64_AND_STRIDE
 
 #define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \
   dsp->weight_mask[w_index][h_index][0] =                      \
-      WeightMask##width##x##height##_NEON<0>;                  \
-  dsp->weight_mask[w_index][h_index][1] = WeightMask##width##x##height##_NEON<1>
+      WeightMask##width##x##height##_NEON<0, 8>;               \
+  dsp->weight_mask[w_index][h_index][1] =                      \
+      WeightMask##width##x##height##_NEON<1, 8>
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
@@ -442,11 +527,51 @@
   INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3);
   INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4);
 }
+#undef INIT_WEIGHT_MASK_8BPP
 
 }  // namespace
-}  // namespace low_bitdepth
 
-void WeightMaskInit_NEON() { low_bitdepth::Init8bpp(); }
+#if LIBGAV1_MAX_BITDEPTH >= 10
+namespace high_bitdepth {
+namespace {
+
+#define INIT_WEIGHT_MASK_10BPP(width, height, w_index, h_index) \
+  dsp->weight_mask[w_index][h_index][0] =                       \
+      WeightMask##width##x##height##_NEON<0, 10>;               \
+  dsp->weight_mask[w_index][h_index][1] =                       \
+      WeightMask##width##x##height##_NEON<1, 10>
+void Init10bpp() {
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
+  assert(dsp != nullptr);
+  INIT_WEIGHT_MASK_10BPP(8, 8, 0, 0);
+  INIT_WEIGHT_MASK_10BPP(8, 16, 0, 1);
+  INIT_WEIGHT_MASK_10BPP(8, 32, 0, 2);
+  INIT_WEIGHT_MASK_10BPP(16, 8, 1, 0);
+  INIT_WEIGHT_MASK_10BPP(16, 16, 1, 1);
+  INIT_WEIGHT_MASK_10BPP(16, 32, 1, 2);
+  INIT_WEIGHT_MASK_10BPP(16, 64, 1, 3);
+  INIT_WEIGHT_MASK_10BPP(32, 8, 2, 0);
+  INIT_WEIGHT_MASK_10BPP(32, 16, 2, 1);
+  INIT_WEIGHT_MASK_10BPP(32, 32, 2, 2);
+  INIT_WEIGHT_MASK_10BPP(32, 64, 2, 3);
+  INIT_WEIGHT_MASK_10BPP(64, 16, 3, 1);
+  INIT_WEIGHT_MASK_10BPP(64, 32, 3, 2);
+  INIT_WEIGHT_MASK_10BPP(64, 64, 3, 3);
+  INIT_WEIGHT_MASK_10BPP(64, 128, 3, 4);
+  INIT_WEIGHT_MASK_10BPP(128, 64, 4, 3);
+  INIT_WEIGHT_MASK_10BPP(128, 128, 4, 4);
+}
+#undef INIT_WEIGHT_MASK_10BPP
+
+}  // namespace
+}  // namespace high_bitdepth
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+void WeightMaskInit_NEON() {
+  Init8bpp();
+#if LIBGAV1_MAX_BITDEPTH >= 10
+  high_bitdepth::Init10bpp();
+#endif  // LIBGAV1_MAX_BITDEPTH >= 10
+}
 
 }  // namespace dsp
 }  // namespace libgav1
diff --git a/libgav1/src/dsp/arm/weight_mask_neon.h b/libgav1/src/dsp/arm/weight_mask_neon.h
index b4749ec..573f7de 100644
--- a/libgav1/src/dsp/arm/weight_mask_neon.h
+++ b/libgav1/src/dsp/arm/weight_mask_neon.h
@@ -47,6 +47,24 @@
 #define LIBGAV1_Dsp8bpp_WeightMask_64x128 LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_WeightMask_128x64 LIBGAV1_CPU_NEON
 #define LIBGAV1_Dsp8bpp_WeightMask_128x128 LIBGAV1_CPU_NEON
+
+#define LIBGAV1_Dsp10bpp_WeightMask_8x8 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_8x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_8x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_16x8 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_16x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_16x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_16x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_32x8 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_32x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_32x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_32x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_64x16 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_64x32 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_64x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_64x128 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_128x64 LIBGAV1_CPU_NEON
+#define LIBGAV1_Dsp10bpp_WeightMask_128x128 LIBGAV1_CPU_NEON
 #endif  // LIBGAV1_ENABLE_NEON
 
 #endif  // LIBGAV1_SRC_DSP_ARM_WEIGHT_MASK_NEON_H_
diff --git a/libgav1/src/dsp/average_blend.cc b/libgav1/src/dsp/average_blend.cc
index d3ec21f..273b355 100644
--- a/libgav1/src/dsp/average_blend.cc
+++ b/libgav1/src/dsp/average_blend.cc
@@ -27,8 +27,9 @@
 namespace {
 
 template <int bitdepth, typename Pixel>
-void AverageBlend_C(const void* prediction_0, const void* prediction_1,
-                    const int width, const int height, void* const dest,
+void AverageBlend_C(const void* LIBGAV1_RESTRICT prediction_0,
+                    const void* LIBGAV1_RESTRICT prediction_1, const int width,
+                    const int height, void* const dest,
                     const ptrdiff_t dest_stride) {
   // 7.11.3.2 Rounding variables derivation process
   //   2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7))
diff --git a/libgav1/src/dsp/cdef.cc b/libgav1/src/dsp/cdef.cc
index 0b50517..ca2adfd 100644
--- a/libgav1/src/dsp/cdef.cc
+++ b/libgav1/src/dsp/cdef.cc
@@ -40,8 +40,10 @@
 int32_t Square(int32_t x) { return x * x; }
 
 template <int bitdepth, typename Pixel>
-void CdefDirection_C(const void* const source, ptrdiff_t stride,
-                     uint8_t* const direction, int* const variance) {
+void CdefDirection_C(const void* LIBGAV1_RESTRICT const source,
+                     ptrdiff_t stride,
+                     uint8_t* LIBGAV1_RESTRICT const direction,
+                     int* LIBGAV1_RESTRICT const variance) {
   assert(direction != nullptr);
   assert(variance != nullptr);
   const auto* src = static_cast<const Pixel*>(source);
@@ -121,10 +123,11 @@
 // constant large value (kCdefLargeValue) if at the boundary.
 template <int block_width, int bitdepth, typename Pixel,
           bool enable_primary = true, bool enable_secondary = true>
-void CdefFilter_C(const uint16_t* src, const ptrdiff_t src_stride,
-                  const int block_height, const int primary_strength,
-                  const int secondary_strength, const int damping,
-                  const int direction, void* const dest,
+void CdefFilter_C(const uint16_t* LIBGAV1_RESTRICT src,
+                  const ptrdiff_t src_stride, const int block_height,
+                  const int primary_strength, const int secondary_strength,
+                  const int damping, const int direction,
+                  void* LIBGAV1_RESTRICT const dest,
                   const ptrdiff_t dest_stride) {
   static_assert(block_width == 4 || block_width == 8, "Invalid CDEF width.");
   static_assert(enable_primary || enable_secondary, "");
diff --git a/libgav1/src/dsp/convolve.cc b/libgav1/src/dsp/convolve.cc
index 727b4af..f11b45e 100644
--- a/libgav1/src/dsp/convolve.cc
+++ b/libgav1/src/dsp/convolve.cc
@@ -33,34 +33,39 @@
 constexpr int kVerticalOffset = 3;
 
 // Compound prediction output ranges from ConvolveTest.ShowRange.
+// In some cases, the horizontal or vertical filter will be omitted. This table
+// shows the general case, where the downscaled horizontal output is input to
+// the vertical filter via the |intermediate_result| array. The final output is
+// either Pixel or compound values, depending on the |is_compound| variable.
 // Bitdepth:  8 Input range:            [       0,      255]
-//   intermediate range:                [   -7140,    23460]
-//   first pass output range:           [   -1785,     5865]
-//   intermediate range:                [ -328440,   589560]
-//   second pass output range:          [       0,      255]
-//   compound second pass output range: [   -5132,     9212]
+//   Horizontal upscaled range:         [   -7140,    23460]
+//   Horizontal downscaled range:       [   -1785,     5865]
+//   Vertical upscaled range:           [ -328440,   589560]
+//   Pixel output range:                [       0,      255]
+//   Compound output range:             [   -5132,     9212]
 //
 // Bitdepth: 10 Input range:            [       0,     1023]
-//   intermediate range:                [  -28644,    94116]
-//   first pass output range:           [   -7161,    23529]
-//   intermediate range:                [-1317624,  2365176]
-//   second pass output range:          [       0,     1023]
-//   compound second pass output range: [    3988,    61532]
+//   Horizontal upscaled range:         [  -28644,    94116]
+//   Horizontal downscaled range:       [   -7161,    23529]
+//   Vertical upscaled range:           [-1317624,  2365176]
+//   Pixel output range:                [       0,     1023]
+//   Compound output range:             [    3988,    61532]
 //
 // Bitdepth: 12 Input range:            [       0,     4095]
-//   intermediate range:                [ -114660,   376740]
-//   first pass output range:           [   -7166,    23546]
-//   intermediate range:                [-1318560,  2366880]
-//   second pass output range:          [       0,     4095]
-//   compound second pass output range: [    3974,    61559]
+//   Horizontal upscaled range:         [ -114660,   376740]
+//   Horizontal downscaled range:       [   -7166,    23546]
+//   Vertical upscaled range:           [-1318560,  2366880]
+//   Pixel output range:                [       0,     4095]
+//   Compound output range:             [    3974,    61559]
 
 template <int bitdepth, typename Pixel>
-void ConvolveScale2D_C(const void* const reference,
+void ConvolveScale2D_C(const void* LIBGAV1_RESTRICT const reference,
                        const ptrdiff_t reference_stride,
                        const int horizontal_filter_index,
                        const int vertical_filter_index, const int subpixel_x,
                        const int subpixel_y, const int step_x, const int step_y,
-                       const int width, const int height, void* prediction,
+                       const int width, const int height,
+                       void* LIBGAV1_RESTRICT prediction,
                        const ptrdiff_t pred_stride) {
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
                                            ? kInterRoundBitsHorizontal12bpp
@@ -137,14 +142,12 @@
 }
 
 template <int bitdepth, typename Pixel>
-void ConvolveCompoundScale2D_C(const void* const reference,
-                               const ptrdiff_t reference_stride,
-                               const int horizontal_filter_index,
-                               const int vertical_filter_index,
-                               const int subpixel_x, const int subpixel_y,
-                               const int step_x, const int step_y,
-                               const int width, const int height,
-                               void* prediction, const ptrdiff_t pred_stride) {
+void ConvolveCompoundScale2D_C(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int vertical_filter_index, const int subpixel_x, const int subpixel_y,
+    const int step_x, const int step_y, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   // All compound functions output to the predictor buffer with |pred_stride|
   // equal to |width|.
   assert(pred_stride == width);
@@ -223,13 +226,13 @@
 }
 
 template <int bitdepth, typename Pixel>
-void ConvolveCompound2D_C(const void* const reference,
+void ConvolveCompound2D_C(const void* LIBGAV1_RESTRICT const reference,
                           const ptrdiff_t reference_stride,
                           const int horizontal_filter_index,
                           const int vertical_filter_index,
                           const int horizontal_filter_id,
                           const int vertical_filter_id, const int width,
-                          const int height, void* prediction,
+                          const int height, void* LIBGAV1_RESTRICT prediction,
                           const ptrdiff_t pred_stride) {
   // All compound functions output to the predictor buffer with |pred_stride|
   // equal to |width|.
@@ -307,11 +310,13 @@
 // The output is the single prediction of the block, clipped to valid pixel
 // range.
 template <int bitdepth, typename Pixel>
-void Convolve2D_C(const void* const reference, const ptrdiff_t reference_stride,
+void Convolve2D_C(const void* LIBGAV1_RESTRICT const reference,
+                  const ptrdiff_t reference_stride,
                   const int horizontal_filter_index,
                   const int vertical_filter_index,
                   const int horizontal_filter_id, const int vertical_filter_id,
-                  const int width, const int height, void* prediction,
+                  const int width, const int height,
+                  void* LIBGAV1_RESTRICT prediction,
                   const ptrdiff_t pred_stride) {
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
                                            ? kInterRoundBitsHorizontal12bpp
@@ -385,13 +390,13 @@
 // The output is the single prediction of the block, clipped to valid pixel
 // range.
 template <int bitdepth, typename Pixel>
-void ConvolveHorizontal_C(const void* const reference,
+void ConvolveHorizontal_C(const void* LIBGAV1_RESTRICT const reference,
                           const ptrdiff_t reference_stride,
                           const int horizontal_filter_index,
                           const int /*vertical_filter_index*/,
                           const int horizontal_filter_id,
                           const int /*vertical_filter_id*/, const int width,
-                          const int height, void* prediction,
+                          const int height, void* LIBGAV1_RESTRICT prediction,
                           const ptrdiff_t pred_stride) {
   constexpr int kRoundBitsHorizontal = (bitdepth == 12)
                                            ? kInterRoundBitsHorizontal12bpp
@@ -427,13 +432,13 @@
 // The output is the single prediction of the block, clipped to valid pixel
 // range.
 template <int bitdepth, typename Pixel>
-void ConvolveVertical_C(const void* const reference,
+void ConvolveVertical_C(const void* LIBGAV1_RESTRICT const reference,
                         const ptrdiff_t reference_stride,
                         const int /*horizontal_filter_index*/,
                         const int vertical_filter_index,
                         const int /*horizontal_filter_id*/,
                         const int vertical_filter_id, const int width,
-                        const int height, void* prediction,
+                        const int height, void* LIBGAV1_RESTRICT prediction,
                         const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
   const ptrdiff_t src_stride = reference_stride / sizeof(Pixel);
@@ -464,13 +469,13 @@
 }
 
 template <int bitdepth, typename Pixel>
-void ConvolveCopy_C(const void* const reference,
+void ConvolveCopy_C(const void* LIBGAV1_RESTRICT const reference,
                     const ptrdiff_t reference_stride,
                     const int /*horizontal_filter_index*/,
                     const int /*vertical_filter_index*/,
                     const int /*horizontal_filter_id*/,
                     const int /*vertical_filter_id*/, const int width,
-                    const int height, void* prediction,
+                    const int height, void* LIBGAV1_RESTRICT prediction,
                     const ptrdiff_t pred_stride) {
   const auto* src = static_cast<const uint8_t*>(reference);
   auto* dest = static_cast<uint8_t*>(prediction);
@@ -483,13 +488,13 @@
 }
 
 template <int bitdepth, typename Pixel>
-void ConvolveCompoundCopy_C(const void* const reference,
+void ConvolveCompoundCopy_C(const void* LIBGAV1_RESTRICT const reference,
                             const ptrdiff_t reference_stride,
                             const int /*horizontal_filter_index*/,
                             const int /*vertical_filter_index*/,
                             const int /*horizontal_filter_id*/,
                             const int /*vertical_filter_id*/, const int width,
-                            const int height, void* prediction,
+                            const int height, void* LIBGAV1_RESTRICT prediction,
                             const ptrdiff_t pred_stride) {
   // All compound functions output to the predictor buffer with |pred_stride|
   // equal to |width|.
@@ -523,11 +528,11 @@
 // blended with another predictor to generate the final prediction of the block.
 template <int bitdepth, typename Pixel>
 void ConvolveCompoundHorizontal_C(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int horizontal_filter_index, const int /*vertical_filter_index*/,
-    const int horizontal_filter_id, const int /*vertical_filter_id*/,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t pred_stride) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int /*vertical_filter_index*/, const int horizontal_filter_id,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   // All compound functions output to the predictor buffer with |pred_stride|
   // equal to |width|.
   assert(pred_stride == width);
@@ -567,14 +572,12 @@
 // The output is not clipped to valid pixel range. Its output will be
 // blended with another predictor to generate the final prediction of the block.
 template <int bitdepth, typename Pixel>
-void ConvolveCompoundVertical_C(const void* const reference,
-                                const ptrdiff_t reference_stride,
-                                const int /*horizontal_filter_index*/,
-                                const int vertical_filter_index,
-                                const int /*horizontal_filter_id*/,
-                                const int vertical_filter_id, const int width,
-                                const int height, void* prediction,
-                                const ptrdiff_t pred_stride) {
+void ConvolveCompoundVertical_C(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int vertical_filter_index, const int /*horizontal_filter_id*/,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   // All compound functions output to the predictor buffer with |pred_stride|
   // equal to |width|.
   assert(pred_stride == width);
@@ -615,14 +618,12 @@
 // The output is the single prediction of the block, clipped to valid pixel
 // range.
 template <int bitdepth, typename Pixel>
-void ConvolveIntraBlockCopy2D_C(const void* const reference,
-                                const ptrdiff_t reference_stride,
-                                const int /*horizontal_filter_index*/,
-                                const int /*vertical_filter_index*/,
-                                const int /*horizontal_filter_id*/,
-                                const int /*vertical_filter_id*/,
-                                const int width, const int height,
-                                void* prediction, const ptrdiff_t pred_stride) {
+void ConvolveIntraBlockCopy2D_C(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
   assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
   const auto* src = static_cast<const Pixel*>(reference);
@@ -670,14 +671,12 @@
 // The filtering of intra block copy is simply the average of current and
 // the next pixel.
 template <int bitdepth, typename Pixel, bool is_horizontal>
-void ConvolveIntraBlockCopy1D_C(const void* const reference,
-                                const ptrdiff_t reference_stride,
-                                const int /*horizontal_filter_index*/,
-                                const int /*vertical_filter_index*/,
-                                const int /*horizontal_filter_id*/,
-                                const int /*vertical_filter_id*/,
-                                const int width, const int height,
-                                void* prediction, const ptrdiff_t pred_stride) {
+void ConvolveIntraBlockCopy1D_C(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
   assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
   const auto* src = static_cast<const Pixel*>(reference);
diff --git a/libgav1/src/dsp/convolve.inc b/libgav1/src/dsp/convolve.inc
index 140648b..e0f755e 100644
--- a/libgav1/src/dsp/convolve.inc
+++ b/libgav1/src/dsp/convolve.inc
@@ -45,6 +45,7 @@
   return 4;
 }
 
-constexpr int kIntermediateStride = kMaxSuperBlockSizeInPixels;
+constexpr int kIntermediateAllocWidth = kMaxSuperBlockSizeInPixels;
+constexpr int kIntermediateStride = 8;
 constexpr int kHorizontalOffset = 3;
 constexpr int kFilterIndexShift = 6;
diff --git a/libgav1/src/dsp/distance_weighted_blend.cc b/libgav1/src/dsp/distance_weighted_blend.cc
index a035fbe..34d10fc 100644
--- a/libgav1/src/dsp/distance_weighted_blend.cc
+++ b/libgav1/src/dsp/distance_weighted_blend.cc
@@ -27,10 +27,12 @@
 namespace {
 
 template <int bitdepth, typename Pixel>
-void DistanceWeightedBlend_C(const void* prediction_0, const void* prediction_1,
+void DistanceWeightedBlend_C(const void* LIBGAV1_RESTRICT prediction_0,
+                             const void* LIBGAV1_RESTRICT prediction_1,
                              const uint8_t weight_0, const uint8_t weight_1,
                              const int width, const int height,
-                             void* const dest, const ptrdiff_t dest_stride) {
+                             void* LIBGAV1_RESTRICT const dest,
+                             const ptrdiff_t dest_stride) {
   // 7.11.3.2 Rounding variables derivation process
   //   2 * FILTER_BITS(7) - (InterRound0(3|5) + InterRound1(7))
   constexpr int inter_post_round_bits = (bitdepth == 12) ? 2 : 4;
diff --git a/libgav1/src/dsp/dsp.cc b/libgav1/src/dsp/dsp.cc
index a3d7701..aac0ca0 100644
--- a/libgav1/src/dsp/dsp.cc
+++ b/libgav1/src/dsp/dsp.cc
@@ -155,7 +155,9 @@
     WarpInit_NEON();
     WeightMaskInit_NEON();
 #if LIBGAV1_MAX_BITDEPTH >= 10
+    ConvolveInit10bpp_NEON();
     InverseTransformInit10bpp_NEON();
+    LoopRestorationInit10bpp_NEON();
 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
 #endif  // LIBGAV1_ENABLE_NEON
   });
diff --git a/libgav1/src/dsp/dsp.h b/libgav1/src/dsp/dsp.h
index 153db7f..f9e6b22 100644
--- a/libgav1/src/dsp/dsp.h
+++ b/libgav1/src/dsp/dsp.h
@@ -50,23 +50,23 @@
 };
 
 // List of valid 1D transforms.
-enum Transform1D : uint8_t {
-  k1DTransformDct,   // Discrete Cosine Transform.
-  k1DTransformAdst,  // Asymmetric Discrete Sine Transform.
-  k1DTransformIdentity,
-  k1DTransformWht,  // Walsh Hadamard Transform.
-  kNum1DTransforms
+enum Transform1d : uint8_t {
+  kTransform1dDct,   // Discrete Cosine Transform.
+  kTransform1dAdst,  // Asymmetric Discrete Sine Transform.
+  kTransform1dIdentity,
+  kTransform1dWht,  // Walsh Hadamard Transform.
+  kNumTransform1ds
 };
 
 // List of valid 1D transform sizes. Not all transforms may be available for all
 // the sizes.
-enum TransformSize1D : uint8_t {
-  k1DTransformSize4,
-  k1DTransformSize8,
-  k1DTransformSize16,
-  k1DTransformSize32,
-  k1DTransformSize64,
-  kNum1DTransformSizes
+enum Transform1dSize : uint8_t {
+  kTransform1dSize4,
+  kTransform1dSize8,
+  kTransform1dSize16,
+  kTransform1dSize32,
+  kTransform1dSize64,
+  kNumTransform1dSizes
 };
 
 // The maximum width of the loop filter, fewer pixels may be filtered depending
@@ -120,36 +120,36 @@
   abort();
 }
 
-inline const char* ToString(const Transform1D transform) {
+inline const char* ToString(const Transform1d transform) {
   switch (transform) {
-    case k1DTransformDct:
-      return "k1DTransformDct";
-    case k1DTransformAdst:
-      return "k1DTransformAdst";
-    case k1DTransformIdentity:
-      return "k1DTransformIdentity";
-    case k1DTransformWht:
-      return "k1DTransformWht";
-    case kNum1DTransforms:
-      return "kNum1DTransforms";
+    case kTransform1dDct:
+      return "kTransform1dDct";
+    case kTransform1dAdst:
+      return "kTransform1dAdst";
+    case kTransform1dIdentity:
+      return "kTransform1dIdentity";
+    case kTransform1dWht:
+      return "kTransform1dWht";
+    case kNumTransform1ds:
+      return "kNumTransform1ds";
   }
   abort();
 }
 
-inline const char* ToString(const TransformSize1D transform_size) {
+inline const char* ToString(const Transform1dSize transform_size) {
   switch (transform_size) {
-    case k1DTransformSize4:
-      return "k1DTransformSize4";
-    case k1DTransformSize8:
-      return "k1DTransformSize8";
-    case k1DTransformSize16:
-      return "k1DTransformSize16";
-    case k1DTransformSize32:
-      return "k1DTransformSize32";
-    case k1DTransformSize64:
-      return "k1DTransformSize64";
-    case kNum1DTransformSizes:
-      return "kNum1DTransformSizes";
+    case kTransform1dSize4:
+      return "kTransform1dSize4";
+    case kTransform1dSize8:
+      return "kTransform1dSize8";
+    case kTransform1dSize16:
+      return "kTransform1dSize16";
+    case kTransform1dSize32:
+      return "kTransform1dSize32";
+    case kTransform1dSize64:
+      return "kTransform1dSize64";
+    case kNumTransform1dSizes:
+      return "kNumTransform1dSizes";
   }
   abort();
 }
@@ -194,6 +194,7 @@
 // by bitdepth with |stride| given in bytes. |top| is an unaligned pointer to
 // the row above |dst|. |left| is an aligned vector of the column to the left
 // of |dst|. top-left and bottom-left may be accessed.
+// The pointer arguments do not alias one another.
 using IntraPredictorFunc = void (*)(void* dst, ptrdiff_t stride,
                                     const void* top, const void* left);
 using IntraPredictorFuncs =
@@ -209,6 +210,7 @@
 // |top| has been upsampled as described in '7.11.2.11. Intra edge upsample
 // process'. This can occur in cases with |width| + |height| <= 16. top-right
 // is accessed.
+// The pointer arguments do not alias one another.
 using DirectionalIntraPredictorZone1Func = void (*)(void* dst, ptrdiff_t stride,
                                                     const void* top, int width,
                                                     int height, int xstep,
@@ -226,6 +228,7 @@
 // described in '7.11.2.11. Intra edge upsample process'. This can occur in
 // cases with |width| + |height| <= 16. top-left and upper-left are accessed,
 // up to [-2] in each if |upsampled_top/left| are set.
+// The pointer arguments do not alias one another.
 using DirectionalIntraPredictorZone2Func = void (*)(
     void* dst, ptrdiff_t stride, const void* top, const void* left, int width,
     int height, int xstep, int ystep, bool upsampled_top, bool upsampled_left);
@@ -240,6 +243,7 @@
 // |left| has been upsampled as described in '7.11.2.11. Intra edge upsample
 // process'. This can occur in cases with |width| + |height| <= 16. bottom-left
 // is accessed.
+// The pointer arguments do not alias one another.
 using DirectionalIntraPredictorZone3Func = void (*)(void* dst, ptrdiff_t stride,
                                                     const void* left, int width,
                                                     int height, int ystep,
@@ -250,6 +254,7 @@
 // by bitdepth with |stride| given in bytes. |top| is an unaligned pointer to
 // the row above |dst|. |left| is an aligned vector of the column to the left
 // of |dst|. |width| and |height| are the size of the block in pixels.
+// The pointer arguments do not alias one another.
 using FilterIntraPredictorFunc = void (*)(void* dst, ptrdiff_t stride,
                                           const void* top, const void* left,
                                           FilterIntraPredictor pred, int width,
@@ -303,11 +308,14 @@
 // 7.13.3).
 // Apply the inverse transforms and add the residual to the destination frame
 // for the transform type and block size |tx_size| starting at position
-// |start_x| and |start_y|. |dst_frame| is a pointer to an Array2D.
-// |adjusted_tx_height| is the number of rows to process based on the non-zero
-// coefficient count in the block. It will be 1 (non-zero coefficient count ==
-// 1), 4 or a multiple of 8 up to 32 or the original transform height,
-// whichever is less.
+// |start_x| and |start_y|. |dst_frame| is a pointer to an Array2D of Pixel
+// values. |adjusted_tx_height| is the number of rows to process based on the
+// non-zero coefficient count in the block. It will be 1 (non-zero coefficient
+// count == 1), 4 or a multiple of 8 up to 32 or the original transform height,
+// whichever is less. |src_buffer| is a pointer to an Array2D of Residual
+// values. On input |src_buffer| contains the dequantized values, on output it
+// contains the residual.
+// The pointer arguments do not alias one another.
 using InverseTransformAddFunc = void (*)(TransformType tx_type,
                                          TransformSize tx_size,
                                          int adjusted_tx_height,
@@ -316,7 +324,7 @@
 // The final dimension holds row and column transforms indexed with kRow and
 // kColumn.
 using InverseTransformAddFuncs =
-    InverseTransformAddFunc[kNum1DTransforms][kNum1DTransformSizes][2];
+    InverseTransformAddFunc[kNumTransform1ds][kNumTransform1dSizes][2];
 
 //------------------------------------------------------------------------------
 // Post processing.
@@ -324,6 +332,13 @@
 // Loop filter function signature. Section 7.14.
 // |dst| is an unaligned pointer to the output block. Pixel size is determined
 // by bitdepth with |stride| given in bytes.
+// <threshold param> <spec name> <range>
+// |outer_thresh|    blimit      [7, 193]
+// |inner_thresh|    limit       [1, 63]
+// |hev_thresh|      thresh      [0, 63]
+// These are scaled by the implementation by 'bitdepth - 8' to produce
+// the spec variables blimitBd, limitBd and threshBd.
+// Note these functions are not called when the loop filter level is 0.
 using LoopFilterFunc = void (*)(void* dst, ptrdiff_t stride, int outer_thresh,
                                 int inner_thresh, int hev_thresh);
 using LoopFilterFuncs =
@@ -333,6 +348,7 @@
 // |src| is a pointer to the source block. Pixel size is determined by bitdepth
 // with |stride| given in bytes. |direction| and |variance| are output
 // parameters and must not be nullptr.
+// The pointer arguments do not alias one another.
 using CdefDirectionFunc = void (*)(const void* src, ptrdiff_t stride,
                                    uint8_t* direction, int* variance);
 
@@ -344,6 +360,7 @@
 // parameters.
 // |direction| is the filtering direction.
 // |dest| is the output buffer. |dest_stride| is given in bytes.
+// The pointer arguments do not alias one another.
 using CdefFilteringFunc = void (*)(const uint16_t* source,
                                    ptrdiff_t source_stride, int block_height,
                                    int primary_strength, int secondary_strength,
@@ -381,6 +398,7 @@
 // |step| is the number of subpixels to move the kernel for the next destination
 // pixel.
 // |initial_subpixel_x| is a base offset from which |step| increments.
+// The pointer arguments do not alias one another.
 using SuperResFunc = void (*)(const void* coefficients, void* source,
                               ptrdiff_t source_stride, int height,
                               int downscaled_width, int upscaled_width,
@@ -397,6 +415,7 @@
 // |top_border_stride| and |bottom_border_stride| are given in pixels.
 // |restoration_buffer| contains buffers required for self guided filter and
 // wiener filter. They must be initialized before calling.
+// The pointer arguments do not alias one another.
 using LoopRestorationFunc = void (*)(
     const RestorationUnitInfo& restoration_info, const void* source,
     ptrdiff_t stride, const void* top_border, ptrdiff_t top_border_stride,
@@ -425,6 +444,7 @@
 // used. For compound vertical filtering kInterRoundBitsCompoundVertical will be
 // used. Otherwise kInterRoundBitsVertical & kInterRoundBitsVertical12bpp will
 // be used.
+// The pointer arguments do not alias one another.
 using ConvolveFunc = void (*)(const void* reference, ptrdiff_t reference_stride,
                               int horizontal_filter_index,
                               int vertical_filter_index,
@@ -462,6 +482,7 @@
 // used. For compound vertical filtering kInterRoundBitsCompoundVertical will be
 // used. Otherwise kInterRoundBitsVertical & kInterRoundBitsVertical12bpp will
 // be used.
+// The pointer arguments do not alias one another.
 using ConvolveScaleFunc = void (*)(const void* reference,
                                    ptrdiff_t reference_stride,
                                    int horizontal_filter_index,
@@ -482,6 +503,7 @@
 // The stride for the input buffers is equal to |width|.
 // The valid range of block size is [8x8, 128x128] for the luma plane.
 // |mask| is the output buffer. |mask_stride| is the output buffer stride.
+// The pointer arguments do not alias one another.
 using WeightMaskFunc = void (*)(const void* prediction_0,
                                 const void* prediction_1, uint8_t* mask,
                                 ptrdiff_t mask_stride);
@@ -504,6 +526,7 @@
 // The stride for the input buffers is equal to |width|.
 // The valid range of block size is [8x8, 128x128] for the luma plane.
 // |dest| is the output buffer. |dest_stride| is the output buffer stride.
+// The pointer arguments do not alias one another.
 using AverageBlendFunc = void (*)(const void* prediction_0,
                                   const void* prediction_1, int width,
                                   int height, void* dest,
@@ -525,6 +548,7 @@
 // The stride for the input buffers is equal to |width|.
 // The valid range of block size is [8x8, 128x128] for the luma plane.
 // |dest| is the output buffer. |dest_stride| is the output buffer stride.
+// The pointer arguments do not alias one another.
 using DistanceWeightedBlendFunc = void (*)(const void* prediction_0,
                                            const void* prediction_1,
                                            uint8_t weight_0, uint8_t weight_1,
@@ -550,17 +574,18 @@
 // |mask_stride| is corresponding stride.
 // |width|, |height| are the same for both input blocks.
 // If it's inter_intra (or wedge_inter_intra), the valid range of block size is
-// [8x8, 32x32]. Otherwise (including difference weighted prediction and
-// compound average prediction), the valid range is [8x8, 128x128].
+// [8x8, 32x32], no 4:1/1:4 blocks (Section 5.11.28). Otherwise (including
+// difference weighted prediction and compound average prediction), the valid
+// range is [8x8, 128x128].
 // If there's subsampling, the corresponding width and height are halved for
 // chroma planes.
-// |subsampling_x|, |subsampling_y| are the subsampling factors.
 // |is_inter_intra| stands for the prediction mode. If it is true, one of the
 // prediction blocks is from intra prediction of current frame. Otherwise, two
 // prediction blocks are both inter frame predictions.
 // |is_wedge_inter_intra| indicates if the mask is for the wedge prediction.
 // |dest| is the output block.
 // |dest_stride| is the corresponding stride for dest.
+// The pointer arguments do not alias one another.
 using MaskBlendFunc = void (*)(const void* prediction_0,
                                const void* prediction_1,
                                ptrdiff_t prediction_stride_1,
@@ -577,6 +602,7 @@
 // |is_inter_intra| is true and |bitdepth| == 8.
 // |prediction_[01]| are Pixel values (uint8_t).
 // |prediction_1| is also the output buffer.
+// The pointer arguments do not alias one another.
 using InterIntraMaskBlendFunc8bpp = void (*)(const uint8_t* prediction_0,
                                              uint8_t* prediction_1,
                                              ptrdiff_t prediction_stride_1,
@@ -600,9 +626,12 @@
 // clipped. Therefore obmc blending process doesn't need to clip the output.
 // |prediction| is the first input block, which will be overwritten.
 // |prediction_stride| is the stride, given in bytes.
-// |width|, |height| are the same for both input blocks.
+// |width|, |height| are the same for both input blocks. The range is [4x2,
+// 32x32] for kObmcDirectionVertical and [2x4, 32x32] for
+// kObmcDirectionHorizontal, see Section 7.11.3.9.
 // |obmc_prediction| is the second input block.
 // |obmc_prediction_stride| is its stride, given in bytes.
+// The pointer arguments do not alias one another.
 using ObmcBlendFunc = void (*)(void* prediction, ptrdiff_t prediction_stride,
                                int width, int height,
                                const void* obmc_prediction,
@@ -645,6 +674,7 @@
 //   Therefore, there must be at least one extra padding byte after the right
 //   border of the last row in the source buffer.
 // * The top and bottom borders must be at least 13 pixels high.
+// The pointer arguments do not alias one another.
 using WarpFunc = void (*)(const void* source, ptrdiff_t source_stride,
                           int source_width, int source_height,
                           const int* warp_params, int subsampling_x,
@@ -686,6 +716,7 @@
 // from frame header, mainly providing auto_regression_coeff_u and
 // auto_regression_coeff_v for each chroma plane's filter, and
 // auto_regression_shift to right shift the filter sums by.
+// The pointer arguments do not alias one another.
 using ChromaAutoRegressionFunc = void (*)(const FilmGrainParams& params,
                                           const void* luma_grain_buffer,
                                           int subsampling_x, int subsampling_y,
@@ -704,6 +735,7 @@
 // Because this function treats all planes identically and independently, it is
 // simplified to take one grain buffer at a time. This means duplicating some
 // random number generations, but that work can be reduced in other ways.
+// The pointer arguments do not alias one another.
 using ConstructNoiseStripesFunc = void (*)(const void* grain_buffer,
                                            int grain_seed, int width,
                                            int height, int subsampling_x,
@@ -720,6 +752,7 @@
 // Array2D containing the allocated plane for this frame. Because this function
 // treats all planes identically and independently, it is simplified to take one
 // grain buffer at a time.
+// The pointer arguments do not alias one another.
 using ConstructNoiseImageOverlapFunc =
     void (*)(const void* noise_stripes_buffer, int width, int height,
              int subsampling_x, int subsampling_y, void* noise_image_buffer);
@@ -730,9 +763,12 @@
 // |num_points| can be between 0 and 15. When 0, the lookup table is set to
 // zero.
 // |point_value| and |point_scaling| have |num_points| valid elements.
-using InitializeScalingLutFunc = void (*)(
-    int num_points, const uint8_t point_value[], const uint8_t point_scaling[],
-    uint8_t scaling_lut[kScalingLookupTableSize]);
+// The pointer arguments do not alias one another.
+using InitializeScalingLutFunc = void (*)(int num_points,
+                                          const uint8_t point_value[],
+                                          const uint8_t point_scaling[],
+                                          int16_t* scaling_lut,
+                                          const int scaling_lut_length);
 
 // Blend noise with image. Section 7.18.3.5, third code block.
 // |width| is the width of each row, while |height| is how many rows to compute.
@@ -749,18 +785,19 @@
 // |scaling_shift| is applied as a right shift after scaling, so that scaling
 // down is possible. It is found in FilmGrainParams, but supplied directly to
 // BlendNoiseWithImageLumaFunc because it's the only member used.
-using BlendNoiseWithImageLumaFunc =
-    void (*)(const void* noise_image_ptr, int min_value, int max_value,
-             int scaling_shift, int width, int height, int start_height,
-             const uint8_t scaling_lut_y[kScalingLookupTableSize],
-             const void* source_plane_y, ptrdiff_t source_stride_y,
-             void* dest_plane_y, ptrdiff_t dest_stride_y);
+// The dest plane may point to the source plane, depending on the value of
+// frame_header.show_existing_frame. |noise_image_ptr| and scaling_lut.* do not
+// alias other arguments.
+using BlendNoiseWithImageLumaFunc = void (*)(
+    const void* noise_image_ptr, int min_value, int max_value,
+    int scaling_shift, int width, int height, int start_height,
+    const int16_t* scaling_lut_y, const void* source_plane_y,
+    ptrdiff_t source_stride_y, void* dest_plane_y, ptrdiff_t dest_stride_y);
 
 using BlendNoiseWithImageChromaFunc = void (*)(
     Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
     int min_value, int max_value, int width, int height, int start_height,
-    int subsampling_x, int subsampling_y,
-    const uint8_t scaling_lut[kScalingLookupTableSize],
+    int subsampling_x, int subsampling_y, const int16_t* scaling_lut,
     const void* source_plane_y, ptrdiff_t source_stride_y,
     const void* source_plane_uv, ptrdiff_t source_stride_uv,
     void* dest_plane_uv, ptrdiff_t dest_stride_uv);
@@ -790,6 +827,8 @@
 // tile.
 // |motion_field| is the output which saves the projected motion field
 // information.
+// Note: Only the entry from the 8-bit Dsp table is used as this function is
+// bitdepth agnostic.
 using MotionFieldProjectionKernelFunc = void (*)(
     const ReferenceInfo& reference_info, int reference_to_current_with_sign,
     int dst_sign, int y8_start, int y8_end, int x8_start, int x8_end,
@@ -797,13 +836,16 @@
 
 // Compound temporal motion vector projection function signature.
 // Section 7.9.3 and 7.10.2.10.
-// |temporal_mvs| is the set of temporal reference motion vectors.
+// |temporal_mvs| is the aligned set of temporal reference motion vectors.
 // |temporal_reference_offsets| specifies the number of frames covered by the
 // original motion vector.
 // |reference_offsets| specifies the number of frames to be covered by the
 // projected motion vector.
 // |count| is the number of the temporal motion vectors.
-// |candidate_mvs| is the set of projected motion vectors.
+// |candidate_mvs| is the aligned set of projected motion vectors.
+// The pointer arguments do not alias one another.
+// Note: Only the entry from the 8-bit Dsp table is used as this function is
+// bitdepth agnostic.
 using MvProjectionCompoundFunc = void (*)(
     const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
     const int reference_offsets[2], int count,
@@ -811,13 +853,16 @@
 
 // Single temporal motion vector projection function signature.
 // Section 7.9.3 and 7.10.2.10.
-// |temporal_mvs| is the set of temporal reference motion vectors.
+// |temporal_mvs| is the aligned set of temporal reference motion vectors.
 // |temporal_reference_offsets| specifies the number of frames covered by the
 // original motion vector.
 // |reference_offset| specifies the number of frames to be covered by the
 // projected motion vector.
 // |count| is the number of the temporal motion vectors.
-// |candidate_mvs| is the set of projected motion vectors.
+// |candidate_mvs| is the aligned set of projected motion vectors.
+// The pointer arguments do not alias one another.
+// Note: Only the entry from the 8-bit Dsp table is used as this function is
+// bitdepth agnostic.
 using MvProjectionSingleFunc = void (*)(
     const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
     int reference_offset, int count, MotionVector* candidate_mvs);
diff --git a/libgav1/src/dsp/film_grain.cc b/libgav1/src/dsp/film_grain.cc
index 41d1dd0..fa12b69 100644
--- a/libgav1/src/dsp/film_grain.cc
+++ b/libgav1/src/dsp/film_grain.cc
@@ -29,29 +29,26 @@
 #include "src/utils/common.h"
 #include "src/utils/compiler_attributes.h"
 #include "src/utils/logging.h"
+#include "src/utils/memory.h"
 
 namespace libgav1 {
 namespace dsp {
 namespace film_grain {
 namespace {
 
-// Making this a template function prevents it from adding to code size when it
-// is not placed in the DSP table. Most functions in the dsp directory change
-// behavior by bitdepth, but because this one doesn't, it receives a dummy
-// parameter with one enforced value, ensuring only one copy is made.
-template <int singleton>
-void InitializeScalingLookupTable_C(
-    int num_points, const uint8_t point_value[], const uint8_t point_scaling[],
-    uint8_t scaling_lut[kScalingLookupTableSize]) {
-  static_assert(singleton == 0,
-                "Improper instantiation of InitializeScalingLookupTable_C. "
-                "There should be only one copy of this function.");
+template <int bitdepth>
+void InitializeScalingLookupTable_C(int num_points, const uint8_t point_value[],
+                                    const uint8_t point_scaling[],
+                                    int16_t* scaling_lut,
+                                    const int scaling_lut_length) {
   if (num_points == 0) {
-    memset(scaling_lut, 0, sizeof(scaling_lut[0]) * kScalingLookupTableSize);
+    memset(scaling_lut, 0, sizeof(scaling_lut[0]) * scaling_lut_length);
     return;
   }
-  static_assert(sizeof(scaling_lut[0]) == 1, "");
-  memset(scaling_lut, point_scaling[0], point_value[0]);
+  constexpr int index_shift = bitdepth - kBitdepth8;
+  static_assert(sizeof(scaling_lut[0]) == 2, "");
+  Memset(scaling_lut, point_scaling[0],
+         std::max(static_cast<int>(point_value[0]), 1) << index_shift);
   for (int i = 0; i < num_points - 1; ++i) {
     const int delta_y = point_scaling[i + 1] - point_scaling[i];
     const int delta_x = point_value[i + 1] - point_value[i];
@@ -59,25 +56,38 @@
     for (int x = 0; x < delta_x; ++x) {
       const int v = point_scaling[i] + ((x * delta + 32768) >> 16);
       assert(v >= 0 && v <= UINT8_MAX);
-      scaling_lut[point_value[i] + x] = v;
+      const int lut_index = (point_value[i] + x) << index_shift;
+      scaling_lut[lut_index] = v;
     }
   }
-  const uint8_t last_point_value = point_value[num_points - 1];
-  memset(&scaling_lut[last_point_value], point_scaling[num_points - 1],
-         kScalingLookupTableSize - last_point_value);
+  const int16_t last_point_value = point_value[num_points - 1];
+  const int x_base = last_point_value << index_shift;
+  Memset(&scaling_lut[x_base], point_scaling[num_points - 1],
+         scaling_lut_length - x_base);
+  // Fill in the gaps.
+  if (bitdepth == kBitdepth10) {
+    for (int x = 4; x < x_base + 4; x += 4) {
+      const int start = scaling_lut[x - 4];
+      const int end = scaling_lut[x];
+      const int delta = end - start;
+      scaling_lut[x - 3] = start + RightShiftWithRounding(delta, 2);
+      scaling_lut[x - 2] = start + RightShiftWithRounding(2 * delta, 2);
+      scaling_lut[x - 1] = start + RightShiftWithRounding(3 * delta, 2);
+    }
+  }
 }
 
 // Section 7.18.3.5.
-// Performs a piecewise linear interpolation into the scaling table.
 template <int bitdepth>
-int ScaleLut(const uint8_t scaling_lut[kScalingLookupTableSize], int index) {
-  const int shift = bitdepth - 8;
+int ScaleLut(const int16_t* scaling_lut, int index) {
+  if (bitdepth <= kBitdepth10) {
+    assert(index < kScalingLookupTableSize << (bitdepth - 2));
+    return scaling_lut[index];
+  }
+  // Performs a piecewise linear interpolation into the scaling table.
+  const int shift = bitdepth - kBitdepth8;
   const int quotient = index >> shift;
   const int remainder = index - (quotient << shift);
-  if (bitdepth == 8) {
-    assert(quotient < kScalingLookupTableSize);
-    return scaling_lut[quotient];
-  }
   assert(quotient + 1 < kScalingLookupTableSize);
   const int start = scaling_lut[quotient];
   const int end = scaling_lut[quotient + 1];
@@ -153,12 +163,11 @@
 
 template <int bitdepth, typename GrainType, int auto_regression_coeff_lag,
           bool use_luma>
-void ApplyAutoRegressiveFilterToChromaGrains_C(const FilmGrainParams& params,
-                                               const void* luma_grain_buffer,
-                                               int subsampling_x,
-                                               int subsampling_y,
-                                               void* u_grain_buffer,
-                                               void* v_grain_buffer) {
+void ApplyAutoRegressiveFilterToChromaGrains_C(
+    const FilmGrainParams& params,
+    const void* LIBGAV1_RESTRICT luma_grain_buffer, int subsampling_x,
+    int subsampling_y, void* LIBGAV1_RESTRICT u_grain_buffer,
+    void* LIBGAV1_RESTRICT v_grain_buffer) {
   static_assert(
       auto_regression_coeff_lag >= 0 && auto_regression_coeff_lag <= 3,
       "Unsupported autoregression lag for chroma.");
@@ -227,9 +236,10 @@
 
 // This implementation is for the condition overlap_flag == false.
 template <int bitdepth, typename GrainType>
-void ConstructNoiseStripes_C(const void* grain_buffer, int grain_seed,
-                             int width, int height, int subsampling_x,
-                             int subsampling_y, void* noise_stripes_buffer) {
+void ConstructNoiseStripes_C(const void* LIBGAV1_RESTRICT grain_buffer,
+                             int grain_seed, int width, int height,
+                             int subsampling_x, int subsampling_y,
+                             void* LIBGAV1_RESTRICT noise_stripes_buffer) {
   auto* noise_stripes =
       static_cast<Array2DView<GrainType>*>(noise_stripes_buffer);
   const auto* grain = static_cast<const GrainType*>(grain_buffer);
@@ -272,8 +282,6 @@
         // Writes beyond the width of each row could happen below. To
         // prevent those writes, we clip the number of pixels to copy against
         // the remaining width.
-        // TODO(petersonab): Allocate aligned stripes with extra width to cover
-        // the size of the final stripe block, then remove this call to min.
         const int copy_size =
             std::min(kNoiseStripeHeight >> subsampling_x,
                      plane_width - (x << (1 - subsampling_x)));
@@ -291,10 +299,10 @@
 
 // This implementation is for the condition overlap_flag == true.
 template <int bitdepth, typename GrainType>
-void ConstructNoiseStripesWithOverlap_C(const void* grain_buffer,
-                                        int grain_seed, int width, int height,
-                                        int subsampling_x, int subsampling_y,
-                                        void* noise_stripes_buffer) {
+void ConstructNoiseStripesWithOverlap_C(
+    const void* LIBGAV1_RESTRICT grain_buffer, int grain_seed, int width,
+    int height, int subsampling_x, int subsampling_y,
+    void* LIBGAV1_RESTRICT noise_stripes_buffer) {
   auto* noise_stripes =
       static_cast<Array2DView<GrainType>*>(noise_stripes_buffer);
   const auto* grain = static_cast<const GrainType*>(grain_buffer);
@@ -326,8 +334,6 @@
     // The overlap computation only occurs when x > 0, so it is omitted here.
     int i = 0;
     do {
-      // TODO(petersonab): Allocate aligned stripes with extra width to cover
-      // the size of the final stripe block, then remove this call to min.
       const int copy_size =
           std::min(kNoiseStripeHeight >> subsampling_x, plane_width);
       memcpy(&noise_stripe[i * plane_width],
@@ -399,8 +405,6 @@
         // Writes beyond the width of each row could happen below. To
         // prevent those writes, we clip the number of pixels to copy against
         // the remaining width.
-        // TODO(petersonab): Allocate aligned stripes with extra width to cover
-        // the size of the final stripe block, then remove this call to min.
         const int copy_size =
             std::min(kNoiseStripeHeight >> subsampling_x,
                      plane_width - (x << (1 - subsampling_x))) -
@@ -417,10 +421,11 @@
 }
 
 template <int bitdepth, typename GrainType>
-inline void WriteOverlapLine_C(const GrainType* noise_stripe_row,
-                               const GrainType* noise_stripe_row_prev,
-                               int plane_width, int grain_coeff, int old_coeff,
-                               GrainType* noise_image_row) {
+inline void WriteOverlapLine_C(
+    const GrainType* LIBGAV1_RESTRICT noise_stripe_row,
+    const GrainType* LIBGAV1_RESTRICT noise_stripe_row_prev, int plane_width,
+    int grain_coeff, int old_coeff,
+    GrainType* LIBGAV1_RESTRICT noise_image_row) {
   int x = 0;
   do {
     int grain = noise_stripe_row[x];
@@ -433,9 +438,10 @@
 }
 
 template <int bitdepth, typename GrainType>
-void ConstructNoiseImageOverlap_C(const void* noise_stripes_buffer, int width,
-                                  int height, int subsampling_x,
-                                  int subsampling_y, void* noise_image_buffer) {
+void ConstructNoiseImageOverlap_C(
+    const void* LIBGAV1_RESTRICT noise_stripes_buffer, int width, int height,
+    int subsampling_x, int subsampling_y,
+    void* LIBGAV1_RESTRICT noise_image_buffer) {
   const auto* noise_stripes =
       static_cast<const Array2DView<GrainType>*>(noise_stripes_buffer);
   auto* noise_image = static_cast<Array2D<GrainType>*>(noise_image_buffer);
@@ -495,12 +501,13 @@
 }
 
 template <int bitdepth, typename GrainType, typename Pixel>
-void BlendNoiseWithImageLuma_C(
-    const void* noise_image_ptr, int min_value, int max_luma, int scaling_shift,
-    int width, int height, int start_height,
-    const uint8_t scaling_lut_y[kScalingLookupTableSize],
-    const void* source_plane_y, ptrdiff_t source_stride_y, void* dest_plane_y,
-    ptrdiff_t dest_stride_y) {
+void BlendNoiseWithImageLuma_C(const void* LIBGAV1_RESTRICT noise_image_ptr,
+                               int min_value, int max_luma, int scaling_shift,
+                               int width, int height, int start_height,
+                               const int16_t* scaling_lut_y,
+                               const void* source_plane_y,
+                               ptrdiff_t source_stride_y, void* dest_plane_y,
+                               ptrdiff_t dest_stride_y) {
   const auto* noise_image =
       static_cast<const Array2D<GrainType>*>(noise_image_ptr);
   const auto* in_y = static_cast<const Pixel*>(source_plane_y);
@@ -524,10 +531,10 @@
 // This function is for the case params_.chroma_scaling_from_luma == false.
 template <int bitdepth, typename GrainType, typename Pixel>
 void BlendNoiseWithImageChroma_C(
-    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
-    int min_value, int max_chroma, int width, int height, int start_height,
-    int subsampling_x, int subsampling_y,
-    const uint8_t scaling_lut_uv[kScalingLookupTableSize],
+    Plane plane, const FilmGrainParams& params,
+    const void* LIBGAV1_RESTRICT noise_image_ptr, int min_value, int max_chroma,
+    int width, int height, int start_height, int subsampling_x,
+    int subsampling_y, const int16_t* scaling_lut_uv,
     const void* source_plane_y, ptrdiff_t source_stride_y,
     const void* source_plane_uv, ptrdiff_t source_stride_uv,
     void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
@@ -571,7 +578,7 @@
       const int orig = in_uv[y * source_stride_uv + x];
       const int combined = average_luma * luma_multiplier + orig * multiplier;
       const int merged =
-          Clip3((combined >> 6) + LeftShift(offset, bitdepth - 8), 0,
+          Clip3((combined >> 6) + LeftShift(offset, bitdepth - kBitdepth8), 0,
                 (1 << bitdepth) - 1);
       int noise = noise_image[plane][y + start_height][x];
       noise = RightShiftWithRounding(
@@ -586,13 +593,12 @@
 // This further implies that scaling_lut_u == scaling_lut_v == scaling_lut_y.
 template <int bitdepth, typename GrainType, typename Pixel>
 void BlendNoiseWithImageChromaWithCfl_C(
-    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
-    int min_value, int max_chroma, int width, int height, int start_height,
-    int subsampling_x, int subsampling_y,
-    const uint8_t scaling_lut[kScalingLookupTableSize],
-    const void* source_plane_y, ptrdiff_t source_stride_y,
-    const void* source_plane_uv, ptrdiff_t source_stride_uv,
-    void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
+    Plane plane, const FilmGrainParams& params,
+    const void* LIBGAV1_RESTRICT noise_image_ptr, int min_value, int max_chroma,
+    int width, int height, int start_height, int subsampling_x,
+    int subsampling_y, const int16_t* scaling_lut, const void* source_plane_y,
+    ptrdiff_t source_stride_y, const void* source_plane_uv,
+    ptrdiff_t source_stride_uv, void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
   const auto* noise_image =
       static_cast<const Array2D<GrainType>*>(noise_image_ptr);
   const auto* in_y = static_cast<const Pixel*>(source_plane_y);
@@ -639,106 +645,108 @@
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   // LumaAutoRegressionFunc
   dsp->film_grain.luma_auto_regression[0] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth8, int8_t>;
   dsp->film_grain.luma_auto_regression[1] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth8, int8_t>;
   dsp->film_grain.luma_auto_regression[2] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth8, int8_t>;
 
   // ChromaAutoRegressionFunc
   // Chroma autoregression should never be called when lag is 0 and use_luma is
   // false.
   dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
   dsp->film_grain.chroma_auto_regression[0][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 1, false>;
   dsp->film_grain.chroma_auto_regression[0][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 2, false>;
   dsp->film_grain.chroma_auto_regression[0][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 3, false>;
   dsp->film_grain.chroma_auto_regression[1][0] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 0, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 0, true>;
   dsp->film_grain.chroma_auto_regression[1][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 1, true>;
   dsp->film_grain.chroma_auto_regression[1][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 2, true>;
   dsp->film_grain.chroma_auto_regression[1][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 3, true>;
 
   // ConstructNoiseStripesFunc
   dsp->film_grain.construct_noise_stripes[0] =
-      ConstructNoiseStripes_C<8, int8_t>;
+      ConstructNoiseStripes_C<kBitdepth8, int8_t>;
   dsp->film_grain.construct_noise_stripes[1] =
-      ConstructNoiseStripesWithOverlap_C<8, int8_t>;
+      ConstructNoiseStripesWithOverlap_C<kBitdepth8, int8_t>;
 
   // ConstructNoiseImageOverlapFunc
   dsp->film_grain.construct_noise_image_overlap =
-      ConstructNoiseImageOverlap_C<8, int8_t>;
+      ConstructNoiseImageOverlap_C<kBitdepth8, int8_t>;
 
   // InitializeScalingLutFunc
-  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>;
+  dsp->film_grain.initialize_scaling_lut =
+      InitializeScalingLookupTable_C<kBitdepth8>;
 
   // BlendNoiseWithImageLumaFunc
   dsp->film_grain.blend_noise_luma =
-      BlendNoiseWithImageLuma_C<8, int8_t, uint8_t>;
+      BlendNoiseWithImageLuma_C<kBitdepth8, int8_t, uint8_t>;
 
   // BlendNoiseWithImageChromaFunc
   dsp->film_grain.blend_noise_chroma[0] =
-      BlendNoiseWithImageChroma_C<8, int8_t, uint8_t>;
+      BlendNoiseWithImageChroma_C<kBitdepth8, int8_t, uint8_t>;
   dsp->film_grain.blend_noise_chroma[1] =
-      BlendNoiseWithImageChromaWithCfl_C<8, int8_t, uint8_t>;
+      BlendNoiseWithImageChromaWithCfl_C<kBitdepth8, int8_t, uint8_t>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp8bpp_FilmGrainAutoregressionLuma
   dsp->film_grain.luma_auto_regression[0] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth8, int8_t>;
   dsp->film_grain.luma_auto_regression[1] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth8, int8_t>;
   dsp->film_grain.luma_auto_regression[2] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<8, int8_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth8, int8_t>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_FilmGrainAutoregressionChroma
   // Chroma autoregression should never be called when lag is 0 and use_luma is
   // false.
   dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
   dsp->film_grain.chroma_auto_regression[0][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 1, false>;
   dsp->film_grain.chroma_auto_regression[0][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 2, false>;
   dsp->film_grain.chroma_auto_regression[0][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 3, false>;
   dsp->film_grain.chroma_auto_regression[1][0] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 0, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 0, true>;
   dsp->film_grain.chroma_auto_regression[1][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 1, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 1, true>;
   dsp->film_grain.chroma_auto_regression[1][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 2, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 2, true>;
   dsp->film_grain.chroma_auto_regression[1][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<8, int8_t, 3, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth8, int8_t, 3, true>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseStripes
   dsp->film_grain.construct_noise_stripes[0] =
-      ConstructNoiseStripes_C<8, int8_t>;
+      ConstructNoiseStripes_C<kBitdepth8, int8_t>;
   dsp->film_grain.construct_noise_stripes[1] =
-      ConstructNoiseStripesWithOverlap_C<8, int8_t>;
+      ConstructNoiseStripesWithOverlap_C<kBitdepth8, int8_t>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_FilmGrainConstructNoiseImageOverlap
   dsp->film_grain.construct_noise_image_overlap =
-      ConstructNoiseImageOverlap_C<8, int8_t>;
+      ConstructNoiseImageOverlap_C<kBitdepth8, int8_t>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_FilmGrainInitializeScalingLutFunc
-  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>;
+  dsp->film_grain.initialize_scaling_lut =
+      InitializeScalingLookupTable_C<kBitdepth8>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseLuma
   dsp->film_grain.blend_noise_luma =
-      BlendNoiseWithImageLuma_C<8, int8_t, uint8_t>;
+      BlendNoiseWithImageLuma_C<kBitdepth8, int8_t, uint8_t>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChroma
   dsp->film_grain.blend_noise_chroma[0] =
-      BlendNoiseWithImageChroma_C<8, int8_t, uint8_t>;
+      BlendNoiseWithImageChroma_C<kBitdepth8, int8_t, uint8_t>;
 #endif
 #ifndef LIBGAV1_Dsp8bpp_FilmGrainBlendNoiseChromaWithCfl
   dsp->film_grain.blend_noise_chroma[1] =
-      BlendNoiseWithImageChromaWithCfl_C<8, int8_t, uint8_t>;
+      BlendNoiseWithImageChromaWithCfl_C<kBitdepth8, int8_t, uint8_t>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
@@ -751,106 +759,108 @@
 
   // LumaAutoRegressionFunc
   dsp->film_grain.luma_auto_regression[0] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth10, int16_t>;
   dsp->film_grain.luma_auto_regression[1] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth10, int16_t>;
   dsp->film_grain.luma_auto_regression[2] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth10, int16_t>;
 
   // ChromaAutoRegressionFunc
   // Chroma autoregression should never be called when lag is 0 and use_luma is
   // false.
   dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
   dsp->film_grain.chroma_auto_regression[0][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 1, false>;
   dsp->film_grain.chroma_auto_regression[0][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 2, false>;
   dsp->film_grain.chroma_auto_regression[0][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 3, false>;
   dsp->film_grain.chroma_auto_regression[1][0] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 0, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 0, true>;
   dsp->film_grain.chroma_auto_regression[1][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 1, true>;
   dsp->film_grain.chroma_auto_regression[1][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 2, true>;
   dsp->film_grain.chroma_auto_regression[1][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 3, true>;
 
   // ConstructNoiseStripesFunc
   dsp->film_grain.construct_noise_stripes[0] =
-      ConstructNoiseStripes_C<10, int16_t>;
+      ConstructNoiseStripes_C<kBitdepth10, int16_t>;
   dsp->film_grain.construct_noise_stripes[1] =
-      ConstructNoiseStripesWithOverlap_C<10, int16_t>;
+      ConstructNoiseStripesWithOverlap_C<kBitdepth10, int16_t>;
 
   // ConstructNoiseImageOverlapFunc
   dsp->film_grain.construct_noise_image_overlap =
-      ConstructNoiseImageOverlap_C<10, int16_t>;
+      ConstructNoiseImageOverlap_C<kBitdepth10, int16_t>;
 
   // InitializeScalingLutFunc
-  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>;
+  dsp->film_grain.initialize_scaling_lut =
+      InitializeScalingLookupTable_C<kBitdepth10>;
 
   // BlendNoiseWithImageLumaFunc
   dsp->film_grain.blend_noise_luma =
-      BlendNoiseWithImageLuma_C<10, int16_t, uint16_t>;
+      BlendNoiseWithImageLuma_C<kBitdepth10, int16_t, uint16_t>;
 
   // BlendNoiseWithImageChromaFunc
   dsp->film_grain.blend_noise_chroma[0] =
-      BlendNoiseWithImageChroma_C<10, int16_t, uint16_t>;
+      BlendNoiseWithImageChroma_C<kBitdepth10, int16_t, uint16_t>;
   dsp->film_grain.blend_noise_chroma[1] =
-      BlendNoiseWithImageChromaWithCfl_C<10, int16_t, uint16_t>;
+      BlendNoiseWithImageChromaWithCfl_C<kBitdepth10, int16_t, uint16_t>;
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   static_cast<void>(dsp);
 #ifndef LIBGAV1_Dsp10bpp_FilmGrainAutoregressionLuma
   dsp->film_grain.luma_auto_regression[0] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth10, int16_t>;
   dsp->film_grain.luma_auto_regression[1] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth10, int16_t>;
   dsp->film_grain.luma_auto_regression[2] =
-      ApplyAutoRegressiveFilterToLumaGrain_C<10, int16_t>;
+      ApplyAutoRegressiveFilterToLumaGrain_C<kBitdepth10, int16_t>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_FilmGrainAutoregressionChroma
   // Chroma autoregression should never be called when lag is 0 and use_luma is
   // false.
   dsp->film_grain.chroma_auto_regression[0][0] = nullptr;
   dsp->film_grain.chroma_auto_regression[0][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 1, false>;
   dsp->film_grain.chroma_auto_regression[0][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 2, false>;
   dsp->film_grain.chroma_auto_regression[0][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, false>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 3, false>;
   dsp->film_grain.chroma_auto_regression[1][0] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 0, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 0, true>;
   dsp->film_grain.chroma_auto_regression[1][1] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 1, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 1, true>;
   dsp->film_grain.chroma_auto_regression[1][2] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 2, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 2, true>;
   dsp->film_grain.chroma_auto_regression[1][3] =
-      ApplyAutoRegressiveFilterToChromaGrains_C<10, int16_t, 3, true>;
+      ApplyAutoRegressiveFilterToChromaGrains_C<kBitdepth10, int16_t, 3, true>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_FilmGrainConstructNoiseStripes
   dsp->film_grain.construct_noise_stripes[0] =
-      ConstructNoiseStripes_C<10, int16_t>;
+      ConstructNoiseStripes_C<kBitdepth10, int16_t>;
   dsp->film_grain.construct_noise_stripes[1] =
-      ConstructNoiseStripesWithOverlap_C<10, int16_t>;
+      ConstructNoiseStripesWithOverlap_C<kBitdepth10, int16_t>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_FilmGrainConstructNoiseImageOverlap
   dsp->film_grain.construct_noise_image_overlap =
-      ConstructNoiseImageOverlap_C<10, int16_t>;
+      ConstructNoiseImageOverlap_C<kBitdepth10, int16_t>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_FilmGrainInitializeScalingLutFunc
-  dsp->film_grain.initialize_scaling_lut = InitializeScalingLookupTable_C<0>;
+  dsp->film_grain.initialize_scaling_lut =
+      InitializeScalingLookupTable_C<kBitdepth10>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseLuma
   dsp->film_grain.blend_noise_luma =
-      BlendNoiseWithImageLuma_C<10, int16_t, uint16_t>;
+      BlendNoiseWithImageLuma_C<kBitdepth10, int16_t, uint16_t>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChroma
   dsp->film_grain.blend_noise_chroma[0] =
-      BlendNoiseWithImageChroma_C<10, int16_t, uint16_t>;
+      BlendNoiseWithImageChroma_C<kBitdepth10, int16_t, uint16_t>;
 #endif
 #ifndef LIBGAV1_Dsp10bpp_FilmGrainBlendNoiseChromaWithCfl
   dsp->film_grain.blend_noise_chroma[1] =
-      BlendNoiseWithImageChromaWithCfl_C<10, int16_t, uint16_t>;
+      BlendNoiseWithImageChromaWithCfl_C<kBitdepth10, int16_t, uint16_t>;
 #endif
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
 }
diff --git a/libgav1/src/dsp/film_grain_common.h b/libgav1/src/dsp/film_grain_common.h
index 64e3e8e..2e6ad45 100644
--- a/libgav1/src/dsp/film_grain_common.h
+++ b/libgav1/src/dsp/film_grain_common.h
@@ -59,15 +59,16 @@
   // The two possible heights of the chroma noise array.
   kMinChromaHeight = 38,
   kMaxChromaHeight = 73,
-  // The scaling lookup table maps bytes to bytes, so only uses 256 elements,
-  // plus one for overflow in 10bit lookups.
+  // The standard scaling lookup table maps bytes to bytes, so only uses 256
+  // elements, plus one for overflow in 12bpp lookups. The size is scaled up for
+  // 10bpp.
   kScalingLookupTableSize = 257,
   // Padding is added to the scaling lookup table to permit overwrites by
   // InitializeScalingLookupTable_NEON.
   kScalingLookupTablePadding = 6,
   // Padding is added to each row of the noise image to permit overreads by
   // BlendNoiseWithImageLuma_NEON and overwrites by WriteOverlapLine8bpp_NEON.
-  kNoiseImagePadding = 7,
+  kNoiseImagePadding = 15,
   // Padding is added to the end of the |noise_stripes_| buffer to permit
   // overreads by WriteOverlapLine8bpp_NEON.
   kNoiseStripePadding = 7,
diff --git a/libgav1/src/dsp/intrapred.cc b/libgav1/src/dsp/intrapred.cc
index 4520c2c..75af279 100644
--- a/libgav1/src/dsp/intrapred.cc
+++ b/libgav1/src/dsp/intrapred.cc
@@ -63,8 +63,8 @@
 
 template <int block_width, int block_height, typename Pixel>
 void IntraPredFuncs_C<block_width, block_height, Pixel>::DcTop(
-    void* const dest, ptrdiff_t stride, const void* const top_row,
-    const void* /*left_column*/) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row, const void* /*left_column*/) {
   int sum = block_width >> 1;  // rounder
   const auto* const top = static_cast<const Pixel*>(top_row);
   for (int x = 0; x < block_width; ++x) sum += top[x];
@@ -80,8 +80,8 @@
 
 template <int block_width, int block_height, typename Pixel>
 void IntraPredFuncs_C<block_width, block_height, Pixel>::DcLeft(
-    void* const dest, ptrdiff_t stride, const void* /*top_row*/,
-    const void* const left_column) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* /*top_row*/, const void* LIBGAV1_RESTRICT const left_column) {
   int sum = block_height >> 1;  // rounder
   const auto* const left = static_cast<const Pixel*>(left_column);
   for (int y = 0; y < block_height; ++y) sum += left[y];
@@ -132,8 +132,9 @@
 
 template <int block_width, int block_height, typename Pixel>
 void IntraPredFuncs_C<block_width, block_height, Pixel>::Dc(
-    void* const dest, ptrdiff_t stride, const void* const top_row,
-    const void* const left_column) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const int divisor = block_width + block_height;
   int sum = divisor >> 1;  // rounder
 
@@ -158,8 +159,8 @@
 // IntraPredFuncs_C::Vertical -- apply top row vertically
 template <int block_width, int block_height, typename Pixel>
 void IntraPredFuncs_C<block_width, block_height, Pixel>::Vertical(
-    void* const dest, ptrdiff_t stride, const void* const top_row,
-    const void* /*left_column*/) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row, const void* /*left_column*/) {
   auto* dst = static_cast<uint8_t*>(dest);
   for (int y = 0; y < block_height; ++y) {
     memcpy(dst, top_row, block_width * sizeof(Pixel));
@@ -170,8 +171,8 @@
 // IntraPredFuncs_C::Horizontal -- apply left column horizontally
 template <int block_width, int block_height, typename Pixel>
 void IntraPredFuncs_C<block_width, block_height, Pixel>::Horizontal(
-    void* const dest, ptrdiff_t stride, const void* /*top_row*/,
-    const void* const left_column) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* /*top_row*/, const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left = static_cast<const Pixel*>(left_column);
   auto* dst = static_cast<Pixel*>(dest);
   stride /= sizeof(Pixel);
@@ -184,8 +185,9 @@
 // IntraPredFuncs_C::Paeth
 template <int block_width, int block_height, typename Pixel>
 void IntraPredFuncs_C<block_width, block_height, Pixel>::Paeth(
-    void* const dest, ptrdiff_t stride, const void* const top_row,
-    const void* const left_column) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const Pixel*>(top_row);
   const auto* const left = static_cast<const Pixel*>(left_column);
   const Pixel top_left = top[-1];
diff --git a/libgav1/src/dsp/intrapred_cfl.cc b/libgav1/src/dsp/intrapred_cfl.cc
index 948c0c0..0f7f4f2 100644
--- a/libgav1/src/dsp/intrapred_cfl.cc
+++ b/libgav1/src/dsp/intrapred_cfl.cc
@@ -41,7 +41,7 @@
 // |alpha| can be -16 to 16 (inclusive).
 template <int block_width, int block_height, int bitdepth, typename Pixel>
 void CflIntraPredictor_C(
-    void* const dest, ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   auto* dst = static_cast<Pixel*>(dest);
@@ -66,7 +66,8 @@
           int subsampling_x, int subsampling_y>
 void CflSubsampler_C(int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
                      const int max_luma_width, const int max_luma_height,
-                     const void* const source, ptrdiff_t stride) {
+                     const void* LIBGAV1_RESTRICT const source,
+                     ptrdiff_t stride) {
   assert(max_luma_width >= 4);
   assert(max_luma_height >= 4);
   const auto* src = static_cast<const Pixel*>(source);
diff --git a/libgav1/src/dsp/intrapred_directional.cc b/libgav1/src/dsp/intrapred_directional.cc
index e670769..21a40b5 100644
--- a/libgav1/src/dsp/intrapred_directional.cc
+++ b/libgav1/src/dsp/intrapred_directional.cc
@@ -33,11 +33,10 @@
 // 7.11.2.4. Directional intra prediction process
 
 template <typename Pixel>
-void DirectionalIntraPredictorZone1_C(void* const dest, ptrdiff_t stride,
-                                      const void* const top_row,
-                                      const int width, const int height,
-                                      const int xstep,
-                                      const bool upsampled_top) {
+void DirectionalIntraPredictorZone1_C(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row, const int width,
+    const int height, const int xstep, const bool upsampled_top) {
   const auto* const top = static_cast<const Pixel*>(top_row);
   auto* dst = static_cast<Pixel*>(dest);
   stride /= sizeof(Pixel);
@@ -96,13 +95,12 @@
 }
 
 template <typename Pixel>
-void DirectionalIntraPredictorZone2_C(void* const dest, ptrdiff_t stride,
-                                      const void* const top_row,
-                                      const void* const left_column,
-                                      const int width, const int height,
-                                      const int xstep, const int ystep,
-                                      const bool upsampled_top,
-                                      const bool upsampled_left) {
+void DirectionalIntraPredictorZone2_C(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column, const int width,
+    const int height, const int xstep, const int ystep,
+    const bool upsampled_top, const bool upsampled_left) {
   const auto* const top = static_cast<const Pixel*>(top_row);
   const auto* const left = static_cast<const Pixel*>(left_column);
   auto* dst = static_cast<Pixel*>(dest);
@@ -146,11 +144,10 @@
 }
 
 template <typename Pixel>
-void DirectionalIntraPredictorZone3_C(void* const dest, ptrdiff_t stride,
-                                      const void* const left_column,
-                                      const int width, const int height,
-                                      const int ystep,
-                                      const bool upsampled_left) {
+void DirectionalIntraPredictorZone3_C(
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const left_column, const int width,
+    const int height, const int ystep, const bool upsampled_left) {
   const auto* const left = static_cast<const Pixel*>(left_column);
   stride /= sizeof(Pixel);
 
diff --git a/libgav1/src/dsp/intrapred_filter.cc b/libgav1/src/dsp/intrapred_filter.cc
index f4bd296..9a45eff 100644
--- a/libgav1/src/dsp/intrapred_filter.cc
+++ b/libgav1/src/dsp/intrapred_filter.cc
@@ -40,9 +40,9 @@
 // adjacent to the |top_row| or |left_column|. The set of 8 filters is selected
 // according to |pred|.
 template <int bitdepth, typename Pixel>
-void FilterIntraPredictor_C(void* const dest, ptrdiff_t stride,
-                            const void* const top_row,
-                            const void* const left_column,
+void FilterIntraPredictor_C(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                            const void* LIBGAV1_RESTRICT const top_row,
+                            const void* LIBGAV1_RESTRICT const left_column,
                             const FilterIntraPredictor pred, const int width,
                             const int height) {
   const int kMaxPixel = (1 << bitdepth) - 1;
diff --git a/libgav1/src/dsp/intrapred_smooth.cc b/libgav1/src/dsp/intrapred_smooth.cc
index 83c005e..0c7f272 100644
--- a/libgav1/src/dsp/intrapred_smooth.cc
+++ b/libgav1/src/dsp/intrapred_smooth.cc
@@ -42,26 +42,15 @@
 };
 
 constexpr uint8_t kSmoothWeights[] = {
-    // block dimension = 4
-    255, 149, 85, 64,
-    // block dimension = 8
-    255, 197, 146, 105, 73, 50, 37, 32,
-    // block dimension = 16
-    255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16,
-    // block dimension = 32
-    255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74,
-    66, 59, 52, 45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8,
-    // block dimension = 64
-    255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156,
-    150, 144, 138, 133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73,
-    69, 65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16,
-    15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4};
+#include "src/dsp/smooth_weights.inc"
+};
 
 // SmoothFuncs_C::Smooth
 template <int block_width, int block_height, typename Pixel>
 void SmoothFuncs_C<block_width, block_height, Pixel>::Smooth(
-    void* const dest, ptrdiff_t stride, const void* const top_row,
-    const void* const left_column) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const Pixel*>(top_row);
   const auto* const left = static_cast<const Pixel*>(left_column);
   const Pixel top_right = top[block_width - 1];
@@ -94,8 +83,9 @@
 // SmoothFuncs_C::SmoothVertical
 template <int block_width, int block_height, typename Pixel>
 void SmoothFuncs_C<block_width, block_height, Pixel>::SmoothVertical(
-    void* const dest, ptrdiff_t stride, const void* const top_row,
-    const void* const left_column) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const Pixel*>(top_row);
   const auto* const left = static_cast<const Pixel*>(left_column);
   const Pixel bottom_left = left[block_height - 1];
@@ -121,8 +111,9 @@
 // SmoothFuncs_C::SmoothHorizontal
 template <int block_width, int block_height, typename Pixel>
 void SmoothFuncs_C<block_width, block_height, Pixel>::SmoothHorizontal(
-    void* const dest, ptrdiff_t stride, const void* const top_row,
-    const void* const left_column) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const Pixel*>(top_row);
   const auto* const left = static_cast<const Pixel*>(left_column);
   const Pixel top_right = top[block_width - 1];
diff --git a/libgav1/src/dsp/inverse_transform.cc b/libgav1/src/dsp/inverse_transform.cc
index ed984d8..1b0064f 100644
--- a/libgav1/src/dsp/inverse_transform.cc
+++ b/libgav1/src/dsp/inverse_transform.cc
@@ -42,8 +42,8 @@
 #if defined(LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK) && \
     LIBGAV1_ENABLE_TRANSFORM_RANGE_CHECK
   assert(range <= 32);
-  const int32_t min = -(1 << (range - 1));
-  const int32_t max = (1 << (range - 1)) - 1;
+  const auto min = static_cast<int32_t>(-(uint32_t{1} << (range - 1)));
+  const auto max = static_cast<int32_t>((uint32_t{1} << (range - 1)) - 1);
   if (min > value || value > max) {
     LIBGAV1_DLOG(ERROR, "coeff out of bit range, value: %d bit range %d\n",
                  value, range);
@@ -140,7 +140,7 @@
 // For e.g. index (2, 3) will be computed as follows:
 //   * bitreverse(3) = bitreverse(..000011) = 110000...
 //   * interpreting that as an integer with bit-length 2+2 = 4 will be 1100 = 12
-constexpr uint8_t kBitReverseLookup[kNum1DTransformSizes][64] = {
+constexpr uint8_t kBitReverseLookup[kNumTransform1dSizes][64] = {
     {0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2,
      1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3,
      0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3, 0, 2, 1, 3},
@@ -532,8 +532,8 @@
 }
 
 template <typename Residual>
-void AdstInputPermutation(int32_t* const dst, const Residual* const src,
-                          int n) {
+void AdstInputPermutation(int32_t* LIBGAV1_RESTRICT const dst,
+                          const Residual* LIBGAV1_RESTRICT const src, int n) {
   assert(n == 8 || n == 16);
   for (int i = 0; i < n; ++i) {
     dst[i] = src[((i & 1) == 0) ? n - i - 1 : i - 1];
@@ -544,8 +544,8 @@
     0, 8, 12, 4, 6, 14, 10, 2, 3, 11, 15, 7, 5, 13, 9, 1};
 
 template <typename Residual>
-void AdstOutputPermutation(Residual* const dst, const int32_t* const src,
-                           int n) {
+void AdstOutputPermutation(Residual* LIBGAV1_RESTRICT const dst,
+                           const int32_t* LIBGAV1_RESTRICT const src, int n) {
   assert(n == 8 || n == 16);
   const auto shift = static_cast<int8_t>(n == 8);
   for (int i = 0; i < n; ++i) {
@@ -1096,20 +1096,21 @@
 //------------------------------------------------------------------------------
 // row/column transform loop
 
-using InverseTransform1DFunc = void (*)(void* dst, int8_t range);
+using InverseTransform1dFunc = void (*)(void* dst, int8_t range);
 using InverseTransformDcOnlyFunc = void (*)(void* dest, int8_t range,
                                             bool should_round, int row_shift,
                                             bool is_row);
 
 template <int bitdepth, typename Residual, typename Pixel,
-          Transform1D transform1d_type,
+          Transform1d transform1d_type,
           InverseTransformDcOnlyFunc dconly_transform1d,
-          InverseTransform1DFunc transform1d_func, bool is_row>
+          InverseTransform1dFunc transform1d_func, bool is_row>
 void TransformLoop_C(TransformType tx_type, TransformSize tx_size,
-                     int adjusted_tx_height, void* src_buffer, int start_x,
-                     int start_y, void* dst_frame) {
-  constexpr bool lossless = transform1d_type == k1DTransformWht;
-  constexpr bool is_identity = transform1d_type == k1DTransformIdentity;
+                     int adjusted_tx_height, void* LIBGAV1_RESTRICT src_buffer,
+                     int start_x, int start_y,
+                     void* LIBGAV1_RESTRICT dst_frame) {
+  constexpr bool lossless = transform1d_type == kTransform1dWht;
+  constexpr bool is_identity = transform1d_type == kTransform1dIdentity;
   // The transform size of the WHT is always 4x4. Setting tx_width and
   // tx_height to the constant 4 for the WHT speeds the code up.
   assert(!lossless || tx_size == kTransformSize4x4);
@@ -1127,7 +1128,7 @@
   if (is_row) {
     // Row transform.
     const uint8_t row_shift = lossless ? 0 : kTransformRowShift[tx_size];
-    // This is the |range| parameter of the InverseTransform1DFunc.  For lossy
+    // This is the |range| parameter of the InverseTransform1dFunc.  For lossy
     // transforms, this will be equal to the clamping range.
     const int8_t row_clamp_range = lossless ? 2 : (bitdepth + 8);
     // If the width:height ratio of the transform size is 2:1 or 1:2, multiply
@@ -1170,10 +1171,10 @@
 
   assert(!is_row);
   constexpr uint8_t column_shift = lossless ? 0 : kTransformColumnShift;
-  // This is the |range| parameter of the InverseTransform1DFunc.  For lossy
+  // This is the |range| parameter of the InverseTransform1dFunc.  For lossy
   // transforms, this will be equal to the clamping range.
   const int8_t column_clamp_range = lossless ? 0 : std::max(bitdepth + 6, 16);
-  const bool flip_rows = transform1d_type == k1DTransformAdst &&
+  const bool flip_rows = transform1d_type == kTransform1dAdst &&
                          kTransformFlipRowsMask.Contains(tx_type);
   const bool flip_columns =
       !lossless && kTransformFlipColumnsMask.Contains(tx_type);
@@ -1216,114 +1217,114 @@
 template <int bitdepth, typename Residual, typename Pixel>
 void InitAll(Dsp* const dsp) {
   // Maximum transform size for Dct is 64.
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dDct,
                       DctDcOnly_C<bitdepth, Residual, 2>, Dct_C<Residual, 2>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dDct,
                       DctDcOnly_C<bitdepth, Residual, 2>, Dct_C<Residual, 2>,
                       /*is_row=*/false>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dDct,
                       DctDcOnly_C<bitdepth, Residual, 3>, Dct_C<Residual, 3>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dDct,
                       DctDcOnly_C<bitdepth, Residual, 3>, Dct_C<Residual, 3>,
                       /*is_row=*/false>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dDct,
                       DctDcOnly_C<bitdepth, Residual, 4>, Dct_C<Residual, 4>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dDct,
                       DctDcOnly_C<bitdepth, Residual, 4>, Dct_C<Residual, 4>,
                       /*is_row=*/false>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dDct,
                       DctDcOnly_C<bitdepth, Residual, 5>, Dct_C<Residual, 5>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dDct,
                       DctDcOnly_C<bitdepth, Residual, 5>, Dct_C<Residual, 5>,
                       /*is_row=*/false>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dDct,
                       DctDcOnly_C<bitdepth, Residual, 6>, Dct_C<Residual, 6>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dDct,
                       DctDcOnly_C<bitdepth, Residual, 6>, Dct_C<Residual, 6>,
                       /*is_row=*/false>;
 
   // Maximum transform size for Adst is 16.
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dAdst,
                       Adst4DcOnly_C<bitdepth, Residual>, Adst4_C<Residual>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dAdst,
                       Adst4DcOnly_C<bitdepth, Residual>, Adst4_C<Residual>,
                       /*is_row=*/false>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dAdst,
                       Adst8DcOnly_C<bitdepth, Residual>, Adst8_C<Residual>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dAdst,
                       Adst8DcOnly_C<bitdepth, Residual>, Adst8_C<Residual>,
                       /*is_row=*/false>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dAdst,
                       Adst16DcOnly_C<bitdepth, Residual>, Adst16_C<Residual>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dAdst,
                       Adst16DcOnly_C<bitdepth, Residual>, Adst16_C<Residual>,
                       /*is_row=*/false>;
 
   // Maximum transform size for Identity transform is 32.
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dIdentity,
                       Identity4DcOnly_C<bitdepth, Residual>,
                       Identity4Row_C<Residual>, /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dIdentity,
                       Identity4DcOnly_C<bitdepth, Residual>,
                       Identity4Column_C<Residual>, /*is_row=*/false>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dIdentity,
                       Identity8DcOnly_C<bitdepth, Residual>,
                       Identity8Row_C<Residual>, /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dIdentity,
                       Identity8DcOnly_C<bitdepth, Residual>,
                       Identity8Column_C<Residual>, /*is_row=*/false>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dIdentity,
                       Identity16DcOnly_C<bitdepth, Residual>,
                       Identity16Row_C<Residual>, /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dIdentity,
                       Identity16DcOnly_C<bitdepth, Residual>,
                       Identity16Column_C<Residual>, /*is_row=*/false>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dIdentity,
                       Identity32DcOnly_C<bitdepth, Residual>,
                       Identity32Row_C<Residual>, /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dIdentity,
                       Identity32DcOnly_C<bitdepth, Residual>,
                       Identity32Column_C<Residual>, /*is_row=*/false>;
 
   // Maximum transform size for Wht is 4.
-  dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformWht,
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dWht,
                       Wht4DcOnly_C<bitdepth, Residual>, Wht4_C<Residual>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] =
-      TransformLoop_C<bitdepth, Residual, Pixel, k1DTransformWht,
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
+      TransformLoop_C<bitdepth, Residual, Pixel, kTransform1dWht,
                       Wht4DcOnly_C<bitdepth, Residual>, Wht4_C<Residual>,
                       /*is_row=*/false>;
 }
@@ -1332,142 +1333,137 @@
 void Init8bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
   assert(dsp != nullptr);
-  for (auto& inverse_transform_by_size : dsp->inverse_transforms) {
-    for (auto& inverse_transform : inverse_transform_by_size) {
-      inverse_transform[kRow] = nullptr;
-      inverse_transform[kColumn] = nullptr;
-    }
-  }
+  static_cast<void>(dsp);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   InitAll<8, int16_t, uint8_t>(dsp);
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dDct
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dDct,
                       DctDcOnly_C<8, int16_t, 2>, Dct_C<int16_t, 2>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dDct,
                       DctDcOnly_C<8, int16_t, 2>, Dct_C<int16_t, 2>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dDct
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dDct,
                       DctDcOnly_C<8, int16_t, 3>, Dct_C<int16_t, 3>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dDct,
                       DctDcOnly_C<8, int16_t, 3>, Dct_C<int16_t, 3>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dDct
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dDct,
                       DctDcOnly_C<8, int16_t, 4>, Dct_C<int16_t, 4>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dDct,
                       DctDcOnly_C<8, int16_t, 4>, Dct_C<int16_t, 4>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize32_Transform1dDct
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dDct,
                       DctDcOnly_C<8, int16_t, 5>, Dct_C<int16_t, 5>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dDct,
                       DctDcOnly_C<8, int16_t, 5>, Dct_C<int16_t, 5>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize64_Transform1dDct
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dDct,
                       DctDcOnly_C<8, int16_t, 6>, Dct_C<int16_t, 6>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dDct,
                       DctDcOnly_C<8, int16_t, 6>, Dct_C<int16_t, 6>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dAdst
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dAdst,
                       Adst4DcOnly_C<8, int16_t>, Adst4_C<int16_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dAdst,
                       Adst4DcOnly_C<8, int16_t>, Adst4_C<int16_t>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dAdst
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dAdst,
                       Adst8DcOnly_C<8, int16_t>, Adst8_C<int16_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dAdst,
                       Adst8DcOnly_C<8, int16_t>, Adst8_C<int16_t>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dAdst
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dAdst,
                       Adst16DcOnly_C<8, int16_t>, Adst16_C<int16_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dAdst,
                       Adst16DcOnly_C<8, int16_t>, Adst16_C<int16_t>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dIdentity
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dIdentity,
                       Identity4DcOnly_C<8, int16_t>, Identity4Row_C<int16_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dIdentity,
                       Identity4DcOnly_C<8, int16_t>, Identity4Column_C<int16_t>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dIdentity
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dIdentity,
                       Identity8DcOnly_C<8, int16_t>, Identity8Row_C<int16_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dIdentity,
                       Identity8DcOnly_C<8, int16_t>, Identity8Column_C<int16_t>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dIdentity
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dIdentity,
                       Identity16DcOnly_C<8, int16_t>, Identity16Row_C<int16_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dIdentity,
                       Identity16DcOnly_C<8, int16_t>,
                       Identity16Column_C<int16_t>, /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize32_Transform1dIdentity
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dIdentity,
                       Identity32DcOnly_C<8, int16_t>, Identity32Row_C<int16_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dIdentity,
                       Identity32DcOnly_C<8, int16_t>,
                       Identity32Column_C<int16_t>, /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht
-  dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformWht,
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dWht
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dWht,
                       Wht4DcOnly_C<8, int16_t>, Wht4_C<int16_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] =
-      TransformLoop_C<8, int16_t, uint8_t, k1DTransformWht,
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
+      TransformLoop_C<8, int16_t, uint8_t, kTransform1dWht,
                       Wht4DcOnly_C<8, int16_t>, Wht4_C<int16_t>,
                       /*is_row=*/false>;
 #endif
@@ -1478,142 +1474,137 @@
 void Init10bpp() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
   assert(dsp != nullptr);
-  for (auto& inverse_transform_by_size : dsp->inverse_transforms) {
-    for (auto& inverse_transform : inverse_transform_by_size) {
-      inverse_transform[kRow] = nullptr;
-      inverse_transform[kColumn] = nullptr;
-    }
-  }
+  static_cast<void>(dsp);
 #if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
   InitAll<10, int32_t, uint16_t>(dsp);
 #else  // !LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformDct
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize4_Transform1dDct
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dDct,
                       DctDcOnly_C<10, int32_t, 2>, Dct_C<int32_t, 2>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dDct,
                       DctDcOnly_C<10, int32_t, 2>, Dct_C<int32_t, 2>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformDct
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize8_Transform1dDct
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dDct,
                       DctDcOnly_C<10, int32_t, 3>, Dct_C<int32_t, 3>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dDct,
                       DctDcOnly_C<10, int32_t, 3>, Dct_C<int32_t, 3>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformDct
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize16_Transform1dDct
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dDct,
                       DctDcOnly_C<10, int32_t, 4>, Dct_C<int32_t, 4>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dDct,
                       DctDcOnly_C<10, int32_t, 4>, Dct_C<int32_t, 4>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize32_1DTransformDct
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize32_Transform1dDct
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dDct,
                       DctDcOnly_C<10, int32_t, 5>, Dct_C<int32_t, 5>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dDct,
                       DctDcOnly_C<10, int32_t, 5>, Dct_C<int32_t, 5>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize64_1DTransformDct
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize64_Transform1dDct
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dDct,
                       DctDcOnly_C<10, int32_t, 6>, Dct_C<int32_t, 6>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformDct,
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dDct,
                       DctDcOnly_C<10, int32_t, 6>, Dct_C<int32_t, 6>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformAdst
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize4_Transform1dAdst
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dAdst,
                       Adst4DcOnly_C<10, int32_t>, Adst4_C<int32_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dAdst,
                       Adst4DcOnly_C<10, int32_t>, Adst4_C<int32_t>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformAdst
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize8_Transform1dAdst
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dAdst,
                       Adst8DcOnly_C<10, int32_t>, Adst8_C<int32_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dAdst,
                       Adst8DcOnly_C<10, int32_t>, Adst8_C<int32_t>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformAdst
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize16_Transform1dAdst
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dAdst,
                       Adst16DcOnly_C<10, int32_t>, Adst16_C<int32_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformAdst,
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dAdst,
                       Adst16DcOnly_C<10, int32_t>, Adst16_C<int32_t>,
                       /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformIdentity
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize4_Transform1dIdentity
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dIdentity,
                       Identity4DcOnly_C<10, int32_t>, Identity4Row_C<int32_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dIdentity,
                       Identity4DcOnly_C<10, int32_t>,
                       Identity4Column_C<int32_t>, /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize8_1DTransformIdentity
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize8_Transform1dIdentity
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dIdentity,
                       Identity8DcOnly_C<10, int32_t>, Identity8Row_C<int32_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dIdentity,
                       Identity8DcOnly_C<10, int32_t>,
                       Identity8Column_C<int32_t>, /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize16_1DTransformIdentity
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize16_Transform1dIdentity
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dIdentity,
                       Identity16DcOnly_C<10, int32_t>, Identity16Row_C<int32_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dIdentity,
                       Identity16DcOnly_C<10, int32_t>,
                       Identity16Column_C<int32_t>, /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize32_1DTransformIdentity
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize32_Transform1dIdentity
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dIdentity,
                       Identity32DcOnly_C<10, int32_t>, Identity32Row_C<int32_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformIdentity,
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dIdentity,
                       Identity32DcOnly_C<10, int32_t>,
                       Identity32Column_C<int32_t>, /*is_row=*/false>;
 #endif
-#ifndef LIBGAV1_Dsp10bpp_1DTransformSize4_1DTransformWht
-  dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformWht,
+#ifndef LIBGAV1_Dsp10bpp_Transform1dSize4_Transform1dWht
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dWht,
                       Wht4DcOnly_C<10, int32_t>, Wht4_C<int32_t>,
                       /*is_row=*/true>;
-  dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] =
-      TransformLoop_C<10, int32_t, uint16_t, k1DTransformWht,
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
+      TransformLoop_C<10, int32_t, uint16_t, kTransform1dWht,
                       Wht4DcOnly_C<10, int32_t>, Wht4_C<int32_t>,
                       /*is_row=*/false>;
 #endif
diff --git a/libgav1/src/dsp/libgav1_dsp.cmake b/libgav1/src/dsp/libgav1_dsp.cmake
index a28334d..4bd1443 100644
--- a/libgav1/src/dsp/libgav1_dsp.cmake
+++ b/libgav1/src/dsp/libgav1_dsp.cmake
@@ -66,6 +66,7 @@
             "${libgav1_source}/dsp/obmc.cc"
             "${libgav1_source}/dsp/obmc.h"
             "${libgav1_source}/dsp/obmc.inc"
+            "${libgav1_source}/dsp/smooth_weights.inc"
             "${libgav1_source}/dsp/super_res.cc"
             "${libgav1_source}/dsp/super_res.h"
             "${libgav1_source}/dsp/warp.cc"
@@ -90,6 +91,7 @@
             "${libgav1_source}/dsp/arm/cdef_neon.cc"
             "${libgav1_source}/dsp/arm/cdef_neon.h"
             "${libgav1_source}/dsp/arm/common_neon.h"
+            "${libgav1_source}/dsp/arm/convolve_10bit_neon.cc"
             "${libgav1_source}/dsp/arm/convolve_neon.cc"
             "${libgav1_source}/dsp/arm/convolve_neon.h"
             "${libgav1_source}/dsp/arm/distance_weighted_blend_neon.cc"
@@ -113,6 +115,7 @@
             "${libgav1_source}/dsp/arm/inverse_transform_neon.h"
             "${libgav1_source}/dsp/arm/loop_filter_neon.cc"
             "${libgav1_source}/dsp/arm/loop_filter_neon.h"
+            "${libgav1_source}/dsp/arm/loop_restoration_10bit_neon.cc"
             "${libgav1_source}/dsp/arm/loop_restoration_neon.cc"
             "${libgav1_source}/dsp/arm/loop_restoration_neon.h"
             "${libgav1_source}/dsp/arm/mask_blend_neon.cc"
diff --git a/libgav1/src/dsp/loop_filter.cc b/libgav1/src/dsp/loop_filter.cc
index 6cad97d..14d47bf 100644
--- a/libgav1/src/dsp/loop_filter.cc
+++ b/libgav1/src/dsp/loop_filter.cc
@@ -56,6 +56,9 @@
 
 inline void AdjustThresholds(const int bitdepth, int* const outer_thresh,
                              int* const inner_thresh, int* const hev_thresh) {
+  assert(*outer_thresh >= 7 && *outer_thresh <= 3 * kMaxLoopFilterValue + 4);
+  assert(*inner_thresh >= 1 && *inner_thresh <= kMaxLoopFilterValue);
+  assert(*hev_thresh >= 0 && *hev_thresh <= 3);
   *outer_thresh <<= bitdepth - 8;
   *inner_thresh <<= bitdepth - 8;
   *hev_thresh <<= bitdepth - 8;
diff --git a/libgav1/src/dsp/loop_restoration.cc b/libgav1/src/dsp/loop_restoration.cc
index 1a15d90..2301a3e 100644
--- a/libgav1/src/dsp/loop_restoration.cc
+++ b/libgav1/src/dsp/loop_restoration.cc
@@ -144,11 +144,14 @@
 // Thus in libaom's computation, an offset of 128 is needed for filter[3].
 template <int bitdepth, typename Pixel>
 void WienerFilter_C(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   constexpr int kCenterTap = kWienerFilterTaps / 2;
   const int16_t* const number_leading_zero_coefficients =
       restoration_info.wiener_info.number_leading_zero_coefficients;
@@ -867,11 +870,14 @@
 
 template <int bitdepth, typename Pixel>
 void SelfGuidedFilter_C(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int index = restoration_info.sgr_proj_info.index;
   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
diff --git a/libgav1/src/dsp/mask_blend.cc b/libgav1/src/dsp/mask_blend.cc
index 15ef821..207fde0 100644
--- a/libgav1/src/dsp/mask_blend.cc
+++ b/libgav1/src/dsp/mask_blend.cc
@@ -25,7 +25,8 @@
 namespace dsp {
 namespace {
 
-uint8_t GetMaskValue(const uint8_t* mask, const uint8_t* mask_next_row, int x,
+uint8_t GetMaskValue(const uint8_t* LIBGAV1_RESTRICT mask,
+                     const uint8_t* LIBGAV1_RESTRICT mask_next_row, int x,
                      int subsampling_x, int subsampling_y) {
   if ((subsampling_x | subsampling_y) == 0) {
     return mask[x];
@@ -43,10 +44,12 @@
 
 template <int bitdepth, typename Pixel, bool is_inter_intra, int subsampling_x,
           int subsampling_y>
-void MaskBlend_C(const void* prediction_0, const void* prediction_1,
-                 const ptrdiff_t prediction_stride_1, const uint8_t* mask,
+void MaskBlend_C(const void* LIBGAV1_RESTRICT prediction_0,
+                 const void* LIBGAV1_RESTRICT prediction_1,
+                 const ptrdiff_t prediction_stride_1,
+                 const uint8_t* LIBGAV1_RESTRICT mask,
                  const ptrdiff_t mask_stride, const int width, const int height,
-                 void* dest, const ptrdiff_t dest_stride) {
+                 void* LIBGAV1_RESTRICT dest, const ptrdiff_t dest_stride) {
   static_assert(!(bitdepth == 8 && is_inter_intra), "");
   assert(mask != nullptr);
   using PredType =
@@ -85,11 +88,12 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-void InterIntraMaskBlend8bpp_C(const uint8_t* prediction_0,
-                               uint8_t* prediction_1,
+void InterIntraMaskBlend8bpp_C(const uint8_t* LIBGAV1_RESTRICT prediction_0,
+                               uint8_t* LIBGAV1_RESTRICT prediction_1,
                                const ptrdiff_t prediction_stride_1,
-                               const uint8_t* mask, const ptrdiff_t mask_stride,
-                               const int width, const int height) {
+                               const uint8_t* LIBGAV1_RESTRICT mask,
+                               const ptrdiff_t mask_stride, const int width,
+                               const int height) {
   assert(mask != nullptr);
   constexpr int step_y = subsampling_y ? 2 : 1;
   const uint8_t* mask_next_row = mask + mask_stride;
diff --git a/libgav1/src/dsp/motion_field_projection.cc b/libgav1/src/dsp/motion_field_projection.cc
index b51ec8f..7c17b8e 100644
--- a/libgav1/src/dsp/motion_field_projection.cc
+++ b/libgav1/src/dsp/motion_field_projection.cc
@@ -31,10 +31,8 @@
 
 // Silence unused function warnings when MotionFieldProjectionKernel_C is
 // not used.
-#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||                      \
-    !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel) || \
-    (LIBGAV1_MAX_BITDEPTH >= 10 &&                           \
-     !defined(LIBGAV1_Dsp10bpp_MotionFieldProjectionKernel))
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
+    !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel)
 
 // 7.9.2.
 void MotionFieldProjectionKernel_C(const ReferenceInfo& reference_info,
@@ -101,36 +99,16 @@
 }
 
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||
-        // !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel) ||
-        // (LIBGAV1_MAX_BITDEPTH >= 10 &&
-        //  !defined(LIBGAV1_Dsp10bpp_MotionFieldProjectionKernel))
-
-void Init8bpp() {
-#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
-    !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel)
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
-  assert(dsp != nullptr);
-  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_C;
-#endif
-}
-
-#if LIBGAV1_MAX_BITDEPTH >= 10
-void Init10bpp() {
-#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
-    !defined(LIBGAV1_Dsp10bpp_MotionFieldProjectionKernel)
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
-  assert(dsp != nullptr);
-  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_C;
-#endif
-}
-#endif
+        // !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel)
 
 }  // namespace
 
 void MotionFieldProjectionInit_C() {
-  Init8bpp();
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  Init10bpp();
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
+    !defined(LIBGAV1_Dsp8bpp_MotionFieldProjectionKernel)
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_C;
 #endif
 }
 
diff --git a/libgav1/src/dsp/motion_vector_search.cc b/libgav1/src/dsp/motion_vector_search.cc
index 9402302..205a1b6 100644
--- a/libgav1/src/dsp/motion_vector_search.cc
+++ b/libgav1/src/dsp/motion_vector_search.cc
@@ -29,16 +29,14 @@
 namespace {
 
 // Silence unused function warnings when the C functions are not used.
-#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||             \
-    !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch) || \
-    (LIBGAV1_MAX_BITDEPTH >= 10 &&                  \
-     !defined(LIBGAV1_Dsp10bpp_MotionVectorSearch))
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
+    !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch)
 
 void MvProjectionCompoundLowPrecision_C(
-    const MotionVector* const temporal_mvs,
-    const int8_t* const temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT const temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT const temporal_reference_offsets,
     const int reference_offsets[2], const int count,
-    CompoundMotionVector* const candidate_mvs) {
+    CompoundMotionVector* LIBGAV1_RESTRICT const candidate_mvs) {
   // To facilitate the compilers, make a local copy of |reference_offsets|.
   const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
   int index = 0;
@@ -62,10 +60,10 @@
 }
 
 void MvProjectionCompoundForceInteger_C(
-    const MotionVector* const temporal_mvs,
-    const int8_t* const temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT const temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT const temporal_reference_offsets,
     const int reference_offsets[2], const int count,
-    CompoundMotionVector* const candidate_mvs) {
+    CompoundMotionVector* LIBGAV1_RESTRICT const candidate_mvs) {
   // To facilitate the compilers, make a local copy of |reference_offsets|.
   const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
   int index = 0;
@@ -91,10 +89,10 @@
 }
 
 void MvProjectionCompoundHighPrecision_C(
-    const MotionVector* const temporal_mvs,
-    const int8_t* const temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT const temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT const temporal_reference_offsets,
     const int reference_offsets[2], const int count,
-    CompoundMotionVector* const candidate_mvs) {
+    CompoundMotionVector* LIBGAV1_RESTRICT const candidate_mvs) {
   // To facilitate the compilers, make a local copy of |reference_offsets|.
   const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
   int index = 0;
@@ -113,9 +111,10 @@
 }
 
 void MvProjectionSingleLowPrecision_C(
-    const MotionVector* const temporal_mvs,
-    const int8_t* const temporal_reference_offsets, const int reference_offset,
-    const int count, MotionVector* const candidate_mvs) {
+    const MotionVector* LIBGAV1_RESTRICT const temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT const temporal_reference_offsets,
+    const int reference_offset, const int count,
+    MotionVector* LIBGAV1_RESTRICT const candidate_mvs) {
   int index = 0;
   do {
     GetMvProjection(
@@ -131,9 +130,10 @@
 }
 
 void MvProjectionSingleForceInteger_C(
-    const MotionVector* const temporal_mvs,
-    const int8_t* const temporal_reference_offsets, const int reference_offset,
-    const int count, MotionVector* const candidate_mvs) {
+    const MotionVector* LIBGAV1_RESTRICT const temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT const temporal_reference_offsets,
+    const int reference_offset, const int count,
+    MotionVector* LIBGAV1_RESTRICT const candidate_mvs) {
   int index = 0;
   do {
     GetMvProjection(
@@ -151,9 +151,10 @@
 }
 
 void MvProjectionSingleHighPrecision_C(
-    const MotionVector* const temporal_mvs,
-    const int8_t* const temporal_reference_offsets, const int reference_offset,
-    const int count, MotionVector* const candidate_mvs) {
+    const MotionVector* LIBGAV1_RESTRICT const temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT const temporal_reference_offsets,
+    const int reference_offset, const int count,
+    MotionVector* LIBGAV1_RESTRICT const candidate_mvs) {
   int index = 0;
   do {
     GetMvProjection(
@@ -164,46 +165,21 @@
 }
 
 #endif  // LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS ||
-        // !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch) ||
-        // (LIBGAV1_MAX_BITDEPTH >= 10 &&
-        //  !defined(LIBGAV1_Dsp10bpp_MotionVectorSearch))
-
-void Init8bpp() {
-#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
-    !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch)
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
-  assert(dsp != nullptr);
-  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_C;
-  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_C;
-  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_C;
-  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_C;
-  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_C;
-  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_C;
-#endif
-}
-
-#if LIBGAV1_MAX_BITDEPTH >= 10
-void Init10bpp() {
-#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
-    !defined(LIBGAV1_Dsp10bpp_MotionVectorSearch)
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(10);
-  assert(dsp != nullptr);
-  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_C;
-  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_C;
-  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_C;
-  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_C;
-  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_C;
-  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_C;
-#endif
-}
-#endif
+        // !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch)
 
 }  // namespace
 
 void MotionVectorSearchInit_C() {
-  Init8bpp();
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  Init10bpp();
+#if LIBGAV1_ENABLE_ALL_DSP_FUNCTIONS || \
+    !defined(LIBGAV1_Dsp8bpp_MotionVectorSearch)
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_C;
+  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_C;
+  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_C;
+  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_C;
+  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_C;
+  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_C;
 #endif
 }
 
diff --git a/libgav1/src/dsp/obmc.cc b/libgav1/src/dsp/obmc.cc
index 46d1b5b..6b5c6e3 100644
--- a/libgav1/src/dsp/obmc.cc
+++ b/libgav1/src/dsp/obmc.cc
@@ -30,15 +30,18 @@
 
 // 7.11.3.10 (from top samples).
 template <typename Pixel>
-void OverlapBlendVertical_C(void* const prediction,
+void OverlapBlendVertical_C(void* LIBGAV1_RESTRICT const prediction,
                             const ptrdiff_t prediction_stride, const int width,
-                            const int height, const void* const obmc_prediction,
+                            const int height,
+                            const void* LIBGAV1_RESTRICT const obmc_prediction,
                             const ptrdiff_t obmc_prediction_stride) {
   auto* pred = static_cast<Pixel*>(prediction);
   const ptrdiff_t pred_stride = prediction_stride / sizeof(Pixel);
   const auto* obmc_pred = static_cast<const Pixel*>(obmc_prediction);
   const ptrdiff_t obmc_pred_stride = obmc_prediction_stride / sizeof(Pixel);
   const uint8_t* const mask = kObmcMask + height - 2;
+  assert(width >= 4);
+  assert(height >= 2);
 
   for (int y = 0; y < height; ++y) {
     const uint8_t mask_value = mask[y];
@@ -53,16 +56,19 @@
 
 // 7.11.3.10 (from left samples).
 template <typename Pixel>
-void OverlapBlendHorizontal_C(void* const prediction,
-                              const ptrdiff_t prediction_stride,
-                              const int width, const int height,
-                              const void* const obmc_prediction,
-                              const ptrdiff_t obmc_prediction_stride) {
+void OverlapBlendHorizontal_C(
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
+    const int width, const int height,
+    const void* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
   auto* pred = static_cast<Pixel*>(prediction);
   const ptrdiff_t pred_stride = prediction_stride / sizeof(Pixel);
   const auto* obmc_pred = static_cast<const Pixel*>(obmc_prediction);
   const ptrdiff_t obmc_pred_stride = obmc_prediction_stride / sizeof(Pixel);
   const uint8_t* const mask = kObmcMask + width - 2;
+  assert(width >= 2);
+  assert(height >= 4);
+
   for (int y = 0; y < height; ++y) {
     for (int x = 0; x < width; ++x) {
       const uint8_t mask_value = mask[x];
diff --git a/libgav1/src/dsp/smooth_weights.inc b/libgav1/src/dsp/smooth_weights.inc
new file mode 100644
index 0000000..d4ee8a6
--- /dev/null
+++ b/libgav1/src/dsp/smooth_weights.inc
@@ -0,0 +1,35 @@
+// Copyright 2021 The libgav1 Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Each row below contains weights used for a corresponding block size. Because
+// they are adjacent powers of 2, the index of each row is the sum of the sizes
+// of preceding rows, minus 4.
+// The weights need to be declared as uint8_t or uint16_t, depending on the
+// bitdepth, so the values are held in a single canonical place.
+// clang-format off
+    // block dimension = 4
+    255, 149, 85, 64,
+    // block dimension = 8
+    255, 197, 146, 105, 73, 50, 37, 32,
+    // block dimension = 16
+    255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16,
+    // block dimension = 32
+    255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74,
+    66, 59, 52, 45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8,
+    // block dimension = 64
+    255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156,
+    150, 144, 138, 133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73,
+    69, 65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16,
+    15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4
+    // clang-format on
diff --git a/libgav1/src/dsp/super_res.cc b/libgav1/src/dsp/super_res.cc
index abb01a1..570ba73 100644
--- a/libgav1/src/dsp/super_res.cc
+++ b/libgav1/src/dsp/super_res.cc
@@ -25,11 +25,12 @@
 namespace {
 
 template <int bitdepth, typename Pixel>
-void SuperRes_C(const void* /*coefficients*/, void* const source,
+void SuperRes_C(const void* /*coefficients*/,
+                void* LIBGAV1_RESTRICT const source,
                 const ptrdiff_t source_stride, const int height,
                 const int downscaled_width, const int upscaled_width,
-                const int initial_subpixel_x, const int step, void* const dest,
-                ptrdiff_t dest_stride) {
+                const int initial_subpixel_x, const int step,
+                void* LIBGAV1_RESTRICT const dest, ptrdiff_t dest_stride) {
   assert(step <= 1 << kSuperResScaleBits);
   auto* src = static_cast<Pixel*>(source) - DivideBy2(kSuperResFilterTaps);
   auto* dst = static_cast<Pixel*>(dest);
diff --git a/libgav1/src/dsp/warp.cc b/libgav1/src/dsp/warp.cc
index fbde65a..dd467ea 100644
--- a/libgav1/src/dsp/warp.cc
+++ b/libgav1/src/dsp/warp.cc
@@ -59,14 +59,14 @@
 //   compound second pass output range: [    8129,    57403]
 
 template <bool is_compound, int bitdepth, typename Pixel>
-void Warp_C(const void* const source, ptrdiff_t source_stride,
+void Warp_C(const void* LIBGAV1_RESTRICT const source, ptrdiff_t source_stride,
             const int source_width, const int source_height,
-            const int* const warp_params, const int subsampling_x,
-            const int subsampling_y, const int block_start_x,
-            const int block_start_y, const int block_width,
-            const int block_height, const int16_t alpha, const int16_t beta,
-            const int16_t gamma, const int16_t delta, void* dest,
-            ptrdiff_t dest_stride) {
+            const int* LIBGAV1_RESTRICT const warp_params,
+            const int subsampling_x, const int subsampling_y,
+            const int block_start_x, const int block_start_y,
+            const int block_width, const int block_height, const int16_t alpha,
+            const int16_t beta, const int16_t gamma, const int16_t delta,
+            void* LIBGAV1_RESTRICT dest, ptrdiff_t dest_stride) {
   assert(block_width >= 8 && block_height >= 8);
   if (is_compound) {
     assert(dest_stride == block_width);
diff --git a/libgav1/src/dsp/weight_mask.cc b/libgav1/src/dsp/weight_mask.cc
index 15d6bc6..41f4c70 100644
--- a/libgav1/src/dsp/weight_mask.cc
+++ b/libgav1/src/dsp/weight_mask.cc
@@ -29,8 +29,9 @@
 namespace {
 
 template <int width, int height, int bitdepth, bool mask_is_inverse>
-void WeightMask_C(const void* prediction_0, const void* prediction_1,
-                  uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask_C(const void* LIBGAV1_RESTRICT prediction_0,
+                  const void* LIBGAV1_RESTRICT prediction_1,
+                  uint8_t* LIBGAV1_RESTRICT mask, ptrdiff_t mask_stride) {
   using PredType =
       typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
   const auto* pred_0 = static_cast<const PredType*>(prediction_0);
diff --git a/libgav1/src/dsp/x86/average_blend_sse4.cc b/libgav1/src/dsp/x86/average_blend_sse4.cc
index ec9f589..911c5a9 100644
--- a/libgav1/src/dsp/x86/average_blend_sse4.cc
+++ b/libgav1/src/dsp/x86/average_blend_sse4.cc
@@ -35,8 +35,9 @@
 
 constexpr int kInterPostRoundBit = 4;
 
-inline void AverageBlend4Row(const int16_t* prediction_0,
-                             const int16_t* prediction_1, uint8_t* dest) {
+inline void AverageBlend4Row(const int16_t* LIBGAV1_RESTRICT prediction_0,
+                             const int16_t* LIBGAV1_RESTRICT prediction_1,
+                             uint8_t* LIBGAV1_RESTRICT dest) {
   const __m128i pred_0 = LoadLo8(prediction_0);
   const __m128i pred_1 = LoadLo8(prediction_1);
   __m128i res = _mm_add_epi16(pred_0, pred_1);
@@ -44,8 +45,9 @@
   Store4(dest, _mm_packus_epi16(res, res));
 }
 
-inline void AverageBlend8Row(const int16_t* prediction_0,
-                             const int16_t* prediction_1, uint8_t* dest) {
+inline void AverageBlend8Row(const int16_t* LIBGAV1_RESTRICT prediction_0,
+                             const int16_t* LIBGAV1_RESTRICT prediction_1,
+                             uint8_t* LIBGAV1_RESTRICT dest) {
   const __m128i pred_0 = LoadAligned16(prediction_0);
   const __m128i pred_1 = LoadAligned16(prediction_1);
   __m128i res = _mm_add_epi16(pred_0, pred_1);
@@ -53,9 +55,10 @@
   StoreLo8(dest, _mm_packus_epi16(res, res));
 }
 
-inline void AverageBlendLargeRow(const int16_t* prediction_0,
-                                 const int16_t* prediction_1, const int width,
-                                 uint8_t* dest) {
+inline void AverageBlendLargeRow(const int16_t* LIBGAV1_RESTRICT prediction_0,
+                                 const int16_t* LIBGAV1_RESTRICT prediction_1,
+                                 const int width,
+                                 uint8_t* LIBGAV1_RESTRICT dest) {
   int x = 0;
   do {
     const __m128i pred_00 = LoadAligned16(&prediction_0[x]);
@@ -71,8 +74,10 @@
   } while (x < width);
 }
 
-void AverageBlend_SSE4_1(const void* prediction_0, const void* prediction_1,
-                         const int width, const int height, void* const dest,
+void AverageBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
+                         const void* LIBGAV1_RESTRICT prediction_1,
+                         const int width, const int height,
+                         void* LIBGAV1_RESTRICT const dest,
                          const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
@@ -148,11 +153,11 @@
 constexpr int kInterPostRoundBitPlusOne = 5;
 
 template <const int width, const int offset>
-inline void AverageBlendRow(const uint16_t* prediction_0,
-                            const uint16_t* prediction_1,
+inline void AverageBlendRow(const uint16_t* LIBGAV1_RESTRICT prediction_0,
+                            const uint16_t* LIBGAV1_RESTRICT prediction_1,
                             const __m128i& compound_offset,
                             const __m128i& round_offset, const __m128i& max,
-                            const __m128i& zero, uint16_t* dst,
+                            const __m128i& zero, uint16_t* LIBGAV1_RESTRICT dst,
                             const ptrdiff_t dest_stride) {
   // pred_0/1 max range is 16b.
   const __m128i pred_0 = LoadUnaligned16(prediction_0 + offset);
@@ -182,9 +187,10 @@
   StoreHi8(dst + dest_stride, result);
 }
 
-void AverageBlend10bpp_SSE4_1(const void* prediction_0,
-                              const void* prediction_1, const int width,
-                              const int height, void* const dest,
+void AverageBlend10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
+                              const void* LIBGAV1_RESTRICT prediction_1,
+                              const int width, const int height,
+                              void* LIBGAV1_RESTRICT const dest,
                               const ptrdiff_t dst_stride) {
   auto* dst = static_cast<uint16_t*>(dest);
   const ptrdiff_t dest_stride = dst_stride / sizeof(dst[0]);
diff --git a/libgav1/src/dsp/x86/cdef_avx2.cc b/libgav1/src/dsp/x86/cdef_avx2.cc
index d41dc38..01a2b9f 100644
--- a/libgav1/src/dsp/x86/cdef_avx2.cc
+++ b/libgav1/src/dsp/x86/cdef_avx2.cc
@@ -269,8 +269,8 @@
       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_pair_add[3], 10));
 }
 
-LIBGAV1_ALWAYS_INLINE void AddPartial(const uint8_t* src, ptrdiff_t stride,
-                                      __m256i* partial) {
+LIBGAV1_ALWAYS_INLINE void AddPartial(const uint8_t* LIBGAV1_RESTRICT src,
+                                      ptrdiff_t stride, __m256i* partial) {
   // 8x8 input
   // 00 01 02 03 04 05 06 07
   // 10 11 12 13 14 15 16 17
@@ -451,8 +451,10 @@
   cost[6] = _mm_cvtsi128_si32(_mm_srli_si128(sums, 8));
 }
 
-void CdefDirection_AVX2(const void* const source, ptrdiff_t stride,
-                        uint8_t* const direction, int* const variance) {
+void CdefDirection_AVX2(const void* LIBGAV1_RESTRICT const source,
+                        ptrdiff_t stride,
+                        uint8_t* LIBGAV1_RESTRICT const direction,
+                        int* LIBGAV1_RESTRICT const variance) {
   assert(direction != nullptr);
   assert(variance != nullptr);
   const auto* src = static_cast<const uint8_t*>(source);
@@ -500,8 +502,9 @@
 // CdefFilter
 
 // Load 4 vectors based on the given |direction|.
-inline void LoadDirection(const uint16_t* const src, const ptrdiff_t stride,
-                          __m128i* output, const int direction) {
+inline void LoadDirection(const uint16_t* LIBGAV1_RESTRICT const src,
+                          const ptrdiff_t stride, __m128i* output,
+                          const int direction) {
   // Each |direction| describes a different set of source values. Expand this
   // set by negating each set. For |direction| == 0 this gives a diagonal line
   // from top right to bottom left. The first value is y, the second x. Negative
@@ -525,8 +528,9 @@
 
 // Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to
 // do 2 rows at a time.
-void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride,
-                    __m128i* output, const int direction) {
+void LoadDirection4(const uint16_t* LIBGAV1_RESTRICT const src,
+                    const ptrdiff_t stride, __m128i* output,
+                    const int direction) {
   const int y_0 = kCdefDirections[direction][0][0];
   const int x_0 = kCdefDirections[direction][0][1];
   const int y_1 = kCdefDirections[direction][1][0];
@@ -569,11 +573,11 @@
 }
 
 template <int width, bool enable_primary = true, bool enable_secondary = true>
-void CdefFilter_AVX2(const uint16_t* src, const ptrdiff_t src_stride,
-                     const int height, const int primary_strength,
-                     const int secondary_strength, const int damping,
-                     const int direction, void* dest,
-                     const ptrdiff_t dst_stride) {
+void CdefFilter_AVX2(const uint16_t* LIBGAV1_RESTRICT src,
+                     const ptrdiff_t src_stride, const int height,
+                     const int primary_strength, const int secondary_strength,
+                     const int damping, const int direction,
+                     void* LIBGAV1_RESTRICT dest, const ptrdiff_t dst_stride) {
   static_assert(width == 8 || width == 4, "Invalid CDEF width.");
   static_assert(enable_primary || enable_secondary, "");
   constexpr bool clipping_required = enable_primary && enable_secondary;
diff --git a/libgav1/src/dsp/x86/cdef_sse4.cc b/libgav1/src/dsp/x86/cdef_sse4.cc
index 6ede778..6c48844 100644
--- a/libgav1/src/dsp/x86/cdef_sse4.cc
+++ b/libgav1/src/dsp/x86/cdef_sse4.cc
@@ -241,8 +241,8 @@
   *partial_hi = _mm_add_epi16(*partial_hi, _mm_srli_si128(v_pair_add[3], 10));
 }
 
-LIBGAV1_ALWAYS_INLINE void AddPartial(const uint8_t* src, ptrdiff_t stride,
-                                      __m128i* partial_lo,
+LIBGAV1_ALWAYS_INLINE void AddPartial(const uint8_t* LIBGAV1_RESTRICT src,
+                                      ptrdiff_t stride, __m128i* partial_lo,
                                       __m128i* partial_hi) {
   // 8x8 input
   // 00 01 02 03 04 05 06 07
@@ -395,8 +395,10 @@
   return SumVector_S32(square);
 }
 
-void CdefDirection_SSE4_1(const void* const source, ptrdiff_t stride,
-                          uint8_t* const direction, int* const variance) {
+void CdefDirection_SSE4_1(const void* LIBGAV1_RESTRICT const source,
+                          ptrdiff_t stride,
+                          uint8_t* LIBGAV1_RESTRICT const direction,
+                          int* LIBGAV1_RESTRICT const variance) {
   assert(direction != nullptr);
   assert(variance != nullptr);
   const auto* src = static_cast<const uint8_t*>(source);
@@ -438,8 +440,9 @@
 // CdefFilter
 
 // Load 4 vectors based on the given |direction|.
-inline void LoadDirection(const uint16_t* const src, const ptrdiff_t stride,
-                          __m128i* output, const int direction) {
+inline void LoadDirection(const uint16_t* LIBGAV1_RESTRICT const src,
+                          const ptrdiff_t stride, __m128i* output,
+                          const int direction) {
   // Each |direction| describes a different set of source values. Expand this
   // set by negating each set. For |direction| == 0 this gives a diagonal line
   // from top right to bottom left. The first value is y, the second x. Negative
@@ -463,8 +466,9 @@
 
 // Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to
 // do 2 rows at a time.
-void LoadDirection4(const uint16_t* const src, const ptrdiff_t stride,
-                    __m128i* output, const int direction) {
+void LoadDirection4(const uint16_t* LIBGAV1_RESTRICT const src,
+                    const ptrdiff_t stride, __m128i* output,
+                    const int direction) {
   const int y_0 = kCdefDirections[direction][0][0];
   const int x_0 = kCdefDirections[direction][0][1];
   const int y_1 = kCdefDirections[direction][1][0];
@@ -507,10 +511,11 @@
 }
 
 template <int width, bool enable_primary = true, bool enable_secondary = true>
-void CdefFilter_SSE4_1(const uint16_t* src, const ptrdiff_t src_stride,
-                       const int height, const int primary_strength,
-                       const int secondary_strength, const int damping,
-                       const int direction, void* dest,
+void CdefFilter_SSE4_1(const uint16_t* LIBGAV1_RESTRICT src,
+                       const ptrdiff_t src_stride, const int height,
+                       const int primary_strength, const int secondary_strength,
+                       const int damping, const int direction,
+                       void* LIBGAV1_RESTRICT dest,
                        const ptrdiff_t dst_stride) {
   static_assert(width == 8 || width == 4, "Invalid CDEF width.");
   static_assert(enable_primary || enable_secondary, "");
diff --git a/libgav1/src/dsp/x86/convolve_avx2.cc b/libgav1/src/dsp/x86/convolve_avx2.cc
index 2ecb77c..4126ca9 100644
--- a/libgav1/src/dsp/x86/convolve_avx2.cc
+++ b/libgav1/src/dsp/x86/convolve_avx2.cc
@@ -127,10 +127,11 @@
 // Filter 2xh sizes.
 template <int num_taps, int filter_index, bool is_2d = false,
           bool is_compound = false>
-void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride,
-                      void* const dest, const ptrdiff_t pred_stride,
-                      const int /*width*/, const int height,
-                      const __m128i* const v_tap) {
+void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
+                      const ptrdiff_t src_stride,
+                      void* LIBGAV1_RESTRICT const dest,
+                      const ptrdiff_t pred_stride, const int /*width*/,
+                      const int height, const __m128i* const v_tap) {
   auto* dest8 = static_cast<uint8_t*>(dest);
   auto* dest16 = static_cast<uint16_t*>(dest);
 
@@ -195,10 +196,11 @@
 // Filter widths >= 4.
 template <int num_taps, int filter_index, bool is_2d = false,
           bool is_compound = false>
-void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride,
-                      void* const dest, const ptrdiff_t pred_stride,
-                      const int width, const int height,
-                      const __m256i* const v_tap) {
+void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
+                      const ptrdiff_t src_stride,
+                      void* LIBGAV1_RESTRICT const dest,
+                      const ptrdiff_t pred_stride, const int width,
+                      const int height, const __m256i* const v_tap) {
   auto* dest8 = static_cast<uint8_t*>(dest);
   auto* dest16 = static_cast<uint16_t*>(dest);
 
@@ -467,7 +469,8 @@
 }
 
 template <int num_taps, bool is_compound = false>
-void Filter2DVertical16xH(const uint16_t* src, void* const dst,
+void Filter2DVertical16xH(const uint16_t* LIBGAV1_RESTRICT src,
+                          void* LIBGAV1_RESTRICT const dst,
                           const ptrdiff_t dst_stride, const int width,
                           const int height, const __m256i* const taps) {
   assert(width >= 8);
@@ -542,9 +545,10 @@
 
 template <bool is_2d = false, bool is_compound = false>
 LIBGAV1_ALWAYS_INLINE void DoHorizontalPass2xH(
-    const uint8_t* const src, const ptrdiff_t src_stride, void* const dst,
-    const ptrdiff_t dst_stride, const int width, const int height,
-    const int filter_id, const int filter_index) {
+    const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride,
+    const int width, const int height, const int filter_id,
+    const int filter_index) {
   assert(filter_id != 0);
   __m128i v_tap[4];
   const __m128i v_horizontal_filter =
@@ -567,9 +571,10 @@
 
 template <bool is_2d = false, bool is_compound = false>
 LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
-    const uint8_t* const src, const ptrdiff_t src_stride, void* const dst,
-    const ptrdiff_t dst_stride, const int width, const int height,
-    const int filter_id, const int filter_index) {
+    const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride,
+    const int width, const int height, const int filter_id,
+    const int filter_index) {
   assert(filter_id != 0);
   __m256i v_tap[4];
   const __m128i v_horizontal_filter =
@@ -602,13 +607,13 @@
   }
 }
 
-void Convolve2D_AVX2(const void* const reference,
+void Convolve2D_AVX2(const void* LIBGAV1_RESTRICT const reference,
                      const ptrdiff_t reference_stride,
                      const int horizontal_filter_index,
                      const int vertical_filter_index,
                      const int horizontal_filter_id,
                      const int vertical_filter_id, const int width,
-                     const int height, void* prediction,
+                     const int height, void* LIBGAV1_RESTRICT prediction,
                      const ptrdiff_t pred_stride) {
   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
@@ -774,10 +779,11 @@
 }
 
 template <int filter_index, bool is_compound = false>
-void FilterVertical32xH(const uint8_t* src, const ptrdiff_t src_stride,
-                        void* const dst, const ptrdiff_t dst_stride,
-                        const int width, const int height,
-                        const __m256i* const v_tap) {
+void FilterVertical32xH(const uint8_t* LIBGAV1_RESTRICT src,
+                        const ptrdiff_t src_stride,
+                        void* LIBGAV1_RESTRICT const dst,
+                        const ptrdiff_t dst_stride, const int width,
+                        const int height, const __m256i* const v_tap) {
   const int num_taps = GetNumTapsInFilter(filter_index);
   const int next_row = num_taps - 1;
   auto* dst8 = static_cast<uint8_t*>(dst);
@@ -856,10 +862,11 @@
 }
 
 template <int filter_index, bool is_compound = false>
-void FilterVertical16xH(const uint8_t* src, const ptrdiff_t src_stride,
-                        void* const dst, const ptrdiff_t dst_stride,
-                        const int /*width*/, const int height,
-                        const __m256i* const v_tap) {
+void FilterVertical16xH(const uint8_t* LIBGAV1_RESTRICT src,
+                        const ptrdiff_t src_stride,
+                        void* LIBGAV1_RESTRICT const dst,
+                        const ptrdiff_t dst_stride, const int /*width*/,
+                        const int height, const __m256i* const v_tap) {
   const int num_taps = GetNumTapsInFilter(filter_index);
   const int next_row = num_taps;
   auto* dst8 = static_cast<uint8_t*>(dst);
@@ -958,10 +965,11 @@
 }
 
 template <int filter_index, bool is_compound = false>
-void FilterVertical8xH(const uint8_t* src, const ptrdiff_t src_stride,
-                       void* const dst, const ptrdiff_t dst_stride,
-                       const int /*width*/, const int height,
-                       const __m256i* const v_tap) {
+void FilterVertical8xH(const uint8_t* LIBGAV1_RESTRICT src,
+                       const ptrdiff_t src_stride,
+                       void* LIBGAV1_RESTRICT const dst,
+                       const ptrdiff_t dst_stride, const int /*width*/,
+                       const int height, const __m256i* const v_tap) {
   const int num_taps = GetNumTapsInFilter(filter_index);
   const int next_row = num_taps;
   auto* dst8 = static_cast<uint8_t*>(dst);
@@ -1055,10 +1063,11 @@
 }
 
 template <int filter_index, bool is_compound = false>
-void FilterVertical8xH(const uint8_t* src, const ptrdiff_t src_stride,
-                       void* const dst, const ptrdiff_t dst_stride,
-                       const int /*width*/, const int height,
-                       const __m128i* const v_tap) {
+void FilterVertical8xH(const uint8_t* LIBGAV1_RESTRICT src,
+                       const ptrdiff_t src_stride,
+                       void* LIBGAV1_RESTRICT const dst,
+                       const ptrdiff_t dst_stride, const int /*width*/,
+                       const int height, const __m128i* const v_tap) {
   const int num_taps = GetNumTapsInFilter(filter_index);
   const int next_row = num_taps - 1;
   auto* dst8 = static_cast<uint8_t*>(dst);
@@ -1119,13 +1128,13 @@
   } while (--y != 0);
 }
 
-void ConvolveVertical_AVX2(const void* const reference,
+void ConvolveVertical_AVX2(const void* LIBGAV1_RESTRICT const reference,
                            const ptrdiff_t reference_stride,
                            const int /*horizontal_filter_index*/,
                            const int vertical_filter_index,
                            const int /*horizontal_filter_id*/,
                            const int vertical_filter_id, const int width,
-                           const int height, void* prediction,
+                           const int height, void* LIBGAV1_RESTRICT prediction,
                            const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
   const int vertical_taps = GetNumTapsInFilter(filter_index);
@@ -1257,11 +1266,11 @@
 }
 
 void ConvolveCompoundVertical_AVX2(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int /*horizontal_filter_index*/, const int vertical_filter_index,
-    const int /*horizontal_filter_id*/, const int vertical_filter_id,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t /*pred_stride*/) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int vertical_filter_index, const int /*horizontal_filter_id*/,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t /*pred_stride*/) {
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
   const int vertical_taps = GetNumTapsInFilter(filter_index);
   const ptrdiff_t src_stride = reference_stride;
@@ -1366,14 +1375,12 @@
   }
 }
 
-void ConvolveHorizontal_AVX2(const void* const reference,
-                             const ptrdiff_t reference_stride,
-                             const int horizontal_filter_index,
-                             const int /*vertical_filter_index*/,
-                             const int horizontal_filter_id,
-                             const int /*vertical_filter_id*/, const int width,
-                             const int height, void* prediction,
-                             const ptrdiff_t pred_stride) {
+void ConvolveHorizontal_AVX2(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int /*vertical_filter_index*/, const int horizontal_filter_id,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   // Set |src| to the outermost tap.
   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
@@ -1390,11 +1397,11 @@
 }
 
 void ConvolveCompoundHorizontal_AVX2(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int horizontal_filter_index, const int /*vertical_filter_index*/,
-    const int horizontal_filter_id, const int /*vertical_filter_id*/,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t pred_stride) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int /*vertical_filter_index*/, const int horizontal_filter_id,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   // Set |src| to the outermost tap.
   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
@@ -1415,14 +1422,12 @@
       filter_index);
 }
 
-void ConvolveCompound2D_AVX2(const void* const reference,
-                             const ptrdiff_t reference_stride,
-                             const int horizontal_filter_index,
-                             const int vertical_filter_index,
-                             const int horizontal_filter_id,
-                             const int vertical_filter_id, const int width,
-                             const int height, void* prediction,
-                             const ptrdiff_t pred_stride) {
+void ConvolveCompound2D_AVX2(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int vertical_filter_index, const int horizontal_filter_id,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
   const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
diff --git a/libgav1/src/dsp/x86/convolve_sse4.cc b/libgav1/src/dsp/x86/convolve_sse4.cc
index 9b72fe4..f7e5a71 100644
--- a/libgav1/src/dsp/x86/convolve_sse4.cc
+++ b/libgav1/src/dsp/x86/convolve_sse4.cc
@@ -37,7 +37,7 @@
 #include "src/dsp/x86/convolve_sse4.inc"
 
 template <int filter_index>
-__m128i SumHorizontalTaps(const uint8_t* const src,
+__m128i SumHorizontalTaps(const uint8_t* LIBGAV1_RESTRICT const src,
                           const __m128i* const v_tap) {
   __m128i v_src[4];
   const __m128i src_long = LoadUnaligned16(src);
@@ -68,7 +68,7 @@
 }
 
 template <int filter_index>
-__m128i SimpleHorizontalTaps(const uint8_t* const src,
+__m128i SimpleHorizontalTaps(const uint8_t* LIBGAV1_RESTRICT const src,
                              const __m128i* const v_tap) {
   __m128i sum = SumHorizontalTaps<filter_index>(src, v_tap);
 
@@ -84,7 +84,7 @@
 }
 
 template <int filter_index>
-__m128i HorizontalTaps8To16(const uint8_t* const src,
+__m128i HorizontalTaps8To16(const uint8_t* LIBGAV1_RESTRICT const src,
                             const __m128i* const v_tap) {
   const __m128i sum = SumHorizontalTaps<filter_index>(src, v_tap);
 
@@ -93,10 +93,11 @@
 
 template <int num_taps, int filter_index, bool is_2d = false,
           bool is_compound = false>
-void FilterHorizontal(const uint8_t* src, const ptrdiff_t src_stride,
-                      void* const dest, const ptrdiff_t pred_stride,
-                      const int width, const int height,
-                      const __m128i* const v_tap) {
+void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
+                      const ptrdiff_t src_stride,
+                      void* LIBGAV1_RESTRICT const dest,
+                      const ptrdiff_t pred_stride, const int width,
+                      const int height, const __m128i* const v_tap) {
   auto* dest8 = static_cast<uint8_t*>(dest);
   auto* dest16 = static_cast<uint16_t*>(dest);
 
@@ -206,9 +207,10 @@
 
 template <bool is_2d = false, bool is_compound = false>
 LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
-    const uint8_t* const src, const ptrdiff_t src_stride, void* const dst,
-    const ptrdiff_t dst_stride, const int width, const int height,
-    const int filter_id, const int filter_index) {
+    const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
+    void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride,
+    const int width, const int height, const int filter_id,
+    const int filter_index) {
   assert(filter_id != 0);
   __m128i v_tap[4];
   const __m128i v_horizontal_filter =
@@ -241,13 +243,13 @@
   }
 }
 
-void Convolve2D_SSE4_1(const void* const reference,
+void Convolve2D_SSE4_1(const void* LIBGAV1_RESTRICT const reference,
                        const ptrdiff_t reference_stride,
                        const int horizontal_filter_index,
                        const int vertical_filter_index,
                        const int horizontal_filter_id,
                        const int vertical_filter_id, const int width,
-                       const int height, void* prediction,
+                       const int height, void* LIBGAV1_RESTRICT prediction,
                        const ptrdiff_t pred_stride) {
   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
@@ -328,10 +330,11 @@
 }
 
 template <int filter_index, bool is_compound = false>
-void FilterVertical(const uint8_t* src, const ptrdiff_t src_stride,
-                    void* const dst, const ptrdiff_t dst_stride,
-                    const int width, const int height,
-                    const __m128i* const v_tap) {
+void FilterVertical(const uint8_t* LIBGAV1_RESTRICT src,
+                    const ptrdiff_t src_stride,
+                    void* LIBGAV1_RESTRICT const dst,
+                    const ptrdiff_t dst_stride, const int width,
+                    const int height, const __m128i* const v_tap) {
   const int num_taps = GetNumTapsInFilter(filter_index);
   const int next_row = num_taps - 1;
   auto* dst8 = static_cast<uint8_t*>(dst);
@@ -400,14 +403,12 @@
   } while (x < width);
 }
 
-void ConvolveVertical_SSE4_1(const void* const reference,
-                             const ptrdiff_t reference_stride,
-                             const int /*horizontal_filter_index*/,
-                             const int vertical_filter_index,
-                             const int /*horizontal_filter_id*/,
-                             const int vertical_filter_id, const int width,
-                             const int height, void* prediction,
-                             const ptrdiff_t pred_stride) {
+void ConvolveVertical_SSE4_1(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int vertical_filter_index, const int /*horizontal_filter_id*/,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
   const int vertical_taps = GetNumTapsInFilter(filter_index);
   const ptrdiff_t src_stride = reference_stride;
@@ -477,14 +478,12 @@
   }
 }
 
-void ConvolveCompoundCopy_SSE4(const void* const reference,
-                               const ptrdiff_t reference_stride,
-                               const int /*horizontal_filter_index*/,
-                               const int /*vertical_filter_index*/,
-                               const int /*horizontal_filter_id*/,
-                               const int /*vertical_filter_id*/,
-                               const int width, const int height,
-                               void* prediction, const ptrdiff_t pred_stride) {
+void ConvolveCompoundCopy_SSE4(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   const auto* src = static_cast<const uint8_t*>(reference);
   const ptrdiff_t src_stride = reference_stride;
   auto* dest = static_cast<uint16_t*>(prediction);
@@ -539,11 +538,11 @@
 }
 
 void ConvolveCompoundVertical_SSE4_1(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int /*horizontal_filter_index*/, const int vertical_filter_index,
-    const int /*horizontal_filter_id*/, const int vertical_filter_id,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t /*pred_stride*/) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int vertical_filter_index, const int /*horizontal_filter_id*/,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t /*pred_stride*/) {
   const int filter_index = GetFilterIndex(vertical_filter_index, height);
   const int vertical_taps = GetNumTapsInFilter(filter_index);
   const ptrdiff_t src_stride = reference_stride;
@@ -608,14 +607,12 @@
   }
 }
 
-void ConvolveHorizontal_SSE4_1(const void* const reference,
-                               const ptrdiff_t reference_stride,
-                               const int horizontal_filter_index,
-                               const int /*vertical_filter_index*/,
-                               const int horizontal_filter_id,
-                               const int /*vertical_filter_id*/,
-                               const int width, const int height,
-                               void* prediction, const ptrdiff_t pred_stride) {
+void ConvolveHorizontal_SSE4_1(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int /*vertical_filter_index*/, const int horizontal_filter_id,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   // Set |src| to the outermost tap.
   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
@@ -626,11 +623,11 @@
 }
 
 void ConvolveCompoundHorizontal_SSE4_1(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int horizontal_filter_index, const int /*vertical_filter_index*/,
-    const int horizontal_filter_id, const int /*vertical_filter_id*/,
-    const int width, const int height, void* prediction,
-    const ptrdiff_t /*pred_stride*/) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int /*vertical_filter_index*/, const int horizontal_filter_id,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t /*pred_stride*/) {
   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
   auto* dest = static_cast<uint16_t*>(prediction);
@@ -640,14 +637,12 @@
       filter_index);
 }
 
-void ConvolveCompound2D_SSE4_1(const void* const reference,
-                               const ptrdiff_t reference_stride,
-                               const int horizontal_filter_index,
-                               const int vertical_filter_index,
-                               const int horizontal_filter_id,
-                               const int vertical_filter_id, const int width,
-                               const int height, void* prediction,
-                               const ptrdiff_t /*pred_stride*/) {
+void ConvolveCompound2D_SSE4_1(
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int horizontal_filter_index,
+    const int vertical_filter_index, const int horizontal_filter_id,
+    const int vertical_filter_id, const int width, const int height,
+    void* LIBGAV1_RESTRICT prediction, const ptrdiff_t /*pred_stride*/) {
   // The output of the horizontal filter, i.e. the intermediate_result, is
   // guaranteed to fit in int16_t.
   alignas(16) uint16_t
@@ -835,7 +830,8 @@
 // exceed 4 when width <= 4, |grade_x| is set to 1 regardless of the value of
 // |step_x|.
 template <int num_taps, int grade_x>
-inline void PrepareSourceVectors(const uint8_t* src, const __m128i src_indices,
+inline void PrepareSourceVectors(const uint8_t* LIBGAV1_RESTRICT src,
+                                 const __m128i src_indices,
                                  __m128i* const source /*[num_taps >> 1]*/) {
   // |used_bytes| is only computed in msan builds. Mask away unused bytes for
   // msan because it incorrectly models the outcome of the shuffles in some
@@ -900,10 +896,11 @@
 }
 
 template <int grade_x, int filter_index, int num_taps>
-inline void ConvolveHorizontalScale(const uint8_t* src, ptrdiff_t src_stride,
-                                    int width, int subpixel_x, int step_x,
+inline void ConvolveHorizontalScale(const uint8_t* LIBGAV1_RESTRICT src,
+                                    ptrdiff_t src_stride, int width,
+                                    int subpixel_x, int step_x,
                                     int intermediate_height,
-                                    int16_t* intermediate) {
+                                    int16_t* LIBGAV1_RESTRICT intermediate) {
   // Account for the 0-taps that precede the 2 nonzero taps.
   const int kernel_offset = (8 - num_taps) >> 1;
   const int ref_x = subpixel_x >> kScaleSubPixelBits;
@@ -946,11 +943,11 @@
   }
 
   // |width| >= 8
+  int16_t* intermediate_x = intermediate;
   int x = 0;
   do {
     const uint8_t* src_x =
         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
-    int16_t* intermediate_x = intermediate + x;
     // Only add steps to the 10-bit truncated p to avoid overflow.
     const __m128i p_fraction = _mm_set1_epi16(p & 1023);
     const __m128i subpel_indices = _mm_add_epi16(index_steps, p_fraction);
@@ -976,7 +973,8 @@
 }
 
 template <int num_taps>
-inline void PrepareVerticalTaps(const int8_t* taps, __m128i* output) {
+inline void PrepareVerticalTaps(const int8_t* LIBGAV1_RESTRICT taps,
+                                __m128i* output) {
   // Avoid overreading the filter due to starting at kernel_offset.
   // The only danger of overread is in the final filter, which has 4 taps.
   const __m128i filter =
@@ -1072,10 +1070,12 @@
 // |width_class| is 2, 4, or 8, according to the Store function that should be
 // used.
 template <int num_taps, int width_class, bool is_compound>
-inline void ConvolveVerticalScale(const int16_t* src, const int width,
-                                  const int subpixel_y, const int filter_index,
-                                  const int step_y, const int height,
-                                  void* dest, const ptrdiff_t dest_stride) {
+inline void ConvolveVerticalScale(const int16_t* LIBGAV1_RESTRICT src,
+                                  const int intermediate_height,
+                                  const int width, const int subpixel_y,
+                                  const int filter_index, const int step_y,
+                                  const int height, void* LIBGAV1_RESTRICT dest,
+                                  const ptrdiff_t dest_stride) {
   constexpr ptrdiff_t src_stride = kIntermediateStride;
   constexpr int kernel_offset = (8 - num_taps) / 2;
   const int16_t* src_y = src;
@@ -1138,15 +1138,19 @@
 
   // |width_class| >= 8
   __m128i filter_taps[num_taps >> 1];
-  do {  // y > 0
-    src_y = src + (p >> kScaleSubPixelBits) * src_stride;
-    const int filter_id = (p >> 6) & kSubPixelMask;
-    const int8_t* filter =
-        kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
-    PrepareVerticalTaps<num_taps>(filter, filter_taps);
+  int x = 0;
+  do {  // x < width
+    auto* dest_y = static_cast<uint8_t*>(dest) + x;
+    auto* dest16_y = static_cast<uint16_t*>(dest) + x;
+    int p = subpixel_y & 1023;
+    int y = height;
+    do {  // y > 0
+      const int filter_id = (p >> 6) & kSubPixelMask;
+      const int8_t* filter =
+          kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
+      PrepareVerticalTaps<num_taps>(filter, filter_taps);
 
-    int x = 0;
-    do {  // x < width
+      src_y = src + (p >> kScaleSubPixelBits) * src_stride;
       for (int i = 0; i < num_taps; ++i) {
         s[i] = LoadUnaligned16(src_y + i * src_stride);
       }
@@ -1154,38 +1158,36 @@
       const __m128i sums =
           Sum2DVerticalTaps<num_taps, is_compound>(s, filter_taps);
       if (is_compound) {
-        StoreUnaligned16(dest16_y + x, sums);
+        StoreUnaligned16(dest16_y, sums);
       } else {
-        StoreLo8(dest_y + x, _mm_packus_epi16(sums, sums));
+        StoreLo8(dest_y, _mm_packus_epi16(sums, sums));
       }
-      x += 8;
-      src_y += 8;
-    } while (x < width);
-    p += step_y;
-    dest_y += dest_stride;
-    dest16_y += dest_stride;
-  } while (--y != 0);
+      p += step_y;
+      dest_y += dest_stride;
+      dest16_y += dest_stride;
+    } while (--y != 0);
+    src += kIntermediateStride * intermediate_height;
+    x += 8;
+  } while (x < width);
 }
 
 template <bool is_compound>
-void ConvolveScale2D_SSE4_1(const void* const reference,
+void ConvolveScale2D_SSE4_1(const void* LIBGAV1_RESTRICT const reference,
                             const ptrdiff_t reference_stride,
                             const int horizontal_filter_index,
                             const int vertical_filter_index,
                             const int subpixel_x, const int subpixel_y,
                             const int step_x, const int step_y, const int width,
-                            const int height, void* prediction,
+                            const int height, void* LIBGAV1_RESTRICT prediction,
                             const ptrdiff_t pred_stride) {
   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
   assert(step_x <= 2048);
   // The output of the horizontal filter, i.e. the intermediate_result, is
   // guaranteed to fit in int16_t.
-  // TODO(petersonab): Reduce intermediate block stride to width to make smaller
-  // blocks faster.
   alignas(16) int16_t
-      intermediate_result[kMaxSuperBlockSizeInPixels *
-                          (2 * kMaxSuperBlockSizeInPixels + kSubPixelTaps)];
+      intermediate_result[kIntermediateAllocWidth *
+                          (2 * kIntermediateAllocWidth + kSubPixelTaps)];
   const int num_vert_taps = GetNumTapsInFilter(vert_filter_index);
   const int intermediate_height =
       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
@@ -1282,76 +1284,78 @@
     case 1:
       if (!is_compound && width == 2) {
         ConvolveVerticalScale<6, 2, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       } else if (width == 4) {
         ConvolveVerticalScale<6, 4, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       } else {
         ConvolveVerticalScale<6, 8, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       }
       break;
     case 2:
       if (!is_compound && width == 2) {
         ConvolveVerticalScale<8, 2, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       } else if (width == 4) {
         ConvolveVerticalScale<8, 4, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       } else {
         ConvolveVerticalScale<8, 8, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       }
       break;
     case 3:
       if (!is_compound && width == 2) {
         ConvolveVerticalScale<2, 2, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       } else if (width == 4) {
         ConvolveVerticalScale<2, 4, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       } else {
         ConvolveVerticalScale<2, 8, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       }
       break;
     default:
       assert(vert_filter_index == 4 || vert_filter_index == 5);
       if (!is_compound && width == 2) {
         ConvolveVerticalScale<4, 2, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       } else if (width == 4) {
         ConvolveVerticalScale<4, 4, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       } else {
         ConvolveVerticalScale<4, 8, is_compound>(
-            intermediate, width, subpixel_y, vert_filter_index, step_y, height,
-            prediction, pred_stride);
+            intermediate, intermediate_height, width, subpixel_y,
+            vert_filter_index, step_y, height, prediction, pred_stride);
       }
   }
 }
 
-inline void HalfAddHorizontal(const uint8_t* src, uint8_t* dst) {
+inline void HalfAddHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
+                              uint8_t* LIBGAV1_RESTRICT dst) {
   const __m128i left = LoadUnaligned16(src);
   const __m128i right = LoadUnaligned16(src + 1);
   StoreUnaligned16(dst, _mm_avg_epu8(left, right));
 }
 
 template <int width>
-inline void IntraBlockCopyHorizontal(const uint8_t* src,
+inline void IntraBlockCopyHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
                                      const ptrdiff_t src_stride,
-                                     const int height, uint8_t* dst,
+                                     const int height,
+                                     uint8_t* LIBGAV1_RESTRICT dst,
                                      const ptrdiff_t dst_stride) {
   const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
@@ -1392,10 +1396,11 @@
 }
 
 void ConvolveIntraBlockCopyHorizontal_SSE4_1(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const int /*subpixel_x*/, const int /*subpixel_y*/, const int width,
-    const int height, void* const prediction, const ptrdiff_t pred_stride) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*subpixel_x*/,
+    const int /*subpixel_y*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
   const auto* src = static_cast<const uint8_t*>(reference);
   auto* dest = static_cast<uint8_t*>(prediction);
 
@@ -1464,9 +1469,10 @@
 }
 
 template <int width>
-inline void IntraBlockCopyVertical(const uint8_t* src,
+inline void IntraBlockCopyVertical(const uint8_t* LIBGAV1_RESTRICT src,
                                    const ptrdiff_t src_stride, const int height,
-                                   uint8_t* dst, const ptrdiff_t dst_stride) {
+                                   uint8_t* LIBGAV1_RESTRICT dst,
+                                   const ptrdiff_t dst_stride) {
   const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
   __m128i row[8], below[8];
@@ -1553,11 +1559,11 @@
 }
 
 void ConvolveIntraBlockCopyVertical_SSE4_1(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/,
-    const int width, const int height, void* const prediction,
-    const ptrdiff_t pred_stride) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
   const auto* src = static_cast<const uint8_t*>(reference);
   auto* dest = static_cast<uint8_t*>(prediction);
 
@@ -1622,7 +1628,8 @@
 }
 
 // Load then add two uint8_t vectors. Return the uint16_t vector result.
-inline __m128i LoadU8AndAddLong(const uint8_t* src, const uint8_t* src1) {
+inline __m128i LoadU8AndAddLong(const uint8_t* LIBGAV1_RESTRICT src,
+                                const uint8_t* LIBGAV1_RESTRICT src1) {
   const __m128i a = _mm_cvtepu8_epi16(LoadLo8(src));
   const __m128i b = _mm_cvtepu8_epi16(LoadLo8(src1));
   return _mm_add_epi16(a, b);
@@ -1637,8 +1644,9 @@
 }
 
 template <int width>
-inline void IntraBlockCopy2D(const uint8_t* src, const ptrdiff_t src_stride,
-                             const int height, uint8_t* dst,
+inline void IntraBlockCopy2D(const uint8_t* LIBGAV1_RESTRICT src,
+                             const ptrdiff_t src_stride, const int height,
+                             uint8_t* LIBGAV1_RESTRICT dst,
                              const ptrdiff_t dst_stride) {
   const ptrdiff_t src_remainder_stride = src_stride - (width - 8);
   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8);
@@ -1793,11 +1801,11 @@
 }
 
 void ConvolveIntraBlockCopy2D_SSE4_1(
-    const void* const reference, const ptrdiff_t reference_stride,
-    const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
-    const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/,
-    const int width, const int height, void* const prediction,
-    const ptrdiff_t pred_stride) {
+    const void* LIBGAV1_RESTRICT const reference,
+    const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
+    const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
+    const int /*vertical_filter_id*/, const int width, const int height,
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
   const auto* src = static_cast<const uint8_t*>(reference);
   auto* dest = static_cast<uint8_t*>(prediction);
   // Note: allow vertical access to height + 1. Because this function is only
diff --git a/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc b/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc
index 3c29b19..c813df4 100644
--- a/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc
+++ b/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc
@@ -54,8 +54,10 @@
 
 template <int height>
 inline void DistanceWeightedBlend4xH_SSE4_1(
-    const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
-    const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
+    const int16_t* LIBGAV1_RESTRICT pred_0,
+    const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
+    const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
+    const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
 
@@ -98,8 +100,10 @@
 
 template <int height>
 inline void DistanceWeightedBlend8xH_SSE4_1(
-    const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
-    const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
+    const int16_t* LIBGAV1_RESTRICT pred_0,
+    const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
+    const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
+    const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
 
@@ -125,9 +129,10 @@
 }
 
 inline void DistanceWeightedBlendLarge_SSE4_1(
-    const int16_t* pred_0, const int16_t* pred_1, const uint8_t weight_0,
-    const uint8_t weight_1, const int width, const int height, void* const dest,
-    const ptrdiff_t dest_stride) {
+    const int16_t* LIBGAV1_RESTRICT pred_0,
+    const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
+    const uint8_t weight_1, const int width, const int height,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
 
@@ -154,11 +159,12 @@
   } while (--y != 0);
 }
 
-void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
-                                  const void* prediction_1,
+void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
+                                  const void* LIBGAV1_RESTRICT prediction_1,
                                   const uint8_t weight_0,
                                   const uint8_t weight_1, const int width,
-                                  const int height, void* const dest,
+                                  const int height,
+                                  void* LIBGAV1_RESTRICT const dest,
                                   const ptrdiff_t dest_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
@@ -257,8 +263,10 @@
 
 template <int height>
 inline void DistanceWeightedBlend4xH_SSE4_1(
-    const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
-    const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
+    const uint16_t* LIBGAV1_RESTRICT pred_0,
+    const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
+    const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
+    const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint16_t*>(dest);
   const __m128i weight0 = _mm_set1_epi32(weight_0);
   const __m128i weight1 = _mm_set1_epi32(weight_1);
@@ -301,8 +309,10 @@
 
 template <int height>
 inline void DistanceWeightedBlend8xH_SSE4_1(
-    const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
-    const uint8_t weight_1, void* const dest, const ptrdiff_t dest_stride) {
+    const uint16_t* LIBGAV1_RESTRICT pred_0,
+    const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
+    const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
+    const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint16_t*>(dest);
   const __m128i weight0 = _mm_set1_epi32(weight_0);
   const __m128i weight1 = _mm_set1_epi32(weight_1);
@@ -332,9 +342,10 @@
 }
 
 inline void DistanceWeightedBlendLarge_SSE4_1(
-    const uint16_t* pred_0, const uint16_t* pred_1, const uint8_t weight_0,
-    const uint8_t weight_1, const int width, const int height, void* const dest,
-    const ptrdiff_t dest_stride) {
+    const uint16_t* LIBGAV1_RESTRICT pred_0,
+    const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
+    const uint8_t weight_1, const int width, const int height,
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint16_t*>(dest);
   const __m128i weight0 = _mm_set1_epi32(weight_0);
   const __m128i weight1 = _mm_set1_epi32(weight_1);
@@ -364,11 +375,12 @@
   } while (--y != 0);
 }
 
-void DistanceWeightedBlend_SSE4_1(const void* prediction_0,
-                                  const void* prediction_1,
+void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
+                                  const void* LIBGAV1_RESTRICT prediction_1,
                                   const uint8_t weight_0,
                                   const uint8_t weight_1, const int width,
-                                  const int height, void* const dest,
+                                  const int height,
+                                  void* LIBGAV1_RESTRICT const dest,
                                   const ptrdiff_t dest_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
diff --git a/libgav1/src/dsp/x86/film_grain_sse4.cc b/libgav1/src/dsp/x86/film_grain_sse4.cc
index 745c1ca..9ece947 100644
--- a/libgav1/src/dsp/x86/film_grain_sse4.cc
+++ b/libgav1/src/dsp/x86/film_grain_sse4.cc
@@ -126,30 +126,16 @@
 }
 
 template <int bitdepth, typename Pixel>
-inline __m128i GetScalingFactors(
-    const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* source) {
+inline __m128i GetScalingFactors(const int16_t* scaling_lut,
+                                 const Pixel* source) {
   alignas(16) int16_t start_vals[8];
-  if (bitdepth == 8) {
-    // TODO(petersonab): Speed this up by creating a uint16_t scaling_lut.
-    // Currently this code results in a series of movzbl.
-    for (int i = 0; i < 8; ++i) {
-      start_vals[i] = scaling_lut[source[i]];
-    }
-    return LoadAligned16(start_vals);
-  }
-  alignas(16) int16_t end_vals[8];
-  // TODO(petersonab): Precompute this into a larger table for direct lookups.
+  static_assert(bitdepth <= kBitdepth10,
+                "SSE4 Film Grain is not yet implemented for 12bpp.");
   for (int i = 0; i < 8; ++i) {
-    const int index = source[i] >> 2;
-    start_vals[i] = scaling_lut[index];
-    end_vals[i] = scaling_lut[index + 1];
+    assert(source[i] < kScalingLookupTableSize << (bitdepth - 2));
+    start_vals[i] = scaling_lut[source[i]];
   }
-  const __m128i start = LoadAligned16(start_vals);
-  const __m128i end = LoadAligned16(end_vals);
-  __m128i remainder = LoadSource(source);
-  remainder = _mm_srli_epi16(_mm_slli_epi16(remainder, 14), 1);
-  const __m128i delta = _mm_mulhrs_epi16(_mm_sub_epi16(end, start), remainder);
-  return _mm_add_epi16(start, delta);
+  return LoadAligned16(start_vals);
 }
 
 // |scaling_shift| is in range [8,11].
@@ -162,11 +148,10 @@
 
 template <int bitdepth, typename GrainType, typename Pixel>
 void BlendNoiseWithImageLuma_SSE4_1(
-    const void* noise_image_ptr, int min_value, int max_luma, int scaling_shift,
-    int width, int height, int start_height,
-    const uint8_t scaling_lut_y[kScalingLookupTableSize],
-    const void* source_plane_y, ptrdiff_t source_stride_y, void* dest_plane_y,
-    ptrdiff_t dest_stride_y) {
+    const void* LIBGAV1_RESTRICT noise_image_ptr, int min_value, int max_luma,
+    int scaling_shift, int width, int height, int start_height,
+    const int16_t* scaling_lut_y, const void* source_plane_y,
+    ptrdiff_t source_stride_y, void* dest_plane_y, ptrdiff_t dest_stride_y) {
   const auto* noise_image =
       static_cast<const Array2D<GrainType>*>(noise_image_ptr);
   const auto* in_y_row = static_cast<const Pixel*>(source_plane_y);
@@ -181,7 +166,6 @@
   do {
     int x = 0;
     for (; x < safe_width; x += 8) {
-      // TODO(b/133525232): Make 16-pixel version of loop body.
       const __m128i orig = LoadSource(&in_y_row[x]);
       const __m128i scaling =
           GetScalingFactors<bitdepth, Pixel>(scaling_lut_y, &in_y_row[x]);
@@ -216,9 +200,9 @@
 
 template <int bitdepth, typename GrainType, typename Pixel>
 inline __m128i BlendChromaValsWithCfl(
-    const Pixel* average_luma_buffer,
-    const uint8_t scaling_lut[kScalingLookupTableSize],
-    const Pixel* chroma_cursor, const GrainType* noise_image_cursor,
+    const Pixel* LIBGAV1_RESTRICT average_luma_buffer,
+    const int16_t* scaling_lut, const Pixel* LIBGAV1_RESTRICT chroma_cursor,
+    const GrainType* LIBGAV1_RESTRICT noise_image_cursor,
     const __m128i scaling_shift) {
   const __m128i scaling =
       GetScalingFactors<bitdepth, Pixel>(scaling_lut, average_luma_buffer);
@@ -232,11 +216,10 @@
 LIBGAV1_ALWAYS_INLINE void BlendChromaPlaneWithCfl_SSE4_1(
     const Array2D<GrainType>& noise_image, int min_value, int max_chroma,
     int width, int height, int start_height, int subsampling_x,
-    int subsampling_y, int scaling_shift,
-    const uint8_t scaling_lut[kScalingLookupTableSize], const Pixel* in_y_row,
-    ptrdiff_t source_stride_y, const Pixel* in_chroma_row,
-    ptrdiff_t source_stride_chroma, Pixel* out_chroma_row,
-    ptrdiff_t dest_stride) {
+    int subsampling_y, int scaling_shift, const int16_t* scaling_lut,
+    const Pixel* LIBGAV1_RESTRICT in_y_row, ptrdiff_t source_stride_y,
+    const Pixel* in_chroma_row, ptrdiff_t source_stride_chroma,
+    Pixel* out_chroma_row, ptrdiff_t dest_stride) {
   const __m128i floor = _mm_set1_epi16(min_value);
   const __m128i ceiling = _mm_set1_epi16(max_chroma);
   alignas(16) Pixel luma_buffer[16];
@@ -258,8 +241,6 @@
     int x = 0;
     for (; x < safe_chroma_width; x += 8) {
       const int luma_x = x << subsampling_x;
-      // TODO(petersonab): Consider specializing by subsampling_x. In the 444
-      // case &in_y_row[x] can be passed to GetScalingFactors directly.
       const __m128i average_luma =
           GetAverageLuma(&in_y_row[luma_x], subsampling_x);
       StoreUnsigned(average_luma_buffer, average_luma);
@@ -277,7 +258,7 @@
       // Prevent huge indices from entering GetScalingFactors due to
       // uninitialized values. This is not a problem in 8bpp because the table
       // is made larger than 255 values.
-      if (bitdepth > 8) {
+      if (bitdepth > kBitdepth8) {
         memset(luma_buffer, 0, sizeof(luma_buffer));
       }
       const int luma_x = x << subsampling_x;
@@ -306,11 +287,11 @@
 // This further implies that scaling_lut_u == scaling_lut_v == scaling_lut_y.
 template <int bitdepth, typename GrainType, typename Pixel>
 void BlendNoiseWithImageChromaWithCfl_SSE4_1(
-    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
-    int min_value, int max_chroma, int width, int height, int start_height,
-    int subsampling_x, int subsampling_y,
-    const uint8_t scaling_lut[kScalingLookupTableSize],
-    const void* source_plane_y, ptrdiff_t source_stride_y,
+    Plane plane, const FilmGrainParams& params,
+    const void* LIBGAV1_RESTRICT noise_image_ptr, int min_value, int max_chroma,
+    int width, int height, int start_height, int subsampling_x,
+    int subsampling_y, const int16_t* scaling_lut,
+    const void* LIBGAV1_RESTRICT source_plane_y, ptrdiff_t source_stride_y,
     const void* source_plane_uv, ptrdiff_t source_stride_uv,
     void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
   const auto* noise_image =
@@ -335,10 +316,10 @@
 
 // |offset| is 32x4 packed to add with the result of _mm_madd_epi16.
 inline __m128i BlendChromaValsNoCfl8bpp(
-    const uint8_t scaling_lut[kScalingLookupTableSize], const __m128i& orig,
-    const int8_t* noise_image_cursor, const __m128i& average_luma,
-    const __m128i& scaling_shift, const __m128i& offset,
-    const __m128i& weights) {
+    const int16_t* scaling_lut, const __m128i& orig,
+    const int8_t* LIBGAV1_RESTRICT noise_image_cursor,
+    const __m128i& average_luma, const __m128i& scaling_shift,
+    const __m128i& offset, const __m128i& weights) {
   uint8_t merged_buffer[8];
   const __m128i combined_lo =
       _mm_madd_epi16(_mm_unpacklo_epi16(average_luma, orig), weights);
@@ -351,9 +332,9 @@
 
   StoreLo8(merged_buffer, _mm_packus_epi16(merged, merged));
   const __m128i scaling =
-      GetScalingFactors<8, uint8_t>(scaling_lut, merged_buffer);
+      GetScalingFactors<kBitdepth8, uint8_t>(scaling_lut, merged_buffer);
   __m128i noise = LoadSource(noise_image_cursor);
-  noise = ScaleNoise<8>(noise, scaling, scaling_shift);
+  noise = ScaleNoise<kBitdepth8>(noise, scaling, scaling_shift);
   return _mm_add_epi16(orig, noise);
 }
 
@@ -361,11 +342,10 @@
     const Array2D<int8_t>& noise_image, int min_value, int max_chroma,
     int width, int height, int start_height, int subsampling_x,
     int subsampling_y, int scaling_shift, int chroma_offset,
-    int chroma_multiplier, int luma_multiplier,
-    const uint8_t scaling_lut[kScalingLookupTableSize], const uint8_t* in_y_row,
-    ptrdiff_t source_stride_y, const uint8_t* in_chroma_row,
-    ptrdiff_t source_stride_chroma, uint8_t* out_chroma_row,
-    ptrdiff_t dest_stride) {
+    int chroma_multiplier, int luma_multiplier, const int16_t* scaling_lut,
+    const uint8_t* LIBGAV1_RESTRICT in_y_row, ptrdiff_t source_stride_y,
+    const uint8_t* in_chroma_row, ptrdiff_t source_stride_chroma,
+    uint8_t* out_chroma_row, ptrdiff_t dest_stride) {
   const __m128i floor = _mm_set1_epi16(min_value);
   const __m128i ceiling = _mm_set1_epi16(max_chroma);
 
@@ -432,11 +412,11 @@
 
 // This function is for the case params_.chroma_scaling_from_luma == false.
 void BlendNoiseWithImageChroma8bpp_SSE4_1(
-    Plane plane, const FilmGrainParams& params, const void* noise_image_ptr,
-    int min_value, int max_chroma, int width, int height, int start_height,
-    int subsampling_x, int subsampling_y,
-    const uint8_t scaling_lut[kScalingLookupTableSize],
-    const void* source_plane_y, ptrdiff_t source_stride_y,
+    Plane plane, const FilmGrainParams& params,
+    const void* LIBGAV1_RESTRICT noise_image_ptr, int min_value, int max_chroma,
+    int width, int height, int start_height, int subsampling_x,
+    int subsampling_y, const int16_t* scaling_lut,
+    const void* LIBGAV1_RESTRICT source_plane_y, ptrdiff_t source_stride_y,
     const void* source_plane_uv, ptrdiff_t source_stride_uv,
     void* dest_plane_uv, ptrdiff_t dest_stride_uv) {
   assert(plane == kPlaneU || plane == kPlaneV);
@@ -463,10 +443,10 @@
   assert(dsp != nullptr);
 
   dsp->film_grain.blend_noise_luma =
-      BlendNoiseWithImageLuma_SSE4_1<8, int8_t, uint8_t>;
+      BlendNoiseWithImageLuma_SSE4_1<kBitdepth8, int8_t, uint8_t>;
   dsp->film_grain.blend_noise_chroma[0] = BlendNoiseWithImageChroma8bpp_SSE4_1;
   dsp->film_grain.blend_noise_chroma[1] =
-      BlendNoiseWithImageChromaWithCfl_SSE4_1<8, int8_t, uint8_t>;
+      BlendNoiseWithImageChromaWithCfl_SSE4_1<kBitdepth8, int8_t, uint8_t>;
 }
 
 }  // namespace
@@ -481,9 +461,9 @@
   assert(dsp != nullptr);
 
   dsp->film_grain.blend_noise_luma =
-      BlendNoiseWithImageLuma_SSE4_1<10, int16_t, uint16_t>;
+      BlendNoiseWithImageLuma_SSE4_1<kBitdepth10, int16_t, uint16_t>;
   dsp->film_grain.blend_noise_chroma[1] =
-      BlendNoiseWithImageChromaWithCfl_SSE4_1<10, int16_t, uint16_t>;
+      BlendNoiseWithImageChromaWithCfl_SSE4_1<kBitdepth10, int16_t, uint16_t>;
 }
 
 }  // namespace
diff --git a/libgav1/src/dsp/x86/intra_edge_sse4.cc b/libgav1/src/dsp/x86/intra_edge_sse4.cc
index d6af907..967be06 100644
--- a/libgav1/src/dsp/x86/intra_edge_sse4.cc
+++ b/libgav1/src/dsp/x86/intra_edge_sse4.cc
@@ -41,7 +41,8 @@
 // This function applies the kernel [0, 4, 8, 4, 0] to 12 values.
 // Assumes |edge| has 16 packed byte values. Produces 12 filter outputs to
 // write as overlapping sets of 8-bytes.
-inline void ComputeKernel1Store12(uint8_t* dest, const uint8_t* source) {
+inline void ComputeKernel1Store12(uint8_t* LIBGAV1_RESTRICT dest,
+                                  const uint8_t* LIBGAV1_RESTRICT source) {
   const __m128i edge_lo = LoadUnaligned16(source);
   const __m128i edge_hi = _mm_srli_si128(edge_lo, 6);
   // Samples matched with the '4' tap, expanded to 16-bit.
@@ -77,7 +78,8 @@
 // This function applies the kernel [0, 5, 6, 5, 0] to 12 values.
 // Assumes |edge| has 8 packed byte values, and that the 2 invalid values will
 // be overwritten or safely discarded.
-inline void ComputeKernel2Store12(uint8_t* dest, const uint8_t* source) {
+inline void ComputeKernel2Store12(uint8_t* LIBGAV1_RESTRICT dest,
+                                  const uint8_t* LIBGAV1_RESTRICT source) {
   const __m128i edge_lo = LoadUnaligned16(source);
   const __m128i edge_hi = _mm_srli_si128(edge_lo, 6);
   const __m128i outers_lo = _mm_cvtepu8_epi16(edge_lo);
@@ -115,7 +117,8 @@
 }
 
 // This function applies the kernel [2, 4, 4, 4, 2] to 8 values.
-inline void ComputeKernel3Store8(uint8_t* dest, const uint8_t* source) {
+inline void ComputeKernel3Store8(uint8_t* LIBGAV1_RESTRICT dest,
+                                 const uint8_t* LIBGAV1_RESTRICT source) {
   const __m128i edge_lo = LoadUnaligned16(source);
   const __m128i edge_hi = _mm_srli_si128(edge_lo, 4);
   // Finish |edge_lo| life cycle quickly.
diff --git a/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc b/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
index f2dcfdb..eb7e466 100644
--- a/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_cfl_sse4.cc
@@ -88,7 +88,7 @@
 
 template <int width, int height>
 void CflIntraPredictor_SSE4_1(
-    void* const dest, ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   auto* dst = static_cast<uint8_t*>(dest);
@@ -127,7 +127,8 @@
 template <int block_height_log2, bool is_inside>
 void CflSubsampler444_4xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int max_luma_height, const void* const source, ptrdiff_t stride) {
+    const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
+    ptrdiff_t stride) {
   static_assert(block_height_log2 <= 4, "");
   const int block_height = 1 << block_height_log2;
   const int visible_height = max_luma_height;
@@ -189,7 +190,7 @@
 void CflSubsampler444_4xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_assert(block_height_log2 <= 4, "");
   assert(max_luma_width >= 4);
   assert(max_luma_height >= 4);
@@ -209,7 +210,7 @@
 void CflSubsampler444_8xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_assert(block_height_log2 <= 5, "");
   const int block_height = 1 << block_height_log2, block_width = 8;
   const int visible_height = max_luma_height;
@@ -292,7 +293,7 @@
 void CflSubsampler444_8xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_assert(block_height_log2 <= 5, "");
   assert(max_luma_width >= 4);
   assert(max_luma_height >= 4);
@@ -315,7 +316,7 @@
 void CflSubsampler444_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_assert(block_width_log2 == 4 || block_width_log2 == 5, "");
   static_assert(block_height_log2 <= 5, "");
   assert(max_luma_width >= 4);
@@ -418,7 +419,7 @@
 void CflSubsampler444_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_assert(block_width_log2 == 4 || block_width_log2 == 5, "");
   static_assert(block_height_log2 <= 5, "");
   assert(max_luma_width >= 4);
@@ -441,7 +442,7 @@
 void CflSubsampler420_4xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int /*max_luma_width*/, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   const int block_height = 1 << block_height_log2;
   const auto* src = static_cast<const uint8_t*>(source);
   int16_t* luma_ptr = luma[0];
@@ -511,7 +512,7 @@
 inline void CflSubsampler420Impl_8xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int /*max_luma_width*/, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   const int block_height = 1 << block_height_log2;
   const auto* src = static_cast<const uint8_t*>(source);
   const __m128i zero = _mm_setzero_si128();
@@ -620,7 +621,7 @@
 void CflSubsampler420_8xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   if (max_luma_width == 8) {
     CflSubsampler420Impl_8xH_SSE4_1<block_height_log2, 8>(
         luma, max_luma_width, max_luma_height, source, stride);
@@ -634,7 +635,7 @@
 inline void CflSubsampler420Impl_WxH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int /*max_luma_width*/, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   const auto* src = static_cast<const uint8_t*>(source);
   const __m128i zero = _mm_setzero_si128();
   __m128i final_sum = zero;
@@ -751,7 +752,7 @@
 void CflSubsampler420_WxH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   switch (max_luma_width) {
     case 8:
       CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 8>(
@@ -968,7 +969,7 @@
 
 template <int width, int height>
 void CflIntraPredictor_10bpp_SSE4_1(
-    void* const dest, ptrdiff_t stride,
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
     const int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int alpha) {
   constexpr int kCflLumaBufferStrideLog2_16i = 5;
@@ -1018,7 +1019,8 @@
 template <int block_height_log2, bool is_inside>
 void CflSubsampler444_4xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int max_luma_height, const void* const source, ptrdiff_t stride) {
+    const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
+    ptrdiff_t stride) {
   static_assert(block_height_log2 <= 4, "");
   const int block_height = 1 << block_height_log2;
   const int visible_height = max_luma_height;
@@ -1079,7 +1081,7 @@
 void CflSubsampler444_4xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_cast<void>(max_luma_width);
   static_cast<void>(max_luma_height);
   static_assert(block_height_log2 <= 4, "");
@@ -1099,7 +1101,8 @@
 template <int block_height_log2, bool is_inside>
 void CflSubsampler444_8xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int max_luma_height, const void* const source, ptrdiff_t stride) {
+    const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
+    ptrdiff_t stride) {
   const int block_height = 1 << block_height_log2;
   const int visible_height = max_luma_height;
   const __m128i dup16 = _mm_set1_epi32(0x01000100);
@@ -1158,7 +1161,7 @@
 void CflSubsampler444_8xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_cast<void>(max_luma_width);
   static_cast<void>(max_luma_height);
   static_assert(block_height_log2 <= 5, "");
@@ -1182,7 +1185,7 @@
 void CflSubsampler444_WxH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   const int block_height = 1 << block_height_log2;
   const int visible_height = max_luma_height;
   const int block_width = 1 << block_width_log2;
@@ -1278,7 +1281,7 @@
 void CflSubsampler444_WxH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   static_assert(block_width_log2 == 4 || block_width_log2 == 5,
                 "This function will only work for block_width 16 and 32.");
   static_assert(block_height_log2 <= 5, "");
@@ -1300,7 +1303,7 @@
 void CflSubsampler420_4xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int /*max_luma_width*/, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   const int block_height = 1 << block_height_log2;
   const auto* src = static_cast<const uint16_t*>(source);
   const ptrdiff_t src_stride = stride / sizeof(src[0]);
@@ -1371,7 +1374,8 @@
 template <int block_height_log2, int max_luma_width>
 inline void CflSubsampler420Impl_8xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int max_luma_height, const void* const source, ptrdiff_t stride) {
+    const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
+    ptrdiff_t stride) {
   const int block_height = 1 << block_height_log2;
   const auto* src = static_cast<const uint16_t*>(source);
   const ptrdiff_t src_stride = stride / sizeof(src[0]);
@@ -1483,7 +1487,7 @@
 void CflSubsampler420_8xH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   if (max_luma_width == 8) {
     CflSubsampler420Impl_8xH_SSE4_1<block_height_log2, 8>(luma, max_luma_height,
                                                           source, stride);
@@ -1496,7 +1500,8 @@
 template <int block_width_log2, int block_height_log2, int max_luma_width>
 inline void CflSubsampler420Impl_WxH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
-    const int max_luma_height, const void* const source, ptrdiff_t stride) {
+    const int max_luma_height, const void* LIBGAV1_RESTRICT const source,
+    ptrdiff_t stride) {
   const auto* src = static_cast<const uint16_t*>(source);
   const ptrdiff_t src_stride = stride / sizeof(src[0]);
   const __m128i zero = _mm_setzero_si128();
@@ -1615,7 +1620,7 @@
 void CflSubsampler420_WxH_SSE4_1(
     int16_t luma[kCflLumaBufferStride][kCflLumaBufferStride],
     const int max_luma_width, const int max_luma_height,
-    const void* const source, ptrdiff_t stride) {
+    const void* LIBGAV1_RESTRICT const source, ptrdiff_t stride) {
   switch (max_luma_width) {
     case 8:
       CflSubsampler420Impl_WxH_SSE4_1<block_width_log2, block_height_log2, 8>(
diff --git a/libgav1/src/dsp/x86/intrapred_filter_sse4.cc b/libgav1/src/dsp/x86/intrapred_filter_sse4.cc
index 022af8d..a43a5cf 100644
--- a/libgav1/src/dsp/x86/intrapred_filter_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_filter_sse4.cc
@@ -64,10 +64,10 @@
 // at zero to preserve the sum.
 // |pixels| contains p0-p7 in order as shown above.
 // |taps_0_1| contains the filter kernels used to predict f0 and f1, and so on.
-inline void Filter4x2_SSE4_1(uint8_t* dst, const ptrdiff_t stride,
-                             const __m128i& pixels, const __m128i& taps_0_1,
-                             const __m128i& taps_2_3, const __m128i& taps_4_5,
-                             const __m128i& taps_6_7) {
+inline void Filter4x2_SSE4_1(uint8_t* LIBGAV1_RESTRICT dst,
+                             const ptrdiff_t stride, const __m128i& pixels,
+                             const __m128i& taps_0_1, const __m128i& taps_2_3,
+                             const __m128i& taps_4_5, const __m128i& taps_6_7) {
   const __m128i mul_0_01 = _mm_maddubs_epi16(pixels, taps_0_1);
   const __m128i mul_0_23 = _mm_maddubs_epi16(pixels, taps_2_3);
   // |output_half| contains 8 partial sums for f0-f7.
@@ -93,10 +93,10 @@
 // for successive blocks. This implementation takes advantage of the fact
 // that the p5 and p6 for each sub-block come solely from the |left_ptr| buffer,
 // using shifts to arrange things to fit reusable shuffle vectors.
-inline void Filter4xH(uint8_t* dest, ptrdiff_t stride,
-                      const uint8_t* const top_ptr,
-                      const uint8_t* const left_ptr, FilterIntraPredictor pred,
-                      const int height) {
+inline void Filter4xH(uint8_t* LIBGAV1_RESTRICT dest, ptrdiff_t stride,
+                      const uint8_t* LIBGAV1_RESTRICT const top_ptr,
+                      const uint8_t* LIBGAV1_RESTRICT const left_ptr,
+                      FilterIntraPredictor pred, const int height) {
   // Two filter kernels per vector.
   const __m128i taps_0_1 = LoadAligned16(kFilterIntraTaps[pred][0]);
   const __m128i taps_2_3 = LoadAligned16(kFilterIntraTaps[pred][2]);
@@ -271,9 +271,10 @@
   }
 }
 
-void FilterIntraPredictor_SSE4_1(void* const dest, ptrdiff_t stride,
-                                 const void* const top_row,
-                                 const void* const left_column,
+void FilterIntraPredictor_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                                 ptrdiff_t stride,
+                                 const void* LIBGAV1_RESTRICT const top_row,
+                                 const void* LIBGAV1_RESTRICT const left_column,
                                  FilterIntraPredictor pred, const int width,
                                  const int height) {
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
diff --git a/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc b/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc
index de9f551..b53ee8c 100644
--- a/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_smooth_sse4.cc
@@ -38,23 +38,12 @@
 // to have visibility of the values. This helps reduce loads and in the
 // creation of the inverse weights.
 constexpr uint8_t kSmoothWeights[] = {
-    // block dimension = 4
-    255, 149, 85, 64,
-    // block dimension = 8
-    255, 197, 146, 105, 73, 50, 37, 32,
-    // block dimension = 16
-    255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16,
-    // block dimension = 32
-    255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74,
-    66, 59, 52, 45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8,
-    // block dimension = 64
-    255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156,
-    150, 144, 138, 133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73,
-    69, 65, 61, 57, 54, 50, 47, 44, 41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16,
-    15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4};
+#include "src/dsp/smooth_weights.inc"
+};
 
 template <int y_mask>
-inline void WriteSmoothHorizontalSum4(void* const dest, const __m128i& left,
+inline void WriteSmoothHorizontalSum4(void* LIBGAV1_RESTRICT const dest,
+                                      const __m128i& left,
                                       const __m128i& weights,
                                       const __m128i& scaled_top_right,
                                       const __m128i& round) {
@@ -77,7 +66,8 @@
   return _mm_add_epi16(scaled_corner, weighted_px);
 }
 
-inline void WriteSmoothDirectionalSum8(uint8_t* dest, const __m128i& pixels,
+inline void WriteSmoothDirectionalSum8(uint8_t* LIBGAV1_RESTRICT dest,
+                                       const __m128i& pixels,
                                        const __m128i& weights,
                                        const __m128i& scaled_corner,
                                        const __m128i& round) {
@@ -91,13 +81,11 @@
 // For Horizontal, pixels1 and pixels2 are the same repeated value. For
 // Vertical, weights1 and weights2 are the same, and scaled_corner1 and
 // scaled_corner2 are the same.
-inline void WriteSmoothDirectionalSum16(uint8_t* dest, const __m128i& pixels1,
-                                        const __m128i& pixels2,
-                                        const __m128i& weights1,
-                                        const __m128i& weights2,
-                                        const __m128i& scaled_corner1,
-                                        const __m128i& scaled_corner2,
-                                        const __m128i& round) {
+inline void WriteSmoothDirectionalSum16(
+    uint8_t* LIBGAV1_RESTRICT dest, const __m128i& pixels1,
+    const __m128i& pixels2, const __m128i& weights1, const __m128i& weights2,
+    const __m128i& scaled_corner1, const __m128i& scaled_corner2,
+    const __m128i& round) {
   const __m128i weighted_px1 = _mm_mullo_epi16(pixels1, weights1);
   const __m128i weighted_px2 = _mm_mullo_epi16(pixels2, weights2);
   const __m128i pred_sum1 = _mm_add_epi16(scaled_corner1, weighted_px1);
@@ -109,8 +97,9 @@
 }
 
 template <int y_mask>
-inline void WriteSmoothPredSum4(uint8_t* const dest, const __m128i& top,
-                                const __m128i& left, const __m128i& weights_x,
+inline void WriteSmoothPredSum4(uint8_t* LIBGAV1_RESTRICT const dest,
+                                const __m128i& top, const __m128i& left,
+                                const __m128i& weights_x,
                                 const __m128i& weights_y,
                                 const __m128i& scaled_bottom_left,
                                 const __m128i& scaled_top_right,
@@ -135,7 +124,8 @@
 // pixels[0]: above and below_pred interleave vector
 // pixels[1]: left vector
 // pixels[2]: right_pred vector
-inline void LoadSmoothPixels4(const uint8_t* above, const uint8_t* left,
+inline void LoadSmoothPixels4(const uint8_t* LIBGAV1_RESTRICT above,
+                              const uint8_t* LIBGAV1_RESTRICT left,
                               const int height, __m128i* pixels) {
   if (height == 4) {
     pixels[1] = Load4(left);
@@ -156,8 +146,9 @@
 // weight_h[2]: same as [0], second half for height = 16 only
 // weight_h[3]: same as [1], second half for height = 16 only
 // weight_w[0]: weights_w and scale - weights_w interleave vector
-inline void LoadSmoothWeights4(const uint8_t* weight_array, const int height,
-                               __m128i* weight_h, __m128i* weight_w) {
+inline void LoadSmoothWeights4(const uint8_t* LIBGAV1_RESTRICT weight_array,
+                               const int height, __m128i* weight_h,
+                               __m128i* weight_w) {
   const __m128i scale = _mm_set1_epi16(256);
   const __m128i x_weights = Load4(weight_array);
   weight_h[0] = _mm_cvtepu8_epi16(x_weights);
@@ -179,7 +170,8 @@
 }
 
 inline void WriteSmoothPred4x8(const __m128i* pixel, const __m128i* weights_y,
-                               const __m128i* weight_x, uint8_t* dst,
+                               const __m128i* weight_x,
+                               uint8_t* LIBGAV1_RESTRICT dst,
                                const ptrdiff_t stride,
                                const bool use_second_half) {
   const __m128i round = _mm_set1_epi32(256);
@@ -215,8 +207,9 @@
 
 // The interleaving approach has some overhead that causes it to underperform in
 // the 4x4 case.
-void Smooth4x4_SSE4_1(void* const dest, const ptrdiff_t stride,
-                      const void* top_row, const void* left_column) {
+void Smooth4x4_SSE4_1(void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT top_row,
+                      const void* LIBGAV1_RESTRICT left_column) {
   const __m128i top = _mm_cvtepu8_epi32(Load4(top_row));
   const __m128i left = _mm_cvtepu8_epi32(Load4(left_column));
   const __m128i weights = _mm_cvtepu8_epi32(Load4(kSmoothWeights));
@@ -247,8 +240,9 @@
                             scaled_bottom_left, scaled_top_right, scale);
 }
 
-void Smooth4x8_SSE4_1(void* const dest, const ptrdiff_t stride,
-                      const void* top_row, const void* left_column) {
+void Smooth4x8_SSE4_1(void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT top_row,
+                      const void* LIBGAV1_RESTRICT left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   __m128i weights_x[1];
@@ -260,8 +254,10 @@
   WriteSmoothPred4x8(pixels, weights_y, weights_x, dst, stride, false);
 }
 
-void Smooth4x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                       const void* top_row, const void* left_column) {
+void Smooth4x16_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                       const ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT top_row,
+                       const void* LIBGAV1_RESTRICT left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   __m128i weights_x[1];
@@ -283,7 +279,8 @@
 // pixels[5]: above and below_pred interleave vector, second half
 // pixels[6]: left vector + 16
 // pixels[7]: right_pred vector
-inline void LoadSmoothPixels8(const uint8_t* above, const uint8_t* left,
+inline void LoadSmoothPixels8(const uint8_t* LIBGAV1_RESTRICT above,
+                              const uint8_t* LIBGAV1_RESTRICT left,
                               const int height, __m128i* pixels) {
   const __m128i bottom_left = _mm_set1_epi16(left[height - 1]);
   __m128i top_row = _mm_cvtepu8_epi16(LoadLo8(above));
@@ -317,8 +314,9 @@
 // weight_h[7]: same as [1], offset 24
 // weight_w[0]: weights_w and scale - weights_w interleave vector, first half
 // weight_w[1]: weights_w and scale - weights_w interleave vector, second half
-inline void LoadSmoothWeights8(const uint8_t* weight_array, const int height,
-                               __m128i* weight_w, __m128i* weight_h) {
+inline void LoadSmoothWeights8(const uint8_t* LIBGAV1_RESTRICT weight_array,
+                               const int height, __m128i* weight_w,
+                               __m128i* weight_h) {
   const int offset = (height < 8) ? 0 : 4;
   __m128i loaded_weights = LoadUnaligned16(&weight_array[offset]);
   weight_h[0] = _mm_cvtepu8_epi16(loaded_weights);
@@ -360,7 +358,8 @@
 
 inline void WriteSmoothPred8xH(const __m128i* pixels, const __m128i* weights_x,
                                const __m128i* weights_y, const int height,
-                               uint8_t* dst, const ptrdiff_t stride,
+                               uint8_t* LIBGAV1_RESTRICT dst,
+                               const ptrdiff_t stride,
                                const bool use_second_half) {
   const __m128i round = _mm_set1_epi32(256);
   const __m128i mask_increment = _mm_set1_epi16(0x0202);
@@ -405,8 +404,9 @@
   }
 }
 
-void Smooth8x4_SSE4_1(void* const dest, const ptrdiff_t stride,
-                      const void* top_row, const void* left_column) {
+void Smooth8x4_SSE4_1(void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT top_row,
+                      const void* LIBGAV1_RESTRICT left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   __m128i pixels[4];
@@ -419,8 +419,9 @@
   WriteSmoothPred8xH(pixels, weights_x, weights_y, 4, dst, stride, false);
 }
 
-void Smooth8x8_SSE4_1(void* const dest, const ptrdiff_t stride,
-                      const void* top_row, const void* left_column) {
+void Smooth8x8_SSE4_1(void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT top_row,
+                      const void* LIBGAV1_RESTRICT left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
 
@@ -434,8 +435,10 @@
   WriteSmoothPred8xH(pixels, weights_x, weights_y, 8, dst, stride, false);
 }
 
-void Smooth8x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                       const void* top_row, const void* left_column) {
+void Smooth8x16_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                       const ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT top_row,
+                       const void* LIBGAV1_RESTRICT left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   __m128i pixels[4];
@@ -450,8 +453,10 @@
   WriteSmoothPred8xH(pixels, weights_x, &weights_y[2], 8, dst, stride, true);
 }
 
-void Smooth8x32_SSE4_1(void* const dest, const ptrdiff_t stride,
-                       const void* top_row, const void* left_column) {
+void Smooth8x32_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                       const ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT top_row,
+                       const void* LIBGAV1_RESTRICT left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   __m128i pixels[8];
@@ -473,8 +478,9 @@
 }
 
 template <int width, int height>
-void SmoothWxH(void* const dest, const ptrdiff_t stride,
-               const void* const top_row, const void* const left_column) {
+void SmoothWxH(void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+               const void* LIBGAV1_RESTRICT const top_row,
+               const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   const uint8_t* const sm_weights_h = kSmoothWeights + height - 4;
@@ -532,8 +538,10 @@
   }
 }
 
-void SmoothHorizontal4x4_SSE4_1(void* dest, const ptrdiff_t stride,
-                                const void* top_row, const void* left_column) {
+void SmoothHorizontal4x4_SSE4_1(void* LIBGAV1_RESTRICT dest,
+                                const ptrdiff_t stride,
+                                const void* LIBGAV1_RESTRICT top_row,
+                                const void* LIBGAV1_RESTRICT left_column) {
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi32(top_ptr[3]);
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
@@ -553,9 +561,10 @@
   WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale);
 }
 
-void SmoothHorizontal4x8_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothHorizontal4x8_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi32(top[3]);
   const __m128i weights = _mm_cvtepu8_epi32(Load4(kSmoothWeights));
@@ -585,9 +594,10 @@
   WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale);
 }
 
-void SmoothHorizontal4x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                 const void* const top_row,
-                                 const void* const left_column) {
+void SmoothHorizontal4x16_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi32(top[3]);
   const __m128i weights = _mm_cvtepu8_epi32(Load4(kSmoothWeights));
@@ -637,9 +647,10 @@
   WriteSmoothHorizontalSum4<0xFF>(dst, left, weights, scaled_top_right, scale);
 }
 
-void SmoothHorizontal8x4_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothHorizontal8x4_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[7]);
   const __m128i left = _mm_cvtepu8_epi16(Load4(left_column));
@@ -666,9 +677,10 @@
   WriteSmoothDirectionalSum8(dst, left_y, weights, scaled_top_right, scale);
 }
 
-void SmoothHorizontal8x8_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothHorizontal8x8_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[7]);
   const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column));
@@ -686,9 +698,10 @@
   }
 }
 
-void SmoothHorizontal8x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                 const void* const top_row,
-                                 const void* const left_column) {
+void SmoothHorizontal8x16_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[7]);
   const __m128i weights = _mm_cvtepu8_epi16(LoadLo8(kSmoothWeights + 4));
@@ -714,9 +727,10 @@
   }
 }
 
-void SmoothHorizontal8x32_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                 const void* const top_row,
-                                 const void* const left_column) {
+void SmoothHorizontal8x32_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[7]);
   const __m128i weights = _mm_cvtepu8_epi16(LoadLo8(kSmoothWeights + 4));
@@ -756,9 +770,10 @@
   }
 }
 
-void SmoothHorizontal16x4_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                 const void* const top_row,
-                                 const void* const left_column) {
+void SmoothHorizontal16x4_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[15]);
   const __m128i left = _mm_cvtepu8_epi16(Load4(left_column));
@@ -795,9 +810,10 @@
                               scaled_top_right1, scaled_top_right2, scale);
 }
 
-void SmoothHorizontal16x8_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                 const void* const top_row,
-                                 const void* const left_column) {
+void SmoothHorizontal16x8_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[15]);
   const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column));
@@ -822,9 +838,10 @@
   }
 }
 
-void SmoothHorizontal16x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                  const void* const top_row,
-                                  const void* const left_column) {
+void SmoothHorizontal16x16_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[15]);
   const __m128i weights = LoadUnaligned16(kSmoothWeights + 12);
@@ -858,9 +875,10 @@
   }
 }
 
-void SmoothHorizontal16x32_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                  const void* const top_row,
-                                  const void* const left_column) {
+void SmoothHorizontal16x32_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[15]);
   const __m128i weights = LoadUnaligned16(kSmoothWeights + 12);
@@ -910,9 +928,10 @@
   }
 }
 
-void SmoothHorizontal16x64_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                  const void* const top_row,
-                                  const void* const left_column) {
+void SmoothHorizontal16x64_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[15]);
   const __m128i weights = LoadUnaligned16(kSmoothWeights + 12);
@@ -940,9 +959,10 @@
   }
 }
 
-void SmoothHorizontal32x8_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                 const void* const top_row,
-                                 const void* const left_column) {
+void SmoothHorizontal32x8_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[31]);
   const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column));
@@ -978,9 +998,10 @@
   }
 }
 
-void SmoothHorizontal32x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                  const void* const top_row,
-                                  const void* const left_column) {
+void SmoothHorizontal32x16_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[31]);
   const __m128i left1 = _mm_cvtepu8_epi16(LoadLo8(left_column));
@@ -1027,9 +1048,10 @@
   }
 }
 
-void SmoothHorizontal32x32_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                  const void* const top_row,
-                                  const void* const left_column) {
+void SmoothHorizontal32x32_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[31]);
   const __m128i weights_lo = LoadUnaligned16(kSmoothWeights + 28);
@@ -1096,9 +1118,10 @@
   }
 }
 
-void SmoothHorizontal32x64_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                  const void* const top_row,
-                                  const void* const left_column) {
+void SmoothHorizontal32x64_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[31]);
   const __m128i weights_lo = LoadUnaligned16(kSmoothWeights + 28);
@@ -1137,9 +1160,10 @@
   }
 }
 
-void SmoothHorizontal64x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                  const void* const top_row,
-                                  const void* const left_column) {
+void SmoothHorizontal64x16_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[63]);
   const __m128i left1 = _mm_cvtepu8_epi16(LoadLo8(left_column));
@@ -1212,9 +1236,10 @@
   }
 }
 
-void SmoothHorizontal64x32_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                  const void* const top_row,
-                                  const void* const left_column) {
+void SmoothHorizontal64x32_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[63]);
   const __m128i left1 = _mm_cvtepu8_epi16(LoadLo8(left_column));
@@ -1315,9 +1340,10 @@
   }
 }
 
-void SmoothHorizontal64x64_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                  const void* const top_row,
-                                  const void* const left_column) {
+void SmoothHorizontal64x64_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const top = static_cast<const uint8_t*>(top_row);
   const __m128i top_right = _mm_set1_epi16(top[63]);
   const __m128i weights_lolo = LoadUnaligned16(kSmoothWeights + 60);
@@ -1378,7 +1404,8 @@
   }
 }
 
-inline void LoadSmoothVerticalPixels4(const uint8_t* above, const uint8_t* left,
+inline void LoadSmoothVerticalPixels4(const uint8_t* LIBGAV1_RESTRICT above,
+                                      const uint8_t* LIBGAV1_RESTRICT left,
                                       const int height, __m128i* pixels) {
   __m128i top = Load4(above);
   const __m128i bottom_left = _mm_set1_epi16(left[height - 1]);
@@ -1390,7 +1417,8 @@
 // (256-w) counterparts. This is precomputed by the compiler when the weights
 // table is visible to this module. Removing this visibility can cut speed by up
 // to half in both 4xH and 8xH transforms.
-inline void LoadSmoothVerticalWeights4(const uint8_t* weight_array,
+inline void LoadSmoothVerticalWeights4(const uint8_t* LIBGAV1_RESTRICT
+                                           weight_array,
                                        const int height, __m128i* weights) {
   const __m128i inverter = _mm_set1_epi16(256);
 
@@ -1413,7 +1441,8 @@
 }
 
 inline void WriteSmoothVertical4xH(const __m128i* pixel, const __m128i* weight,
-                                   const int height, uint8_t* dst,
+                                   const int height,
+                                   uint8_t* LIBGAV1_RESTRICT dst,
                                    const ptrdiff_t stride) {
   const __m128i pred_round = _mm_set1_epi32(128);
   const __m128i mask_increment = _mm_set1_epi16(0x0202);
@@ -1438,9 +1467,10 @@
   }
 }
 
-void SmoothVertical4x4_SSE4_1(void* const dest, const ptrdiff_t stride,
-                              const void* const top_row,
-                              const void* const left_column) {
+void SmoothVertical4x4_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                              const ptrdiff_t stride,
+                              const void* LIBGAV1_RESTRICT const top_row,
+                              const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left = static_cast<const uint8_t*>(left_column);
   const auto* const above = static_cast<const uint8_t*>(top_row);
   auto* dst = static_cast<uint8_t*>(dest);
@@ -1453,9 +1483,10 @@
   WriteSmoothVertical4xH(&pixels, weights, 4, dst, stride);
 }
 
-void SmoothVertical4x8_SSE4_1(void* const dest, const ptrdiff_t stride,
-                              const void* const top_row,
-                              const void* const left_column) {
+void SmoothVertical4x8_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                              const ptrdiff_t stride,
+                              const void* LIBGAV1_RESTRICT const top_row,
+                              const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left = static_cast<const uint8_t*>(left_column);
   const auto* const above = static_cast<const uint8_t*>(top_row);
   auto* dst = static_cast<uint8_t*>(dest);
@@ -1468,9 +1499,10 @@
   WriteSmoothVertical4xH(&pixels, weights, 8, dst, stride);
 }
 
-void SmoothVertical4x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                               const void* const top_row,
-                               const void* const left_column) {
+void SmoothVertical4x16_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                               const ptrdiff_t stride,
+                               const void* LIBGAV1_RESTRICT const top_row,
+                               const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left = static_cast<const uint8_t*>(left_column);
   const auto* const above = static_cast<const uint8_t*>(top_row);
   auto* dst = static_cast<uint8_t*>(dest);
@@ -1485,9 +1517,10 @@
   WriteSmoothVertical4xH(&pixels, &weights[2], 8, dst, stride);
 }
 
-void SmoothVertical8x4_SSE4_1(void* const dest, const ptrdiff_t stride,
-                              const void* const top_row,
-                              const void* const left_column) {
+void SmoothVertical8x4_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                              const ptrdiff_t stride,
+                              const void* LIBGAV1_RESTRICT const top_row,
+                              const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const __m128i bottom_left = _mm_set1_epi16(left_ptr[3]);
   const __m128i weights = _mm_cvtepu8_epi16(Load4(kSmoothWeights));
@@ -1520,9 +1553,10 @@
   WriteSmoothDirectionalSum8(dst, top, weights_y, scaled_bottom_left_y, scale);
 }
 
-void SmoothVertical8x8_SSE4_1(void* const dest, const ptrdiff_t stride,
-                              const void* const top_row,
-                              const void* const left_column) {
+void SmoothVertical8x8_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                              const ptrdiff_t stride,
+                              const void* LIBGAV1_RESTRICT const top_row,
+                              const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const __m128i bottom_left = _mm_set1_epi16(left_ptr[7]);
   const __m128i weights = _mm_cvtepu8_epi16(LoadLo8(kSmoothWeights + 4));
@@ -1544,9 +1578,10 @@
   }
 }
 
-void SmoothVertical8x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                               const void* const top_row,
-                               const void* const left_column) {
+void SmoothVertical8x16_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                               const ptrdiff_t stride,
+                               const void* LIBGAV1_RESTRICT const top_row,
+                               const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const __m128i bottom_left = _mm_set1_epi16(left_ptr[15]);
   const __m128i weights = LoadUnaligned16(kSmoothWeights + 12);
@@ -1583,9 +1618,10 @@
   }
 }
 
-void SmoothVertical8x32_SSE4_1(void* const dest, const ptrdiff_t stride,
-                               const void* const top_row,
-                               const void* const left_column) {
+void SmoothVertical8x32_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                               const ptrdiff_t stride,
+                               const void* LIBGAV1_RESTRICT const top_row,
+                               const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const __m128i zero = _mm_setzero_si128();
   const __m128i bottom_left = _mm_set1_epi16(left_ptr[31]);
@@ -1649,9 +1685,10 @@
   }
 }
 
-void SmoothVertical16x4_SSE4_1(void* const dest, const ptrdiff_t stride,
-                               const void* const top_row,
-                               const void* const left_column) {
+void SmoothVertical16x4_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                               const ptrdiff_t stride,
+                               const void* LIBGAV1_RESTRICT const top_row,
+                               const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   const __m128i bottom_left = _mm_set1_epi16(left_ptr[3]);
@@ -1694,9 +1731,10 @@
                               scale);
 }
 
-void SmoothVertical16x8_SSE4_1(void* const dest, const ptrdiff_t stride,
-                               const void* const top_row,
-                               const void* const left_column) {
+void SmoothVertical16x8_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                               const ptrdiff_t stride,
+                               const void* LIBGAV1_RESTRICT const top_row,
+                               const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   const __m128i bottom_left = _mm_set1_epi16(left_ptr[7]);
@@ -1722,9 +1760,10 @@
   }
 }
 
-void SmoothVertical16x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothVertical16x16_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   const __m128i bottom_left = _mm_set1_epi16(left_ptr[15]);
@@ -1766,9 +1805,10 @@
   }
 }
 
-void SmoothVertical16x32_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothVertical16x32_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   const __m128i bottom_left = _mm_set1_epi16(left_ptr[31]);
@@ -1839,9 +1879,10 @@
   }
 }
 
-void SmoothVertical16x64_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothVertical16x64_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   const __m128i bottom_left = _mm_set1_epi16(left_ptr[63]);
@@ -1887,9 +1928,10 @@
   }
 }
 
-void SmoothVertical32x8_SSE4_1(void* const dest, const ptrdiff_t stride,
-                               const void* const top_row,
-                               const void* const left_column) {
+void SmoothVertical32x8_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                               const ptrdiff_t stride,
+                               const void* LIBGAV1_RESTRICT const top_row,
+                               const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   auto* dst = static_cast<uint8_t*>(dest);
@@ -1922,9 +1964,10 @@
   }
 }
 
-void SmoothVertical32x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothVertical32x16_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   auto* dst = static_cast<uint8_t*>(dest);
@@ -1975,9 +2018,10 @@
   }
 }
 
-void SmoothVertical32x32_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothVertical32x32_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
@@ -2063,9 +2107,10 @@
   }
 }
 
-void SmoothVertical32x64_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothVertical32x64_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
@@ -2120,9 +2165,10 @@
   }
 }
 
-void SmoothVertical64x16_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothVertical64x16_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
@@ -2192,9 +2238,10 @@
   }
 }
 
-void SmoothVertical64x32_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothVertical64x32_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
@@ -2311,9 +2358,10 @@
   }
 }
 
-void SmoothVertical64x64_SSE4_1(void* const dest, const ptrdiff_t stride,
-                                const void* const top_row,
-                                const void* const left_column) {
+void SmoothVertical64x64_SSE4_1(
+    void* LIBGAV1_RESTRICT const dest, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_row,
+    const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* dst = static_cast<uint8_t*>(dest);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
diff --git a/libgav1/src/dsp/x86/intrapred_sse4.cc b/libgav1/src/dsp/x86/intrapred_sse4.cc
index 063929d..556afed 100644
--- a/libgav1/src/dsp/x86/intrapred_sse4.cc
+++ b/libgav1/src/dsp/x86/intrapred_sse4.cc
@@ -90,11 +90,11 @@
 
 template <int width_log2, int height_log2, DcSumFunc top_sumfn,
           DcSumFunc left_sumfn, DcStoreFunc storefn, int shiftk, int dc_mult>
-void DcPredFuncs_SSE4_1<width_log2, height_log2, top_sumfn, left_sumfn, storefn,
-                        shiftk, dc_mult>::DcTop(void* const dest,
-                                                ptrdiff_t stride,
-                                                const void* const top_row,
-                                                const void* /*left_column*/) {
+void DcPredFuncs_SSE4_1<
+    width_log2, height_log2, top_sumfn, left_sumfn, storefn, shiftk,
+    dc_mult>::DcTop(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                    const void* LIBGAV1_RESTRICT const top_row,
+                    const void* /*left_column*/) {
   const __m128i rounder = _mm_set1_epi32(1 << (width_log2 - 1));
   const __m128i sum = top_sumfn(top_row);
   const __m128i dc = _mm_srli_epi32(_mm_add_epi32(sum, rounder), width_log2);
@@ -103,11 +103,11 @@
 
 template <int width_log2, int height_log2, DcSumFunc top_sumfn,
           DcSumFunc left_sumfn, DcStoreFunc storefn, int shiftk, int dc_mult>
-void DcPredFuncs_SSE4_1<width_log2, height_log2, top_sumfn, left_sumfn, storefn,
-                        shiftk,
-                        dc_mult>::DcLeft(void* const dest, ptrdiff_t stride,
-                                         const void* /*top_row*/,
-                                         const void* const left_column) {
+void DcPredFuncs_SSE4_1<
+    width_log2, height_log2, top_sumfn, left_sumfn, storefn, shiftk,
+    dc_mult>::DcLeft(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                     const void* /*top_row*/,
+                     const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i rounder = _mm_set1_epi32(1 << (height_log2 - 1));
   const __m128i sum = left_sumfn(left_column);
   const __m128i dc = _mm_srli_epi32(_mm_add_epi32(sum, rounder), height_log2);
@@ -116,10 +116,11 @@
 
 template <int width_log2, int height_log2, DcSumFunc top_sumfn,
           DcSumFunc left_sumfn, DcStoreFunc storefn, int shiftk, int dc_mult>
-void DcPredFuncs_SSE4_1<width_log2, height_log2, top_sumfn, left_sumfn, storefn,
-                        shiftk, dc_mult>::Dc(void* const dest, ptrdiff_t stride,
-                                             const void* const top_row,
-                                             const void* const left_column) {
+void DcPredFuncs_SSE4_1<
+    width_log2, height_log2, top_sumfn, left_sumfn, storefn, shiftk,
+    dc_mult>::Dc(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                 const void* LIBGAV1_RESTRICT const top_row,
+                 const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i rounder =
       _mm_set1_epi32((1 << (width_log2 - 1)) + (1 << (height_log2 - 1)));
   const __m128i sum_top = top_sumfn(top_row);
@@ -141,8 +142,8 @@
 
 template <ColumnStoreFunc col_storefn>
 void DirectionalPredFuncs_SSE4_1<col_storefn>::Horizontal(
-    void* const dest, ptrdiff_t stride, const void* /*top_row*/,
-    const void* const left_column) {
+    void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+    const void* /*top_row*/, const void* LIBGAV1_RESTRICT const left_column) {
   col_storefn(dest, stride, left_column);
 }
 
@@ -384,8 +385,9 @@
 // ColStoreN<height> copies each of the |height| values in |column| across its
 // corresponding in dest.
 template <WriteDuplicateFunc writefn>
-inline void ColStore4_SSE4_1(void* const dest, ptrdiff_t stride,
-                             const void* const column) {
+inline void ColStore4_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                             ptrdiff_t stride,
+                             const void* LIBGAV1_RESTRICT const column) {
   const __m128i col_data = Load4(column);
   const __m128i col_dup16 = _mm_unpacklo_epi8(col_data, col_data);
   const __m128i col_dup32 = _mm_unpacklo_epi16(col_dup16, col_dup16);
@@ -393,8 +395,9 @@
 }
 
 template <WriteDuplicateFunc writefn>
-inline void ColStore8_SSE4_1(void* const dest, ptrdiff_t stride,
-                             const void* const column) {
+inline void ColStore8_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                             ptrdiff_t stride,
+                             const void* LIBGAV1_RESTRICT const column) {
   const ptrdiff_t stride4 = stride << 2;
   const __m128i col_data = LoadLo8(column);
   const __m128i col_dup16 = _mm_unpacklo_epi8(col_data, col_data);
@@ -407,8 +410,9 @@
 }
 
 template <WriteDuplicateFunc writefn>
-inline void ColStore16_SSE4_1(void* const dest, ptrdiff_t stride,
-                              const void* const column) {
+inline void ColStore16_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                              ptrdiff_t stride,
+                              const void* LIBGAV1_RESTRICT const column) {
   const ptrdiff_t stride4 = stride << 2;
   const __m128i col_data = _mm_loadu_si128(static_cast<const __m128i*>(column));
   const __m128i col_dup16_lo = _mm_unpacklo_epi8(col_data, col_data);
@@ -428,8 +432,9 @@
 }
 
 template <WriteDuplicateFunc writefn>
-inline void ColStore32_SSE4_1(void* const dest, ptrdiff_t stride,
-                              const void* const column) {
+inline void ColStore32_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                              ptrdiff_t stride,
+                              const void* LIBGAV1_RESTRICT const column) {
   const ptrdiff_t stride4 = stride << 2;
   auto* dst = static_cast<uint8_t*>(dest);
   for (int y = 0; y < 32; y += 16) {
@@ -457,8 +462,9 @@
 }
 
 template <WriteDuplicateFunc writefn>
-inline void ColStore64_SSE4_1(void* const dest, ptrdiff_t stride,
-                              const void* const column) {
+inline void ColStore64_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                              ptrdiff_t stride,
+                              const void* LIBGAV1_RESTRICT const column) {
   const ptrdiff_t stride4 = stride << 2;
   auto* dst = static_cast<uint8_t*>(dest);
   for (int y = 0; y < 64; y += 16) {
@@ -574,7 +580,7 @@
 };
 
 template <int y_mask>
-inline void WritePaethLine4(uint8_t* dst, const __m128i& top,
+inline void WritePaethLine4(uint8_t* LIBGAV1_RESTRICT dst, const __m128i& top,
                             const __m128i& left, const __m128i& top_lefts,
                             const __m128i& top_dists, const __m128i& left_dists,
                             const __m128i& top_left_diffs) {
@@ -614,7 +620,7 @@
 // could pay off to accommodate top_left_dists for cmpgt, and repack into epi8
 // for the blends.
 template <int y_mask>
-inline void WritePaethLine8(uint8_t* dst, const __m128i& top,
+inline void WritePaethLine8(uint8_t* LIBGAV1_RESTRICT dst, const __m128i& top,
                             const __m128i& left, const __m128i& top_lefts,
                             const __m128i& top_dists, const __m128i& left_dists,
                             const __m128i& top_left_diffs) {
@@ -658,7 +664,7 @@
 // |left_dists| is provided alongside its spread out version because it doesn't
 // change between calls and interacts with both kinds of packing.
 template <int y_mask>
-inline void WritePaethLine16(uint8_t* dst, const __m128i& top,
+inline void WritePaethLine16(uint8_t* LIBGAV1_RESTRICT dst, const __m128i& top,
                              const __m128i& left, const __m128i& top_lefts,
                              const __m128i& top_dists,
                              const __m128i& left_dists,
@@ -712,8 +718,9 @@
   _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), pred);
 }
 
-void Paeth4x4_SSE4_1(void* const dest, ptrdiff_t stride,
-                     const void* const top_row, const void* const left_column) {
+void Paeth4x4_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                     const void* LIBGAV1_RESTRICT const top_row,
+                     const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = _mm_cvtepu8_epi32(Load4(left_column));
   const __m128i top = _mm_cvtepu8_epi32(Load4(top_row));
 
@@ -742,8 +749,9 @@
                         top_left_diff);
 }
 
-void Paeth4x8_SSE4_1(void* const dest, ptrdiff_t stride,
-                     const void* const top_row, const void* const left_column) {
+void Paeth4x8_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                     const void* LIBGAV1_RESTRICT const top_row,
+                     const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = LoadLo8(left_column);
   const __m128i left_lo = _mm_cvtepu8_epi32(left);
   const __m128i left_hi = _mm_cvtepu8_epi32(_mm_srli_si128(left, 4));
@@ -787,9 +795,9 @@
                         top_left_diff);
 }
 
-void Paeth4x16_SSE4_1(void* const dest, ptrdiff_t stride,
-                      const void* const top_row,
-                      const void* const left_column) {
+void Paeth4x16_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT const top_row,
+                      const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = LoadUnaligned16(left_column);
   const __m128i left_0 = _mm_cvtepu8_epi32(left);
   const __m128i left_1 = _mm_cvtepu8_epi32(_mm_srli_si128(left, 4));
@@ -862,8 +870,9 @@
                         top_left_diff);
 }
 
-void Paeth8x4_SSE4_1(void* const dest, ptrdiff_t stride,
-                     const void* const top_row, const void* const left_column) {
+void Paeth8x4_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                     const void* LIBGAV1_RESTRICT const top_row,
+                     const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = _mm_cvtepu8_epi16(Load4(left_column));
   const __m128i top = _mm_cvtepu8_epi16(LoadLo8(top_row));
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
@@ -891,8 +900,9 @@
                               top_left_diff);
 }
 
-void Paeth8x8_SSE4_1(void* const dest, ptrdiff_t stride,
-                     const void* const top_row, const void* const left_column) {
+void Paeth8x8_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                     const void* LIBGAV1_RESTRICT const top_row,
+                     const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = _mm_cvtepu8_epi16(LoadLo8(left_column));
   const __m128i top = _mm_cvtepu8_epi16(LoadLo8(top_row));
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
@@ -932,9 +942,9 @@
                               top_left_diff);
 }
 
-void Paeth8x16_SSE4_1(void* const dest, ptrdiff_t stride,
-                      const void* const top_row,
-                      const void* const left_column) {
+void Paeth8x16_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT const top_row,
+                      const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = LoadUnaligned16(left_column);
   const __m128i left_lo = _mm_cvtepu8_epi16(left);
   const __m128i left_hi = _mm_cvtepu8_epi16(_mm_srli_si128(left, 8));
@@ -1001,18 +1011,18 @@
                               left_dists, top_left_diff);
 }
 
-void Paeth8x32_SSE4_1(void* const dest, ptrdiff_t stride,
-                      const void* const top_row,
-                      const void* const left_column) {
+void Paeth8x32_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT const top_row,
+                      const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   auto* const dst = static_cast<uint8_t*>(dest);
   Paeth8x16_SSE4_1(dst, stride, top_row, left_column);
   Paeth8x16_SSE4_1(dst + (stride << 4), stride, top_row, left_ptr + 16);
 }
 
-void Paeth16x4_SSE4_1(void* const dest, ptrdiff_t stride,
-                      const void* const top_row,
-                      const void* const left_column) {
+void Paeth16x4_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT const top_row,
+                      const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = Load4(left_column);
   const __m128i top = LoadUnaligned16(top_row);
   const __m128i top_lo = _mm_cvtepu8_epi16(top);
@@ -1057,7 +1067,7 @@
 
 // Inlined for calling with offsets in larger transform sizes, mainly to
 // preserve top_left.
-inline void WritePaeth16x8(void* const dest, ptrdiff_t stride,
+inline void WritePaeth16x8(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
                            const uint8_t top_left, const __m128i top,
                            const __m128i left) {
   const __m128i top_lo = _mm_cvtepu8_epi16(top);
@@ -1115,9 +1125,9 @@
                                top_left_diff_lo, top_left_diff_hi);
 }
 
-void Paeth16x8_SSE4_1(void* const dest, ptrdiff_t stride,
-                      const void* const top_row,
-                      const void* const left_column) {
+void Paeth16x8_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT const top_row,
+                      const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i top = LoadUnaligned16(top_row);
   const __m128i left = LoadLo8(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
@@ -1213,18 +1223,18 @@
                                top_left_diff_lo, top_left_diff_hi);
 }
 
-void Paeth16x16_SSE4_1(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
-                       const void* const left_column) {
+void Paeth16x16_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
+                       const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = LoadUnaligned16(left_column);
   const __m128i top = LoadUnaligned16(top_row);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   WritePaeth16x16(static_cast<uint8_t*>(dest), stride, top_ptr[-1], top, left);
 }
 
-void Paeth16x32_SSE4_1(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
-                       const void* const left_column) {
+void Paeth16x32_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
+                       const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left_0 = LoadUnaligned16(left_column);
   const __m128i top = LoadUnaligned16(top_row);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
@@ -1236,9 +1246,9 @@
   WritePaeth16x16(dst + (stride << 4), stride, top_left, top, left_1);
 }
 
-void Paeth16x64_SSE4_1(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
-                       const void* const left_column) {
+void Paeth16x64_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
+                       const void* LIBGAV1_RESTRICT const left_column) {
   const ptrdiff_t stride16 = stride << 4;
   const __m128i left_0 = LoadUnaligned16(left_column);
   const __m128i top = LoadUnaligned16(top_row);
@@ -1258,9 +1268,9 @@
   WritePaeth16x16(dst, stride, top_left, top, left_3);
 }
 
-void Paeth32x8_SSE4_1(void* const dest, ptrdiff_t stride,
-                      const void* const top_row,
-                      const void* const left_column) {
+void Paeth32x8_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                      const void* LIBGAV1_RESTRICT const top_row,
+                      const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = LoadLo8(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   const __m128i top_0 = LoadUnaligned16(top_row);
@@ -1271,9 +1281,9 @@
   WritePaeth16x8(dst + 16, stride, top_left, top_1, left);
 }
 
-void Paeth32x16_SSE4_1(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
-                       const void* const left_column) {
+void Paeth32x16_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
+                       const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = LoadUnaligned16(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   const __m128i top_0 = LoadUnaligned16(top_row);
@@ -1284,9 +1294,9 @@
   WritePaeth16x16(dst + 16, stride, top_left, top_1, left);
 }
 
-void Paeth32x32_SSE4_1(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
-                       const void* const left_column) {
+void Paeth32x32_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
+                       const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const __m128i left_0 = LoadUnaligned16(left_ptr);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
@@ -1302,9 +1312,9 @@
   WritePaeth16x16(dst + 16, stride, top_left, top_1, left_1);
 }
 
-void Paeth32x64_SSE4_1(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
-                       const void* const left_column) {
+void Paeth32x64_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
+                       const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const __m128i left_0 = LoadUnaligned16(left_ptr);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
@@ -1328,9 +1338,9 @@
   WritePaeth16x16(dst + 16, stride, top_left, top_1, left_3);
 }
 
-void Paeth64x16_SSE4_1(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
-                       const void* const left_column) {
+void Paeth64x16_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
+                       const void* LIBGAV1_RESTRICT const left_column) {
   const __m128i left = LoadUnaligned16(left_column);
   const auto* const top_ptr = static_cast<const uint8_t*>(top_row);
   const __m128i top_0 = LoadUnaligned16(top_ptr);
@@ -1345,9 +1355,9 @@
   WritePaeth16x16(dst + 48, stride, top_left, top_3, left);
 }
 
-void Paeth64x32_SSE4_1(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
-                       const void* const left_column) {
+void Paeth64x32_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
+                       const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const __m128i left_0 = LoadUnaligned16(left_ptr);
   const __m128i left_1 = LoadUnaligned16(left_ptr + 16);
@@ -1369,9 +1379,9 @@
   WritePaeth16x16(dst + 48, stride, top_left, top_3, left_1);
 }
 
-void Paeth64x64_SSE4_1(void* const dest, ptrdiff_t stride,
-                       const void* const top_row,
-                       const void* const left_column) {
+void Paeth64x64_SSE4_1(void* LIBGAV1_RESTRICT const dest, ptrdiff_t stride,
+                       const void* LIBGAV1_RESTRICT const top_row,
+                       const void* LIBGAV1_RESTRICT const left_column) {
   const auto* const left_ptr = static_cast<const uint8_t*>(left_column);
   const __m128i left_0 = LoadUnaligned16(left_ptr);
   const __m128i left_1 = LoadUnaligned16(left_ptr + 16);
@@ -1793,7 +1803,6 @@
       DirDefs::_64x64::Horizontal;
 #endif
 }  // NOLINT(readability/fn_size)
-// TODO(petersonab): Split Init8bpp function into family-specific files.
 
 }  // namespace
 }  // namespace low_bitdepth
@@ -1937,16 +1946,18 @@
 // ColStoreN<height> copies each of the |height| values in |column| across its
 // corresponding row in dest.
 template <WriteDuplicateFunc writefn>
-inline void ColStore4_SSE4_1(void* const dest, ptrdiff_t stride,
-                             const void* const column) {
+inline void ColStore4_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                             ptrdiff_t stride,
+                             const void* LIBGAV1_RESTRICT const column) {
   const __m128i col_data = LoadLo8(column);
   const __m128i col_dup32 = _mm_unpacklo_epi16(col_data, col_data);
   writefn(dest, stride, col_dup32);
 }
 
 template <WriteDuplicateFunc writefn>
-inline void ColStore8_SSE4_1(void* const dest, ptrdiff_t stride,
-                             const void* const column) {
+inline void ColStore8_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                             ptrdiff_t stride,
+                             const void* LIBGAV1_RESTRICT const column) {
   const __m128i col_data = LoadUnaligned16(column);
   const __m128i col_dup32_lo = _mm_unpacklo_epi16(col_data, col_data);
   const __m128i col_dup32_hi = _mm_unpackhi_epi16(col_data, col_data);
@@ -1958,8 +1969,9 @@
 }
 
 template <WriteDuplicateFunc writefn>
-inline void ColStore16_SSE4_1(void* const dest, ptrdiff_t stride,
-                              const void* const column) {
+inline void ColStore16_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                              ptrdiff_t stride,
+                              const void* LIBGAV1_RESTRICT const column) {
   const ptrdiff_t stride4 = stride << 2;
   auto* dst = static_cast<uint8_t*>(dest);
   for (int y = 0; y < 32; y += 16) {
@@ -1975,8 +1987,9 @@
 }
 
 template <WriteDuplicateFunc writefn>
-inline void ColStore32_SSE4_1(void* const dest, ptrdiff_t stride,
-                              const void* const column) {
+inline void ColStore32_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                              ptrdiff_t stride,
+                              const void* LIBGAV1_RESTRICT const column) {
   const ptrdiff_t stride4 = stride << 2;
   auto* dst = static_cast<uint8_t*>(dest);
   for (int y = 0; y < 64; y += 16) {
@@ -1992,8 +2005,9 @@
 }
 
 template <WriteDuplicateFunc writefn>
-inline void ColStore64_SSE4_1(void* const dest, ptrdiff_t stride,
-                              const void* const column) {
+inline void ColStore64_SSE4_1(void* LIBGAV1_RESTRICT const dest,
+                              ptrdiff_t stride,
+                              const void* LIBGAV1_RESTRICT const column) {
   const ptrdiff_t stride4 = stride << 2;
   auto* dst = static_cast<uint8_t*>(dest);
   for (int y = 0; y < 128; y += 16) {
diff --git a/libgav1/src/dsp/x86/inverse_transform_sse4.cc b/libgav1/src/dsp/x86/inverse_transform_sse4.cc
index 12c008f..e9ceb87 100644
--- a/libgav1/src/dsp/x86/inverse_transform_sse4.cc
+++ b/libgav1/src/dsp/x86/inverse_transform_sse4.cc
@@ -41,7 +41,8 @@
 #include "src/dsp/inverse_transform.inc"
 
 template <int store_width, int store_count>
-LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* dst, int32_t stride, int32_t idx,
+LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* LIBGAV1_RESTRICT dst,
+                                    int32_t stride, int32_t idx,
                                     const __m128i* s) {
   // NOTE: It is expected that the compiler will unroll these loops.
   if (store_width == 16) {
@@ -63,8 +64,8 @@
 }
 
 template <int load_width, int load_count>
-LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* src, int32_t stride,
-                                   int32_t idx, __m128i* x) {
+LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* LIBGAV1_RESTRICT src,
+                                   int32_t stride, int32_t idx, __m128i* x) {
   // NOTE: It is expected that the compiler will unroll these loops.
   if (load_width == 16) {
     for (int i = 0; i < load_count; i += 4) {
@@ -1638,9 +1639,10 @@
 
 LIBGAV1_ALWAYS_INLINE void Identity4ColumnStoreToFrame(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source) {
+    const int tx_width, const int tx_height,
+    const int16_t* LIBGAV1_RESTRICT source) {
   const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
+  uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
 
   const __m128i v_multiplier_fraction =
       _mm_set1_epi16(static_cast<int16_t>(kIdentity4MultiplierFraction << 3));
@@ -1685,9 +1687,10 @@
 
 LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source) {
+    const int tx_width, const int tx_height,
+    const int16_t* LIBGAV1_RESTRICT source) {
   const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
+  uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
 
   const __m128i v_multiplier_fraction =
       _mm_set1_epi16(static_cast<int16_t>(kIdentity4MultiplierFraction << 3));
@@ -1789,9 +1792,10 @@
 
 LIBGAV1_ALWAYS_INLINE void Identity8ColumnStoreToFrame_SSE4_1(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source) {
+    const int tx_width, const int tx_height,
+    const int16_t* LIBGAV1_RESTRICT source) {
   const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
+  uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
   const __m128i v_eight = _mm_set1_epi16(8);
   if (tx_width == 4) {
     int i = 0;
@@ -1883,9 +1887,10 @@
 
 LIBGAV1_ALWAYS_INLINE void Identity16ColumnStoreToFrame_SSE4_1(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source) {
+    const int tx_width, const int tx_height,
+    const int16_t* LIBGAV1_RESTRICT source) {
   const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
+  uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
   const __m128i v_eight = _mm_set1_epi16(8);
   const __m128i v_multiplier =
       _mm_set1_epi16(static_cast<int16_t>(kIdentity4MultiplierFraction << 4));
@@ -1966,9 +1971,10 @@
 
 LIBGAV1_ALWAYS_INLINE void Identity32ColumnStoreToFrame(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source) {
+    const int tx_width, const int tx_height,
+    const int16_t* LIBGAV1_RESTRICT source) {
   const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
+  uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
   const __m128i v_two = _mm_set1_epi16(2);
 
   int i = 0;
@@ -1995,7 +2001,7 @@
 // Process 4 wht4 rows and columns.
 LIBGAV1_ALWAYS_INLINE void Wht4_SSE4_1(Array2DView<uint8_t> frame,
                                        const int start_x, const int start_y,
-                                       const void* source,
+                                       const void* LIBGAV1_RESTRICT source,
                                        const int adjusted_tx_height) {
   const auto* const src = static_cast<const int16_t*>(source);
   __m128i s[4], x[4];
@@ -2058,12 +2064,11 @@
 
   // Store to frame.
   const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
+  uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
   for (int row = 0; row < 4; ++row) {
     const __m128i frame_data = Load4(dst);
     const __m128i a = _mm_cvtepu8_epi16(frame_data);
-    // Saturate to prevent overflowing int16_t
-    const __m128i b = _mm_adds_epi16(a, s[row]);
+    const __m128i b = _mm_add_epi16(a, s[row]);
     Store4(dst, _mm_packus_epi16(b, b));
     dst += stride;
   }
@@ -2075,13 +2080,13 @@
 template <bool enable_flip_rows = false>
 LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
     Array2DView<uint8_t> frame, const int start_x, const int start_y,
-    const int tx_width, const int tx_height, const int16_t* source,
-    TransformType tx_type) {
+    const int tx_width, const int tx_height,
+    const int16_t* LIBGAV1_RESTRICT source, TransformType tx_type) {
   const bool flip_rows =
       enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
   const __m128i v_eight = _mm_set1_epi16(8);
   const int stride = frame.columns();
-  uint8_t* dst = frame[start_y] + start_x;
+  uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
   if (tx_width == 4) {
     for (int i = 0; i < tx_height; ++i) {
       const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4;
@@ -2262,8 +2267,10 @@
 
 void Dct4TransformLoopColumn_SSE4_1(TransformType tx_type,
                                     TransformSize tx_size,
-                                    int adjusted_tx_height, void* src_buffer,
-                                    int start_x, int start_y, void* dst_frame) {
+                                    int adjusted_tx_height,
+                                    void* LIBGAV1_RESTRICT src_buffer,
+                                    int start_x, int start_y,
+                                    void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2325,8 +2332,10 @@
 
 void Dct8TransformLoopColumn_SSE4_1(TransformType tx_type,
                                     TransformSize tx_size,
-                                    int adjusted_tx_height, void* src_buffer,
-                                    int start_x, int start_y, void* dst_frame) {
+                                    int adjusted_tx_height,
+                                    void* LIBGAV1_RESTRICT src_buffer,
+                                    int start_x, int start_y,
+                                    void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2386,9 +2395,10 @@
 
 void Dct16TransformLoopColumn_SSE4_1(TransformType tx_type,
                                      TransformSize tx_size,
-                                     int adjusted_tx_height, void* src_buffer,
+                                     int adjusted_tx_height,
+                                     void* LIBGAV1_RESTRICT src_buffer,
                                      int start_x, int start_y,
-                                     void* dst_frame) {
+                                     void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2441,9 +2451,10 @@
 
 void Dct32TransformLoopColumn_SSE4_1(TransformType tx_type,
                                      TransformSize tx_size,
-                                     int adjusted_tx_height, void* src_buffer,
+                                     int adjusted_tx_height,
+                                     void* LIBGAV1_RESTRICT src_buffer,
                                      int start_x, int start_y,
-                                     void* dst_frame) {
+                                     void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2486,9 +2497,10 @@
 
 void Dct64TransformLoopColumn_SSE4_1(TransformType tx_type,
                                      TransformSize tx_size,
-                                     int adjusted_tx_height, void* src_buffer,
+                                     int adjusted_tx_height,
+                                     void* LIBGAV1_RESTRICT src_buffer,
                                      int start_x, int start_y,
-                                     void* dst_frame) {
+                                     void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2535,9 +2547,10 @@
 
 void Adst4TransformLoopColumn_SSE4_1(TransformType tx_type,
                                      TransformSize tx_size,
-                                     int adjusted_tx_height, void* src_buffer,
+                                     int adjusted_tx_height,
+                                     void* LIBGAV1_RESTRICT src_buffer,
                                      int start_x, int start_y,
-                                     void* dst_frame) {
+                                     void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2594,9 +2607,10 @@
 
 void Adst8TransformLoopColumn_SSE4_1(TransformType tx_type,
                                      TransformSize tx_size,
-                                     int adjusted_tx_height, void* src_buffer,
+                                     int adjusted_tx_height,
+                                     void* LIBGAV1_RESTRICT src_buffer,
                                      int start_x, int start_y,
-                                     void* dst_frame) {
+                                     void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2658,9 +2672,10 @@
 
 void Adst16TransformLoopColumn_SSE4_1(TransformType tx_type,
                                       TransformSize tx_size,
-                                      int adjusted_tx_height, void* src_buffer,
+                                      int adjusted_tx_height,
+                                      void* LIBGAV1_RESTRICT src_buffer,
                                       int start_x, int start_y,
-                                      void* dst_frame) {
+                                      void* LIBGAV1_RESTRICT dst_frame) {
   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2727,8 +2742,9 @@
 void Identity4TransformLoopColumn_SSE4_1(TransformType tx_type,
                                          TransformSize tx_size,
                                          int adjusted_tx_height,
-                                         void* src_buffer, int start_x,
-                                         int start_y, void* dst_frame) {
+                                         void* LIBGAV1_RESTRICT src_buffer,
+                                         int start_x, int start_y,
+                                         void* LIBGAV1_RESTRICT dst_frame) {
   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2799,8 +2815,9 @@
 void Identity8TransformLoopColumn_SSE4_1(TransformType tx_type,
                                          TransformSize tx_size,
                                          int adjusted_tx_height,
-                                         void* src_buffer, int start_x,
-                                         int start_y, void* dst_frame) {
+                                         void* LIBGAV1_RESTRICT src_buffer,
+                                         int start_x, int start_y,
+                                         void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2839,8 +2856,9 @@
 void Identity16TransformLoopColumn_SSE4_1(TransformType tx_type,
                                           TransformSize tx_size,
                                           int adjusted_tx_height,
-                                          void* src_buffer, int start_x,
-                                          int start_y, void* dst_frame) {
+                                          void* LIBGAV1_RESTRICT src_buffer,
+                                          int start_x, int start_y,
+                                          void* LIBGAV1_RESTRICT dst_frame) {
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
 
@@ -2884,8 +2902,9 @@
 void Identity32TransformLoopColumn_SSE4_1(TransformType /*tx_type*/,
                                           TransformSize tx_size,
                                           int adjusted_tx_height,
-                                          void* src_buffer, int start_x,
-                                          int start_y, void* dst_frame) {
+                                          void* LIBGAV1_RESTRICT src_buffer,
+                                          int start_x, int start_y,
+                                          void* LIBGAV1_RESTRICT dst_frame) {
   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
   auto* src = static_cast<int16_t*>(src_buffer);
   const int tx_width = kTransformWidth[tx_size];
@@ -2907,8 +2926,10 @@
 
 void Wht4TransformLoopColumn_SSE4_1(TransformType tx_type,
                                     TransformSize tx_size,
-                                    int adjusted_tx_height, void* src_buffer,
-                                    int start_x, int start_y, void* dst_frame) {
+                                    int adjusted_tx_height,
+                                    void* LIBGAV1_RESTRICT src_buffer,
+                                    int start_x, int start_y,
+                                    void* LIBGAV1_RESTRICT dst_frame) {
   assert(tx_type == kTransformTypeDctDct);
   assert(tx_size == kTransformSize4x4);
   static_cast<void>(tx_type);
@@ -2928,88 +2949,88 @@
   assert(dsp != nullptr);
 
   // Maximum transform size for Dct is 64.
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformDct)
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize4_Transform1dDct)
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
       Dct4TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
       Dct4TransformLoopColumn_SSE4_1;
 #endif
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize8_1DTransformDct)
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize8_Transform1dDct)
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
       Dct8TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize8][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
       Dct8TransformLoopColumn_SSE4_1;
 #endif
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize16_1DTransformDct)
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize16_Transform1dDct)
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
       Dct16TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize16][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
       Dct16TransformLoopColumn_SSE4_1;
 #endif
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize32_1DTransformDct)
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize32_Transform1dDct)
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
       Dct32TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize32][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
       Dct32TransformLoopColumn_SSE4_1;
 #endif
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize64_1DTransformDct)
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize64_Transform1dDct)
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
       Dct64TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformDct][k1DTransformSize64][kColumn] =
+  dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
       Dct64TransformLoopColumn_SSE4_1;
 #endif
 
   // Maximum transform size for Adst is 16.
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformAdst)
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize4_Transform1dAdst)
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
       Adst4TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
       Adst4TransformLoopColumn_SSE4_1;
 #endif
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize8_1DTransformAdst)
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize8_Transform1dAdst)
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
       Adst8TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize8][kColumn] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
       Adst8TransformLoopColumn_SSE4_1;
 #endif
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize16_1DTransformAdst)
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize16_Transform1dAdst)
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
       Adst16TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformAdst][k1DTransformSize16][kColumn] =
+  dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
       Adst16TransformLoopColumn_SSE4_1;
 #endif
 
   // Maximum transform size for Identity transform is 32.
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformIdentity)
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize4_Transform1dIdentity)
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
       Identity4TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
       Identity4TransformLoopColumn_SSE4_1;
 #endif
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize8_1DTransformIdentity)
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize8_Transform1dIdentity)
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
       Identity8TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize8][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
       Identity8TransformLoopColumn_SSE4_1;
 #endif
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize16_1DTransformIdentity)
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize16_Transform1dIdentity)
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
       Identity16TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize16][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
       Identity16TransformLoopColumn_SSE4_1;
 #endif
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize32_1DTransformIdentity)
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize32_Transform1dIdentity)
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
       Identity32TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformIdentity][k1DTransformSize32][kColumn] =
+  dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
       Identity32TransformLoopColumn_SSE4_1;
 #endif
 
   // Maximum transform size for Wht is 4.
-#if DSP_ENABLED_8BPP_SSE4_1(1DTransformSize4_1DTransformWht)
-  dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kRow] =
+#if DSP_ENABLED_8BPP_SSE4_1(Transform1dSize4_Transform1dWht)
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
       Wht4TransformLoopRow_SSE4_1;
-  dsp->inverse_transforms[k1DTransformWht][k1DTransformSize4][kColumn] =
+  dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
       Wht4TransformLoopColumn_SSE4_1;
 #endif
 }
diff --git a/libgav1/src/dsp/x86/inverse_transform_sse4.h b/libgav1/src/dsp/x86/inverse_transform_sse4.h
index 106084b..c31e88b 100644
--- a/libgav1/src/dsp/x86/inverse_transform_sse4.h
+++ b/libgav1/src/dsp/x86/inverse_transform_sse4.h
@@ -34,56 +34,56 @@
 // optimization being enabled, signal the sse4 implementation should be used.
 #if LIBGAV1_TARGETING_SSE4_1
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformDct LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dDct
+#define LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dDct LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformDct LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dDct
+#define LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dDct LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformDct LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dDct
+#define LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dDct LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct
-#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformDct LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize32_Transform1dDct
+#define LIBGAV1_Dsp8bpp_Transform1dSize32_Transform1dDct LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct
-#define LIBGAV1_Dsp8bpp_1DTransformSize64_1DTransformDct LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize64_Transform1dDct
+#define LIBGAV1_Dsp8bpp_Transform1dSize64_Transform1dDct LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformAdst LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dAdst
+#define LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dAdst LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformAdst LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dAdst
+#define LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dAdst LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformAdst LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dAdst
+#define LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dAdst LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformIdentity LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dIdentity
+#define LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dIdentity LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity
-#define LIBGAV1_Dsp8bpp_1DTransformSize8_1DTransformIdentity LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dIdentity
+#define LIBGAV1_Dsp8bpp_Transform1dSize8_Transform1dIdentity LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity
-#define LIBGAV1_Dsp8bpp_1DTransformSize16_1DTransformIdentity LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dIdentity
+#define LIBGAV1_Dsp8bpp_Transform1dSize16_Transform1dIdentity LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity
-#define LIBGAV1_Dsp8bpp_1DTransformSize32_1DTransformIdentity LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize32_Transform1dIdentity
+#define LIBGAV1_Dsp8bpp_Transform1dSize32_Transform1dIdentity LIBGAV1_CPU_SSE4_1
 #endif
 
-#ifndef LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht
-#define LIBGAV1_Dsp8bpp_1DTransformSize4_1DTransformWht LIBGAV1_CPU_SSE4_1
+#ifndef LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dWht
+#define LIBGAV1_Dsp8bpp_Transform1dSize4_Transform1dWht LIBGAV1_CPU_SSE4_1
 #endif
 #endif  // LIBGAV1_TARGETING_SSE4_1
 #endif  // LIBGAV1_SRC_DSP_X86_INVERSE_TRANSFORM_SSE4_H_
diff --git a/libgav1/src/dsp/x86/loop_restoration_10bit_avx2.cc b/libgav1/src/dsp/x86/loop_restoration_10bit_avx2.cc
index b38f322..daf5c42 100644
--- a/libgav1/src/dsp/x86/loop_restoration_10bit_avx2.cc
+++ b/libgav1/src/dsp/x86/loop_restoration_10bit_avx2.cc
@@ -472,11 +472,14 @@
 }
 
 void WienerFilter_AVX2(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int16_t* const number_leading_zero_coefficients =
       restoration_info.wiener_info.number_leading_zero_coefficients;
   const int number_rows_to_skip = std::max(
@@ -3097,11 +3100,14 @@
 // in the end of each row. It is safe to overwrite the output as it will not be
 // part of the visible frame.
 void SelfGuidedFilter_AVX2(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int index = restoration_info.sgr_proj_info.index;
   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
diff --git a/libgav1/src/dsp/x86/loop_restoration_10bit_sse4.cc b/libgav1/src/dsp/x86/loop_restoration_10bit_sse4.cc
index 96380e3..6625d51 100644
--- a/libgav1/src/dsp/x86/loop_restoration_10bit_sse4.cc
+++ b/libgav1/src/dsp/x86/loop_restoration_10bit_sse4.cc
@@ -429,11 +429,14 @@
 }
 
 void WienerFilter_SSE4_1(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int16_t* const number_leading_zero_coefficients =
       restoration_info.wiener_info.number_leading_zero_coefficients;
   const int number_rows_to_skip = std::max(
@@ -2465,11 +2468,14 @@
 // in the end of each row. It is safe to overwrite the output as it will not be
 // part of the visible frame.
 void SelfGuidedFilter_SSE4_1(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int index = restoration_info.sgr_proj_info.index;
   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
diff --git a/libgav1/src/dsp/x86/loop_restoration_avx2.cc b/libgav1/src/dsp/x86/loop_restoration_avx2.cc
index 351a324..30e8a22 100644
--- a/libgav1/src/dsp/x86/loop_restoration_avx2.cc
+++ b/libgav1/src/dsp/x86/loop_restoration_avx2.cc
@@ -483,11 +483,14 @@
 }
 
 void WienerFilter_AVX2(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int16_t* const number_leading_zero_coefficients =
       restoration_info.wiener_info.number_leading_zero_coefficients;
   const int number_rows_to_skip = std::max(
@@ -2880,11 +2883,14 @@
 // in the end of each row. It is safe to overwrite the output as it will not be
 // part of the visible frame.
 void SelfGuidedFilter_AVX2(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int index = restoration_info.sgr_proj_info.index;
   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
diff --git a/libgav1/src/dsp/x86/loop_restoration_sse4.cc b/libgav1/src/dsp/x86/loop_restoration_sse4.cc
index 273bcc8..3363f0e 100644
--- a/libgav1/src/dsp/x86/loop_restoration_sse4.cc
+++ b/libgav1/src/dsp/x86/loop_restoration_sse4.cc
@@ -482,11 +482,14 @@
 }
 
 void WienerFilter_SSE4_1(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int16_t* const number_leading_zero_coefficients =
       restoration_info.wiener_info.number_leading_zero_coefficients;
   const int number_rows_to_skip = std::max(
@@ -2510,11 +2513,14 @@
 // in the end of each row. It is safe to overwrite the output as it will not be
 // part of the visible frame.
 void SelfGuidedFilter_SSE4_1(
-    const RestorationUnitInfo& restoration_info, const void* const source,
-    const ptrdiff_t stride, const void* const top_border,
-    const ptrdiff_t top_border_stride, const void* const bottom_border,
+    const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
+    const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
+    const void* LIBGAV1_RESTRICT const top_border,
+    const ptrdiff_t top_border_stride,
+    const void* LIBGAV1_RESTRICT const bottom_border,
     const ptrdiff_t bottom_border_stride, const int width, const int height,
-    RestorationBuffer* const restoration_buffer, void* const dest) {
+    RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
+    void* LIBGAV1_RESTRICT const dest) {
   const int index = restoration_info.sgr_proj_info.index;
   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
diff --git a/libgav1/src/dsp/x86/mask_blend_sse4.cc b/libgav1/src/dsp/x86/mask_blend_sse4.cc
index 2e836af..a18444b 100644
--- a/libgav1/src/dsp/x86/mask_blend_sse4.cc
+++ b/libgav1/src/dsp/x86/mask_blend_sse4.cc
@@ -36,7 +36,8 @@
 // Width can only be 4 when it is subsampled from a block of width 8, hence
 // subsampling_x is always 1 when this function is called.
 template <int subsampling_x, int subsampling_y>
-inline __m128i GetMask4x2(const uint8_t* mask, ptrdiff_t mask_stride) {
+inline __m128i GetMask4x2(const uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   if (subsampling_x == 1) {
     const __m128i mask_val_0 = _mm_cvtepu8_epi16(LoadLo8(mask));
     const __m128i mask_val_1 =
@@ -62,7 +63,8 @@
 // 16-bit is also the lowest packing for hadd, but without subsampling there is
 // an unfortunate conversion required.
 template <int subsampling_x, int subsampling_y>
-inline __m128i GetMask8(const uint8_t* mask, ptrdiff_t stride) {
+inline __m128i GetMask8(const uint8_t* LIBGAV1_RESTRICT mask,
+                        ptrdiff_t stride) {
   if (subsampling_x == 1) {
     const __m128i row_vals = LoadUnaligned16(mask);
 
@@ -89,7 +91,8 @@
 // when is_inter_intra is true, the prediction values are brought to 8-bit
 // packing as well.
 template <int subsampling_x, int subsampling_y>
-inline __m128i GetInterIntraMask8(const uint8_t* mask, ptrdiff_t stride) {
+inline __m128i GetInterIntraMask8(const uint8_t* LIBGAV1_RESTRICT mask,
+                                  ptrdiff_t stride) {
   if (subsampling_x == 1) {
     const __m128i row_vals = LoadUnaligned16(mask);
 
@@ -116,10 +119,11 @@
   return mask_val;
 }
 
-inline void WriteMaskBlendLine4x2(const int16_t* const pred_0,
-                                  const int16_t* const pred_1,
+inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0,
+                                  const int16_t* LIBGAV1_RESTRICT const pred_1,
                                   const __m128i pred_mask_0,
-                                  const __m128i pred_mask_1, uint8_t* dst,
+                                  const __m128i pred_mask_1,
+                                  uint8_t* LIBGAV1_RESTRICT dst,
                                   const ptrdiff_t dst_stride) {
   const __m128i pred_val_0 = LoadAligned16(pred_0);
   const __m128i pred_val_1 = LoadAligned16(pred_1);
@@ -145,9 +149,11 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void MaskBlending4x4_SSE4(const int16_t* pred_0, const int16_t* pred_1,
-                                 const uint8_t* mask,
-                                 const ptrdiff_t mask_stride, uint8_t* dst,
+inline void MaskBlending4x4_SSE4(const int16_t* LIBGAV1_RESTRICT pred_0,
+                                 const int16_t* LIBGAV1_RESTRICT pred_1,
+                                 const uint8_t* LIBGAV1_RESTRICT mask,
+                                 const ptrdiff_t mask_stride,
+                                 uint8_t* LIBGAV1_RESTRICT dst,
                                  const ptrdiff_t dst_stride) {
   const __m128i mask_inverter = _mm_set1_epi16(64);
   __m128i pred_mask_0 =
@@ -167,10 +173,12 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void MaskBlending4xH_SSE4(const int16_t* pred_0, const int16_t* pred_1,
-                                 const uint8_t* const mask_ptr,
+inline void MaskBlending4xH_SSE4(const int16_t* LIBGAV1_RESTRICT pred_0,
+                                 const int16_t* LIBGAV1_RESTRICT pred_1,
+                                 const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
                                  const ptrdiff_t mask_stride, const int height,
-                                 uint8_t* dst, const ptrdiff_t dst_stride) {
+                                 uint8_t* LIBGAV1_RESTRICT dst,
+                                 const ptrdiff_t dst_stride) {
   const uint8_t* mask = mask_ptr;
   if (height == 4) {
     MaskBlending4x4_SSE4<subsampling_x, subsampling_y>(
@@ -222,11 +230,12 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void MaskBlend_SSE4(const void* prediction_0, const void* prediction_1,
+inline void MaskBlend_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                           const void* LIBGAV1_RESTRICT prediction_1,
                            const ptrdiff_t /*prediction_stride_1*/,
-                           const uint8_t* const mask_ptr,
+                           const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
                            const ptrdiff_t mask_stride, const int width,
-                           const int height, void* dest,
+                           const int height, void* LIBGAV1_RESTRICT dest,
                            const ptrdiff_t dst_stride) {
   auto* dst = static_cast<uint8_t*>(dest);
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
@@ -277,11 +286,10 @@
   } while (++y < height);
 }
 
-inline void InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t* const pred_0,
-                                                uint8_t* const pred_1,
-                                                const ptrdiff_t pred_stride_1,
-                                                const __m128i pred_mask_0,
-                                                const __m128i pred_mask_1) {
+inline void InterIntraWriteMaskBlendLine8bpp4x2(
+    const uint8_t* LIBGAV1_RESTRICT const pred_0,
+    uint8_t* LIBGAV1_RESTRICT const pred_1, const ptrdiff_t pred_stride_1,
+    const __m128i pred_mask_0, const __m128i pred_mask_1) {
   const __m128i pred_mask = _mm_unpacklo_epi8(pred_mask_0, pred_mask_1);
 
   const __m128i pred_val_0 = LoadLo8(pred_0);
@@ -301,11 +309,10 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void InterIntraMaskBlending8bpp4x4_SSE4(const uint8_t* pred_0,
-                                               uint8_t* pred_1,
-                                               const ptrdiff_t pred_stride_1,
-                                               const uint8_t* mask,
-                                               const ptrdiff_t mask_stride) {
+inline void InterIntraMaskBlending8bpp4x4_SSE4(
+    const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
+    const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
+    const ptrdiff_t mask_stride) {
   const __m128i mask_inverter = _mm_set1_epi8(64);
   const __m128i pred_mask_u16_first =
       GetMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
@@ -328,12 +335,11 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void InterIntraMaskBlending8bpp4xH_SSE4(const uint8_t* pred_0,
-                                               uint8_t* pred_1,
-                                               const ptrdiff_t pred_stride_1,
-                                               const uint8_t* const mask_ptr,
-                                               const ptrdiff_t mask_stride,
-                                               const int height) {
+inline void InterIntraMaskBlending8bpp4xH_SSE4(
+    const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
+    const ptrdiff_t pred_stride_1,
+    const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
+    const int height) {
   const uint8_t* mask = mask_ptr;
   if (height == 4) {
     InterIntraMaskBlending8bpp4x4_SSE4<subsampling_x, subsampling_y>(
@@ -358,12 +364,11 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-void InterIntraMaskBlend8bpp_SSE4(const uint8_t* prediction_0,
-                                  uint8_t* prediction_1,
-                                  const ptrdiff_t prediction_stride_1,
-                                  const uint8_t* const mask_ptr,
-                                  const ptrdiff_t mask_stride, const int width,
-                                  const int height) {
+void InterIntraMaskBlend8bpp_SSE4(
+    const uint8_t* LIBGAV1_RESTRICT prediction_0,
+    uint8_t* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t prediction_stride_1,
+    const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
+    const int width, const int height) {
   if (width == 4) {
     InterIntraMaskBlending8bpp4xH_SSE4<subsampling_x, subsampling_y>(
         prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
@@ -503,10 +508,11 @@
 }
 
 inline void WriteMaskBlendLine10bpp4x2_SSE4_1(
-    const uint16_t* pred_0, const uint16_t* pred_1,
-    const ptrdiff_t pred_stride_1, const __m128i& pred_mask_0,
-    const __m128i& pred_mask_1, const __m128i& offset, const __m128i& max,
-    const __m128i& shift4, uint16_t* dst, const ptrdiff_t dst_stride) {
+    const uint16_t* LIBGAV1_RESTRICT pred_0,
+    const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
+    const __m128i& pred_mask_0, const __m128i& pred_mask_1,
+    const __m128i& offset, const __m128i& max, const __m128i& shift4,
+    uint16_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
   const __m128i pred_val_0 = LoadUnaligned16(pred_0);
   const __m128i pred_val_1 = LoadHi8(LoadLo8(pred_1), pred_1 + pred_stride_1);
 
@@ -544,11 +550,12 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void MaskBlend10bpp4x4_SSE4_1(const uint16_t* pred_0,
-                                     const uint16_t* pred_1,
+inline void MaskBlend10bpp4x4_SSE4_1(const uint16_t* LIBGAV1_RESTRICT pred_0,
+                                     const uint16_t* LIBGAV1_RESTRICT pred_1,
                                      const ptrdiff_t pred_stride_1,
-                                     const uint8_t* mask,
-                                     const ptrdiff_t mask_stride, uint16_t* dst,
+                                     const uint8_t* LIBGAV1_RESTRICT mask,
+                                     const ptrdiff_t mask_stride,
+                                     uint16_t* LIBGAV1_RESTRICT dst,
                                      const ptrdiff_t dst_stride) {
   const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
   const __m128i zero = _mm_setzero_si128();
@@ -575,13 +582,12 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void MaskBlend10bpp4xH_SSE4_1(const uint16_t* pred_0,
-                                     const uint16_t* pred_1,
-                                     const ptrdiff_t pred_stride_1,
-                                     const uint8_t* const mask_ptr,
-                                     const ptrdiff_t mask_stride,
-                                     const int height, uint16_t* dst,
-                                     const ptrdiff_t dst_stride) {
+inline void MaskBlend10bpp4xH_SSE4_1(
+    const uint16_t* LIBGAV1_RESTRICT pred_0,
+    const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
+    const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
+    const int height, uint16_t* LIBGAV1_RESTRICT dst,
+    const ptrdiff_t dst_stride) {
   const uint8_t* mask = mask_ptr;
   if (height == 4) {
     MaskBlend10bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
@@ -648,13 +654,13 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void MaskBlend10bpp_SSE4_1(const void* prediction_0,
-                                  const void* prediction_1,
-                                  const ptrdiff_t prediction_stride_1,
-                                  const uint8_t* const mask_ptr,
-                                  const ptrdiff_t mask_stride, const int width,
-                                  const int height, void* dest,
-                                  const ptrdiff_t dest_stride) {
+inline void MaskBlend10bpp_SSE4_1(
+    const void* LIBGAV1_RESTRICT prediction_0,
+    const void* LIBGAV1_RESTRICT prediction_1,
+    const ptrdiff_t prediction_stride_1,
+    const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
+    const int width, const int height, void* LIBGAV1_RESTRICT dest,
+    const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint16_t*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]);
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
@@ -725,10 +731,11 @@
 }
 
 inline void InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(
-    const uint16_t* prediction_0, const uint16_t* prediction_1,
+    const uint16_t* LIBGAV1_RESTRICT prediction_0,
+    const uint16_t* LIBGAV1_RESTRICT prediction_1,
     const ptrdiff_t pred_stride_1, const __m128i& pred_mask_0,
-    const __m128i& pred_mask_1, const __m128i& shift6, uint16_t* dst,
-    const ptrdiff_t dst_stride) {
+    const __m128i& pred_mask_1, const __m128i& shift6,
+    uint16_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
   const __m128i pred_val_0 = LoadUnaligned16(prediction_0);
   const __m128i pred_val_1 =
       LoadHi8(LoadLo8(prediction_1), prediction_1 + pred_stride_1);
@@ -751,9 +758,10 @@
 
 template <int subsampling_x, int subsampling_y>
 inline void InterIntraMaskBlend10bpp4x4_SSE4_1(
-    const uint16_t* pred_0, const uint16_t* pred_1,
-    const ptrdiff_t pred_stride_1, const uint8_t* mask,
-    const ptrdiff_t mask_stride, uint16_t* dst, const ptrdiff_t dst_stride) {
+    const uint16_t* LIBGAV1_RESTRICT pred_0,
+    const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
+    const uint8_t* LIBGAV1_RESTRICT mask, const ptrdiff_t mask_stride,
+    uint16_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
   const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
   const __m128i shift6 = _mm_set1_epi32((1 << 6) >> 1);
   const __m128i zero = _mm_setzero_si128();
@@ -777,13 +785,12 @@
 }
 
 template <int subsampling_x, int subsampling_y>
-inline void InterIntraMaskBlend10bpp4xH_SSE4_1(const uint16_t* pred_0,
-                                               const uint16_t* pred_1,
-                                               const ptrdiff_t pred_stride_1,
-                                               const uint8_t* const mask_ptr,
-                                               const ptrdiff_t mask_stride,
-                                               const int height, uint16_t* dst,
-                                               const ptrdiff_t dst_stride) {
+inline void InterIntraMaskBlend10bpp4xH_SSE4_1(
+    const uint16_t* LIBGAV1_RESTRICT pred_0,
+    const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
+    const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
+    const int height, uint16_t* LIBGAV1_RESTRICT dst,
+    const ptrdiff_t dst_stride) {
   const uint8_t* mask = mask_ptr;
   if (height == 4) {
     InterIntraMaskBlend10bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
@@ -848,9 +855,11 @@
 
 template <int subsampling_x, int subsampling_y>
 inline void InterIntraMaskBlend10bpp_SSE4_1(
-    const void* prediction_0, const void* prediction_1,
-    const ptrdiff_t prediction_stride_1, const uint8_t* const mask_ptr,
-    const ptrdiff_t mask_stride, const int width, const int height, void* dest,
+    const void* LIBGAV1_RESTRICT prediction_0,
+    const void* LIBGAV1_RESTRICT prediction_1,
+    const ptrdiff_t prediction_stride_1,
+    const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
+    const int width, const int height, void* LIBGAV1_RESTRICT dest,
     const ptrdiff_t dest_stride) {
   auto* dst = static_cast<uint16_t*>(dest);
   const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]);
diff --git a/libgav1/src/dsp/x86/motion_field_projection_sse4.cc b/libgav1/src/dsp/x86/motion_field_projection_sse4.cc
index e3f2cce..5641531 100644
--- a/libgav1/src/dsp/x86/motion_field_projection_sse4.cc
+++ b/libgav1/src/dsp/x86/motion_field_projection_sse4.cc
@@ -360,27 +360,12 @@
   } while (++y8 < y8_end);
 }
 
-void Init8bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
-  assert(dsp != nullptr);
-  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_SSE4_1;
-}
-
-#if LIBGAV1_MAX_BITDEPTH >= 10
-void Init10bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
-  assert(dsp != nullptr);
-  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_SSE4_1;
-}
-#endif
-
 }  // namespace
 
 void MotionFieldProjectionInit_SSE4_1() {
-  Init8bpp();
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  Init10bpp();
-#endif
+  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
+  assert(dsp != nullptr);
+  dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_SSE4_1;
 }
 
 }  // namespace dsp
diff --git a/libgav1/src/dsp/x86/motion_vector_search_sse4.cc b/libgav1/src/dsp/x86/motion_vector_search_sse4.cc
index 7f5f035..dacc6ec 100644
--- a/libgav1/src/dsp/x86/motion_vector_search_sse4.cc
+++ b/libgav1/src/dsp/x86/motion_vector_search_sse4.cc
@@ -64,7 +64,7 @@
 }
 
 inline __m128i MvProjectionCompoundClip(
-    const MotionVector* const temporal_mvs,
+    const MotionVector* LIBGAV1_RESTRICT const temporal_mvs,
     const int8_t temporal_reference_offsets[2],
     const int reference_offsets[2]) {
   const auto* const tmvs = reinterpret_cast<const int32_t*>(temporal_mvs);
@@ -83,8 +83,8 @@
 }
 
 inline __m128i MvProjectionSingleClip(
-    const MotionVector* const temporal_mvs,
-    const int8_t* const temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT const temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT const temporal_reference_offsets,
     const int reference_offset) {
   const auto* const tmvs = reinterpret_cast<const int16_t*>(temporal_mvs);
   const __m128i temporal_mv = LoadAligned16(tmvs);
@@ -126,9 +126,10 @@
 }
 
 void MvProjectionCompoundLowPrecision_SSE4_1(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
     const int reference_offsets[2], const int count,
-    CompoundMotionVector* candidate_mvs) {
+    CompoundMotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // |reference_offsets| non-zero check usually equals true and is ignored.
   // To facilitate the compilers, make a local copy of |reference_offsets|.
   const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
@@ -143,9 +144,10 @@
 }
 
 void MvProjectionCompoundForceInteger_SSE4_1(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
     const int reference_offsets[2], const int count,
-    CompoundMotionVector* candidate_mvs) {
+    CompoundMotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // |reference_offsets| non-zero check usually equals true and is ignored.
   // To facilitate the compilers, make a local copy of |reference_offsets|.
   const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
@@ -160,9 +162,10 @@
 }
 
 void MvProjectionCompoundHighPrecision_SSE4_1(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
     const int reference_offsets[2], const int count,
-    CompoundMotionVector* candidate_mvs) {
+    CompoundMotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // |reference_offsets| non-zero check usually equals true and is ignored.
   // To facilitate the compilers, make a local copy of |reference_offsets|.
   const int offsets[2] = {reference_offsets[0], reference_offsets[1]};
@@ -177,8 +180,10 @@
 }
 
 void MvProjectionSingleLowPrecision_SSE4_1(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
-    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
+    const int reference_offset, const int count,
+    MotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // Up to three more elements could be calculated.
   int i = 0;
   do {
@@ -190,8 +195,10 @@
 }
 
 void MvProjectionSingleForceInteger_SSE4_1(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
-    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
+    const int reference_offset, const int count,
+    MotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // Up to three more elements could be calculated.
   int i = 0;
   do {
@@ -203,8 +210,10 @@
 }
 
 void MvProjectionSingleHighPrecision_SSE4_1(
-    const MotionVector* temporal_mvs, const int8_t* temporal_reference_offsets,
-    const int reference_offset, const int count, MotionVector* candidate_mvs) {
+    const MotionVector* LIBGAV1_RESTRICT temporal_mvs,
+    const int8_t* LIBGAV1_RESTRICT temporal_reference_offsets,
+    const int reference_offset, const int count,
+    MotionVector* LIBGAV1_RESTRICT candidate_mvs) {
   // Up to three more elements could be calculated.
   int i = 0;
   do {
@@ -215,7 +224,9 @@
   } while (i < count);
 }
 
-void Init8bpp() {
+}  // namespace
+
+void MotionVectorSearchInit_SSE4_1() {
   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
   assert(dsp != nullptr);
   dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_SSE4_1;
@@ -226,28 +237,6 @@
   dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_SSE4_1;
 }
 
-#if LIBGAV1_MAX_BITDEPTH >= 10
-void Init10bpp() {
-  Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
-  assert(dsp != nullptr);
-  dsp->mv_projection_compound[0] = MvProjectionCompoundLowPrecision_SSE4_1;
-  dsp->mv_projection_compound[1] = MvProjectionCompoundForceInteger_SSE4_1;
-  dsp->mv_projection_compound[2] = MvProjectionCompoundHighPrecision_SSE4_1;
-  dsp->mv_projection_single[0] = MvProjectionSingleLowPrecision_SSE4_1;
-  dsp->mv_projection_single[1] = MvProjectionSingleForceInteger_SSE4_1;
-  dsp->mv_projection_single[2] = MvProjectionSingleHighPrecision_SSE4_1;
-}
-#endif
-
-}  // namespace
-
-void MotionVectorSearchInit_SSE4_1() {
-  Init8bpp();
-#if LIBGAV1_MAX_BITDEPTH >= 10
-  Init10bpp();
-#endif
-}
-
 }  // namespace dsp
 }  // namespace libgav1
 
diff --git a/libgav1/src/dsp/x86/obmc_sse4.cc b/libgav1/src/dsp/x86/obmc_sse4.cc
index c34a7f7..8ce23b4 100644
--- a/libgav1/src/dsp/x86/obmc_sse4.cc
+++ b/libgav1/src/dsp/x86/obmc_sse4.cc
@@ -37,8 +37,9 @@
 #include "src/dsp/obmc.inc"
 
 inline void OverlapBlendFromLeft2xH_SSE4_1(
-    uint8_t* const prediction, const ptrdiff_t prediction_stride,
-    const int height, const uint8_t* const obmc_prediction,
+    uint8_t* LIBGAV1_RESTRICT const prediction,
+    const ptrdiff_t prediction_stride, const int height,
+    const uint8_t* LIBGAV1_RESTRICT const obmc_prediction,
     const ptrdiff_t obmc_prediction_stride) {
   uint8_t* pred = prediction;
   const uint8_t* obmc_pred = obmc_prediction;
@@ -68,8 +69,9 @@
 }
 
 inline void OverlapBlendFromLeft4xH_SSE4_1(
-    uint8_t* const prediction, const ptrdiff_t prediction_stride,
-    const int height, const uint8_t* const obmc_prediction,
+    uint8_t* LIBGAV1_RESTRICT const prediction,
+    const ptrdiff_t prediction_stride, const int height,
+    const uint8_t* LIBGAV1_RESTRICT const obmc_prediction,
     const ptrdiff_t obmc_prediction_stride) {
   uint8_t* pred = prediction;
   const uint8_t* obmc_pred = obmc_prediction;
@@ -106,8 +108,9 @@
 }
 
 inline void OverlapBlendFromLeft8xH_SSE4_1(
-    uint8_t* const prediction, const ptrdiff_t prediction_stride,
-    const int height, const uint8_t* const obmc_prediction,
+    uint8_t* LIBGAV1_RESTRICT const prediction,
+    const ptrdiff_t prediction_stride, const int height,
+    const uint8_t* LIBGAV1_RESTRICT const obmc_prediction,
     const ptrdiff_t obmc_prediction_stride) {
   uint8_t* pred = prediction;
   const uint8_t* obmc_pred = obmc_prediction;
@@ -130,13 +133,15 @@
   } while (--y != 0);
 }
 
-void OverlapBlendFromLeft_SSE4_1(void* const prediction,
-                                 const ptrdiff_t prediction_stride,
-                                 const int width, const int height,
-                                 const void* const obmc_prediction,
-                                 const ptrdiff_t obmc_prediction_stride) {
+void OverlapBlendFromLeft_SSE4_1(
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
+    const int width, const int height,
+    const void* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
   auto* pred = static_cast<uint8_t*>(prediction);
   const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+  assert(width >= 2);
+  assert(height >= 4);
 
   if (width == 2) {
     OverlapBlendFromLeft2xH_SSE4_1(pred, prediction_stride, height, obmc_pred,
@@ -185,8 +190,9 @@
 }
 
 inline void OverlapBlendFromTop4xH_SSE4_1(
-    uint8_t* const prediction, const ptrdiff_t prediction_stride,
-    const int height, const uint8_t* const obmc_prediction,
+    uint8_t* LIBGAV1_RESTRICT const prediction,
+    const ptrdiff_t prediction_stride, const int height,
+    const uint8_t* LIBGAV1_RESTRICT const obmc_prediction,
     const ptrdiff_t obmc_prediction_stride) {
   uint8_t* pred = prediction;
   const uint8_t* obmc_pred = obmc_prediction;
@@ -227,8 +233,9 @@
 }
 
 inline void OverlapBlendFromTop8xH_SSE4_1(
-    uint8_t* const prediction, const ptrdiff_t prediction_stride,
-    const int height, const uint8_t* const obmc_prediction,
+    uint8_t* LIBGAV1_RESTRICT const prediction,
+    const ptrdiff_t prediction_stride, const int height,
+    const uint8_t* LIBGAV1_RESTRICT const obmc_prediction,
     const ptrdiff_t obmc_prediction_stride) {
   uint8_t* pred = prediction;
   const uint8_t* obmc_pred = obmc_prediction;
@@ -253,15 +260,17 @@
   } while (--y != 0);
 }
 
-void OverlapBlendFromTop_SSE4_1(void* const prediction,
-                                const ptrdiff_t prediction_stride,
-                                const int width, const int height,
-                                const void* const obmc_prediction,
-                                const ptrdiff_t obmc_prediction_stride) {
+void OverlapBlendFromTop_SSE4_1(
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
+    const int width, const int height,
+    const void* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
   auto* pred = static_cast<uint8_t*>(prediction);
   const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
+  assert(width >= 4);
+  assert(height >= 2);
 
-  if (width <= 4) {
+  if (width == 4) {
     OverlapBlendFromTop4xH_SSE4_1(pred, prediction_stride, height, obmc_pred,
                                   obmc_prediction_stride);
     return;
@@ -323,8 +332,9 @@
 constexpr int kRoundBitsObmcBlend = 6;
 
 inline void OverlapBlendFromLeft2xH_SSE4_1(
-    uint16_t* const prediction, const ptrdiff_t pred_stride, const int height,
-    const uint16_t* const obmc_prediction, const ptrdiff_t obmc_pred_stride) {
+    uint16_t* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride,
+    const int height, const uint16_t* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_pred_stride) {
   uint16_t* pred = prediction;
   const uint16_t* obmc_pred = obmc_prediction;
   const ptrdiff_t pred_stride2 = pred_stride << 1;
@@ -353,8 +363,9 @@
 }
 
 inline void OverlapBlendFromLeft4xH_SSE4_1(
-    uint16_t* const prediction, const ptrdiff_t pred_stride, const int height,
-    const uint16_t* const obmc_prediction, const ptrdiff_t obmc_pred_stride) {
+    uint16_t* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride,
+    const int height, const uint16_t* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_pred_stride) {
   uint16_t* pred = prediction;
   const uint16_t* obmc_pred = obmc_prediction;
   const ptrdiff_t pred_stride2 = pred_stride << 1;
@@ -385,16 +396,18 @@
   } while (y != 0);
 }
 
-void OverlapBlendFromLeft10bpp_SSE4_1(void* const prediction,
-                                      const ptrdiff_t prediction_stride,
-                                      const int width, const int height,
-                                      const void* const obmc_prediction,
-                                      const ptrdiff_t obmc_prediction_stride) {
+void OverlapBlendFromLeft10bpp_SSE4_1(
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
+    const int width, const int height,
+    const void* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
   auto* pred = static_cast<uint16_t*>(prediction);
   const auto* obmc_pred = static_cast<const uint16_t*>(obmc_prediction);
   const ptrdiff_t pred_stride = prediction_stride / sizeof(pred[0]);
   const ptrdiff_t obmc_pred_stride =
       obmc_prediction_stride / sizeof(obmc_pred[0]);
+  assert(width >= 2);
+  assert(height >= 4);
 
   if (width == 2) {
     OverlapBlendFromLeft2xH_SSE4_1(pred, pred_stride, height, obmc_pred,
@@ -437,54 +450,10 @@
   } while (x < width);
 }
 
-inline void OverlapBlendFromTop2xH_SSE4_1(uint16_t* const prediction,
-                                          const ptrdiff_t pred_stride,
-                                          const int height,
-                                          const uint16_t* const obmc_prediction,
-                                          const ptrdiff_t obmc_pred_stride) {
-  uint16_t* pred = prediction;
-  const uint16_t* obmc_pred = obmc_prediction;
-  const __m128i mask_inverter = _mm_set1_epi16(64);
-  const __m128i mask_shuffler = _mm_set_epi32(0x01010101, 0x01010101, 0, 0);
-  const __m128i mask_preinverter = _mm_set1_epi16(-256 | 1);
-  const uint8_t* mask = kObmcMask + height - 2;
-  const int compute_height =
-      height - (height >> 2);  // compute_height based on 8-bit opt
-  const ptrdiff_t pred_stride2 = pred_stride << 1;
-  const ptrdiff_t obmc_pred_stride2 = obmc_pred_stride << 1;
-  int y = 0;
-  do {
-    // First mask in the first half, second mask in the second half.
-    const __m128i mask_val = _mm_shuffle_epi8(Load4(mask + y), mask_shuffler);
-    const __m128i masks =
-        _mm_sub_epi8(mask_inverter, _mm_sign_epi8(mask_val, mask_preinverter));
-    const __m128i masks_lo = _mm_cvtepi8_epi16(masks);
-    const __m128i masks_hi = _mm_cvtepi8_epi16(_mm_srli_si128(masks, 8));
-
-    const __m128i pred_val = LoadHi8(LoadLo8(pred), pred + pred_stride);
-    const __m128i obmc_pred_val =
-        LoadHi8(LoadLo8(obmc_pred), obmc_pred + obmc_pred_stride);
-    const __m128i terms_lo = _mm_unpacklo_epi16(obmc_pred_val, pred_val);
-    const __m128i terms_hi = _mm_unpackhi_epi16(obmc_pred_val, pred_val);
-    const __m128i result_lo = RightShiftWithRounding_U32(
-        _mm_madd_epi16(terms_lo, masks_lo), kRoundBitsObmcBlend);
-    const __m128i result_hi = RightShiftWithRounding_U32(
-        _mm_madd_epi16(terms_hi, masks_hi), kRoundBitsObmcBlend);
-    const __m128i packed_result = _mm_packus_epi32(result_lo, result_hi);
-
-    Store4(pred, packed_result);
-    Store4(pred + pred_stride, _mm_srli_si128(packed_result, 8));
-    pred += pred_stride2;
-    obmc_pred += obmc_pred_stride2;
-    y += 2;
-  } while (y < compute_height);
-}
-
-inline void OverlapBlendFromTop4xH_SSE4_1(uint16_t* const prediction,
-                                          const ptrdiff_t pred_stride,
-                                          const int height,
-                                          const uint16_t* const obmc_prediction,
-                                          const ptrdiff_t obmc_pred_stride) {
+inline void OverlapBlendFromTop4xH_SSE4_1(
+    uint16_t* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride,
+    const int height, const uint16_t* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_pred_stride) {
   uint16_t* pred = prediction;
   const uint16_t* obmc_pred = obmc_prediction;
   const __m128i mask_inverter = _mm_set1_epi16(64);
@@ -522,22 +491,19 @@
   } while (y < compute_height);
 }
 
-void OverlapBlendFromTop10bpp_SSE4_1(void* const prediction,
-                                     const ptrdiff_t prediction_stride,
-                                     const int width, const int height,
-                                     const void* const obmc_prediction,
-                                     const ptrdiff_t obmc_prediction_stride) {
+void OverlapBlendFromTop10bpp_SSE4_1(
+    void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
+    const int width, const int height,
+    const void* LIBGAV1_RESTRICT const obmc_prediction,
+    const ptrdiff_t obmc_prediction_stride) {
   auto* pred = static_cast<uint16_t*>(prediction);
   const auto* obmc_pred = static_cast<const uint16_t*>(obmc_prediction);
   const ptrdiff_t pred_stride = prediction_stride / sizeof(pred[0]);
   const ptrdiff_t obmc_pred_stride =
       obmc_prediction_stride / sizeof(obmc_pred[0]);
+  assert(width >= 4);
+  assert(height >= 2);
 
-  if (width == 2) {
-    OverlapBlendFromTop2xH_SSE4_1(pred, pred_stride, height, obmc_pred,
-                                  obmc_pred_stride);
-    return;
-  }
   if (width == 4) {
     OverlapBlendFromTop4xH_SSE4_1(pred, pred_stride, height, obmc_pred,
                                   obmc_pred_stride);
diff --git a/libgav1/src/dsp/x86/super_res_sse4.cc b/libgav1/src/dsp/x86/super_res_sse4.cc
index 85d05bc..458d94e 100644
--- a/libgav1/src/dsp/x86/super_res_sse4.cc
+++ b/libgav1/src/dsp/x86/super_res_sse4.cc
@@ -90,11 +90,13 @@
   } while (--x != 0);
 }
 
-void SuperRes_SSE4_1(const void* const coefficients, void* const source,
+void SuperRes_SSE4_1(const void* LIBGAV1_RESTRICT const coefficients,
+                     void* LIBGAV1_RESTRICT const source,
                      const ptrdiff_t source_stride, const int height,
                      const int downscaled_width, const int upscaled_width,
                      const int initial_subpixel_x, const int step,
-                     void* const dest, const ptrdiff_t dest_stride) {
+                     void* LIBGAV1_RESTRICT const dest,
+                     const ptrdiff_t dest_stride) {
   auto* src = static_cast<uint8_t*>(source) - DivideBy2(kSuperResFilterTaps);
   auto* dst = static_cast<uint8_t*>(dest);
   int y = height;
@@ -227,11 +229,13 @@
 }
 
 template <int bitdepth>
-void SuperRes_SSE4_1(const void* const coefficients, void* const source,
+void SuperRes_SSE4_1(const void* LIBGAV1_RESTRICT const coefficients,
+                     void* LIBGAV1_RESTRICT const source,
                      const ptrdiff_t source_stride, const int height,
                      const int downscaled_width, const int upscaled_width,
                      const int initial_subpixel_x, const int step,
-                     void* const dest, const ptrdiff_t dest_stride) {
+                     void* LIBGAV1_RESTRICT const dest,
+                     const ptrdiff_t dest_stride) {
   auto* src = static_cast<uint16_t*>(source) - DivideBy2(kSuperResFilterTaps);
   auto* dst = static_cast<uint16_t*>(dest);
   int y = height;
diff --git a/libgav1/src/dsp/x86/warp_sse4.cc b/libgav1/src/dsp/x86/warp_sse4.cc
index 9ddfeac..5830894 100644
--- a/libgav1/src/dsp/x86/warp_sse4.cc
+++ b/libgav1/src/dsp/x86/warp_sse4.cc
@@ -101,7 +101,7 @@
 template <bool is_compound>
 inline void WriteVerticalFilter(const __m128i filter[8],
                                 const int16_t intermediate_result[15][8], int y,
-                                void* dst_row) {
+                                void* LIBGAV1_RESTRICT dst_row) {
   constexpr int kRoundBitsVertical =
       is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical;
   __m128i sum_low = _mm_set1_epi32(kOffsetRemoval);
@@ -136,8 +136,9 @@
 
 template <bool is_compound>
 inline void WriteVerticalFilter(const __m128i filter[8],
-                                const int16_t* intermediate_result_column,
-                                void* dst_row) {
+                                const int16_t* LIBGAV1_RESTRICT
+                                    intermediate_result_column,
+                                void* LIBGAV1_RESTRICT dst_row) {
   constexpr int kRoundBitsVertical =
       is_compound ? kInterRoundBitsCompoundVertical : kInterRoundBitsVertical;
   __m128i sum_low = _mm_setzero_si128();
@@ -167,7 +168,7 @@
 
 template <bool is_compound, typename DestType>
 inline void VerticalFilter(const int16_t source[15][8], int y4, int gamma,
-                           int delta, DestType* dest_row,
+                           int delta, DestType* LIBGAV1_RESTRICT dest_row,
                            ptrdiff_t dest_stride) {
   int sy4 = (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
   for (int y = 0; y < 8; ++y) {
@@ -187,8 +188,9 @@
 }
 
 template <bool is_compound, typename DestType>
-inline void VerticalFilter(const int16_t* source_cols, int y4, int gamma,
-                           int delta, DestType* dest_row,
+inline void VerticalFilter(const int16_t* LIBGAV1_RESTRICT source_cols, int y4,
+                           int gamma, int delta,
+                           DestType* LIBGAV1_RESTRICT dest_row,
                            ptrdiff_t dest_stride) {
   int sy4 = (y4 & ((1 << kWarpedModelPrecisionBits) - 1)) - MultiplyBy4(delta);
   for (int y = 0; y < 8; ++y) {
@@ -208,9 +210,11 @@
 }
 
 template <bool is_compound, typename DestType>
-inline void WarpRegion1(const uint8_t* src, ptrdiff_t source_stride,
-                        int source_width, int source_height, int ix4, int iy4,
-                        DestType* dst_row, ptrdiff_t dest_stride) {
+inline void WarpRegion1(const uint8_t* LIBGAV1_RESTRICT src,
+                        ptrdiff_t source_stride, int source_width,
+                        int source_height, int ix4, int iy4,
+                        DestType* LIBGAV1_RESTRICT dst_row,
+                        ptrdiff_t dest_stride) {
   // Region 1
   // Points to the left or right border of the first row of |src|.
   const uint8_t* first_row_border =
@@ -244,10 +248,12 @@
 }
 
 template <bool is_compound, typename DestType>
-inline void WarpRegion2(const uint8_t* src, ptrdiff_t source_stride,
-                        int source_width, int y4, int ix4, int iy4, int gamma,
-                        int delta, int16_t intermediate_result_column[15],
-                        DestType* dst_row, ptrdiff_t dest_stride) {
+inline void WarpRegion2(const uint8_t* LIBGAV1_RESTRICT src,
+                        ptrdiff_t source_stride, int source_width, int y4,
+                        int ix4, int iy4, int gamma, int delta,
+                        int16_t intermediate_result_column[15],
+                        DestType* LIBGAV1_RESTRICT dst_row,
+                        ptrdiff_t dest_stride) {
   // Region 2.
   // Points to the left or right border of the first row of |src|.
   const uint8_t* first_row_border =
@@ -283,9 +289,10 @@
 }
 
 template <bool is_compound, typename DestType>
-inline void WarpRegion3(const uint8_t* src, ptrdiff_t source_stride,
-                        int source_height, int alpha, int beta, int x4, int ix4,
-                        int iy4, int16_t intermediate_result[15][8]) {
+inline void WarpRegion3(const uint8_t* LIBGAV1_RESTRICT src,
+                        ptrdiff_t source_stride, int source_height, int alpha,
+                        int beta, int x4, int ix4, int iy4,
+                        int16_t intermediate_result[15][8]) {
   // Region 3
   // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
 
@@ -315,9 +322,9 @@
 }
 
 template <bool is_compound, typename DestType>
-inline void WarpRegion4(const uint8_t* src, ptrdiff_t source_stride, int alpha,
-                        int beta, int x4, int ix4, int iy4,
-                        int16_t intermediate_result[15][8]) {
+inline void WarpRegion4(const uint8_t* LIBGAV1_RESTRICT src,
+                        ptrdiff_t source_stride, int alpha, int beta, int x4,
+                        int ix4, int iy4, int16_t intermediate_result[15][8]) {
   // Region 4.
   // At this point, we know ix4 - 7 < source_width - 1 and ix4 + 7 > 0.
 
@@ -351,12 +358,14 @@
 }
 
 template <bool is_compound, typename DestType>
-inline void HandleWarpBlock(const uint8_t* src, ptrdiff_t source_stride,
-                            int source_width, int source_height,
-                            const int* warp_params, int subsampling_x,
-                            int subsampling_y, int src_x, int src_y,
-                            int16_t alpha, int16_t beta, int16_t gamma,
-                            int16_t delta, DestType* dst_row,
+inline void HandleWarpBlock(const uint8_t* LIBGAV1_RESTRICT src,
+                            ptrdiff_t source_stride, int source_width,
+                            int source_height,
+                            const int* LIBGAV1_RESTRICT warp_params,
+                            int subsampling_x, int subsampling_y, int src_x,
+                            int src_y, int16_t alpha, int16_t beta,
+                            int16_t gamma, int16_t delta,
+                            DestType* LIBGAV1_RESTRICT dst_row,
                             ptrdiff_t dest_stride) {
   union {
     // Intermediate_result is the output of the horizontal filtering and
@@ -460,11 +469,12 @@
 }
 
 template <bool is_compound>
-void Warp_SSE4_1(const void* source, ptrdiff_t source_stride, int source_width,
-                 int source_height, const int* warp_params, int subsampling_x,
+void Warp_SSE4_1(const void* LIBGAV1_RESTRICT source, ptrdiff_t source_stride,
+                 int source_width, int source_height,
+                 const int* LIBGAV1_RESTRICT warp_params, int subsampling_x,
                  int subsampling_y, int block_start_x, int block_start_y,
                  int block_width, int block_height, int16_t alpha, int16_t beta,
-                 int16_t gamma, int16_t delta, void* dest,
+                 int16_t gamma, int16_t delta, void* LIBGAV1_RESTRICT dest,
                  ptrdiff_t dest_stride) {
   const auto* const src = static_cast<const uint8_t*>(source);
   using DestType =
diff --git a/libgav1/src/dsp/x86/weight_mask_sse4.cc b/libgav1/src/dsp/x86/weight_mask_sse4.cc
index 08a1739..69cb784 100644
--- a/libgav1/src/dsp/x86/weight_mask_sse4.cc
+++ b/libgav1/src/dsp/x86/weight_mask_sse4.cc
@@ -37,8 +37,9 @@
 constexpr int kRoundingBits8bpp = 4;
 
 template <bool mask_is_inverse, bool is_store_16>
-inline void WeightMask16_SSE4(const int16_t* prediction_0,
-                              const int16_t* prediction_1, uint8_t* mask,
+inline void WeightMask16_SSE4(const int16_t* LIBGAV1_RESTRICT prediction_0,
+                              const int16_t* LIBGAV1_RESTRICT prediction_1,
+                              uint8_t* LIBGAV1_RESTRICT mask,
                               ptrdiff_t mask_stride) {
   const __m128i pred_00 = LoadAligned16(prediction_0);
   const __m128i pred_10 = LoadAligned16(prediction_1);
@@ -86,8 +87,9 @@
   mask += mask_stride << 1
 
 template <bool mask_is_inverse>
-void WeightMask8x8_SSE4(const void* prediction_0, const void* prediction_1,
-                        uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask8x8_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                        const void* LIBGAV1_RESTRICT prediction_1,
+                        uint8_t* LIBGAV1_RESTRICT mask, ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
 
@@ -98,8 +100,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask8x16_SSE4(const void* prediction_0, const void* prediction_1,
-                         uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask8x16_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                         const void* LIBGAV1_RESTRICT prediction_1,
+                         uint8_t* LIBGAV1_RESTRICT mask,
+                         ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 3;
@@ -112,8 +116,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask8x32_SSE4(const void* prediction_0, const void* prediction_1,
-                         uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask8x32_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                         const void* LIBGAV1_RESTRICT prediction_1,
+                         uint8_t* LIBGAV1_RESTRICT mask,
+                         ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y5 = 5;
@@ -135,8 +141,10 @@
   mask += mask_stride
 
 template <bool mask_is_inverse>
-void WeightMask16x8_SSE4(const void* prediction_0, const void* prediction_1,
-                         uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask16x8_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                         const void* LIBGAV1_RESTRICT prediction_1,
+                         uint8_t* LIBGAV1_RESTRICT mask,
+                         ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y = 7;
@@ -147,8 +155,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask16x16_SSE4(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask16x16_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 5;
@@ -161,8 +171,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask16x32_SSE4(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask16x32_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y5 = 6;
@@ -178,8 +190,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask16x64_SSE4(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask16x64_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 21;
@@ -203,8 +217,10 @@
   mask += mask_stride
 
 template <bool mask_is_inverse>
-void WeightMask32x8_SSE4(const void* prediction_0, const void* prediction_1,
-                         uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask32x8_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                         const void* LIBGAV1_RESTRICT prediction_1,
+                         uint8_t* LIBGAV1_RESTRICT mask,
+                         ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   WEIGHT32_AND_STRIDE;
@@ -218,8 +234,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask32x16_SSE4(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask32x16_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 5;
@@ -232,8 +250,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask32x32_SSE4(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask32x32_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y5 = 6;
@@ -249,8 +269,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask32x64_SSE4(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask32x64_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 21;
@@ -278,8 +300,10 @@
   mask += mask_stride
 
 template <bool mask_is_inverse>
-void WeightMask64x16_SSE4(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask64x16_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -292,8 +316,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask64x32_SSE4(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask64x32_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y5 = 0;
@@ -309,8 +335,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask64x64_SSE4(const void* prediction_0, const void* prediction_1,
-                          uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask64x64_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                          const void* LIBGAV1_RESTRICT prediction_1,
+                          uint8_t* LIBGAV1_RESTRICT mask,
+                          ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -323,8 +351,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask64x128_SSE4(const void* prediction_0, const void* prediction_1,
-                           uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask64x128_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                           const void* LIBGAV1_RESTRICT prediction_1,
+                           uint8_t* LIBGAV1_RESTRICT mask,
+                           ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -338,8 +368,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask128x64_SSE4(const void* prediction_0, const void* prediction_1,
-                           uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask128x64_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                           const void* LIBGAV1_RESTRICT prediction_1,
+                           uint8_t* LIBGAV1_RESTRICT mask,
+                           ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -380,8 +412,10 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask128x128_SSE4(const void* prediction_0, const void* prediction_1,
-                            uint8_t* mask, ptrdiff_t mask_stride) {
+void WeightMask128x128_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                            const void* LIBGAV1_RESTRICT prediction_1,
+                            uint8_t* LIBGAV1_RESTRICT mask,
+                            ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
   int y3 = 0;
@@ -467,9 +501,10 @@
 constexpr int kScaledDiffShift = 4;
 
 template <bool mask_is_inverse, bool is_store_16>
-inline void WeightMask16_10bpp_SSE4(const uint16_t* prediction_0,
-                                    const uint16_t* prediction_1, uint8_t* mask,
-                                    ptrdiff_t mask_stride) {
+inline void WeightMask16_10bpp_SSE4(
+    const uint16_t* LIBGAV1_RESTRICT prediction_0,
+    const uint16_t* LIBGAV1_RESTRICT prediction_1,
+    uint8_t* LIBGAV1_RESTRICT mask, ptrdiff_t mask_stride) {
   const __m128i diff_offset = _mm_set1_epi8(38);
   const __m128i mask_ceiling = _mm_set1_epi8(64);
   const __m128i zero = _mm_setzero_si128();
@@ -538,8 +573,9 @@
   mask += mask_stride << 1
 
 template <bool mask_is_inverse>
-void WeightMask8x8_10bpp_SSE4(const void* prediction_0,
-                              const void* prediction_1, uint8_t* mask,
+void WeightMask8x8_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                              const void* LIBGAV1_RESTRICT prediction_1,
+                              uint8_t* LIBGAV1_RESTRICT mask,
                               ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -551,8 +587,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask8x16_10bpp_SSE4(const void* prediction_0,
-                               const void* prediction_1, uint8_t* mask,
+void WeightMask8x16_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                               const void* LIBGAV1_RESTRICT prediction_1,
+                               uint8_t* LIBGAV1_RESTRICT mask,
                                ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -566,8 +603,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask8x32_10bpp_SSE4(const void* prediction_0,
-                               const void* prediction_1, uint8_t* mask,
+void WeightMask8x32_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                               const void* LIBGAV1_RESTRICT prediction_1,
+                               uint8_t* LIBGAV1_RESTRICT mask,
                                ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -591,8 +629,9 @@
   mask += mask_stride
 
 template <bool mask_is_inverse>
-void WeightMask16x8_10bpp_SSE4(const void* prediction_0,
-                               const void* prediction_1, uint8_t* mask,
+void WeightMask16x8_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                               const void* LIBGAV1_RESTRICT prediction_1,
+                               uint8_t* LIBGAV1_RESTRICT mask,
                                ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -604,8 +643,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask16x16_10bpp_SSE4(const void* prediction_0,
-                                const void* prediction_1, uint8_t* mask,
+void WeightMask16x16_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                const void* LIBGAV1_RESTRICT prediction_1,
+                                uint8_t* LIBGAV1_RESTRICT mask,
                                 ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -619,8 +659,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask16x32_10bpp_SSE4(const void* prediction_0,
-                                const void* prediction_1, uint8_t* mask,
+void WeightMask16x32_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                const void* LIBGAV1_RESTRICT prediction_1,
+                                uint8_t* LIBGAV1_RESTRICT mask,
                                 ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -637,8 +678,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask16x64_10bpp_SSE4(const void* prediction_0,
-                                const void* prediction_1, uint8_t* mask,
+void WeightMask16x64_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                const void* LIBGAV1_RESTRICT prediction_1,
+                                uint8_t* LIBGAV1_RESTRICT mask,
                                 ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -664,8 +706,9 @@
   mask += mask_stride
 
 template <bool mask_is_inverse>
-void WeightMask32x8_10bpp_SSE4(const void* prediction_0,
-                               const void* prediction_1, uint8_t* mask,
+void WeightMask32x8_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                               const void* LIBGAV1_RESTRICT prediction_1,
+                               uint8_t* LIBGAV1_RESTRICT mask,
                                ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -680,8 +723,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask32x16_10bpp_SSE4(const void* prediction_0,
-                                const void* prediction_1, uint8_t* mask,
+void WeightMask32x16_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                const void* LIBGAV1_RESTRICT prediction_1,
+                                uint8_t* LIBGAV1_RESTRICT mask,
                                 ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -695,8 +739,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask32x32_10bpp_SSE4(const void* prediction_0,
-                                const void* prediction_1, uint8_t* mask,
+void WeightMask32x32_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                const void* LIBGAV1_RESTRICT prediction_1,
+                                uint8_t* LIBGAV1_RESTRICT mask,
                                 ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -713,8 +758,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask32x64_10bpp_SSE4(const void* prediction_0,
-                                const void* prediction_1, uint8_t* mask,
+void WeightMask32x64_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                const void* LIBGAV1_RESTRICT prediction_1,
+                                uint8_t* LIBGAV1_RESTRICT mask,
                                 ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -744,8 +790,9 @@
   mask += mask_stride
 
 template <bool mask_is_inverse>
-void WeightMask64x16_10bpp_SSE4(const void* prediction_0,
-                                const void* prediction_1, uint8_t* mask,
+void WeightMask64x16_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                const void* LIBGAV1_RESTRICT prediction_1,
+                                uint8_t* LIBGAV1_RESTRICT mask,
                                 ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -759,8 +806,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask64x32_10bpp_SSE4(const void* prediction_0,
-                                const void* prediction_1, uint8_t* mask,
+void WeightMask64x32_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                const void* LIBGAV1_RESTRICT prediction_1,
+                                uint8_t* LIBGAV1_RESTRICT mask,
                                 ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -777,8 +825,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask64x64_10bpp_SSE4(const void* prediction_0,
-                                const void* prediction_1, uint8_t* mask,
+void WeightMask64x64_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                const void* LIBGAV1_RESTRICT prediction_1,
+                                uint8_t* LIBGAV1_RESTRICT mask,
                                 ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -792,8 +841,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask64x128_10bpp_SSE4(const void* prediction_0,
-                                 const void* prediction_1, uint8_t* mask,
+void WeightMask64x128_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                 const void* LIBGAV1_RESTRICT prediction_1,
+                                 uint8_t* LIBGAV1_RESTRICT mask,
                                  ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -808,8 +858,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask128x64_10bpp_SSE4(const void* prediction_0,
-                                 const void* prediction_1, uint8_t* mask,
+void WeightMask128x64_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                 const void* LIBGAV1_RESTRICT prediction_1,
+                                 uint8_t* LIBGAV1_RESTRICT mask,
                                  ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
@@ -851,8 +902,9 @@
 }
 
 template <bool mask_is_inverse>
-void WeightMask128x128_10bpp_SSE4(const void* prediction_0,
-                                  const void* prediction_1, uint8_t* mask,
+void WeightMask128x128_10bpp_SSE4(const void* LIBGAV1_RESTRICT prediction_0,
+                                  const void* LIBGAV1_RESTRICT prediction_1,
+                                  uint8_t* LIBGAV1_RESTRICT mask,
                                   ptrdiff_t mask_stride) {
   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
diff --git a/libgav1/src/film_grain.cc b/libgav1/src/film_grain.cc
index dac37b5..5c64ff2 100644
--- a/libgav1/src/film_grain.cc
+++ b/libgav1/src/film_grain.cc
@@ -24,6 +24,7 @@
 #include "src/dsp/common.h"
 #include "src/dsp/constants.h"
 #include "src/dsp/dsp.h"
+#include "src/dsp/film_grain_common.h"
 #include "src/utils/array_2d.h"
 #include "src/utils/blocking_counter.h"
 #include "src/utils/common.h"
@@ -318,10 +319,14 @@
   //
   // Note: Although it does not seem to make sense, there are test vectors
   // with chroma_scaling_from_luma=true and params_.num_y_points=0.
+#if LIBGAV1_MSAN
+  // Quiet film grain / md5 msan warnings.
+  memset(scaling_lut_y_, 0, sizeof(scaling_lut_y_));
+#endif
   if (use_luma || params_.chroma_scaling_from_luma) {
     dsp.film_grain.initialize_scaling_lut(
         params_.num_y_points, params_.point_y_value, params_.point_y_scaling,
-        scaling_lut_y_);
+        scaling_lut_y_, kScalingLutLength);
   } else {
     ASAN_POISON_MEMORY_REGION(scaling_lut_y_, sizeof(scaling_lut_y_));
   }
@@ -331,25 +336,28 @@
       scaling_lut_v_ = scaling_lut_y_;
     } else if (params_.num_u_points > 0 || params_.num_v_points > 0) {
       const size_t buffer_size =
-          (kScalingLookupTableSize + kScalingLookupTablePadding) *
-          (static_cast<int>(params_.num_u_points > 0) +
-           static_cast<int>(params_.num_v_points > 0));
-      scaling_lut_chroma_buffer_.reset(new (std::nothrow) uint8_t[buffer_size]);
+          kScalingLutLength * (static_cast<int>(params_.num_u_points > 0) +
+                               static_cast<int>(params_.num_v_points > 0));
+      scaling_lut_chroma_buffer_.reset(new (std::nothrow) int16_t[buffer_size]);
       if (scaling_lut_chroma_buffer_ == nullptr) return false;
 
-      uint8_t* buffer = scaling_lut_chroma_buffer_.get();
+      int16_t* buffer = scaling_lut_chroma_buffer_.get();
+#if LIBGAV1_MSAN
+      // Quiet film grain / md5 msan warnings.
+      memset(buffer, 0, buffer_size * 2);
+#endif
       if (params_.num_u_points > 0) {
         scaling_lut_u_ = buffer;
         dsp.film_grain.initialize_scaling_lut(
             params_.num_u_points, params_.point_u_value,
-            params_.point_u_scaling, scaling_lut_u_);
-        buffer += kScalingLookupTableSize + kScalingLookupTablePadding;
+            params_.point_u_scaling, scaling_lut_u_, kScalingLutLength);
+        buffer += kScalingLutLength;
       }
       if (params_.num_v_points > 0) {
         scaling_lut_v_ = buffer;
         dsp.film_grain.initialize_scaling_lut(
             params_.num_v_points, params_.point_v_value,
-            params_.point_v_scaling, scaling_lut_v_);
+            params_.point_v_scaling, scaling_lut_v_, kScalingLutLength);
       }
     }
   }
@@ -364,7 +372,7 @@
   // 7.18.3.3 says luma_grain "will never be read in this case". So we don't
   // call GenerateLumaGrain if params.num_y_points is equal to 0.
   assert(params.num_y_points > 0);
-  const int shift = 12 - bitdepth + params.grain_scale_shift;
+  const int shift = kBitdepth12 - bitdepth + params.grain_scale_shift;
   uint16_t seed = params.grain_seed;
   GrainType* luma_grain_row = luma_grain;
   for (int y = 0; y < kLumaHeight; ++y) {
@@ -382,7 +390,7 @@
                                                int chroma_height,
                                                GrainType* u_grain,
                                                GrainType* v_grain) {
-  const int shift = 12 - bitdepth + params.grain_scale_shift;
+  const int shift = kBitdepth12 - bitdepth + params.grain_scale_shift;
   if (params.num_u_points == 0 && !params.chroma_scaling_from_luma) {
     memset(u_grain, 0, chroma_height * chroma_width * sizeof(*u_grain));
   } else {
@@ -460,22 +468,25 @@
 
 template <int bitdepth>
 bool FilmGrain<bitdepth>::AllocateNoiseImage() {
+  // When LIBGAV1_MSAN is enabled, zero initialize to quiet optimized film grain
+  // msan warnings.
+  constexpr bool zero_initialize = LIBGAV1_MSAN == 1;
   if (params_.num_y_points > 0 &&
       !noise_image_[kPlaneY].Reset(height_, width_ + kNoiseImagePadding,
-                                   /*zero_initialize=*/false)) {
+                                   zero_initialize)) {
     return false;
   }
   if (!is_monochrome_) {
     if (!noise_image_[kPlaneU].Reset(
             (height_ + subsampling_y_) >> subsampling_y_,
             ((width_ + subsampling_x_) >> subsampling_x_) + kNoiseImagePadding,
-            /*zero_initialize=*/false)) {
+            zero_initialize)) {
       return false;
     }
     if (!noise_image_[kPlaneV].Reset(
             (height_ + subsampling_y_) >> subsampling_y_,
             ((width_ + subsampling_x_) >> subsampling_x_) + kNoiseImagePadding,
-            /*zero_initialize=*/false)) {
+            zero_initialize)) {
       return false;
     }
   }
@@ -556,7 +567,7 @@
 
     const auto* source_cursor_y = reinterpret_cast<const Pixel*>(
         source_plane_y + start_height * source_stride_y);
-    const uint8_t* scaling_lut_uv;
+    const int16_t* scaling_lut_uv;
     const uint8_t* source_plane_uv;
     uint8_t* dest_plane_uv;
 
@@ -689,16 +700,16 @@
   int max_luma;
   int max_chroma;
   if (params_.clip_to_restricted_range) {
-    min_value = 16 << (bitdepth - 8);
-    max_luma = 235 << (bitdepth - 8);
+    min_value = 16 << (bitdepth - kBitdepth8);
+    max_luma = 235 << (bitdepth - kBitdepth8);
     if (color_matrix_is_identity_) {
       max_chroma = max_luma;
     } else {
-      max_chroma = 240 << (bitdepth - 8);
+      max_chroma = 240 << (bitdepth - kBitdepth8);
     }
   } else {
     min_value = 0;
-    max_luma = (256 << (bitdepth - 8)) - 1;
+    max_luma = (256 << (bitdepth - kBitdepth8)) - 1;
     max_chroma = max_luma;
   }
 
@@ -809,9 +820,9 @@
 }
 
 // Explicit instantiations.
-template class FilmGrain<8>;
+template class FilmGrain<kBitdepth8>;
 #if LIBGAV1_MAX_BITDEPTH >= 10
-template class FilmGrain<10>;
+template class FilmGrain<kBitdepth10>;
 #endif
 
 }  // namespace libgav1
diff --git a/libgav1/src/film_grain.h b/libgav1/src/film_grain.h
index b588f6d..f2c1e93 100644
--- a/libgav1/src/film_grain.h
+++ b/libgav1/src/film_grain.h
@@ -103,6 +103,8 @@
  private:
   using Pixel =
       typename std::conditional<bitdepth == 8, uint8_t, uint16_t>::type;
+  static constexpr int kScalingLutLength =
+      (kScalingLookupTableSize + kScalingLookupTablePadding) << (bitdepth - 8);
 
   bool Init();
 
@@ -156,13 +158,13 @@
   GrainType u_grain_[kMaxChromaHeight * kMaxChromaWidth];
   GrainType v_grain_[kMaxChromaHeight * kMaxChromaWidth];
   // Scaling lookup tables.
-  uint8_t scaling_lut_y_[kScalingLookupTableSize + kScalingLookupTablePadding];
-  uint8_t* scaling_lut_u_ = nullptr;
-  uint8_t* scaling_lut_v_ = nullptr;
-  // If allocated, this buffer is 256 * 2 bytes long and scaling_lut_u_ and
+  int16_t scaling_lut_y_[kScalingLutLength];
+  int16_t* scaling_lut_u_ = nullptr;
+  int16_t* scaling_lut_v_ = nullptr;
+  // If allocated, this buffer is 256 * 2 values long and scaling_lut_u_ and
   // scaling_lut_v_ point into this buffer. Otherwise, scaling_lut_u_ and
   // scaling_lut_v_ point to scaling_lut_y_.
-  std::unique_ptr<uint8_t[]> scaling_lut_chroma_buffer_;
+  std::unique_ptr<int16_t[]> scaling_lut_chroma_buffer_;
 
   // A two-dimensional array of noise data for each plane. Generated for each 32
   // luma sample high stripe of the image. The first dimension is called
diff --git a/libgav1/src/frame_scratch_buffer.h b/libgav1/src/frame_scratch_buffer.h
index 90c3bb8..1b0d2e0 100644
--- a/libgav1/src/frame_scratch_buffer.h
+++ b/libgav1/src/frame_scratch_buffer.h
@@ -17,10 +17,13 @@
 #ifndef LIBGAV1_SRC_FRAME_SCRATCH_BUFFER_H_
 #define LIBGAV1_SRC_FRAME_SCRATCH_BUFFER_H_
 
+#include <array>
 #include <condition_variable>  // NOLINT (unapproved c++11 header)
 #include <cstdint>
 #include <memory>
 #include <mutex>  // NOLINT (unapproved c++11 header)
+#include <new>
+#include <utility>
 
 #include "src/loop_restoration_info.h"
 #include "src/residual_buffer_pool.h"
@@ -46,9 +49,24 @@
 
 // Buffer to facilitate decoding a frame. This struct is used only within
 // DecoderImpl::DecodeTiles().
-struct FrameScratchBuffer {
+// The alignment requirement is due to the SymbolDecoderContext member
+// symbol_decoder_context and the TileScratchBufferPool member
+// tile_scratch_buffer_pool.
+struct FrameScratchBuffer : public MaxAlignedAllocable {
   LoopRestorationInfo loop_restoration_info;
-  Array2D<int16_t> cdef_index;
+  Array2D<int8_t> cdef_index;
+  // Encodes the block skip information as a bitmask for the entire frame which
+  // will be used by the cdef process.
+  //
+  // * The size of this array is rows4x4 / 2 * column4x4 / 16.
+  // * Each row of the bitmasks array (cdef_skip) stores the bitmask for 2 rows
+  // of 4x4 blocks.
+  // * Each entry in the row will store the skip information for 16 4x4 blocks
+  // (8 bits).
+  // * If any of the four 4x4 blocks in the 8x8 block is not a skip block, then
+  // the corresponding bit (as described below) will be set to 1.
+  // * For the 4x4 block at column4x4 the bit index is (column4x4 >> 1).
+  Array2D<uint8_t> cdef_skip;
   Array2D<TransformSize> inter_transform_sizes;
   BlockParametersHolder block_parameters_holder;
   TemporalMotionField motion_field;
diff --git a/libgav1/src/gav1/decoder_buffer.h b/libgav1/src/gav1/decoder_buffer.h
index 37bcb29..880c320 100644
--- a/libgav1/src/gav1/decoder_buffer.h
+++ b/libgav1/src/gav1/decoder_buffer.h
@@ -129,24 +129,17 @@
   Libgav1TransferCharacteristics transfer_characteristics;
   Libgav1MatrixCoefficients matrix_coefficients;
 
-  // Image storage dimensions.
-  // NOTE: These fields are named w and h in vpx_image_t and aom_image_t.
-  // uint32_t width;  // Stored image width.
-  // uint32_t height;  // Stored image height.
   int bitdepth;  // Stored image bitdepth.
 
-  // Image display dimensions.
-  // NOTES:
-  // 1. These fields are named d_w and d_h in vpx_image_t and aom_image_t.
-  // 2. libvpx and libaom clients use d_w and d_h much more often than w and h.
-  // 3. These fields can just be stored for the Y plane and the clients can
-  //    calculate the values for the U and V planes if the image format or
-  //    subsampling is exposed.
+  // Image display dimensions in Y/U/V order.
   int displayed_width[3];   // Displayed image width.
   int displayed_height[3];  // Displayed image height.
 
-  int stride[3];
-  uint8_t* plane[3];
+  // Values are given in Y/U/V order.
+  int stride[3];      // The width in bytes of one row of the |plane| buffer.
+                      // This may include padding bytes for alignment or
+                      // internal use by the decoder.
+  uint8_t* plane[3];  // The reconstructed image plane(s).
 
   // Spatial id of this frame.
   int spatial_id;
diff --git a/libgav1/src/gav1/version.h b/libgav1/src/gav1/version.h
index c018928..9bdc630 100644
--- a/libgav1/src/gav1/version.h
+++ b/libgav1/src/gav1/version.h
@@ -23,8 +23,8 @@
 // (https://semver.org).
 
 #define LIBGAV1_MAJOR_VERSION 0
-#define LIBGAV1_MINOR_VERSION 16
-#define LIBGAV1_PATCH_VERSION 3
+#define LIBGAV1_MINOR_VERSION 17
+#define LIBGAV1_PATCH_VERSION 0
 
 #define LIBGAV1_VERSION                                           \
   ((LIBGAV1_MAJOR_VERSION << 16) | (LIBGAV1_MINOR_VERSION << 8) | \
diff --git a/libgav1/src/loop_restoration_info.cc b/libgav1/src/loop_restoration_info.cc
index 2dba57d..8c17711 100644
--- a/libgav1/src/loop_restoration_info.cc
+++ b/libgav1/src/loop_restoration_info.cc
@@ -133,7 +133,7 @@
 }
 
 void LoopRestorationInfo::ReadUnitCoefficients(
-    DaalaBitReader* const reader,
+    EntropyDecoder* const reader,
     SymbolDecoderContext* const symbol_decoder_context, Plane plane,
     int unit_id,
     std::array<RestorationUnitInfo, kMaxPlanes>* const reference_unit_info) {
@@ -161,7 +161,7 @@
 }
 
 void LoopRestorationInfo::ReadWienerInfo(
-    DaalaBitReader* const reader, Plane plane, int unit_id,
+    EntropyDecoder* const reader, Plane plane, int unit_id,
     std::array<RestorationUnitInfo, kMaxPlanes>* const reference_unit_info) {
   for (int i = WienerInfo::kVertical; i <= WienerInfo::kHorizontal; ++i) {
     if (plane != kPlaneY) {
@@ -198,7 +198,7 @@
 }
 
 void LoopRestorationInfo::ReadSgrProjInfo(
-    DaalaBitReader* const reader, Plane plane, int unit_id,
+    EntropyDecoder* const reader, Plane plane, int unit_id,
     std::array<RestorationUnitInfo, kMaxPlanes>* const reference_unit_info) {
   const int sgr_proj_index =
       static_cast<int>(reader->ReadLiteral(kSgrProjParamsBits));
diff --git a/libgav1/src/loop_restoration_info.h b/libgav1/src/loop_restoration_info.h
index f174b89..bff6746 100644
--- a/libgav1/src/loop_restoration_info.h
+++ b/libgav1/src/loop_restoration_info.h
@@ -19,8 +19,6 @@
 
 #include <array>
 #include <cstdint>
-#include <memory>
-#include <vector>
 
 #include "src/dsp/common.h"
 #include "src/symbol_decoder_context.h"
@@ -58,16 +56,16 @@
                                      uint8_t superres_scale_denominator,
                                      int row4x4, int column4x4,
                                      LoopRestorationUnitInfo* unit_info) const;
-  void ReadUnitCoefficients(DaalaBitReader* reader,
+  void ReadUnitCoefficients(EntropyDecoder* reader,
                             SymbolDecoderContext* symbol_decoder_context,
                             Plane plane, int unit_id,
                             std::array<RestorationUnitInfo, kMaxPlanes>*
                                 reference_unit_info);  // 5.11.58.
   void ReadWienerInfo(
-      DaalaBitReader* reader, Plane plane, int unit_id,
+      EntropyDecoder* reader, Plane plane, int unit_id,
       std::array<RestorationUnitInfo, kMaxPlanes>* reference_unit_info);
   void ReadSgrProjInfo(
-      DaalaBitReader* reader, Plane plane, int unit_id,
+      EntropyDecoder* reader, Plane plane, int unit_id,
       std::array<RestorationUnitInfo, kMaxPlanes>* reference_unit_info);
 
   // Getters.
diff --git a/libgav1/src/motion_vector.cc b/libgav1/src/motion_vector.cc
index fdb1875..36018ab 100644
--- a/libgav1/src/motion_vector.cc
+++ b/libgav1/src/motion_vector.cc
@@ -83,14 +83,12 @@
                  (gm.params[5] - (1 << kWarpedModelPrecisionBits)) * y +
                  gm.params[1];
   if (frame_header.allow_high_precision_mv) {
-    mv->mv[MotionVector::kRow] =
-        RightShiftWithRoundingSigned(yc, kWarpedModelPrecisionBits - 3);
-    mv->mv[MotionVector::kColumn] =
-        RightShiftWithRoundingSigned(xc, kWarpedModelPrecisionBits - 3);
+    mv->mv[0] = RightShiftWithRoundingSigned(yc, kWarpedModelPrecisionBits - 3);
+    mv->mv[1] = RightShiftWithRoundingSigned(xc, kWarpedModelPrecisionBits - 3);
   } else {
-    mv->mv[MotionVector::kRow] = MultiplyBy2(
+    mv->mv[0] = MultiplyBy2(
         RightShiftWithRoundingSigned(yc, kWarpedModelPrecisionBits - 2));
-    mv->mv[MotionVector::kColumn] = MultiplyBy2(
+    mv->mv[1] = MultiplyBy2(
         RightShiftWithRoundingSigned(xc, kWarpedModelPrecisionBits - 2));
     LowerMvPrecision(frame_header, mv);
   }
@@ -115,7 +113,7 @@
   // LowerMvPrecision() is not necessary, since the values in
   // |prediction_parameters.global_mv| and |mv_bp.mv| were generated by it.
   const auto global_motion_type = global_motion[bp.reference_frame[0]].type;
-  if (IsGlobalMvBlock(mv_bp.is_global_mv_block, global_motion_type)) {
+  if (IsGlobalMvBlock(mv_bp, global_motion_type)) {
     candidate_mv = prediction_parameters.global_mv[0];
   } else {
     candidate_mv = mv_bp.mv.mv[index];
@@ -126,7 +124,7 @@
   const int num_found = *num_mv_found;
   const auto result = std::find_if(ref_mv_stack, ref_mv_stack + num_found,
                                    [&candidate_mv](const MotionVector& ref_mv) {
-                                     return ref_mv == candidate_mv;
+                                     return ref_mv.mv32 == candidate_mv.mv32;
                                    });
   if (result != ref_mv_stack + num_found) {
     prediction_parameters.IncreaseWeight(std::distance(ref_mv_stack, result),
@@ -152,7 +150,7 @@
   CompoundMotionVector candidate_mv = mv_bp.mv;
   for (int i = 0; i < 2; ++i) {
     const auto global_motion_type = global_motion[bp.reference_frame[i]].type;
-    if (IsGlobalMvBlock(mv_bp.is_global_mv_block, global_motion_type)) {
+    if (IsGlobalMvBlock(mv_bp, global_motion_type)) {
       candidate_mv.mv[i] = prediction_parameters.global_mv[i];
     }
   }
@@ -164,7 +162,7 @@
   const auto result =
       std::find_if(compound_ref_mv_stack, compound_ref_mv_stack + num_found,
                    [&candidate_mv](const CompoundMotionVector& ref_mv) {
-                     return ref_mv == candidate_mv;
+                     return ref_mv.mv64 == candidate_mv.mv64;
                    });
   if (result != compound_ref_mv_stack + num_found) {
     prediction_parameters.IncreaseWeight(
@@ -172,7 +170,7 @@
     return;
   }
   if (num_found >= kMaxRefMvStackSize) return;
-  compound_ref_mv_stack[num_found] = candidate_mv;
+  compound_ref_mv_stack[num_found].mv64 = candidate_mv.mv64;
   prediction_parameters.SetWeightIndexStackEntry(num_found, weight);
   ++*num_mv_found;
 }
@@ -284,7 +282,8 @@
       frame_header.allow_high_precision_mv ? 2 : frame_header.force_integer_mv;
   const MotionVector* const global_mv = prediction_parameters->global_mv;
   if (is_compound) {
-    CompoundMotionVector candidate_mvs[kMaxTemporalMvCandidatesWithPadding];
+    alignas(kMaxAlignment)
+        CompoundMotionVector candidate_mvs[kMaxTemporalMvCandidatesWithPadding];
     const dsp::Dsp& dsp = *dsp::GetDspTable(8);
     dsp.mv_projection_compound[mv_projection_function_index](
         temporal_mvs, temporal_reference_offsets, reference_offsets, count,
@@ -310,7 +309,7 @@
       const auto result =
           std::find_if(compound_ref_mv_stack, compound_ref_mv_stack + num_found,
                        [&candidate_mv](const CompoundMotionVector& ref_mv) {
-                         return ref_mv == candidate_mv;
+                         return ref_mv.mv64 == candidate_mv.mv64;
                        });
       if (result != compound_ref_mv_stack + num_found) {
         prediction_parameters->IncreaseWeight(
@@ -318,7 +317,7 @@
         continue;
       }
       if (num_found >= kMaxRefMvStackSize) continue;
-      compound_ref_mv_stack[num_found] = candidate_mv;
+      compound_ref_mv_stack[num_found].mv64 = candidate_mv.mv64;
       prediction_parameters->SetWeightIndexStackEntry(num_found, 2);
       ++num_found;
     } while (++index < count);
@@ -337,7 +336,7 @@
     const auto result =
         std::find_if(ref_mv_stack, ref_mv_stack + num_found,
                      [&candidate_mv](const MotionVector& ref_mv) {
-                       return ref_mv == candidate_mv;
+                       return ref_mv.mv32 == candidate_mv.mv32;
                      });
     if (result != ref_mv_stack + num_found) {
       prediction_parameters->IncreaseWeight(std::distance(ref_mv_stack, result),
@@ -369,7 +368,7 @@
     const auto result =
         std::find_if(ref_mv_stack, ref_mv_stack + num_found,
                      [&candidate_mv](const MotionVector& ref_mv) {
-                       return ref_mv == candidate_mv;
+                       return ref_mv.mv32 == candidate_mv.mv32;
                      });
     if (result != ref_mv_stack + num_found) {
       prediction_parameters->IncreaseWeight(std::distance(ref_mv_stack, result),
@@ -563,8 +562,8 @@
       candidate_mv.mv[1] *= -1;
     }
     assert(num_found <= 2);
-    if ((num_found != 0 && ref_mv_stack[0] == candidate_mv) ||
-        (num_found == 2 && ref_mv_stack[1] == candidate_mv)) {
+    if ((num_found != 0 && ref_mv_stack[0].mv32 == candidate_mv.mv32) ||
+        (num_found == 2 && ref_mv_stack[1].mv32 == candidate_mv.mv32)) {
       continue;
     }
     ref_mv_stack[num_found] = candidate_mv;
@@ -624,16 +623,16 @@
       }
     }
     if (*num_mv_found == 1) {
-      if (combined_mvs[0] == compound_ref_mv_stack[0]) {
-        compound_ref_mv_stack[1] = combined_mvs[1];
+      if (combined_mvs[0].mv64 == compound_ref_mv_stack[0].mv64) {
+        compound_ref_mv_stack[1].mv64 = combined_mvs[1].mv64;
       } else {
-        compound_ref_mv_stack[1] = combined_mvs[0];
+        compound_ref_mv_stack[1].mv64 = combined_mvs[0].mv64;
       }
       prediction_parameters.SetWeightIndexStackEntry(1, 0);
     } else {
       assert(*num_mv_found == 0);
       for (int i = 0; i < 2; ++i) {
-        compound_ref_mv_stack[i] = combined_mvs[i];
+        compound_ref_mv_stack[i].mv64 = combined_mvs[i].mv64;
         prediction_parameters.SetWeightIndexStackEntry(i, 0);
       }
     }
diff --git a/libgav1/src/motion_vector.h b/libgav1/src/motion_vector.h
index d739e80..68d14fe 100644
--- a/libgav1/src/motion_vector.h
+++ b/libgav1/src/motion_vector.h
@@ -30,9 +30,11 @@
 
 namespace libgav1 {
 
-constexpr bool IsGlobalMvBlock(bool is_global_mv_block,
+constexpr bool IsGlobalMvBlock(const BlockParameters& bp,
                                GlobalMotionTransformationType type) {
-  return is_global_mv_block &&
+  return (bp.y_mode == kPredictionModeGlobalMv ||
+          bp.y_mode == kPredictionModeGlobalGlobalMv) &&
+         !IsBlockDimension4(bp.size) &&
          type > kGlobalMotionTransformationTypeTranslation;
 }
 
diff --git a/libgav1/src/obu_parser.cc b/libgav1/src/obu_parser.cc
index 69480d7..445450b 100644
--- a/libgav1/src/obu_parser.cc
+++ b/libgav1/src/obu_parser.cc
@@ -140,10 +140,10 @@
   int64_t scratch;
   ColorConfig* const color_config = &sequence_header->color_config;
   OBU_READ_BIT_OR_FAIL;
-  const auto high_bitdepth = static_cast<bool>(scratch);
+  const bool high_bitdepth = scratch != 0;
   if (sequence_header->profile == kProfile2 && high_bitdepth) {
     OBU_READ_BIT_OR_FAIL;
-    const auto is_twelve_bit = static_cast<bool>(scratch);
+    const bool is_twelve_bit = scratch != 0;
     color_config->bitdepth = is_twelve_bit ? 12 : 10;
   } else {
     color_config->bitdepth = high_bitdepth ? 10 : 8;
@@ -152,10 +152,10 @@
     color_config->is_monochrome = false;
   } else {
     OBU_READ_BIT_OR_FAIL;
-    color_config->is_monochrome = static_cast<bool>(scratch);
+    color_config->is_monochrome = scratch != 0;
   }
   OBU_READ_BIT_OR_FAIL;
-  const auto color_description_present_flag = static_cast<bool>(scratch);
+  const bool color_description_present_flag = scratch != 0;
   if (color_description_present_flag) {
     OBU_READ_LITERAL_OR_FAIL(8);
     color_config->color_primary = static_cast<ColorPrimary>(scratch);
@@ -230,7 +230,7 @@
       }
     }
     OBU_READ_BIT_OR_FAIL;
-    color_config->separate_uv_delta_q = static_cast<bool>(scratch);
+    color_config->separate_uv_delta_q = scratch != 0;
   }
   if (color_config->matrix_coefficients == kMatrixCoefficientsIdentity &&
       (color_config->subsampling_x != 0 || color_config->subsampling_y != 0)) {
@@ -246,7 +246,7 @@
 bool ObuParser::ParseTimingInfo(ObuSequenceHeader* sequence_header) {
   int64_t scratch;
   OBU_READ_BIT_OR_FAIL;
-  sequence_header->timing_info_present_flag = static_cast<bool>(scratch);
+  sequence_header->timing_info_present_flag = scratch != 0;
   if (!sequence_header->timing_info_present_flag) return true;
   TimingInfo* const info = &sequence_header->timing_info;
   OBU_READ_LITERAL_OR_FAIL(32);
@@ -262,7 +262,7 @@
     return false;
   }
   OBU_READ_BIT_OR_FAIL;
-  info->equal_picture_interval = static_cast<bool>(scratch);
+  info->equal_picture_interval = scratch != 0;
   if (info->equal_picture_interval) {
     OBU_READ_UVLC_OR_FAIL(info->num_ticks_per_picture);
     ++info->num_ticks_per_picture;
@@ -274,7 +274,7 @@
   if (!sequence_header->timing_info_present_flag) return true;
   int64_t scratch;
   OBU_READ_BIT_OR_FAIL;
-  sequence_header->decoder_model_info_present_flag = static_cast<bool>(scratch);
+  sequence_header->decoder_model_info_present_flag = scratch != 0;
   if (!sequence_header->decoder_model_info_present_flag) return true;
   DecoderModelInfo* const info = &sequence_header->decoder_model_info;
   OBU_READ_LITERAL_OR_FAIL(5);
@@ -293,7 +293,7 @@
   int64_t scratch;
   OBU_READ_BIT_OR_FAIL;
   sequence_header->decoder_model_present_for_operating_point[index] =
-      static_cast<bool>(scratch);
+      scratch != 0;
   if (!sequence_header->decoder_model_present_for_operating_point[index]) {
     return true;
   }
@@ -305,7 +305,7 @@
       sequence_header->decoder_model_info.encoder_decoder_buffer_delay_length);
   params->encoder_buffer_delay[index] = static_cast<uint32_t>(scratch);
   OBU_READ_BIT_OR_FAIL;
-  params->low_delay_mode_flag[index] = static_cast<bool>(scratch);
+  params->low_delay_mode_flag[index] = scratch != 0;
   return true;
 }
 
@@ -319,9 +319,9 @@
   }
   sequence_header.profile = static_cast<BitstreamProfile>(scratch);
   OBU_READ_BIT_OR_FAIL;
-  sequence_header.still_picture = static_cast<bool>(scratch);
+  sequence_header.still_picture = scratch != 0;
   OBU_READ_BIT_OR_FAIL;
-  sequence_header.reduced_still_picture_header = static_cast<bool>(scratch);
+  sequence_header.reduced_still_picture_header = scratch != 0;
   if (sequence_header.reduced_still_picture_header) {
     if (!sequence_header.still_picture) {
       LIBGAV1_DLOG(
@@ -338,7 +338,7 @@
       return false;
     }
     OBU_READ_BIT_OR_FAIL;
-    const auto initial_display_delay_present_flag = static_cast<bool>(scratch);
+    const bool initial_display_delay_present_flag = scratch != 0;
     OBU_READ_LITERAL_OR_FAIL(5);
     sequence_header.operating_points = static_cast<int>(1 + scratch);
     if (operating_point_ >= sequence_header.operating_points) {
@@ -374,7 +374,7 @@
       }
       if (initial_display_delay_present_flag) {
         OBU_READ_BIT_OR_FAIL;
-        if (static_cast<bool>(scratch)) {
+        if (scratch != 0) {
           OBU_READ_LITERAL_OR_FAIL(4);
           sequence_header.initial_display_delay[i] = 1 + scratch;
         }
@@ -391,7 +391,7 @@
   sequence_header.max_frame_height = static_cast<int32_t>(1 + scratch);
   if (!sequence_header.reduced_still_picture_header) {
     OBU_READ_BIT_OR_FAIL;
-    sequence_header.frame_id_numbers_present = static_cast<bool>(scratch);
+    sequence_header.frame_id_numbers_present = scratch != 0;
   }
   if (sequence_header.frame_id_numbers_present) {
     OBU_READ_LITERAL_OR_FAIL(4);
@@ -409,33 +409,33 @@
     }
   }
   OBU_READ_BIT_OR_FAIL;
-  sequence_header.use_128x128_superblock = static_cast<bool>(scratch);
+  sequence_header.use_128x128_superblock = scratch != 0;
   OBU_READ_BIT_OR_FAIL;
-  sequence_header.enable_filter_intra = static_cast<bool>(scratch);
+  sequence_header.enable_filter_intra = scratch != 0;
   OBU_READ_BIT_OR_FAIL;
-  sequence_header.enable_intra_edge_filter = static_cast<bool>(scratch);
+  sequence_header.enable_intra_edge_filter = scratch != 0;
   if (sequence_header.reduced_still_picture_header) {
     sequence_header.force_screen_content_tools = kSelectScreenContentTools;
     sequence_header.force_integer_mv = kSelectIntegerMv;
   } else {
     OBU_READ_BIT_OR_FAIL;
-    sequence_header.enable_interintra_compound = static_cast<bool>(scratch);
+    sequence_header.enable_interintra_compound = scratch != 0;
     OBU_READ_BIT_OR_FAIL;
-    sequence_header.enable_masked_compound = static_cast<bool>(scratch);
+    sequence_header.enable_masked_compound = scratch != 0;
     OBU_READ_BIT_OR_FAIL;
-    sequence_header.enable_warped_motion = static_cast<bool>(scratch);
+    sequence_header.enable_warped_motion = scratch != 0;
     OBU_READ_BIT_OR_FAIL;
-    sequence_header.enable_dual_filter = static_cast<bool>(scratch);
+    sequence_header.enable_dual_filter = scratch != 0;
     OBU_READ_BIT_OR_FAIL;
-    sequence_header.enable_order_hint = static_cast<bool>(scratch);
+    sequence_header.enable_order_hint = scratch != 0;
     if (sequence_header.enable_order_hint) {
       OBU_READ_BIT_OR_FAIL;
-      sequence_header.enable_jnt_comp = static_cast<bool>(scratch);
+      sequence_header.enable_jnt_comp = scratch != 0;
       OBU_READ_BIT_OR_FAIL;
-      sequence_header.enable_ref_frame_mvs = static_cast<bool>(scratch);
+      sequence_header.enable_ref_frame_mvs = scratch != 0;
     }
     OBU_READ_BIT_OR_FAIL;
-    sequence_header.choose_screen_content_tools = static_cast<bool>(scratch);
+    sequence_header.choose_screen_content_tools = scratch != 0;
     if (sequence_header.choose_screen_content_tools) {
       sequence_header.force_screen_content_tools = kSelectScreenContentTools;
     } else {
@@ -444,7 +444,7 @@
     }
     if (sequence_header.force_screen_content_tools > 0) {
       OBU_READ_BIT_OR_FAIL;
-      sequence_header.choose_integer_mv = static_cast<bool>(scratch);
+      sequence_header.choose_integer_mv = scratch != 0;
       if (sequence_header.choose_integer_mv) {
         sequence_header.force_integer_mv = kSelectIntegerMv;
       } else {
@@ -462,14 +462,14 @@
     }
   }
   OBU_READ_BIT_OR_FAIL;
-  sequence_header.enable_superres = static_cast<bool>(scratch);
+  sequence_header.enable_superres = scratch != 0;
   OBU_READ_BIT_OR_FAIL;
-  sequence_header.enable_cdef = static_cast<bool>(scratch);
+  sequence_header.enable_cdef = scratch != 0;
   OBU_READ_BIT_OR_FAIL;
-  sequence_header.enable_restoration = static_cast<bool>(scratch);
+  sequence_header.enable_restoration = scratch != 0;
   if (!ParseColorConfig(&sequence_header)) return false;
   OBU_READ_BIT_OR_FAIL;
-  sequence_header.film_grain_params_present = static_cast<bool>(scratch);
+  sequence_header.film_grain_params_present = scratch != 0;
   // Compare new sequence header with old sequence header.
   if (has_sequence_header_ &&
       sequence_header.ParametersChanged(sequence_header_)) {
@@ -546,7 +546,7 @@
 
   // Render Size.
   OBU_READ_BIT_OR_FAIL;
-  frame_header_.render_and_frame_size_different = static_cast<bool>(scratch);
+  frame_header_.render_and_frame_size_different = scratch != 0;
   if (frame_header_.render_and_frame_size_different) {
     OBU_READ_LITERAL_OR_FAIL(16);
     frame_header_.render_width = static_cast<int32_t>(1 + scratch);
@@ -567,7 +567,7 @@
   frame_header_.use_superres = false;
   if (sequence_header_.enable_superres) {
     OBU_READ_BIT_OR_FAIL;
-    frame_header_.use_superres = static_cast<bool>(scratch);
+    frame_header_.use_superres = scratch != 0;
   }
   if (frame_header_.use_superres) {
     OBU_READ_LITERAL_OR_FAIL(3);
@@ -878,14 +878,14 @@
   OBU_READ_LITERAL_OR_FAIL(3);
   loop_filter->sharpness = scratch;
   OBU_READ_BIT_OR_FAIL;
-  loop_filter->delta_enabled = static_cast<bool>(scratch);
+  loop_filter->delta_enabled = scratch != 0;
   if (loop_filter->delta_enabled) {
     OBU_READ_BIT_OR_FAIL;
-    loop_filter->delta_update = static_cast<bool>(scratch);
+    loop_filter->delta_update = scratch != 0;
     if (loop_filter->delta_update) {
       for (auto& ref_delta : loop_filter->ref_deltas) {
         OBU_READ_BIT_OR_FAIL;
-        const auto update_ref_delta = static_cast<bool>(scratch);
+        const bool update_ref_delta = scratch != 0;
         if (update_ref_delta) {
           int scratch_int;
           if (!bit_reader_->ReadInverseSignedLiteral(6, &scratch_int)) {
@@ -897,7 +897,7 @@
       }
       for (auto& mode_delta : loop_filter->mode_deltas) {
         OBU_READ_BIT_OR_FAIL;
-        const auto update_mode_delta = static_cast<bool>(scratch);
+        const bool update_mode_delta = scratch != 0;
         if (update_mode_delta) {
           int scratch_int;
           if (!bit_reader_->ReadInverseSignedLiteral(6, &scratch_int)) {
@@ -918,7 +918,7 @@
   int64_t scratch;
   *delta = 0;
   OBU_READ_BIT_OR_FAIL;
-  const auto delta_coded = static_cast<bool>(scratch);
+  const bool delta_coded = scratch != 0;
   if (delta_coded) {
     int scratch_int;
     if (!bit_reader_->ReadInverseSignedLiteral(6, &scratch_int)) {
@@ -940,7 +940,7 @@
     bool diff_uv_delta = false;
     if (sequence_header_.color_config.separate_uv_delta_q) {
       OBU_READ_BIT_OR_FAIL;
-      diff_uv_delta = static_cast<bool>(scratch);
+      diff_uv_delta = scratch != 0;
     }
     if (!ParseDeltaQuantizer(&quantizer->delta_dc[kPlaneU]) ||
         !ParseDeltaQuantizer(&quantizer->delta_ac[kPlaneU])) {
@@ -957,7 +957,7 @@
     }
   }
   OBU_READ_BIT_OR_FAIL;
-  quantizer->use_matrix = static_cast<bool>(scratch);
+  quantizer->use_matrix = scratch != 0;
   if (quantizer->use_matrix) {
     OBU_READ_LITERAL_OR_FAIL(4);
     quantizer->matrix_level[kPlaneY] = scratch;
@@ -987,20 +987,20 @@
   int64_t scratch;
   Segmentation* const segmentation = &frame_header_.segmentation;
   OBU_READ_BIT_OR_FAIL;
-  segmentation->enabled = static_cast<bool>(scratch);
+  segmentation->enabled = scratch != 0;
   if (!segmentation->enabled) return true;
   if (frame_header_.primary_reference_frame == kPrimaryReferenceNone) {
     segmentation->update_map = true;
     segmentation->update_data = true;
   } else {
     OBU_READ_BIT_OR_FAIL;
-    segmentation->update_map = static_cast<bool>(scratch);
+    segmentation->update_map = scratch != 0;
     if (segmentation->update_map) {
       OBU_READ_BIT_OR_FAIL;
-      segmentation->temporal_update = static_cast<bool>(scratch);
+      segmentation->temporal_update = scratch != 0;
     }
     OBU_READ_BIT_OR_FAIL;
-    segmentation->update_data = static_cast<bool>(scratch);
+    segmentation->update_data = scratch != 0;
     if (!segmentation->update_data) {
       // Part of the load_previous() function in the spec.
       const int prev_frame_index =
@@ -1014,7 +1014,7 @@
   for (int8_t i = 0; i < kMaxSegments; ++i) {
     for (int8_t j = 0; j < kSegmentFeatureMax; ++j) {
       OBU_READ_BIT_OR_FAIL;
-      segmentation->feature_enabled[i][j] = static_cast<bool>(scratch);
+      segmentation->feature_enabled[i][j] = scratch != 0;
       if (segmentation->feature_enabled[i][j]) {
         if (Segmentation::FeatureSigned(static_cast<SegmentFeature>(j))) {
           int scratch_int;
@@ -1049,7 +1049,7 @@
   int64_t scratch;
   if (frame_header_.quantizer.base_index > 0) {
     OBU_READ_BIT_OR_FAIL;
-    frame_header_.delta_q.present = static_cast<bool>(scratch);
+    frame_header_.delta_q.present = scratch != 0;
     if (frame_header_.delta_q.present) {
       OBU_READ_LITERAL_OR_FAIL(2);
       frame_header_.delta_q.scale = scratch;
@@ -1063,13 +1063,13 @@
   if (frame_header_.delta_q.present) {
     if (!frame_header_.allow_intrabc) {
       OBU_READ_BIT_OR_FAIL;
-      frame_header_.delta_lf.present = static_cast<bool>(scratch);
+      frame_header_.delta_lf.present = scratch != 0;
     }
     if (frame_header_.delta_lf.present) {
       OBU_READ_LITERAL_OR_FAIL(2);
       frame_header_.delta_lf.scale = scratch;
       OBU_READ_BIT_OR_FAIL;
-      frame_header_.delta_lf.multi = static_cast<bool>(scratch);
+      frame_header_.delta_lf.multi = scratch != 0;
     }
   }
   return true;
@@ -1193,7 +1193,7 @@
   int64_t scratch;
   if (!IsIntraFrame(frame_header_.frame_type)) {
     OBU_READ_BIT_OR_FAIL;
-    frame_header_.reference_mode_select = static_cast<bool>(scratch);
+    frame_header_.reference_mode_select = scratch != 0;
   }
   return true;
 }
@@ -1276,7 +1276,7 @@
   if (!IsSkipModeAllowed()) return true;
   int64_t scratch;
   OBU_READ_BIT_OR_FAIL;
-  frame_header_.skip_mode_present = static_cast<bool>(scratch);
+  frame_header_.skip_mode_present = scratch != 0;
   return true;
 }
 
@@ -1348,15 +1348,15 @@
     GlobalMotion* const global_motion = &frame_header_.global_motion[ref];
     int64_t scratch;
     OBU_READ_BIT_OR_FAIL;
-    const auto is_global = static_cast<bool>(scratch);
+    const bool is_global = scratch != 0;
     if (is_global) {
       OBU_READ_BIT_OR_FAIL;
-      const auto is_rot_zoom = static_cast<bool>(scratch);
+      const bool is_rot_zoom = scratch != 0;
       if (is_rot_zoom) {
         global_motion->type = kGlobalMotionTransformationTypeRotZoom;
       } else {
         OBU_READ_BIT_OR_FAIL;
-        const auto is_translation = static_cast<bool>(scratch);
+        const bool is_translation = scratch != 0;
         global_motion->type = is_translation
                                   ? kGlobalMotionTransformationTypeTranslation
                                   : kGlobalMotionTransformationTypeAffine;
@@ -1399,7 +1399,7 @@
   FilmGrainParams& film_grain_params = frame_header_.film_grain_params;
   int64_t scratch;
   OBU_READ_BIT_OR_FAIL;
-  film_grain_params.apply_grain = static_cast<bool>(scratch);
+  film_grain_params.apply_grain = scratch != 0;
   if (!film_grain_params.apply_grain) {
     // film_grain_params is already zero-initialized.
     return true;
@@ -1410,7 +1410,7 @@
   film_grain_params.update_grain = true;
   if (frame_header_.frame_type == kFrameInter) {
     OBU_READ_BIT_OR_FAIL;
-    film_grain_params.update_grain = static_cast<bool>(scratch);
+    film_grain_params.update_grain = scratch != 0;
   }
   if (!film_grain_params.update_grain) {
     OBU_READ_LITERAL_OR_FAIL(3);
@@ -1481,7 +1481,7 @@
     film_grain_params.chroma_scaling_from_luma = false;
   } else {
     OBU_READ_BIT_OR_FAIL;
-    film_grain_params.chroma_scaling_from_luma = static_cast<bool>(scratch);
+    film_grain_params.chroma_scaling_from_luma = scratch != 0;
   }
   if (sequence_header_.color_config.is_monochrome ||
       film_grain_params.chroma_scaling_from_luma ||
@@ -1597,9 +1597,9 @@
     film_grain_params.v_offset = static_cast<int16_t>(scratch - 256);
   }
   OBU_READ_BIT_OR_FAIL;
-  film_grain_params.overlap_flag = static_cast<bool>(scratch);
+  film_grain_params.overlap_flag = scratch != 0;
   OBU_READ_BIT_OR_FAIL;
-  film_grain_params.clip_to_restricted_range = static_cast<bool>(scratch);
+  film_grain_params.clip_to_restricted_range = scratch != 0;
   return true;
 }
 
@@ -1626,7 +1626,7 @@
       minlog2_tile_columns, TileLog2(sb_max_tile_area, sb_rows * sb_columns));
   int64_t scratch;
   OBU_READ_BIT_OR_FAIL;
-  tile_info->uniform_spacing = static_cast<bool>(scratch);
+  tile_info->uniform_spacing = scratch != 0;
   if (tile_info->uniform_spacing) {
     // Read tile columns.
     tile_info->tile_columns_log2 = minlog2_tile_columns;
@@ -1759,7 +1759,7 @@
   }
   int64_t scratch;
   OBU_READ_BIT_OR_FAIL;
-  frame_header_.allow_warped_motion = static_cast<bool>(scratch);
+  frame_header_.allow_warped_motion = scratch != 0;
   return true;
 }
 
@@ -1774,7 +1774,7 @@
     }
   } else {
     OBU_READ_BIT_OR_FAIL;
-    frame_header_.show_existing_frame = static_cast<bool>(scratch);
+    frame_header_.show_existing_frame = scratch != 0;
     if (frame_header_.show_existing_frame) {
       OBU_READ_LITERAL_OR_FAIL(3);
       frame_header_.frame_to_show = scratch;
@@ -1849,7 +1849,7 @@
     frame_header_.frame_type = static_cast<FrameType>(scratch);
     current_frame_->set_frame_type(frame_header_.frame_type);
     OBU_READ_BIT_OR_FAIL;
-    frame_header_.show_frame = static_cast<bool>(scratch);
+    frame_header_.show_frame = scratch != 0;
     if (frame_header_.show_frame &&
         sequence_header_.decoder_model_info_present_flag &&
         !sequence_header_.timing_info.equal_picture_interval) {
@@ -1861,7 +1861,7 @@
       frame_header_.showable_frame = (frame_header_.frame_type != kFrameKey);
     } else {
       OBU_READ_BIT_OR_FAIL;
-      frame_header_.showable_frame = static_cast<bool>(scratch);
+      frame_header_.showable_frame = scratch != 0;
     }
     current_frame_->set_showable_frame(frame_header_.showable_frame);
     if (frame_header_.frame_type == kFrameSwitch ||
@@ -1869,7 +1869,7 @@
       frame_header_.error_resilient_mode = true;
     } else {
       OBU_READ_BIT_OR_FAIL;
-      frame_header_.error_resilient_mode = static_cast<bool>(scratch);
+      frame_header_.error_resilient_mode = scratch != 0;
     }
   }
   if (frame_header_.frame_type == kFrameKey && frame_header_.show_frame) {
@@ -1877,14 +1877,14 @@
     decoder_state_.reference_frame.fill(nullptr);
   }
   OBU_READ_BIT_OR_FAIL;
-  frame_header_.enable_cdf_update = !static_cast<bool>(scratch);
+  frame_header_.enable_cdf_update = scratch == 0;
   if (sequence_header_.force_screen_content_tools ==
       kSelectScreenContentTools) {
     OBU_READ_BIT_OR_FAIL;
-    frame_header_.allow_screen_content_tools = static_cast<bool>(scratch);
+    frame_header_.allow_screen_content_tools = scratch != 0;
   } else {
     frame_header_.allow_screen_content_tools =
-        static_cast<bool>(sequence_header_.force_screen_content_tools);
+        sequence_header_.force_screen_content_tools != 0;
   }
   if (frame_header_.allow_screen_content_tools) {
     if (sequence_header_.force_integer_mv == kSelectIntegerMv) {
@@ -1934,7 +1934,7 @@
     frame_header_.frame_size_override_flag = true;
   } else if (!sequence_header_.reduced_still_picture_header) {
     OBU_READ_BIT_OR_FAIL;
-    frame_header_.frame_size_override_flag = static_cast<bool>(scratch);
+    frame_header_.frame_size_override_flag = scratch != 0;
   }
   if (sequence_header_.order_hint_bits > 0) {
     OBU_READ_LITERAL_OR_FAIL(sequence_header_.order_hint_bits);
@@ -1950,7 +1950,7 @@
   }
   if (sequence_header_.decoder_model_info_present_flag) {
     OBU_READ_BIT_OR_FAIL;
-    const auto buffer_removal_time_present = static_cast<bool>(scratch);
+    const bool buffer_removal_time_present = scratch != 0;
     if (buffer_removal_time_present) {
       for (int i = 0; i < sequence_header_.operating_points; ++i) {
         if (!sequence_header_.decoder_model_present_for_operating_point[i]) {
@@ -1992,14 +1992,14 @@
     if (frame_header_.allow_screen_content_tools &&
         frame_header_.width == frame_header_.upscaled_width) {
       OBU_READ_BIT_OR_FAIL;
-      frame_header_.allow_intrabc = static_cast<bool>(scratch);
+      frame_header_.allow_intrabc = scratch != 0;
     }
   } else {
     if (!sequence_header_.enable_order_hint) {
       frame_header_.frame_refs_short_signaling = false;
     } else {
       OBU_READ_BIT_OR_FAIL;
-      frame_header_.frame_refs_short_signaling = static_cast<bool>(scratch);
+      frame_header_.frame_refs_short_signaling = scratch != 0;
       if (frame_header_.frame_refs_short_signaling) {
         OBU_READ_LITERAL_OR_FAIL(3);
         const int8_t last_frame_idx = scratch;
@@ -2054,7 +2054,7 @@
       // Section 5.9.7.
       for (int index : frame_header_.reference_frame_index) {
         OBU_READ_BIT_OR_FAIL;
-        frame_header_.found_reference = static_cast<bool>(scratch);
+        frame_header_.found_reference = scratch != 0;
         if (frame_header_.found_reference) {
           const RefCountedBuffer* reference_frame =
               decoder_state_.reference_frame[index].get();
@@ -2079,10 +2079,10 @@
       frame_header_.allow_high_precision_mv = false;
     } else {
       OBU_READ_BIT_OR_FAIL;
-      frame_header_.allow_high_precision_mv = static_cast<bool>(scratch);
+      frame_header_.allow_high_precision_mv = scratch != 0;
     }
     OBU_READ_BIT_OR_FAIL;
-    const auto is_filter_switchable = static_cast<bool>(scratch);
+    const bool is_filter_switchable = scratch != 0;
     if (is_filter_switchable) {
       frame_header_.interpolation_filter = kInterpolationFilterSwitchable;
     } else {
@@ -2091,13 +2091,13 @@
           static_cast<InterpolationFilter>(scratch);
     }
     OBU_READ_BIT_OR_FAIL;
-    frame_header_.is_motion_mode_switchable = static_cast<bool>(scratch);
+    frame_header_.is_motion_mode_switchable = scratch != 0;
     if (frame_header_.error_resilient_mode ||
         !sequence_header_.enable_ref_frame_mvs) {
       frame_header_.use_ref_frame_mvs = false;
     } else {
       OBU_READ_BIT_OR_FAIL;
-      frame_header_.use_ref_frame_mvs = static_cast<bool>(scratch);
+      frame_header_.use_ref_frame_mvs = scratch != 0;
     }
   }
   // At this point, we have parsed the frame and render sizes and computed
@@ -2151,7 +2151,7 @@
   if (frame_header_.enable_cdf_update &&
       !sequence_header_.reduced_still_picture_header) {
     OBU_READ_BIT_OR_FAIL;
-    frame_header_.enable_frame_end_update_cdf = !static_cast<bool>(scratch);
+    frame_header_.enable_frame_end_update_cdf = scratch == 0;
   } else {
     frame_header_.enable_frame_end_update_cdf = false;
   }
@@ -2189,7 +2189,7 @@
   if (!status) return false;
   int64_t scratch;
   OBU_READ_BIT_OR_FAIL;
-  frame_header_.reduced_tx_set = static_cast<bool>(scratch);
+  frame_header_.reduced_tx_set = scratch != 0;
   status = ParseGlobalMotionParameters();
   if (!status) return false;
   current_frame_->SetGlobalMotions(frame_header_.global_motion);
@@ -2236,16 +2236,13 @@
     const auto spatial_layers_count = static_cast<int>(scratch) + 1;
     // spatial_layer_dimensions_present_flag
     OBU_READ_BIT_OR_FAIL;
-    const auto spatial_layer_dimensions_present_flag =
-        static_cast<bool>(scratch);
+    const auto spatial_layer_dimensions_present_flag = scratch != 0;
     // spatial_layer_description_present_flag
     OBU_READ_BIT_OR_FAIL;
-    const auto spatial_layer_description_present_flag =
-        static_cast<bool>(scratch);
+    const auto spatial_layer_description_present_flag = scratch != 0;
     // temporal_group_description_present_flag
     OBU_READ_BIT_OR_FAIL;
-    const auto temporal_group_description_present_flag =
-        static_cast<bool>(scratch);
+    const auto temporal_group_description_present_flag = scratch != 0;
     // scalability_structure_reserved_3bits
     OBU_READ_LITERAL_OR_FAIL(3);
     if (scratch != 0) {
@@ -2297,7 +2294,7 @@
   OBU_READ_LITERAL_OR_FAIL(5);
   // full_timestamp_flag
   OBU_READ_BIT_OR_FAIL;
-  const auto full_timestamp_flag = static_cast<bool>(scratch);
+  const bool full_timestamp_flag = scratch != 0;
   // discontinuity_flag
   OBU_READ_BIT_OR_FAIL;
   // cnt_dropped_flag
@@ -2329,7 +2326,7 @@
   } else {
     // seconds_flag
     OBU_READ_BIT_OR_FAIL;
-    const auto seconds_flag = static_cast<bool>(scratch);
+    const bool seconds_flag = scratch != 0;
     if (seconds_flag) {
       // seconds_value
       OBU_READ_LITERAL_OR_FAIL(6);
@@ -2340,7 +2337,7 @@
       }
       // minutes_flag
       OBU_READ_BIT_OR_FAIL;
-      const auto minutes_flag = static_cast<bool>(scratch);
+      const bool minutes_flag = scratch != 0;
       if (minutes_flag) {
         // minutes_value
         OBU_READ_LITERAL_OR_FAIL(6);
@@ -2351,7 +2348,7 @@
         }
         // hours_flag
         OBU_READ_BIT_OR_FAIL;
-        const auto hours_flag = static_cast<bool>(scratch);
+        const bool hours_flag = scratch != 0;
         if (hours_flag) {
           // hours_value
           OBU_READ_LITERAL_OR_FAIL(5);
@@ -2560,7 +2557,7 @@
   }
   int64_t scratch;
   OBU_READ_BIT_OR_FAIL;
-  const auto tile_start_and_end_present_flag = static_cast<bool>(scratch);
+  const bool tile_start_and_end_present_flag = scratch != 0;
   if (!tile_start_and_end_present_flag) {
     if (!bit_reader_->AlignToNextByte()) {
       LIBGAV1_DLOG(ERROR, "Byte alignment has non zero bits.");
@@ -2600,9 +2597,9 @@
   OBU_READ_LITERAL_OR_FAIL(4);
   obu_header.type = static_cast<libgav1::ObuType>(scratch);
   OBU_READ_BIT_OR_FAIL;
-  const auto extension_flag = static_cast<bool>(scratch);
+  const bool extension_flag = scratch != 0;
   OBU_READ_BIT_OR_FAIL;
-  obu_header.has_size_field = static_cast<bool>(scratch);
+  obu_header.has_size_field = scratch != 0;
   OBU_READ_BIT_OR_FAIL;  // reserved.
   if (scratch != 0) {
     LIBGAV1_DLOG(WARNING, "obu_reserved_1bit is not zero.");
diff --git a/libgav1/src/obu_parser.h b/libgav1/src/obu_parser.h
index c4619ed..3f452ef 100644
--- a/libgav1/src/obu_parser.h
+++ b/libgav1/src/obu_parser.h
@@ -22,6 +22,7 @@
 #include <cstdint>
 #include <memory>
 #include <type_traits>
+#include <utility>
 
 #include "src/buffer_pool.h"
 #include "src/decoder_state.h"
diff --git a/libgav1/src/post_filter.h b/libgav1/src/post_filter.h
index dfcd08e..a247075 100644
--- a/libgav1/src/post_filter.h
+++ b/libgav1/src/post_filter.h
@@ -160,7 +160,7 @@
             frame_header.cdef.uv_secondary_strength[0] > 0) &&
            (do_post_filter_mask & 0x02) != 0;
   }
-  bool DoCdef() const { return DoCdef(frame_header_, do_post_filter_mask_); }
+  bool DoCdef() const { return do_cdef_; }
   // If filter levels for Y plane (0 for vertical, 1 for horizontal),
   // are all zero, deblock filter will not be applied.
   static bool DoDeblock(const ObuFrameHeader& frame_header,
@@ -169,9 +169,7 @@
             frame_header.loop_filter.level[1] > 0) &&
            (do_post_filter_mask & 0x01) != 0;
   }
-  bool DoDeblock() const {
-    return DoDeblock(frame_header_, do_post_filter_mask_);
-  }
+  bool DoDeblock() const { return do_deblock_; }
 
   uint8_t GetZeroDeltaDeblockFilterLevel(int segment_id, int level_index,
                                          ReferenceFrameType type,
@@ -197,9 +195,7 @@
             loop_restoration.type[kPlaneV] != kLoopRestorationTypeNone) &&
            (do_post_filter_mask & 0x08) != 0;
   }
-  bool DoRestoration() const {
-    return DoRestoration(loop_restoration_, do_post_filter_mask_, planes_);
-  }
+  bool DoRestoration() const { return do_restoration_; }
 
   // Returns a pointer to the unfiltered buffer. This is used by the Tile class
   // to determine where to write the output of the tile decoding process taking
@@ -214,9 +210,7 @@
     return frame_header.width != frame_header.upscaled_width &&
            (do_post_filter_mask & 0x04) != 0;
   }
-  bool DoSuperRes() const {
-    return DoSuperRes(frame_header_, do_post_filter_mask_);
-  }
+  bool DoSuperRes() const { return do_superres_; }
   LoopRestorationInfo* restoration_info() const { return restoration_info_; }
   uint8_t* GetBufferOffset(uint8_t* base_buffer, int stride, Plane plane,
                            int row, int column) const {
@@ -244,13 +238,9 @@
  private:
   // The type of the HorizontalDeblockFilter and VerticalDeblockFilter member
   // functions.
-  using DeblockFilter = void (PostFilter::*)(int row4x4_start,
-                                             int column4x4_start);
-  // The lookup table for picking the deblock filter, according to deblock
-  // filter type.
-  const DeblockFilter deblock_filter_func_[2] = {
-      &PostFilter::VerticalDeblockFilter, &PostFilter::HorizontalDeblockFilter};
-
+  using DeblockFilter = void (PostFilter::*)(int row4x4_start, int row4x4_end,
+                                             int column4x4_start,
+                                             int column4x4_end);
   // Functions common to all post filters.
 
   // Extends the frame by setting the border pixel values to the one from its
@@ -308,13 +298,6 @@
 
   // Functions for the Deblocking filter.
 
-  static int GetIndex(int row4x4) { return DivideBy4(row4x4); }
-  static int GetShift(int row4x4, int column4x4) {
-    return ((row4x4 & 3) << 4) | column4x4;
-  }
-  int GetDeblockUnitId(int row_unit, int column_unit) const {
-    return row_unit * num_64x64_blocks_per_row_ + column_unit;
-  }
   bool GetHorizontalDeblockFilterEdgeInfo(int row4x4, int column4x4,
                                           uint8_t* level, int* step,
                                           int* filter_length) const;
@@ -330,8 +313,10 @@
                                           BlockParameters* const* bp_ptr,
                                           uint8_t* level_u, uint8_t* level_v,
                                           int* step, int* filter_length) const;
-  void HorizontalDeblockFilter(int row4x4_start, int column4x4_start);
-  void VerticalDeblockFilter(int row4x4_start, int column4x4_start);
+  void HorizontalDeblockFilter(int row4x4_start, int row4x4_end,
+                               int column4x4_start, int column4x4_end);
+  void VerticalDeblockFilter(int row4x4_start, int row4x4_end,
+                             int column4x4_start, int column4x4_end);
   // HorizontalDeblockFilter and VerticalDeblockFilter must have the correct
   // signature.
   static_assert(std::is_same<decltype(&PostFilter::HorizontalDeblockFilter),
@@ -340,9 +325,6 @@
   static_assert(std::is_same<decltype(&PostFilter::VerticalDeblockFilter),
                              DeblockFilter>::value,
                 "");
-  // Applies deblock filtering for the superblock row starting at |row4x4| with
-  // a height of 4*|sb4x4|.
-  void ApplyDeblockFilterForOneSuperBlockRow(int row4x4, int sb4x4);
   // Worker function used for multi-threaded deblocking.
   template <LoopFilterType loop_filter_type>
   void DeblockFilterWorker(std::atomic<int>* row4x4_atomic);
@@ -465,13 +447,13 @@
                              WorkerFunction>::value,
                 "");
 
+  // The lookup table for picking the deblock filter, according to deblock
+  // filter type.
+  const DeblockFilter deblock_filter_func_[2] = {
+      &PostFilter::VerticalDeblockFilter, &PostFilter::HorizontalDeblockFilter};
   const ObuFrameHeader& frame_header_;
   const LoopRestoration& loop_restoration_;
   const dsp::Dsp& dsp_;
-  const int num_64x64_blocks_per_row_;
-  const int upscaled_width_;
-  const int width_;
-  const int height_;
   const int8_t bitdepth_;
   const int8_t subsampling_x_[kMaxPlanes];
   const int8_t subsampling_y_[kMaxPlanes];
@@ -480,6 +462,10 @@
   const uint8_t* const inner_thresh_;
   const uint8_t* const outer_thresh_;
   const bool needs_chroma_deblock_;
+  const bool do_cdef_;
+  const bool do_deblock_;
+  const bool do_restoration_;
+  const bool do_superres_;
   // This stores the deblocking filter levels assuming that the delta is zero.
   // This will be used by all superblocks whose delta is zero (without having to
   // recompute them). The dimensions (in order) are: segment_id, level_index
@@ -492,7 +478,8 @@
     int initial_subpixel_x;
     int step;
   } super_res_info_[kMaxPlanes];
-  const Array2D<int16_t>& cdef_index_;
+  const Array2D<int8_t>& cdef_index_;
+  const Array2D<uint8_t>& cdef_skip_;
   const Array2D<TransformSize>& inter_transform_sizes_;
   LoopRestorationInfo* const restoration_info_;
   uint8_t* const superres_coefficients_[kNumPlaneTypes];
@@ -528,7 +515,6 @@
   //   (1). Loop Restoration is on.
   //   (2). Cdef is on, or multi-threading is enabled for post filter.
   YuvBuffer& loop_restoration_border_;
-  const uint8_t do_post_filter_mask_;
   ThreadPool* const thread_pool_;
 
   // Tracks the progress of the post filters.
diff --git a/libgav1/src/post_filter/cdef.cc b/libgav1/src/post_filter/cdef.cc
index f32b0a0..037fc17 100644
--- a/libgav1/src/post_filter/cdef.cc
+++ b/libgav1/src/post_filter/cdef.cc
@@ -126,8 +126,8 @@
   const int8_t subsampling_y = y_plane ? 0 : subsampling_y_[kPlaneU];
   const int start_x = MultiplyBy4(column4x4) >> subsampling_x;
   const int start_y = MultiplyBy4(row4x4) >> subsampling_y;
-  const int plane_width = SubsampledValue(width_, subsampling_x);
-  const int plane_height = SubsampledValue(height_, subsampling_y);
+  const int plane_width = SubsampledValue(frame_header_.width, subsampling_x);
+  const int plane_height = SubsampledValue(frame_header_.height, subsampling_y);
   const int block_width = MultiplyBy4(block_width4x4) >> subsampling_x;
   const int block_height = MultiplyBy4(block_height4x4) >> subsampling_y;
   // unit_width, unit_height are the same as block_width, block_height unless
@@ -319,7 +319,7 @@
   }
 
   const bool is_frame_right =
-      MultiplyBy4(column4x4_start) + MultiplyBy4(block_width4x4) >= width_;
+      MultiplyBy4(column4x4_start + block_width4x4) >= frame_header_.width;
   if (!is_frame_right && thread_pool_ != nullptr) {
     // Backup the last 2 columns for use in the next iteration.
     use_border_columns[border_columns_dst_index][0] = true;
@@ -356,104 +356,111 @@
 
   const bool compute_direction_and_variance =
       (y_primary_strength | frame_header_.cdef.uv_primary_strength[index]) != 0;
-  BlockParameters* const* bp_row0_base =
-      block_parameters_.Address(row4x4_start, column4x4_start);
-  BlockParameters* const* bp_row1_base =
-      bp_row0_base + block_parameters_.columns4x4();
-  const int bp_stride = MultiplyBy2(block_parameters_.columns4x4());
+  const uint8_t* skip_row =
+      &cdef_skip_[row4x4_start >> 1][column4x4_start >> 4];
+  const int skip_stride = cdef_skip_.columns();
   int row4x4 = row4x4_start;
   do {
     uint8_t* cdef_buffer_base = cdef_buffer_row_base[kPlaneY];
     const uint8_t* src_buffer_base = src_buffer_row_base[kPlaneY];
     const uint16_t* cdef_src_base = cdef_src_row_base[kPlaneY];
-    BlockParameters* const* bp0 = bp_row0_base;
-    BlockParameters* const* bp1 = bp_row1_base;
     int column4x4 = column4x4_start;
-    do {
-      const int block_width = kStep;
-      const int block_height = kStep;
-      const int cdef_stride = frame_buffer_.stride(kPlaneY);
-      uint8_t* const cdef_buffer = cdef_buffer_base;
-      const uint16_t* const cdef_src = cdef_src_base;
-      const int src_stride = frame_buffer_.stride(kPlaneY);
-      const uint8_t* const src_buffer = src_buffer_base;
 
-      const bool skip = (*bp0)->skip && (*(bp0 + 1))->skip && (*bp1)->skip &&
-                        (*(bp1 + 1))->skip;
-
-      if (skip) {  // No cdef filtering.
+    if (*skip_row == 0) {
+      for (int i = 0; i < DivideBy2(block_width4x4); ++i, ++y_index) {
         direction_y[y_index] = kCdefSkip;
-        if (thread_pool_ == nullptr) {
-          CopyPixels(src_buffer, src_stride, cdef_buffer, cdef_stride,
-                     block_width, block_height, sizeof(Pixel));
-        }
-      } else {
-        // Zero out residual skip flag.
-        direction_y[y_index] = 0;
+      }
+      if (thread_pool_ == nullptr) {
+        CopyPixels(src_buffer_base, frame_buffer_.stride(kPlaneY),
+                   cdef_buffer_base, frame_buffer_.stride(kPlaneY), 64, kStep,
+                   sizeof(Pixel));
+      }
+    } else {
+      do {
+        const int block_width = kStep;
+        const int block_height = kStep;
+        const int cdef_stride = frame_buffer_.stride(kPlaneY);
+        uint8_t* const cdef_buffer = cdef_buffer_base;
+        const uint16_t* const cdef_src = cdef_src_base;
+        const int src_stride = frame_buffer_.stride(kPlaneY);
+        const uint8_t* const src_buffer = src_buffer_base;
 
-        int variance = 0;
-        if (compute_direction_and_variance) {
-          if (thread_pool_ == nullptr ||
-              row4x4 + kStep4x4 < row4x4_start + block_height4x4) {
-            dsp_.cdef_direction(src_buffer, src_stride, &direction_y[y_index],
-                                &variance);
-          } else if (sizeof(Pixel) == 2) {
-            dsp_.cdef_direction(cdef_src, kCdefUnitSizeWithBorders * 2,
-                                &direction_y[y_index], &variance);
-          } else {
-            // If we are in the last row4x4 for this unit, then the last two
-            // input rows have to come from |cdef_border_|. Since we already
-            // have |cdef_src| populated correctly, use that as the input
-            // for the direction process.
-            uint8_t direction_src[8][8];
-            const uint16_t* cdef_src_line = cdef_src;
-            for (auto& direction_src_line : direction_src) {
-              for (int i = 0; i < 8; ++i) {
-                direction_src_line[i] = cdef_src_line[i];
-              }
-              cdef_src_line += kCdefUnitSizeWithBorders;
-            }
-            dsp_.cdef_direction(direction_src, 8, &direction_y[y_index],
-                                &variance);
-          }
-        }
-        const int direction =
-            (y_primary_strength == 0) ? 0 : direction_y[y_index];
-        const int variance_strength =
-            ((variance >> 6) != 0) ? std::min(FloorLog2(variance >> 6), 12) : 0;
-        const uint8_t primary_strength =
-            (variance != 0)
-                ? (y_primary_strength * (4 + variance_strength) + 8) >> 4
-                : 0;
-        if ((primary_strength | y_secondary_strength) == 0) {
+        const uint8_t skip_shift = (column4x4 >> 1) & 0x7;
+        const bool skip = ((*skip_row >> skip_shift) & 1) == 0;
+        if (skip) {  // No cdef filtering.
+          direction_y[y_index] = kCdefSkip;
           if (thread_pool_ == nullptr) {
             CopyPixels(src_buffer, src_stride, cdef_buffer, cdef_stride,
                        block_width, block_height, sizeof(Pixel));
           }
         } else {
-          const int strength_index =
-              y_strength_index | (static_cast<int>(primary_strength == 0) << 1);
-          dsp_.cdef_filters[1][strength_index](
-              cdef_src, kCdefUnitSizeWithBorders, block_height,
-              primary_strength, y_secondary_strength,
-              frame_header_.cdef.damping, direction, cdef_buffer, cdef_stride);
-        }
-      }
-      cdef_buffer_base += column_step[kPlaneY];
-      src_buffer_base += column_step[kPlaneY];
-      cdef_src_base += column_step[kPlaneY] / sizeof(Pixel);
+          // Zero out residual skip flag.
+          direction_y[y_index] = 0;
 
-      bp0 += kStep4x4;
-      bp1 += kStep4x4;
-      column4x4 += kStep4x4;
-      y_index++;
-    } while (column4x4 < column4x4_start + block_width4x4);
+          int variance = 0;
+          if (compute_direction_and_variance) {
+            if (thread_pool_ == nullptr ||
+                row4x4 + kStep4x4 < row4x4_start + block_height4x4) {
+              dsp_.cdef_direction(src_buffer, src_stride, &direction_y[y_index],
+                                  &variance);
+            } else if (sizeof(Pixel) == 2) {
+              dsp_.cdef_direction(cdef_src, kCdefUnitSizeWithBorders * 2,
+                                  &direction_y[y_index], &variance);
+            } else {
+              // If we are in the last row4x4 for this unit, then the last two
+              // input rows have to come from |cdef_border_|. Since we already
+              // have |cdef_src| populated correctly, use that as the input
+              // for the direction process.
+              uint8_t direction_src[8][8];
+              const uint16_t* cdef_src_line = cdef_src;
+              for (auto& direction_src_line : direction_src) {
+                for (int i = 0; i < 8; ++i) {
+                  direction_src_line[i] = cdef_src_line[i];
+                }
+                cdef_src_line += kCdefUnitSizeWithBorders;
+              }
+              dsp_.cdef_direction(direction_src, 8, &direction_y[y_index],
+                                  &variance);
+            }
+          }
+          const int direction =
+              (y_primary_strength == 0) ? 0 : direction_y[y_index];
+          const int variance_strength =
+              ((variance >> 6) != 0) ? std::min(FloorLog2(variance >> 6), 12)
+                                     : 0;
+          const uint8_t primary_strength =
+              (variance != 0)
+                  ? (y_primary_strength * (4 + variance_strength) + 8) >> 4
+                  : 0;
+          if ((primary_strength | y_secondary_strength) == 0) {
+            if (thread_pool_ == nullptr) {
+              CopyPixels(src_buffer, src_stride, cdef_buffer, cdef_stride,
+                         block_width, block_height, sizeof(Pixel));
+            }
+          } else {
+            const int strength_index =
+                y_strength_index |
+                (static_cast<int>(primary_strength == 0) << 1);
+            dsp_.cdef_filters[1][strength_index](
+                cdef_src, kCdefUnitSizeWithBorders, block_height,
+                primary_strength, y_secondary_strength,
+                frame_header_.cdef.damping, direction, cdef_buffer,
+                cdef_stride);
+          }
+        }
+        cdef_buffer_base += column_step[kPlaneY];
+        src_buffer_base += column_step[kPlaneY];
+        cdef_src_base += column_step[kPlaneY] / sizeof(Pixel);
+
+        column4x4 += kStep4x4;
+        y_index++;
+      } while (column4x4 < column4x4_start + block_width4x4);
+    }
 
     cdef_buffer_row_base[kPlaneY] += cdef_buffer_row_base_stride[kPlaneY];
     src_buffer_row_base[kPlaneY] += src_buffer_row_base_stride[kPlaneY];
     cdef_src_row_base[kPlaneY] += cdef_src_row_base_stride[kPlaneY];
-    bp_row0_base += bp_stride;
-    bp_row1_base += bp_stride;
+    skip_row += skip_stride;
     row4x4 += kStep4x4;
   } while (row4x4 < row4x4_start + block_height4x4);
 
@@ -591,9 +598,12 @@
     uint16_t* cdef_block, uint8_t border_columns[2][kMaxPlanes][256],
     int row4x4, int block_height4x4) {
   bool use_border_columns[2][2] = {};
-  for (int column4x4 = 0; column4x4 < frame_header_.columns4x4;
-       column4x4 += kStep64x64) {
-    const int index = cdef_index_[DivideBy16(row4x4)][DivideBy16(column4x4)];
+  const bool non_zero_index = frame_header_.cdef.bits > 0;
+  const int8_t* cdef_index =
+      non_zero_index ? cdef_index_[DivideBy16(row4x4)] : nullptr;
+  int column4x4 = 0;
+  do {
+    const int index = non_zero_index ? *cdef_index++ : 0;
     const int block_width4x4 =
         std::min(kStep64x64, frame_header_.columns4x4 - column4x4);
 
@@ -602,29 +612,32 @@
       ApplyCdefForOneUnit<uint16_t>(cdef_block, index, block_width4x4,
                                     block_height4x4, row4x4, column4x4,
                                     border_columns, use_border_columns);
-      continue;
+    } else  // NOLINT
+#endif      // LIBGAV1_MAX_BITDEPTH >= 10
+    {
+      ApplyCdefForOneUnit<uint8_t>(cdef_block, index, block_width4x4,
+                                   block_height4x4, row4x4, column4x4,
+                                   border_columns, use_border_columns);
     }
-#endif  // LIBGAV1_MAX_BITDEPTH >= 10
-    ApplyCdefForOneUnit<uint8_t>(cdef_block, index, block_width4x4,
-                                 block_height4x4, row4x4, column4x4,
-                                 border_columns, use_border_columns);
-  }
+    column4x4 += kStep64x64;
+  } while (column4x4 < frame_header_.columns4x4);
 }
 
 void PostFilter::ApplyCdefForOneSuperBlockRow(int row4x4_start, int sb4x4,
                                               bool is_last_row) {
   assert(row4x4_start >= 0);
   assert(DoCdef());
-  for (int y = 0; y < sb4x4; y += kStep64x64) {
-    const int row4x4 = row4x4_start + y;
+  int row4x4 = row4x4_start;
+  const int row4x4_limit = row4x4_start + sb4x4;
+  do {
     if (row4x4 >= frame_header_.rows4x4) return;
 
     // Apply cdef for the last 8 rows of the previous superblock row.
     // One exception: If the superblock size is 128x128 and is_last_row is true,
     // then we simply apply cdef for the entire superblock row without any lag.
     // In that case, apply cdef for the previous superblock row only during the
-    // first iteration (y == 0).
-    if (row4x4 > 0 && (!is_last_row || y == 0)) {
+    // first iteration (row4x4 == row4x4_start).
+    if (row4x4 > 0 && (!is_last_row || row4x4 == row4x4_start)) {
       assert(row4x4 >= 16);
       ApplyCdefForOneSuperBlockRowHelper(cdef_block_, nullptr, row4x4 - 2, 2);
     }
@@ -639,7 +652,8 @@
       ApplyCdefForOneSuperBlockRowHelper(cdef_block_, nullptr, row4x4,
                                          height4x4);
     }
-  }
+    row4x4 += kStep64x64;
+  } while (row4x4 < row4x4_limit);
 }
 
 void PostFilter::ApplyCdefWorker(std::atomic<int>* row4x4_atomic) {
diff --git a/libgav1/src/post_filter/deblock.cc b/libgav1/src/post_filter/deblock.cc
index 9b5ed0f..48ad823 100644
--- a/libgav1/src/post_filter/deblock.cc
+++ b/libgav1/src/post_filter/deblock.cc
@@ -101,9 +101,9 @@
     uint8_t deblock_filter_levels[kMaxSegments][kFrameLfCount]
                                  [kNumReferenceFrameTypes][2]) const {
   if (!DoDeblock()) return;
-  for (int segment_id = 0;
-       segment_id < (frame_header_.segmentation.enabled ? kMaxSegments : 1);
-       ++segment_id) {
+  const int num_segments =
+      frame_header_.segmentation.enabled ? kMaxSegments : 1;
+  for (int segment_id = 0; segment_id < num_segments; ++segment_id) {
     int level_index = 0;
     for (; level_index < 2; ++level_index) {
       ComputeDeblockFilterLevelsHelper(
@@ -295,8 +295,13 @@
   *filter_length = std::min(*step, step_prev);
 }
 
-void PostFilter::HorizontalDeblockFilter(int row4x4_start,
-                                         int column4x4_start) {
+void PostFilter::HorizontalDeblockFilter(int row4x4_start, int row4x4_end,
+                                         int column4x4_start,
+                                         int column4x4_end) {
+  const int height4x4 = row4x4_end - row4x4_start;
+  const int width4x4 = column4x4_end - column4x4_start;
+  if (height4x4 <= 0 || width4x4 <= 0) return;
+
   const int column_step = 1;
   const int src_step = 4 << pixel_size_log2_;
   const ptrdiff_t src_stride = frame_buffer_.stride(kPlaneY);
@@ -305,17 +310,20 @@
   uint8_t level;
   int filter_length;
 
-  for (int column4x4 = 0; column4x4 < kNum4x4InLoopFilterUnit &&
-                          MultiplyBy4(column4x4_start + column4x4) < width_;
+  const int width = frame_header_.width;
+  const int height = frame_header_.height;
+  for (int column4x4 = 0;
+       column4x4 < width4x4 && MultiplyBy4(column4x4_start + column4x4) < width;
        column4x4 += column_step, src += src_step) {
     uint8_t* src_row = src;
-    for (int row4x4 = 0; row4x4 < kNum4x4InLoopFilterUnit &&
-                         MultiplyBy4(row4x4_start + row4x4) < height_;
+    for (int row4x4 = 0;
+         row4x4 < height4x4 && MultiplyBy4(row4x4_start + row4x4) < height;
          row4x4 += row_step) {
       const bool need_filter = GetHorizontalDeblockFilterEdgeInfo(
           row4x4_start + row4x4, column4x4_start + column4x4, &level, &row_step,
           &filter_length);
       if (need_filter) {
+        assert(level > 0 && level <= kMaxLoopFilterValue);
         const dsp::LoopFilterSize size = GetLoopFilterSizeY(filter_length);
         dsp_.loop_filters[size][kLoopFilterTypeHorizontal](
             src_row, src_stride, outer_thresh_[level], inner_thresh_[level],
@@ -340,13 +348,13 @@
     uint8_t level_v;
     int filter_length;
 
-    for (int column4x4 = 0; column4x4 < kNum4x4InLoopFilterUnit &&
-                            MultiplyBy4(column4x4_start + column4x4) < width_;
+    for (int column4x4 = 0; column4x4 < width4x4 &&
+                            MultiplyBy4(column4x4_start + column4x4) < width;
          column4x4 += column_step, src_u += src_step, src_v += src_step) {
       uint8_t* src_row_u = src_u;
       uint8_t* src_row_v = src_v;
-      for (int row4x4 = 0; row4x4 < kNum4x4InLoopFilterUnit &&
-                           MultiplyBy4(row4x4_start + row4x4) < height_;
+      for (int row4x4 = 0;
+           row4x4 < height4x4 && MultiplyBy4(row4x4_start + row4x4) < height;
            row4x4 += row_step) {
         GetHorizontalDeblockFilterEdgeInfoUV(
             row4x4_start + row4x4, column4x4_start + column4x4, &level_u,
@@ -371,7 +379,12 @@
   }
 }
 
-void PostFilter::VerticalDeblockFilter(int row4x4_start, int column4x4_start) {
+void PostFilter::VerticalDeblockFilter(int row4x4_start, int row4x4_end,
+                                       int column4x4_start, int column4x4_end) {
+  const int height4x4 = row4x4_end - row4x4_start;
+  const int width4x4 = column4x4_end - column4x4_start;
+  if (height4x4 <= 0 || width4x4 <= 0) return;
+
   const ptrdiff_t row_stride = MultiplyBy4(frame_buffer_.stride(kPlaneY));
   const ptrdiff_t src_stride = frame_buffer_.stride(kPlaneY);
   uint8_t* src = GetSourceBuffer(kPlaneY, row4x4_start, column4x4_start);
@@ -383,18 +396,21 @@
       block_parameters_.Address(row4x4_start, column4x4_start);
   const int bp_stride = block_parameters_.columns4x4();
   const int column_step_shift = pixel_size_log2_;
-  for (int row4x4 = 0; row4x4 < kNum4x4InLoopFilterUnit &&
-                       MultiplyBy4(row4x4_start + row4x4) < height_;
+  const int width = frame_header_.width;
+  const int height = frame_header_.height;
+  for (int row4x4 = 0;
+       row4x4 < height4x4 && MultiplyBy4(row4x4_start + row4x4) < height;
        ++row4x4, src += row_stride, bp_row_base += bp_stride) {
     uint8_t* src_row = src;
     BlockParameters* const* bp = bp_row_base;
-    for (int column4x4 = 0; column4x4 < kNum4x4InLoopFilterUnit &&
-                            MultiplyBy4(column4x4_start + column4x4) < width_;
+    for (int column4x4 = 0; column4x4 < width4x4 &&
+                            MultiplyBy4(column4x4_start + column4x4) < width;
          column4x4 += column_step, bp += column_step) {
       const bool need_filter = GetVerticalDeblockFilterEdgeInfo(
           row4x4_start + row4x4, column4x4_start + column4x4, bp, &level,
           &column_step, &filter_length);
       if (need_filter) {
+        assert(level > 0 && level <= kMaxLoopFilterValue);
         const dsp::LoopFilterSize size = GetLoopFilterSizeY(filter_length);
         dsp_.loop_filters[size][kLoopFilterTypeVertical](
             src_row, src_stride, outer_thresh_[level], inner_thresh_[level],
@@ -425,15 +441,15 @@
         GetDeblockPosition(row4x4_start, subsampling_y),
         GetDeblockPosition(column4x4_start, subsampling_x));
     const int bp_stride = block_parameters_.columns4x4() << subsampling_y;
-    for (int row4x4 = 0; row4x4 < kNum4x4InLoopFilterUnit &&
-                         MultiplyBy4(row4x4_start + row4x4) < height_;
+    for (int row4x4 = 0;
+         row4x4 < height4x4 && MultiplyBy4(row4x4_start + row4x4) < height;
          row4x4 += row_step, src_u += row_stride_u, src_v += row_stride_v,
              bp_row_base += bp_stride) {
       uint8_t* src_row_u = src_u;
       uint8_t* src_row_v = src_v;
       BlockParameters* const* bp = bp_row_base;
-      for (int column4x4 = 0; column4x4 < kNum4x4InLoopFilterUnit &&
-                              MultiplyBy4(column4x4_start + column4x4) < width_;
+      for (int column4x4 = 0; column4x4 < width4x4 &&
+                              MultiplyBy4(column4x4_start + column4x4) < width;
            column4x4 += column_step, bp += column_step) {
         GetVerticalDeblockFilterEdgeInfoUV(column4x4_start + column4x4, bp,
                                            &level_u, &level_v, &column_step,
@@ -458,39 +474,15 @@
   }
 }
 
-void PostFilter::ApplyDeblockFilterForOneSuperBlockRow(int row4x4_start,
-                                                       int sb4x4) {
-  assert(row4x4_start >= 0);
-  assert(DoDeblock());
-  for (int y = 0; y < sb4x4; y += 16) {
-    const int row4x4 = row4x4_start + y;
-    if (row4x4 >= frame_header_.rows4x4) break;
-    int column4x4;
-    for (column4x4 = 0; column4x4 < frame_header_.columns4x4;
-         column4x4 += kNum4x4InLoopFilterUnit) {
-      // First apply vertical filtering
-      VerticalDeblockFilter(row4x4, column4x4);
-
-      // Delay one superblock to apply horizontal filtering.
-      if (column4x4 != 0) {
-        HorizontalDeblockFilter(row4x4, column4x4 - kNum4x4InLoopFilterUnit);
-      }
-    }
-    // Horizontal filtering for the last 64x64 block.
-    HorizontalDeblockFilter(row4x4, column4x4 - kNum4x4InLoopFilterUnit);
-  }
-}
-
 template <LoopFilterType loop_filter_type>
 void PostFilter::DeblockFilterWorker(std::atomic<int>* row4x4_atomic) {
+  const int rows4x4 = frame_header_.rows4x4;
+  const int columns4x4 = frame_header_.columns4x4;
   int row4x4;
-  while ((row4x4 = row4x4_atomic->fetch_add(kNum4x4InLoopFilterUnit,
-                                            std::memory_order_relaxed)) <
-         frame_header_.rows4x4) {
-    for (int column4x4 = 0; column4x4 < frame_header_.columns4x4;
-         column4x4 += kNum4x4InLoopFilterUnit) {
-      (this->*deblock_filter_func_[loop_filter_type])(row4x4, column4x4);
-    }
+  while ((row4x4 = row4x4_atomic->fetch_add(
+              kNum4x4InLoopFilterUnit, std::memory_order_relaxed)) < rows4x4) {
+    (this->*deblock_filter_func_[loop_filter_type])(
+        row4x4, row4x4 + kNum4x4InLoopFilterUnit, 0, columns4x4);
   }
 }
 
@@ -504,20 +496,12 @@
                                     int column4x4_end, int sb4x4) {
   assert(row4x4_start >= 0);
   assert(DoDeblock());
-
-  column4x4_end = std::min(column4x4_end, frame_header_.columns4x4);
+  column4x4_end =
+      std::min(Align(column4x4_end, static_cast<int>(kNum4x4InLoopFilterUnit)),
+               frame_header_.columns4x4);
   if (column4x4_start >= column4x4_end) return;
-
-  const DeblockFilter deblock_filter = deblock_filter_func_[loop_filter_type];
-  const int sb_height4x4 =
-      std::min(sb4x4, frame_header_.rows4x4 - row4x4_start);
-  for (int y = 0; y < sb_height4x4; y += kNum4x4InLoopFilterUnit) {
-    const int row4x4 = row4x4_start + y;
-    for (int column4x4 = column4x4_start; column4x4 < column4x4_end;
-         column4x4 += kNum4x4InLoopFilterUnit) {
-      (this->*deblock_filter)(row4x4, column4x4);
-    }
-  }
+  (this->*deblock_filter_func_[loop_filter_type])(
+      row4x4_start, row4x4_start + sb4x4, column4x4_start, column4x4_end);
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/post_filter/loop_restoration.cc b/libgav1/src/post_filter/loop_restoration.cc
index 826ef48..2e6982c 100644
--- a/libgav1/src/post_filter/loop_restoration.cc
+++ b/libgav1/src/post_filter/loop_restoration.cc
@@ -101,6 +101,8 @@
   assert(row4x4_start >= 0);
   assert(DoRestoration());
   int plane = kPlaneY;
+  const int upscaled_width = frame_header_.upscaled_width;
+  const int height = frame_header_.height;
   do {
     if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
       continue;
@@ -108,9 +110,9 @@
     const ptrdiff_t stride = frame_buffer_.stride(plane) / sizeof(Pixel);
     const int unit_height_offset =
         kRestorationUnitOffset >> subsampling_y_[plane];
-    const int plane_height = SubsampledValue(height_, subsampling_y_[plane]);
+    const int plane_height = SubsampledValue(height, subsampling_y_[plane]);
     const int plane_width =
-        SubsampledValue(upscaled_width_, subsampling_x_[plane]);
+        SubsampledValue(upscaled_width, subsampling_x_[plane]);
     const int plane_unit_size = 1 << loop_restoration_.unit_size_log2[plane];
     const int plane_process_unit_height =
         kRestorationUnitHeight >> subsampling_y_[plane];
diff --git a/libgav1/src/post_filter/post_filter.cc b/libgav1/src/post_filter/post_filter.cc
index 7671f01..bc71410 100644
--- a/libgav1/src/post_filter/post_filter.cc
+++ b/libgav1/src/post_filter/post_filter.cc
@@ -26,6 +26,7 @@
 #include "src/utils/array_2d.h"
 #include "src/utils/blocking_counter.h"
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
 #include "src/utils/memory.h"
 #include "src/utils/types.h"
@@ -43,6 +44,131 @@
 
 }  // namespace
 
+PostFilter::PostFilter(const ObuFrameHeader& frame_header,
+                       const ObuSequenceHeader& sequence_header,
+                       FrameScratchBuffer* const frame_scratch_buffer,
+                       YuvBuffer* const frame_buffer, const dsp::Dsp* dsp,
+                       int do_post_filter_mask)
+    : frame_header_(frame_header),
+      loop_restoration_(frame_header.loop_restoration),
+      dsp_(*dsp),
+      bitdepth_(sequence_header.color_config.bitdepth),
+      subsampling_x_{0, sequence_header.color_config.subsampling_x,
+                     sequence_header.color_config.subsampling_x},
+      subsampling_y_{0, sequence_header.color_config.subsampling_y,
+                     sequence_header.color_config.subsampling_y},
+      planes_(sequence_header.color_config.is_monochrome ? kMaxPlanesMonochrome
+                                                         : kMaxPlanes),
+      pixel_size_log2_(static_cast<int>((bitdepth_ == 8) ? sizeof(uint8_t)
+                                                         : sizeof(uint16_t)) -
+                       1),
+      inner_thresh_(kInnerThresh[frame_header.loop_filter.sharpness]),
+      outer_thresh_(kOuterThresh[frame_header.loop_filter.sharpness]),
+      needs_chroma_deblock_(frame_header.loop_filter.level[kPlaneU + 1] != 0 ||
+                            frame_header.loop_filter.level[kPlaneV + 1] != 0),
+      do_cdef_(DoCdef(frame_header, do_post_filter_mask)),
+      do_deblock_(DoDeblock(frame_header, do_post_filter_mask)),
+      do_restoration_(
+          DoRestoration(loop_restoration_, do_post_filter_mask, planes_)),
+      do_superres_(DoSuperRes(frame_header, do_post_filter_mask)),
+      cdef_index_(frame_scratch_buffer->cdef_index),
+      cdef_skip_(frame_scratch_buffer->cdef_skip),
+      inter_transform_sizes_(frame_scratch_buffer->inter_transform_sizes),
+      restoration_info_(&frame_scratch_buffer->loop_restoration_info),
+      superres_coefficients_{
+          frame_scratch_buffer->superres_coefficients[kPlaneTypeY].get(),
+          frame_scratch_buffer
+              ->superres_coefficients
+                  [(sequence_header.color_config.is_monochrome ||
+                    sequence_header.color_config.subsampling_x == 0)
+                       ? kPlaneTypeY
+                       : kPlaneTypeUV]
+              .get()},
+      superres_line_buffer_(frame_scratch_buffer->superres_line_buffer),
+      block_parameters_(frame_scratch_buffer->block_parameters_holder),
+      frame_buffer_(*frame_buffer),
+      cdef_border_(frame_scratch_buffer->cdef_border),
+      loop_restoration_border_(frame_scratch_buffer->loop_restoration_border),
+      thread_pool_(
+          frame_scratch_buffer->threading_strategy.post_filter_thread_pool()) {
+  const int8_t zero_delta_lf[kFrameLfCount] = {};
+  ComputeDeblockFilterLevels(zero_delta_lf, deblock_filter_levels_);
+  if (DoSuperRes()) {
+    int plane = kPlaneY;
+    const int width = frame_header_.width;
+    const int upscaled_width_fh = frame_header_.upscaled_width;
+    do {
+      const int downscaled_width =
+          SubsampledValue(width, subsampling_x_[plane]);
+      const int upscaled_width =
+          SubsampledValue(upscaled_width_fh, subsampling_x_[plane]);
+      const int superres_width = downscaled_width << kSuperResScaleBits;
+      super_res_info_[plane].step =
+          (superres_width + upscaled_width / 2) / upscaled_width;
+      const int error =
+          super_res_info_[plane].step * upscaled_width - superres_width;
+      super_res_info_[plane].initial_subpixel_x =
+          ((-((upscaled_width - downscaled_width) << (kSuperResScaleBits - 1)) +
+            DivideBy2(upscaled_width)) /
+               upscaled_width +
+           (1 << (kSuperResExtraBits - 1)) - error / 2) &
+          kSuperResScaleMask;
+      super_res_info_[plane].upscaled_width = upscaled_width;
+    } while (++plane < planes_);
+    if (dsp->super_res_coefficients != nullptr) {
+      int plane = kPlaneY;
+      const int number_loops = (superres_coefficients_[kPlaneTypeY] ==
+                                superres_coefficients_[kPlaneTypeUV])
+                                   ? kMaxPlanesMonochrome
+                                   : static_cast<int>(kNumPlaneTypes);
+      do {
+        dsp->super_res_coefficients(super_res_info_[plane].upscaled_width,
+                                    super_res_info_[plane].initial_subpixel_x,
+                                    super_res_info_[plane].step,
+                                    superres_coefficients_[plane]);
+      } while (++plane < number_loops);
+    }
+  }
+  int plane = kPlaneY;
+  do {
+    loop_restoration_buffer_[plane] = frame_buffer_.data(plane);
+    cdef_buffer_[plane] = frame_buffer_.data(plane);
+    superres_buffer_[plane] = frame_buffer_.data(plane);
+    source_buffer_[plane] = frame_buffer_.data(plane);
+  } while (++plane < planes_);
+  if (DoCdef() || DoRestoration() || DoSuperRes()) {
+    plane = kPlaneY;
+    const int pixel_size_log2 = pixel_size_log2_;
+    do {
+      int horizontal_shift = 0;
+      int vertical_shift = 0;
+      if (DoRestoration() &&
+          loop_restoration_.type[plane] != kLoopRestorationTypeNone) {
+        horizontal_shift += frame_buffer_.alignment();
+        if (!DoCdef() && thread_pool_ == nullptr) {
+          vertical_shift += kRestorationVerticalBorder;
+        }
+        superres_buffer_[plane] +=
+            vertical_shift * frame_buffer_.stride(plane) +
+            (horizontal_shift << pixel_size_log2);
+      }
+      if (DoSuperRes()) {
+        vertical_shift += kSuperResVerticalBorder;
+      }
+      cdef_buffer_[plane] += vertical_shift * frame_buffer_.stride(plane) +
+                             (horizontal_shift << pixel_size_log2);
+      if (DoCdef() && thread_pool_ == nullptr) {
+        horizontal_shift += frame_buffer_.alignment();
+        vertical_shift += kCdefBorder;
+      }
+      assert(horizontal_shift <= frame_buffer_.right_border(plane));
+      assert(vertical_shift <= frame_buffer_.bottom_border(plane));
+      source_buffer_[plane] += vertical_shift * frame_buffer_.stride(plane) +
+                               (horizontal_shift << pixel_size_log2);
+    } while (++plane < planes_);
+  }
+}
+
 // The following example illustrates how ExtendFrame() extends a frame.
 // Suppose the frame width is 8 and height is 4, and left, right, top, and
 // bottom are all equal to 3.
@@ -138,129 +264,6 @@
     const int bottom);
 #endif
 
-PostFilter::PostFilter(const ObuFrameHeader& frame_header,
-                       const ObuSequenceHeader& sequence_header,
-                       FrameScratchBuffer* const frame_scratch_buffer,
-                       YuvBuffer* const frame_buffer, const dsp::Dsp* dsp,
-                       int do_post_filter_mask)
-    : frame_header_(frame_header),
-      loop_restoration_(frame_header.loop_restoration),
-      dsp_(*dsp),
-      // Deblocking filter always uses 64x64 as step size.
-      num_64x64_blocks_per_row_(DivideBy64(frame_header.width + 63)),
-      upscaled_width_(frame_header.upscaled_width),
-      width_(frame_header.width),
-      height_(frame_header.height),
-      bitdepth_(sequence_header.color_config.bitdepth),
-      subsampling_x_{0, sequence_header.color_config.subsampling_x,
-                     sequence_header.color_config.subsampling_x},
-      subsampling_y_{0, sequence_header.color_config.subsampling_y,
-                     sequence_header.color_config.subsampling_y},
-      planes_(sequence_header.color_config.is_monochrome ? kMaxPlanesMonochrome
-                                                         : kMaxPlanes),
-      pixel_size_log2_(static_cast<int>((bitdepth_ == 8) ? sizeof(uint8_t)
-                                                         : sizeof(uint16_t)) -
-                       1),
-      inner_thresh_(kInnerThresh[frame_header.loop_filter.sharpness]),
-      outer_thresh_(kOuterThresh[frame_header.loop_filter.sharpness]),
-      needs_chroma_deblock_(frame_header.loop_filter.level[kPlaneU + 1] != 0 ||
-                            frame_header.loop_filter.level[kPlaneV + 1] != 0),
-      cdef_index_(frame_scratch_buffer->cdef_index),
-      inter_transform_sizes_(frame_scratch_buffer->inter_transform_sizes),
-      restoration_info_(&frame_scratch_buffer->loop_restoration_info),
-      superres_coefficients_{
-          frame_scratch_buffer->superres_coefficients[kPlaneTypeY].get(),
-          frame_scratch_buffer
-              ->superres_coefficients
-                  [(sequence_header.color_config.is_monochrome ||
-                    sequence_header.color_config.subsampling_x == 0)
-                       ? kPlaneTypeY
-                       : kPlaneTypeUV]
-              .get()},
-      superres_line_buffer_(frame_scratch_buffer->superres_line_buffer),
-      block_parameters_(frame_scratch_buffer->block_parameters_holder),
-      frame_buffer_(*frame_buffer),
-      cdef_border_(frame_scratch_buffer->cdef_border),
-      loop_restoration_border_(frame_scratch_buffer->loop_restoration_border),
-      do_post_filter_mask_(do_post_filter_mask),
-      thread_pool_(
-          frame_scratch_buffer->threading_strategy.post_filter_thread_pool()) {
-  const int8_t zero_delta_lf[kFrameLfCount] = {};
-  ComputeDeblockFilterLevels(zero_delta_lf, deblock_filter_levels_);
-  if (DoSuperRes()) {
-    int plane = kPlaneY;
-    do {
-      const int downscaled_width =
-          SubsampledValue(width_, subsampling_x_[plane]);
-      const int upscaled_width =
-          SubsampledValue(upscaled_width_, subsampling_x_[plane]);
-      const int superres_width = downscaled_width << kSuperResScaleBits;
-      super_res_info_[plane].step =
-          (superres_width + upscaled_width / 2) / upscaled_width;
-      const int error =
-          super_res_info_[plane].step * upscaled_width - superres_width;
-      super_res_info_[plane].initial_subpixel_x =
-          ((-((upscaled_width - downscaled_width) << (kSuperResScaleBits - 1)) +
-            DivideBy2(upscaled_width)) /
-               upscaled_width +
-           (1 << (kSuperResExtraBits - 1)) - error / 2) &
-          kSuperResScaleMask;
-      super_res_info_[plane].upscaled_width = upscaled_width;
-    } while (++plane < planes_);
-    if (dsp->super_res_coefficients != nullptr) {
-      int plane = kPlaneY;
-      const int number_loops = (superres_coefficients_[kPlaneTypeY] ==
-                                superres_coefficients_[kPlaneTypeUV])
-                                   ? kMaxPlanesMonochrome
-                                   : static_cast<int>(kNumPlaneTypes);
-      do {
-        dsp->super_res_coefficients(
-            SubsampledValue(upscaled_width_, subsampling_x_[plane]),
-            super_res_info_[plane].initial_subpixel_x,
-            super_res_info_[plane].step, superres_coefficients_[plane]);
-      } while (++plane < number_loops);
-    }
-  }
-  int plane = kPlaneY;
-  do {
-    loop_restoration_buffer_[plane] = frame_buffer_.data(plane);
-    cdef_buffer_[plane] = frame_buffer_.data(plane);
-    superres_buffer_[plane] = frame_buffer_.data(plane);
-    source_buffer_[plane] = frame_buffer_.data(plane);
-  } while (++plane < planes_);
-  if (DoCdef() || DoRestoration() || DoSuperRes()) {
-    plane = kPlaneY;
-    const int pixel_size_log2 = pixel_size_log2_;
-    do {
-      int horizontal_shift = 0;
-      int vertical_shift = 0;
-      if (DoRestoration() &&
-          loop_restoration_.type[plane] != kLoopRestorationTypeNone) {
-        horizontal_shift += frame_buffer_.alignment();
-        if (!DoCdef() && thread_pool_ == nullptr) {
-          vertical_shift += kRestorationVerticalBorder;
-        }
-        superres_buffer_[plane] +=
-            vertical_shift * frame_buffer_.stride(plane) +
-            (horizontal_shift << pixel_size_log2);
-      }
-      if (DoSuperRes()) {
-        vertical_shift += kSuperResVerticalBorder;
-      }
-      cdef_buffer_[plane] += vertical_shift * frame_buffer_.stride(plane) +
-                             (horizontal_shift << pixel_size_log2);
-      if (DoCdef() && thread_pool_ == nullptr) {
-        horizontal_shift += frame_buffer_.alignment();
-        vertical_shift += kCdefBorder;
-      }
-      assert(horizontal_shift <= frame_buffer_.right_border(plane));
-      assert(vertical_shift <= frame_buffer_.bottom_border(plane));
-      source_buffer_[plane] += vertical_shift * frame_buffer_.stride(plane) +
-                               (horizontal_shift << pixel_size_log2);
-    } while (++plane < planes_);
-  }
-}
-
 void PostFilter::ExtendFrameBoundary(uint8_t* const frame_start,
                                      const int width, const int height,
                                      const ptrdiff_t stride, const int left,
@@ -269,8 +272,7 @@
 #if LIBGAV1_MAX_BITDEPTH >= 10
   if (bitdepth_ >= 10) {
     ExtendFrame<uint16_t>(reinterpret_cast<uint16_t*>(frame_start), width,
-                          height, stride / sizeof(uint16_t), left, right, top,
-                          bottom);
+                          height, stride >> 1, left, right, top, bottom);
     return;
   }
 #endif
@@ -280,11 +282,13 @@
 
 void PostFilter::ExtendBordersForReferenceFrame() {
   if (frame_header_.refresh_frame_flags == 0) return;
+  const int upscaled_width = frame_header_.upscaled_width;
+  const int height = frame_header_.height;
   int plane = kPlaneY;
   do {
     const int plane_width =
-        SubsampledValue(upscaled_width_, subsampling_x_[plane]);
-    const int plane_height = SubsampledValue(height_, subsampling_y_[plane]);
+        SubsampledValue(upscaled_width, subsampling_x_[plane]);
+    const int plane_height = SubsampledValue(height, subsampling_y_[plane]);
     assert(frame_buffer_.left_border(plane) >= kMinLeftBorderPixels &&
            frame_buffer_.right_border(plane) >= kMinRightBorderPixels &&
            frame_buffer_.top_border(plane) >= kMinTopBorderPixels &&
@@ -343,11 +347,13 @@
   // needs 2 extra rows for the bottom border in each plane.
   const int extra_rows =
       (for_loop_restoration && thread_pool_ == nullptr && !DoCdef()) ? 2 : 0;
+  const int upscaled_width = frame_header_.upscaled_width;
+  const int height = frame_header_.height;
   int plane = kPlaneY;
   do {
     const int plane_width =
-        SubsampledValue(upscaled_width_, subsampling_x_[plane]);
-    const int plane_height = SubsampledValue(height_, subsampling_y_[plane]);
+        SubsampledValue(upscaled_width, subsampling_x_[plane]);
+    const int plane_height = SubsampledValue(height, subsampling_y_[plane]);
     const int row = (MultiplyBy4(row4x4) - row_offset) >> subsampling_y_[plane];
     assert(row >= 0);
     if (row >= plane_height) break;
@@ -362,16 +368,25 @@
       progress_row_ = row + num_rows;
     }
     const bool copy_bottom = row + num_rows == plane_height;
-    const int stride = frame_buffer_.stride(plane);
+    const ptrdiff_t stride = frame_buffer_.stride(plane);
     uint8_t* const start = (for_loop_restoration ? superres_buffer_[plane]
                                                  : frame_buffer_.data(plane)) +
                            row * stride;
     const int left_border = for_loop_restoration
                                 ? kRestorationHorizontalBorder
                                 : frame_buffer_.left_border(plane);
+#if LIBGAV1_MSAN
+    // The optimized loop restoration code will overread the visible frame
+    // buffer into the right border. Extend the right boundary further to
+    // prevent msan warnings.
+    const int right_border = for_loop_restoration
+                                 ? kRestorationHorizontalBorder + 16
+                                 : frame_buffer_.right_border(plane);
+#else
     const int right_border = for_loop_restoration
                                  ? kRestorationHorizontalBorder
                                  : frame_buffer_.right_border(plane);
+#endif
     const int top_border =
         (row == 0) ? (for_loop_restoration ? kRestorationVerticalBorder
                                            : frame_buffer_.top_border(plane))
@@ -390,6 +405,8 @@
   assert(row4x4 >= 0);
   assert(!DoCdef());
   assert(DoRestoration());
+  const int upscaled_width = frame_header_.upscaled_width;
+  const int height = frame_header_.height;
   int plane = kPlaneY;
   do {
     if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
@@ -397,9 +414,9 @@
     }
     const int row_offset = DivideBy4(row4x4);
     const int num_pixels =
-        SubsampledValue(upscaled_width_, subsampling_x_[plane]);
+        SubsampledValue(upscaled_width, subsampling_x_[plane]);
     const int row_width = num_pixels << pixel_size_log2_;
-    const int plane_height = SubsampledValue(height_, subsampling_y_[plane]);
+    const int plane_height = SubsampledValue(height, subsampling_y_[plane]);
     const int row = kLoopRestorationBorderRows[subsampling_y_[plane]];
     const int absolute_row =
         (MultiplyBy4(row4x4) >> subsampling_y_[plane]) + row;
@@ -437,30 +454,33 @@
     const int row_offset_start = DivideBy4(row4x4);
     const std::array<uint8_t*, kMaxPlanes> dst = {
         loop_restoration_border_.data(kPlaneY) +
-            row_offset_start * loop_restoration_border_.stride(kPlaneY),
+            row_offset_start * static_cast<ptrdiff_t>(
+                                   loop_restoration_border_.stride(kPlaneY)),
         loop_restoration_border_.data(kPlaneU) +
-            row_offset_start * loop_restoration_border_.stride(kPlaneU),
+            row_offset_start * static_cast<ptrdiff_t>(
+                                   loop_restoration_border_.stride(kPlaneU)),
         loop_restoration_border_.data(kPlaneV) +
-            row_offset_start * loop_restoration_border_.stride(kPlaneV)};
+            row_offset_start * static_cast<ptrdiff_t>(
+                                   loop_restoration_border_.stride(kPlaneV))};
     // If SuperRes is enabled, then we apply SuperRes for the rows to be copied
     // directly with |loop_restoration_border_| as the destination. Otherwise,
     // we simply copy the rows.
     if (DoSuperRes()) {
       std::array<uint8_t*, kMaxPlanes> src;
       std::array<int, kMaxPlanes> rows;
+      const int height = frame_header_.height;
       int plane = kPlaneY;
       do {
         if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
           rows[plane] = 0;
           continue;
         }
-        const int plane_height =
-            SubsampledValue(frame_header_.height, subsampling_y_[plane]);
+        const int plane_height = SubsampledValue(height, subsampling_y_[plane]);
         const int row = kLoopRestorationBorderRows[subsampling_y_[plane]];
         const int absolute_row =
             (MultiplyBy4(row4x4) >> subsampling_y_[plane]) + row;
         src[plane] = GetSourceBuffer(static_cast<Plane>(plane), row4x4, 0) +
-                     row * frame_buffer_.stride(plane);
+                     row * static_cast<ptrdiff_t>(frame_buffer_.stride(plane));
         rows[plane] = Clip3(plane_height - absolute_row, 0, 4);
       } while (++plane < planes_);
       ApplySuperRes(src, rows, /*line_buffer_row=*/-1, dst,
@@ -487,6 +507,7 @@
       } while (++plane < planes_);
     }
     // Extend the left and right boundaries needed for loop restoration.
+    const int upscaled_width = frame_header_.upscaled_width;
     int plane = kPlaneY;
     do {
       if (loop_restoration_.type[plane] == kLoopRestorationTypeNone) {
@@ -494,7 +515,7 @@
       }
       uint8_t* dst_line = dst[plane];
       const int plane_width =
-          SubsampledValue(upscaled_width_, subsampling_x_[plane]);
+          SubsampledValue(upscaled_width, subsampling_x_[plane]);
       for (int i = 0; i < 4; ++i) {
 #if LIBGAV1_MAX_BITDEPTH >= 10
         if (bitdepth_ >= 10) {
@@ -567,7 +588,9 @@
                                                   bool do_deblock) {
   if (row4x4 < 0) return -1;
   if (DoDeblock() && do_deblock) {
-    ApplyDeblockFilterForOneSuperBlockRow(row4x4, sb4x4);
+    VerticalDeblockFilter(row4x4, row4x4 + sb4x4, 0, frame_header_.columns4x4);
+    HorizontalDeblockFilter(row4x4, row4x4 + sb4x4, 0,
+                            frame_header_.columns4x4);
   }
   if (DoRestoration() && DoCdef()) {
     SetupLoopRestorationBorder(row4x4, sb4x4);
@@ -597,7 +620,7 @@
   if (is_last_row && !DoBorderExtensionInLoop()) {
     ExtendBordersForReferenceFrame();
   }
-  return is_last_row ? height_ : progress_row_;
+  return is_last_row ? frame_header_.height : progress_row_;
 }
 
 }  // namespace libgav1
diff --git a/libgav1/src/post_filter/super_res.cc b/libgav1/src/post_filter/super_res.cc
index 554e537..2133a8a 100644
--- a/libgav1/src/post_filter/super_res.cc
+++ b/libgav1/src/post_filter/super_res.cc
@@ -149,16 +149,17 @@
   int num_threads = thread_pool_->num_threads() + 1;
   // The number of rows that will be processed by each thread in the thread pool
   // (other than the current thread).
-  int thread_pool_rows = height_ / num_threads;
+  int thread_pool_rows = frame_header_.height / num_threads;
   thread_pool_rows = std::max(thread_pool_rows, 1);
   // Make rows of Y plane even when there is subsampling for the other planes.
   if ((thread_pool_rows & 1) != 0 && subsampling_y_[kPlaneU] != 0) {
     ++thread_pool_rows;
   }
   // Adjust the number of threads to what we really need.
-  num_threads = Clip3(height_ / thread_pool_rows, 1, num_threads);
+  num_threads = Clip3(frame_header_.height / thread_pool_rows, 1, num_threads);
   // For the current thread, we round up to process all the remaining rows.
-  int current_thread_rows = height_ - thread_pool_rows * (num_threads - 1);
+  int current_thread_rows =
+      frame_header_.height - thread_pool_rows * (num_threads - 1);
   // Make rows of Y plane even when there is subsampling for the other planes.
   if ((current_thread_rows & 1) != 0 && subsampling_y_[kPlaneU] != 0) {
     ++current_thread_rows;
diff --git a/libgav1/src/prediction_mask.h b/libgav1/src/prediction_mask.h
index 0134a0d..827a0fa 100644
--- a/libgav1/src/prediction_mask.h
+++ b/libgav1/src/prediction_mask.h
@@ -17,9 +17,6 @@
 #ifndef LIBGAV1_SRC_PREDICTION_MASK_H_
 #define LIBGAV1_SRC_PREDICTION_MASK_H_
 
-#include <cstddef>
-#include <cstdint>
-
 #include "src/utils/bit_mask_set.h"
 #include "src/utils/types.h"
 
diff --git a/libgav1/src/quantizer.h b/libgav1/src/quantizer.h
index 00c53ab..c60756c 100644
--- a/libgav1/src/quantizer.h
+++ b/libgav1/src/quantizer.h
@@ -17,6 +17,7 @@
 #ifndef LIBGAV1_SRC_QUANTIZER_H_
 #define LIBGAV1_SRC_QUANTIZER_H_
 
+#include <array>
 #include <cstdint>
 
 #include "src/utils/constants.h"
diff --git a/libgav1/src/reconstruction.cc b/libgav1/src/reconstruction.cc
index 1aa1233..bf48137 100644
--- a/libgav1/src/reconstruction.cc
+++ b/libgav1/src/reconstruction.cc
@@ -23,30 +23,30 @@
 namespace libgav1 {
 namespace {
 
-// Maps TransformType to dsp::Transform1D for the row transforms.
-constexpr dsp::Transform1D kRowTransform[kNumTransformTypes] = {
-    dsp::k1DTransformDct,      dsp::k1DTransformAdst,
-    dsp::k1DTransformDct,      dsp::k1DTransformAdst,
-    dsp::k1DTransformAdst,     dsp::k1DTransformDct,
-    dsp::k1DTransformAdst,     dsp::k1DTransformAdst,
-    dsp::k1DTransformAdst,     dsp::k1DTransformIdentity,
-    dsp::k1DTransformIdentity, dsp::k1DTransformDct,
-    dsp::k1DTransformIdentity, dsp::k1DTransformAdst,
-    dsp::k1DTransformIdentity, dsp::k1DTransformAdst};
+// Maps TransformType to dsp::Transform1d for the row transforms.
+constexpr dsp::Transform1d kRowTransform[kNumTransformTypes] = {
+    dsp::kTransform1dDct,      dsp::kTransform1dAdst,
+    dsp::kTransform1dDct,      dsp::kTransform1dAdst,
+    dsp::kTransform1dAdst,     dsp::kTransform1dDct,
+    dsp::kTransform1dAdst,     dsp::kTransform1dAdst,
+    dsp::kTransform1dAdst,     dsp::kTransform1dIdentity,
+    dsp::kTransform1dIdentity, dsp::kTransform1dDct,
+    dsp::kTransform1dIdentity, dsp::kTransform1dAdst,
+    dsp::kTransform1dIdentity, dsp::kTransform1dAdst};
 
-// Maps TransformType to dsp::Transform1D for the column transforms.
-constexpr dsp::Transform1D kColumnTransform[kNumTransformTypes] = {
-    dsp::k1DTransformDct,  dsp::k1DTransformDct,
-    dsp::k1DTransformAdst, dsp::k1DTransformAdst,
-    dsp::k1DTransformDct,  dsp::k1DTransformAdst,
-    dsp::k1DTransformAdst, dsp::k1DTransformAdst,
-    dsp::k1DTransformAdst, dsp::k1DTransformIdentity,
-    dsp::k1DTransformDct,  dsp::k1DTransformIdentity,
-    dsp::k1DTransformAdst, dsp::k1DTransformIdentity,
-    dsp::k1DTransformAdst, dsp::k1DTransformIdentity};
+// Maps TransformType to dsp::Transform1d for the column transforms.
+constexpr dsp::Transform1d kColumnTransform[kNumTransformTypes] = {
+    dsp::kTransform1dDct,  dsp::kTransform1dDct,
+    dsp::kTransform1dAdst, dsp::kTransform1dAdst,
+    dsp::kTransform1dDct,  dsp::kTransform1dAdst,
+    dsp::kTransform1dAdst, dsp::kTransform1dAdst,
+    dsp::kTransform1dAdst, dsp::kTransform1dIdentity,
+    dsp::kTransform1dDct,  dsp::kTransform1dIdentity,
+    dsp::kTransform1dAdst, dsp::kTransform1dIdentity,
+    dsp::kTransform1dAdst, dsp::kTransform1dIdentity};
 
-dsp::TransformSize1D Get1DTransformSize(int size_log2) {
-  return static_cast<dsp::TransformSize1D>(size_log2 - 2);
+dsp::Transform1dSize GetTransform1dSize(int size_log2) {
+  return static_cast<dsp::Transform1dSize>(size_log2 - 2);
 }
 
 // Returns the number of rows to process based on |non_zero_coeff_count|. The
@@ -150,10 +150,10 @@
   assert(tx_height <= 32);
 
   // Row transform.
-  const dsp::TransformSize1D row_transform_size =
-      Get1DTransformSize(tx_width_log2);
-  const dsp::Transform1D row_transform =
-      lossless ? dsp::k1DTransformWht : kRowTransform[tx_type];
+  const dsp::Transform1dSize row_transform_size =
+      GetTransform1dSize(tx_width_log2);
+  const dsp::Transform1d row_transform =
+      lossless ? dsp::kTransform1dWht : kRowTransform[tx_type];
   const dsp::InverseTransformAddFunc row_transform_func =
       dsp.inverse_transforms[row_transform][row_transform_size][dsp::kRow];
   assert(row_transform_func != nullptr);
@@ -162,10 +162,10 @@
                      frame);
 
   // Column transform.
-  const dsp::TransformSize1D column_transform_size =
-      Get1DTransformSize(tx_height_log2);
-  const dsp::Transform1D column_transform =
-      lossless ? dsp::k1DTransformWht : kColumnTransform[tx_type];
+  const dsp::Transform1dSize column_transform_size =
+      GetTransform1dSize(tx_height_log2);
+  const dsp::Transform1d column_transform =
+      lossless ? dsp::kTransform1dWht : kColumnTransform[tx_type];
   const dsp::InverseTransformAddFunc column_transform_func =
       dsp.inverse_transforms[column_transform][column_transform_size]
                             [dsp::kColumn];
diff --git a/libgav1/src/tile.h b/libgav1/src/tile.h
index 6bae2a0..83c3423 100644
--- a/libgav1/src/tile.h
+++ b/libgav1/src/tile.h
@@ -65,7 +65,9 @@
   kProcessingModeParseAndDecode,
 };
 
-class Tile : public Allocable {
+// The alignment requirement is due to the SymbolDecoderContext member
+// symbol_decoder_context_.
+class Tile : public MaxAlignedAllocable {
  public:
   static std::unique_ptr<Tile> Create(
       int tile_number, const uint8_t* const data, size_t size,
@@ -320,7 +322,7 @@
   bool ReadSegmentId(const Block& block);       // 5.11.9.
   bool ReadIntraSegmentId(const Block& block);  // 5.11.8.
   void ReadSkip(const Block& block);            // 5.11.11.
-  void ReadSkipMode(const Block& block);        // 5.11.10.
+  bool ReadSkipMode(const Block& block);        // 5.11.10.
   void ReadCdef(const Block& block);            // 5.11.56.
   // Returns the new value. |cdf| is an array of size kDeltaSymbolCount + 1.
   int ReadAndClipDelta(uint16_t* cdf, int delta_small, int scale, int min_value,
@@ -330,6 +332,7 @@
   // Populates |BlockParameters::deblock_filter_level| for the given |block|
   // using |deblock_filter_levels_|.
   void PopulateDeblockFilterLevel(const Block& block);
+  void PopulateCdefSkip(const Block& block);
   void ReadPredictionModeY(const Block& block, bool intra_y_mode);
   void ReadIntraAngleInfo(const Block& block,
                           PlaneType plane_type);  // 5.11.42 and 5.11.43.
@@ -346,36 +349,41 @@
   bool DecodeIntraModeInfo(const Block& block);                // 5.11.7.
   int8_t ComputePredictedSegmentId(const Block& block) const;  // 5.11.21.
   bool ReadInterSegmentId(const Block& block, bool pre_skip);  // 5.11.19.
-  void ReadIsInter(const Block& block);                        // 5.11.20.
+  void ReadIsInter(const Block& block, bool skip_mode);        // 5.11.20.
   bool ReadIntraBlockModeInfo(const Block& block,
                               bool intra_y_mode);  // 5.11.22.
   CompoundReferenceType ReadCompoundReferenceType(const Block& block);
   template <bool is_single, bool is_backward, int index>
   uint16_t* GetReferenceCdf(const Block& block, CompoundReferenceType type =
                                                     kNumCompoundReferenceTypes);
-  void ReadReferenceFrames(const Block& block);  // 5.11.25.
+  void ReadReferenceFrames(const Block& block, bool skip_mode);  // 5.11.25.
   void ReadInterPredictionModeY(const Block& block,
-                                const MvContexts& mode_contexts);
+                                const MvContexts& mode_contexts,
+                                bool skip_mode);
   void ReadRefMvIndex(const Block& block);
-  void ReadInterIntraMode(const Block& block, bool is_compound);  // 5.11.28.
+  void ReadInterIntraMode(const Block& block, bool is_compound,
+                          bool skip_mode);        // 5.11.28.
   bool IsScaled(ReferenceFrameType type) const {  // Part of 5.11.27.
     const int index =
         frame_header_.reference_frame_index[type - kReferenceFrameLast];
     return reference_frames_[index]->upscaled_width() != frame_header_.width ||
            reference_frames_[index]->frame_height() != frame_header_.height;
   }
-  void ReadMotionMode(const Block& block, bool is_compound);  // 5.11.27.
+  void ReadMotionMode(const Block& block, bool is_compound,
+                      bool skip_mode);  // 5.11.27.
   uint16_t* GetIsExplicitCompoundTypeCdf(const Block& block);
   uint16_t* GetIsCompoundTypeAverageCdf(const Block& block);
-  void ReadCompoundType(const Block& block, bool is_compound);  // 5.11.29.
+  void ReadCompoundType(const Block& block, bool is_compound, bool skip_mode,
+                        bool* is_explicit_compound_type,
+                        bool* is_compound_type_average);  // 5.11.29.
   uint16_t* GetInterpolationFilterCdf(const Block& block, int direction);
-  void ReadInterpolationFilter(const Block& block);
-  bool ReadInterBlockModeInfo(const Block& block);             // 5.11.23.
-  bool DecodeInterModeInfo(const Block& block);                // 5.11.18.
-  bool DecodeModeInfo(const Block& block);                     // 5.11.6.
-  bool IsMvValid(const Block& block, bool is_compound) const;  // 6.10.25.
-  bool AssignInterMv(const Block& block, bool is_compound);    // 5.11.26.
-  bool AssignIntraMv(const Block& block);                      // 5.11.26.
+  void ReadInterpolationFilter(const Block& block, bool skip_mode);
+  bool ReadInterBlockModeInfo(const Block& block, bool skip_mode);  // 5.11.23.
+  bool DecodeInterModeInfo(const Block& block);                     // 5.11.18.
+  bool DecodeModeInfo(const Block& block);                          // 5.11.6.
+  bool IsMvValid(const Block& block, bool is_compound) const;       // 6.10.25.
+  bool AssignInterMv(const Block& block, bool is_compound);         // 5.11.26.
+  bool AssignIntraMv(const Block& block);                           // 5.11.26.
   int GetTopTransformWidth(const Block& block, int row4x4, int column4x4,
                            bool ignore_skip);
   int GetLeftTransformHeight(const Block& block, int row4x4, int column4x4,
@@ -541,7 +549,6 @@
                        bool has_left, bool has_top, bool has_top_right,
                        bool has_bottom_left, PredictionMode mode,
                        TransformSize tx_size);
-  bool IsSmoothPrediction(int row, int column, Plane plane) const;
   int GetIntraEdgeFilterType(const Block& block,
                              Plane plane) const;  // 7.11.2.8.
   template <typename Pixel>
@@ -563,6 +570,17 @@
   // for the given |block| and stores them into |current_frame_|.
   void StoreMotionFieldMvsIntoCurrentFrame(const Block& block);
 
+  // SetCdfContext*() functions will populate the |left_context_| and
+  // |top_context_| for the |block|.
+  void SetCdfContextUsePredictedSegmentId(const Block& block,
+                                          bool use_predicted_segment_id);
+  void SetCdfContextCompoundType(const Block& block,
+                                 bool is_explicit_compound_type,
+                                 bool is_compound_type_average);
+  void SetCdfContextSkipMode(const Block& block, bool skip_mode);
+  void SetCdfContextPaletteSize(const Block& block);
+  void SetCdfContextUVMode(const Block& block);
+
   // Returns the zero-based index of the super block that contains |row4x4|
   // relative to the start of this tile.
   int SuperBlockRowIndex(int row4x4) const {
@@ -577,6 +595,16 @@
            (sequence_header_.use_128x128_superblock ? 5 : 4);
   }
 
+  // Returns the zero-based index of the block that starts at row4x4 or
+  // column4x4 relative to the start of the superblock that contains the block.
+  // This is used to index into the members of |left_context_| and
+  // |top_context_|.
+  int CdfContextIndex(int row_or_column4x4) const {
+    return row_or_column4x4 -
+           (row_or_column4x4 &
+            (sequence_header_.use_128x128_superblock ? ~31 : ~15));
+  }
+
   BlockSize SuperBlockSize() const {
     return sequence_header_.use_128x128_superblock ? kBlock128x128
                                                    : kBlock64x64;
@@ -600,8 +628,6 @@
   bool read_deltas_;
   const int8_t subsampling_x_[kMaxPlanes];
   const int8_t subsampling_y_[kMaxPlanes];
-  int deblock_row_limit_[kMaxPlanes];
-  int deblock_column_limit_[kMaxPlanes];
 
   // The dimensions (in order) are: segment_id, level_index (based on plane and
   // direction), reference_frame and mode_id.
@@ -649,7 +675,7 @@
   const std::array<uint8_t, kNumReferenceFrameTypes>& reference_order_hint_;
   const WedgeMaskArray& wedge_masks_;
   const QuantizerMatrix& quantizer_matrix_;
-  DaalaBitReader reader_;
+  EntropyDecoder reader_;
   SymbolDecoderContext symbol_decoder_context_;
   SymbolDecoderContext* const saved_symbol_decoder_context_;
   const SegmentationMap* prev_segment_ids_;
@@ -712,7 +738,8 @@
   Array2DView<uint8_t> buffer_[kMaxPlanes];
   RefCountedBuffer& current_frame_;
 
-  Array2D<int16_t>& cdef_index_;
+  Array2D<int8_t>& cdef_index_;
+  Array2D<uint8_t>& cdef_skip_;
   Array2D<TransformSize>& inter_transform_sizes_;
   std::array<RestorationUnitInfo, kMaxPlanes> reference_unit_info_;
   // If |thread_pool_| is nullptr, the calling thread will do the parsing and
@@ -746,12 +773,19 @@
   // Stores the progress of the reference frames. This will be used to avoid
   // unnecessary calls into RefCountedBuffer::WaitUntil().
   std::array<int, kNumReferenceFrameTypes> reference_frame_progress_cache_;
+  // Stores the CDF contexts necessary for the "left" block.
+  BlockCdfContext left_context_;
+  // Stores the CDF contexts necessary for the "top" block. The size of this
+  // buffer is the number of superblock columns in this tile. For each block,
+  // the access index will be the corresponding SuperBlockColumnIndex()'th
+  // entry.
+  DynamicBuffer<BlockCdfContext> top_context_;
 };
 
 struct Tile::Block {
-  Block(const Tile& tile, BlockSize size, int row4x4, int column4x4,
+  Block(Tile* tile_ptr, BlockSize size, int row4x4, int column4x4,
         TileScratchBuffer* const scratch_buffer, ResidualPtr* residual)
-      : tile(tile),
+      : tile(*tile_ptr),
         size(size),
         row4x4(row4x4),
         column4x4(column4x4),
@@ -760,7 +794,11 @@
         width4x4(width >> 2),
         height4x4(height >> 2),
         scratch_buffer(scratch_buffer),
-        residual(residual) {
+        residual(residual),
+        top_context(tile.top_context_.get() +
+                    tile.SuperBlockColumnIndex(column4x4)),
+        top_context_index(tile.CdfContextIndex(column4x4)),
+        left_context_index(tile.CdfContextIndex(row4x4)) {
     assert(size != kBlockInvalid);
     residual_size[kPlaneY] = kPlaneResidualSize[size][0][0];
     residual_size[kPlaneU] = residual_size[kPlaneV] =
@@ -881,7 +919,7 @@
     return false;
   }
 
-  const Tile& tile;
+  Tile& tile;
   bool has_chroma;
   const BlockSize size;
   bool top_available[kMaxPlanes];
@@ -898,6 +936,9 @@
   BlockParameters* bp;
   TileScratchBuffer* const scratch_buffer;
   ResidualPtr* const residual;
+  BlockCdfContext* const top_context;
+  const int top_context_index;
+  const int left_context_index;
 };
 
 extern template bool
diff --git a/libgav1/src/tile/bitstream/mode_info.cc b/libgav1/src/tile/bitstream/mode_info.cc
index 0b22eb0..cb7b311 100644
--- a/libgav1/src/tile/bitstream/mode_info.cc
+++ b/libgav1/src/tile/bitstream/mode_info.cc
@@ -185,19 +185,22 @@
 }  // namespace
 
 bool Tile::ReadSegmentId(const Block& block) {
+  // These two asserts ensure that current_frame_.segmentation_map() is not
+  // nullptr.
+  assert(frame_header_.segmentation.enabled);
+  assert(frame_header_.segmentation.update_map);
+  const SegmentationMap& map = *current_frame_.segmentation_map();
   int top_left = -1;
   if (block.top_available[kPlaneY] && block.left_available[kPlaneY]) {
-    top_left =
-        block_parameters_holder_.Find(block.row4x4 - 1, block.column4x4 - 1)
-            ->segment_id;
+    top_left = map.segment_id(block.row4x4 - 1, block.column4x4 - 1);
   }
   int top = -1;
   if (block.top_available[kPlaneY]) {
-    top = block.bp_top->segment_id;
+    top = map.segment_id(block.row4x4 - 1, block.column4x4);
   }
   int left = -1;
   if (block.left_available[kPlaneY]) {
-    left = block.bp_left->segment_id;
+    left = map.segment_id(block.row4x4, block.column4x4 - 1);
   }
   int pred;
   if (top == -1) {
@@ -209,7 +212,7 @@
   }
   BlockParameters& bp = *block.bp;
   if (bp.skip) {
-    bp.segment_id = pred;
+    bp.prediction_parameters->segment_id = pred;
     return true;
   }
   int context = 0;
@@ -224,17 +227,18 @@
       symbol_decoder_context_.segment_id_cdf[context];
   const int encoded_segment_id =
       reader_.ReadSymbol<kMaxSegments>(segment_id_cdf);
-  bp.segment_id =
+  bp.prediction_parameters->segment_id =
       DecodeSegmentId(encoded_segment_id, pred,
                       frame_header_.segmentation.last_active_segment_id + 1);
   // Check the bitstream conformance requirement in Section 6.10.8 of the spec.
-  if (bp.segment_id < 0 ||
-      bp.segment_id > frame_header_.segmentation.last_active_segment_id) {
+  if (bp.prediction_parameters->segment_id < 0 ||
+      bp.prediction_parameters->segment_id >
+          frame_header_.segmentation.last_active_segment_id) {
     LIBGAV1_DLOG(
         ERROR,
         "Corrupted segment_ids: encoded %d, last active %d, postprocessed %d",
         encoded_segment_id, frame_header_.segmentation.last_active_segment_id,
-        bp.segment_id);
+        bp.prediction_parameters->segment_id);
     return false;
   }
   return true;
@@ -243,7 +247,7 @@
 bool Tile::ReadIntraSegmentId(const Block& block) {
   BlockParameters& bp = *block.bp;
   if (!frame_header_.segmentation.enabled) {
-    bp.segment_id = 0;
+    bp.prediction_parameters->segment_id = 0;
     return true;
   }
   return ReadSegmentId(block);
@@ -252,8 +256,8 @@
 void Tile::ReadSkip(const Block& block) {
   BlockParameters& bp = *block.bp;
   if (frame_header_.segmentation.segment_id_pre_skip &&
-      frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureSkip)) {
+      frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id, kSegmentFeatureSkip)) {
     bp.skip = true;
     return;
   }
@@ -268,51 +272,53 @@
   bp.skip = reader_.ReadSymbol(skip_cdf);
 }
 
-void Tile::ReadSkipMode(const Block& block) {
+bool Tile::ReadSkipMode(const Block& block) {
   BlockParameters& bp = *block.bp;
   if (!frame_header_.skip_mode_present ||
-      frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureSkip) ||
-      frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureReferenceFrame) ||
-      frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureGlobalMv) ||
+      frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id, kSegmentFeatureSkip) ||
+      frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id,
+          kSegmentFeatureReferenceFrame) ||
+      frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id, kSegmentFeatureGlobalMv) ||
       IsBlockDimension4(block.size)) {
-    bp.skip_mode = false;
-    return;
+    return false;
   }
   const int context =
       (block.left_available[kPlaneY]
-           ? static_cast<int>(block.bp_left->skip_mode)
+           ? static_cast<int>(left_context_.skip_mode[block.left_context_index])
            : 0) +
-      (block.top_available[kPlaneY] ? static_cast<int>(block.bp_top->skip_mode)
-                                    : 0);
-  bp.skip_mode =
-      reader_.ReadSymbol(symbol_decoder_context_.skip_mode_cdf[context]);
+      (block.top_available[kPlaneY]
+           ? static_cast<int>(
+                 block.top_context->skip_mode[block.top_context_index])
+           : 0);
+  return reader_.ReadSymbol(symbol_decoder_context_.skip_mode_cdf[context]);
 }
 
 void Tile::ReadCdef(const Block& block) {
   BlockParameters& bp = *block.bp;
   if (bp.skip || frame_header_.coded_lossless ||
-      !sequence_header_.enable_cdef || frame_header_.allow_intrabc) {
+      !sequence_header_.enable_cdef || frame_header_.allow_intrabc ||
+      frame_header_.cdef.bits == 0) {
     return;
   }
-  const int cdef_size4x4 = kNum4x4BlocksWide[kBlock64x64];
-  const int cdef_mask4x4 = ~(cdef_size4x4 - 1);
-  const int row4x4 = block.row4x4 & cdef_mask4x4;
-  const int column4x4 = block.column4x4 & cdef_mask4x4;
-  const int row = DivideBy16(row4x4);
-  const int column = DivideBy16(column4x4);
-  if (cdef_index_[row][column] == -1) {
-    cdef_index_[row][column] =
-        frame_header_.cdef.bits > 0
-            ? static_cast<int16_t>(reader_.ReadLiteral(frame_header_.cdef.bits))
-            : 0;
-    for (int i = row4x4; i < row4x4 + block.height4x4; i += cdef_size4x4) {
-      for (int j = column4x4; j < column4x4 + block.width4x4;
-           j += cdef_size4x4) {
-        cdef_index_[DivideBy16(i)][DivideBy16(j)] = cdef_index_[row][column];
-      }
+  int8_t* const cdef_index =
+      &cdef_index_[DivideBy16(block.row4x4)][DivideBy16(block.column4x4)];
+  int stride = cdef_index_.columns();
+  if (cdef_index[0] == -1) {
+    cdef_index[0] =
+        static_cast<int8_t>(reader_.ReadLiteral(frame_header_.cdef.bits));
+    if (block.size == kBlock128x128) {
+      // This condition is shorthand for block.width4x4 > 16 && block.height4x4
+      // > 16.
+      cdef_index[1] = cdef_index[0];
+      cdef_index[stride] = cdef_index[0];
+      cdef_index[stride + 1] = cdef_index[0];
+    } else if (block.width4x4 > 16) {
+      cdef_index[1] = cdef_index[0];
+    } else if (block.height4x4 > 16) {
+      cdef_index[stride] = cdef_index[0];
     }
   }
 }
@@ -328,7 +334,7 @@
     abs = abs_remaining_bits + (1 << remaining_bit_count) + 1;
   }
   if (abs != 0) {
-    const bool sign = static_cast<bool>(reader_.ReadBit());
+    const bool sign = reader_.ReadBit() != 0;
     const int scaled_abs = abs << scale;
     const int reduced_delta = sign ? -scaled_abs : scaled_abs;
     value += reduced_delta;
@@ -404,8 +410,9 @@
   PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
   prediction_parameters.angle_delta[plane_type] = 0;
-  const PredictionMode mode =
-      (plane_type == kPlaneTypeY) ? bp.y_mode : bp.uv_mode;
+  const PredictionMode mode = (plane_type == kPlaneTypeY)
+                                  ? bp.y_mode
+                                  : bp.prediction_parameters->uv_mode;
   if (IsBlockSmallerThan8x8(block.size) || !IsDirectionalMode(mode)) return;
   uint16_t* const cdf =
       symbol_decoder_context_.angle_delta_cdf[mode - kPredictionModeVertical];
@@ -445,7 +452,8 @@
 void Tile::ReadPredictionModeUV(const Block& block) {
   BlockParameters& bp = *block.bp;
   bool chroma_from_luma_allowed;
-  if (frame_header_.segmentation.lossless[bp.segment_id]) {
+  if (frame_header_.segmentation
+          .lossless[bp.prediction_parameters->segment_id]) {
     chroma_from_luma_allowed = block.residual_size[kPlaneU] == kBlock4x4;
   } else {
     chroma_from_luma_allowed = IsBlockDimensionLessThan64(block.size);
@@ -454,10 +462,10 @@
       symbol_decoder_context_
           .uv_mode_cdf[static_cast<int>(chroma_from_luma_allowed)][bp.y_mode];
   if (chroma_from_luma_allowed) {
-    bp.uv_mode = static_cast<PredictionMode>(
+    bp.prediction_parameters->uv_mode = static_cast<PredictionMode>(
         reader_.ReadSymbol<kIntraPredictionModesUV>(cdf));
   } else {
-    bp.uv_mode = static_cast<PredictionMode>(
+    bp.prediction_parameters->uv_mode = static_cast<PredictionMode>(
         reader_.ReadSymbol<kIntraPredictionModesUV - 1>(cdf));
   }
 }
@@ -528,7 +536,7 @@
       *block.bp->prediction_parameters;
   prediction_parameters.use_filter_intra = false;
   if (!sequence_header_.enable_filter_intra || bp.y_mode != kPredictionModeDc ||
-      bp.palette_mode_info.size[kPlaneTypeY] != 0 ||
+      bp.prediction_parameters->palette_mode_info.size[kPlaneTypeY] != 0 ||
       !IsBlockDimensionLessThan64(block.size)) {
     return;
   }
@@ -548,7 +556,7 @@
       !ReadIntraSegmentId(block)) {
     return false;
   }
-  bp.skip_mode = false;
+  SetCdfContextSkipMode(block, false);
   ReadSkip(block);
   if (!frame_header_.segmentation.segment_id_pre_skip &&
       !ReadIntraSegmentId(block)) {
@@ -572,12 +580,14 @@
     bp.reference_frame[0] = kReferenceFrameIntra;
     bp.reference_frame[1] = kReferenceFrameNone;
     bp.y_mode = kPredictionModeDc;
-    bp.uv_mode = kPredictionModeDc;
+    bp.prediction_parameters->uv_mode = kPredictionModeDc;
+    SetCdfContextUVMode(block);
     prediction_parameters.motion_mode = kMotionModeSimple;
     prediction_parameters.compound_prediction_type =
         kCompoundPredictionTypeAverage;
-    bp.palette_mode_info.size[kPlaneTypeY] = 0;
-    bp.palette_mode_info.size[kPlaneTypeUV] = 0;
+    bp.prediction_parameters->palette_mode_info.size[kPlaneTypeY] = 0;
+    bp.prediction_parameters->palette_mode_info.size[kPlaneTypeUV] = 0;
+    SetCdfContextPaletteSize(block);
     bp.interpolation_filter[0] = kInterpolationFilterBilinear;
     bp.interpolation_filter[1] = kInterpolationFilterBilinear;
     MvContexts dummy_mode_contexts;
@@ -608,59 +618,73 @@
   return id;
 }
 
+void Tile::SetCdfContextUsePredictedSegmentId(const Block& block,
+                                              bool use_predicted_segment_id) {
+  memset(left_context_.use_predicted_segment_id + block.left_context_index,
+         static_cast<int>(use_predicted_segment_id), block.height4x4);
+  memset(block.top_context->use_predicted_segment_id + block.top_context_index,
+         static_cast<int>(use_predicted_segment_id), block.width4x4);
+}
+
 bool Tile::ReadInterSegmentId(const Block& block, bool pre_skip) {
   BlockParameters& bp = *block.bp;
   if (!frame_header_.segmentation.enabled) {
-    bp.segment_id = 0;
+    bp.prediction_parameters->segment_id = 0;
     return true;
   }
   if (!frame_header_.segmentation.update_map) {
-    bp.segment_id = ComputePredictedSegmentId(block);
+    bp.prediction_parameters->segment_id = ComputePredictedSegmentId(block);
     return true;
   }
   if (pre_skip) {
     if (!frame_header_.segmentation.segment_id_pre_skip) {
-      bp.segment_id = 0;
+      bp.prediction_parameters->segment_id = 0;
       return true;
     }
   } else if (bp.skip) {
-    bp.use_predicted_segment_id = false;
+    SetCdfContextUsePredictedSegmentId(block, false);
     return ReadSegmentId(block);
   }
   if (frame_header_.segmentation.temporal_update) {
     const int context =
         (block.left_available[kPlaneY]
-             ? static_cast<int>(block.bp_left->use_predicted_segment_id)
+             ? static_cast<int>(
+                   left_context_
+                       .use_predicted_segment_id[block.left_context_index])
              : 0) +
         (block.top_available[kPlaneY]
-             ? static_cast<int>(block.bp_top->use_predicted_segment_id)
+             ? static_cast<int>(
+                   block.top_context
+                       ->use_predicted_segment_id[block.top_context_index])
              : 0);
-    bp.use_predicted_segment_id = reader_.ReadSymbol(
+    const bool use_predicted_segment_id = reader_.ReadSymbol(
         symbol_decoder_context_.use_predicted_segment_id_cdf[context]);
-    if (bp.use_predicted_segment_id) {
-      bp.segment_id = ComputePredictedSegmentId(block);
+    SetCdfContextUsePredictedSegmentId(block, use_predicted_segment_id);
+    if (use_predicted_segment_id) {
+      bp.prediction_parameters->segment_id = ComputePredictedSegmentId(block);
       return true;
     }
   }
   return ReadSegmentId(block);
 }
 
-void Tile::ReadIsInter(const Block& block) {
+void Tile::ReadIsInter(const Block& block, bool skip_mode) {
   BlockParameters& bp = *block.bp;
-  if (bp.skip_mode) {
+  if (skip_mode) {
     bp.is_inter = true;
     return;
   }
-  if (frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureReferenceFrame)) {
-    bp.is_inter =
-        frame_header_.segmentation
-            .feature_data[bp.segment_id][kSegmentFeatureReferenceFrame] !=
-        kReferenceFrameIntra;
+  if (frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id,
+          kSegmentFeatureReferenceFrame)) {
+    bp.is_inter = frame_header_.segmentation
+                      .feature_data[bp.prediction_parameters->segment_id]
+                                   [kSegmentFeatureReferenceFrame] !=
+                  kReferenceFrameIntra;
     return;
   }
-  if (frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureGlobalMv)) {
+  if (frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id, kSegmentFeatureGlobalMv)) {
     bp.is_inter = true;
     return;
   }
@@ -678,6 +702,49 @@
       reader_.ReadSymbol(symbol_decoder_context_.is_inter_cdf[context]);
 }
 
+void Tile::SetCdfContextPaletteSize(const Block& block) {
+  const PaletteModeInfo& palette_mode_info =
+      block.bp->prediction_parameters->palette_mode_info;
+  for (int plane_type = kPlaneTypeY; plane_type <= kPlaneTypeUV; ++plane_type) {
+    memset(left_context_.palette_size[plane_type] + block.left_context_index,
+           palette_mode_info.size[plane_type], block.height4x4);
+    memset(
+        block.top_context->palette_size[plane_type] + block.top_context_index,
+        palette_mode_info.size[plane_type], block.width4x4);
+    if (palette_mode_info.size[plane_type] == 0) continue;
+    for (int i = block.left_context_index;
+         i < block.left_context_index + block.height4x4; ++i) {
+      memcpy(left_context_.palette_color[i][plane_type],
+             palette_mode_info.color[plane_type],
+             kMaxPaletteSize * sizeof(palette_mode_info.color[0][0]));
+    }
+    for (int i = block.top_context_index;
+         i < block.top_context_index + block.width4x4; ++i) {
+      memcpy(block.top_context->palette_color[i][plane_type],
+             palette_mode_info.color[plane_type],
+             kMaxPaletteSize * sizeof(palette_mode_info.color[0][0]));
+    }
+  }
+}
+
+void Tile::SetCdfContextUVMode(const Block& block) {
+  // BlockCdfContext.uv_mode is only used to compute is_smooth_prediction for
+  // the intra edge upsamplers in the subsequent blocks. They have some special
+  // rules for subsampled UV planes. For subsampled UV planes, update left
+  // context only if current block contains the last odd column and update top
+  // context only if current block contains the last odd row.
+  if (subsampling_x_[kPlaneU] == 0 || (block.column4x4 & 1) == 1 ||
+      block.width4x4 > 1) {
+    memset(left_context_.uv_mode + block.left_context_index,
+           block.bp->prediction_parameters->uv_mode, block.height4x4);
+  }
+  if (subsampling_y_[kPlaneU] == 0 || (block.row4x4 & 1) == 1 ||
+      block.height4x4 > 1) {
+    memset(block.top_context->uv_mode + block.top_context_index,
+           block.bp->prediction_parameters->uv_mode, block.width4x4);
+  }
+}
+
 bool Tile::ReadIntraBlockModeInfo(const Block& block, bool intra_y_mode) {
   BlockParameters& bp = *block.bp;
   bp.reference_frame[0] = kReferenceFrameIntra;
@@ -686,12 +753,39 @@
   ReadIntraAngleInfo(block, kPlaneTypeY);
   if (block.HasChroma()) {
     ReadPredictionModeUV(block);
-    if (bp.uv_mode == kPredictionModeChromaFromLuma) {
+    if (bp.prediction_parameters->uv_mode == kPredictionModeChromaFromLuma) {
       ReadCflAlpha(block);
     }
+    if (block.left_available[kPlaneU]) {
+      const int smooth_row =
+          block.row4x4 + (~block.row4x4 & subsampling_y_[kPlaneU]);
+      const int smooth_column =
+          block.column4x4 - 1 - (block.column4x4 & subsampling_x_[kPlaneU]);
+      const BlockParameters& bp_left =
+          *block_parameters_holder_.Find(smooth_row, smooth_column);
+      bp.prediction_parameters->chroma_left_uses_smooth_prediction =
+          (bp_left.reference_frame[0] <= kReferenceFrameIntra) &&
+          kPredictionModeSmoothMask.Contains(
+              left_context_.uv_mode[CdfContextIndex(smooth_row)]);
+    }
+    if (block.top_available[kPlaneU]) {
+      const int smooth_row =
+          block.row4x4 - 1 - (block.row4x4 & subsampling_y_[kPlaneU]);
+      const int smooth_column =
+          block.column4x4 + (~block.column4x4 & subsampling_x_[kPlaneU]);
+      const BlockParameters& bp_top =
+          *block_parameters_holder_.Find(smooth_row, smooth_column);
+      bp.prediction_parameters->chroma_top_uses_smooth_prediction =
+          (bp_top.reference_frame[0] <= kReferenceFrameIntra) &&
+          kPredictionModeSmoothMask.Contains(
+              top_context_.get()[SuperBlockColumnIndex(smooth_column)]
+                  .uv_mode[CdfContextIndex(smooth_column)]);
+    }
+    SetCdfContextUVMode(block);
     ReadIntraAngleInfo(block, kPlaneTypeUV);
   }
   ReadPaletteModeInfo(block);
+  SetCdfContextPaletteSize(block);
   ReadFilterIntraModeInfo(block);
   return true;
 }
@@ -808,25 +902,27 @@
   return symbol_decoder_context_.compound_reference_cdf[type][context][index];
 }
 
-void Tile::ReadReferenceFrames(const Block& block) {
+void Tile::ReadReferenceFrames(const Block& block, bool skip_mode) {
   BlockParameters& bp = *block.bp;
-  if (bp.skip_mode) {
+  if (skip_mode) {
     bp.reference_frame[0] = frame_header_.skip_mode_frame[0];
     bp.reference_frame[1] = frame_header_.skip_mode_frame[1];
     return;
   }
-  if (frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureReferenceFrame)) {
+  if (frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id,
+          kSegmentFeatureReferenceFrame)) {
     bp.reference_frame[0] = static_cast<ReferenceFrameType>(
         frame_header_.segmentation
-            .feature_data[bp.segment_id][kSegmentFeatureReferenceFrame]);
+            .feature_data[bp.prediction_parameters->segment_id]
+                         [kSegmentFeatureReferenceFrame]);
     bp.reference_frame[1] = kReferenceFrameNone;
     return;
   }
-  if (frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureSkip) ||
-      frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureGlobalMv)) {
+  if (frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id, kSegmentFeatureSkip) ||
+      frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id, kSegmentFeatureGlobalMv)) {
     bp.reference_frame[0] = kReferenceFrameLast;
     bp.reference_frame[1] = kReferenceFrameNone;
     return;
@@ -927,16 +1023,17 @@
 }
 
 void Tile::ReadInterPredictionModeY(const Block& block,
-                                    const MvContexts& mode_contexts) {
+                                    const MvContexts& mode_contexts,
+                                    bool skip_mode) {
   BlockParameters& bp = *block.bp;
-  if (bp.skip_mode) {
+  if (skip_mode) {
     bp.y_mode = kPredictionModeNearestNearestMv;
     return;
   }
-  if (frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureSkip) ||
-      frame_header_.segmentation.FeatureActive(bp.segment_id,
-                                               kSegmentFeatureGlobalMv)) {
+  if (frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id, kSegmentFeatureSkip) ||
+      frame_header_.segmentation.FeatureActive(
+          bp.prediction_parameters->segment_id, kSegmentFeatureGlobalMv)) {
     bp.y_mode = kPredictionModeGlobalMv;
     return;
   }
@@ -995,13 +1092,14 @@
   }
 }
 
-void Tile::ReadInterIntraMode(const Block& block, bool is_compound) {
+void Tile::ReadInterIntraMode(const Block& block, bool is_compound,
+                              bool skip_mode) {
   BlockParameters& bp = *block.bp;
   PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
   prediction_parameters.inter_intra_mode = kNumInterIntraModes;
   prediction_parameters.is_wedge_inter_intra = false;
-  if (bp.skip_mode || !sequence_header_.enable_interintra_compound ||
+  if (skip_mode || !sequence_header_.enable_interintra_compound ||
       is_compound || !kIsInterIntraModeAllowedMask.Contains(block.size)) {
     return;
   }
@@ -1031,13 +1129,14 @@
   prediction_parameters.wedge_sign = 0;
 }
 
-void Tile::ReadMotionMode(const Block& block, bool is_compound) {
+void Tile::ReadMotionMode(const Block& block, bool is_compound,
+                          bool skip_mode) {
   BlockParameters& bp = *block.bp;
   PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
   const auto global_motion_type =
       frame_header_.global_motion[bp.reference_frame[0]].type;
-  if (bp.skip_mode || !frame_header_.is_motion_mode_switchable ||
+  if (skip_mode || !frame_header_.is_motion_mode_switchable ||
       IsBlockDimension4(block.size) ||
       (frame_header_.force_integer_mv == 0 &&
        (bp.y_mode == kPredictionModeGlobalMv ||
@@ -1073,14 +1172,17 @@
   int context = 0;
   if (block.top_available[kPlaneY]) {
     if (!block.IsTopSingle()) {
-      context += static_cast<int>(block.bp_top->is_explicit_compound_type);
+      context += static_cast<int>(
+          block.top_context
+              ->is_explicit_compound_type[block.top_context_index]);
     } else if (block.TopReference(0) == kReferenceFrameAlternate) {
       context += 3;
     }
   }
   if (block.left_available[kPlaneY]) {
     if (!block.IsLeftSingle()) {
-      context += static_cast<int>(block.bp_left->is_explicit_compound_type);
+      context += static_cast<int>(
+          left_context_.is_explicit_compound_type[block.left_context_index]);
     } else if (block.LeftReference(0) == kReferenceFrameAlternate) {
       context += 3;
     }
@@ -1099,14 +1201,16 @@
   int context = (forward == backward) ? 3 : 0;
   if (block.top_available[kPlaneY]) {
     if (!block.IsTopSingle()) {
-      context += static_cast<int>(block.bp_top->is_compound_type_average);
+      context += static_cast<int>(
+          block.top_context->is_compound_type_average[block.top_context_index]);
     } else if (block.TopReference(0) == kReferenceFrameAlternate) {
       ++context;
     }
   }
   if (block.left_available[kPlaneY]) {
     if (!block.IsLeftSingle()) {
-      context += static_cast<int>(block.bp_left->is_compound_type_average);
+      context += static_cast<int>(
+          left_context_.is_compound_type_average[block.left_context_index]);
     } else if (block.LeftReference(0) == kReferenceFrameAlternate) {
       ++context;
     }
@@ -1114,23 +1218,25 @@
   return symbol_decoder_context_.is_compound_type_average_cdf[context];
 }
 
-void Tile::ReadCompoundType(const Block& block, bool is_compound) {
-  BlockParameters& bp = *block.bp;
-  bp.is_explicit_compound_type = false;
-  bp.is_compound_type_average = true;
+void Tile::ReadCompoundType(const Block& block, bool is_compound,
+                            bool skip_mode,
+                            bool* const is_explicit_compound_type,
+                            bool* const is_compound_type_average) {
+  *is_explicit_compound_type = false;
+  *is_compound_type_average = true;
   PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
-  if (bp.skip_mode) {
+  if (skip_mode) {
     prediction_parameters.compound_prediction_type =
         kCompoundPredictionTypeAverage;
     return;
   }
   if (is_compound) {
     if (sequence_header_.enable_masked_compound) {
-      bp.is_explicit_compound_type =
+      *is_explicit_compound_type =
           reader_.ReadSymbol(GetIsExplicitCompoundTypeCdf(block));
     }
-    if (bp.is_explicit_compound_type) {
+    if (*is_explicit_compound_type) {
       if (kIsWedgeCompoundModeAllowed.Contains(block.size)) {
         // Only kCompoundPredictionTypeWedge and
         // kCompoundPredictionTypeDiffWeighted are signaled explicitly.
@@ -1143,11 +1249,11 @@
       }
     } else {
       if (sequence_header_.enable_jnt_comp) {
-        bp.is_compound_type_average =
+        *is_compound_type_average =
             reader_.ReadSymbol(GetIsCompoundTypeAverageCdf(block));
         prediction_parameters.compound_prediction_type =
-            bp.is_compound_type_average ? kCompoundPredictionTypeAverage
-                                        : kCompoundPredictionTypeDistance;
+            *is_compound_type_average ? kCompoundPredictionTypeAverage
+                                      : kCompoundPredictionTypeDistance;
       } else {
         prediction_parameters.compound_prediction_type =
             kCompoundPredictionTypeAverage;
@@ -1162,8 +1268,7 @@
       prediction_parameters.wedge_sign = static_cast<int>(reader_.ReadBit());
     } else if (prediction_parameters.compound_prediction_type ==
                kCompoundPredictionTypeDiffWeighted) {
-      prediction_parameters.mask_is_inverse =
-          static_cast<bool>(reader_.ReadBit());
+      prediction_parameters.mask_is_inverse = reader_.ReadBit() != 0;
     }
     return;
   }
@@ -1209,7 +1314,7 @@
   return symbol_decoder_context_.interpolation_filter_cdf[context];
 }
 
-void Tile::ReadInterpolationFilter(const Block& block) {
+void Tile::ReadInterpolationFilter(const Block& block, bool skip_mode) {
   BlockParameters& bp = *block.bp;
   if (frame_header_.interpolation_filter != kInterpolationFilterSwitchable) {
     static_assert(
@@ -1222,7 +1327,7 @@
     return;
   }
   bool interpolation_filter_present = true;
-  if (bp.skip_mode ||
+  if (skip_mode ||
       block.bp->prediction_parameters->motion_mode == kMotionModeLocalWarp) {
     interpolation_filter_present = false;
   } else if (!IsBlockDimension4(block.size) &&
@@ -1251,31 +1356,58 @@
   }
 }
 
-bool Tile::ReadInterBlockModeInfo(const Block& block) {
+void Tile::SetCdfContextCompoundType(const Block& block,
+                                     bool is_explicit_compound_type,
+                                     bool is_compound_type_average) {
+  memset(left_context_.is_explicit_compound_type + block.left_context_index,
+         static_cast<int>(is_explicit_compound_type), block.height4x4);
+  memset(left_context_.is_compound_type_average + block.left_context_index,
+         static_cast<int>(is_compound_type_average), block.height4x4);
+  memset(block.top_context->is_explicit_compound_type + block.top_context_index,
+         static_cast<int>(is_explicit_compound_type), block.width4x4);
+  memset(block.top_context->is_compound_type_average + block.top_context_index,
+         static_cast<int>(is_compound_type_average), block.width4x4);
+}
+
+bool Tile::ReadInterBlockModeInfo(const Block& block, bool skip_mode) {
   BlockParameters& bp = *block.bp;
-  bp.palette_mode_info.size[kPlaneTypeY] = 0;
-  bp.palette_mode_info.size[kPlaneTypeUV] = 0;
-  ReadReferenceFrames(block);
+  bp.prediction_parameters->palette_mode_info.size[kPlaneTypeY] = 0;
+  bp.prediction_parameters->palette_mode_info.size[kPlaneTypeUV] = 0;
+  SetCdfContextPaletteSize(block);
+  ReadReferenceFrames(block, skip_mode);
   const bool is_compound = bp.reference_frame[1] > kReferenceFrameIntra;
   MvContexts mode_contexts;
   FindMvStack(block, is_compound, &mode_contexts);
-  ReadInterPredictionModeY(block, mode_contexts);
+  ReadInterPredictionModeY(block, mode_contexts, skip_mode);
   ReadRefMvIndex(block);
   if (!AssignInterMv(block, is_compound)) return false;
-  ReadInterIntraMode(block, is_compound);
-  ReadMotionMode(block, is_compound);
-  ReadCompoundType(block, is_compound);
-  ReadInterpolationFilter(block);
+  ReadInterIntraMode(block, is_compound, skip_mode);
+  ReadMotionMode(block, is_compound, skip_mode);
+  bool is_explicit_compound_type;
+  bool is_compound_type_average;
+  ReadCompoundType(block, is_compound, skip_mode, &is_explicit_compound_type,
+                   &is_compound_type_average);
+  SetCdfContextCompoundType(block, is_explicit_compound_type,
+                            is_compound_type_average);
+  ReadInterpolationFilter(block, skip_mode);
   return true;
 }
 
+void Tile::SetCdfContextSkipMode(const Block& block, bool skip_mode) {
+  memset(left_context_.skip_mode + block.left_context_index,
+         static_cast<int>(skip_mode), block.height4x4);
+  memset(block.top_context->skip_mode + block.top_context_index,
+         static_cast<int>(skip_mode), block.width4x4);
+}
+
 bool Tile::DecodeInterModeInfo(const Block& block) {
   BlockParameters& bp = *block.bp;
   block.bp->prediction_parameters->use_intra_block_copy = false;
   bp.skip = false;
   if (!ReadInterSegmentId(block, /*pre_skip=*/true)) return false;
-  ReadSkipMode(block);
-  if (bp.skip_mode) {
+  bool skip_mode = ReadSkipMode(block);
+  SetCdfContextSkipMode(block, skip_mode);
+  if (skip_mode) {
     bp.skip = true;
   } else {
     ReadSkip(block);
@@ -1290,8 +1422,8 @@
     ReadLoopFilterDelta(block);
     read_deltas_ = false;
   }
-  ReadIsInter(block);
-  return bp.is_inter ? ReadInterBlockModeInfo(block)
+  ReadIsInter(block, skip_mode);
+  return bp.is_inter ? ReadInterBlockModeInfo(block, skip_mode)
                      : ReadIntraBlockModeInfo(block, /*intra_y_mode=*/false);
 }
 
diff --git a/libgav1/src/tile/bitstream/palette.cc b/libgav1/src/tile/bitstream/palette.cc
index 41b42d6..27e5110 100644
--- a/libgav1/src/tile/bitstream/palette.cc
+++ b/libgav1/src/tile/bitstream/palette.cc
@@ -35,20 +35,23 @@
                           uint16_t* const cache) {
   const int top_size =
       (block.top_available[kPlaneY] && Mod64(MultiplyBy4(block.row4x4)) != 0)
-          ? block.bp_top->palette_mode_info.size[plane_type]
+          ? block.top_context->palette_size[plane_type][block.top_context_index]
           : 0;
-  const int left_size = block.left_available[kPlaneY]
-                            ? block.bp_left->palette_mode_info.size[plane_type]
-                            : 0;
+  const int left_size =
+      block.left_available[kPlaneY]
+          ? left_context_.palette_size[plane_type][block.left_context_index]
+          : 0;
   if (left_size == 0 && top_size == 0) return 0;
   // Merge the left and top colors in sorted order and store them in |cache|.
-  uint16_t dummy[1];
-  const uint16_t* top = (top_size > 0)
-                            ? block.bp_top->palette_mode_info.color[plane_type]
-                            : dummy;
+  uint16_t empty_palette[1];
+  const uint16_t* top =
+      (top_size > 0) ? block.top_context
+                           ->palette_color[block.top_context_index][plane_type]
+                     : empty_palette;
   const uint16_t* left =
-      (left_size > 0) ? block.bp_left->palette_mode_info.color[plane_type]
-                      : dummy;
+      (left_size > 0)
+          ? left_context_.palette_color[block.left_context_index][plane_type]
+          : empty_palette;
   std::merge(top, top + top_size, left, left + left_size, cache);
   // Deduplicate the entries in |cache| and return the number of unique
   // entries.
@@ -61,8 +64,10 @@
   uint16_t cache[2 * kMaxPaletteSize];
   const int n = GetPaletteCache(block, plane_type, cache);
   BlockParameters& bp = *block.bp;
-  const uint8_t palette_size = bp.palette_mode_info.size[plane_type];
-  uint16_t* const palette_color = bp.palette_mode_info.color[plane];
+  const uint8_t palette_size =
+      bp.prediction_parameters->palette_mode_info.size[plane_type];
+  uint16_t* const palette_color =
+      bp.prediction_parameters->palette_mode_info.color[plane];
   const int8_t bitdepth = sequence_header_.color_config.bitdepth;
   int index = 0;
   for (int i = 0; i < n && index < palette_size; ++i) {
@@ -101,7 +106,8 @@
   std::inplace_merge(palette_color, palette_color + merge_pivot,
                      palette_color + palette_size);
   if (plane_type == kPlaneTypeUV) {
-    uint16_t* const palette_color_v = bp.palette_mode_info.color[kPlaneV];
+    uint16_t* const palette_color_v =
+        bp.prediction_parameters->palette_mode_info.color[kPlaneV];
     if (reader_.ReadBit() != 0) {  // delta_encode_palette_colors_v.
       const int bits = bitdepth - 4 + static_cast<int>(reader_.ReadLiteral(2));
       palette_color_v[0] = reader_.ReadLiteral(bitdepth);
@@ -130,8 +136,8 @@
 
 void Tile::ReadPaletteModeInfo(const Block& block) {
   BlockParameters& bp = *block.bp;
-  bp.palette_mode_info.size[kPlaneTypeY] = 0;
-  bp.palette_mode_info.size[kPlaneTypeUV] = 0;
+  bp.prediction_parameters->palette_mode_info.size[kPlaneTypeY] = 0;
+  bp.prediction_parameters->palette_mode_info.size[kPlaneTypeUV] = 0;
   if (IsBlockSmallerThan8x8(block.size) || block.size > kBlock64x64 ||
       !frame_header_.allow_screen_content_tools) {
     return;
@@ -140,29 +146,32 @@
       k4x4WidthLog2[block.size] + k4x4HeightLog2[block.size] - 2;
   if (bp.y_mode == kPredictionModeDc) {
     const int context =
-        static_cast<int>(block.top_available[kPlaneY] &&
-                         block.bp_top->palette_mode_info.size[kPlaneTypeY] >
-                             0) +
-        static_cast<int>(block.left_available[kPlaneY] &&
-                         block.bp_left->palette_mode_info.size[kPlaneTypeY] >
-                             0);
+        static_cast<int>(
+            block.top_available[kPlaneY] &&
+            block.top_context
+                    ->palette_size[kPlaneTypeY][block.top_context_index] > 0) +
+        static_cast<int>(
+            block.left_available[kPlaneY] &&
+            left_context_.palette_size[kPlaneTypeY][block.left_context_index] >
+                0);
     const bool has_palette_y = reader_.ReadSymbol(
         symbol_decoder_context_.has_palette_y_cdf[block_size_context][context]);
     if (has_palette_y) {
-      bp.palette_mode_info.size[kPlaneTypeY] =
+      bp.prediction_parameters->palette_mode_info.size[kPlaneTypeY] =
           kMinPaletteSize +
           reader_.ReadSymbol<kPaletteSizeSymbolCount>(
               symbol_decoder_context_.palette_y_size_cdf[block_size_context]);
       ReadPaletteColors(block, kPlaneY);
     }
   }
-  if (block.HasChroma() && bp.uv_mode == kPredictionModeDc) {
-    const int context =
-        static_cast<int>(bp.palette_mode_info.size[kPlaneTypeY] > 0);
+  if (block.HasChroma() &&
+      bp.prediction_parameters->uv_mode == kPredictionModeDc) {
+    const int context = static_cast<int>(
+        bp.prediction_parameters->palette_mode_info.size[kPlaneTypeY] > 0);
     const bool has_palette_uv =
         reader_.ReadSymbol(symbol_decoder_context_.has_palette_uv_cdf[context]);
     if (has_palette_uv) {
-      bp.palette_mode_info.size[kPlaneTypeUV] =
+      bp.prediction_parameters->palette_mode_info.size[kPlaneTypeUV] =
           kMinPaletteSize +
           reader_.ReadSymbol<kPaletteSizeSymbolCount>(
               symbol_decoder_context_.palette_uv_size_cdf[block_size_context]);
@@ -244,7 +253,8 @@
 }
 
 bool Tile::ReadPaletteTokens(const Block& block) {
-  const PaletteModeInfo& palette_mode_info = block.bp->palette_mode_info;
+  const PaletteModeInfo& palette_mode_info =
+      block.bp->prediction_parameters->palette_mode_info;
   PredictionParameters& prediction_parameters =
       *block.bp->prediction_parameters;
   for (int plane_type = kPlaneTypeY;
diff --git a/libgav1/src/tile/bitstream/transform_size.cc b/libgav1/src/tile/bitstream/transform_size.cc
index b79851d..7197400 100644
--- a/libgav1/src/tile/bitstream/transform_size.cc
+++ b/libgav1/src/tile/bitstream/transform_size.cc
@@ -95,7 +95,8 @@
 
 TransformSize Tile::ReadFixedTransformSize(const Block& block) {
   BlockParameters& bp = *block.bp;
-  if (frame_header_.segmentation.lossless[bp.segment_id]) {
+  if (frame_header_.segmentation
+          .lossless[bp.prediction_parameters->segment_id]) {
     return kTransformSize4x4;
   }
   const TransformSize max_rect_tx_size = kMaxTransformSizeRectangle[block.size];
@@ -189,8 +190,6 @@
       memset(&inter_transform_sizes_[node.y + i][node.x], node.tx_size,
              tx_width4x4);
     }
-    block_parameters_holder_.Find(node.y, node.x)->transform_size =
-        node.tx_size;
   } while (!stack.Empty());
 }
 
@@ -198,7 +197,8 @@
   BlockParameters& bp = *block.bp;
   if (frame_header_.tx_mode == kTxModeSelect && block.size > kBlock4x4 &&
       bp.is_inter && !bp.skip &&
-      !frame_header_.segmentation.lossless[bp.segment_id]) {
+      !frame_header_.segmentation
+           .lossless[bp.prediction_parameters->segment_id]) {
     const TransformSize max_tx_size = kMaxTransformSizeRectangle[block.size];
     const int tx_width4x4 = kTransformWidth4x4[max_tx_size];
     const int tx_height4x4 = kTransformHeight4x4[max_tx_size];
@@ -210,10 +210,10 @@
       }
     }
   } else {
-    bp.transform_size = ReadFixedTransformSize(block);
+    const TransformSize transform_size = ReadFixedTransformSize(block);
     for (int row = block.row4x4; row < block.row4x4 + block.height4x4; ++row) {
       static_assert(sizeof(TransformSize) == 1, "");
-      memset(&inter_transform_sizes_[row][block.column4x4], bp.transform_size,
+      memset(&inter_transform_sizes_[row][block.column4x4], transform_size,
              block.width4x4);
     }
   }
diff --git a/libgav1/src/tile/prediction.cc b/libgav1/src/tile/prediction.cc
index c5560a6..bba5a69 100644
--- a/libgav1/src/tile/prediction.cc
+++ b/libgav1/src/tile/prediction.cc
@@ -226,8 +226,8 @@
                            bool has_left, bool has_top, bool has_top_right,
                            bool has_bottom_left, PredictionMode mode,
                            TransformSize tx_size) {
-  const int width = 1 << kTransformWidthLog2[tx_size];
-  const int height = 1 << kTransformHeightLog2[tx_size];
+  const int width = kTransformWidth[tx_size];
+  const int height = kTransformHeight[tx_size];
   const int x_shift = subsampling_x_[plane];
   const int y_shift = subsampling_y_[plane];
   const int max_x = (MultiplyBy4(frame_header_.columns4x4) >> x_shift) - 1;
@@ -386,36 +386,21 @@
                                               TransformSize tx_size);
 #endif
 
-constexpr BitMaskSet kPredictionModeSmoothMask(kPredictionModeSmooth,
-                                               kPredictionModeSmoothHorizontal,
-                                               kPredictionModeSmoothVertical);
-
-bool Tile::IsSmoothPrediction(int row, int column, Plane plane) const {
-  const BlockParameters& bp = *block_parameters_holder_.Find(row, column);
-  PredictionMode mode;
-  if (plane == kPlaneY) {
-    mode = bp.y_mode;
-  } else {
-    if (bp.reference_frame[0] > kReferenceFrameIntra) return false;
-    mode = bp.uv_mode;
-  }
-  return kPredictionModeSmoothMask.Contains(mode);
-}
-
 int Tile::GetIntraEdgeFilterType(const Block& block, Plane plane) const {
-  const int subsampling_x = subsampling_x_[plane];
-  const int subsampling_y = subsampling_y_[plane];
-  if (block.top_available[plane]) {
-    const int row = block.row4x4 - 1 - (block.row4x4 & subsampling_y);
-    const int column = block.column4x4 + (~block.column4x4 & subsampling_x);
-    if (IsSmoothPrediction(row, column, plane)) return 1;
+  bool top;
+  bool left;
+  if (plane == kPlaneY) {
+    top = block.top_available[kPlaneY] &&
+          kPredictionModeSmoothMask.Contains(block.bp_top->y_mode);
+    left = block.left_available[kPlaneY] &&
+           kPredictionModeSmoothMask.Contains(block.bp_left->y_mode);
+  } else {
+    top = block.top_available[plane] &&
+          block.bp->prediction_parameters->chroma_top_uses_smooth_prediction;
+    left = block.left_available[plane] &&
+           block.bp->prediction_parameters->chroma_left_uses_smooth_prediction;
   }
-  if (block.left_available[plane]) {
-    const int row = block.row4x4 + (~block.row4x4 & subsampling_y);
-    const int column = block.column4x4 - 1 - (block.column4x4 & subsampling_x);
-    if (IsSmoothPrediction(row, column, plane)) return 1;
-  }
-  return 0;
+  return static_cast<int>(top || left);
 }
 
 template <typename Pixel>
@@ -510,7 +495,8 @@
                              const int y, const TransformSize tx_size) {
   const int tx_width = kTransformWidth[tx_size];
   const int tx_height = kTransformHeight[tx_size];
-  const uint16_t* const palette = block.bp->palette_mode_info.color[plane];
+  const uint16_t* const palette =
+      block.bp->prediction_parameters->palette_mode_info.color[plane];
   const PlaneType plane_type = GetPlaneType(plane);
   const int x4 = MultiplyBy4(x);
   const int y4 = MultiplyBy4(y);
@@ -695,7 +681,7 @@
             ? global_motion_params->type
             : kNumGlobalMotionTransformationTypes;
     const bool is_global_valid =
-        IsGlobalMvBlock(block.bp->is_global_mv_block, global_motion_type) &&
+        IsGlobalMvBlock(*block.bp, global_motion_type) &&
         SetupShear(global_motion_params);
     // Valid global motion type implies reference type can't be intra.
     assert(!is_global_valid || reference_type != kReferenceFrameIntra);
@@ -1028,6 +1014,7 @@
         (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
          kScaleSubPixelBits) +
         kSubPixelTaps;
+    *ref_block_end_x += kConvolveScaleBorderRight - kConvolveBorderRight;
     ref_block_end_y = *ref_block_start_y + block_height - 1;
   }
   // Determines if we need to extend beyond the left/right/top/bottom border.
@@ -1206,11 +1193,12 @@
                     (ref_block_start_x + kConvolveBorderLeftTop) * pixel_size;
     }
   } else {
+    const int border_right =
+        is_scaled ? kConvolveScaleBorderRight : kConvolveBorderRight;
     // The block width can be at most 2 times as much as current
     // block's width because of scaling.
     auto block_extended_width = Align<ptrdiff_t>(
-        (2 * width + kConvolveBorderLeftTop + kConvolveBorderRight) *
-            pixel_size,
+        (2 * width + kConvolveBorderLeftTop + border_right) * pixel_size,
         kMaxAlignment);
     convolve_buffer_stride = block.scratch_buffer->convolve_block_buffer_stride;
 #if LIBGAV1_MAX_BITDEPTH >= 10
diff --git a/libgav1/src/tile/tile.cc b/libgav1/src/tile/tile.cc
index 9699517..5070bb6 100644
--- a/libgav1/src/tile/tile.cc
+++ b/libgav1/src/tile/tile.cc
@@ -463,6 +463,7 @@
               : 1),
       current_frame_(*current_frame),
       cdef_index_(frame_scratch_buffer->cdef_index),
+      cdef_skip_(frame_scratch_buffer->cdef_skip),
       inter_transform_sizes_(frame_scratch_buffer->inter_transform_sizes),
       thread_pool_(thread_pool),
       residual_buffer_pool_(frame_scratch_buffer->residual_buffer_pool.get()),
@@ -541,16 +542,6 @@
     buffer_[plane].Reset(Align(buffer.height(plane), max_tx_length),
                          buffer.stride(plane),
                          post_filter_.GetUnfilteredBuffer(plane));
-    const int plane_height =
-        SubsampledValue(frame_header_.height, subsampling_y_[plane]);
-    deblock_row_limit_[plane] =
-        std::min(frame_header_.rows4x4, DivideBy4(plane_height + 3)
-                                            << subsampling_y_[plane]);
-    const int plane_width =
-        SubsampledValue(frame_header_.width, subsampling_x_[plane]);
-    deblock_column_limit_[plane] =
-        std::min(frame_header_.columns4x4, DivideBy4(plane_width + 3)
-                                               << subsampling_x_[plane]);
   }
 }
 
@@ -598,6 +589,10 @@
                      column4x4_end_, &motion_field_);
   }
   ResetLoopRestorationParams();
+  if (!top_context_.Resize(superblock_columns_)) {
+    LIBGAV1_DLOG(ERROR, "Allocation of top_context_ failed.");
+    return false;
+  }
   return true;
 }
 
@@ -1019,7 +1014,8 @@
                                          int block_y) {
   const BlockParameters& bp = *block.bp;
   const TransformSize tx_size_square_max = kTransformSizeSquareMax[tx_size];
-  if (frame_header_.segmentation.lossless[bp.segment_id] ||
+  if (frame_header_.segmentation
+          .lossless[bp.prediction_parameters->segment_id] ||
       tx_size_square_max == kTransformSize64x64) {
     return kTransformTypeDctDct;
   }
@@ -1034,7 +1030,7 @@
     const int y4 = std::max(block.row4x4, block_y << subsampling_y_[kPlaneU]);
     tx_type = transform_types_[y4 - block.row4x4][x4 - block.column4x4];
   } else {
-    tx_type = kModeToTransformType[bp.uv_mode];
+    tx_type = kModeToTransformType[bp.prediction_parameters->uv_mode];
   }
   return kTransformTypeInSetMask[tx_set].Contains(tx_type)
              ? tx_type
@@ -1048,7 +1044,8 @@
 
   TransformType tx_type = kTransformTypeDctDct;
   if (tx_set != kTransformSetDctOnly &&
-      frame_header_.segmentation.qindex[bp.segment_id] > 0) {
+      frame_header_.segmentation.qindex[bp.prediction_parameters->segment_id] >
+          0) {
     const int cdf_index = SymbolDecoderContext::TxTypeIndex(tx_set);
     const int cdf_tx_size_index =
         TransformSizeToSquareTransformIndex(kTransformSizeSquareMin[tx_size]);
@@ -1309,7 +1306,7 @@
     int length = 0;
     bool golomb_length_bit = false;
     do {
-      golomb_length_bit = static_cast<bool>(reader_.ReadBit());
+      golomb_length_bit = reader_.ReadBit() != 0;
       ++length;
       if (length > 20) {
         LIBGAV1_DLOG(ERROR, "Invalid golomb_length %d", length);
@@ -1454,7 +1451,7 @@
     for (int i = 1; i < eob_pt - 2; ++i) {
       assert(eob_pt - i >= 3);
       assert(eob_pt <= kEobPt1024SymbolCount);
-      if (static_cast<bool>(reader_.ReadBit())) {
+      if (reader_.ReadBit() != 0) {
         eob += 1 << (eob_pt - i - 3);
       }
     }
@@ -1500,15 +1497,17 @@
         coeff_base_range_cdf, residual, level_buffer);
   }
   const int max_value = (1 << (7 + sequence_header_.color_config.bitdepth)) - 1;
-  const int current_quantizer_index = GetQIndex(
-      frame_header_.segmentation, bp.segment_id, current_quantizer_index_);
+  const int current_quantizer_index =
+      GetQIndex(frame_header_.segmentation,
+                bp.prediction_parameters->segment_id, current_quantizer_index_);
   const int dc_q_value = quantizer_.GetDcValue(plane, current_quantizer_index);
   const int ac_q_value = quantizer_.GetAcValue(plane, current_quantizer_index);
   const int shift = kQuantizationShift[tx_size];
   const uint8_t* const quantizer_matrix =
       (frame_header_.quantizer.use_matrix &&
        *tx_type < kTransformTypeIdentityIdentity &&
-       !frame_header_.segmentation.lossless[bp.segment_id] &&
+       !frame_header_.segmentation
+            .lossless[bp.prediction_parameters->segment_id] &&
        frame_header_.quantizer.matrix_level[plane] < 15)
           ? quantizer_matrix_[frame_header_.quantizer.matrix_level[plane]]
                              [plane_type][adjusted_tx_size]
@@ -1587,15 +1586,17 @@
   const bool do_decode = mode == kProcessingModeDecodeOnly ||
                          mode == kProcessingModeParseAndDecode;
   if (do_decode && !bp.is_inter) {
-    if (bp.palette_mode_info.size[GetPlaneType(plane)] > 0) {
+    if (bp.prediction_parameters->palette_mode_info.size[GetPlaneType(plane)] >
+        0) {
       CALL_BITDEPTH_FUNCTION(PalettePrediction, block, plane, start_x, start_y,
                              x, y, tx_size);
     } else {
       const PredictionMode mode =
-          (plane == kPlaneY)
-              ? bp.y_mode
-              : (bp.uv_mode == kPredictionModeChromaFromLuma ? kPredictionModeDc
-                                                             : bp.uv_mode);
+          (plane == kPlaneY) ? bp.y_mode
+                             : (bp.prediction_parameters->uv_mode ==
+                                        kPredictionModeChromaFromLuma
+                                    ? kPredictionModeDc
+                                    : bp.prediction_parameters->uv_mode);
       const int tr_row4x4 = (sub_block_row4x4 >> subsampling_y);
       const int tr_column4x4 =
           (sub_block_column4x4 >> subsampling_x) + step_x + 1;
@@ -1609,7 +1610,8 @@
           block.scratch_buffer->block_decoded[plane][tr_row4x4][tr_column4x4],
           block.scratch_buffer->block_decoded[plane][bl_row4x4][bl_column4x4],
           mode, tx_size);
-      if (plane != kPlaneY && bp.uv_mode == kPredictionModeChromaFromLuma) {
+      if (plane != kPlaneY &&
+          bp.prediction_parameters->uv_mode == kPredictionModeChromaFromLuma) {
         CALL_BITDEPTH_FUNCTION(ChromaFromLumaPrediction, block, plane, start_x,
                                start_y, tx_size);
       }
@@ -1738,14 +1740,16 @@
         buffer_[plane].rows(), buffer_[plane].columns() / sizeof(uint16_t),
         reinterpret_cast<uint16_t*>(&buffer_[plane][0][0]));
     Reconstruct(dsp_, tx_type, tx_size,
-                frame_header_.segmentation.lossless[block.bp->segment_id],
+                frame_header_.segmentation
+                    .lossless[block.bp->prediction_parameters->segment_id],
                 reinterpret_cast<int32_t*>(*block.residual), start_x, start_y,
                 &buffer, non_zero_coeff_count);
   } else  // NOLINT
 #endif
   {
     Reconstruct(dsp_, tx_type, tx_size,
-                frame_header_.segmentation.lossless[block.bp->segment_id],
+                frame_header_.segmentation
+                    .lossless[block.bp->prediction_parameters->segment_id],
                 reinterpret_cast<int16_t*>(*block.residual), start_x, start_y,
                 &buffer_[plane], non_zero_coeff_count);
   }
@@ -1772,12 +1776,15 @@
         // kTransformSize4x4. So we can simply use |bp.transform_size| here as
         // the Y plane's transform size (part of Section 5.11.37 in the spec).
         const TransformSize tx_size =
-            (plane == kPlaneY) ? bp.transform_size : bp.uv_transform_size;
+            (plane == kPlaneY)
+                ? inter_transform_sizes_[block.row4x4][block.column4x4]
+                : bp.uv_transform_size;
         const BlockSize plane_size =
             kPlaneResidualSize[size_chunk4x4][subsampling_x][subsampling_y];
         assert(plane_size != kBlockInvalid);
         if (bp.is_inter &&
-            !frame_header_.segmentation.lossless[bp.segment_id] &&
+            !frame_header_.segmentation
+                 .lossless[bp.prediction_parameters->segment_id] &&
             plane == kPlaneY) {
           const int row_chunk4x4 = block.row4x4 + MultiplyBy16(chunk_y);
           const int column_chunk4x4 = block.column4x4 + MultiplyBy16(chunk_x);
@@ -2112,15 +2119,53 @@
   for (int i = 0; i < kFrameLfCount; ++i) {
     if (delta_lf_all_zero_) {
       bp.deblock_filter_level[i] = post_filter_.GetZeroDeltaDeblockFilterLevel(
-          bp.segment_id, i, bp.reference_frame[0], mode_id);
+          bp.prediction_parameters->segment_id, i, bp.reference_frame[0],
+          mode_id);
     } else {
       bp.deblock_filter_level[i] =
-          deblock_filter_levels_[bp.segment_id][i][bp.reference_frame[0]]
-                                [mode_id];
+          deblock_filter_levels_[bp.prediction_parameters->segment_id][i]
+                                [bp.reference_frame[0]][mode_id];
     }
   }
 }
 
+void Tile::PopulateCdefSkip(const Block& block) {
+  if (!post_filter_.DoCdef() || block.bp->skip ||
+      (frame_header_.cdef.bits > 0 &&
+       cdef_index_[DivideBy16(block.row4x4)][DivideBy16(block.column4x4)] ==
+           -1)) {
+    return;
+  }
+  // The rest of this function is an efficient version of the following code:
+  // for (int y = block.row4x4; y < block.row4x4 + block.height4x4; y++) {
+  //   for (int x = block.column4x4; y < block.column4x4 + block.width4x4;
+  //        x++) {
+  //     const uint8_t mask = uint8_t{1} << ((x >> 1) & 0x7);
+  //     cdef_skip_[y >> 1][x >> 4] |= mask;
+  //   }
+  // }
+
+  // For all block widths other than 32, the mask will fit in uint8_t. For
+  // block width == 32, the mask is always 0xFFFF.
+  const int bw4 =
+      std::max(DivideBy2(block.width4x4) + (block.column4x4 & 1), 1);
+  const uint8_t mask = (block.width4x4 == 32)
+                           ? 0xFF
+                           : (uint8_t{0xFF} >> (8 - bw4))
+                                 << (DivideBy2(block.column4x4) & 0x7);
+  uint8_t* cdef_skip = &cdef_skip_[block.row4x4 >> 1][block.column4x4 >> 4];
+  const int stride = cdef_skip_.columns();
+  int row = 0;
+  do {
+    *cdef_skip |= mask;
+    if (block.width4x4 == 32) {
+      *(cdef_skip + 1) = 0xFF;
+    }
+    cdef_skip += stride;
+    row += 2;
+  } while (row < block.height4x4);
+}
+
 bool Tile::ProcessBlock(int row4x4, int column4x4, BlockSize block_size,
                         TileScratchBuffer* const scratch_buffer,
                         ResidualPtr* residual) {
@@ -2150,7 +2195,7 @@
     return false;
   }
   BlockParameters& bp = *bp_ptr;
-  Block block(*this, block_size, row4x4, column4x4, scratch_buffer, residual);
+  Block block(this, block_size, row4x4, column4x4, scratch_buffer, residual);
   bp.size = block_size;
   bp.prediction_parameters =
       split_parse_and_decode_ ? std::unique_ptr<PredictionParameters>(
@@ -2158,17 +2203,16 @@
                               : std::move(prediction_parameters_);
   if (bp.prediction_parameters == nullptr) return false;
   if (!DecodeModeInfo(block)) return false;
-  bp.is_global_mv_block = (bp.y_mode == kPredictionModeGlobalMv ||
-                           bp.y_mode == kPredictionModeGlobalGlobalMv) &&
-                          !IsBlockDimension4(bp.size);
   PopulateDeblockFilterLevel(block);
   if (!ReadPaletteTokens(block)) return false;
   DecodeTransformSize(block);
   // Part of Section 5.11.37 in the spec (implemented as a simple lookup).
-  bp.uv_transform_size = frame_header_.segmentation.lossless[bp.segment_id]
-                             ? kTransformSize4x4
-                             : kUVTransformSize[block.residual_size[kPlaneU]];
+  bp.uv_transform_size =
+      frame_header_.segmentation.lossless[bp.prediction_parameters->segment_id]
+          ? kTransformSize4x4
+          : kUVTransformSize[block.residual_size[kPlaneU]];
   if (bp.skip) ResetEntropyContext(block);
+  PopulateCdefSkip(block);
   if (split_parse_and_decode_) {
     if (!Residual(block, kProcessingModeParseOnly)) return false;
   } else {
@@ -2177,22 +2221,24 @@
       return false;
     }
   }
-  // If frame_header_.segmentation.enabled is false, bp.segment_id is 0 for all
-  // blocks. We don't need to call save bp.segment_id in the current frame
-  // because the current frame's segmentation map will be cleared to all 0s.
+  // If frame_header_.segmentation.enabled is false,
+  // bp.prediction_parameters->segment_id is 0 for all blocks. We don't need to
+  // call save bp.prediction_parameters->segment_id in the current frame because
+  // the current frame's segmentation map will be cleared to all 0s.
   //
   // If frame_header_.segmentation.enabled is true and
   // frame_header_.segmentation.update_map is false, we will copy the previous
   // frame's segmentation map to the current frame. So we don't need to call
-  // save bp.segment_id in the current frame.
+  // save bp.prediction_parameters->segment_id in the current frame.
   if (frame_header_.segmentation.enabled &&
       frame_header_.segmentation.update_map) {
     const int x_limit = std::min(frame_header_.columns4x4 - column4x4,
                                  static_cast<int>(block.width4x4));
     const int y_limit = std::min(frame_header_.rows4x4 - row4x4,
                                  static_cast<int>(block.height4x4));
-    current_frame_.segmentation_map()->FillBlock(row4x4, column4x4, x_limit,
-                                                 y_limit, bp.segment_id);
+    current_frame_.segmentation_map()->FillBlock(
+        row4x4, column4x4, x_limit, y_limit,
+        bp.prediction_parameters->segment_id);
   }
   StoreMotionFieldMvsIntoCurrentFrame(block);
   if (!split_parse_and_decode_) {
@@ -2208,7 +2254,7 @@
       column4x4 >= frame_header_.columns4x4) {
     return true;
   }
-  Block block(*this, block_size, row4x4, column4x4, scratch_buffer, residual);
+  Block block(this, block_size, row4x4, column4x4, scratch_buffer, residual);
   if (!ComputePrediction(block) ||
       !Residual(block, kProcessingModeDecodeOnly)) {
     return false;
@@ -2382,7 +2428,7 @@
 }
 
 void Tile::ResetCdef(const int row4x4, const int column4x4) {
-  if (!sequence_header_.enable_cdef) return;
+  if (frame_header_.cdef.bits == 0) return;
   const int row = DivideBy16(row4x4);
   const int column = DivideBy16(column4x4);
   cdef_index_[row][column] = -1;
@@ -2562,8 +2608,8 @@
     // Must make a local copy so that StoreMotionFieldMvs() knows there is no
     // overlap between load and store.
     const MotionVector mv_to_store = bp.mv.mv[i];
-    const int mv_row = std::abs(mv_to_store.mv[MotionVector::kRow]);
-    const int mv_column = std::abs(mv_to_store.mv[MotionVector::kColumn]);
+    const int mv_row = std::abs(mv_to_store.mv[0]);
+    const int mv_column = std::abs(mv_to_store.mv[1]);
     if (reference_frame_to_store > kReferenceFrameIntra &&
         // kRefMvsLimit equals 0x07FF, so we can first bitwise OR the two
         // absolute values and then compare with kRefMvsLimit to save a branch.
diff --git a/libgav1/src/tile_scratch_buffer.h b/libgav1/src/tile_scratch_buffer.h
index 3eaf8b8..828f550 100644
--- a/libgav1/src/tile_scratch_buffer.h
+++ b/libgav1/src/tile_scratch_buffer.h
@@ -17,8 +17,13 @@
 #ifndef LIBGAV1_SRC_TILE_SCRATCH_BUFFER_H_
 #define LIBGAV1_SRC_TILE_SCRATCH_BUFFER_H_
 
+#include <cstddef>
 #include <cstdint>
+#include <cstring>
+#include <memory>
 #include <mutex>  // NOLINT (unapproved c++11 header)
+#include <new>
+#include <utility>
 
 #include "src/dsp/constants.h"
 #include "src/utils/common.h"
@@ -42,9 +47,10 @@
     const int pixel_size = 1;
 #endif
 
+    static_assert(kConvolveScaleBorderRight >= kConvolveBorderRight, "");
     constexpr int unaligned_convolve_buffer_stride =
         kMaxScaledSuperBlockSizeInPixels + kConvolveBorderLeftTop +
-        kConvolveBorderRight;
+        kConvolveScaleBorderRight;
     convolve_block_buffer_stride = Align<ptrdiff_t>(
         unaligned_convolve_buffer_stride * pixel_size, kMaxAlignment);
     constexpr int convolve_buffer_height = kMaxScaledSuperBlockSizeInPixels +
@@ -53,6 +59,13 @@
 
     convolve_block_buffer = MakeAlignedUniquePtr<uint8_t>(
         kMaxAlignment, convolve_buffer_height * convolve_block_buffer_stride);
+#if LIBGAV1_MSAN
+    // Quiet msan warnings in ConvolveScale2D_NEON(). Set with random non-zero
+    // value to aid in future debugging.
+    memset(convolve_block_buffer.get(), 0x66,
+           convolve_buffer_height * convolve_block_buffer_stride);
+#endif
+
     return convolve_block_buffer != nullptr;
   }
 
diff --git a/libgav1/src/utils/common.h b/libgav1/src/utils/common.h
index 2e599f0..f75ace8 100644
--- a/libgav1/src/utils/common.h
+++ b/libgav1/src/utils/common.h
@@ -21,15 +21,17 @@
 #include <intrin.h>
 #pragma intrinsic(_BitScanForward)
 #pragma intrinsic(_BitScanReverse)
-#if defined(_M_X64) || defined(_M_ARM) || defined(_M_ARM64)
+#if defined(_M_X64) || defined(_M_ARM64)
 #pragma intrinsic(_BitScanReverse64)
 #define HAVE_BITSCANREVERSE64
-#endif  // defined(_M_X64) || defined(_M_ARM) || defined(_M_ARM64)
+#endif  // defined(_M_X64) || defined(_M_ARM64)
 #endif  // defined(_MSC_VER)
 
+#include <algorithm>
 #include <cassert>
 #include <cstddef>
 #include <cstdint>
+#include <cstdlib>
 #include <cstring>
 #include <type_traits>
 
@@ -40,6 +42,26 @@
 
 namespace libgav1 {
 
+// LIBGAV1_RESTRICT
+// Declares a pointer with the restrict type qualifier if available.
+// This allows code to hint to the compiler that only this pointer references a
+// particular object or memory region within the scope of the block in which it
+// is declared. This may allow for improved optimizations due to the lack of
+// pointer aliasing. See also:
+// https://en.cppreference.com/w/c/language/restrict
+// Note a template alias is not used for compatibility with older compilers
+// (e.g., gcc < 10) that do not expand the type when instantiating a template
+// function, either explicitly or in an assignment to a function pointer as is
+// done within the dsp code. RestrictPtr<T>::type is an alternative to this,
+// similar to std::add_const, but for conciseness the macro is preferred.
+#ifdef __GNUC__
+#define LIBGAV1_RESTRICT __restrict__
+#elif defined(_MSC_VER)
+#define LIBGAV1_RESTRICT __restrict
+#else
+#define LIBGAV1_RESTRICT
+#endif
+
 // Aligns |value| to the desired |alignment|. |alignment| must be a power of 2.
 template <typename T>
 inline T Align(T value, T alignment) {
diff --git a/libgav1/src/utils/compiler_attributes.h b/libgav1/src/utils/compiler_attributes.h
index e122426..09f0035 100644
--- a/libgav1/src/utils/compiler_attributes.h
+++ b/libgav1/src/utils/compiler_attributes.h
@@ -165,7 +165,7 @@
 //     int p1_ LIBGAV1_GUARDED_BY(mu_);
 //     ...
 //   };
-// TODO(b/132506370): this can be reenabled after a local MutexLock
+// TODO(b/133245043): this can be reenabled after a local MutexLock
 // implementation is added with proper thread annotations.
 #if 0  // LIBGAV1_HAS_ATTRIBUTE(guarded_by)
 #define LIBGAV1_GUARDED_BY(x) __attribute__((guarded_by(x)))
diff --git a/libgav1/src/utils/constants.h b/libgav1/src/utils/constants.h
index a2076c5..1126ad6 100644
--- a/libgav1/src/utils/constants.h
+++ b/libgav1/src/utils/constants.h
@@ -71,6 +71,7 @@
   // but was increased to simplify the SIMD loads in
   // ConvolveCompoundScale2D_NEON() and ConvolveScale2D_NEON().
   kConvolveBorderRight = 8,
+  kConvolveScaleBorderRight = 15,
   kConvolveBorderBottom = 4,
   kSubPixelTaps = 8,
   kWienerFilterBits = 7,
@@ -523,6 +524,10 @@
   kObuPadding = 15,
 };
 
+constexpr BitMaskSet kPredictionModeSmoothMask(kPredictionModeSmooth,
+                                               kPredictionModeSmoothHorizontal,
+                                               kPredictionModeSmoothVertical);
+
 //------------------------------------------------------------------------------
 // ToString()
 //
diff --git a/libgav1/src/utils/dynamic_buffer.h b/libgav1/src/utils/dynamic_buffer.h
index 40ece26..0694980 100644
--- a/libgav1/src/utils/dynamic_buffer.h
+++ b/libgav1/src/utils/dynamic_buffer.h
@@ -17,6 +17,7 @@
 #ifndef LIBGAV1_SRC_UTILS_DYNAMIC_BUFFER_H_
 #define LIBGAV1_SRC_UTILS_DYNAMIC_BUFFER_H_
 
+#include <cstddef>
 #include <memory>
 #include <new>
 
diff --git a/libgav1/src/utils/entropy_decoder.cc b/libgav1/src/utils/entropy_decoder.cc
index bf21199..3d97e69 100644
--- a/libgav1/src/utils/entropy_decoder.cc
+++ b/libgav1/src/utils/entropy_decoder.cc
@@ -60,7 +60,8 @@
          (kMinimumProbabilityPerSymbol * (symbol_count - index));
 }
 
-void UpdateCdf(uint16_t* const cdf, const int symbol_count, const int symbol) {
+void UpdateCdf(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol_count,
+               const int symbol) {
   const uint16_t count = cdf[symbol_count];
   // rate is computed in the spec as:
   //  3 + ( cdf[N] > 15 ) + ( cdf[N] > 31 ) + Min(FloorLog2(N), 2)
@@ -168,7 +169,7 @@
 //    the cdf array. Since an invalid CDF value is written into cdf[7], the
 //    count in cdf[7] needs to be fixed up after the vectorized code.
 
-void UpdateCdf5(uint16_t* const cdf, const int symbol) {
+void UpdateCdf5(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol) {
   uint16x4_t cdf_vec = vld1_u16(cdf);
   const uint16_t count = cdf[5];
   const int rate = (count >> 4) + 5;
@@ -195,7 +196,7 @@
 // This version works for |symbol_count| = 7, 8, or 9.
 // See UpdateCdf5 for implementation details.
 template <int symbol_count>
-void UpdateCdf7To9(uint16_t* const cdf, const int symbol) {
+void UpdateCdf7To9(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol) {
   static_assert(symbol_count >= 7 && symbol_count <= 9, "");
   uint16x8_t cdf_vec = vld1q_u16(cdf);
   const uint16_t count = cdf[symbol_count];
@@ -229,7 +230,7 @@
 }
 
 // See UpdateCdf5 for implementation details.
-void UpdateCdf11(uint16_t* const cdf, const int symbol) {
+void UpdateCdf11(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol) {
   uint16x8_t cdf_vec = vld1q_u16(cdf + 2);
   const uint16_t count = cdf[11];
   cdf[11] = count + static_cast<uint16_t>(count < 32);
@@ -266,7 +267,7 @@
 }
 
 // See UpdateCdf5 for implementation details.
-void UpdateCdf13(uint16_t* const cdf, const int symbol) {
+void UpdateCdf13(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol) {
   uint16x8_t cdf_vec0 = vld1q_u16(cdf);
   uint16x8_t cdf_vec1 = vld1q_u16(cdf + 4);
   const uint16_t count = cdf[13];
@@ -299,7 +300,7 @@
 }
 
 // See UpdateCdf5 for implementation details.
-void UpdateCdf16(uint16_t* const cdf, const int symbol) {
+void UpdateCdf16(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol) {
   uint16x8_t cdf_vec = vld1q_u16(cdf);
   const uint16_t count = cdf[16];
   const int rate = (count >> 4) + 5;
@@ -351,7 +352,7 @@
   _mm_storeu_si128(static_cast<__m128i*>(a), v);
 }
 
-void UpdateCdf5(uint16_t* const cdf, const int symbol) {
+void UpdateCdf5(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol) {
   __m128i cdf_vec = LoadLo8(cdf);
   const uint16_t count = cdf[5];
   const int rate = (count >> 4) + 5;
@@ -379,7 +380,7 @@
 // This version works for |symbol_count| = 7, 8, or 9.
 // See UpdateCdf5 for implementation details.
 template <int symbol_count>
-void UpdateCdf7To9(uint16_t* const cdf, const int symbol) {
+void UpdateCdf7To9(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol) {
   static_assert(symbol_count >= 7 && symbol_count <= 9, "");
   __m128i cdf_vec = LoadUnaligned16(cdf);
   const uint16_t count = cdf[symbol_count];
@@ -412,7 +413,7 @@
 }
 
 // See UpdateCdf5 for implementation details.
-void UpdateCdf11(uint16_t* const cdf, const int symbol) {
+void UpdateCdf11(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol) {
   __m128i cdf_vec = LoadUnaligned16(cdf + 2);
   const uint16_t count = cdf[11];
   cdf[11] = count + static_cast<uint16_t>(count < 32);
@@ -447,7 +448,7 @@
 }
 
 // See UpdateCdf5 for implementation details.
-void UpdateCdf13(uint16_t* const cdf, const int symbol) {
+void UpdateCdf13(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol) {
   __m128i cdf_vec0 = LoadLo8(cdf);
   __m128i cdf_vec1 = LoadUnaligned16(cdf + 4);
   const uint16_t count = cdf[13];
@@ -478,7 +479,7 @@
   cdf[13] = count + static_cast<uint16_t>(count < 32);
 }
 
-void UpdateCdf16(uint16_t* const cdf, const int symbol) {
+void UpdateCdf16(uint16_t* LIBGAV1_RESTRICT const cdf, const int symbol) {
   __m128i cdf_vec0 = LoadUnaligned16(cdf);
   const uint16_t count = cdf[16];
   const int rate = (count >> 4) + 5;
@@ -543,8 +544,8 @@
 #endif  // LIBGAV1_ENTROPY_DECODER_ENABLE_SSE2
 #endif  // LIBGAV1_ENTROPY_DECODER_ENABLE_NEON
 
-inline DaalaBitReader::WindowSize HostToBigEndian(
-    const DaalaBitReader::WindowSize x) {
+inline EntropyDecoder::WindowSize HostToBigEndian(
+    const EntropyDecoder::WindowSize x) {
   static_assert(sizeof(x) == 4 || sizeof(x) == 8, "");
 #if defined(__GNUC__)
 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
@@ -554,7 +555,7 @@
 #endif
 #elif defined(_WIN32)
   // Note Windows targets are assumed to be little endian.
-  return static_cast<DaalaBitReader::WindowSize>(
+  return static_cast<EntropyDecoder::WindowSize>(
       (sizeof(x) == 8) ? _byteswap_uint64(static_cast<unsigned __int64>(x))
                        : _byteswap_ulong(static_cast<unsigned long>(x)));
 #else
@@ -565,10 +566,10 @@
 }  // namespace
 
 #if !LIBGAV1_CXX17
-constexpr int DaalaBitReader::kWindowSize;  // static.
+constexpr int EntropyDecoder::kWindowSize;  // static.
 #endif
 
-DaalaBitReader::DaalaBitReader(const uint8_t* data, size_t size,
+EntropyDecoder::EntropyDecoder(const uint8_t* data, size_t size,
                                bool allow_update_cdf)
     : data_(data),
       data_end_(data + size),
@@ -607,7 +608,7 @@
 //   * The probability is fixed at half. So some multiplications can be replaced
 //     with bit operations.
 //   * Symbol count is fixed at 2.
-int DaalaBitReader::ReadBit() {
+int EntropyDecoder::ReadBit() {
   const uint32_t curr =
       ((values_in_range_ & kReadBitMask) >> 1) + kMinimumProbabilityPerSymbol;
   const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
@@ -623,7 +624,7 @@
   return bit;
 }
 
-int64_t DaalaBitReader::ReadLiteral(int num_bits) {
+int64_t EntropyDecoder::ReadLiteral(int num_bits) {
   assert(num_bits <= 32);
   assert(num_bits > 0);
   uint32_t literal = 0;
@@ -643,7 +644,8 @@
   return literal;
 }
 
-int DaalaBitReader::ReadSymbol(uint16_t* const cdf, int symbol_count) {
+int EntropyDecoder::ReadSymbol(uint16_t* LIBGAV1_RESTRICT const cdf,
+                               int symbol_count) {
   const int symbol = ReadSymbolImpl(cdf, symbol_count);
   if (allow_update_cdf_) {
     UpdateCdf(cdf, symbol_count, symbol);
@@ -651,7 +653,7 @@
   return symbol;
 }
 
-bool DaalaBitReader::ReadSymbol(uint16_t* cdf) {
+bool EntropyDecoder::ReadSymbol(uint16_t* LIBGAV1_RESTRICT cdf) {
   assert(cdf[1] == 0);
   const bool symbol = ReadSymbolImpl(cdf[0]) != 0;
   if (allow_update_cdf_) {
@@ -681,12 +683,12 @@
   return symbol;
 }
 
-bool DaalaBitReader::ReadSymbolWithoutCdfUpdate(uint16_t cdf) {
+bool EntropyDecoder::ReadSymbolWithoutCdfUpdate(uint16_t cdf) {
   return ReadSymbolImpl(cdf) != 0;
 }
 
 template <int symbol_count>
-int DaalaBitReader::ReadSymbol(uint16_t* const cdf) {
+int EntropyDecoder::ReadSymbol(uint16_t* LIBGAV1_RESTRICT const cdf) {
   static_assert(symbol_count >= 3 && symbol_count <= 16, "");
   if (symbol_count == 3 || symbol_count == 4) {
     return ReadSymbol3Or4(cdf, symbol_count);
@@ -721,7 +723,7 @@
   return symbol;
 }
 
-int DaalaBitReader::ReadSymbolImpl(const uint16_t* const cdf,
+int EntropyDecoder::ReadSymbolImpl(const uint16_t* LIBGAV1_RESTRICT const cdf,
                                    int symbol_count) {
   assert(cdf[symbol_count - 1] == 0);
   --symbol_count;
@@ -744,8 +746,8 @@
   return symbol;
 }
 
-int DaalaBitReader::ReadSymbolImplBinarySearch(const uint16_t* const cdf,
-                                               int symbol_count) {
+int EntropyDecoder::ReadSymbolImplBinarySearch(
+    const uint16_t* LIBGAV1_RESTRICT const cdf, int symbol_count) {
   assert(cdf[symbol_count - 1] == 0);
   assert(symbol_count > 1 && symbol_count <= 16);
   --symbol_count;
@@ -787,7 +789,7 @@
   return low;
 }
 
-int DaalaBitReader::ReadSymbolImpl(uint16_t cdf) {
+int EntropyDecoder::ReadSymbolImpl(uint16_t cdf) {
   const auto symbol_value = static_cast<uint16_t>(window_diff_ >> bits_);
   const uint32_t curr =
       (((values_in_range_ >> 8) * (cdf >> kCdfPrecision)) >> 1) +
@@ -805,7 +807,7 @@
 
 // Equivalent to ReadSymbol(cdf, [3,4]), with the ReadSymbolImpl and UpdateCdf
 // calls inlined.
-int DaalaBitReader::ReadSymbol3Or4(uint16_t* const cdf,
+int EntropyDecoder::ReadSymbol3Or4(uint16_t* LIBGAV1_RESTRICT const cdf,
                                    const int symbol_count) {
   assert(cdf[symbol_count - 1] == 0);
   uint32_t curr = values_in_range_;
@@ -970,7 +972,8 @@
   return symbol;
 }
 
-int DaalaBitReader::ReadSymbolImpl8(const uint16_t* const cdf) {
+int EntropyDecoder::ReadSymbolImpl8(
+    const uint16_t* LIBGAV1_RESTRICT const cdf) {
   assert(cdf[7] == 0);
   uint32_t curr = values_in_range_;
   uint32_t prev;
@@ -1033,7 +1036,7 @@
   return symbol;
 }
 
-void DaalaBitReader::PopulateBits() {
+void EntropyDecoder::PopulateBits() {
   constexpr int kMaxCachedBits = kWindowSize - 16;
 #if defined(__aarch64__)
   // Fast path: read eight bytes and add the first six bytes to window_diff_.
@@ -1092,7 +1095,7 @@
   window_diff_ = window_diff;
 }
 
-void DaalaBitReader::NormalizeRange() {
+void EntropyDecoder::NormalizeRange() {
   const int bits_used = 15 ^ FloorLog2(values_in_range_);
   bits_ -= bits_used;
   values_in_range_ <<= bits_used;
@@ -1100,18 +1103,18 @@
 }
 
 // Explicit instantiations.
-template int DaalaBitReader::ReadSymbol<3>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<4>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<5>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<6>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<7>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<8>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<9>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<10>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<11>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<12>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<13>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<14>(uint16_t* cdf);
-template int DaalaBitReader::ReadSymbol<16>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<3>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<4>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<5>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<6>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<7>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<8>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<9>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<10>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<11>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<12>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<13>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<14>(uint16_t* cdf);
+template int EntropyDecoder::ReadSymbol<16>(uint16_t* cdf);
 
 }  // namespace libgav1
diff --git a/libgav1/src/utils/entropy_decoder.h b/libgav1/src/utils/entropy_decoder.h
index c066b98..8eeaef4 100644
--- a/libgav1/src/utils/entropy_decoder.h
+++ b/libgav1/src/utils/entropy_decoder.h
@@ -25,20 +25,20 @@
 
 namespace libgav1 {
 
-class DaalaBitReader : public BitReader {
+class EntropyDecoder final : public BitReader {
  public:
   // WindowSize must be an unsigned integer type with at least 32 bits. Use the
   // largest type with fast arithmetic. size_t should meet these requirements.
   using WindowSize = size_t;
 
-  DaalaBitReader(const uint8_t* data, size_t size, bool allow_update_cdf);
-  ~DaalaBitReader() override = default;
+  EntropyDecoder(const uint8_t* data, size_t size, bool allow_update_cdf);
+  ~EntropyDecoder() override = default;
 
   // Move only.
-  DaalaBitReader(DaalaBitReader&& rhs) noexcept;
-  DaalaBitReader& operator=(DaalaBitReader&& rhs) noexcept;
+  EntropyDecoder(EntropyDecoder&& rhs) noexcept;
+  EntropyDecoder& operator=(EntropyDecoder&& rhs) noexcept;
 
-  int ReadBit() final;
+  int ReadBit() override;
   int64_t ReadLiteral(int num_bits) override;
   // ReadSymbol() calls for which the |symbol_count| is only known at runtime
   // will use this variant.
@@ -104,19 +104,19 @@
   WindowSize window_diff_;
 };
 
-extern template int DaalaBitReader::ReadSymbol<3>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<4>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<5>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<6>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<7>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<8>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<9>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<10>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<11>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<12>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<13>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<14>(uint16_t* cdf);
-extern template int DaalaBitReader::ReadSymbol<16>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<3>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<4>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<5>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<6>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<7>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<8>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<9>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<10>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<11>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<12>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<13>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<14>(uint16_t* cdf);
+extern template int EntropyDecoder::ReadSymbol<16>(uint16_t* cdf);
 
 }  // namespace libgav1
 
diff --git a/libgav1/src/utils/memory.h b/libgav1/src/utils/memory.h
index a8da53b..d1762a2 100644
--- a/libgav1/src/utils/memory.h
+++ b/libgav1/src/utils/memory.h
@@ -17,7 +17,7 @@
 #ifndef LIBGAV1_SRC_UTILS_MEMORY_H_
 #define LIBGAV1_SRC_UTILS_MEMORY_H_
 
-#if defined(__ANDROID__) || defined(_MSC_VER)
+#if defined(__ANDROID__) || defined(_MSC_VER) || defined(__MINGW32__)
 #include <malloc.h>
 #endif
 
@@ -55,7 +55,7 @@
 // void AlignedFree(void* aligned_memory);
 //   Free aligned memory.
 
-#if defined(_MSC_VER)  // MSVC
+#if defined(_MSC_VER) || defined(__MINGW32__)
 
 inline void* AlignedAlloc(size_t alignment, size_t size) {
   return _aligned_malloc(size, alignment);
@@ -63,7 +63,7 @@
 
 inline void AlignedFree(void* aligned_memory) { _aligned_free(aligned_memory); }
 
-#else  // !defined(_MSC_VER)
+#else  // !(defined(_MSC_VER) || defined(__MINGW32__))
 
 inline void* AlignedAlloc(size_t alignment, size_t size) {
 #if defined(__ANDROID__)
@@ -89,7 +89,7 @@
 
 inline void AlignedFree(void* aligned_memory) { free(aligned_memory); }
 
-#endif  // defined(_MSC_VER)
+#endif  // defined(_MSC_VER) || defined(__MINGW32__)
 
 inline void Memset(uint8_t* const dst, int value, size_t count) {
   memset(dst, value, count);
@@ -101,6 +101,12 @@
   }
 }
 
+inline void Memset(int16_t* const dst, int value, size_t count) {
+  for (size_t i = 0; i < count; ++i) {
+    dst[i] = static_cast<int16_t>(value);
+  }
+}
+
 struct MallocDeleter {
   void operator()(void* ptr) const { free(ptr); }
 };
diff --git a/libgav1/src/utils/queue.h b/libgav1/src/utils/queue.h
index cffb9ca..fcc7bfe 100644
--- a/libgav1/src/utils/queue.h
+++ b/libgav1/src/utils/queue.h
@@ -21,6 +21,7 @@
 #include <cstddef>
 #include <memory>
 #include <new>
+#include <utility>
 
 #include "src/utils/compiler_attributes.h"
 
diff --git a/libgav1/src/utils/raw_bit_reader.h b/libgav1/src/utils/raw_bit_reader.h
index 7d8ce8f..da770d1 100644
--- a/libgav1/src/utils/raw_bit_reader.h
+++ b/libgav1/src/utils/raw_bit_reader.h
@@ -25,7 +25,7 @@
 
 namespace libgav1 {
 
-class RawBitReader : public BitReader, public Allocable {
+class RawBitReader final : public BitReader, public Allocable {
  public:
   RawBitReader(const uint8_t* data, size_t size);
   ~RawBitReader() override = default;
diff --git a/libgav1/src/utils/reference_info.h b/libgav1/src/utils/reference_info.h
index a660791..73c32d9 100644
--- a/libgav1/src/utils/reference_info.h
+++ b/libgav1/src/utils/reference_info.h
@@ -21,6 +21,7 @@
 #include <cstdint>
 
 #include "src/utils/array_2d.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/constants.h"
 #include "src/utils/types.h"
 
diff --git a/libgav1/src/utils/types.h b/libgav1/src/utils/types.h
index eba13b7..0dd6360 100644
--- a/libgav1/src/utils/types.h
+++ b/libgav1/src/utils/types.h
@@ -28,45 +28,20 @@
 
 namespace libgav1 {
 
-struct MotionVector : public Allocable {
-  static constexpr int kRow = 0;
-  static constexpr int kColumn = 1;
-
-  MotionVector() = default;
-  MotionVector(const MotionVector& mv) = default;
-
-  MotionVector& operator=(const MotionVector& rhs) {
-    mv32 = rhs.mv32;
-    return *this;
-  }
-
-  bool operator==(const MotionVector& rhs) const { return mv32 == rhs.mv32; }
-
-  union {
-    // Motion vectors will always fit in int16_t and using int16_t here instead
-    // of int saves significant memory since some of the frame sized structures
-    // store motion vectors.
-    int16_t mv[2];
-    // A uint32_t view into the |mv| array. Useful for cases where both the
-    // motion vectors have to be copied or compared with a single 32 bit
-    // instruction.
-    uint32_t mv32;
-  };
+union MotionVector {
+  // Motion vectors will always fit in int16_t and using int16_t here instead
+  // of int saves significant memory since some of the frame sized structures
+  // store motion vectors.
+  // Index 0 is the entry for row (horizontal direction) motion vector.
+  // Index 1 is the entry for column (vertical direction) motion vector.
+  int16_t mv[2];
+  // A uint32_t view into the |mv| array. Useful for cases where both the
+  // motion vectors have to be copied or compared with a single 32 bit
+  // instruction.
+  uint32_t mv32;
 };
 
 union CompoundMotionVector {
-  CompoundMotionVector() = default;
-  CompoundMotionVector(const CompoundMotionVector& mv) = default;
-
-  CompoundMotionVector& operator=(const CompoundMotionVector& rhs) {
-    mv64 = rhs.mv64;
-    return *this;
-  }
-
-  bool operator==(const CompoundMotionVector& rhs) const {
-    return mv64 == rhs.mv64;
-  }
-
   MotionVector mv[2];
   // A uint64_t view into the |mv| array. Useful for cases where all the motion
   // vectors have to be copied or compared with a single 64 bit instruction.
@@ -163,6 +138,11 @@
   MotionVector global_mv[2];
   int num_warp_samples;
   int warp_estimate_candidates[kMaxLeastSquaresSamples][4];
+  PaletteModeInfo palette_mode_info;
+  int8_t segment_id;  // segment_id is in the range [0, 7].
+  PredictionMode uv_mode;
+  bool chroma_top_uses_smooth_prediction;
+  bool chroma_left_uses_smooth_prediction;
 };
 
 // A lot of BlockParameters objects are created, so the smallest type is used
@@ -171,19 +151,8 @@
 struct BlockParameters : public Allocable {
   BlockSize size;
   bool skip;
-  // True means that this block will use some default settings (that
-  // correspond to compound prediction) and so most of the mode info is
-  // skipped. False means that the mode info is not skipped.
-  bool skip_mode;
   bool is_inter;
-  bool is_explicit_compound_type;  // comp_group_idx in the spec.
-  bool is_compound_type_average;   // compound_idx in the spec.
-  bool is_global_mv_block;
-  bool use_predicted_segment_id;  // only valid with temporal update enabled.
-  int8_t segment_id;              // segment_id is in the range [0, 7].
   PredictionMode y_mode;
-  PredictionMode uv_mode;
-  TransformSize transform_size;
   TransformSize uv_transform_size;
   InterpolationFilter interpolation_filter[2];
   ReferenceFrameType reference_frame[2];
@@ -194,7 +163,6 @@
   //  3 - V plane (both directions).
   uint8_t deblock_filter_level[kFrameLfCount];
   CompoundMotionVector mv;
-  PaletteModeInfo palette_mode_info;
   // When |Tile::split_parse_and_decode_| is true, each block gets its own
   // instance of |prediction_parameters|. When it is false, all the blocks point
   // to |Tile::prediction_parameters_|. This field is valid only as long as the
@@ -203,6 +171,18 @@
   std::unique_ptr<PredictionParameters> prediction_parameters;
 };
 
+// Used to store the left and top block parameters that are used for computing
+// the cdf context of the subsequent blocks.
+struct BlockCdfContext {
+  bool use_predicted_segment_id[32];
+  bool is_explicit_compound_type[32];  // comp_group_idx in the spec.
+  bool is_compound_type_average[32];   // compound_idx in the spec.
+  bool skip_mode[32];
+  uint8_t palette_size[kNumPlaneTypes][32];
+  uint16_t palette_color[32][kNumPlaneTypes][kMaxPaletteSize];
+  PredictionMode uv_mode[32];
+};
+
 // A five dimensional array used to store the wedge masks. The dimensions are:
 //   - block_size_index (returned by GetWedgeBlockSizeIndex() in prediction.cc).
 //   - flip_sign (0 or 1).
diff --git a/libgav1/src/utils/vector.h b/libgav1/src/utils/vector.h
index e211240..9a21aeb 100644
--- a/libgav1/src/utils/vector.h
+++ b/libgav1/src/utils/vector.h
@@ -24,6 +24,7 @@
 #include <cstdlib>
 #include <cstring>
 #include <iterator>
+#include <new>
 #include <type_traits>
 #include <utility>
 
diff --git a/libgav1/src/warp_prediction.cc b/libgav1/src/warp_prediction.cc
index dd06317..69b40e8 100644
--- a/libgav1/src/warp_prediction.cc
+++ b/libgav1/src/warp_prediction.cc
@@ -153,10 +153,8 @@
   const int mid_x = MultiplyBy4(column4x4) + MultiplyBy2(block_width4x4) - 1;
   const int subpixel_mid_y = MultiplyBy8(mid_y);
   const int subpixel_mid_x = MultiplyBy8(mid_x);
-  const int reference_subpixel_mid_y =
-      subpixel_mid_y + mv.mv[MotionVector::kRow];
-  const int reference_subpixel_mid_x =
-      subpixel_mid_x + mv.mv[MotionVector::kColumn];
+  const int reference_subpixel_mid_y = subpixel_mid_y + mv.mv[0];
+  const int reference_subpixel_mid_x = subpixel_mid_x + mv.mv[1];
 
   for (int i = 0; i < num_samples; ++i) {
     // candidates[][0] and candidates[][1] are the row/column coordinates of the
@@ -223,14 +221,12 @@
   params[4] = NonDiagonalClamp(params[4]);
   params[5] = DiagonalClamp(params[5]);
 
-  const int vx =
-      mv.mv[MotionVector::kColumn] * (1 << (kWarpedModelPrecisionBits - 3)) -
-      (mid_x * (params[2] - (1 << kWarpedModelPrecisionBits)) +
-       mid_y * params[3]);
-  const int vy =
-      mv.mv[MotionVector::kRow] * (1 << (kWarpedModelPrecisionBits - 3)) -
-      (mid_x * params[4] +
-       mid_y * (params[5] - (1 << kWarpedModelPrecisionBits)));
+  const int vx = mv.mv[1] * (1 << (kWarpedModelPrecisionBits - 3)) -
+                 (mid_x * (params[2] - (1 << kWarpedModelPrecisionBits)) +
+                  mid_y * params[3]);
+  const int vy = mv.mv[0] * (1 << (kWarpedModelPrecisionBits - 3)) -
+                 (mid_x * params[4] +
+                  mid_y * (params[5] - (1 << kWarpedModelPrecisionBits)));
   params[0] =
       Clip3(vx, -kWarpModelTranslationClamp, kWarpModelTranslationClamp - 1);
   params[1] =
diff --git a/libgav1/src/yuv_buffer.cc b/libgav1/src/yuv_buffer.cc
index c74e140..efb8016 100644
--- a/libgav1/src/yuv_buffer.cc
+++ b/libgav1/src/yuv_buffer.cc
@@ -20,6 +20,7 @@
 
 #include "src/frame_buffer_utils.h"
 #include "src/utils/common.h"
+#include "src/utils/compiler_attributes.h"
 #include "src/utils/logging.h"
 
 namespace libgav1 {
@@ -195,6 +196,60 @@
   assert(!is_monochrome || buffer_[kPlaneU] == nullptr);
   assert(!is_monochrome || buffer_[kPlaneV] == nullptr);
 
+#if LIBGAV1_MSAN
+  const int pixel_size = (bitdepth == 8) ? sizeof(uint8_t) : sizeof(uint16_t);
+  int width_in_bytes = width * pixel_size;
+  // The optimized loop restoration code will overread the visible frame buffer
+  // into the right border. The optimized cfl subsambler uses the right border
+  // as well. Initialize the right border and padding to prevent msan warnings.
+  int right_border_size_in_bytes = right_border * pixel_size;
+  // Calculate the padding bytes for the buffer. Note: The stride of the buffer
+  // is always a multiple of 16. (see yuv_buffer.h)
+  const int right_padding_in_bytes =
+      stride_[kPlaneY] - (pixel_size * (width + left_border + right_border));
+  const int padded_right_border_size =
+      right_border_size_in_bytes + right_padding_in_bytes;
+  constexpr uint8_t right_val = 0x55;
+  uint8_t* rb = buffer_[kPlaneY] + width_in_bytes;
+  for (int i = 0; i < height + bottom_border; ++i) {
+    memset(rb, right_val, padded_right_border_size);
+    rb += stride_[kPlaneY];
+  }
+  if (!is_monochrome) {
+    int uv_width_in_bytes = uv_width * pixel_size;
+    int uv_right_border_size_in_bytes = uv_right_border * pixel_size;
+    const int u_right_padding_in_bytes =
+        stride_[kPlaneU] -
+        (pixel_size * (uv_width + uv_left_border + uv_right_border));
+    const int u_padded_right_border_size =
+        uv_right_border_size_in_bytes + u_right_padding_in_bytes;
+    rb = buffer_[kPlaneU] + uv_width_in_bytes;
+    for (int i = 0; i < uv_height; ++i) {
+      memset(rb, right_val, u_padded_right_border_size);
+      rb += stride_[kPlaneU];
+    }
+    const int v_right_padding_in_bytes =
+        stride_[kPlaneV] -
+        ((uv_width + uv_left_border + uv_right_border) * pixel_size);
+    const int v_padded_right_border_size =
+        uv_right_border_size_in_bytes + v_right_padding_in_bytes;
+    rb = buffer_[kPlaneV] + uv_width_in_bytes;
+    for (int i = 0; i < uv_height; ++i) {
+      memset(rb, right_val, v_padded_right_border_size);
+      rb += stride_[kPlaneV];
+    }
+  }
+
+  // The optimized cfl subsampler will overread (to the right of the current
+  // block) into the uninitialized visible area. The cfl subsampler can overread
+  // into the bottom border as well. Initialize the both to quiet msan warnings.
+  uint8_t* y_visible = buffer_[kPlaneY];
+  for (int i = 0; i < height + bottom_border; ++i) {
+    memset(y_visible, right_val, width_in_bytes);
+    y_visible += stride_[kPlaneY];
+  }
+#endif
+
   return true;
 }