zynq petalinux AMP双核运行linux+ucos(或者裸机)之间进行 IPI 软中断通讯的实现

博主分享了自己实现的Linux驱动程序,用于在Zynq平台的双核CPU间进行启动、停止及中断通讯。驱动简化了官方remoteprocAMP程序的复杂性,提供了启动CPU1、停止CPU1和IPI中断通知的功能。通过设备树配置和ioctl接口,驱动能够加载CPU1的bin文件并在两CPU间进行数据交换。应用示例展示了如何在Linux用户空间触发和响应这些操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前几天试验了xilinx官方remotporc AMP程序,运行是运行起来了,但是感觉太复杂了,我只想要一个能启动停止cpu1,能在两个cpu之间方便通讯的功能就行了,看着remotproc框架一堆的代码,编译出来的elf文件体积还超级大心里就非常不爽,想着干脆自己实现一个简单处理驱动程序算了,so...经过几天的痛苦研究,也算完整实现了想要的功能

由于linux应用程序不能处理硬件中断,因此这个IPI通讯只能在驱动层进行处理,所以第一步,要编写一个驱动程序,实现cpu1的bin(不是elf)文件加载,控制启动和停止cpu1,处理cpu0和cpu1之间的ipi中断,完整驱动代码如下

/*  ampipidevice.c - The simplest kernel module.

* Copyright (C) 2013 - 2016 Xilinx, Inc
*
*   This program is free software; you can redistribute it and/or modify
*   it under the terms of the GNU General Public License as published by
*   the Free Software Foundation; either version 2 of the License, or
*   (at your option) any later version.

*   This program is distributed in the hope that it will be useful,
*   but WITHOUT ANY WARRANTY; without even the implied warranty of
*   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*   GNU General Public License for more details.
*
*   You should have received a copy of the GNU General Public License along
*   with this program. If not, see <http://www.gnu.org/licenses/>.

*/
#include <linux/kernel.h>
#include <linux/init.h>
#include <linux/module.h>
#include <linux/slab.h>
#include <linux/io.h>
#include <linux/interrupt.h>
#include <linux/uaccess.h> 
#include <linux/cdev.h>
#include <linux/of_address.h>
#include <linux/of_device.h>
#include <linux/of_platform.h>
#include <linux/of_irq.h>
#include <linux/fs.h>
#include <linux/fcntl.h> 

#include <../../arch/arm/mach-zynq/common.h>
#include <linux/irqchip/arm-gic.h> 

extern int zynq_cpun_stop(int cpu);



/* Standard module information, edit as appropriate */
MODULE_LICENSE("GPL");
MODULE_AUTHOR
    ("Xilinx Inc.");
MODULE_DESCRIPTION
    ("ampipidevice - loadable module template generated by petalinux-create -t modules");

#define DRIVER_NAME "ampipi_Device"
#define DRIVER_NUM   1                      



#define StarCpu1			0x10000000
#define StopCpu1			0x20000000
#define KickCpu1			0x30000000

struct ampipidevice_local 
{
	int cpu0_to_cpu1_ipi;
	int cpu1_to_cpu0_ipi;
	unsigned long mem_start;
	unsigned long mem_end;
	void __iomem *base_addr;
};


struct ampipi_dev 
{ 
	dev_t devid;                   /*  设备号  */ 
	struct cdev chdev;            	/* cdev 结构体  */ 
	struct class *class;         	/*  类  */ 
	struct device *device;     		/*  设备  */ 
	struct ampipidevice_local *param;
	struct fasync_struct *async_queue; 
}; 

static struct ampipi_dev ampipi;  


/*********************************************************************
*
*
*
**********************************************************************/
static int ampipi_open(struct inode *inode, struct file *filp) 
{ 

	cpu_up(1);
	printk("ampipi Dev: open success  \r\n"); 
					
	return 0; 
} 

/*********************************************************************
*
*
*
**********************************************************************/
static ssize_t ampipi_write(struct file *filp, const char __user *buf,size_t cnt, loff_t *offt) 
{
	int ret; 
	
	ret = copy_from_user((unsigned char*)(ampipi.param->base_addr + (*offt)), buf, cnt);  
	
	if(0 > ret)
	{
		printk(KERN_ERR "ampipi Dev: Failed to copy data from user space \r\n"); 
		return -EFAULT; 			
	}
	
	printk("ampipi Dev: write add: %08X  cnt: %d\r\n",(unsigned int)(ampipi.param->base_addr + (*offt)),cnt); 
	
	*offt = *offt + cnt;
	
	return cnt;
}

