Fixed the ability to filter by multiple class/methods combinations

bug: 16034525
Change-Id: Id8f81b09c51f66ffd6e77bb58d9684167b2e55ba
(cherry picked from commit d904d52274bafc1fd3928c6b428fed4717e30b4a)
diff --git a/support/src/android/support/test/internal/runner/TestLoader.java b/support/src/android/support/test/internal/runner/TestLoader.java
index 9fffc63..e556d9d 100644
--- a/support/src/android/support/test/internal/runner/TestLoader.java
+++ b/support/src/android/support/test/internal/runner/TestLoader.java
@@ -24,8 +24,10 @@
 import java.io.PrintStream;
 import java.lang.reflect.Method;
 import java.lang.reflect.Modifier;
-import java.util.LinkedList;
+import java.util.Collection;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 
 /**
  * A class for loading JUnit3 and JUnit4 test classes given a set of potential class names.
@@ -34,8 +36,8 @@
 
     private static final String LOG_TAG = "TestLoader";
 
-    private  List<Class<?>> mLoadedClasses = new LinkedList<Class<?>>();
-    private  List<Failure> mLoadFailures = new LinkedList<Failure>();
+    private Map<String, Class<?>> mLoadedClassesMap = new LinkedHashMap<String, Class<?>>();
+    private Map<String, Failure> mLoadFailuresMap = new LinkedHashMap<String, Failure>();
 
     private PrintStream mWriter;
 
@@ -49,7 +51,7 @@
     }
 
     /**
-     * Loads the test class from the given class name.
+     * Loads the test class from a given class name if its not already loaded.
      * <p/>
      * Will store the result internally. Successfully loaded classes can be retrieved via
      * {@link #getLoadedClasses()}, failures via {@link #getLoadFailures()}.
@@ -60,15 +62,23 @@
     public Class<?> loadClass(String className) {
         Class<?> loadedClass = doLoadClass(className);
         if (loadedClass != null) {
-            mLoadedClasses.add(loadedClass);
+            mLoadedClassesMap.put(className, loadedClass);
         }
         return loadedClass;
     }
 
     private Class<?> doLoadClass(String className) {
+        if (mLoadFailuresMap.containsKey(className)) {
+            // Don't load classes that already failed to load
+            return null;
+        } else if (mLoadedClassesMap.containsKey(className)) {
+            // Class with the same name was already loaded, return it
+            return mLoadedClassesMap.get(className);
+        }
+
         try {
-            // TODO: InstrumentationTestRunner uses Class.forName(className, false,
-            // getTargetContext().getClassLoader()
+            // TODO: InstrumentationTestRunner uses
+            // Class.forName(className, false, getTargetContext().getClassLoader());
             // Evaluate if that is needed. Initial testing indicates
             // getTargetContext().getClassLoader() == this.getClass().getClassLoader()
             ClassLoader myClassLoader = this.getClass().getClassLoader();
@@ -79,7 +89,7 @@
             mWriter.println(errMsg);
             Description description = Description.createSuiteDescription(className);
             Failure failure = new Failure(description, e);
-            mLoadFailures.add(failure);
+            mLoadFailuresMap.put(className, failure);
         }
         return null;
     }
@@ -87,7 +97,7 @@
     /**
      * Loads the test class from the given class name.
      * <p/>
-     * Similar to {@link #loadClass(String, PrintStream))}, but will ignore classes that are
+     * Similar to {@link #loadClass(String)}, but will ignore classes that are
      * not tests.
      *
      * @param className the class name to attempt to load
@@ -96,7 +106,7 @@
     public Class<?> loadIfTest(String className) {
         Class<?> loadedClass = doLoadClass(className);
         if (loadedClass != null && isTestClass(loadedClass)) {
-            mLoadedClasses.add(loadedClass);
+            mLoadedClassesMap.put(className, loadedClass);
             return loadedClass;
         }
         return null;
@@ -106,23 +116,23 @@
      * @return whether this {@link TestLoader} contains any loaded classes or load failures.
      */
     public boolean isEmpty() {
-        return mLoadedClasses.isEmpty() && mLoadFailures.isEmpty();
+        return mLoadedClassesMap.isEmpty() && mLoadFailuresMap.isEmpty();
     }
 
     /**
-     * Get the {@link List) of classes successfully loaded via
-     * {@link #loadTest(String, PrintStream)} calls.
+     * Get the {@link Collection) of classes successfully loaded via
+     * {@link #loadIfTest(String)} calls.
      */
