import React, {Component} from 'react'
import {connect} from 'react-redux'
import { Progress, Button } from 'semantic-ui-react'
import * as tf from '@tensorflow/tfjs'
import {labelEncodeData} from "./ModelScripts/labelEncodeData";
import {
    XYPlot,
    XAxis,
    MarkSeries,
    YAxis,
    FlexibleWidthXYPlot,
    VerticalBarSeries,
    Hint
} from 'react-vis';
import { getData } from "./data";

class AssessIndex extends Component {
    constructor(props) {
        super(props);

        this.state = {
            isTraining: true,
            trainingProgress: 0,
            epochs: [],
            loss: [],
            trainingAccuracy: [],
            testAccuracy: 0,
            data: {},
            model: null
        }

        // this.downloadModel = this.downloadModel.bind(this);
    }

    componentDidMount(){
        // tf.tidy(() => {
            let targetValue = [];
            let data = [];

            const dataAndClassesValues = labelEncodeData(this.props.fileInformation.slice(1, this.props.fileInformation.length));

            const labelEncodedData = dataAndClassesValues['data']
            const labelClasses = dataAndClassesValues['labelClasses']

            const [xTrain, yTrain, xTest, yTest, xTrainRaw, yTrainRaw, xTestRaw, yTestRaw] = getData((parseInt(this.props.testSplit) / 100), labelEncodedData, labelClasses);
            this.setState({
                data: {
                    xTest: xTest,
                    yTest: yTest
                }
            })

            const HIDDEN_SIZE = parseInt(this.props.mlHyperParameters.neuralNetwork.units)
            const layers = parseInt(this.props.mlHyperParameters.neuralNetwork.layers)

            let model = tf.sequential()
            model.add(tf.layers.dense(
                {units: 10, activation: 'sigmoid', inputShape: [xTrain.shape[1]]}
            ));
            model.add(tf.layers.dense(
                {units: labelClasses.length, activation: 'softmax'}
            ));

            model.compile({
                optimizer: this.props.mlHyperParameters.optimizer === "adam" ? tf.train.adam(.001) : tf.train.sgd(.001),
                loss: this.props.mlHyperParameters.loss === "mse" ? tf.losses.meanSquaredError : tf.losses.softmaxCrossEntropy,
                metrics: ['accuracy']
            })

            model.fit(xTrain, yTrain, {
                epochs: 51,
                // validationData: [xTest, yTest],
                callbacks: {
                    onEpochEnd: async (epoch, logs) => {
                        // Plot the loss and accuracy values at the end of every training epoch.
                        if (epoch % 10 === 0) {
                            this.setState({ trainingProgress: epoch * 2 })

                            // if (epoch === 50){
                            //     this.setState({ isTraining: false })
                            // }
                        }
                    },
                }
            }).then((info) => {
                console.log(info)
                const xTest = this.state.data.xTest;
                const yTest = this.state.data.yTest;

                const xData = xTest.dataSync();
                const yTrue = yTest.argMax(-1).dataSync();
                const predictOut = model.predict(xTest);
                const yPred = predictOut.argMax(-1).dataSync();
                console.log(xData)
                console.log(yTrue)
                console.log(yPred)

                let accuracy = 0.0;

                for (let i=0; i<yPred.length; i++){
                    if (yPred[i] === yTrue[i]){
                        accuracy += 1.0
                    }
                }

                const testAccuracy = accuracy / yTrue.length;

                this.setState({
                    isTraining: false,
                    epochs: info.epoch,
                    loss: info.history.loss,
                    testAccuracy: testAccuracy,
                    trainingAccuracy: info.history.acc,
                    model: model
                })

            }).catch(err => console.log(err))
    }



    getProgressText = () => {
        let text = "";

        if (this.state.isTraining){
            text = "Your model is training..." + ( this.state.trainingProgress > 0 ?
                "(" + this.state.trainingProgress.toString() + "%)" : "" )
        }

        return text
    }

    changeStep = (step) => {
        this.props.changeStep(step)
    }

    download = () => {
        this.state.model.save('downloads://my-model');
    }

    render() {
        if (this.state.isTraining){
            return (
                <div>
                    <Progress percent={this.state.trainingProgress} indicating />
                    <h6 style={{textAlign: "center"}}>{this.getProgressText()}</h6>
                </div>
            )
        } else {
            console.log(this.state)
            return (
                <div>
                    <div className="row">
                        <div className="col-sm-3">
                            <Button color="yellow"
                                    fluid
                                    icon="left arrow"
                                    content="Re-select Hyperparameters"
                                    onClick={(e) => this.changeStep(2)}/>
                        </div>
                        <div className="col-sm-1" />
                        <div className="col-sm-4">
                            <div style={{boxShadow: "0 4px 8px 0 rgba(0,0,0,0.2),0 6px 20px 0 rgba(0,0,0,0.19)",
                                textAlign: "center", padding: "15px", borderRadius: "5px"}}>
                                <h4>Test Accuracy</h4>
                                <h1>{Math.floor(this.state.testAccuracy * 100).toString()}%</h1>
                            </div>
                        </div>
                        <div className="col-sm-1" />
                        <div className="col-sm-3">
                            <Button icon="download" fluid color="green" content="Download Model" onClick={this.download} />
                        </div>
                    </div>
                    <br/><br/>
                    <div className="row">
                        <div className="col-sm-1" />
                        <div className="col-sm-5">
                            <h4>Error per Training Step</h4>
                            <FlexibleWidthXYPlot
                                xType="ordinal"
                                height={400}>
                                <XAxis
                                    title="Training Step"
                                    tickValues={[10, 20, 30, 40, 50]}
                                />
                                <YAxis title="Error" />
                                <MarkSeries
                                    data={this.state.epochs.map((x, idx) => { return { x: x, y: this.state.loss[idx] } })}
                                />
                            </FlexibleWidthXYPlot>
                        </div>
                        <div className="col-sm-5">
                            <h4>Accuracy per Training Step</h4>
                            <FlexibleWidthXYPlot
                                xType="ordinal"
                                color="orange"
                                height={400}>
                                <XAxis
                                    title="Training Step"
                                    tickValues={[10, 20, 30, 40, 50]}
                                />
                                <YAxis title="Error" />
                                <MarkSeries
                                    data={this.state.epochs.map((x, idx) => { return { x: x, y: this.state.trainingAccuracy[idx] } })}
                                />
                            </FlexibleWidthXYPlot>
                        </div>
                        <div className="col-sm-1" />
                    </div>

                </div>
            )
        }
    }
}

const mapStateToProps = (state) => ({
    fileInformation: state.mainState.fileInformation,
    mlHyperParameters: state.mainState.mlHyperParameters,
    testSplit: state.mainState.testSplit
})

const mapActionsToProps = {}

export default connect(mapStateToProps, mapActionsToProps)(AssessIndex)