/*********************************************************************
*
*
*
**********************************************************************/
static ssize_t ampipi_read(struct file *filp, char __user *buf, size_t cnt, loff_t *offt) 
{ 
	int ret = 0; 
	
	ret = copy_to_user(buf, (unsigned char*)(ampipi.param->base_addr + (*offt)) , cnt); 
	
	if(ret < 0)
	{ 
		printk(KERN_ERR "ampipi Dev: Failed to copy data to user space \r\n"); 
	} 
	
	*offt = *offt + cnt;
	
	return cnt;  

} 
/*********************************************************************
*
*
*
**********************************************************************/
static long ampipi_ioctl(struct file *filp,unsigned int cmd, unsigned long arg) 
{
	  int ret = 0; 
		printk("cmd %d \n" ,cmd);
		

		switch(cmd)
		{
			case StarCpu1:
				

					ret = cpu_down(1);
				
					if (ret && (ret != -EBUSY)) 
					{
						printk("Can't release cpu1 %d\n",ret);
						return ret;
					}
					
					zynq_cpun_stop(1);
					zynq_cpun_start((u32)ampipi.param->mem_start, 1);
					
					printk("StarCpu1\n");

				break;
				
			case StopCpu1:
				

				ret = cpu_up(1);
				if (ret)
				{
					printk("Can't power on cpu1 %d\n", ret);
				}
				
				printk("StopCpu1\n");
				break;
				
				case KickCpu1:
					
					gic_raise_softirq(cpumask_of(1), ampipi.param->cpu0_to_cpu1_ipi);
					
					printk("KickCpu %d  %d \n",1,ampipi.param->cpu0_to_cpu1_ipi);
					
					break;		
		}
		
		return 0;
}
/*********************************************************************
*
*
*
**********************************************************************/
loff_t ampipi_llseek(struct file *filp, loff_t off, int whence)
{
        loff_t newpos;

        switch(whence)
        {
        case 0: /* SEEK_SET */
        	
                newpos = off;
                
                break;

        case 1: /* SEEK_CUR */
        	
                newpos = filp->f_pos + off;
                
                break;

        case 2: /* SEEK_END */
        	
        				if(whence>0)
 									newpos = (ampipi.param->mem_end - ampipi.param->mem_start);
 								else
 									newpos = (ampipi.param->mem_end - ampipi.param->mem_start)+ whence;
 									
                break;

        default: /* can't happen */
                return -EINVAL;
        }
        
        if (newpos < 0)
                return -EINVAL;
                
        filp->f_pos = newpos;
        
        return newpos;
}


/*********************************************************************
*
*
*
**********************************************************************/
static int ampipi_fasync(int fd, struct file *filp, int on) 
{
	return fasync_helper(fd, filp, on, &ampipi.async_queue); 
}
/*********************************************************************
*
*
*
**********************************************************************/
static int ampipi_fasync_release(struct inode *inode, struct file *filp) 
{
	return ampipi_fasync(-1, filp, 0); 
}

/*********************************************************************
*
*
*
**********************************************************************/
static struct file_operations ampipi_fops = 
{ 
	.owner = THIS_MODULE, 
	.open = ampipi_open, 
	.write = ampipi_write, 
	.read = ampipi_read, 
	.unlocked_ioctl = ampipi_ioctl, 
	.llseek = ampipi_llseek, 
	.fasync = ampipi_fasync, 
	.release = ampipi_fasync_release, 
 }; 

/*********************************************************************
*
*
*
**********************************************************************/
static int ampipi_init(struct ampipidevice_local *nd) 
{ 
		ampipi.param = nd;
		
		cpu_down(1);
		
		return 0;
}



/*********************************************************************
*
*IPI 中断通知Linux CPU0 有信息要处理
*
**********************************************************************/
static void cpu1_to_cpu0_ipi_kick(void)
{
		
	if(ampipi.async_queue) 
		{
			
			printk("ipi %d  kick SIGIO \n",ampipi.param->cpu1_to_cpu0_ipi);
			kill_fasync(&ampipi.async_queue, SIGIO, POLL_IN); 
		}
		else
		{
			printk("ipi %d kick but async_queue is null \n",ampipi.param->cpu1_to_cpu0_ipi);
		}

}

