blob: 4b49b0737f667262907cd7ba4b05735967a1125e [file] [log] [blame]
/* GENERATED SOURCE. DO NOT MODIFY. */
// © 2021 and later: Unicode, Inc. and others.
// License & terms of use: http://www.unicode.org/copyright.html
//
/**
* A LSTMBreakEngine
*/
package android.icu.impl.breakiter;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.text.CharacterIterator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import android.icu.impl.ICUData;
import android.icu.impl.ICUResourceBundle;
import android.icu.lang.UCharacter;
import android.icu.lang.UProperty;
import android.icu.lang.UScript;
import android.icu.text.BreakIterator;
import android.icu.text.UnicodeSet;
import android.icu.util.UResourceBundle;
/**
* @hide Only a subset of ICU is exposed in Android
* @hide draft / provisional / internal are hidden on Android
*/
public class LSTMBreakEngine extends DictionaryBreakEngine {
/**
* @hide Only a subset of ICU is exposed in Android
*/
public enum EmbeddingType {
UNKNOWN,
CODE_POINTS,
GRAPHEME_CLUSTER
}
/**
* @hide Only a subset of ICU is exposed in Android
*/
public enum LSTMClass {
BEGIN,
INSIDE,
END,
SINGLE,
}
private static float[][] make2DArray(int[] data, int start, int d1, int d2) {
byte[] bytes = new byte[4];
float [][] result = new float[d1][d2];
for (int i = 0; i < d1 ; i++) {
for (int j = 0; j < d2 ; j++) {
int d = data[start++];
bytes[0] = (byte) (d >> 24);
bytes[1] = (byte) (d >> 16);
bytes[2] = (byte) (d >> 8);
bytes[3] = (byte) (d /*>> 0*/);
result[i][j] = ByteBuffer.wrap(bytes).order(ByteOrder.BIG_ENDIAN).getFloat();
}
}
return result;
}
private static float[] make1DArray(int[] data, int start, int d1) {
byte[] bytes = new byte[4];
float [] result = new float[d1];
for (int i = 0; i < d1 ; i++) {
int d = data[start++];
bytes[0] = (byte) (d >> 24);
bytes[1] = (byte) (d >> 16);
bytes[2] = (byte) (d >> 8);
bytes[3] = (byte) (d /*>> 0*/);
result[i] = ByteBuffer.wrap(bytes).order(ByteOrder.BIG_ENDIAN).getFloat();
}
return result;
}
/** @hide Only a subset of ICU is exposed in Android
* @hide draft / provisional / internal are hidden on Android*/
public static class LSTMData {
private LSTMData() {
}
public LSTMData(UResourceBundle rb) {
int embeddings = rb.get("embeddings").getInt();
int hunits = rb.get("hunits").getInt();
this.fType = EmbeddingType.UNKNOWN;
this.fName = rb.get("model").getString();
String typeString = rb.get("type").getString();
if (typeString.equals("codepoints")) {
this.fType = EmbeddingType.CODE_POINTS;
} else if (typeString.equals("graphclust")) {
this.fType = EmbeddingType.GRAPHEME_CLUSTER;
}
String[] dict = rb.get("dict").getStringArray();
int[] data = rb.get("data").getIntVector();
int dataLen = data.length;
int numIndex = dict.length;
fDict = new HashMap<String, Integer>(numIndex + 1);
int idx = 0;
for (String embedding : dict){
fDict.put(embedding, idx++);
}
int mat1Size = (numIndex + 1) * embeddings;
int mat2Size = embeddings * 4 * hunits;
int mat3Size = hunits * 4 * hunits;
int mat4Size = 4 * hunits;
int mat5Size = mat2Size;
int mat6Size = mat3Size;
int mat7Size = mat4Size;
int mat8Size = 2 * hunits * 4;
int mat9Size = 4;
assert dataLen == mat1Size + mat2Size + mat3Size + mat4Size + mat5Size + mat6Size + mat7Size + mat8Size + mat9Size;
int start = 0;
this.fEmbedding = make2DArray(data, start, (numIndex+1), embeddings);
start += mat1Size;
this.fForwardW = make2DArray(data, start, embeddings, 4 * hunits);
start += mat2Size;
this.fForwardU = make2DArray(data, start, hunits, 4 * hunits);
start += mat3Size;
this.fForwardB = make1DArray(data, start, 4 * hunits);
start += mat4Size;
this.fBackwardW = make2DArray(data, start, embeddings, 4 * hunits);
start += mat5Size;
this.fBackwardU = make2DArray(data, start, hunits, 4 * hunits);
start += mat6Size;
this.fBackwardB = make1DArray(data, start, 4 * hunits);
start += mat7Size;
this.fOutputW = make2DArray(data, start, 2 * hunits, 4);
start += mat8Size;
this.fOutputB = make1DArray(data, start, 4);
}
public EmbeddingType fType;
public String fName;
public Map<String, Integer> fDict;
public float fEmbedding[][];
public float fForwardW[][];
public float fForwardU[][];
public float fForwardB[];
public float fBackwardW[][];
public float fBackwardU[][];
public float fBackwardB[];
public float fOutputW[][];
public float fOutputB[];
}
// Minimum word size
private static final byte MIN_WORD = 2;
// Minimum number of characters for two words
private static final byte MIN_WORD_SPAN = MIN_WORD * 2;
abstract class Vectorizer {
public Vectorizer(Map<String, Integer> dict) {
this.fDict = dict;
}
abstract public void vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd,
List<Integer> offsets, List<Integer> indicies);
protected int getIndex(String token) {
Integer res = fDict.get(token);
return (res == null) ? fDict.size() : res;
}
private Map<String, Integer> fDict;
}
class CodePointsVectorizer extends Vectorizer {
public CodePointsVectorizer(Map<String, Integer> dict) {
super(dict);
}
public void vectorize(CharacterIterator fIter, int rangeStart, int rangeEnd,
List<Integer> offsets, List<Integer> indicies) {
fIter.setIndex(rangeStart);
for (char c = fIter.current();
c != CharacterIterator.DONE && fIter.getIndex() < rangeEnd;
c = fIter.next()) {
offsets.add(fIter.getIndex());
indicies.add(getIndex(String.valueOf(c)));
}
}
}
class GraphemeClusterVectorizer extends Vectorizer {
public GraphemeClusterVectorizer(Map<String, Integer> dict) {
super(dict);
}
private String substring(CharacterIterator text, int startPos, int endPos) {
int saved = text.getIndex();
text.setIndex(startPos);
StringBuilder sb = new StringBuilder();
for (char c = text.current();
c != CharacterIterator.DONE && text.getIndex() < endPos;
c = text.next()) {
sb.append(c);
}
text.setIndex(saved);
return sb.toString();
}
public void vectorize(CharacterIterator text, int startPos, int endPos,
List<Integer> offsets, List<Integer> indicies) {
BreakIterator iter = BreakIterator.getCharacterInstance();
iter.setText(text);
int last = iter.next(startPos);
for (int curr = iter.next(); curr != BreakIterator.DONE && curr <= endPos; curr = iter.next()) {
offsets.add(last);
String segment = substring(text, last, curr);
int index = getIndex(segment);
indicies.add(index);
last = curr;
}
}
}
private final LSTMData fData;
private int fScript;
private final Vectorizer fVectorizer;
private Vectorizer makeVectorizer(LSTMData data) {
switch(data.fType) {
case CODE_POINTS:
return new CodePointsVectorizer(data.fDict);
case GRAPHEME_CLUSTER:
return new GraphemeClusterVectorizer(data.fDict);
default:
return null;
}
}
public LSTMBreakEngine(int script, UnicodeSet set, LSTMData data) {
setCharacters(set);
this.fScript = script;
this.fData = data;
this.fVectorizer = makeVectorizer(this.fData);
}
@Override
public int hashCode() {
return getClass().hashCode();
}
@Override
public boolean handles(int c) {
return fScript == UCharacter.getIntPropertyValue(c, UProperty.SCRIPT);
}
static private void addDotProductTo(final float [] a, final float[][] b, float[] result) {
assert a.length == b.length;
assert b[0].length == result.length;
for (int i = 0; i < result.length; i++) {
for (int j = 0; j < a.length; j++) {
result[i] += a[j] * b[j][i];
}
}
}
static private void addTo(final float [] a, float[] result) {
assert a.length == result.length;
for (int i = 0; i < result.length; i++) {
result[i] += a[i];
}
}
static private void hadamardProductTo(final float [] a, float[] result) {
assert a.length == result.length;
for (int i = 0; i < result.length; i++) {
result[i] *= a[i];
}
}
static private void addHadamardProductTo(final float [] a, final float [] b, float[] result) {
assert a.length == result.length;
assert b.length == result.length;
for (int i = 0; i < result.length; i++) {
result[i] += a[i] * b[i];
}
}
static private void sigmoid(float [] result, int start, int length) {
assert start < result.length;
assert start + length <= result.length;
for (int i = start; i < start + length; i++) {
result[i] = (float)(1.0/(1.0 + Math.exp(-result[i])));
}
}
static private void tanh(float [] result, int start, int length) {
assert start < result.length;
assert start + length <= result.length;
for (int i = start; i < start + length; i++) {
result[i] = (float)Math.tanh(result[i]);
}
}
static private int maxIndex(float [] data) {
int index = 0;
float max = data[0];
for (int i = 1; i < data.length; i++) {
if (data[i] > max) {
max = data[i];
index = i;
}
}
return index;
}
/*
static private void print(float [] data) {
for (int i=0; i < data.length; i++) {
System.out.format(" %e", data[i]);
if (i % 4 == 3) {
System.out.println();
}
}
System.out.println();
}
*/
private float[] compute(final float[][] W, final float[][] U, final float[] B,
final float[] x, float[] h, float[] c) {
// ifco = x * W + h * U + b
float[] ifco = Arrays.copyOf(B, B.length);
addDotProductTo(x, W, ifco);
float[] hU = new float[B.length];
addDotProductTo(h, U, ifco);
int hunits = B.length / 4;
sigmoid(ifco, 0*hunits, hunits); // i
sigmoid(ifco, 1*hunits, hunits); // f
tanh(ifco, 2*hunits, hunits); // c_
sigmoid(ifco, 3*hunits, hunits); // o
hadamardProductTo(Arrays.copyOfRange(ifco, hunits, 2*hunits), c);
addHadamardProductTo(Arrays.copyOf(ifco, hunits),
Arrays.copyOfRange(ifco, 2*hunits, 3*hunits), c);
h = Arrays.copyOf(c, c.length);
tanh(h, 0, h.length);
hadamardProductTo(Arrays.copyOfRange(ifco, 3*hunits, 4*hunits), h);
// System.out.println("c");
// print(c);
// System.out.println("h");
// print(h);
return h;
}
@Override
public int divideUpDictionaryRange(CharacterIterator fIter, int rangeStart, int rangeEnd,
DequeI foundBreaks, boolean isPhraseBreaking) {
int beginSize = foundBreaks.size();
if ((rangeEnd - rangeStart) < MIN_WORD_SPAN) {
return 0; // Not enough characters for word
}
List<Integer> offsets = new ArrayList<Integer>(rangeEnd - rangeStart);
List<Integer> indicies = new ArrayList<Integer>(rangeEnd - rangeStart);
fVectorizer.vectorize(fIter, rangeStart, rangeEnd, offsets, indicies);
// To save the needed memory usage, the following is different from the
// Python or ICU4X implementation. We first perform the Backward LSTM
// and then merge the iteration of the forward LSTM and the output layer
// together because we only need to remember the h[t-1] for Forward LSTM.
int inputSeqLength = indicies.size();
int hunits = this.fData.fForwardU.length;
float c[] = new float[hunits];
// TODO: limit size of hBackward. If input_seq_len is too big, we could
// run out of memory.
// Backward LSTM
float hBackward[][] = new float[inputSeqLength][hunits];
for (int i = inputSeqLength - 1; i >= 0; i--) {
if (i != inputSeqLength - 1) {
hBackward[i] = Arrays.copyOf(hBackward[i+1], hunits);
}
// System.out.println("Backward LSTM " + i);
hBackward[i] = compute(this.fData.fBackwardW, this.fData.fBackwardU, this.fData.fBackwardB,
this.fData.fEmbedding[indicies.get(i)],
hBackward[i], c);
}
c = new float[hunits];
float forwardH[] = new float[hunits];
float both[] = new float[2*hunits];
// The following iteration merge the forward LSTM and the output layer
// together.
for (int i = 0 ; i < inputSeqLength; i++) {
// Forward LSTM
forwardH = compute(this.fData.fForwardW, this.fData.fForwardU, this.fData.fForwardB,
this.fData.fEmbedding[indicies.get(i)],
forwardH, c);
System.arraycopy(forwardH, 0, both, 0, hunits);
System.arraycopy(hBackward[i], 0, both, hunits, hunits);
//System.out.println("Merged " + i);
//print(both);
// Output layer
// logp = fbRow * fOutputW + fOutputB
float logp[] = Arrays.copyOf(this.fData.fOutputB, this.fData.fOutputB.length);
addDotProductTo(both, this.fData.fOutputW, logp);
int current = maxIndex(logp);
// BIES logic.
if (current == LSTMClass.BEGIN.ordinal() ||
current == LSTMClass.SINGLE.ordinal()) {
if (i != 0) {
foundBreaks.push(offsets.get(i));
}
}
}
return foundBreaks.size() - beginSize;
}
public static LSTMData createData(UResourceBundle bundle) {
return new LSTMData(bundle);
}
private static String defaultLSTM(int script) {
ICUResourceBundle rb = (ICUResourceBundle)UResourceBundle.getBundleInstance(ICUData.ICU_BRKITR_BASE_NAME);
return rb.getStringWithFallback("lstm/" + UScript.getShortName(script));
}
public static LSTMData createData(int script) {
if (script != UScript.KHMER && script != UScript.LAO && script != UScript.MYANMAR && script != UScript.THAI) {
return null;
}
String name = defaultLSTM(script);
name = name.substring(0, name.indexOf("."));
UResourceBundle rb = UResourceBundle.getBundleInstance(
ICUData.ICU_BRKITR_BASE_NAME, name,
ICUResourceBundle.ICU_DATA_CLASS_LOADER);
return createData(rb);
}
public static LSTMBreakEngine create(int script, LSTMData data) {
String setExpr = "[[:" + UScript.getShortName(script) + ":]&[:LineBreak=SA:]]";
UnicodeSet set = new UnicodeSet();
set.applyPattern(setExpr);
set.compact();
return new LSTMBreakEngine(script, set, data);
}
}