Snap for 10127524 from 5de3963ef66e85679c6e868643b334b8c4da1515 to udc-release

Change-Id: I0dc5fe0a04dda29ceeb16b9f6bedf12f47531828
diff --git a/app/src/com/android/rkpdapp/ThreadPool.java b/app/src/com/android/rkpdapp/ThreadPool.java
index ca1fe0b..6d4b9a1 100644
--- a/app/src/com/android/rkpdapp/ThreadPool.java
+++ b/app/src/com/android/rkpdapp/ThreadPool.java
@@ -17,9 +17,7 @@
 package com.android.rkpdapp;
 
 import java.util.concurrent.ExecutorService;
-import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.ThreadPoolExecutor;
-import java.util.concurrent.TimeUnit;
+import java.util.concurrent.Executors;
 
 /**
  * This class provides a global thread pool to RKPD app.
@@ -35,8 +33,5 @@
      * Each thread has an unbounded queue. This allows RKPD to serve requests
      * asynchronously.
      */
-    public static final ExecutorService EXECUTOR =
-            new ThreadPoolExecutor(/*corePoolSize=*/ 0, /*maximumPoolSize=*/ NUMBER_OF_THREADS,
-                    /*keepAliveTime=*/ 30L, /*unit=*/ TimeUnit.SECONDS,
-                    /*workQueue=*/ new LinkedBlockingQueue<Runnable>());
+    public static final ExecutorService EXECUTOR = Executors.newFixedThreadPool(NUMBER_OF_THREADS);
 }