/*********************************************************************
*
*驱动探测,初始化
*
**********************************************************************/
static int ampipidevice_probe(struct platform_device *pdev)
{
	struct resource *r_mem; /* IO mem resources */
	struct device *dev = &pdev->dev;
	struct ampipidevice_local *lp = NULL;

	int rc = 0;
	
	printk("ampipi Device Tree Probing\n");
	/* Get iospace for the device */
	r_mem = platform_get_resource(pdev, IORESOURCE_MEM, 0);
	if (!r_mem) 
	{
		dev_err(dev, "invalid address\n");
		return -ENODEV;
	}
	
	lp = (struct ampipidevice_local *) kmalloc(sizeof(struct ampipidevice_local), GFP_KERNEL);
	if (!lp) 
	{
		dev_err(dev, "Cound not allocate ampipidevice device\n");
		return -ENOMEM;
	}
	
	
	dev_set_drvdata(dev, lp);
	
	lp->mem_start = r_mem->start;
	lp->mem_end = r_mem->end;

	if (!request_mem_region(lp->mem_start,lp->mem_end - lp->mem_start + 1,DRIVER_NAME)) 
	{
		dev_err(dev, "Couldn't lock memory region at %p\n",
			(void *)lp->mem_start);
		rc = -EBUSY;
		goto error1;
	}

	lp->base_addr = ioremap(lp->mem_start, lp->mem_end - lp->mem_start + 1);
	if (!lp->base_addr) 
	{
		dev_err(dev, "ampipidevice: Could not allocate iomem\n");
		rc = -EIO;
		goto error2;
	}

	lp->cpu0_to_cpu1_ipi = 12;
	lp->cpu1_to_cpu0_ipi = 13;//默认值

	/* Read ipi12 ipi number 用于 ucos 通知 linux */
	rc = of_property_read_u32(pdev->dev.of_node, "ipi12",&lp->cpu1_to_cpu0_ipi);
	if (rc < 0) 
	{
		dev_err(&pdev->dev, "unable to read property ipi 12");
		goto error3;
	}

	rc = set_ipi_handler(lp->cpu1_to_cpu0_ipi, cpu1_to_cpu0_ipi_kick,"cpu 1 kick cpu0");
	if (rc) 
	{
		dev_err(&pdev->dev, "IPI 12 handler already registered\n");
		goto error3;
	}

	/* Read ipi13 ipi number 用于 linux 通知 ucos */
	rc = of_property_read_u32(pdev->dev.of_node, "ipi13",&lp->cpu0_to_cpu1_ipi);
	if (rc < 0) 
	{
		dev_err(&pdev->dev, "unable to read property ipi 13");
		goto error3;
	}

	printk("ampipidevice at 0x%08x mapped to 0x%08x len= %08X , ipi= %d\n",
		(unsigned int __force)lp->mem_start,
		(unsigned int __force)lp->base_addr,
		(unsigned int __force)(lp->mem_end - lp->mem_start+1),
		lp->cpu0_to_cpu1_ipi);
		
	
	ampipi_init(lp); 
	
	rc = alloc_chrdev_region(&ampipi.devid, 0, DRIVER_NUM, DRIVER_NAME); 
	if(rc)
	{
		dev_err(&pdev->dev, "unable to alloc_chrdev_region ");
	}
	
	ampipi.chdev.owner = THIS_MODULE; 
	cdev_init(&ampipi.chdev, &ampipi_fops); 
	
	rc = cdev_add(&ampipi.chdev, ampipi.devid, 1); 
	if(rc)
	{
		dev_err(&pdev->dev, "unable to cdev_add ");
	}	
	
	ampipi.class = class_create(THIS_MODULE, DRIVER_NAME); 
	if (IS_ERR(ampipi.class)) 
	{ 
		dev_err(&pdev->dev, "unable to class_create ");
	} 	
	
	ampipi.device = device_create(ampipi.class, &pdev->dev,ampipi.devid, NULL, DRIVER_NAME); 
	if (IS_ERR(ampipi.device)) 
	{ 
		dev_err(&pdev->dev, "unable to device_create ");
	} 	
	
	return 0;
	
error3:
	clear_ipi_handler(lp->cpu0_to_cpu1_ipi);	
error2:
	release_mem_region(lp->mem_start, lp->mem_end - lp->mem_start + 1);
error1:
	kfree(lp);
	dev_set_drvdata(dev, NULL);
	return rc;
}

