/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.runtime.tasks;

import java.nio.file.Path;
import java.util.HashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nonnegative;
import javax.annotation.Nullable;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.DoneFuture;
import org.apache.flink.runtime.state.InputStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.LocalRecoveryConfig;
import org.apache.flink.runtime.state.LocalSnapshotDirectoryProvider;
import org.apache.flink.runtime.state.LocalSnapshotDirectoryProviderImpl;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OutputStateHandle;
import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.runtime.state.TaskExecutorStateChangelogStoragesManager;
import org.apache.flink.runtime.state.TaskLocalStateStore;
import org.apache.flink.runtime.state.TaskLocalStateStoreImpl;
import org.apache.flink.runtime.state.TaskStateManager;
import org.apache.flink.runtime.state.TaskStateManagerImpl;
import org.apache.flink.runtime.state.TestTaskStateManager;
import org.apache.flink.runtime.state.changelog.StateChangelogStorage;
import org.apache.flink.runtime.state.changelog.inmemory.InMemoryStateChangelogStorage;
import org.apache.flink.runtime.taskmanager.CheckpointResponder;
import org.apache.flink.runtime.taskmanager.TestCheckpointResponder;
import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
import org.apache.flink.streaming.runtime.tasks.AsyncCheckpointRunnable;
import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment;
import org.apache.flink.streaming.runtime.tasks.StreamTaskITCase;
import org.apache.flink.testutils.junit.utils.TempDirUtils;
import org.apache.flink.util.concurrent.Executors;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.AtomicBooleanAssert;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mockito;

class LocalStateForwardingTest {
    @TempDir
    private Path temporaryFolder;

    LocalStateForwardingTest() {
    }

    @Test
    void testReportingFromSnapshotToTaskStateManager() throws Exception {
        TestTaskStateManager taskStateManager = new TestTaskStateManager();
        StreamMockEnvironment streamMockEnvironment = new StreamMockEnvironment(new Configuration(), new Configuration(), new ExecutionConfig(), 0x100000L, new MockInputSplitProvider(), 0, (TaskStateManager)taskStateManager);
        StreamTaskITCase.NoOpStreamTask testStreamTask = new StreamTaskITCase.NoOpStreamTask((Environment)streamMockEnvironment);
        CheckpointMetaData checkpointMetaData = new CheckpointMetaData(0L, 0L);
        CheckpointMetricsBuilder checkpointMetrics = new CheckpointMetricsBuilder();
        HashMap<OperatorID, OperatorSnapshotFutures> snapshots = new HashMap<OperatorID, OperatorSnapshotFutures>(1);
        OperatorSnapshotFutures osFuture = new OperatorSnapshotFutures();
        osFuture.setKeyedStateManagedFuture(LocalStateForwardingTest.createSnapshotResult(KeyedStateHandle.class));
        osFuture.setKeyedStateRawFuture(LocalStateForwardingTest.createSnapshotResult(KeyedStateHandle.class));
        osFuture.setOperatorStateManagedFuture(LocalStateForwardingTest.createSnapshotResult(OperatorStateHandle.class));
        osFuture.setOperatorStateRawFuture(LocalStateForwardingTest.createSnapshotResult(OperatorStateHandle.class));
        osFuture.setInputChannelStateFuture(LocalStateForwardingTest.createSnapshotCollectionResult(InputStateHandle.class));
        osFuture.setResultSubpartitionStateFuture(LocalStateForwardingTest.createSnapshotCollectionResult(OutputStateHandle.class));
        OperatorID operatorID = new OperatorID();
        snapshots.put(operatorID, osFuture);
        AsyncCheckpointRunnable checkpointRunnable = new AsyncCheckpointRunnable(snapshots, checkpointMetaData, checkpointMetrics, 0L, testStreamTask.getName(), asyncCheckpointRunnable -> {}, testStreamTask.getEnvironment(), testStreamTask, false, false, () -> true);
        checkpointMetrics.setAlignmentDurationNanos(0L);
        checkpointMetrics.setBytesProcessedDuringAlignment(0L);
        checkpointRunnable.run();
        TaskStateSnapshot lastJobManagerTaskStateSnapshot = taskStateManager.getLastJobManagerTaskStateSnapshot();
        TaskStateSnapshot lastTaskManagerTaskStateSnapshot = taskStateManager.getLastTaskManagerTaskStateSnapshot();
        OperatorSubtaskState jmState = lastJobManagerTaskStateSnapshot.getSubtaskStateByOperatorID(operatorID);
        OperatorSubtaskState tmState = lastTaskManagerTaskStateSnapshot.getSubtaskStateByOperatorID(operatorID);
        LocalStateForwardingTest.performCheck(osFuture.getKeyedStateManagedFuture(), jmState.getManagedKeyedState(), tmState.getManagedKeyedState());
        LocalStateForwardingTest.performCheck(osFuture.getKeyedStateRawFuture(), jmState.getRawKeyedState(), tmState.getRawKeyedState());
        LocalStateForwardingTest.performCheck(osFuture.getOperatorStateManagedFuture(), jmState.getManagedOperatorState(), tmState.getManagedOperatorState());
        LocalStateForwardingTest.performCheck(osFuture.getOperatorStateRawFuture(), jmState.getRawOperatorState(), tmState.getRawOperatorState());
        LocalStateForwardingTest.performCollectionCheck(osFuture.getInputChannelStateFuture(), jmState.getInputChannelState(), tmState.getInputChannelState());
        LocalStateForwardingTest.performCollectionCheck(osFuture.getResultSubpartitionStateFuture(), jmState.getResultSubpartitionState(), tmState.getResultSubpartitionState());
    }