diff --git a/app/src/com/android/rkpdapp/provisioner/Provisioner.java b/app/src/com/android/rkpdapp/provisioner/Provisioner.java
index bffe3a0..eb63b50 100644
--- a/app/src/com/android/rkpdapp/provisioner/Provisioner.java
+++ b/app/src/com/android/rkpdapp/provisioner/Provisioner.java
@@ -49,6 +49,7 @@
 public class Provisioner {
     private static final String TAG = "RkpdProvisioner";
     private static final int FAILURE_MAXIMUM = 5;
+    private static final Object provisionKeysLock = new Object();
 
     private final Context mContext;
     private final ProvisionedKeyDao mKeyDao;
@@ -74,34 +75,36 @@
      */
     public void provisionKeys(ProvisioningAttempt metrics, SystemInterface systemInterface,
             GeekResponse geekResponse) throws CborException, RkpdException, InterruptedException {
-        try {
-            int keysRequired = calculateKeysRequired(metrics, systemInterface.getServiceName());
-            Log.i(TAG, "Requested number of keys for provisioning: " + keysRequired);
-            if (keysRequired == 0) {
-                metrics.setStatus(ProvisioningAttempt.Status.NO_PROVISIONING_NEEDED);
-                return;
-            }
+        synchronized (provisionKeysLock) {
+            try {
+                int keysRequired = calculateKeysRequired(metrics, systemInterface.getServiceName());
+                Log.i(TAG, "Requested number of keys for provisioning: " + keysRequired);
+                if (keysRequired == 0) {
+                    metrics.setStatus(ProvisioningAttempt.Status.NO_PROVISIONING_NEEDED);
+                    return;
+                }
 
-            List<RkpKey> keysGenerated = generateKeys(metrics, keysRequired, systemInterface);
-            checkForInterrupts();
-            List<byte[]> certChains = fetchCertificates(metrics, keysGenerated, systemInterface,
-                    geekResponse);
-            checkForInterrupts();
-            List<ProvisionedKey> keys = associateCertsWithKeys(certChains, keysGenerated);
+                List<RkpKey> keysGenerated = generateKeys(metrics, keysRequired, systemInterface);
+                checkForInterrupts();
+                List<byte[]> certChains = fetchCertificates(metrics, keysGenerated, systemInterface,
+                        geekResponse);
+                checkForInterrupts();
+                List<ProvisionedKey> keys = associateCertsWithKeys(certChains, keysGenerated);
 
-            mKeyDao.insertKeys(keys);
-            Log.i(TAG, "Total provisioned keys: " + keys.size());
-            metrics.setStatus(ProvisioningAttempt.Status.KEYS_SUCCESSFULLY_PROVISIONED);
-        } catch (InterruptedException e) {
-            metrics.setStatus(ProvisioningAttempt.Status.INTERRUPTED);
-            throw e;
-        } catch (RkpdException e) {
-            if (Settings.getFailureCounter(mContext) > FAILURE_MAXIMUM) {
-                Log.e(TAG, "Too many failures, resetting defaults.");
-                Settings.resetDefaultConfig(mContext);
+                mKeyDao.insertKeys(keys);
+                Log.i(TAG, "Total provisioned keys: " + keys.size());
+                metrics.setStatus(ProvisioningAttempt.Status.KEYS_SUCCESSFULLY_PROVISIONED);
+            } catch (InterruptedException e) {
+                metrics.setStatus(ProvisioningAttempt.Status.INTERRUPTED);
+                throw e;
+            } catch (RkpdException e) {
+                if (Settings.getFailureCounter(mContext) > FAILURE_MAXIMUM) {
+                    Log.e(TAG, "Too many failures, resetting defaults.");
+                    Settings.resetDefaultConfig(mContext);
+                }
+                // Rethrow to provide failure signal to caller
+                throw e;
             }
-            // Rethrow to provide failure signal to caller
-            throw e;
         }
     }
 
diff --git a/app/tests/stress/src/com/android/rkpdapp/stress/RegistrationBinderStressTest.java b/app/tests/stress/src/com/android/rkpdapp/stress/RegistrationBinderStressTest.java
index 40905ee..ce64d91 100644
--- a/app/tests/stress/src/com/android/rkpdapp/stress/RegistrationBinderStressTest.java
+++ b/app/tests/stress/src/com/android/rkpdapp/stress/RegistrationBinderStressTest.java
@@ -22,6 +22,7 @@
 
 import android.content.Context;
 import android.hardware.security.keymint.IRemotelyProvisionedComponent;
+import android.os.IBinder;
 import android.os.Process;
 import android.os.ServiceManager;
 import android.os.SystemProperties;
@@ -106,6 +107,11 @@
             public void onError(byte error, String description) {
                 result.complete(description);
             }
+
+            @Override
+            public IBinder asBinder() {
+                return this;
+            }
         });
         try {
             assertThat(result.get()).isEmpty();
diff --git a/util/Android.bp b/util/Android.bp
new file mode 100644
index 0000000..bf8165e
--- /dev/null
+++ b/util/Android.bp
@@ -0,0 +1,27 @@
+// Copyright (C) 2023 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package {
+    default_applicable_licenses: ["Android-Apache-2.0"],
+}
+
+java_binary {
+    name: "RkpRegistrationCheck",
+    srcs: ["src/**/RkpRegistrationCheck.java"],
+    static_libs: [
+        "android.hardware.security.rkp-V3-java",
+        "cbor-java",
+    ],
+    main_class: "com.android.rkpdapp.RkpRegistrationCheck",
+    platform_apis: true,
+}
diff --git a/util/RkpRegistrationCheck.sh b/util/RkpRegistrationCheck.sh
new file mode 100755
index 0000000..382e314
--- /dev/null
+++ b/util/RkpRegistrationCheck.sh
@@ -0,0 +1,25 @@
+#!/usr/bin/bash
+
+# Builds, installs, then runs a small test binary on an Android device that is
+# attached to your workstation. This tool checks to see if the KeyMint
+# instances on this device have been registered with the RKP backend.
+#
+# Run the script by passing the desired lunch target on the command-line:
+# ./packages/modules/RemoteKeyProvisioning/util/RkpRegistrationCheck.sh <aosp_arm64-userdebug>
+
+if [ -z "$1" ]; then
+  echo "Lunch target must be specified"
+  exit 1
+fi
+
+. build/envsetup.sh
+lunch $1
+m RkpRegistrationCheck
+
+adb push $ANDROID_PRODUCT_OUT/system/framework/RkpRegistrationCheck.jar \
+  /data/local/tmp
+
+adb shell "CLASSPATH=/data/local/tmp/RkpRegistrationCheck.jar \
+  exec app_process /system/bin com.android.rkpdapp.RkpRegistrationCheck"
+
+adb shell "rm /data/local/tmp/RkpRegistrationCheck.jar"
diff --git a/util/src/com/android/rkpdapp/RkpRegistrationCheck.java b/util/src/com/android/rkpdapp/RkpRegistrationCheck.java
new file mode 100644
index 0000000..faa99d6
--- /dev/null
+++ b/util/src/com/android/rkpdapp/RkpRegistrationCheck.java
@@ -0,0 +1,347 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.rkpdapp;
+
+import android.hardware.security.keymint.DeviceInfo;
+import android.hardware.security.keymint.IRemotelyProvisionedComponent;
+import android.hardware.security.keymint.MacedPublicKey;
+import android.hardware.security.keymint.ProtectedData;
+import android.hardware.security.keymint.RpcHardwareInfo;
+import android.net.Uri;
+import android.os.Build;
+import android.os.RemoteException;
+import android.os.ServiceManager;
+import android.os.ServiceSpecificException;
+import android.os.SystemProperties;
+import android.util.Base64;
+import android.util.Log;
+
+import java.io.BufferedInputStream;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.net.HttpURLConnection;
+import java.net.URL;
+import java.security.cert.Certificate;
+import java.security.cert.CertificateException;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.UUID;
+
+import co.nstant.in.cbor.CborBuilder;
+import co.nstant.in.cbor.CborDecoder;
+import co.nstant.in.cbor.CborEncoder;
+import co.nstant.in.cbor.CborException;
+import co.nstant.in.cbor.model.Array;
+import co.nstant.in.cbor.model.ByteString;
+import co.nstant.in.cbor.model.DataItem;
+import co.nstant.in.cbor.model.Map;
+import co.nstant.in.cbor.model.UnicodeString;
+import co.nstant.in.cbor.model.UnsignedInteger;
+
+/**
+ * Command-line utility that verifies each KeyMint instance on this device is able to
+ * get production RKP keys.
+ */
+public class RkpRegistrationCheck {
+    private static final String TAG = "RegistrationTest";
+    private static final int COSE_HEADER_ALGORITHM = 1;
+    private static final int COSE_ALGORITHM_HMAC_256 = 5;
+
+    private static final int SHARED_CERTIFICATES_INDEX = 0;
+    private static final int UNIQUE_CERTIFICATES_INDEX = 1;
+
+    private static final int TIMEOUT_MS = 20_000;
+    private final String mRequestId = UUID.randomUUID().toString();
+    private final String mInstanceName;
+
+    private static class NotRegisteredException extends Exception {
+    }
+
+    private static class FetchEekResponse {
+        private static final int EEK_AND_CURVE_INDEX = 0;
+        private static final int CHALLENGE_INDEX = 1;
+
+        private static final int CURVE_INDEX = 0;
+        private static final int EEK_CERT_CHAIN_INDEX = 1;
+
+        private final byte[] mChallenge;
+        private final HashMap<Integer, byte[]> mCurveToGeek = new HashMap<>();
+
+        FetchEekResponse(DataItem response) throws CborException, RemoteException {
+            List<DataItem> respItems = ((Array) response).getDataItems();
+            List<DataItem> allEekChains =
+                    ((Array) respItems.get(EEK_AND_CURVE_INDEX)).getDataItems();
+            for (DataItem entry : allEekChains) {
+                List<DataItem> curveAndEekChain = ((Array) entry).getDataItems();
+                UnsignedInteger curve = (UnsignedInteger) curveAndEekChain.get(CURVE_INDEX);
+                mCurveToGeek.put(curve.getValue().intValue(),
+                        encodeCbor(curveAndEekChain.get(EEK_CERT_CHAIN_INDEX)));
+            }
+
+            mChallenge = ((ByteString) respItems.get(CHALLENGE_INDEX)).getBytes();
+        }
+
+        public byte[] getEekChain(int curve) {
+            return mCurveToGeek.get(curve);
+        }
+
+        public byte[] getChallenge() {
+            return mChallenge;
+        }
+    }
+
+    /** Main entry point. */
+    public static void main(String[] args) {
+        if (SystemProperties.get("remote_provisioning.hostname").isEmpty()) {
+            System.out.println(
+                    "The RKP server hostname is not configured -- RKP is disabled.");
+        }
+
+        new RkpRegistrationCheck("default").checkNow();
+        new RkpRegistrationCheck("strongbox").checkNow();
+    }
+
+    RkpRegistrationCheck(String instanceName) {
+        mInstanceName = instanceName;
+    }
+
+    void checkNow() {
+        System.out.println();
+        System.out.println("Checking to see if the device key for HAL '" + mInstanceName
+                + "' has been registered...");
+
+        if (!isValidInstance()) {
+            System.err.println("Skipping registration check for '" + mInstanceName + "'.");
+            System.err.println("The HAL does not exist.");
+            return;
+        }
+
+        try {
+            FetchEekResponse eekResponse = fetchEek();
+            String serviceName = IRemotelyProvisionedComponent.DESCRIPTOR + "/" + mInstanceName;
+            IRemotelyProvisionedComponent binder = IRemotelyProvisionedComponent.Stub.asInterface(
+                    ServiceManager.waitForDeclaredService(serviceName));
+            byte[] csr = generateCsr(binder, eekResponse);
+            X509Certificate[] certs = signCertificates(csr, eekResponse.getChallenge());
+            Log.i(TAG, "Cert chain:");
+            for (X509Certificate c : certs) {
+                Log.i(TAG, "  " + c.toString());
+            }
+            System.out.println("SUCCESS: Device key for '" + mInstanceName + "' is registered");
+        } catch (ServiceSpecificException e) {
+            Log.e(TAG, e.getMessage(), e);
+            System.err.println("Error getting CSR for '" + mInstanceName + "': '" + e
+                    + "', skipping.");
+        } catch (NotRegisteredException e) {
+            Log.e(TAG, e.getMessage(), e);
+            System.out.println("FAIL: Device key for '" + mInstanceName + "' is NOT registered");
+        } catch (IOException | CborException | RemoteException | CertificateException e) {
+            Log.e(TAG, e.getMessage(), e);
+            System.err.println("Error checking device registration for '" + mInstanceName
+                    + "': '" + e + "', skipping.");
+        }
+    }
+
+    private boolean isValidInstance() {
+        // The SE policy checks appear to be very strict for shell, and we'll get a security
+        // exception for any HALs not actually declared. Instead, check to see if the given
+        // instance is in the list we can query.
+        String[] instances = ServiceManager.getDeclaredInstances(
+                IRemotelyProvisionedComponent.DESCRIPTOR);
+        for (String i : instances) {
+            if (i.equals(mInstanceName)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private Uri getBaseUri() {
+        String hostnameProperty = "remote_provisioning.hostname";
+        String hostname = SystemProperties.get(hostnameProperty);
+        if (hostname.isEmpty()) {
+            throw new RuntimeException("System property '" + hostnameProperty + "' is empty. "
+                    + "This device does not support RKP.");
+        }
+        return new Uri.Builder().scheme("https").authority(hostname).appendPath("v1").build();
+    }
+
+    FetchEekResponse fetchEek()
+            throws IOException, CborException, RemoteException, NotRegisteredException {
+        final Uri uri = getBaseUri().buildUpon().appendEncodedPath(":fetchEekChain").build();
+
+        final ByteArrayOutputStream input = new ByteArrayOutputStream();
+        new CborEncoder(input).encode(new CborBuilder()
+                .addMap()
+                .put("fingerprint", getFingerprint())
+                .put(new UnicodeString("id"), new UnsignedInteger(0))
+                .end()
+                .build());
+
+        return new FetchEekResponse(httpPost(uri, input.toByteArray()));
+    }
+
+    private String getFingerprint() {
+        // Fake a user build fingerprint so that we will get 444 on unregistered devices instead
+        // of test certs.
+        Log.i(TAG, "Original fingerprint: " + Build.FINGERPRINT);
+        String fingerprint = Build.FINGERPRINT
+                .replace(":userdebug", ":user")
+                .replace(":eng", ":user")
+                .replace("cf_", "cephalopod_");
+        Log.i(TAG, "Modified (prod-like) fingerprint: " + fingerprint);
+        return fingerprint;
+    }
+
+    X509Certificate[] signCertificates(byte[] csr, byte[] challenge)
+            throws IOException, CborException, CertificateException,
+            NotRegisteredException {
+        String encodedChallenge = Base64.encodeToString(challenge,
+                Base64.URL_SAFE | Base64.NO_WRAP);
+        final Uri uri = getBaseUri().buildUpon()
+                .appendEncodedPath(":signCertificates")
+                .appendQueryParameter("challenge", encodedChallenge)
+                .build();
+        DataItem response = httpPost(uri, csr);
+        List<DataItem> dataItems = ((Array) response).getDataItems();
+        byte[] sharedCertificates = ((ByteString) dataItems.get(
+                SHARED_CERTIFICATES_INDEX)).getBytes();
+        DataItem leaf = ((Array) dataItems.get(UNIQUE_CERTIFICATES_INDEX)).getDataItems().get(0);
+
+        ByteArrayOutputStream fullChainWriter = new ByteArrayOutputStream();
+        fullChainWriter.write(((ByteString) leaf).getBytes());
+        fullChainWriter.write(sharedCertificates);
+
+        ByteArrayInputStream fullChainReader = new ByteArrayInputStream(
+                fullChainWriter.toByteArray());
+        CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
+        ArrayList<Certificate> parsedCerts = new ArrayList<>(
+                certFactory.generateCertificates(fullChainReader));
+        return parsedCerts.toArray(new X509Certificate[0]);
+    }
+
+    DataItem httpPost(Uri uri, byte[] input)
+            throws IOException, CborException, NotRegisteredException {
+        uri = uri.buildUpon().appendQueryParameter("requestId", mRequestId).build();
+        Log.i(TAG, "querying " + uri);
+        HttpURLConnection con = (HttpURLConnection) new URL(uri.toString()).openConnection();
+        con.setRequestMethod("POST");
+        con.setConnectTimeout(TIMEOUT_MS);
+        con.setReadTimeout(TIMEOUT_MS);
+        con.setDoOutput(true);
+
+        try (OutputStream os = con.getOutputStream()) {
+            os.write(input, 0, input.length);
+        }
+
+        Log.i(TAG, "HTTP status: " + con.getResponseCode());
+
+        if (con.getResponseCode() == 444) {
+            throw new NotRegisteredException();
+        }
+
+        if (con.getResponseCode() != HttpURLConnection.HTTP_OK) {
+            throw new RuntimeException("Server connection failed for url: " + uri
+                    + ", HTTP response code: " + con.getResponseCode());
+        }
+
+        BufferedInputStream inputStream = new BufferedInputStream(con.getInputStream());
+        ByteArrayOutputStream cborBytes = new ByteArrayOutputStream();
+        byte[] buffer = new byte[1024];
+        int read;
+        while ((read = inputStream.read(buffer, 0, buffer.length)) != -1) {
+            cborBytes.write(buffer, 0, read);
+        }
+        inputStream.close();
+        byte[] response = cborBytes.toByteArray();
+        Log.i(TAG, "response (CBOR): " + Base64.encodeToString(response,
+                Base64.URL_SAFE | Base64.NO_WRAP));
+        return decodeCbor(response);
+    }
+
+    byte[] generateCsr(IRemotelyProvisionedComponent irpc, FetchEekResponse eekResponse)
+            throws RemoteException, CborException {
+        Map unverifiedDeviceInfo = new Map().put(
+                new UnicodeString("fingerprint"), new UnicodeString(getFingerprint()));
+
+        RpcHardwareInfo hwInfo = irpc.getHardwareInfo();
+
+        MacedPublicKey[] macedKeysToSign = new MacedPublicKey[]{new MacedPublicKey()};
+        irpc.generateEcdsaP256KeyPair(false, macedKeysToSign[0]);
+
+        if (hwInfo.versionNumber < 3) {
+            Log.i(TAG, "Generating CSRv1");
+            DeviceInfo deviceInfo = new DeviceInfo();
+            ProtectedData protectedData = new ProtectedData();
+            byte[] geekChain = eekResponse.getEekChain(hwInfo.supportedEekCurve);
+            byte[] csrTag = irpc.generateCertificateRequest(false, macedKeysToSign, geekChain,
+                    eekResponse.getChallenge(), deviceInfo, protectedData);
+            Array mac0Message = buildMac0MessageForV1Csr(macedKeysToSign[0], csrTag);
+            return encodeCbor(new CborBuilder()
+                    .addArray()
+                    .addArray()
+                    .add(decodeCbor(deviceInfo.deviceInfo))
+                    .add(unverifiedDeviceInfo)
+                    .end()
+                    .add(eekResponse.getChallenge())
+                    .add(decodeCbor(protectedData.protectedData))
+                    .add(mac0Message)
+                    .end()
+                    .build().get(0));
+        } else {
+            Log.i(TAG, "Generating CSRv2");
+            byte[] csrBytes = irpc.generateCertificateRequestV2(macedKeysToSign,
+                    eekResponse.getChallenge());
+            Array array = (Array) decodeCbor(csrBytes);
+            array.add(unverifiedDeviceInfo);
+            return encodeCbor(array);
+        }
+    }
+
+    Array buildMac0MessageForV1Csr(MacedPublicKey macedKeyToSign, byte[] csrTag)
+            throws CborException {
+        DataItem macedPayload = ((Array) decodeCbor(
+                macedKeyToSign.macedKey)).getDataItems().get(2);
+        Map macedCoseKey = (Map) decodeCbor(((ByteString) macedPayload).getBytes());
+        byte[] macedKeys = encodeCbor(new Array().add(macedCoseKey));
+
+        Map protectedHeaders = new Map().put(
+                new UnsignedInteger(COSE_HEADER_ALGORITHM),
+                new UnsignedInteger(COSE_ALGORITHM_HMAC_256));
+        return new Array()
+                .add(new ByteString(encodeCbor(protectedHeaders)))
+                .add(new Map())
+                .add(new ByteString(macedKeys))
+                .add(new ByteString(csrTag));
+    }
+
+    static DataItem decodeCbor(byte[] encodedBytes) throws CborException {
+        ByteArrayInputStream inputStream = new ByteArrayInputStream(encodedBytes);
+        return new CborDecoder(inputStream).decode().get(0);
+    }
+
+    static byte[] encodeCbor(final DataItem dataItem) throws CborException {
+        final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        new CborEncoder(outputStream).encode(dataItem);
+        return outputStream.toByteArray();
+    }
+}