static int ampipidevice_remove(struct platform_device *pdev)
{
	struct device *dev = &pdev->dev;
	struct ampipidevice_local *lp = dev_get_drvdata(dev);
	
	
	device_destroy(ampipi.class, ampipi.devid); 
	class_destroy(ampipi.class); 	
	cdev_del(&ampipi.chdev); 	
	unregister_chrdev_region(ampipi.devid, DRIVER_NUM); 	

	clear_ipi_handler(lp->cpu0_to_cpu1_ipi);
	
	iounmap(lp->base_addr);
	
	release_mem_region(lp->mem_start, lp->mem_end - lp->mem_start + 1);
	
	kfree(lp);
	
	dev_set_drvdata(dev, NULL);
	
	cpu_up(1);
	
	return 0;
}


static struct of_device_id ampipidevice_of_match[] = 
{
	{ .compatible = "my_qq_is,260869626", },
	{ /* end of list */ },
};
MODULE_DEVICE_TABLE(of, ampipidevice_of_match);


static struct platform_driver ampipidevice_driver = 
{
	.driver = 
	{
		.name = DRIVER_NAME,
		.owner = THIS_MODULE,
		.of_match_table	= ampipidevice_of_match,
	},
	.probe		= ampipidevice_probe,
	.remove		= ampipidevice_remove,
};

static int __init ampipidevice_init(void)
{
	printk("ampipidevice module init...! \n");

	return platform_driver_register(&ampipidevice_driver);
}


static void __exit ampipidevice_exit(void)
{
	platform_driver_unregister(&ampipidevice_driver);
	printk(KERN_ALERT "ampipidevice module exit\n");
}

module_init(ampipidevice_init);
module_exit(ampipidevice_exit);

然后在设备树里面添加下面一段

ampipi_instance: ampipimod@0 {
		compatible = "my_qq_is,260869626";
		reg = <0x38000000 0x8000000>;
		ipi12 = <12>;
		ipi13 = <13>;
		
	};

然后编译内核和驱动即可,这个驱动提供了一个接口,write会将要运行在cpu1上的bin文件加载到设备树的内存地址里, ioctl提供starCPU1,stopCPU1,KickCpu1  3个命令,分别启动,停止,ipi中断通知cpu1的操作,

实际linux端应用编写如下

#include <stdio.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ioctl.h>
#include <signal.h>
#include <fcntl.h>

unsigned char cpu1_bin_data[114704];

static void sigio_signal_func(int signum)
{
	static ii=0;
	printf("CPU0 Kick up %d \n",ii++);//cpu1 ipi中断触发linux,通知linux 有消息处理

}



int main()
{
	int fd, ret,i=0;

    fd = open("/dev/acdp_Device", O_RDWR);
    if(0 > fd)
    {
    	 printf("acdp dev open failed!\r\n");
    	 return -1;
    }

    signal(SIGIO, sigio_signal_func); 				//	设置信号 SIGIO 的处理函数
    fcntl(fd, F_SETOWN, getpid()); 					//  将当前进程的进程号告诉给内核
    i = fcntl(fd, F_GETFD);  						//  获取当前的进程状态
    fcntl(fd, F_SETFL, i | FASYNC);           	//  设置进程启用异步通知功能

    lseek(fd,0,SEEK_SET);
    ret = write(fd, cpu1_bin_data, 114704);//将cpu1的bin文件写入内存中
    if(0 > ret)
    {
    	 printf("write Failed!\r\n");
    }

    ioctl(fd,0x10000000,NULL);//启动cpu1

while(1)
{
    sleep(5);
    ioctl(fd,0x30000000,NULL);//linux触发 cpu1 的ipi中断,通知cpu1 有消息到达
}

    close(fd);


    return 0;
}

实际ucos端程序如下

#include  <stdio.h>
#include  <Source/os.h>
#include  <ucos_bsp.h>

#include "xgpiops.h"
#include "xil_io.h"
#include "xscugic.h"

