有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

DL4J中的java回归预测下一个时间步

我有一个经过训练的多层网络,但我被困在如何预测额外的时间步长上

我尝试按照角色迭代示例创建此方法-

public float[] sampleFromNetwork(INDArray testingData, int numTimeSteps, DataSetIterator iter){
    int inputCount = this.getNumOfInputs();
    int outputCount = this.getOutputCount();

    float[] samples = new float[numTimeSteps];

    //Sample from network (and feed samples back into input) one value at a time (for all samples)
    //Sampling is done in parallel here
    this.network.rnnClearPreviousState();
    INDArray output = this.network.rnnTimeStep(testingData);
    output = output.tensorAlongDimension(output.size(2)-1,1,0); //Gets the last time step output

    for( int i=0; i<numTimeSteps; ++i ){
        //Set up next input (single time step) by sampling from previous output
        INDArray nextInput = Nd4j.zeros(1,inputCount);

        //Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input
        double[] outputProbDistribution = new double[outputCount];
        for( int j=0; j<outputProbDistribution.length; j++ ) {
            outputProbDistribution[j] = output.getDouble(j);
        }
        int nextValue = sampleFromDistribution(outputProbDistribution, new Random());

        nextInput.putScalar(new int[]{0,nextValue}, 1.0f);      //Prepare next time step input
        samples[i] = (nextValue);   //Add sampled character to StringBuilder (human readable output)
        output = this.network.rnnTimeStep(nextInput);   //Do one time step of forward pass
    }

    return samples;
}

但是sampleFromDistribution()没有意义,因为我没有使用离散类

有什么想法吗


共 (1) 个答案

  1. # 1 楼答案

    我通过调整我的网络以使用身份激活来解决这个问题,并直接使用结果值。仍然有很多调整要做,但它的功能

    public float[] sampleFromNetwork(INDArray priori, int numTimeSteps){
        int inputCount = this.getNumOfInputs();
        float[] samples = new float[numTimeSteps];
    
        if(priori.size(1) != inputCount) {
            String format = String.format("the priori should have the same number of inputs [%s] as the trained network [%s]", priori.size(1), inputCount);
            throw new RuntimeException(format);
        }
        if(priori.size(2) < inputCount) {
            String format = String.format("the priori should have enough timesteps [%s] to prime the new inputs [%s]", priori.size(2), inputCount);
            throw new RuntimeException(format);
        }
    
        this.network.rnnClearPreviousState();
        INDArray output = this.network.rnnTimeStep(priori);
    
        output = output.ravel();
        // Store the output for use in the inputs
        LinkedList<Float> prevOutput = new LinkedList<>();
        for (int i = 0; i < output.length(); i++) {
            prevOutput.add(output.getFloat(0, i));
        }
    
        for( int i=0; i<numTimeSteps; ++i ){
            samples[i] = (prevOutput.peekLast());
            //Set up next input (single time step) by sampling from previous output
            INDArray nextInput = Nd4j.zeros(1,inputCount);
    
            float[] newInputs = new float[inputCount];
            newInputs[inputCount-1] = prevOutput.peekLast();
            for( int j=0; j<newInputs.length-1; j++ ) {
                newInputs[j] = prevOutput.get(prevOutput.size()-inputCount-j);
            }
    
            nextInput.assign(Nd4j.create(newInputs)); //Prepare next time step input
            output = this.network.rnnTimeStep(nextInput); //Do one time step of forward pass
            // Add the output to the end of the previous output queue
            prevOutput.addLast(output.ravel().getFloat(0, output.length()-1));
        }
        return samples;
    }