/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.parser.lexparser.BinaryGrammar;
import edu.stanford.nlp.parser.lexparser.BinaryRule;
import edu.stanford.nlp.parser.lexparser.IntTaggedWord;
import edu.stanford.nlp.parser.lexparser.Lexicon;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.UnaryGrammar;
import edu.stanford.nlp.parser.lexparser.UnaryRule;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ThreeDimensionalMap;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.TwoDimensionalMap;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;

public class SplittingGrammarExtractor {
    static final int MIN_DEBUG_ITERATION = 0;
    static final int MAX_DEBUG_ITERATION = 0;
    static final int MAX_ITERATIONS = Integer.MAX_VALUE;
    int iteration = 0;
    Options op;
    Index<String> stateIndex;
    Index<String> wordIndex;
    Index<String> tagIndex;
    List<String> startSymbols;
    List<Tree> trees = new ArrayList<Tree>();
    Counter<Tree> treeWeights = new ClassicCounter(MapFactory.identityHashMapFactory());
    double trainSize;
    Set<String> originalStates = new HashSet<String>();
    IntCounter<String> stateSplitCounts = new IntCounter();
    ThreeDimensionalMap<String, String, String, double[][][]> binaryBetas = new ThreeDimensionalMap();
    TwoDimensionalMap<String, String, double[][]> unaryBetas = new TwoDimensionalMap();
    Lexicon lex;
    transient Index<String> tempWordIndex;
    transient Index<String> tempTagIndex;
    transient Lexicon tempLex;
    Pair<UnaryGrammar, BinaryGrammar> bgug;
    Random random = new Random(87543875943265L);
    static final double LEX_SMOOTH = 1.0E-4;
    static final double STATE_SMOOTH = 0.0;
    static final double EPSILON = 1.0E-4;

    boolean DEBUG() {
        return this.iteration >= 0 && this.iteration < 0;
    }

    public SplittingGrammarExtractor(Options op) {
        this.op = op;
        this.startSymbols = Arrays.asList(op.langpack().startSymbols());
    }

    double[] neginfDoubles(int size) {
        double[] result = new double[size];
        for (int i = 0; i < size; ++i) {
            result[i] = Double.NEGATIVE_INFINITY;
        }
        return result;
    }

    public void outputTransitions(Tree tree, IdentityHashMap<Tree, double[][]> unaryTransitions, IdentityHashMap<Tree, double[][][]> binaryTransitions) {
        this.outputTransitions(tree, 0, unaryTransitions, binaryTransitions);
    }