-    public List<Class<?>> getLoadedClasses() {
-        return mLoadedClasses;
+    public Collection<Class<?>> getLoadedClasses() {
+        return mLoadedClassesMap.values();
     }
 
     /**
      * Get the {@link List) of {@link Failure} that occurred during
-     * {@link #loadTest(String, PrintStream)} calls.
+     * {@link #loadIfTest(String)} calls.
      */
-    public List<Failure> getLoadFailures() {
-        return mLoadFailures;
+    public Collection<Failure> getLoadFailures() {
+        return mLoadFailuresMap.values();
     }
 
     /**
diff --git a/support/src/android/support/test/internal/runner/TestRequest.java b/support/src/android/support/test/internal/runner/TestRequest.java
index ec390e9..1be44d9 100644
--- a/support/src/android/support/test/internal/runner/TestRequest.java
+++ b/support/src/android/support/test/internal/runner/TestRequest.java
@@ -18,22 +18,22 @@
 import org.junit.runner.Request;
 import org.junit.runner.notification.Failure;
 
-import java.util.List;
+import java.util.Collection;
 
 /**
  * A data structure for holding a {@link Request} and the {@link Failure}s that occurred during its
  * creation.
  */
 public class TestRequest {
-     private final List<Failure> mFailures;
+     private final Collection<Failure> mFailures;
      private final Request mRequest;
 
-     public TestRequest(List<Failure> requestBuildFailures, Request request) {
+     public TestRequest(Collection<Failure> requestBuildFailures, Request request) {
          mRequest = request;
          mFailures = requestBuildFailures;
      }
 
-     public List<Failure> getFailures() {
+     public Collection<Failure> getFailures() {
          return mFailures;
      }
 
diff --git a/support/src/android/support/test/internal/runner/TestRequestBuilder.java b/support/src/android/support/test/internal/runner/TestRequestBuilder.java
index 82f007e..c48b306 100644
--- a/support/src/android/support/test/internal/runner/TestRequestBuilder.java
+++ b/support/src/android/support/test/internal/runner/TestRequestBuilder.java
@@ -44,7 +44,9 @@
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
 import java.util.regex.Pattern;
 
@@ -64,10 +66,11 @@
 
     private String[] mApkPaths;
     private TestLoader mTestLoader;
-    private Filter mFilter = new AnnotationExclusionFilter(Suppress.class).intersect(
-            new SdkSuppressFilter()).intersect(new RequiresDeviceFilter());
-
-    private PrintStream mWriter;
+    private ClassAndMethodFilter mClassMethodFilter = new ClassAndMethodFilter();
+    private Filter mFilter = new AnnotationExclusionFilter(Suppress.class)
+            .intersect(new SdkSuppressFilter())
+            .intersect(new RequiresDeviceFilter())
+            .intersect(mClassMethodFilter);
     private boolean mSkipExecution = false;
     private String mTestPackageName = null;
     private final DeviceBuild mDeviceBuild;
@@ -394,44 +397,101 @@
     public void addTestMethod(String testClassName, String testMethodName) {
         Class<?> clazz = mTestLoader.loadClass(testClassName);
         if (clazz != null) {
-            mFilter = mFilter.intersect(matchParameterizedMethod(
-                    Description.createTestDescription(clazz, testMethodName)));
+            mClassMethodFilter.add(testClassName, testMethodName);
         }
     }
 
     /**
-     * A filter to get around the fact that parameterized tests append "[#]" at
-     * the end of the method names. For instance, "getFoo" would become
-     * "getFoo[0]".
+     * A {@link Filter} to support the ability to filter out multiple classes#methodes combinations.
      */
-    private static Filter matchParameterizedMethod(final Description target) {
-        return new Filter() {
-            Pattern pat = Pattern.compile(Pattern.quote(target.getMethodName()) + "(\\[[0-9]+\\])?");
+    private static class ClassAndMethodFilter extends Filter {
 
-            @Override
-            public boolean shouldRun(Description desc) {
-                if (desc.isTest()) {
-                    return target.getClassName().equals(desc.getClassName())
-                            && isMatch(desc.getMethodName());
+        private Map<String, MethodFilter> mClassMethodFilterMap
+                = new HashMap<String, MethodFilter>();
+
+        @Override
+        public boolean shouldRun(Description description) {
+            if (mClassMethodFilterMap.isEmpty()) {
+                return true;
+            }
+            if (description.isTest()) {
+                MethodFilter mf = mClassMethodFilterMap.get(description.getClassName());
+                if (mf != null) {
+                    return mf.shouldRun(description);
                 }
-
-                for (Description child : desc.getChildren()) {
+            } else {
+                // Check all children, if any
+                for (Description child : description.getChildren()) {
                     if (shouldRun(child)) {
                         return true;
                     }
                 }
-                return false;
             }
+            return false;
+        }
 
-            private boolean isMatch(String first) {
-                return pat.matcher(first).matches();
-            }
+        @Override
+        public String describe() {
+            return "Class and method filter";
+        }
 
-            @Override
-            public String describe() {
-                return String.format("Method %s", target.getDisplayName());
+        public void add(String className, String methodName) {
+            MethodFilter mf = mClassMethodFilterMap.get(className);
+            if (mf == null) {
+                mf = new MethodFilter(className);
+                mClassMethodFilterMap.put(className, mf);
             }
-        };
+            mf.add(methodName);
+        }
+    }
+
+    /**
+     * A {@link Filter} used to filter out desired test methods from a given class
+     */
+    private static class MethodFilter extends Filter {
+
+        private final String mClassName;
+        private Set<String> mMethodNames = new HashSet<String>();
+
+        /**
+         * Constructs a method filter for a given class
+         * @param className  name of the class the method belongs to
+         */
+        public MethodFilter(String className) {
+            mClassName = className;
+        }
+
+        @Override
+        public String describe() {
+            return "Method filter for " + mClassName + " class";
+        }
+
+        @Override
+        public boolean shouldRun(Description description) {
+            if (description.isTest()) {
+                String methodName = description.getMethodName();
+                // Parameterized tests append "[#]" at the end of the method names.
+                // For instance, "getFoo" would become "getFoo[0]".
+                methodName = stripParameterizedSuffix(methodName);
+                return mMethodNames.contains(methodName);
+            }
+            // At this point, this could only be a description of this filter
+            return true;
+
+        }
+
+        // Strips out the parameterized suffix if it exists
+        private String stripParameterizedSuffix(String name) {
+            Pattern suffixPattern = Pattern.compile(".+(\\[[0-9]+\\])$");
+            if (suffixPattern.matcher(name).matches()) {
+                name = name.substring(0, name.lastIndexOf('['));
+            }
+            return name;
+        }
+
+        public void add(String methodName) {
+            mMethodNames.add(methodName);
+        }
     }
 
     /**
@@ -564,7 +624,6 @@
         try {
             return scanner.getClassPathEntries(filter);
         } catch (IOException e) {
-            mWriter.println("failed to scan classes");
             Log.e(LOG_TAG, "Failed to scan classes", e);
         }
         return Collections.emptyList();
diff --git a/support/tests/src/android/support/test/internal/runner/TestLoaderTest.java b/support/tests/src/android/support/test/internal/runner/TestLoaderTest.java
index 53f9df4..a1cc0ff 100644
--- a/support/tests/src/android/support/test/internal/runner/TestLoaderTest.java
+++ b/support/tests/src/android/support/test/internal/runner/TestLoaderTest.java
@@ -130,4 +130,14 @@
     public void testLoadTests_junit3SubclassAbstract() {
         assertLoadTestSuccess(SubClassAbstractTest.class);
     }
+
+    /**
+     *  Verify loading a class that has already been loaded
+     */
+    @Test
+    public void testLoadTests_loadAlreadyLoadedClass() {
+        Class<?> clazz = SubClassAbstractTest.class;
+        assertLoadTestSuccess(clazz);
+        assertLoadTestSuccess(clazz);
+    }
 }
diff --git a/support/tests/src/android/support/test/internal/runner/TestRequestBuilderTest.java b/support/tests/src/android/support/test/internal/runner/TestRequestBuilderTest.java
index f28c894..ebb1fcf 100644
--- a/support/tests/src/android/support/test/internal/runner/TestRequestBuilderTest.java
+++ b/support/tests/src/android/support/test/internal/runner/TestRequestBuilderTest.java
@@ -33,13 +33,17 @@
 import org.junit.runner.Description;
 import org.junit.runner.JUnitCore;
 import org.junit.runner.Result;
+import org.junit.runner.RunWith;
 import org.junit.runner.notification.RunListener;
+import org.junit.runners.Parameterized;
 import org.mockito.Mock;
 import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 
 import java.io.ByteArrayOutputStream;
 import java.io.PrintStream;
+import java.util.Arrays;
+import java.util.Collection;
 
 /**
  * Unit tests for {@link TestRequestBuilder}.
@@ -257,6 +261,25 @@
         }
     }
 
+    @RunWith(value = Parameterized.class)
+    public static class ParameterizedTest {
+
+        public ParameterizedTest(int data) {
+        }
+
+        @Parameterized.Parameters
+        public static Collection<Object[]> data() {
+            Object[][] data = new Object[][]{{1}, {2}, {3}};
+            return Arrays.asList(data);
+        }
+
+        @Test
+        public void testParameterized() {
+
+        }
+    }
+
+
     @InjectInstrumentation
     public Instrumentation mInstr;
 
@@ -656,4 +679,45 @@
         Result result = testRunner.run(request.getRequest());
         Assert.assertEquals(1, result.getRunCount());
     }
+
+    /**
+     * Test filtering by two methods in single class
+     */
+    @Test
+    public void testMultipleMethodsFilter() {
+        TestRequestBuilder b = new TestRequestBuilder(new PrintStream(new ByteArrayOutputStream()));
+        b.addTestMethod(SampleJUnit3Test.class.getName(), "testSmall");
+        b.addTestMethod(SampleJUnit3Test.class.getName(), "testSmall2");
+        TestRequest request = b.build(mInstr, mBundle);
+        JUnitCore testRunner = new JUnitCore();
+        Result result = testRunner.run(request.getRequest());
+        Assert.assertEquals(2, result.getRunCount());
+    }
+
+    /**
+     * Test filtering by two methods in separate classes
+     */
+    @Test
+    public void testTwoMethodsDiffClassFilter() {
+        TestRequestBuilder b = new TestRequestBuilder(new PrintStream(new ByteArrayOutputStream()));
+        b.addTestMethod(SampleJUnit3Test.class.getName(), "testSmall");
+        b.addTestMethod(SampleTest.class.getName(), "testOther");
+        TestRequest request = b.build(mInstr, mBundle);
+        JUnitCore testRunner = new JUnitCore();
+        Result result = testRunner.run(request.getRequest());
+        Assert.assertEquals(2, result.getRunCount());
+    }
+
+    /**
+     * Test filtering a parameterized method
+     */
+    @Test
+    public void testParameterizedMethods() throws Exception {
+        TestRequestBuilder b = new TestRequestBuilder(new PrintStream(new ByteArrayOutputStream()));
+        b.addTestMethod(ParameterizedTest.class.getName(), "testParameterized");
+        TestRequest request = b.build(mInstr, mBundle);
+        JUnitCore testRunner = new JUnitCore();
+        Result result = testRunner.run(request.getRequest());
+        Assert.assertEquals(3, result.getRunCount());
+    }
 }