switch_class

switch 类

1. init/exit:

static int __init switch_class_init(void)
    {
        return create_switch_class();   //新建一个switch类
    }
    static void __exit switch_class_exit(void)
    {
        class_destroy(switch_class);    //销毁一个switch类
    }
    module_init(switch_class_init);
    module_exit(switch_class_exit);
2. create_switch_class()

struct class *switch_class;         //声明两个全局变量
    static atomic_t device_count;

    static int create_switch_class(void)
    {
        if (!switch_class) {
            switch_class = class_create(THIS_MODULE, "switch");   //给this module创建一个switch类.
            if (IS_ERR(switch_class))
                return PTR_ERR(switch_class);
            atomic_set(&device_count, 0);                  //初始化一个原子变量,用来保护临界区。
        }
        return 0;
    }
3. switch_dev_register(): 设备注册函数,提供给别的驱动调用,用来注册switch设备。

int switch_dev_register(struct switch_dev *sdev)
    {
        int ret;
        if (!switch_class) {
            ret = create_switch_class();     //如果没有switch类,则创建一个。
            if (ret < 0)
                return ret;
        }
        sdev->index = atomic_inc_return(&device_count);   //原子变量+1,以获得原子锁,进入临界区操作。
        sdev->dev = device_create(switch_class, NULL,     
            MKDEV(0, sdev->index), NULL, sdev->name);     //创建一个名为sdev->name,设备号为MKDEV(0, sdev->index)的switch类设备。

        if (IS_ERR(sdev->dev))        //如果创建不成功就返回。
            return PTR_ERR(sdev->dev);
        ret = device_create_file(sdev->dev, &dev_attr_state);    //为设备添加 sysfs attribute:state。
        if (ret < 0)
            goto err_create_file_1;
        ret = device_create_file(sdev->dev, &dev_attr_name);    //为设备添加 sysfs attribute:name
        if (ret < 0)
            goto err_create_file_2;
        dev_set_drvdata(sdev->dev, sdev);   //给设备初始化。
        sdev->state = 0;
        return 0;
    err_create_file_2:
        device_remove_file(sdev->dev, &dev_attr_state);
    err_create_file_1:
        device_destroy(switch_class, MKDEV(0, sdev->index));
        printk(KERN_ERR "switch: Failed to register driver %s\n", sdev->name);
        return ret;
    }
    EXPORT_SYMBOL_GPL(switch_dev_register);
4. 注册完一个switch设备之后,就可以给该switch设备设置状态了, 即: switch_set_state()

void switch_set_state(struct switch_dev *sdev, int state)
    {
        char name_buf[120];
        char state_buf[120];
        char *prop_buf;
        char *envp[3];
        int env_offset = 0;
        int length;
        if (sdev->state != state) {
            sdev->state = state;        //如果设备状态不等于state,则赋值state.
            prop_buf = (char *)get_zeroed_page(GFP_KERNEL);    //获得一页初始化为0的内存 prop_buf。
            if (prop_buf) {
                length = name_show(sdev->dev, NULL, prop_buf);   //输出设备名 到prop_buf 中.
                if (length > 0) {
                    if (prop_buf[length - 1] == '\n')
                        prop_buf[length - 1] = 0;
                    snprintf(name_buf, sizeof(name_buf),
                        "SWITCH_NAME=%s", prop_buf);            //打印switch name
                    envp[env_offset++] = name_buf;
                }
                length = state_show(sdev->dev, NULL, prop_buf);
                if (length > 0) {
                    if (prop_buf[length - 1] == '\n')
                        prop_buf[length - 1] = 0;
                    snprintf(state_buf, sizeof(state_buf),    //打印switch state
                        "SWITCH_STATE=%s", prop_buf);
                    envp[env_offset++] = state_buf;
                }
                envp[env_offset] = NULL;
                kobject_uevent_env(&sdev->dev->kobj, KOBJ_CHANGE, envp); //给该switch设备传输一份带envp[]的数据。
                free_page((unsigned long)prop_buf);  //释放申请的内存。
            } else {
                printk(KERN_ERR "out of memory in switch_set_state\n");
                kobject_uevent(&sdev->dev->kobj, KOBJ_CHANGE);   //notify userspace by ending an uevent 通知用户空间,uevent操作结束。
            }
        }
    }
    EXPORT_SYMBOL_GPL(switch_set_state);
 5. name 和 state 属性操作

static ssize_t state_show(struct device *dev, struct device_attribute *attr,
            char *buf)
    {
        struct switch_dev *sdev = (struct switch_dev *)
            dev_get_drvdata(dev);
        if (sdev->print_state) {
            int ret = sdev->print_state(sdev, buf);   //打印state
            if (ret >= 0)
                return ret;
        }
        return sprintf(buf, "%d\n", sdev->state);
    }
    static ssize_t name_show(struct device *dev, struct device_attribute *attr,
            char *buf)
    {
        struct switch_dev *sdev = (struct switch_dev *)
            dev_get_drvdata(dev);
        if (sdev->print_name) {
            int ret = sdev->print_name(sdev, buf);   //打印name
            if (ret >= 0)
                return ret;
        }
        return sprintf(buf, "%s\n", sdev->name);
    }
    static DEVICE_ATTR(state, S_IRUGO | S_IWUSR, state_show, NULL);    //给用户空间提供接口
    static DEVICE_ATTR(name, S_IRUGO | S_IWUSR, name_show, NULL);