    public void outputTransitions(Tree tree, int depth, IdentityHashMap<Tree, double[][]> unaryTransitions, IdentityHashMap<Tree, double[][][]> binaryTransitions) {
        int j;
        int i;
        for (int i2 = 0; i2 < depth; ++i2) {
            System.out.print(" ");
        }
        if (tree.isLeaf()) {
            System.out.println(tree.label().value());
            return;
        }
        if (tree.children().length == 1) {
            System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value());
            if (!tree.isPreTerminal()) {
                double[][] transitions = unaryTransitions.get(tree);
                for (i = 0; i < transitions.length; ++i) {
                    for (j = 0; j < transitions[0].length; ++j) {
                        for (int z = 0; z < depth; ++z) {
                            System.out.print(" ");
                        }
                        System.out.println("  " + i + "," + j + ": " + transitions[i][j] + " | " + Math.exp(transitions[i][j]));
                    }
                }
            }
        } else {
            System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value() + " " + tree.children()[1].label().value());
            double[][][] transitions = binaryTransitions.get(tree);
            for (i = 0; i < transitions.length; ++i) {
                for (j = 0; j < transitions[0].length; ++j) {
                    for (int k = 0; k < transitions[0][0].length; ++k) {
                        for (int z = 0; z < depth; ++z) {
                            System.out.print(" ");
                        }
                        System.out.println("  " + i + "," + j + "," + k + ": " + transitions[i][j][k] + " | " + Math.exp(transitions[i][j][k]));
                    }
                }
            }
        }
        if (tree.isPreTerminal()) {
            return;
        }
        for (Tree child : tree.children()) {
            this.outputTransitions(child, depth + 1, unaryTransitions, binaryTransitions);
        }
    }

    public void outputBetas() {
        System.out.println("UNARY:");
        for (String parent : this.unaryBetas.firstKeySet()) {
            for (String child : this.unaryBetas.get(parent).keySet()) {
                System.out.println("  " + parent + "->" + child);
                double[][] betas = this.unaryBetas.get(parent).get(child);
                int parentStates = betas.length;
                int childStates = betas[0].length;
                for (int i = 0; i < parentStates; ++i) {
                    for (int j = 0; j < childStates; ++j) {
                        System.out.println("    " + i + "->" + j + " " + betas[i][j] + " | " + Math.exp(betas[i][j]));
                    }
                }
            }
        }
        System.out.println("BINARY:");
        for (String parent : this.binaryBetas.firstKeySet()) {
            for (String left : this.binaryBetas.get(parent).firstKeySet()) {
                for (String right : this.binaryBetas.get(parent).get(left).keySet()) {
                    System.out.println("  " + parent + "->" + left + "," + right);
                    double[][][] betas = this.binaryBetas.get(parent).get(left).get(right);
                    int parentStates = betas.length;
                    int leftStates = betas[0].length;
                    int rightStates = betas[0][0].length;
                    for (int i = 0; i < parentStates; ++i) {
                        for (int j = 0; j < leftStates; ++j) {
                            for (int k = 0; k < rightStates; ++k) {
                                System.out.println("    " + i + "->" + j + "," + k + " " + betas[i][j][k] + " | " + Math.exp(betas[i][j][k]));
                            }
                        }
                    }
                }
            }
        }
    }

    public String state(String tag, int i) {
        if (this.startSymbols.contains(tag) || tag.equals(".$$.")) {
            return tag;
        }
        return tag + "^" + i;
    }

    public int getStateSplitCount(Tree tree) {
        return this.stateSplitCounts.getIntCount(tree.label().value());
    }

    public int getStateSplitCount(String label) {
        return this.stateSplitCounts.getIntCount(label);
    }

    public void countOriginalStates() {
        this.originalStates.clear();
        for (Tree tree : this.trees) {
            this.countOriginalStates(tree);
        }
        for (String state : this.originalStates) {
            this.stateSplitCounts.incrementCount(state, 1);
        }
    }

    private void countOriginalStates(Tree tree) {
        if (tree.isLeaf()) {
            return;
        }
        this.originalStates.add(tree.label().value());
        for (Tree child : tree.children()) {
            if (child.isLeaf()) continue;
            this.countOriginalStates(child);
        }
    }

    private void initialBetasAndLexicon() {
        this.wordIndex = new HashIndex<String>();
        this.tagIndex = new HashIndex<String>();
        this.lex = this.op.tlpParams.lex(this.op, this.wordIndex, this.tagIndex);
        this.lex.initializeTraining(this.trainSize);
        for (Tree tree : this.trees) {
            double weight2 = this.treeWeights.getCount(tree);
            this.lex.incrementTreesRead(weight2);
            this.initialBetasAndLexicon(tree, 0, weight2);
        }
        this.lex.finishTraining();
    }

    private int initialBetasAndLexicon(Tree tree, int position, double weight2) {
        String label;
        if (tree.isLeaf()) {
            return position;
        }
        if (tree.isPreTerminal()) {
            String tag = tree.label().value();
            String word = tree.children()[0].label().value();
            TaggedWord tw = new TaggedWord(word, this.state(tag, 0));
            this.lex.train(tw, position, weight2);
            return position + 1;
        }
        if (tree.children().length == 2) {
            String rightLabel;
            String leftLabel;
            label = tree.label().value();
            if (!this.binaryBetas.contains(label, leftLabel = tree.getChild(0).label().value(), rightLabel = tree.getChild(1).label().value())) {
                double[][][] map = new double[1][1][1];
                map[0][0][0] = 0.0;
                this.binaryBetas.put(label, leftLabel, rightLabel, map);
            }
        } else if (tree.children().length == 1) {
            String childLabel;
            label = tree.label().value();
            if (!this.unaryBetas.contains(label, childLabel = tree.getChild(0).label().value())) {
                double[][] map = new double[1][1];
                map[0][0] = 0.0;
                this.unaryBetas.put(label, childLabel, map);
            }
        } else {
            throw new RuntimeException("Trees should have been binarized, expected 1 or 2 children");
        }
        for (Tree child : tree.children()) {
            position = this.initialBetasAndLexicon(child, position, weight2);
        }
        return position;
    }

    private void splitStateCounts() {
        IntCounter<String> newStateSplitCounts = new IntCounter<String>();
        newStateSplitCounts.addAll(this.stateSplitCounts);
        newStateSplitCounts.addAll(this.stateSplitCounts);
        for (String root : this.startSymbols) {
            if (!(newStateSplitCounts.getCount(root) > 1.0)) continue;
            newStateSplitCounts.setCount(root, 1);
        }
        if (newStateSplitCounts.getCount(".$$.") > 1.0) {
            newStateSplitCounts.setCount(".$$.", 1);
        }
        this.stateSplitCounts = newStateSplitCounts;
    }

    public void splitBetas() {
        TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
        ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
        for (String parent : this.unaryBetas.firstKeySet()) {
            for (String child : this.unaryBetas.get(parent).keySet()) {
                int j;
                int i;
                double[][] newBetas;
                double[][] betas = this.unaryBetas.get(parent, child);
                int parentStates = betas.length;
                int childStates = betas[0].length;
                if (!this.startSymbols.contains(parent)) {
                    newBetas = new double[parentStates * 2][childStates];
                    for (i = 0; i < parentStates; ++i) {
                        for (j = 0; j < childStates; ++j) {
                            newBetas[i * 2][j] = betas[i][j];
                            newBetas[i * 2 + 1][j] = betas[i][j];
                        }
                    }
                    parentStates *= 2;
                    betas = newBetas;
                }
                if (!child.equals(".$$.")) {
                    newBetas = new double[parentStates][childStates * 2];
                    for (i = 0; i < parentStates; ++i) {
                        for (j = 0; j < childStates; ++j) {
                            double childWeight = 0.45 + this.random.nextDouble() * 0.1;
                            newBetas[i][j * 2] = betas[i][j] + Math.log(childWeight);
                            newBetas[i][j * 2 + 1] = betas[i][j] + Math.log(1.0 - childWeight);
                        }
                    }
                    betas = newBetas;
                }
                tempUnaryBetas.put(parent, child, betas);
            }
        }
        for (String parent : this.binaryBetas.firstKeySet()) {
            for (String left : this.binaryBetas.get(parent).firstKeySet()) {
                for (String right : this.binaryBetas.get(parent).get(left).keySet()) {
                    int k;
                    int j;
                    int i;
                    double[][][] newBetas;
                    double[][][] betas = this.binaryBetas.get(parent, left, right);
                    int parentStates = betas.length;
                    int leftStates = betas[0].length;
                    int rightStates = betas[0][0].length;
                    if (!this.startSymbols.contains(parent)) {
                        newBetas = new double[parentStates * 2][leftStates][rightStates];
                        for (i = 0; i < parentStates; ++i) {
                            for (j = 0; j < leftStates; ++j) {
                                for (k = 0; k < rightStates; ++k) {
                                    newBetas[i * 2][j][k] = betas[i][j][k];
                                    newBetas[i * 2 + 1][j][k] = betas[i][j][k];
                                }
                            }
                        }
                        parentStates *= 2;
                        betas = newBetas;
                    }
                    newBetas = new double[parentStates][leftStates * 2][rightStates];
                    for (i = 0; i < parentStates; ++i) {
                        for (j = 0; j < leftStates; ++j) {
                            for (k = 0; k < rightStates; ++k) {
                                double leftWeight = 0.45 + this.random.nextDouble() * 0.1;
                                newBetas[i][j * 2][k] = betas[i][j][k] + Math.log(leftWeight);
                                newBetas[i][j * 2 + 1][k] = betas[i][j][k] + Math.log(1.0 - leftWeight);
                            }
                        }
                    }
                    leftStates *= 2;
                    betas = newBetas;
                    if (!right.equals(".$$.")) {
                        newBetas = new double[parentStates][leftStates][rightStates * 2];
                        for (i = 0; i < parentStates; ++i) {
                            for (j = 0; j < leftStates; ++j) {
                                for (k = 0; k < rightStates; ++k) {
                                    double rightWeight = 0.45 + this.random.nextDouble() * 0.1;
                                    newBetas[i][j][k * 2] = betas[i][j][k] + Math.log(rightWeight);
                                    newBetas[i][j][k * 2 + 1] = betas[i][j][k] + Math.log(1.0 - rightWeight);
                                }
                            }
                        }
                    }
                    tempBinaryBetas.put(parent, left, right, newBetas);
                }
            }
        }
        this.unaryBetas = tempUnaryBetas;
        this.binaryBetas = tempBinaryBetas;
    }

    public boolean recalculateBetas(boolean splitStates) {
        if (splitStates) {
            if (this.DEBUG()) {
                System.out.println("Pre-split betas");
                this.outputBetas();
            }
            this.splitBetas();
            if (this.DEBUG()) {
                System.out.println("Post-split betas");
                this.outputBetas();
            }
        }
        TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
        ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
        this.recalculateTemporaryBetas(splitStates, null, tempUnaryBetas, tempBinaryBetas);
        boolean converged = this.useNewBetas(!splitStates, tempUnaryBetas, tempBinaryBetas);
        if (this.DEBUG()) {
            this.outputBetas();
        }
        return converged;
    }

    public boolean useNewBetas(boolean testConverged, TwoDimensionalMap<String, String, double[][]> tempUnaryBetas, ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
        this.rescaleTemporaryBetas(tempUnaryBetas, tempBinaryBetas);
        boolean converged = testConverged && this.testConvergence(tempUnaryBetas, tempBinaryBetas);
        this.unaryBetas = tempUnaryBetas;
        this.binaryBetas = tempBinaryBetas;
        this.wordIndex = this.tempWordIndex;
        this.tagIndex = this.tempTagIndex;
        this.lex = this.tempLex;
        if (this.DEBUG()) {
            System.out.println("LEXICON");
            try {
                OutputStreamWriter osw = new OutputStreamWriter((OutputStream)System.out, "utf-8");
                this.lex.writeData(osw);
                osw.flush();
            }
            catch (IOException e) {
                throw new RuntimeIOException(e);
            }
        }
        this.tempWordIndex = null;
        this.tempTagIndex = null;
        this.tempLex = null;
        return converged;
    }

    public void recalculateTemporaryBetas(boolean splitStates, Map<String, double[]> totalStateMass, TwoDimensionalMap<String, String, double[][]> tempUnaryBetas, ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
        this.tempWordIndex = new HashIndex<String>();
        this.tempTagIndex = new HashIndex<String>();
        this.tempLex = this.op.tlpParams.lex(this.op, this.tempWordIndex, this.tempTagIndex);
        this.tempLex.initializeTraining(this.trainSize);
        for (Tree tree : this.trees) {
            double weight2 = this.treeWeights.getCount(tree);
            if (this.DEBUG()) {
                System.out.println("Incrementing trees read: " + weight2);
            }
            this.tempLex.incrementTreesRead(weight2);
            this.recalculateTemporaryBetas(tree, splitStates, totalStateMass, tempUnaryBetas, tempBinaryBetas);
        }
        this.tempLex.finishTraining();
    }

    public boolean testConvergence(TwoDimensionalMap<String, String, double[][]> tempUnaryBetas, ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
        for (String parentLabel : this.unaryBetas.firstKeySet()) {
            for (String childLabel : this.unaryBetas.get(parentLabel).keySet()) {
                double[][] betas = this.unaryBetas.get(parentLabel, childLabel);
                double[][] newBetas = tempUnaryBetas.get(parentLabel, childLabel);
                int parentStates = betas.length;
                int childStates = betas[0].length;
                for (int i = 0; i < parentStates; ++i) {
                    for (int j = 0; j < childStates; ++j) {
                        double newValue = newBetas[i][j];
                        double oldValue = betas[i][j];
                        if (!(Math.abs(newValue - oldValue) > 1.0E-4)) continue;
                        return false;
                    }
                }
            }
        }
        for (String parentLabel : this.binaryBetas.firstKeySet()) {
            for (String leftLabel : this.binaryBetas.get(parentLabel).firstKeySet()) {
                for (String rightLabel : this.binaryBetas.get(parentLabel).get(leftLabel).keySet()) {
                    double[][][] betas = this.binaryBetas.get(parentLabel, leftLabel, rightLabel);
                    double[][][] newBetas = tempBinaryBetas.get(parentLabel, leftLabel, rightLabel);
                    int parentStates = betas.length;
                    int leftStates = betas[0].length;
                    int rightStates = betas[0][0].length;
                    for (int i = 0; i < parentStates; ++i) {
                        for (int j = 0; j < leftStates; ++j) {
                            for (int k = 0; k < rightStates; ++k) {
                                double newValue = newBetas[i][j][k];
                                double oldValue = betas[i][j][k];
                                if (!(Math.abs(newValue - oldValue) > 1.0E-4)) continue;
                                return false;
                            }
                        }
                    }
                }
            }
        }
        return true;
    }

    public void recalculateTemporaryBetas(Tree tree, boolean splitStates, Map<String, double[]> totalStateMass, TwoDimensionalMap<String, String, double[][]> tempUnaryBetas, ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
        if (this.DEBUG()) {
            System.out.println("Recalculating temporary betas for tree " + tree);
        }
        double[] stateWeights = new double[]{Math.log(this.treeWeights.getCount(tree))};
        IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<Tree, double[][]>();
        IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<Tree, double[][][]>();
        this.recountTree(tree, splitStates, unaryTransitions, binaryTransitions);
        if (this.DEBUG()) {
            System.out.println("  Transitions:");
            this.outputTransitions(tree, unaryTransitions, binaryTransitions);
        }
        this.recalculateTemporaryBetas(tree, stateWeights, 0, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas);
    }

    public int recalculateTemporaryBetas(Tree tree, double[] stateWeights, int position, IdentityHashMap<Tree, double[][]> unaryTransitions, IdentityHashMap<Tree, double[][][]> binaryTransitions, Map<String, double[]> totalStateMass, TwoDimensionalMap<String, String, double[][]> tempUnaryBetas, ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
        String parentLabel;
        if (tree.isLeaf()) {
            return position;
        }
        if (totalStateMass != null) {
            double[] stateTotal = totalStateMass.get(tree.label().value());
            if (stateTotal == null) {
                stateTotal = new double[stateWeights.length];
                totalStateMass.put(tree.label().value(), stateTotal);
            }
            for (int i = 0; i < stateWeights.length; ++i) {
                int n = i;
                stateTotal[n] = stateTotal[n] + Math.exp(stateWeights[i]);
            }
        }
        if (tree.isPreTerminal()) {
            String tag = tree.label().value();
            String word = tree.children()[0].label().value();
            double total = 0.0;
            for (int state = 0; state < stateWeights.length; ++state) {
                total += Math.exp(stateWeights[state]);
            }
            if (total <= 0.0) {
                return position + 1;
            }
            double scale = 0.9999000099990001;
            double smoothing = total * 1.0E-4 / (double)stateWeights.length;
            for (int state = 0; state < stateWeights.length; ++state) {
                TaggedWord tw = new TaggedWord(word, this.state(tag, state));
                this.tempLex.train(tw, position, (Math.exp(stateWeights[state]) + smoothing) * scale);
            }
            return position + 1;
        }
        if (tree.children().length == 1) {
            parentLabel = tree.label().value();
            String childLabel = tree.children()[0].label().value();
            double[][] transitions = unaryTransitions.get(tree);
            int parentStates = transitions.length;
            int childStates = transitions[0].length;
            double[][] betas = tempUnaryBetas.get(parentLabel, childLabel);
            if (betas == null) {
                betas = new double[parentStates][childStates];
                for (int i = 0; i < parentStates; ++i) {
                    for (int j = 0; j < childStates; ++j) {
                        betas[i][j] = Double.NEGATIVE_INFINITY;
                    }
                }
                tempUnaryBetas.put(parentLabel, childLabel, betas);
            }
            double[] childWeights = this.neginfDoubles(childStates);
            for (int i = 0; i < parentStates; ++i) {
                for (int j = 0; j < childStates; ++j) {
                    double weight2 = transitions[i][j];
                    betas[i][j] = SloppyMath.logAdd(betas[i][j], weight2 + stateWeights[i]);
                    childWeights[j] = SloppyMath.logAdd(childWeights[j], weight2 + stateWeights[i]);
                }
            }
            position = this.recalculateTemporaryBetas(tree.children()[0], childWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas);
        } else {
            parentLabel = tree.label().value();
            String leftLabel = tree.children()[0].label().value();
            String rightLabel = tree.children()[1].label().value();
            double[][][] transitions = binaryTransitions.get(tree);
            int parentStates = transitions.length;
            int leftStates = transitions[0].length;
            int rightStates = transitions[0][0].length;
            double[][][] betas = tempBinaryBetas.get(parentLabel, leftLabel, rightLabel);
            if (betas == null) {
                betas = new double[parentStates][leftStates][rightStates];
                for (int i = 0; i < parentStates; ++i) {
                    for (int j = 0; j < leftStates; ++j) {
                        for (int k = 0; k < rightStates; ++k) {
                            betas[i][j][k] = Double.NEGATIVE_INFINITY;
                        }
                    }
                }
                tempBinaryBetas.put(parentLabel, leftLabel, rightLabel, betas);
            }
            double[] leftWeights = this.neginfDoubles(leftStates);
            double[] rightWeights = this.neginfDoubles(rightStates);
            for (int i = 0; i < parentStates; ++i) {
                for (int j = 0; j < leftStates; ++j) {
                    for (int k = 0; k < rightStates; ++k) {
                        double weight3 = transitions[i][j][k];
                        betas[i][j][k] = SloppyMath.logAdd(betas[i][j][k], weight3 + stateWeights[i]);
                        leftWeights[j] = SloppyMath.logAdd(leftWeights[j], weight3 + stateWeights[i]);
                        rightWeights[k] = SloppyMath.logAdd(rightWeights[k], weight3 + stateWeights[i]);
                    }
                }
            }
            position = this.recalculateTemporaryBetas(tree.children()[0], leftWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas);
            position = this.recalculateTemporaryBetas(tree.children()[1], rightWeights, position, unaryTransitions, binaryTransitions, totalStateMass, tempUnaryBetas, tempBinaryBetas);
        }
        return position;
    }

    public void rescaleTemporaryBetas(TwoDimensionalMap<String, String, double[][]> tempUnaryBetas, ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas) {
        for (String parent : tempUnaryBetas.firstKeySet()) {
            for (String child : tempUnaryBetas.get(parent).keySet()) {
                double[][] betas = tempUnaryBetas.get(parent).get(child);
                int parentStates = betas.length;
                int childStates = betas[0].length;
                for (int i = 0; i < parentStates; ++i) {
                    int j;
                    double sum = Double.NEGATIVE_INFINITY;
                    for (j = 0; j < childStates; ++j) {
                        sum = SloppyMath.logAdd(sum, betas[i][j]);
                    }
                    if (Double.isInfinite(sum)) {
                        for (j = 0; j < childStates; ++j) {
                            betas[i][j] = -Math.log(childStates);
                        }
                        continue;
                    }
                    j = 0;
                    while (j < childStates) {
                        double[] dArray = betas[i];
                        int n = j++;
                        dArray[n] = dArray[n] - sum;
                    }
                }
            }
        }
        for (String parent : tempBinaryBetas.firstKeySet()) {
            for (String left : tempBinaryBetas.get(parent).firstKeySet()) {
                for (String right : tempBinaryBetas.get(parent).get(left).keySet()) {
                    double[][][] betas = tempBinaryBetas.get(parent).get(left).get(right);
                    int parentStates = betas.length;
                    int leftStates = betas[0].length;
                    int rightStates = betas[0][0].length;
                    for (int i = 0; i < parentStates; ++i) {
                        int k;
                        int j;
                        double sum = Double.NEGATIVE_INFINITY;
                        for (j = 0; j < leftStates; ++j) {
                            for (k = 0; k < rightStates; ++k) {
                                sum = SloppyMath.logAdd(sum, betas[i][j][k]);
                            }
                        }
                        if (Double.isInfinite(sum)) {
                            for (j = 0; j < leftStates; ++j) {
                                for (k = 0; k < rightStates; ++k) {
                                    betas[i][j][k] = -Math.log(leftStates * rightStates);
                                }
                            }
                            continue;
                        }
                        for (j = 0; j < leftStates; ++j) {
                            k = 0;
                            while (k < rightStates) {
                                double[] dArray = betas[i][j];
                                int n = k++;
                                dArray[n] = dArray[n] - sum;
                            }
                        }
                    }
                }
            }
        }
    }

    public void recountTree(Tree tree, boolean splitStates, IdentityHashMap<Tree, double[][]> unaryTransitions, IdentityHashMap<Tree, double[][][]> binaryTransitions) {
        IdentityHashMap<Tree, double[]> probIn = new IdentityHashMap<Tree, double[]>();
        IdentityHashMap<Tree, double[]> probOut = new IdentityHashMap<Tree, double[]>();
        this.recountTree(tree, splitStates, probIn, probOut, unaryTransitions, binaryTransitions);
    }

    public void recountTree(Tree tree, boolean splitStates, IdentityHashMap<Tree, double[]> probIn, IdentityHashMap<Tree, double[]> probOut, IdentityHashMap<Tree, double[][]> unaryTransitions, IdentityHashMap<Tree, double[][][]> binaryTransitions) {
        this.recountInside(tree, splitStates, 0, probIn);
        if (this.DEBUG()) {
            System.out.println("ROOT PROBABILITY: " + probIn.get(tree)[0]);
        }
        this.recountOutside(tree, probIn, probOut);
        this.recountWeights(tree, probIn, probOut, unaryTransitions, binaryTransitions);
    }

    public void recountWeights(Tree tree, IdentityHashMap<Tree, double[]> probIn, IdentityHashMap<Tree, double[]> probOut, IdentityHashMap<Tree, double[][]> unaryTransitions, IdentityHashMap<Tree, double[][][]> binaryTransitions) {
        if (tree.isLeaf() || tree.isPreTerminal()) {
            return;
        }
        if (tree.children().length == 1) {
            int i;
            Tree child = tree.children()[0];
            String parentLabel = tree.label().value();
            String childLabel = child.label().value();
            double[][] betas = this.unaryBetas.get(parentLabel, childLabel);
            double[] childInside = probIn.get(child);
            double[] parentOutside = probOut.get(tree);
            int parentStates = betas.length;
            int childStates = betas[0].length;
            double[][] transitions = new double[parentStates][childStates];
            unaryTransitions.put(tree, transitions);
            for (i = 0; i < parentStates; ++i) {
                for (int j = 0; j < childStates; ++j) {
                    transitions[i][j] = parentOutside[i] + childInside[j] + betas[i][j];
                }
            }
            for (i = 0; i < parentStates; ++i) {
                int j;
                double total = Double.NEGATIVE_INFINITY;
                for (j = 0; j < childStates; ++j) {
                    total = SloppyMath.logAdd(total, transitions[i][j]);
                }
                if (Double.isInfinite(total)) {
                    double transition = -Math.log(childStates);
                    for (int j2 = 0; j2 < childStates; ++j2) {
                        transitions[i][j2] = transition;
                    }
                    continue;
                }
                for (j = 0; j < childStates; ++j) {
                    transitions[i][j] = transitions[i][j] - total;
                }
            }
            this.recountWeights(child, probIn, probOut, unaryTransitions, binaryTransitions);
        } else {
            int i;
            Tree left = tree.children()[0];
            Tree right = tree.children()[1];
            String parentLabel = tree.label().value();
            String leftLabel = left.label().value();
            String rightLabel = right.label().value();
            double[][][] betas = this.binaryBetas.get(parentLabel, leftLabel, rightLabel);
            double[] leftInside = probIn.get(left);
            double[] rightInside = probIn.get(right);
            double[] parentOutside = probOut.get(tree);
            int parentStates = betas.length;
            int leftStates = betas[0].length;
            int rightStates = betas[0][0].length;
            double[][][] transitions = new double[parentStates][leftStates][rightStates];
            binaryTransitions.put(tree, transitions);
            for (i = 0; i < parentStates; ++i) {
                for (int j = 0; j < leftStates; ++j) {
                    for (int k = 0; k < rightStates; ++k) {
                        transitions[i][j][k] = parentOutside[i] + leftInside[j] + rightInside[k] + betas[i][j][k];
                    }
                }
            }
            for (i = 0; i < parentStates; ++i) {
                int k;
                int j;
                double total = Double.NEGATIVE_INFINITY;
                for (j = 0; j < leftStates; ++j) {
                    for (k = 0; k < rightStates; ++k) {
                        total = SloppyMath.logAdd(total, transitions[i][j][k]);
                    }
                }
                if (Double.isInfinite(total)) {
                    double transition = -Math.log(leftStates * rightStates);
                    for (int j3 = 0; j3 < leftStates; ++j3) {
                        for (int k2 = 0; k2 < rightStates; ++k2) {
                            transitions[i][j3][k2] = transition;
                        }
                    }
                    continue;
                }
                for (j = 0; j < leftStates; ++j) {
                    for (k = 0; k < rightStates; ++k) {
                        transitions[i][j][k] = transitions[i][j][k] - total;
                    }
                }
            }
            this.recountWeights(left, probIn, probOut, unaryTransitions, binaryTransitions);
            this.recountWeights(right, probIn, probOut, unaryTransitions, binaryTransitions);
        }
    }

    public void recountOutside(Tree tree, IdentityHashMap<Tree, double[]> probIn, IdentityHashMap<Tree, double[]> probOut) {
        double[] rootScores = new double[]{0.0};
        probOut.put(tree, rootScores);
        this.recurseOutside(tree, probIn, probOut);
    }

    public void recurseOutside(Tree tree, IdentityHashMap<Tree, double[]> probIn, IdentityHashMap<Tree, double[]> probOut) {
        if (tree.isLeaf() || tree.isPreTerminal()) {
            return;
        }
        if (tree.children().length == 1) {
            this.recountOutside(tree.children()[0], tree, probIn, probOut);
        } else {
            this.recountOutside(tree.children()[0], tree.children()[1], tree, probIn, probOut);
        }
    }

    public void recountOutside(Tree child, Tree parent, IdentityHashMap<Tree, double[]> probIn, IdentityHashMap<Tree, double[]> probOut) {
        String parentLabel = parent.label().value();
        String childLabel = child.label().value();
        double[] parentScores = probOut.get(parent);
        double[][] betas = this.unaryBetas.get(parentLabel, childLabel);
        int parentStates = betas.length;
        int childStates = betas[0].length;
        double[] scores = this.neginfDoubles(childStates);
        probOut.put(child, scores);
        for (int i = 0; i < parentStates; ++i) {
            for (int j = 0; j < childStates; ++j) {
                scores[j] = SloppyMath.logAdd(scores[j], betas[i][j] + parentScores[i]);
            }
        }
        this.recurseOutside(child, probIn, probOut);
    }

    public void recountOutside(Tree left, Tree right, Tree parent, IdentityHashMap<Tree, double[]> probIn, IdentityHashMap<Tree, double[]> probOut) {
        String parentLabel = parent.label().value();
        String leftLabel = left.label().value();
        String rightLabel = right.label().value();
        double[] leftInsideScores = probIn.get(left);
        double[] rightInsideScores = probIn.get(right);
        double[] parentScores = probOut.get(parent);
        double[][][] betas = this.binaryBetas.get(parentLabel, leftLabel, rightLabel);
        int parentStates = betas.length;
        int leftStates = betas[0].length;
        int rightStates = betas[0][0].length;
        double[] leftScores = this.neginfDoubles(leftStates);
        probOut.put(left, leftScores);
        double[] rightScores = this.neginfDoubles(rightStates);
        probOut.put(right, rightScores);
        for (int i = 0; i < parentStates; ++i) {
            for (int j = 0; j < leftStates; ++j) {
                for (int k = 0; k < rightStates; ++k) {
                    leftScores[j] = SloppyMath.logAdd(leftScores[j], betas[i][j][k] + parentScores[i] + rightInsideScores[k]);
                    rightScores[k] = SloppyMath.logAdd(rightScores[k], betas[i][j][k] + parentScores[i] + leftInsideScores[j]);
                }
            }
        }
        this.recurseOutside(left, probIn, probOut);
        this.recurseOutside(right, probIn, probOut);
    }

    public int recountInside(Tree tree, boolean splitStates, int loc, IdentityHashMap<Tree, double[]> probIn) {
        block18: {
            int k;
            int j;
            int i;
            block19: {
                int j2;
                int i2;
                block17: {
                    if (tree.isLeaf()) {
                        throw new RuntimeException();
                    }
                    if (!tree.isPreTerminal()) break block17;
                    int stateCount = this.getStateSplitCount(tree);
                    String word = tree.children()[0].label().value();
                    String tag = tree.label().value();
                    double[] scores = new double[stateCount];
                    probIn.put(tree, scores);
                    if (splitStates && !tag.equals(".$$.")) {
                        for (int i3 = 0; i3 < stateCount / 2; ++i3) {
                            IntTaggedWord tw = new IntTaggedWord(word, this.state(tag, i3), this.wordIndex, this.tagIndex);
                            double logProb = this.lex.score(tw, loc, word, null);
                            double wordWeight = 0.45 + this.random.nextDouble() * 0.1;
                            scores[i3 * 2] = logProb + Math.log(wordWeight);
                            scores[i3 * 2 + 1] = logProb + Math.log(1.0 - wordWeight);
                            if (!this.DEBUG()) continue;
                            System.out.println("Lexicon log prob " + this.state(tag, i3) + "-" + word + ": " + logProb);
                            System.out.println("  Log Split -> " + scores[i3 * 2] + "," + scores[i3 * 2 + 1]);
                        }
                    } else {
                        for (int i4 = 0; i4 < stateCount; ++i4) {
                            IntTaggedWord tw = new IntTaggedWord(word, this.state(tag, i4), this.wordIndex, this.tagIndex);
                            double prob = this.lex.score(tw, loc, word, null);
                            if (this.DEBUG()) {
                                System.out.println("Lexicon log prob " + this.state(tag, i4) + "-" + word + ": " + prob);
                            }
                            scores[i4] = prob;
                        }
                    }
                    ++loc;
                    break block18;
                }
                if (tree.children().length != 1) break block19;
                loc = this.recountInside(tree.children()[0], splitStates, loc, probIn);
                double[] childScores = probIn.get(tree.children()[0]);
                String parentLabel = tree.label().value();
                String childLabel = tree.children()[0].label().value();
                double[][] betas = this.unaryBetas.get(parentLabel, childLabel);
                int parentStates = betas.length;
                int childStates = betas[0].length;
                double[] scores = this.neginfDoubles(parentStates);
                probIn.put(tree, scores);
                for (i2 = 0; i2 < parentStates; ++i2) {
                    for (j2 = 0; j2 < childStates; ++j2) {
                        scores[i2] = SloppyMath.logAdd(scores[i2], childScores[j2] + betas[i2][j2]);
                    }
                }
                if (!this.DEBUG()) break block18;
                System.out.println(parentLabel + " -> " + childLabel);
                for (i2 = 0; i2 < parentStates; ++i2) {
                    System.out.println("  " + i2 + ":" + scores[i2]);
                    for (j2 = 0; j2 < childStates; ++j2) {
                        System.out.println("    " + i2 + "," + j2 + ": " + betas[i2][j2] + " | " + Math.exp(betas[i2][j2]));
                    }
                }
                break block18;
            }
            loc = this.recountInside(tree.children()[0], splitStates, loc, probIn);
            loc = this.recountInside(tree.children()[1], splitStates, loc, probIn);
            double[] leftScores = probIn.get(tree.children()[0]);
            double[] rightScores = probIn.get(tree.children()[1]);
            String parentLabel = tree.label().value();
            String leftLabel = tree.children()[0].label().value();
            String rightLabel = tree.children()[1].label().value();
            double[][][] betas = this.binaryBetas.get(parentLabel, leftLabel, rightLabel);
            int parentStates = betas.length;
            int leftStates = betas[0].length;
            int rightStates = betas[0][0].length;
            double[] scores = this.neginfDoubles(parentStates);
            probIn.put(tree, scores);
            for (i = 0; i < parentStates; ++i) {
                for (j = 0; j < leftStates; ++j) {
                    for (k = 0; k < rightStates; ++k) {
                        scores[i] = SloppyMath.logAdd(scores[i], leftScores[j] + rightScores[k] + betas[i][j][k]);
                    }
                }
            }
            if (this.DEBUG()) {
                System.out.println(parentLabel + " -> " + leftLabel + "," + rightLabel);
                for (i = 0; i < parentStates; ++i) {
                    System.out.println("  " + i + ":" + scores[i]);
                    for (j = 0; j < leftStates; ++j) {
                        for (k = 0; k < rightStates; ++k) {
                            System.out.println("    " + i + "," + j + "," + k + ": " + betas[i][j][k] + " | " + Math.exp(betas[i][j][k]));
                        }
                    }
                }
            }
        }
        return loc;
    }

    public void mergeStates() {
        if (this.op.trainOptions.splitRecombineRate <= 0.0) {
            return;
        }
        TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
        ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
        HashMap<String, double[]> totalStateMass = new HashMap<String, double[]>();
        this.recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas);
        HashMap<String, double[]> deltaAnnotations = new HashMap<String, double[]>();
        for (Tree tree : this.trees) {
            this.countMergeEffects(tree, totalStateMass, deltaAnnotations);
        }
        ArrayList<Triple<String, Integer, Double>> sortedDeltas = new ArrayList<Triple<String, Integer, Double>>();
        for (String state : deltaAnnotations.keySet()) {
            double[] scores = (double[])deltaAnnotations.get(state);
            for (int i = 0; i < scores.length; ++i) {
                sortedDeltas.add(new Triple<String, Integer, Double>(state, i * 2, scores[i]));
            }
        }
        Collections.sort(sortedDeltas, new Comparator<Triple<String, Integer, Double>>(){

            @Override
            public int compare(Triple<String, Integer, Double> first, Triple<String, Integer, Double> second) {
                return Double.compare(second.third(), first.third());
            }

            @Override
            public boolean equals(Object o) {
                return o == this;
            }
        });
        int splitsToMerge = (int)((double)sortedDeltas.size() * this.op.trainOptions.splitRecombineRate);
        splitsToMerge = Math.max(0, splitsToMerge);
        splitsToMerge = Math.min(sortedDeltas.size() - 1, splitsToMerge);
        sortedDeltas = sortedDeltas.subList(0, splitsToMerge);
        System.out.println();
        System.out.println(sortedDeltas);
        Map<String, int[]> mergeCorrespondence = this.buildMergeCorrespondence(sortedDeltas);
        this.recalculateMergedBetas(mergeCorrespondence);
        for (Triple triple : sortedDeltas) {
            this.stateSplitCounts.decrementCount((String)triple.first(), 1);
        }
    }

    public void recalculateMergedBetas(Map<String, int[]> mergeCorrespondence) {
        TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
        ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
        this.tempWordIndex = new HashIndex<String>();
        this.tempTagIndex = new HashIndex<String>();
        this.tempLex = this.op.tlpParams.lex(this.op, this.tempWordIndex, this.tempTagIndex);
        this.tempLex.initializeTraining(this.trainSize);
        for (Tree tree : this.trees) {
            double treeWeight = this.treeWeights.getCount(tree);
            double[] stateWeights = new double[]{Math.log(treeWeight)};
            this.tempLex.incrementTreesRead(treeWeight);
            IdentityHashMap<Tree, double[][]> oldUnaryTransitions = new IdentityHashMap<Tree, double[][]>();
            IdentityHashMap<Tree, double[][][]> oldBinaryTransitions = new IdentityHashMap<Tree, double[][][]>();
            this.recountTree(tree, false, oldUnaryTransitions, oldBinaryTransitions);
            IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<Tree, double[][]>();
            IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<Tree, double[][][]>();
            this.mergeTransitions(tree, oldUnaryTransitions, oldBinaryTransitions, unaryTransitions, binaryTransitions, stateWeights, mergeCorrespondence);
            this.recalculateTemporaryBetas(tree, stateWeights, 0, unaryTransitions, binaryTransitions, null, tempUnaryBetas, tempBinaryBetas);
        }
        this.tempLex.finishTraining();
        this.useNewBetas(false, tempUnaryBetas, tempBinaryBetas);
    }

    public void mergeTransitions(Tree parent, IdentityHashMap<Tree, double[][]> oldUnaryTransitions, IdentityHashMap<Tree, double[][][]> oldBinaryTransitions, IdentityHashMap<Tree, double[][]> newUnaryTransitions, IdentityHashMap<Tree, double[][][]> newBinaryTransitions, double[] stateWeights, Map<String, int[]> mergeCorrespondence) {
        if (parent.isPreTerminal() || parent.isLeaf()) {
            return;
        }
        if (parent.children().length == 1) {
            int j;
            int i;
            double[][] oldTransitions = oldUnaryTransitions.get(parent);
            String parentLabel = parent.label().value();
            int[] parentCorrespondence = mergeCorrespondence.get(parentLabel);
            int parentStates = parentCorrespondence[parentCorrespondence.length - 1] + 1;
            String childLabel = parent.children()[0].label().value();
            int[] childCorrespondence = mergeCorrespondence.get(childLabel);
            int childStates = childCorrespondence[childCorrespondence.length - 1] + 1;
            double[][] newTransitions = new double[parentStates][childStates];
            for (i = 0; i < parentStates; ++i) {
                for (int j2 = 0; j2 < childStates; ++j2) {
                    newTransitions[i][j2] = Double.NEGATIVE_INFINITY;
                }
            }
            newUnaryTransitions.put(parent, newTransitions);
            for (i = 0; i < oldTransitions.length; ++i) {
                int ti = parentCorrespondence[i];
                for (j = 0; j < oldTransitions[0].length; ++j) {
                    int tj = childCorrespondence[j];
                    newTransitions[ti][tj] = SloppyMath.logAdd(newTransitions[ti][tj], oldTransitions[i][j] + stateWeights[i]);
                }
            }
            for (i = 0; i < parentStates; ++i) {
                int j3;
                double total = Double.NEGATIVE_INFINITY;
                for (j3 = 0; j3 < childStates; ++j3) {
                    total = SloppyMath.logAdd(total, newTransitions[i][j3]);
                }
                if (Double.isInfinite(total)) {
                    for (j3 = 0; j3 < childStates; ++j3) {
                        newTransitions[i][j3] = -Math.log(childStates);
                    }
                    continue;
                }
                j3 = 0;
                while (j3 < childStates) {
                    double[] dArray = newTransitions[i];
                    int n = j3++;
                    dArray[n] = dArray[n] - total;
                }
            }
            double[] childWeights = this.neginfDoubles(oldTransitions[0].length);
            for (int i2 = 0; i2 < oldTransitions.length; ++i2) {
                for (j = 0; j < oldTransitions[0].length; ++j) {
                    double weight2 = oldTransitions[i2][j];
                    childWeights[j] = SloppyMath.logAdd(childWeights[j], weight2 + stateWeights[i2]);
                }
            }
            this.mergeTransitions(parent.children()[0], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, childWeights, mergeCorrespondence);
        } else {
            int j;
            int k;
            int i;
            double[][][] oldTransitions = oldBinaryTransitions.get(parent);
            String parentLabel = parent.label().value();
            int[] parentCorrespondence = mergeCorrespondence.get(parentLabel);
            int parentStates = parentCorrespondence[parentCorrespondence.length - 1] + 1;
            String leftLabel = parent.children()[0].label().value();
            int[] leftCorrespondence = mergeCorrespondence.get(leftLabel);
            int leftStates = leftCorrespondence[leftCorrespondence.length - 1] + 1;
            String rightLabel = parent.children()[1].label().value();
            int[] rightCorrespondence = mergeCorrespondence.get(rightLabel);
            int rightStates = rightCorrespondence[rightCorrespondence.length - 1] + 1;
            double[][][] newTransitions = new double[parentStates][leftStates][rightStates];
            for (i = 0; i < parentStates; ++i) {
                for (int j4 = 0; j4 < leftStates; ++j4) {
                    for (int k2 = 0; k2 < rightStates; ++k2) {
                        newTransitions[i][j4][k2] = Double.NEGATIVE_INFINITY;
                    }
                }
            }
            newBinaryTransitions.put(parent, newTransitions);
            for (i = 0; i < oldTransitions.length; ++i) {
                int ti = parentCorrespondence[i];
                for (int j5 = 0; j5 < oldTransitions[0].length; ++j5) {
                    int tj = leftCorrespondence[j5];
                    for (k = 0; k < oldTransitions[0][0].length; ++k) {
                        int tk = rightCorrespondence[k];
                        newTransitions[ti][tj][tk] = SloppyMath.logAdd(newTransitions[ti][tj][tk], oldTransitions[i][j5][k] + stateWeights[i]);
                    }
                }
            }
            for (i = 0; i < parentStates; ++i) {
                double total = Double.NEGATIVE_INFINITY;
                for (j = 0; j < leftStates; ++j) {
                    for (k = 0; k < rightStates; ++k) {
                        total = SloppyMath.logAdd(total, newTransitions[i][j][k]);
                    }
                }
                if (Double.isInfinite(total)) {
                    for (j = 0; j < leftStates; ++j) {
                        for (k = 0; k < rightStates; ++k) {
                            newTransitions[i][j][k] = -Math.log(leftStates * rightStates);
                        }
                    }
                    continue;
                }
                for (j = 0; j < leftStates; ++j) {
                    k = 0;
                    while (k < rightStates) {
                        double[] dArray = newTransitions[i][j];
                        int n = k++;
                        dArray[n] = dArray[n] - total;
                    }
                }
            }
            double[] leftWeights = this.neginfDoubles(oldTransitions[0].length);
            double[] rightWeights = this.neginfDoubles(oldTransitions[0][0].length);
            for (int i3 = 0; i3 < oldTransitions.length; ++i3) {
                for (j = 0; j < oldTransitions[0].length; ++j) {
                    for (k = 0; k < oldTransitions[0][0].length; ++k) {
                        double weight3 = oldTransitions[i3][j][k];
                        leftWeights[j] = SloppyMath.logAdd(leftWeights[j], weight3 + stateWeights[i3]);
                        rightWeights[k] = SloppyMath.logAdd(rightWeights[k], weight3 + stateWeights[i3]);
                    }
                }
            }
            this.mergeTransitions(parent.children()[0], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, leftWeights, mergeCorrespondence);
            this.mergeTransitions(parent.children()[1], oldUnaryTransitions, oldBinaryTransitions, newUnaryTransitions, newBinaryTransitions, rightWeights, mergeCorrespondence);
        }
    }

    Map<String, int[]> buildMergeCorrespondence(List<Triple<String, Integer, Double>> deltas) {
        int states;
        HashMap<String, int[]> mergeCorrespondence = new HashMap<String, int[]>();
        for (String string : this.originalStates) {
            states = this.getStateSplitCount(string);
            int[] correspondence = new int[states];
            for (int i = 0; i < states; ++i) {
                correspondence[i] = i;
            }
            mergeCorrespondence.put(string, correspondence);
        }
        for (Triple triple : deltas) {
            states = this.getStateSplitCount((String)triple.first());
            int split = (Integer)triple.second();
            int[] correspondence = (int[])mergeCorrespondence.get(triple.first());
            for (int i = split + 1; i < states; ++i) {
                correspondence[i] = correspondence[i] - 1;
            }
        }
        return mergeCorrespondence;
    }

    public void countMergeEffects(Tree tree, Map<String, double[]> totalStateMass, Map<String, double[]> deltaAnnotations) {
        IdentityHashMap<Tree, double[]> probIn = new IdentityHashMap<Tree, double[]>();
        IdentityHashMap<Tree, double[]> probOut = new IdentityHashMap<Tree, double[]>();
        IdentityHashMap<Tree, double[][]> unaryTransitions = new IdentityHashMap<Tree, double[][]>();
        IdentityHashMap<Tree, double[][][]> binaryTransitions = new IdentityHashMap<Tree, double[][][]>();
        this.recountTree(tree, false, probIn, probOut, unaryTransitions, binaryTransitions);
        for (Tree child : tree.children()) {
            this.countMergeEffects(child, totalStateMass, deltaAnnotations, probIn, probOut);
        }
    }

    public void countMergeEffects(Tree tree, Map<String, double[]> totalStateMass, Map<String, double[]> deltaAnnotations, IdentityHashMap<Tree, double[]> probIn, IdentityHashMap<Tree, double[]> probOut) {
        double[] stateMass;
        if (tree.isLeaf()) {
            return;
        }
        if (tree.label().value().equals(".$$.")) {
            return;
        }
        String label = tree.label().value();
        double totalMass = 0.0;
        for (double mass : stateMass = totalStateMass.get(label)) {
            totalMass += mass;
        }
        double[] nodeProbIn = probIn.get(tree);
        double[] nodeProbOut = probOut.get(tree);
        double[] nodeDelta = deltaAnnotations.get(label);
        if (nodeDelta == null) {
            nodeDelta = new double[nodeProbIn.length / 2];
            deltaAnnotations.put(label, nodeDelta);
        }
        for (int i = 0; i < nodeProbIn.length / 2; ++i) {
            double probInMerged = SloppyMath.logAdd(Math.log(stateMass[i * 2] / totalMass) + nodeProbIn[i * 2], Math.log(stateMass[i * 2 + 1] / totalMass) + nodeProbIn[i * 2 + 1]);
            double probOutMerged = SloppyMath.logAdd(nodeProbOut[i * 2], nodeProbOut[i * 2 + 1]);
            double probMerged = probInMerged + probOutMerged;
            double probUnmerged = SloppyMath.logAdd(nodeProbIn[i * 2] + nodeProbOut[i * 2], nodeProbIn[i * 2 + 1] + nodeProbOut[i * 2 + 1]);
            nodeDelta[i] = nodeDelta[i] + probMerged - probUnmerged;
        }
        if (tree.isPreTerminal()) {
            return;
        }
        for (Tree child : tree.children()) {
            this.countMergeEffects(child, totalStateMass, deltaAnnotations, probIn, probOut);
        }
    }

    public void buildStateIndex() {
        this.stateIndex = new HashIndex<String>();
        for (String key : this.stateSplitCounts.keySet()) {
            for (int i = 0; i < this.stateSplitCounts.getIntCount(key); ++i) {
                this.stateIndex.indexOf(this.state(key, i), true);
            }
        }
    }

    public void buildGrammars() {
        TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<String, String, double[][]>();
        ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<String, String, String, double[][][]>();
        HashMap<String, double[]> totalStateMass = new HashMap<String, double[]>();
        this.recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas);
        BinaryGrammar bg = new BinaryGrammar(this.stateIndex);
        for (String parent : tempBinaryBetas.firstKeySet()) {
            int parentStates = this.getStateSplitCount(parent);
            double[] stateTotal = (double[])totalStateMass.get(parent);
            for (String left : tempBinaryBetas.get(parent).firstKeySet()) {
                int leftStates = this.getStateSplitCount(left);
                for (String right : tempBinaryBetas.get(parent).get(left).keySet()) {
                    int rightStates = this.getStateSplitCount(right);
                    double[][][] betas = tempBinaryBetas.get(parent, left, right);
                    for (int i = 0; i < parentStates; ++i) {
                        if (stateTotal[i] < 1.0E-4) continue;
                        for (int j = 0; j < leftStates; ++j) {
                            for (int k = 0; k < rightStates; ++k) {
                                int parentIndex = this.stateIndex.indexOf(this.state(parent, i));
                                int leftIndex = this.stateIndex.indexOf(this.state(left, j));
                                int rightIndex = this.stateIndex.indexOf(this.state(right, k));
                                double score = betas[i][j][k] - Math.log(stateTotal[i]);
                                BinaryRule br = new BinaryRule(parentIndex, leftIndex, rightIndex, score);
                                bg.addRule(br);
                            }
                        }
                    }
                }
            }
        }
        UnaryGrammar ug = new UnaryGrammar(this.stateIndex);
        for (String parent : tempUnaryBetas.firstKeySet()) {
            int parentStates = this.getStateSplitCount(parent);
            double[] stateTotal = (double[])totalStateMass.get(parent);
            for (String child : tempUnaryBetas.get(parent).keySet()) {
                int childStates = this.getStateSplitCount(child);
                double[][] betas = tempUnaryBetas.get(parent, child);
                for (int i = 0; i < parentStates; ++i) {
                    if (stateTotal[i] < 1.0E-4) continue;
                    for (int j = 0; j < childStates; ++j) {
                        int parentIndex = this.stateIndex.indexOf(this.state(parent, i));
                        int childIndex = this.stateIndex.indexOf(this.state(child, j));
                        double score = betas[i][j] - Math.log(stateTotal[i]);
                        UnaryRule ur = new UnaryRule(parentIndex, childIndex, score);
                        ug.addRule(ur);
                    }
                }
            }
        }
        this.bgug = new Pair<UnaryGrammar, BinaryGrammar>(ug, bg);
    }

    public void saveTrees(Collection<Tree> trees1, double weight1, Collection<Tree> trees2, double weight2) {
        this.trainSize = 0.0;
        int treeCount = 0;
        this.trees.clear();
        this.treeWeights.clear();
        for (Tree tree : trees1) {
            this.trees.add(tree);
            this.treeWeights.incrementCount(tree, weight1);
            this.trainSize += weight1;
        }
        treeCount += trees1.size();
        if (trees2 != null && weight2 >= 0.0) {
            for (Tree tree : trees2) {
                this.trees.add(tree);
                this.treeWeights.incrementCount(tree, weight2);
                this.trainSize += weight2;
            }
            treeCount += trees2.size();
        }
        System.err.println("Found " + treeCount + " trees with total weight " + this.trainSize);
    }

    public void extract(Collection<Tree> treeList) {
        this.extract(treeList, 1.0, null, 0.0);
    }

    public void extract(Collection<Tree> trees1, double weight1, Collection<Tree> trees2, double weight2) {
        this.saveTrees(trees1, weight1, trees2, weight2);
        this.countOriginalStates();
        this.initialBetasAndLexicon();
        for (int cycle = 0; cycle < this.op.trainOptions.splitCount; ++cycle) {
            this.splitStateCounts();
            this.recalculateBetas(true);
            this.iteration = 0;
            boolean converged = false;
            while (!converged && this.iteration < Integer.MAX_VALUE) {
                if (this.DEBUG()) {
                    System.out.println();
                    System.out.println();
                    System.out.println("-------------------");
                    System.out.println("Iteration " + this.iteration);
                }
                converged = this.recalculateBetas(false);
                ++this.iteration;
            }
            System.err.println("Converged for cycle " + cycle + " in " + this.iteration + " iterations");
            this.mergeStates();
        }
        this.buildStateIndex();
        this.buildGrammars();
    }
}