#define  SGI0_INTR_ID  0
#define  SGI1_INTR_ID  1
#define  SGI2_INTR_ID  2
#define  SGI3_INTR_ID  3
#define  SGI4_INTR_ID  4
#define  SGI5_INTR_ID  5
#define  SGI6_INTR_ID  6
#define  SGI7_INTR_ID  7
#define  SGI8_INTR_ID  8
#define  SGI9_INTR_ID  9
#define  SGI10_INTR_ID  10
#define  SGI11_INTR_ID  11
#define  SGI12_INTR_ID  12
#define  SGI13_INTR_ID  13
#define  SGI14_INTR_ID  14
#define  SGI15_INTR_ID  15

#define  TRIGGER_SELECTED  0x00000000
#define  TRIGGER_SELF  0x02000000
#define  TRIGGER_OTHER  0x01000000

#define  CPU_NO0  0
#define  CPU_NO1  1
#define  CPU_ID_LIST  0x00010000

#define  ICDSGIR  0xF8F01F00
#define  ICDIPTR  0xF8F01800

#define  INTC_DEVICE_ID   XPAR_PS7_SCUGIC_0_DEVICE_ID

XScuGic InterruptController; /* Instance of the Interrupt Controller */

u32 SGI_INTR;

int SGI_trigered;

void  SGI0_INTR_ID_ISR (void);

#define  DELAY  50000000

#define  COMM_VAL (*(volatile unsigned long *)(0xFFFF8000))

#define  COMM_TX_FLAG (*(volatile unsigned long *)(0xFFFF9000))

#define  COMM_TX_DATA (*(volatile unsigned long *)(0xFFFF9004))

#define  COMM_RX_FLAG (*(volatile unsigned long *)(0xFFFF9008))

#define  COMM_RX_DATA (*(volatile unsigned long *)(0xFFFF900C))


u32 SetupSGIIntrSystem(XScuGic *IntcInstancePtr,Xil_InterruptHandler Handler, u32 DeveiceId, u32 SgiIntr, u32 CpuNo)
{
    int Status;

    XScuGic_Config *IntcConfig;

	IntcConfig =  XScuGic_LookupConfig(DeveiceId);

    if (NULL  == IntcConfig)
    {
        return XST_FAILURE;
    }

    Status =  XScuGic_CfgInitialize(IntcInstancePtr, IntcConfig,IntcConfig->CpuBaseAddress);

    if (Status != XST_SUCCESS)
    {
        return XST_FAILURE;
    }

    XScuGic_SetPriorityTriggerType(IntcInstancePtr, SgiIntr,0xd0, 0x3);

    Status =  XScuGic_Connect(IntcInstancePtr, SgiIntr, (Xil_ExceptionHandler)Handler,0);

    if (Status != XST_SUCCESS)
    {
        return XST_FAILURE;
    }

    XScuGic_Enable(IntcInstancePtr, SgiIntr);

    XScuGic_InterruptMaptoCpu(IntcInstancePtr,CpuNo,SgiIntr);

    return XST_SUCCESS;

}

void  ExceptionSetup(XScuGic *IntcInstancePtr)
{
    Xil_ExceptionInit();

    Xil_ExceptionRegisterHandler(XIL_EXCEPTION_ID_INT,(Xil_ExceptionHandler)XScuGic_InterruptHandler,IntcInstancePtr);

    Xil_ExceptionEnable();
}

void  SGI1_INTR_ID_ISR (void)
{
	printf("CPU0: The software interrupt0 has been triggered\n\r");

    SGI_trigered=1;
}

/*
*********************************************************************************************************
*                                      LOCAL FUNCTION PROTOTYPES
*********************************************************************************************************
*/

void  MainTask (void *p_arg);


/*
*********************************************************************************************************
*                                               main()
*
* Description : Entry point for C code.
*
*********************************************************************************************************
*/

int main()
{

    UCOSStartup(MainTask);

    return 0;
}


/*
*********************************************************************************************************
*                                             MainTask()
*
* Description : Startup task example code.
*
* Returns     : none.
*
* Created by  : main().
*********************************************************************************************************
*/

