diff --git a/tutorials/experts/source_en/others/mixed_precision.md b/tutorials/experts/source_en/others/mixed_precision.md index c3bf768e1495a492bde8e7a9ffdf41f04b40c5f2..2c6b274c7684f50d873fb9a993680093a68fe404 100644 --- a/tutorials/experts/source_en/others/mixed_precision.md +++ b/tutorials/experts/source_en/others/mixed_precision.md @@ -10,21 +10,21 @@ Generally, when a neural network model is trained, the default data type is FP32 Floating-point data types include double-precision (FP64), single-precision (FP32), and half-precision (FP16). In a training process of a neural network model, an FP32 data type is generally used by default to indicate a network model weight and other parameters. The following is a brief introduction to floating-point data types. -According to IEEE 754, floating-point data types are classified into double-precision (FP64), single-precision (FP32), and half-precision (FP16). Each type is represented by three different bits. FP64 indicates a data type that uses 8 bytes (64 bits in total) for encoding and storage. FP32 indicates a data type that uses 4 bytes (32 bits in total) and FP16 indicates a data type that uses 2 bytes (16 bits in total). As shown in the following figure: +According to [IEEE 754](https://en.wikipedia.org/wiki/IEEE_754), floating-point data types are classified into double-precision (FP64), single-precision (FP32), and half-precision (FP16). Each type is represented by three different bits. FP64 indicates a data type that uses 8 bytes (64 bits in total) for encoding and storage. FP32 indicates a data type that uses 4 bytes (32 bits in total) and FP16 indicates a data type that uses 2 bytes (16 bits in total). As shown in the following figure: ![fp16_vs_FP32](./images/fp16_vs_fp32.png) As shown in the figure, the storage space of FP16 is half that of FP32, and the storage space of FP32 is half that of FP64. It consists of three parts: -- The leftmost bit indicates the sign bit. +- The highest bit indicates the sign bit. - The middle bits indicate exponent bits. -- The rightmost bits indicate fraction bits. +- The low bits indicate fraction bits. -FP16 is used as an example. The first sign bit sign indicates a positive or negative sign, the next five bits indicate an exponent, and the last 10 bits indicate a fraction. The formula is as follows: +FP16 is used as an example. The first sign bit sign indicates a positive or negative sign, and the next five bits indicate an exponent. All 0s and all 1s have special uses, so the binary range is 00001~11110. The last 10 bits indicate a fraction. Suppose `S` denotes the decimal value of sign bit, `E` denotes the decimal value of exponent, and `fraction` denotes the decimal value of fraction. The formula is as follows: $$x=(-1)^{S}\times2^{E-15}\times(1+\frac{fraction}{1024})$$ -Similarly, the true value of a formatted FP32 is as follows: +Similarly, suppose `M` is score value, the true value of a formatted FP32 is as follows: $$x=(-1)^{S}\times2^{E-127}\times(1.M)$$ @@ -36,11 +36,13 @@ The maximum value that can be represented by FP16 is 0 11110 1111111111, which i $$(-1)^0\times2^{30-15}\times1.1111111111 = 1.1111111111(b)\times2^15 = 1.9990234375(d)\times2^15 = 65504$$ +where `b` indicates binary value, and `d` indicates decimal value. + The minimum value that can be represented by FP16 is 0 00001 0000000000, which is calculated as follows: $$ (-1)^{1}\times2^{1-15}=2^{-14}=6.104×10^{-5}=-65504$$ -Therefore, the maximum value range of FP16 is [-65504,66504], and the precision range is $2^{-24}$. If the value is beyond this range, the value is set to 0. +Therefore, the maximum value range of FP16 is [-65504, 65504], and the precision range is $2^{-24}$. If the value is beyond this range, the value is set to 0. ## FP16 Training Issues @@ -71,25 +73,65 @@ The following figure shows the typical computation process of mixed precision in This document describes the computation process by using examples of automatic and manual mixed precision. -## MindSpore Mixed-precision +## Loss Scale -### Automatic Mixed-Precision +Loss Scale is mainly used in the process of mixed-precision training. -To use the automatic mixed-precision, you need to call the `Model` API to transfer the network to be trained and optimizer as the input. This API converts the network model operators into FP16 operators. +In the process of mixed precision training, the FP16 type is used instead of the FP32 type for data storage, so as to achieve the effect of reducing memory and improving the computing speed. However, because the FP16 type is much smaller than the range represented by the FP32 type, data underflow occurs when parameters (such as gradients) become very small during training. The Loss Scale is proposed to solve the underflow of FP16 type data. -> Due to precision problems, the `BatchNorm` operator and operators involved in loss still use FP32. +The main idea is to enlarge the loss by a certain multiple when calculating the loss. Due to the existence of the chain rule, the gradient also expands accordingly, and then the corresponding multiple is reduced when the optimizer updates the weight, thus avoiding the situation of data underflow without affecting the calculation result. -The specific implementation steps for using the `Model` interface are: +There are two ways of implementing Loss Scale in MindSpore, users can either use the functional programming writeup and manually call the `scale` and `unscale` methods of `StaticLossScaler` or `DynamicLossScaler` to scale the loss or gradient during training; or they can configure the loss or gradient based on the `Model` interface and configure the mixed precision `amp_level` and the Loss Scale method `loss_scale_manager` as `FixedLossScaleManager` or `DynamicLossScaleManager` when building the model by using `Model`. -1. Introduce the MindSpore model API `Model`. +First, let's take a look at why mixing accuracy is needed. The advantages of using FP16 to train a neural network are: -2. Define a network: This step is the same as that for defining a common network (no new configuration is required). +- **Reduce memory occupation**: The bit width of FP16 is half that of FP32, so the memory occupied by parameters such as weights is also half of the original, and the saved memory can be used to put a larger network model or use more data for training. +- **Accelerate communication efficiency**: For distributed training, especially in the process of large model training, the overhead of communication restricts the overall performance of network model training, and the less bit width of communication means that communication performance can be improved. Waiting time is reduced, and data circulation can be accelerated. +- **Higher computing effciency**: On special AI-accelerated chips such as Huawei's Ascend 910 and 310 series, or GPUs of the Titan V and Tesla V100 of the NVIDIA VOLTA architecture, the performance of performing operations using FP16 is faster than that of the FP32. -3. Create a dataset: For this step, refer to [Data Processing](https://www.mindspore.cn/tutorials/en/master/advanced/dataset.html). +But using FP16 also brings some problems, the most important of which are precision overflow and rounding error, and Loss Scale is to solve the precision overflow and proposed. -4. Use the `Model` API to encapsulate the network model, optimizer, and loss function, and set the `amp_level` parameter. For details, see [MindSpore API](https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.Model.html#mindspore.Model). In this step, MindSpore automatically selects an appropriate operator to convert FP32 to FP16. +As shown in the figure, if only FP32 training is used, the model converges better, but if mixed-precision training is used, there will be a situation where the network model cannot converge. The reason is that the value of the gradient is too small, and using the FP16 representation will cause the problem of underflow under the data, resulting in the model not converging, as shown in the gray part of the figure. Loss Scale needs to be introduced. + +![loss_scale1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_zh_cn/others/images/loss_scale1.png) + +The following is in the network model training stage, a layer of activation function gradient distribution, of which 68% of the network model activation parameter bit 0. Another 4% of the accuracy in the $2^{-32}, 2^{-20}$ interval, directly use FP16 to represent the data inside, which truncates the underflow data. All gradient values will become 0. + +![loss_scale2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_zh_cn/others/images/loss_scale2.png) + +In order to solve the problem of ladder overflowing over small data, the forward calculated Loss value is amplified, that is, the parameters of FP32 are multiplied by a factor coefficient, and the possible overflowing decimal data is moved forward and panned to the data range that FP16 can represent. According to the chain differentiation law, amplifying the Loss acts on each gradient of backpropagation, which is more efficient than amplifying on each gradient. + +![loss_scale3](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_zh_cn/others/images/loss_scale3.png) + +Loss amplification needs to be achieved in combination with mixing accuracy, and its main main ideas are: + +- **Scale up stage**: After the network model forward calculation, the resulting loss change value DLoss is increased by a factor of $2^K$ before the repercussion propagation. +- **Scale down stage**: After backpropagation, the weight gradient is reduced by $2^K$, and the FP32 value is restored for storage. + +**Dynamic Loss Scale**: The loss scale mentioned above is to use a default value to scale the loss value, in order to make full use of the dynamic range of FP16, you can better mitigate the rounding error, and try to use a relatively large magnification. To summarize the dynamic loss scaling algorithm, it is to reduce the loss scale whenever the gradient overflows, and intermittently try to increase the loss scale, so as to achieve the use of the highest loss scale factor without causing overflow, and better restore accuracy. + +The dynamic loss scale algorithm is as follows: + +1. The algorithm of dynamic loss scaling starts with a relatively high scaling factor (such as $2^{24}$), then starts training and checks whether the number overflows in the iteration (Infs/Nans); +2. If there is no gradient overflow, the scale factor is not adjusted and the iteration continues; if the gradient overflow is detected, the scale factor is halved and the gradient update is reconfirmed until the parameter does not appear in the overflow range; +3. In the later stages of training, the loss has become stable and convergent, and the amplitude of the gradient update is often small, which can allow a higher loss scaling factor to prevent data underflow again. +4. Therefore, the dynamic loss scaling algorithm attempts to increase the loss scaling by the F multiple every N (N=2000) iterations, and then performs step 2 to check for overflow. + +## Using Mixed Precision and Loss Scale in MindSpore + +MindSpore provides two ways of using mixed precision and loss scale. -The following is a basic code example. First, import the required libraries and declarations, and define the LeNet-5 network model. +- Use functional programming: use `auto_mixed_precision` for automatic mixing accuracy, `all_finite` for overflow judgments, and `StaticLossScaler` and `DynamicLossScaler` for manual scaling of gradients and losses. + +- Using the training interface `Model`: configure the input `amp_level` to set the execution policy for mixed precision and the input `loss_scale_manager` to `FixedLossScaleManager` or `DynamicLossScaleManager` to implement loss scaling. + +## Using a Functional Programming for Mixed Precision and Loss Scale + +MindSpore provides a functional interface for mixed precision scenarios. Users can use `auto_mixed_precision` for automatic mixed precision, `all_finite` for overflow judgments during training, and `StaticLossScaler` and `DynamicLossScaler` to manually perform gradient and loss scaling. + +Common uses of LossScaler under functional. + +First import the relevant libraries and define a LeNet5 network: ```python import numpy as np @@ -99,8 +141,6 @@ import mindspore as ms from mindspore.common.initializer import Normal from mindspore import dataset as ds -ms.set_context(mode=ms.GRAPH_MODE) -ms.set_context(device_target="CPU") class LeNet5(nn.Cell): """ @@ -112,9 +152,8 @@ class LeNet5(nn.Cell): Returns: Tensor, output tensor - - """ + def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') @@ -136,6 +175,109 @@ class LeNet5(nn.Cell): return x ``` +Perform auto mixed precision on the network. + +`auto_mixed_precision` implements the meanings of automatic mixed precision configuration as follows: + +- 'O0': keep FP32. +- 'O1': cast as FP16 by whitelist. +- 'O2': keep FP32 by blacklist and the rest cast as FP16. +- 'O3': fully cast to FP16. + +> The current black and white list is Cell granularity. + +```python +from mindspore import amp +from mindspore import ops + +net = LeNet5(10) +amp.auto_mixed_precision(net, 'O1') +``` + +Instantiate the LossScaler and manually scale up the loss value when defining the forward network. + +```python +loss_fn = nn.BCELoss(reduction='mean') +opt = nn.Adam(generator.trainable_params(), learning_rate=0.01) + +# Define LossScaler +loss_scaler = amp.DynamicLossScaler(scale_value=2**10, scale_factor=2, scale_window=50) + +def net_forward(data, label): + out = net(data) + loss_value = loss_fn(out, label) + # scale up the loss value + scaled_loss = loss_scaler.scale(loss_value) + return scaled_loss, out +``` + +Reverse acquisition of gradients. + +```python +grad_fn = ops.value_and_grad(net_forward, None, net.trainable_params()) +``` + +Define the training step: calculate the current gradient value and recover the loss. Use `all_finite` to determine whether there is a gradient underflow problem, if there is no overflow, recover the gradient and update the network weights; if there is overflow, skip this step. + +```python +@ms_function +def train_step(x, y): + (loss_value, _), grads = grad_fn(x, y) + loss_value = loss_scaler.unscale(loss_value) + + is_finite = amp.all_finite(grads) + if is_finite: + grads = loss_scaler.unscale(grads) + loss_value = ops.depend(loss_value, opt(grads)) + loss_scaler.adjust(is_finite) + return loss_value +``` + +Execute training. + +```python +epochs = 5 +for epoch in range(epochs): + for data, label in datasets: + loss = train_step(data, label) +``` + +## Mixed-precision and Loss Scale by Using the Training Interface `Model` + +### Mixed-Precision + +The `Model` interface provides the input `amp_level` to achieve automatic mixed precision, or the user can set the operator involved in the Cell to FP16 via `to_float(ms.float16)` to achieve manual mixed precision. + +#### Automatic Mixed-Precision + +To use the automatic mixed-precision, you need to call the `Model` API to transfer the network to be trained and optimizer as the input. This API converts the network model operators into FP16 operators. + +> Due to precision problems, the `BatchNorm` operator and operators involved in loss still use FP32. + +The specific implementation steps for using the `Model` interface are: + +1. Introduce the MindSpore model API `Model`. + +2. Define a network: This step is the same as that for defining a common network (no new configuration is required). + +3. Create a dataset: For this step, refer to [Data Processing](https://www.mindspore.cn/tutorials/en/master/advanced/dataset.html). + +4. Use the `Model` API to encapsulate the network model, optimizer, and loss function, and set the `amp_level` parameter. For details, see [MindSpore API](https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.Model.html#mindspore.Model). In this step, MindSpore automatically selects an appropriate operator to convert FP32 to FP16. + +The following is a basic code example. First, import the required libraries and declarations. + +```python +import numpy as np +import mindspore.nn as nn +from mindspore.nn import Accuracy +import mindspore as ms +from mindspore.common.initializer import Normal +from mindspore import dataset as ds + +ms.set_context(mode=ms.GRAPH_MODE) +ms.set_context(device_target="CPU") +``` + Create a virtual random dataset for data input of the sample model. ```python @@ -158,7 +300,7 @@ def create_dataset(num_data=1024, batch_size=32, repeat_size=1): return input_data ``` -Set the `amp_level` parameter and use the `Model` API to encapsulate the network model, optimizer, and loss function. +Taking the LeNet5 as an example, set the `amp_level` parameter and use the `Model` API to encapsulate the network model, optimizer, and loss function. ```python ds_train = create_dataset() @@ -168,18 +310,19 @@ network = LeNet5(10) # Define Loss and Optimizer net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") -net_opt = nn.Momentum(network.trainable_params(),learning_rate=0.01, momentum=0.9) -model = ms.Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O3", loss_scale_manager=None) +net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) +# Set amp level +model = ms.Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O3") # Run training model.train(epoch=10, train_dataset=ds_train) ``` -## Manual Mixed-Precision +#### Manual Mixed-Precision MindSpore also supports manual mixed-precision. (Manual mixed-precision is not recommended unless you want to customize special networks and features.) -Assume that only one dense layer on the network uses FP16 for computation and other layers use FP32. +Assume that only one Conv layer on the network uses FP16 for computation and other layers use FP32. > The mixed-precision is configured in the unit of Cell. The default type of a Cell is FP32. @@ -204,252 +347,78 @@ import mindspore.ops as ops ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU") ``` -The network is defined in the same way regardless of whether FP32 or FP16 is used. The difference is that after the network is defined, the dense layer is declared to use FP16 for computing when the network model is initialized, that is, `net.dense.to_float(mstype.float16)`. +After initializing the network model, declare that the Conv1 layer in LeNet5 is computed by using FP16, i.e. `network.conv1.to_float(mstype.float16)`. ```python -class LeNet5(nn.Cell): - """ - Lenet network - - Args: - num_class (int): Number of classes. Default: 10. - num_channel (int): Number of channels. Default: 1. - - Returns: - Tensor, output tensor - """ - - def __init__(self, num_class=10, num_channel=1): - super(LeNet5, self).__init__() - self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') - self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') - self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) - self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) - self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.flatten = nn.Flatten() - self.cast = ops.Cast() - - def construct(self, x): - x = self.conv1(x) - x = self.cast(x, mstype.float32) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.max_pool2d(self.relu(self.conv2(x))) - x = self.flatten(x) - x = self.relu(self.fc1(x)) - x = self.relu(self.fc2(x)) - x = self.fc3(x) - return x - - - -# create dataset -def get_data(num, img_size=(1, 32, 32), num_classes=10, is_onehot=True): - for _ in range(num): - img = np.random.randn(*img_size) - target = np.random.randint(0, num_classes) - target_ret = np.array([target]).astype(np.float32) - if is_onehot: - target_onehot = np.zeros(shape=(num_classes,)) - target_onehot[target] = 1 - target_ret = target_onehot.astype(np.float32) - yield img.astype(np.float32), target_ret - -def create_dataset(num_data=1024, batch_size=32, repeat_size=1): - input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label']) - input_data = input_data.batch(batch_size, drop_remainder=True) - input_data = input_data.repeat(repeat_size) - return input_data - - ds_train = create_dataset() network = LeNet5(10) net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") -net_opt = nn.Momentum(network.trainable_params(),learning_rate=0.01, momentum=0.9) -network.conv1.to_float(mstype.float16) +net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) +network.conv1.to_float(ms.float16) model = ms.Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2") - model.train(epoch=2, train_dataset=ds_train) ``` -> Constraint: When mixed-precision is used, the backward network can be generated only by the automatic differential function. Otherwise, MindSpore may generate exception information indicating that the data format does not match. - -## Loss Scale - -Loss Scale is mainly used in the process of mixed-precision training. - -In the process of mixed precision training, the FP16 type is used instead of the FP32 type for data storage, so as to achieve the effect of reducing memory and improving the computing speed. However, because the FP16 type is much smaller than the range represented by the FP32 type, data underflow occurs when parameters (such as gradients) become very small during training. The Loss Scale is proposed to solve the underflow of FP16 type data. - -The main idea is to enlarge the loss by a certain multiple when calculating the loss. Due to the existence of the chain rule, the gradient also expands accordingly, and then the corresponding multiple is reduced when the optimizer updates the weight, thus avoiding the situation of data underflow without affecting the calculation result. - -Two ways to scale are available in MindSpore, namely `FixedLossScaleManager` and `DynamicLossScaleManager`, which need to be used with the Model. When building models by using the Model, the mixed-precision strategy `amp_level` and the Loss Scale approach `loss_scale_manager` can be configured. - -First, let's take a look at why mixing accuracy is needed. The advantages of using FP16 to train a neural network are: +> When mixed-precision is used, the backward network can be generated only by the automatic differential function, not by user-defined inverse networks. Otherwise, MindSpore may generate exception information indicating that the data format does not match. -- **Reduce memory occupation**: The bit width of FP16 is half that of FP32, so the memory occupied by parameters such as weights is also half of the original, and the saved memory can be used to put a larger network model or use more data for training. -- **Accelerate communication efficiency**: For distributed training, especially in the process of large model training, the overhead of communication restricts the overall performance of network model training, and the less bit width of communication means that communication performance can be improved. Waiting time is reduced, and data circulation can be accelerated. -- **Higher computing effciency**: On special AI-accelerated chips such as Huawei's Ascend 910 and 310 series, or GPUs of the Titan V and Tesla V100 of the NVIDIA VOLTA architecture, the performance of performing operations using FP16 is faster than that of the FP32. +### Loss scale -But using FP16 also brings some problems, the most important of which are precision overflow and rounding error, and Loss Scale is to solve the precision overflow and proposed. +The following two APIs in MindSpore that use the loss scaling algorithm are described separately [FixedLossScaleManager](https://www.mindspore.cn/docs/en/master/api_python/amp/mindspore.amp.FixedLossScaleManager.html#mindspore.amp.FixedLossScaleManager) and [DynamicLossScaleManager](https://www.mindspore.cn/docs/en/master/api_python/amp/mindspore.amp.DynamicLossScaleManager.html#mindspore.amp.DynamicLossScaleManager). -As shown in the figure, if only FP32 training is used, the model converges better, but if mixed-precision training is used, there will be a situation where the network model cannot converge. The reason is that the value of the gradient is too small, and using the FP16 representation will cause the problem of underflow under the data, resulting in the model not converging, as shown in the gray part of the figure. Loss Scale needs to be introduced. +#### FixedLossScaleManager -![loss_scale1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_zh_cn/others/images/loss_scale1.png) - -The following is in the network model training stage, a layer of activation function gradient distribution, of which 68% of the network model activation parameter bit 0. Another 4% of the accuracy in the $2^{-32}, 2^{-20}$ interval, directly use FP16 to represent the data inside, which truncates the underflow data. All gradient values will become 0. - -![loss_scale2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_zh_cn/others/images/loss_scale2.png) - -In order to solve the problem of ladder overflowing over small data, the forward calculated Loss value is amplified, that is, the parameters of FP32 are multiplied by a factor coefficient, and the possible overflowing decimal data is moved forward and panned to the data range that FP16 can represent. According to the chain differentiation law, amplifying the Loss acts on each gradient of backpropagation, which is more efficient than amplifying on each gradient. - -![loss_scale3](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/experts/source_zh_cn/others/images/loss_scale3.png) +`FixedLossScaleManager` does not change the size of the scale when scaling, and the value of the scale is controlled by the input parameter loss_scale, which can be specified by the user. The default value is taken if it is not specified. -Loss amplification needs to be achieved in combination with mixing accuracy, and its main main ideas are: +Another parameter of `FixedLossScaleManager` is `drop_overflow_update`, which controls whether parameters are updated in the event of an overflow. -- **Scale up stage**: After the network model forward calculation, the resulting loss change value DLoss is increased by a factor of $2^K$ before the repercussion propagation. -- **Scale down stage**: After backpropagation, the weight gradient is reduced by $2^K$, and the FP32 value is restored for storage. +In general, the LossScale function does not need to be used with the optimizer, but when using `FixedLossScaleManager`, if `drop_overflow_update` is False, the optimizer needs to set the value of `loss_scale` and the value of `loss_scale` should be the same as that of `FixedLossScaleManager`. -**Dynamic Loss Scale**: The loss scale mentioned above is to use a default value to scale the loss value, in order to make full use of the dynamic range of FP16, you can better mitigate the rounding error, and try to use a relatively large magnification. To summarize the dynamic loss scaling algorithm, it is to reduce the loss scale whenever the gradient overflows, and intermittently try to increase the loss scale, so as to achieve the use of the highest loss scale factor without causing overflow, and better restore accuracy. +The detailed use of `FixedLossScaleManager` is as follows: -The dynamic loss scale algorithm is as follows: +Import the necessary libraries and declare execution using graph mode. -1. The algorithm of dynamic loss scaling starts with a relatively high scaling factor (such as $2^{24}$), then starts training and checks whether the number overflows in the iteration (Infs/Nans); -2. If there is no gradient overflow, the scale factor is not adjusted and the iteration continues; if the gradient overflow is detected, the scale factor is halved and the gradient update is reconfirmed until the parameter does not appear in the overflow range; -3. In the later stages of training, the loss has become stable and convergent, and the amplitude of the gradient update is often small, which can allow a higher loss scaling factor to prevent data underflow again. -4. Therefore, the dynamic loss scaling algorithm attempts to increase the loss scaling by the F multiple every N (N=2000) iterations, and then performs step 2 to check for overflow. +```python +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore import amp +from mindspore.nn import Accuracy +from mindspore.common.initializer import Normal +from mindspore import dataset as ds -## Loss scale used in MindSpore +ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU") +``` -The following two APIs in MindSpore that use the loss scaling algorithm are described separately [FixedLossScaleManager](https://www.mindspore.cn/docs/en/master/api_python/amp/mindspore.amp.FixedLossScaleManager.html#mindspore.amp.FixedLossScaleManager) and [DynamicLossScaleManager](https://www.mindspore.cn/docs/en/master/api_python/amp/mindspore.amp.DynamicLossScaleManager.html#mindspore.amp.DynamicLossScaleManager). +Define the network model by using LeNet5 as an example; define the dataset and the interfaces commonly used in the training process. -### FixedLossScaleManager +```python +ds_train = create_dataset() +# Initialize network +network = LeNet5(10) +# Define Loss and Optimizer +net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") +``` -`FixedLossScaleManager` does not change the size of the scale when scaling, and the value of the scale is controlled by the input parameter loss_scale, which can be specified by the user. The default value is taken if it is not specified. +Use Loss Scale API to act in optimizers and models. -Another parameter of `FixedLossScaleManager` is `drop_overflow_update`, which controls whether parameters are updated in the event of an overflow. +```python +# Define Loss Scale, optimizer and model +#1) Drop the parameter update if there is an overflow +loss_scale_manager = amp.FixedLossScaleManager() +net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) +model = ms.Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O0", loss_scale_manager=loss_scale_manager) -In general, the LossScale function does not need to be used with the optimizer, but when using `FixedLossScaleManager`, if `drop_overflow_update` is False, the optimizer needs to set the value of `loss_scale` and the value of `loss_scale` should be the same as that of `FixedLossScaleManager`. +#2) Execute parameter update even if overflow occurs +loss_scale = 1024.0 +loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False) +net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9, loss_scale=loss_scale) +model = ms.Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O0", loss_scale_manager=loss_scale_manager) -The detailed use of `FixedLossScaleManager` is as follows: +# Run training +model.train(epoch=10, train_dataset=ds_train, callbacks=[ms.LossMonitor()]) +``` -1. Import the necessary libraries and declare execution using graph mode. - - ```python - import numpy as np - import mindspore as ms - import mindspore.nn as nn - import mindspore as ms - from mindspore.nn import Accuracy - from mindspore.common.initializer import Normal - from mindspore import dataset as ds - - ms.set_seed(0) - ms.set_context(mode=ms.GRAPH_MODE) - ``` - -2. Define the LeNet5 network model, and any network model can use the Loss Scale mechanism. - - ```python - class LeNet5(nn.Cell): - """ - Lenet network - - Args: - num_class (int): Number of classes. Default: 10. - num_channel (int): Number of channels. Default: 1. - - Returns: - Tensor, output tensor - """ - - def __init__(self, num_class=10, num_channel=1): - super(LeNet5, self).__init__() - self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') - self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') - self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) - self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) - self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.flatten = nn.Flatten() - - def construct(self, x): - x = self.max_pool2d(self.relu(self.conv1(x))) - x = self.max_pool2d(self.relu(self.conv2(x))) - x = self.flatten(x) - x = self.relu(self.fc1(x)) - x = self.relu(self.fc2(x)) - x = self.fc3(x) - return x - ``` - -3. Define common interfaces in datasets and training processes. - - ```python - # create dataset - def get_data(num, img_size=(1, 32, 32), num_classes=10, is_onehot=True): - for _ in range(num): - img = np.random.randn(*img_size) - target = np.random.randint(0, num_classes) - target_ret = np.array([target]).astype(np.float32) - if is_onehot: - target_onehot = np.zeros(shape=(num_classes,)) - target_onehot[target] = 1 - target_ret = target_onehot.astype(np.float32) - yield img.astype(np.float32), target_ret - - def create_dataset(num_data=1024, batch_size=32, repeat_size=1): - input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label']) - input_data = input_data.batch(batch_size, drop_remainder=True) - input_data = input_data.repeat(repeat_size) - return input_data - - ds_train = create_dataset() - - # Initialize network - network = LeNet5(10) - - # Define Loss and Optimizer - net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") - ``` - -4. The API interface that really uses Loss Scale acts on the optimizer and model. - - ```python - # Define Loss Scale, optimizer and model - #1) Drop the parameter update if there is an overflow - loss_scale_manager = ms.FixedLossScaleManager() - net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) - model = ms.Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O0", loss_scale_manager=loss_scale_manager) - - #2) Execute parameter update even if overflow occurs - loss_scale = 1024.0 - loss_scale_manager = ms.FixedLossScaleManager(loss_scale, False) - net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9, loss_scale=loss_scale) - model = ms.Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O0", loss_scale_manager=loss_scale_manager) - - # Run training - model.train(epoch=10, train_dataset=ds_train, callbacks=[ms.LossMonitor()]) - ``` - - The running result is as follows: - - ```text - epoch: 1 step: 32, loss is 2.3018966 - epoch: 2 step: 32, loss is 2.2965345 - epoch: 3 step: 32, loss is 2.3021417 - epoch: 4 step: 32, loss is 2.2995133 - epoch: 5 step: 32, loss is 2.3040886 - epoch: 6 step: 32, loss is 2.3131478 - epoch: 7 step: 32, loss is 2.2919555 - epoch: 8 step: 32, loss is 2.311748 - epoch: 9 step: 32, loss is 2.304955 - epoch: 10 step: 32, loss is 2.2682834 - ``` - -### LossScale and Optimizer +#### LossScale and Optimizer As mentioned earlier, the optimizer needs to be used together when using `FixedLossScaleManager` and `drop_overflow_update` is False. @@ -498,7 +467,7 @@ class CustomTrainOneStepCell(nn.TrainOneStepCell): - scale_grad function: Used for division between the gradient and the `loss_scale` coefficient to restore the gradient. - construct function: Referring to `nn. TrainOneStepCell`, defines the computational logic for `construct` and calls `scale_grad` after acquiring the gradient. -After defining `TrainOneStepCell`, the training network needs to be manually built, which is as follows: +After customizing `TrainOneStepCell`, the training network needs to be manually built, which is as follows: ```python import mindspore as ms @@ -541,7 +510,7 @@ model.train(epoch=epochs, train_dataset=ds_train) When training with `Model` in this scenario, the `loss_scale_manager` and `amp_level` do not need to be configured, as the `CustomTrainOneStepCell` already includes mixed-precision calculation logic. -### DynamicLossScaleManager +#### DynamicLossScaleManager `DynamicLossScaleManager` can dynamically change the size of the scale during training, keeping the scale as large as possible without overflow. diff --git a/tutorials/experts/source_zh_cn/others/mixed_precision.ipynb b/tutorials/experts/source_zh_cn/others/mixed_precision.ipynb index b618967cb7aa2a9a0f6e356eec012e152118a230..1cf80a357ef6385fc374219397e8dd8516492e27 100644 --- a/tutorials/experts/source_zh_cn/others/mixed_precision.ipynb +++ b/tutorials/experts/source_zh_cn/others/mixed_precision.ipynb @@ -134,7 +134,7 @@ "\n", "- 使用训练接口 `Model` :配置入参 `amp_level` 设置混合精度的执行策略,配置入参 `loss_scale_manager` 为 `FixedLossScaleManager` 或 `DynamicLossScaleManager` 实现损失缩放;\n", "\n", - "### 使用函数式编程方式实现混合精度与损失缩放\n", + "## 使用函数式编程方式实现混合精度与损失缩放\n", "\n", "MindSpore提供了函数式接口用于混合精度场景,用户可以使用 `auto_mixed_precision` 实现自动混合精度,训练过程中通过 `all_finite` 做溢出判断,使用 `StaticLossScaler` 和 `DynamicLossScaler` 手动执行梯度及损失的缩放。\n", "\n", @@ -353,13 +353,13 @@ { "cell_type": "markdown", "source": [ - "### 使用训练接口 `Model` 实现混合精度与损失缩放\n", + "## 使用训练接口 `Model` 实现混合精度与损失缩放\n", "\n", - "#### 混合精度\n", + "### 混合精度\n", "\n", "`Model` 接口提供了入参 `amp_level` 实现自动混合精度,用户也可以通过 `to_float(ms.float16)` 把Cell中涉及的算子设置成FP16,实现手动混合精度。\n", "\n", - "##### 自动混合精度\n", + "#### 自动混合精度\n", "\n", "使用自动混合精度,需要调用`Model`接口,将待训练网络和优化器作为输入传入,该接口会根据设定策略把对应的网络模型的的算子转换成FP16算子。\n", "\n", @@ -480,7 +480,7 @@ { "cell_type": "markdown", "source": [ - "##### 手动混合精度\n", + "#### 手动混合精度\n", "\n", "MindSpore目前还支持手动混合精度(一般不建议使用手动混合精度,除非自定义特殊网络和特性开发)。\n", "\n", @@ -559,11 +559,11 @@ "source": [ "> 使用混合精度时,只能由自动微分功能生成反向网络,不能由用户自定义生成反向网络,否则可能会导致MindSpore产生数据格式不匹配的异常信息。\n", "\n", - "#### 损失缩放\n", + "### 损失缩放\n", "\n", "下面将会分别介绍MindSpore中,配合 `Model` 接口使用损失缩放算法的主要两个API [FixedLossScaleManager](https://www.mindspore.cn/docs/zh-CN/master/api_python/amp/mindspore.amp.FixedLossScaleManager.html)和[DynamicLossScaleManager](https://www.mindspore.cn/docs/zh-CN/master/api_python/amp/mindspore.amp.DynamicLossScaleManager.html)。\n", "\n", - "##### FixedLossScaleManager\n", + "#### FixedLossScaleManager\n", "\n", "`FixedLossScaleManager`在进行缩放的时候,不会改变scale的大小,scale的值由入参loss_scale控制,可以由用户指定,不指定则取默认值。\n", "\n", @@ -667,7 +667,7 @@ { "cell_type": "markdown", "source": [ - "##### LossScale与优化器\n", + "#### LossScale与优化器\n", "\n", "前面提到了使用`FixedLossScaleManager`且`drop_overflow_update`为False时,优化器需要配合使用。\n", "\n", @@ -810,7 +810,7 @@ "source": [ "在此场景下使用`Model`进行训练时,`loss_scale_manager`和`amp_level`无需配置,因为`CustomTrainOneStepCell`中已经包含了混合精度的计算逻辑。\n", "\n", - "##### DynamicLossScaleManager\n", + "#### DynamicLossScaleManager\n", "\n", "`DynamicLossScaleManager`在训练过程中可以动态改变scale的大小,在没有发生溢出的情况下,要尽可能保持较大的scale。\n", "\n", diff --git a/tutorials/source_en/advanced/derivation.md b/tutorials/source_en/advanced/derivation.md index 4dfc155ecf628b8226dbd9ca53beb3fc5f26fbcf..30df83d4423a482e16959d7429dd3a69b576a682 100644 --- a/tutorials/source_en/advanced/derivation.md +++ b/tutorials/source_en/advanced/derivation.md @@ -2,19 +2,20 @@ -The `GradOperation` API provided by the `mindspore.ops` module can be used to generate a gradient of a network model. The following describes how to use the `GradOperation` API to perform first-order and second-order derivations and how to stop gradient computation. +The `grad` and `value_and_grad` provided by the `mindspore.ops` module generate the gradients of the network model. `grad` computes the network gradient, and `value_and_grad` computes both the forward output and the gradient of the network. This article focuses on how to use the main functions of the `grad`, including first-order and second-order derivations, derivation of the input or network weights separately, returning auxiliary variables, and stopping calculating the gradient. -> For details about `GradOperation`, see [API](https://mindspore.cn/docs/en/master/api_python/ops/mindspore.ops.GradOperation.html#mindspore.ops.GradOperation). +> For more information about the derivative interface, please refer to the [API documentation](https://www.mindspore.cn/docs/en/master/api_python/ops/mindspore.ops.grad.html). ## First-order Derivation -Method: `mindspore.ops.GradOperation()`. The parameter usage is as follows: +Method: `mindspore.ops.grad`. The parameter usage is as follows: -- `get_all`: If this parameter is set to `False`, the derivation is performed only on the first input. If this parameter is set to `True`, the derivation is performed on all inputs. -- `get_by_list`: If this parameter is set to `False`, the weight derivation is not performed. If this parameter is set to `True`, the weight derivation is performed. -- `sens_param`: The output value of the network is scaled to change the final gradient. Therefore, the dimension is the same as the output dimension. +- `fn`: the function or network to be derived. +- `grad_position`: specifies the index of the input position to be derived. If the index is int type, it means to derive for a single input; if tuple type, it means to derive for the position of the index within the tuple, where the index starts from 0; and if None, it means not to derive for the input. In this scenario, `weights` is non-None. Default: 0. +- `weights`: the network variables that need to return the gradients in the training network. Generally the network variables can be obtained by `weights = net.trainable_params()`. Default: None. +- `has_aux`: symbol for whether to return auxiliary arguments. If True, the number of `fn` outputs must be more than one, where only the first output of `fn` is involved in the derivation and the other output values will be returned directly. Default: False. -The [MatMul](https://mindspore.cn/docs/en/master/api_python/ops/mindspore.ops.MatMul.html#mindspore.ops.MatMul) operator is used to build a customized network model `Net`, and then perform first-order derivation on the model. The following formula is an example to describe how to use the `GradOperation` API: +The following is a brief introduction to the use of the `grad` by first constructing a customized network model `Net` and then performing a first-order derivative on it: $$f(x, y)=(x * z) * y \tag{1}$$ @@ -22,174 +23,90 @@ First, define the network model `Net`, input `x`, and input `y`. ```python import numpy as np +from mindspore import ops, Tensor import mindspore.nn as nn -import mindspore.ops as ops import mindspore as ms # Define the inputs x and y. -x = ms.Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=ms.float32) -y = ms.Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=ms.float32) +x = Tensor([3.0], dtype=ms.float32) +y = Tensor([5.0], dtype=ms.float32) -class Net(nn.Cell): - """Define the matrix multiplication network Net.""" +class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.matmul = ops.MatMul() self.z = ms.Parameter(ms.Tensor(np.array([1.0], np.float32)), name='z') def construct(self, x, y): - x = x * self.z - out = self.matmul(x, y) + out = x * x * y * self.z return out ``` -### Computing the Input Derivative +### Computing the First-order Derivative for Input -Compute the input derivative. The code is as follows: +To derive the inputs `x` and `y`, set `grad_position` to (0, 1): -```python -class GradNetWrtX(nn.Cell): - """Define the first-order derivation of network input.""" - - def __init__(self, net): - super(GradNetWrtX, self).__init__() - self.net = net - self.grad_op = ops.GradOperation() - - def construct(self, x, y): - gradient_function = self.grad_op(self.net) - return gradient_function(x, y) +$$\frac{\partial f}{\partial x}=2 * x * y * z \tag{2}$$ -output = GradNetWrtX(Net())(x, y) -print(output) -``` +$$\frac{\partial f}{\partial y}=x * x * z \tag{3}$$ -```text - [[4.5099998 2.7 3.6000001] - [4.5099998 2.7 3.6000001]] +```python +net = Net() +grad_fn = ops.grad(net, grad_position=(0, 1)) +gradients = grad_fn(x, y) +print(gradients) ``` -The preceding result is explained as follows. To facilitate analysis, the preceding inputs `x` and `y`, and weight `z` are expressed as follows: - ```text -x = ms.Tensor([[x1, x2, x3], [x4, x5, x6]]) -y = ms.Tensor([[y1, y2, y3], [y4, y5, y6], [y7, y8, y9]]) -z = ms.Tensor([z]) + (Tensor(shape=[1], dtype=Float32, value= [ 3.00000000e+01]), Tensor(shape=[1], dtype=Float32, value= [ 9.00000000e+00])) ``` -The following forward result can be obtained based on the definition of the MatMul operator: - -$$output = [[(x_1 \cdot y_1 + x_2 \cdot y_4 + x_3 \cdot y_7) \cdot z, (x_1 \cdot y_2 + x_2 \cdot y_5 + x_3 \cdot y_8) \cdot z, (x_1 \cdot y_3 + x_2 \cdot y_6 + x_3 \cdot y_9) \cdot z],$$ +### Computing the Derivative for Weight -$$[(x_4 \cdot y_1 + x_5 \cdot y_4 + x_6 \cdot y_7) \cdot z, (x_4 \cdot y_2 + x_5 \cdot y_5 + x_6 \cdot y_8) \cdot z, (x_4 \cdot y_3 + x_5 \cdot y_6 + x_6 \cdot y_9) \cdot z]] \tag{2}$$ +Derive for the weight `z`, where it is not necessary to derive for the inputs, and set `grad_position` to None: -MindSpore uses the reverse-mode automatic differentiation mechanism during gradient computation. The output result is summed and then the derivative of the input `x` is computed. - -1. Sum formula: - - $$\sum{output} = [(x_1 \cdot y_1 + x_2 \cdot y_4 + x_3 \cdot y_7) + (x_1 \cdot y_2 + x_2 \cdot y_5 + x_3 \cdot y_8) + (x_1 \cdot y_3 + x_2 \cdot y_6 + x_3 \cdot y_9)$$ - - $$+ (x_4 \cdot y_1 + x_5 \cdot y_4 + x_6 \cdot y_7) + (x_4 \cdot y_2 + x_5 \cdot y_5 + x_6 \cdot y_8) + (x_4 \cdot y_3 + x_5 \cdot y_6 + x_6 \cdot y_9)] \cdot z \tag{3}$$ - -2. Derivation formula: - - $$\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[(y_1 + y_2 + y_3) \cdot z, (y_4 + y_5 + y_6) \cdot z, (y_7 + y_8 + y_9) \cdot z],$$ - - $$[(y_1 + y_2 + y_3) \cdot z, (y_4 + y_5 + y_6) \cdot z, (y_7 + y_8 + y_9) \cdot z]] \tag{4}$$ - -3. Computation result: - - $$\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[4.51 \quad 2.7 \quad 3.6] [4.51 \quad 2.7 \quad 3.6]] \tag{5}$$ - - > If the derivatives of the `x` and `y` inputs are considered, you only need to set `self.grad_op = GradOperation(get_all=True)` in `GradNetWrtX`. - -### Computing the Weight Derivative - -Compute the weight derivative. The sample code is as follows: +$$\frac{\partial f}{\partial z}=x * x * y \tag{4}$$ ```python -class GradNetWrtZ(nn.Cell): - """Define the first-order derivation of network weight."" - - def __init__(self, net): - super(GradNetWrtZ, self).__init__() - self.net = net - self.params = ms.ParameterTuple(net.trainable_params()) - self.grad_op = ops.GradOperation(get_by_list=True) - - def construct(self, x, y): - gradient_function = self.grad_op(self.net, self.params) - return gradient_function(x, y) +params = ms.ParameterTuple(net.trainable_params()) -output = GradNetWrtZ(Net())(x, y) -print(output[0]) +output = ops.grad(net, grad_position=None, weights=params)(x, y) +print(output) ``` ```text - [21.536] + (Tensor(shape=[1], dtype=Float32, value= [ 4.50000000e+01]),) ``` -The following formula is used to explain the preceding result. A derivation formula for the weight is: - -$$\frac{\mathrm{d}(\sum{output})}{\mathrm{d}z} = (x_1 \cdot y_1 + x_2 \cdot y_4 + x_3 \cdot y_7) + (x_1 \cdot y_2 + x_2 \cdot y_5 + x_3 \cdot y_8) + (x_1 \cdot y_3 + x_2 \cdot y_6 + x_3 \cdot y_9)$$ - -$$+ (x_4 \cdot y_1 + x_5 \cdot y_4 + x_6 \cdot y_7) + (x_4 \cdot y_2 + x_5 \cdot y_5 + x_6 \cdot y_8) + (x_4 \cdot y_3 + x_5 \cdot y_6 + x_6 \cdot y_9) \tag{6}$$ - -Computation result: - -$$\frac{\mathrm{d}(\sum{output})}{\mathrm{d}z} = [2.1536e+01] \tag{7}$$ - -### Gradient Value Scaling +### Returning Auxiliary Variables -You can use the `sens_param` parameter to control the scaling of the gradient value. +Simultaneous derivation for the inputs and weights, where only the first output is involved in the derivation, with the following sample code: ```python -class GradNetWrtN(nn.Cell): - """Define the first-order derivation of the network and control gradient value scaling.""" - def __init__(self, net): - super(GradNetWrtN, self).__init__() - self.net = net - self.grad_op = ops.GradOperation(sens_param=True) +net = nn.Dense(10, 1) +loss_fn = nn.MSELoss() - # Define gradient value scaling. - self.grad_wrt_output = ms.Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=ms.float32) - def construct(self, x, y): - gradient_function = self.grad_op(self.net) - return gradient_function(x, y, self.grad_wrt_output) +def forward(inputs, labels): + logits = net(inputs) + loss = loss_fn(logits, labels) + return loss, logits -output = GradNetWrtN(Net())(x, y) -print(output) -``` -```text - [[2.211 0.51 1.49 ] - [5.588 2.68 4.07 ]] -``` +inputs = Tensor(np.random.randn(16, 10).astype(np.float32)) +labels = Tensor(np.random.randn(16, 1).astype(np.float32)) +weights = net.trainable_params() -To facilitate the explanation of the preceding result, `self.grad_wrt_output` is recorded as follows: +# Aux value does not contribute to the gradient. +grad_fn = ops.grad(forward, grad_position=0, weights=None, has_aux=True) +inputs_gradient, (aux_logits,) = grad_fn(inputs, labels) +print(len(inputs_gradient), aux_logits.shape) +``` ```text -self.grad_wrt_output = ms.Tensor([[s1, s2, s3], [s4, s5, s6]]) + 16, (16, 1) ``` -The output value after scaling is the product of the original output value and the element corresponding to `self.grad_wrt_output`. The formula is as follows: - -$$output = [[(x_1 \cdot y_1 + x_2 \cdot y_4 + x_3 \cdot y_7) \cdot z \cdot s_1, (x_1 \cdot y_2 + x_2 \cdot y_5 + x_3 \cdot y_8) \cdot z \cdot s_2, (x_1 \cdot y_3 + x_2 \cdot y_6 + x_3 \cdot y_9) \cdot z \cdot s_3], $$ - -$$[(x_4 \cdot y_1 + x_5 \cdot y_4 + x_6 \cdot y_7) \cdot z \cdot s_4, (x_4 \cdot y_2 + x_5 \cdot y_5 + x_6 \cdot y_8) \cdot z \cdot s_5, (x_4 \cdot y_3 + x_5 \cdot y_6 + x_6 \cdot y_9) \cdot z \cdot s_6]] \tag{8}$$ - -The derivation formula is changed to compute the derivative of the sum of the output values to each element of `x`. - -$$\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[(s_1 \cdot y_1 + s_2 \cdot y_2 + s_3 \cdot y_3) \cdot z, (s_1 \cdot y_4 + s_2 \cdot y_5 + s_3 \cdot y_6) \cdot z, (s_1 \cdot y_7 + s_2 \cdot y_8 + s_3 \cdot y_9) \cdot z],$$ - -$$[(s_4 \cdot y_1 + s_5 \cdot y_2 + s_6 \cdot y_3) \cdot z, (s_4 \cdot y_4 + s_5 \cdot y_5 + s_6 \cdot y_6) \cdot z, (s_4 \cdot y_7 + s_5 \cdot y_8 + s_6 \cdot y_9) \cdot z]] \tag{9}$$ - -Computation result: - -$$\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[2.211 \quad 0.51 \quad 1.49][5.588 \quad 2.68 \quad 4.07]] \tag{10}$$ - ### Stopping Gradient Computation You can use `stop_gradient` to stop computing the gradient of a specified operator to eliminate the impact of the operator on the gradient. @@ -203,140 +120,54 @@ class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.matmul = ops.MatMul() def construct(self, x, y): - out1 = self.matmul(x, y) - out2 = self.matmul(x, y) + out1 = x * y + out2 = x * y out2 = ops.stop_gradient(out2) # Stop computing the gradient of the out2 operator. out = out1 + out2 return out -class GradNetWrtX(nn.Cell): - def __init__(self, net): - super(GradNetWrtX, self).__init__() - self.net = net - self.grad_op = ops.GradOperation() - - def construct(self, x, y): - gradient_function = self.grad_op(self.net) - return gradient_function(x, y) - -output = GradNetWrtX(Net())(x, y) +net = Net() +grad_fn = ops.grad(net) +output = grad_fn(x, y) print(output) ``` ```text - [[4.5099998 2.7 3.6000001] - [4.5099998 2.7 3.6000001]] + [5.0] ``` According to the preceding information, `stop_gradient` is set for `out2`. Therefore, `out2` does not contribute to gradient computation. The output result is the same as that when `out2` is not added. -Delete `out2 = stop_gradient(out2)` and check the output. An example of the code is as follows: +Delete `out2 = stop_gradient(out2)` and check the output result. An example of the code is as follows: ```python class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - self.matmul = ops.MatMul() def construct(self, x, y): - out1 = self.matmul(x, y) - out2 = self.matmul(x, y) + out1 = x * y + out2 = x * y # out2 = stop_gradient(out2) out = out1 + out2 return out -class GradNetWrtX(nn.Cell): - def __init__(self, net): - super(GradNetWrtX, self).__init__() - self.net = net - self.grad_op = ops.GradOperation() - - def construct(self, x, y): - gradient_function = self.grad_op(self.net) - return gradient_function(x, y) -output = GradNetWrtX(Net())(x, y) +net = Net() +grad_fn = ops.grad(net) +output = grad_fn(x, y) print(output) ``` ```text - [[9.0199995 5.4 7.2000003] - [9.0199995 5.4 7.2000003]] + [10.0] ``` According to the printed result, after the gradient of the `out2` operator is computed, the gradients generated by the `out2` and `out1` operators are the same. Therefore, the value of each item in the result is twice the original value (accuracy error exists). -### Customized Backward Propagation Function - -When MindSpore is used to build a neural network, the `nn.Cell` class needs to be inherited. When there are some operations that do not define backward propagation rules on the network, or when you want to control the gradient computation process of the entire network, you can use the function of customizing the backward propagation function of the `nn.Cell` object. The format is as follows: - -```python -def bprop(self, ..., out, dout): - return ... -``` - -- Input parameters: Input parameters in the forward porpagation plus `out` and `dout`. `out` indicates the computation result of the forward porpagation, and `dout` indicates the gradient returned to the `nn.Cell` object. -- Return values: Gradient of each input in the forward porpagation. The number of return values must be the same as the number of inputs in the forward porpagation. - -A complete example is as follows: - -```python -import mindspore.nn as nn -import mindspore as ms -import mindspore.ops as ops - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.matmul = ops.MatMul() - - def construct(self, x, y): - out = self.matmul(x, y) - return out - - def bprop(self, x, y, out, dout): - dx = x + 1 - dy = y + 1 - return dx, dy - - -class GradNet(nn.Cell): - def __init__(self, net): - super(GradNet, self).__init__() - self.net = net - self.grad_op = ops.GradOperation(get_all=True) - - def construct(self, x, y): - gradient_function = self.grad_op(self.net) - return gradient_function(x, y) - - -x = ms.Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=ms.float32) -y = ms.Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=ms.float32) -out = GradNet(Net())(x, y) -print(out) -``` - -```text - (Tensor(shape=[2, 3], dtype=Float32, value= - [[ 1.50000000e+00, 1.60000002e+00, 1.39999998e+00], - [ 2.20000005e+00, 2.29999995e+00, 2.09999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value= - [[ 1.00999999e+00, 1.29999995e+00, 2.09999990e+00], - [ 1.10000002e+00, 1.20000005e+00, 2.29999995e+00], - [ 3.09999990e+00, 2.20000005e+00, 4.30000019e+00]])) -``` - -Constraints - -- If the number of return values of the `bprop` function is 1, the return value must be written in the tuple format, that is, `return (dx,)`. -- In graph mode, the `bprop` function needs to be converted into a graph IR. Therefore, the static graph syntax must be complied with. For details, see [Static Graph Syntax Support](https://www.mindspore.cn/docs/en/master/note/static_graph_syntax_support.html). -- Only the gradient of the forward porpagation input can be returned. The gradient of the `Parameter` cannot be returned. -- `Parameter` cannot be used in `bprop`. - ## High-order Derivation High-order differentiation is used in domains such as AI-supported scientific computing and second-order optimization. For example, in the molecular dynamics simulation, when the potential energy is trained using the neural network, the derivative of the neural network output to the input needs to be computed in the loss function, and then the second-order cross derivative of the loss function to the input and the weight exists in backward propagation. @@ -369,6 +200,7 @@ import mindspore as ms class Net(nn.Cell): """Feedforward network model""" + def __init__(self): super(Net, self).__init__() self.sin = ops.Sin() @@ -377,28 +209,6 @@ class Net(nn.Cell): out = self.sin(x) return out -class Grad(nn.Cell): - """First-order derivation""" - def __init__(self, network): - super(Grad, self).__init__() - self.grad = ops.GradOperation() - self.network = network - - def construct(self, x): - gout = self.grad(self.network)(x) - return gout - -class GradSec(nn.Cell): - """Second order derivation""" - def __init__(self, network): - super(GradSec, self).__init__() - self.grad = ops.GradOperation() - self.network = network - - def construct(self, x): - gout = self.grad(self.network)(x) - return gout - x_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32) net = Net() @@ -455,28 +265,6 @@ class Net(nn.Cell): out2 = self.cos(x) return out1, out2 -class Grad(nn.Cell): - """First-order derivation""" - def __init__(self, network): - super(Grad, self).__init__() - self.grad = ops.GradOperation() - self.network = network - - def construct(self, x): - gout = self.grad(self.network)(x) - return gout - -class GradSec(nn.Cell): - """Second order derivation""" - def __init__(self, network): - super(GradSec, self).__init__() - self.grad = ops.GradOperation() - self.network = network - - def construct(self, x): - gout = self.grad(self.network)(x) - return gout - x_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32) net = Net() @@ -547,34 +335,12 @@ class Net(nn.Cell): out2 = self.cos(x) - self.sin(y) return out1, out2 -class Grad(nn.Cell): - """First-order derivation""" - def __init__(self, network): - super(Grad, self).__init__() - self.grad = ops.GradOperation(get_all=True) - self.network = network - - def construct(self, x, y): - gout = self.grad(self.network)(x, y) - return gout - -class GradSec(nn.Cell): - """Second order derivation""" - def __init__(self, network): - super(GradSec, self).__init__() - self.grad = ops.GradOperation(get_all=True) - self.network = network - - def construct(self, x, y): - gout = self.grad(self.network)(x, y) - return gout - x_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32) y_train = ms.Tensor(np.array([3.1415926]), dtype=ms.float32) net = Net() -firstgrad = Grad(net) -secondgrad = GradSec(firstgrad) +firstgrad = ops.grad(net, grad_position=(0, 1)) +secondgrad = ops.grad(firstgrad, grad_position=(0, 1)) output = secondgrad(x_train, y_train) # Print the result. diff --git a/tutorials/source_en/beginner/autograd.md b/tutorials/source_en/beginner/autograd.md index d1a66367507c96825e0c9f8077a9508f683f74f4..6df8ed94525fb05b10addd707bbd161f65054dec 100644 --- a/tutorials/source_en/beginner/autograd.md +++ b/tutorials/source_en/beginner/autograd.md @@ -6,10 +6,10 @@ Automatic differentiation can calculate a derivative value of a derivative funct MindSpore uses `ops.grad` and `ops.value_and_grad` to calculate the first-order derivative. `ops.grad` only returns gradient, while `ops.value_and_grad` returns the network forward calculation result and gradient. The `ops.value_and_grad` attributes are as follows: -+ `fn`: the function or network to be derived. -+ `grad_position`: specifies the index of the input position to be derived. If the index is int type, it means to derive for a single input; if tuple type, it means to derive for the position of the index within the tuple, where the index starts from 0; and if None, it means not to derive for the input. In this scenario, `weights` is non-None. Default: 0. -+ `weights`: the network variables that need to return the gradients in the training network. Generally the network variables can be obtained by `weights = net.trainable_params()`. Default: None. -+ `has_aux`: symbol for whether to return auxiliary arguments. If True, the number of `fn` outputs must be more than one, where only the first output of `fn` is involved in the derivation and the other output values will be returned directly. Default: False. +- `fn`: the function or network to be derived. +- `grad_position`: specifies the index of the input position to be derived. If the index is int type, it means to derive for a single input; if tuple type, it means to derive for the position of the index within the tuple, where the index starts from 0; and if None, it means not to derive for the input. In this scenario, `weights` is non-None. Default: 0. +- `weights`: the network variables that need to return the gradients in the training network. Generally the network variables can be obtained by `weights = net.trainable_params()`. Default: None. +- `has_aux`: symbol for whether to return auxiliary arguments. If True, the number of `fn` outputs must be more than one, where only the first output of `fn` is involved in the derivation and the other output values will be returned directly. Default: False. This chapter uses `ops.value_and_grad` in MindSpore to find first-order derivatives of the network. diff --git a/tutorials/source_zh_cn/advanced/derivation.ipynb b/tutorials/source_zh_cn/advanced/derivation.ipynb index 41ce617ca7d025da82a54f86ecf59f07efd2e82f..60c8cb32d9d489aec63c612d3befd21183ce2e2b 100644 --- a/tutorials/source_zh_cn/advanced/derivation.ipynb +++ b/tutorials/source_zh_cn/advanced/derivation.ipynb @@ -226,7 +226,7 @@ "id": "9895bafe", "metadata": {}, "source": [ - "从上面的打印可以看出,由于对`out2`设置了`stop_gradient`, 所以`out2`没有对梯度计算有任何的贡献,其输出结果与未加`out2`算子时一致。\n", + "从上面的打印可以看出,由于对`out2`设置了`stop_gradient`,所以`out2`没有对梯度计算有任何的贡献,其输出结果与未加`out2`算子时一致。\n", "\n", "下面删除`out2 = stop_gradient(out2)`,再来看一下输出结果。示例代码为:" ] @@ -373,7 +373,7 @@ "\n", "$$f_2(x) = cos(x) \\tag{3}$$\n", "\n", - "梯度计算时由于MindSpore采用的是反向自动微分机制, 会对输出结果求和后再对输入求导。 因此其一阶导数是:\n", + "梯度计算时由于MindSpore采用的是反向自动微分机制,会对输出结果求和后再对输入求导。因此其一阶导数是:\n", "\n", "$$f'(x) = cos(x) -sin(x) \\tag{4}$$\n", "\n", @@ -530,7 +530,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "从上面的打印结果可以看出,输出对输入$x$的二阶导数$-sin(3.1415926) - cos(3.1415926)$的值接近于$1$, 输出对输入$y$的二阶导数$sin(3.1415926) + cos(3.1415926)$的值接近于$-1$。\n", + "从上面的打印结果可以看出,输出对输入$x$的二阶导数$-sin(3.1415926) - cos(3.1415926)$的值接近于$1$,输出对输入$y$的二阶导数$sin(3.1415926) + cos(3.1415926)$的值接近于$-1$。\n", "\n", "> 由于不同计算平台的精度可能存在差异,因此本章节中的代码在不同平台上的执行结果会存在微小的差别。" ]