6. 一个重要的数据结构,描述switch设备的。

struct switch_dev {
        const char    *name;
        struct device    *dev;   
        int        index;
        int        state;
        ssize_t    (*print_name)(struct switch_dev *sdev, char *buf);
        ssize_t    (*print_state)(struct switch_dev *sdev, char *buf);
    };

 
简要分析以下代码: def _init_from_random_skill(self, env_ids, motion_ids, motion_times): # Random init from other skills state_switch_flags = [np.random.rand() < self.cfg['env']['state_switch_prob'] for _ in env_ids] for ind, env_id in enumerate(env_ids): if state_switch_flags[ind] and not self.state_random_flags[ind]: source_motion_class = self._motion_data.motion_class[motion_ids[ind]] source_motion_id = motion_ids[ind:ind+1] source_motion_time = motion_times[ind:ind+1] if self.state_search_to_align_reward: switch_motion_class, switch_motion_id, switch_motion_time, max_sim = \ random.choice(self.state_search_graph[source_motion_class][source_motion_id.item()][source_motion_time.item()]) if switch_motion_id is None and switch_motion_time is None: continue else: self.max_sim[env_id] = max_sim switch_motion_id = torch.tensor([switch_motion_id], device=self.device) switch_motion_time = torch.tensor([switch_motion_time], device=self.device) # switch_motion_time, new_source_motion_time = self._motion_data.resample_time(source_motion_id, switch_motion_id, weights=self.similarity_weights) # motion_times[ind:ind+1] = new_source_motion_time # resample the hoi_data_batch # self.hoi_data_batch[env_id], init_root_pos_source, init_root_rot_source, _, _, _, _, init_obj_pos_source , _, _, _ = \ # self._motion_data.get_initial_state(env_ids[ind:ind+1], source_motion_id, new_source_motion_time) else: switch_motion_id = self._motion_data.sample_switch_motions(source_motion_id) motion_len = self._motion_data.motion_lengths[switch_motion_id].item() switch_motion_time = torch.randint(2, motion_len - 2, (1,), device=self.device) # resample the hoi_data_batch # self.hoi_data_batch[env_id], init_root_pos_source, init_root_rot_source, _, _, _, _, init_obj_pos_source, _, _, _ = \ # self._motion_data.get_initial_state(env_ids[ind:ind+1], source_motion_id, source_motion_time) # 从switch中获取待对齐的初始状态 _, init_root_pos_switch, init_root_rot_switch, init_root_pos_vel_switch, init_root_rot_vel_switch, \ init_dof_pos_switch, init_dof_pos_vel_switch, \ init_obj_pos_switch, init_obj_pos_vel_switch, init_obj_rot_switch, init_obj_rot_vel_switch \ = self._motion_data.get_initial_state(env_ids[ind:ind+1], switch_motion_id, switch_motion_time) # _, _, _, _ \ # 从source中获取初始状态的对齐目标 self.hoi_data_batch[env_id], init_root_pos_source, init_root_rot_source, _, _, _, _, init_obj_pos_source , _, _, _ = \ self._motion_data.get_initial_state(env_ids[ind:ind+1], source_motion_id, source_motion_time) # 计算 yaw 差异: 我们要从switch参考系转到source参考系 yaw_source = torch_utils.quat_to_euler(init_root_rot_source)[0][2] yaw_switch = torch_utils.quat_to_euler(init_root_rot_switch)[0][2] yaw_diff = yaw_source - yaw_switch yaw_diff = (yaw_diff + torch.pi) % (2*torch.pi) - torch.pi # 对齐 root_pos self.init_root_pos[env_id] = self.rotate_xy(init_root_pos_switch[0], init_root_pos_switch[0], yaw_diff, init_root_pos_source[0]) #从 get_initial_state 返回的 init_root_pos_source、init_root_pos_switch 通常是 (1,3) # equal to: init_root_pos_switch[:2] = init_root_pos_source[:2] # 对齐 root_rot yaw_quat = quat_from_euler_xyz(torch.zeros_like(yaw_diff), torch.zeros_like(yaw_diff), yaw_diff) self.init_root_rot[env_id] = torch_utils.quat_multiply(yaw_quat, init_root_rot_switch) # 对齐 obj_pos #Z ball self.init_obj_pos[env_id] = self.rotate_xy(init_obj_pos_switch[0], init_root_pos_switch[0], yaw_diff, init_root_pos_source[0]) # 对齐 obj_rot # 因为是ball,所以不需要变换 # self.init_obj_rot[env_id] # 因为球的旋转不计算奖励,所以不需要更新 # 速度和dof不需要坐标对齐 self.init_root_pos_vel[env_id] = init_root_pos_vel_switch self.init_root_rot_vel[env_id] = init_root_rot_vel_switch self.init_dof_pos[env_id] = init_dof_pos_switch self.init_dof_pos_vel[env_id] = init_dof_pos_vel_switch self.init_obj_pos_vel[env_id] = init_obj_pos_vel_switch self.init_obj_rot_vel[env_id] = init_obj_rot_vel_switch if self.isTest: print(f"Switched from skill {switch_motion_class} to {source_motion_class} for env {env_id}") return motion_ids, motion_times
最新发布
09-18
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值