void  MainTask (void *p_arg)
{
	OS_ERR  os_err;
	int i=0;
	static XGpioPs psGpioInstancePtr;
	XGpioPs_Config *GpioConfigPtr;
	int xStatus;
	int Status;

    SGI_INTR=SGI13_INTR_ID;
    SGI_trigered=0;
    Status =  SetupSGIIntrSystem(&InterruptController,(Xil_ExceptionHandler)SGI1_INTR_ID_ISR,INTC_DEVICE_ID, SGI_INTR,CPU_NO1);
	if (Status != XST_SUCCESS)
	{
		UCOS_Print("FAILED Xil_ExceptionHandler \n\r");
	}
	ExceptionSetup(&InterruptController);

	GpioConfigPtr =XGpioPs_LookupConfig(XPAR_PS7_GPIO_0_DEVICE_ID);
	if(GpioConfigPtr == NULL)
		UCOS_Print("PS GPIO INIT FAILED1 \n\r");

	xStatus =XGpioPs_CfgInitialize(&psGpioInstancePtr,GpioConfigPtr,GpioConfigPtr->BaseAddr);
	if(XST_SUCCESS != xStatus)
		UCOS_Print("PS GPIO INIT FAILED2 \n\r");

	XGpioPs_SetDirectionPin(&psGpioInstancePtr,0,1);
	XGpioPs_SetOutputEnablePin(&psGpioInstancePtr,0,1);

	printf("Starting application...\n");

	while(1)
	{
		XGpioPs_WritePin(&psGpioInstancePtr,0,1);
		OSTimeDlyHMSM(0, 0, 0, 200, OS_OPT_TIME_HMSM_STRICT, &os_err);
		XGpioPs_WritePin(&psGpioInstancePtr,0,0);
		OSTimeDlyHMSM(0, 0, 0, 200, OS_OPT_TIME_HMSM_STRICT, &os_err);

		XScuGic_SoftwareIntr(&InterruptController,SGI12_INTR_ID,0x1<<CPU_NO0);

	}
}

 

实际运行时当cpu1的ucos运行起来后每400ms会触发一次ipi中断,linux收到中断后打印 printf("CPU0 Kick up %d \n",ii++);然后linux发一次ipi中断给ucos,ucos收到后打印printf("CPU0: The software interrupt0 has been triggered\n\r");

串口截图如下

显示是乱的,实际是正常的,因为linux和ucos使用的是同一个串口,两个系统竞争使用导致打印错乱了,这个并不影响试验结果!

 

 

 

### AMP 技术简介 AMP(Automatic Mixed Precision)是一种用于加速深度学习训练的技术,它通过动态调整模型计算过程中的精度来减少内存占用并提升性能。以下是关于 AMP 技术及其实现方法的具体说明。 #### 1. PyTorch 原生支持的 AMP 实现方式 PyTorch 提供了 `torch.cuda.amp` 模块,其中包含了两个核心组件:`autocast` 和 `GradScaler`。 - **Autocast**: 自动将操作符切换到较低精度(FP16),从而提高 GPU 利用率和速度[^1]。 - **GradScaler**: 处理梯度缩放问题,在低精度环境下防止梯度消失或爆炸现象的发生。 以下是一个简单的代码示例展示如何使用 PyTorch 的原生 AMP 功能: ```python import torch from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() def train_step(model, optimizer, data, target): model.train() with autocast(): # 启用自动混合精度 output = model(data) loss = criterion(output, target) scaler.scale(loss).backward() # 缩放损失以防止梯度消失 scaler.step(optimizer) # 更新权重 scaler.update() # 调整缩放因子 ``` #### 2. 使用 Apex 库实现 AMP Apex 是 NVIDIA 开发的一个库,专门用于优化深度学习框架的性能表现。通过安装 Apex 并导入模块可以快速启用 AMP 支持。 具体步骤如下: - 安装 Apex 库; - 导入所需模块 (`from apex import amp`); - 初始化模型与优化器以便于它们能够兼容 FP16 计算环境。 下面是一段基于 Apex 的代码片段: ```python import torch from apex import amp model, optimizer = initialize_model_and_optimizer() model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # O1 表示推荐级别 for input, target in dataloader: output = model(input) loss = loss_fn(output, target) with amp.scale_loss(loss, optimizer) as scaled_loss: # 对损失进行缩放处理 scaled_loss.backward() optimizer.step() optimizer.zero_grad() ``` #### 3. AMP 技术的应用场景 AMP 主要应用于需要高性能计算资源的大规模机器学习项目中,尤其是在图像识别、自然语言处理等领域有广泛用途。由于它可以显著降低显存消耗以及缩短迭代时间,因此对于大规模分布式训练尤其重要。 --- ###
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值