/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.mining.word2vec;

import com.hankcs.hanlp.mining.word2vec.CacheCorpus;
import com.hankcs.hanlp.mining.word2vec.Config;
import com.hankcs.hanlp.mining.word2vec.Corpus;
import com.hankcs.hanlp.mining.word2vec.TextFileCorpus;
import com.hankcs.hanlp.mining.word2vec.Utils;
import com.hankcs.hanlp.mining.word2vec.VocabWord;
import com.hankcs.hanlp.utility.Predefine;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.nio.charset.Charset;
import java.util.Comparator;

class Word2VecTraining {
    static final int EXP_TABLE_SIZE = 1000;
    static final int MAX_EXP = 6;
    static final int TABLE_SIZE = 100000000;
    static final int MAX_SENTENCE_LENGTH = 1000;
    static final Charset ENCODING = Charset.forName("UTF-8");
    long timeStart;
    static double[] syn0;
    static double[] syn1;
    static double[] syn1neg;
    int[] table;
    private final Config config;
    static final double[] expTable;
    int threadCount;

    public Word2VecTraining(Config config) {
        this.config = config;
    }

    public Config getConfig() {
        return this.config;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void trainModel() throws IOException {
        int layer1Size = this.config.getLayer1Size();
        TextFileCorpus corpus = new TextFileCorpus(this.config);
        Predefine.logger.info("learning vocabulary");
        corpus.learnVocab();
        Predefine.logger.info("sorting vocabulary");
        corpus.sortVocab();
        int vocabSize = corpus.getVocabSize();
        VocabWord[] vocab = corpus.getVocab();
        Predefine.logger.info("Vocab size: " + vocabSize);
        Predefine.logger.info("Words in train file: " + corpus.getTrainWords());
        if (this.config.getOutputFile() == null) {
            return;
        }
        this.initNet(corpus);
        if (this.config.getNegative() > 0) {
            this.initUnigramTable(corpus);
        }
        this.timeStart = System.currentTimeMillis();
        this.threadCount = this.config.getNumThreads();
        for (int i = 0; i < this.config.getNumThreads(); ++i) {
            new TrainModelThread(this, new CacheCorpus(corpus), this.config, i).start();
        }
        corpus.shutdown();
        Word2VecTraining i = this;
        synchronized (i) {
            while (this.threadCount > 0) {
                try {
                    this.wait();
                }
                catch (InterruptedException ignored) {}
            }
        }
        System.err.println();
        Predefine.logger.info(String.format("finished training in %s", Utils.humanTime(System.currentTimeMillis() - this.timeStart)));
        syn1 = null;
        this.table = null;
        FileOutputStream os = null;
        OutputStreamWriter w = null;
        PrintWriter pw = null;
        try {
            os = new FileOutputStream(this.config.getOutputFile());
            w = new OutputStreamWriter((OutputStream)os, ENCODING);
            pw = new PrintWriter(w);
            Predefine.logger.info("now saving the word vectors to the file " + this.config.getOutputFile());
            pw.printf("%d %d\n", vocabSize, layer1Size);
            for (int i2 = 0; i2 < vocabSize; ++i2) {
                pw.print(vocab[i2].word);
                for (int j = 0; j < layer1Size; ++j) {
                    pw.printf(" %f", syn0[i2 * layer1Size + j]);
                }
                pw.println();
            }
        }
        catch (Throwable throwable) {
            corpus.close();
            Utils.closeQuietly(pw);
            Utils.closeQuietly(w);
            Utils.closeQuietly(os);
            throw throwable;
        }
        corpus.close();
        Utils.closeQuietly(pw);
        Utils.closeQuietly(w);
        Utils.closeQuietly(os);
    }

    void initUnigramTable(Corpus corpus) {
        int i;
        int vocabSize = corpus.getVocabSize();
        VocabWord[] vocab = corpus.getVocab();
        long trainWordsPow = 0L;
        double power = 0.75;
        this.table = new int[100000000];
        for (i = 0; i < vocabSize; ++i) {
            trainWordsPow = (long)((double)trainWordsPow + Math.pow(vocab[i].cn, power));
        }
        i = 0;
        double d1 = Math.pow(vocab[i].cn, power) / (double)trainWordsPow;
        for (int j = 0; j < 100000000; ++j) {
            this.table[j] = i++;
            if ((double)j / 1.0E8 > d1) {
                d1 += Math.pow(vocab[i].cn, power) / (double)trainWordsPow;
            }
            if (i < vocabSize) continue;
            i = vocabSize - 1;
        }
    }

    void initNet(Corpus corpus) {
        int j;
        int i;
        int layer1Size = this.config.getLayer1Size();
        int vocabSize = corpus.getVocabSize();
        syn0 = Word2VecTraining.posixMemAlign128(vocabSize * layer1Size);
        if (this.config.useHierarchicalSoftmax()) {
            syn1 = Word2VecTraining.posixMemAlign128(vocabSize * layer1Size);
            for (i = 0; i < vocabSize; ++i) {
                for (j = 0; j < layer1Size; ++j) {
                    Word2VecTraining.syn1[i * layer1Size + j] = 0.0;
                }
            }
        }
        if (this.config.getNegative() > 0) {
            syn1neg = Word2VecTraining.posixMemAlign128(vocabSize * layer1Size);
            for (i = 0; i < vocabSize; ++i) {
                for (j = 0; j < layer1Size; ++j) {
                    Word2VecTraining.syn1neg[i * layer1Size + j] = 0.0;
                }
            }
        }
        long nextRandom = 1L;
        for (int i2 = 0; i2 < vocabSize; ++i2) {
            for (int j2 = 0; j2 < layer1Size; ++j2) {
                nextRandom = Word2VecTraining.nextRandom(nextRandom);
                Word2VecTraining.syn0[i2 * layer1Size + j2] = ((double)(nextRandom & 0xFFFFL) / 65536.0 - 0.5) / (double)layer1Size;
            }
        }
        corpus.createBinaryTree();
    }

    static double[] posixMemAlign128(int size) {
        int surplus = size % 128;
        if (surplus > 0) {
            int div = size / 128;
            return new double[(div + 1) * 128];
        }
        return new double[size];
    }

    static long nextRandom(long nextRandom) {
        return nextRandom * 25214903917L + 11L;
    }

    static {
        expTable = new double[1001];
        for (int i = 0; i < 1000; ++i) {
            Word2VecTraining.expTable[i] = Math.exp(((double)i / 1000.0 * 2.0 - 1.0) * 6.0);
            Word2VecTraining.expTable[i] = expTable[i] / (expTable[i] + 1.0);
        }
    }

    static class VocabWordComparator
    implements Comparator<VocabWord> {
        VocabWordComparator() {
        }

        @Override
        public int compare(VocabWord o1, VocabWord o2) {
            return o2.cn - o1.cn;
        }
    }

    static class TrainModelThread
    extends Thread {
        final Word2VecTraining vec;
        final Corpus corpus;
        final Config config;
        float alpha;
        final float startingAlpha;
        final float trainWords;
        final int id;
        final int vocabSize;
        final long timeStart;
        final int[] table;
        final VocabWord[] vocab;
        static int wordCountActual = 0;

        public TrainModelThread(Word2VecTraining vec, Corpus corpus, Config config, int id) {
            this.vec = vec;
            this.corpus = corpus;
            this.config = config;
            this.startingAlpha = this.alpha = config.getAlpha();
            this.id = id;
            this.table = vec.table;
            this.trainWords = corpus.getTrainWords();
            this.timeStart = vec.timeStart;
            this.vocabSize = corpus.getVocabSize();
            this.vocab = corpus.getVocab();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            float iter = this.config.getIter();
            int layer1Size = this.config.getLayer1Size();
            int numThreads = this.config.getNumThreads();
            int window = this.config.getWindow();
            int negative = this.config.getNegative();
            boolean cbow = this.config.useContinuousBagOfWords();
            boolean hs = this.config.useHierarchicalSoftmax();
            float sample = this.config.getSample();
            try {
                int word = 0;
                int sentence_length = 0;
                int sentence_position = 0;
                int[] sen = new int[1001];
                long word_count = 0L;
                long last_word_count = 0L;
                long local_iter = (int)iter;
                long next_random = this.id;
                double[] neu1 = new double[layer1Size];
                double[] neu1e = new double[layer1Size];
                this.corpus.rewind(numThreads, this.id);
                while (true) {
                    long label;
                    int target;
                    double g;
                    int l2;
                    double f;
                    int d;
                    int last_word;
                    int a;
                    int c;
                    if (word_count - last_word_count > 10000L) {
                        wordCountActual = (int)((long)wordCountActual + (word_count - last_word_count));
                        last_word_count = word_count;
                        long timeNow = System.currentTimeMillis();
                        float percent = (float)wordCountActual / (iter * this.trainWords + 1.0f);
                        long cost_time = timeNow - this.timeStart + 1L;
                        if (this.config.getCallback() == null) {
                            System.err.printf("%cAlpha: %f  iter: %d  Progress: %.2f%%  Words/thread/sec: %.2fk", 13, Float.valueOf(this.alpha), local_iter, Float.valueOf(percent * 100.0f), Float.valueOf((float)wordCountActual / (float)cost_time));
                            String etd = Utils.humanTime((long)((float)cost_time / percent * (1.0f - percent)));
                            if (etd.length() > 0) {
                                System.err.printf("  ETD: %s", etd);
                            }
                            System.err.flush();
                        } else {
                            this.config.getCallback().training(this.alpha, percent * 100.0f);
                        }
                        this.alpha = this.startingAlpha * (1.0f - (float)wordCountActual / (iter * this.trainWords + 1.0f));
                        if ((double)this.alpha < (double)this.startingAlpha * 1.0E-4) {
                            this.alpha = this.startingAlpha * 1.0E-4f;
                        }
                    }
                    if (sentence_length == 0) {
                        while ((word = this.corpus.readWordIndex()) != -2) {
                            double ran;
                            if (word == -1) continue;
                            ++word_count;
                            if (word == -3) break;
                            if (sample > 0.0f && (ran = (Math.sqrt((float)this.vocab[word].cn / (sample * this.trainWords)) + 1.0) * (double)(sample * this.trainWords) / (double)this.vocab[word].cn) < (double)((next_random = Word2VecTraining.nextRandom(next_random)) & 0xFFFFL) / 65536.0) continue;
                            sen[sentence_length] = word;
                            if (++sentence_length < 1000) continue;
                        }
                        sentence_position = 0;
                    }
                    if (word == -2 || (float)word_count > this.trainWords / (float)numThreads) {
                        wordCountActual = (int)((long)wordCountActual + (word_count - last_word_count));
                        if (--local_iter == 0L) break;
                        word_count = 0L;
                        last_word_count = 0L;
                        sentence_length = 0;
                        this.corpus.rewind(numThreads, this.id);
                        continue;
                    }
                    word = sen[sentence_position];
                    if (word == -1) continue;
                    for (c = 0; c < layer1Size; ++c) {
                        neu1[c] = 0.0;
                    }
                    for (c = 0; c < layer1Size; ++c) {
                        neu1e[c] = 0.0;
                    }
                    next_random = Word2VecTraining.nextRandom(next_random);
                    int b = (int)next_random % window;
                    if (cbow) {
                        long cw = 0L;
                        for (a = b; a < window * 2 + 1 - b; ++a) {
                            if (a == window || (c = sentence_position - window + a) < 0 || c >= sentence_length || (last_word = sen[c]) == -1) continue;
                            for (c = 0; c < layer1Size; ++c) {
                                int n = c;
                                neu1[n] = neu1[n] + syn0[c + last_word * layer1Size];
                            }
                            ++cw;
                        }
                        if (cw != 0L) {
                            c = 0;
                            while (c < layer1Size) {
                                int n = c++;
                                neu1[n] = neu1[n] / (double)cw;
                            }
                            if (hs) {
                                for (d = 0; d < this.vocab[word].codelen; ++d) {
                                    f = 0.0;
                                    l2 = this.vocab[word].point[d] * layer1Size;
                                    for (c = 0; c < layer1Size; ++c) {
                                        f += neu1[c] * syn1[c + l2];
                                    }
                                    if (f <= -6.0 || f >= 6.0) continue;
                                    f = expTable[(int)((f + 6.0) * 83.0)];
                                    g = ((double)('\u0001' - this.vocab[word].code[d]) - f) * (double)this.alpha;
                                    for (c = 0; c < layer1Size; ++c) {
                                        int n = c;
                                        neu1e[n] = neu1e[n] + g * syn1[c + l2];
                                    }
                                    for (c = 0; c < layer1Size; ++c) {
                                        int n = c + l2;
                                        syn1[n] = syn1[n] + g * neu1[c];
                                    }
                                }
                            }
                            if (negative > 0) {
                                for (d = 0; d < negative + 1; ++d) {
                                    if (d == 0) {
                                        target = word;
                                        label = 1L;
                                    } else {
                                        target = this.table[Math.abs((int)(((next_random = Word2VecTraining.nextRandom(next_random)) >> 16) % 100000000L))];
                                        if (target == 0) {
                                            target = Math.abs((int)(next_random % (long)(this.vocabSize - 1) + 1L));
                                        }
                                        if (target == word) continue;
                                        label = 0L;
                                    }
                                    l2 = target * layer1Size;
                                    f = 0.0;
                                    for (c = 0; c < layer1Size; ++c) {
                                        f += neu1[c] * syn1neg[c + l2];
                                    }
                                    g = f > 6.0 ? (double)((float)(label - 1L) * this.alpha) : (f < -6.0 ? (double)((float)(label - 0L) * this.alpha) : ((double)label - expTable[(int)((f + 6.0) * 83.0)]) * (double)this.alpha);
                                    for (c = 0; c < layer1Size; ++c) {
                                        int n = c;
                                        neu1e[n] = neu1e[n] + g * syn1neg[c + l2];
                                    }
                                    for (c = 0; c < layer1Size; ++c) {
                                        int n = c + l2;
                                        syn1neg[n] = syn1neg[n] + g * neu1[c];
                                    }
                                }
                            }
                            for (a = b; a < window * 2 + 1 - b; ++a) {
                                if (a == window || (c = sentence_position - window + a) < 0 || c >= sentence_length || (last_word = sen[c]) == -1) continue;
                                for (c = 0; c < layer1Size; ++c) {
                                    int n = c + last_word * layer1Size;
                                    syn0[n] = syn0[n] + neu1e[c];
                                }
                            }
                        }
                    } else {
                        for (a = b; a < window * 2 + 1 - b; ++a) {
                            if (a == window || (c = sentence_position - window + a) < 0 || c >= sentence_length || (last_word = sen[c]) == -1) continue;
                            int l1 = last_word * layer1Size;
                            for (c = 0; c < layer1Size; ++c) {
                                neu1e[c] = 0.0;
                            }
                            if (hs) {
                                for (d = 0; d < this.vocab[word].codelen; ++d) {
                                    f = 0.0;
                                    l2 = this.vocab[word].point[d] * layer1Size;
                                    for (c = 0; c < layer1Size; ++c) {
                                        f += syn0[c + l1] * syn1[c + l2];
                                    }
                                    if (f <= -6.0 || f >= 6.0) continue;
                                    f = expTable[(int)((f + 6.0) * 83.0)];
                                    g = ((double)('\u0001' - this.vocab[word].code[d]) - f) * (double)this.alpha;
                                    for (c = 0; c < layer1Size; ++c) {
                                        int n = c;
                                        neu1e[n] = neu1e[n] + g * syn1[c + l2];
                                    }
                                    for (c = 0; c < layer1Size; ++c) {
                                        int n = c + l2;
                                        syn1[n] = syn1[n] + g * syn0[c + l1];
                                    }
                                }
                            }
                            if (negative > 0) {
                                for (d = 0; d < negative + 1; ++d) {
                                    if (d == 0) {
                                        target = word;
                                        label = 1L;
                                    } else {
                                        target = this.table[Math.abs((int)(((next_random = Word2VecTraining.nextRandom(next_random)) >> 16) % 100000000L))];
                                        if (target == 0) {
                                            target = Math.abs((int)(next_random % (long)(this.vocabSize - 1) + 1L));
                                        }
                                        if (target == word) continue;
                                        label = 0L;
                                    }
                                    l2 = target * layer1Size;
                                    f = 0.0;
                                    for (c = 0; c < layer1Size; ++c) {
                                        f += syn0[c + l1] * syn1neg[c + l2];
                                    }
                                    g = f > 6.0 ? (double)((float)(label - 1L) * this.alpha) : (f < -6.0 ? (double)((float)(label - 0L) * this.alpha) : ((double)label - expTable[(int)((f + 6.0) * 83.0)]) * (double)this.alpha);
                                    for (c = 0; c < layer1Size; ++c) {
                                        int n = c;
                                        neu1e[n] = neu1e[n] + g * syn1neg[c + l2];
                                    }
                                    for (c = 0; c < layer1Size; ++c) {
                                        int n = c + l2;
                                        syn1neg[n] = syn1neg[n] + g * syn0[c + l1];
                                    }
                                }
                            }
                            for (c = 0; c < layer1Size; ++c) {
                                int n = c + l1;
                                syn0[n] = syn0[n] + neu1e[c];
                            }
                        }
                    }
                    if (++sentence_position < sentence_length) continue;
                    sentence_length = 0;
                }
                this.corpus.shutdown();
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
            Word2VecTraining word2VecTraining = this.vec;
            synchronized (word2VecTraining) {
                --this.vec.threadCount;
                this.vec.notify();
            }
        }
    }
}