    @Test
    void testReportingFromTaskStateManagerToResponderAndTaskLocalStateStore() throws Exception {
        final JobID jobID = new JobID();
        AllocationID allocationID = new AllocationID();
        final CheckpointMetaData checkpointMetaData = new CheckpointMetaData(42L, 4711L);
        final CheckpointMetrics checkpointMetrics = new CheckpointMetrics();
        int subtaskIdx = 42;
        JobVertexID jobVertexID = new JobVertexID();
        final ExecutionAttemptID executionAttemptID = ExecutionGraphTestUtils.createExecutionAttemptId((JobVertexID)jobVertexID, (int)42, (int)0);
        TaskStateSnapshot jmSnapshot = new TaskStateSnapshot();
        final TaskStateSnapshot tmSnapshot = new TaskStateSnapshot();
        final AtomicBoolean jmReported = new AtomicBoolean(false);
        final AtomicBoolean tmReported = new AtomicBoolean(false);
        TestCheckpointResponder checkpointResponder = new TestCheckpointResponder(){

            public void acknowledgeCheckpoint(JobID lJobID, ExecutionAttemptID lExecutionAttemptID, long lCheckpointId, CheckpointMetrics lCheckpointMetrics, TaskStateSnapshot lSubtaskState) {
                Assertions.assertThat((Comparable)lJobID).isEqualTo((Object)jobID);
                Assertions.assertThat((Object)lExecutionAttemptID).isEqualTo((Object)executionAttemptID);
                Assertions.assertThat((long)lCheckpointId).isEqualTo(checkpointMetaData.getCheckpointId());
                Assertions.assertThat((Object)lCheckpointMetrics).isEqualTo((Object)checkpointMetrics);
                jmReported.set(true);
            }
        };
        Executor executor = Executors.directExecutor();
        LocalSnapshotDirectoryProviderImpl directoryProvider = new LocalSnapshotDirectoryProviderImpl(TempDirUtils.newFolder((Path)this.temporaryFolder), jobID, jobVertexID, 42);
        LocalRecoveryConfig localRecoveryConfig = LocalRecoveryConfig.backupAndRecoveryEnabled((LocalSnapshotDirectoryProvider)directoryProvider);
        TaskLocalStateStoreImpl taskLocalStateStore = new TaskLocalStateStoreImpl(jobID, allocationID, jobVertexID, 42, localRecoveryConfig, executor){

            public void storeLocalState(@Nonnegative long checkpointId, @Nullable TaskStateSnapshot localState) {
                Assertions.assertThat((Object)localState).isEqualTo((Object)tmSnapshot);
                tmReported.set(true);
            }
        };
        InMemoryStateChangelogStorage stateChangelogStorage = new InMemoryStateChangelogStorage();
        TaskStateManagerImpl taskStateManager = new TaskStateManagerImpl(jobID, executionAttemptID, (TaskLocalStateStore)taskLocalStateStore, null, (StateChangelogStorage)stateChangelogStorage, new TaskExecutorStateChangelogStoragesManager(), null, (CheckpointResponder)checkpointResponder);
        taskStateManager.reportTaskStateSnapshots(checkpointMetaData, checkpointMetrics, jmSnapshot, tmSnapshot);
        ((AtomicBooleanAssert)Assertions.assertThat((AtomicBoolean)jmReported).as("Reporting for JM state was not called.", new Object[0])).isTrue();
        ((AtomicBooleanAssert)Assertions.assertThat((AtomicBoolean)tmReported).as("Reporting for TM state was not called.", new Object[0])).isTrue();
    }

    private static <T extends StateObject> void performCheck(Future<SnapshotResult<T>> resultFuture, StateObjectCollection<T> jmState, StateObjectCollection<T> tmState) {
        SnapshotResult<T> snapshotResult;
        try {
            snapshotResult = resultFuture.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        Assertions.assertThat((Object)((StateObject)jmState.iterator().next())).isEqualTo((Object)snapshotResult.getJobManagerOwnedSnapshot());
        Assertions.assertThat((Object)((StateObject)tmState.iterator().next())).isEqualTo((Object)snapshotResult.getTaskLocalSnapshot());
    }

    private static <T extends StateObject> void performCollectionCheck(Future<SnapshotResult<StateObjectCollection<T>>> resultFuture, StateObjectCollection<T> jmState, StateObjectCollection<T> tmState) {
        SnapshotResult<StateObjectCollection<T>> snapshotResult;
        try {
            snapshotResult = resultFuture.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        Assertions.assertThat(jmState).isEqualTo((Object)snapshotResult.getJobManagerOwnedSnapshot());
        Assertions.assertThat(tmState).isEqualTo((Object)snapshotResult.getTaskLocalSnapshot());
    }

    private static <T extends StateObject> RunnableFuture<SnapshotResult<T>> createSnapshotResult(Class<T> clazz) {
        return DoneFuture.of((Object)SnapshotResult.withLocalState((StateObject)((StateObject)Mockito.mock(clazz)), (StateObject)((StateObject)Mockito.mock(clazz))));
    }

    private static <T extends StateObject> RunnableFuture<SnapshotResult<StateObjectCollection<T>>> createSnapshotCollectionResult(Class<T> clazz) {
        return DoneFuture.of((Object)SnapshotResult.withLocalState((StateObject)StateObjectCollection.singleton((StateObject)((StateObject)Mockito.mock(clazz))), (StateObject)StateObjectCollection.singleton((StateObject)((StateObject)Mockito.mock(clazz)))));
    }
}

