chainopy.nn¶
Functions
|
KL Divergance between MarkovChain.tpm and MarkovChain().fit(MarkovChainNeuralNetwork.simulate_random_walk).tpm. |
Classes
|
Neural network for simulating Markov chain behavior. |
- class chainopy.nn.MarkovChainNeuralNetwork(markov_chain, num_layers)[source]¶
Neural network for simulating Markov chain behavior.
- Parameters:
markov_chain (chainopy.MarkovChain) – Markov chain object.
num_layers (int) – Number of layers in the neural network.
- Raises:
ValueError – If markov_chain is not of type MarkovChain.:
- __init__(markov_chain, num_layers)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]¶
Forward pass of the neural network.
- Parameters:
x (torch.tensor) – Input data.
- Returns:
torch.Tensor: Output data after passing through the network.
- get_weights()[source]¶
Returns the weights of the model.
- Returns:
dict: Dictionary containing layer names and corresponding weights.
- simulate_random_walk(start_state, steps)[source]¶
Simulates a random walk based on the trained model.
- train_model(num_samples, epochs, learning_rate, momentum=0.9, verbose=True, patience=500, factor=0.5)[source]¶
Trains the neural network model.
- Parameters:
num_samples (int) – Number of training samples.
epochs (int) – Number of training epochs.
learning_rate (float) – Learning rate for optimization.
momentum (float) – Momentum factor (default is 0.9).
verbose (bool, optional) – If True, prints training progress (default is True).
patience (int, optional) – Patience parameter for learning rate scheduler (default is 500).
factor (float, optional) – Factor by which the learning rate will be reduced (default is 0.5).
- chainopy.nn.divergance_analysis(mc, nn)[source]¶
KL Divergance between MarkovChain.tpm and MarkovChain().fit(MarkovChainNeuralNetwork.simulate_random_walk).tpm.
- Parameters:
mc (MarkovChain) – Original Markov Chain that is used to fit the MarkovChainNeuralNetwork.
nn (MarkovChainNeuralNetwork) – The fitted MarkovChainNeuralNetwork.
- Return type:
- Returns:
- float: KL-Divergance
Lower the KL-Divergance, better the fit.
Notes
KL-Divergance<https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>_.