self & __set__ __get__ & static、class、abstract

本文详细解析了Python中self参数的作用及用法,并介绍了不同类型的类方法,包括静态方法、类方法和抽象方法的区别与应用场景。
  • 部分转载自这里这里
  • 这里
  • 刚开始学习Python的类写法的时候觉得很是麻烦,为什么定义时需要而调用时又不需要,为什么不能内部简化从而减少我们敲击键盘的次数?你看完这篇文章后就会明白所有的疑问
  • 首先明确的是self只有在类的方法中才会有,独立的函数或方法是不必带有self的。self在定义类的方法时是必须有的,虽然在调用时不必传入相应的参数

self

self代表类的实例,而非类

class Test:
    def prt(self):
        print(self)
        print(self.__class__)
t = Test()
t.prt()
  • 执行结果
<__main__.Test object at 0x000000000284E080>
<class '__main__.Test'>
  • 从上面的例子中可以很明显的看出,self代表的是类的实例。而self.class则指向类。

self不必非写成self

  • self名称不是必须的,在python中self不是关键词,你可以定义成a或b或其它名字都可以,但是约定成俗,不要搞另类,大家会不明白的
class Test:
    def prt(this):
        print(this)
        print(this.__class__)
t = Test()
t.prt()

self可以不写吗

  • 在Python的解释器内部,当我们调用t.prt()时,实际上Python解释成Test.prt(t),也就是说把self替换成类的实例。
  • 有兴趣的童鞋可以把上面的t.prt()一行改写一下,运行后的实际结果完全相同。
  • 实际上已经部分说明了self在定义时不可以省略,如果非要试一下,那么请看下面:
class Test:
    def prt():
        print(self)
t = Test()
t.prt()
  • 运行时提醒错误如下:prt在定义时没有参数,但是我们运行时强行传了一个参数。
  • 由于上面解释过了t.prt()等同于Test.prt(t),所以程序提醒我们多传了一个参数t。
Traceback (most recent call last):
  File "h.py", line 6, in <module>
    t.prt()
TypeError: prt() takes 0 positional arguments but 1 was given
  • 当然,如果我们的定义和调用时均不传类实例是可以的,这就是类方法。
class Test:
    def prt():
        print(__class__)
Test.prt()
  • 运行结果如下
<class '__main__.Test'>
  • 在继承时,传入的是哪个实例,就是那个传入的实例,而不是指定义了self的类的实例
class Parent:
    def pprt(self):
        print(self)
class Child(Parent):
    def cprt(self):
        print(self)
c = Child()
c.cprt()
c.pprt()
p = Parent()
p.pprt()
  • 运行结果如下
<__main__.Child object at 0x0000000002A47080>
<__main__.Child object at 0x0000000002A47080>
<__main__.Parent object at 0x0000000002A47240>
  • 运行c.cprt()时应该没有理解问题,指的是Child类的实例。
  • 但是在运行c.pprt()时,等同于Child.pprt(c),所以self指的依然是Child类的实例,由于self中没有定义pprt()方法,所以沿着继承树往上找,发现在父类Parent中定义了pprt()方法,所以就会成功调用。

在描述符类中,self指的是描述符类的实例

class Desc:
    def __get__(self, ins, cls):
        print('self in Desc: %s ' % self )
        print(self, ins, cls)
class Test:
    x = Desc()
    def prt(self):
        print('self in Test: %s' % self)
t = Test()
t.prt()
t.x
  • 运行结果如下:
self in Test: <__main__.Test object at 0x0000000002A570B8>
self in Desc: <__main__.Desc object at 0x000000000283E208>
<__main__.Desc object at 0x000000000283E208> <__main__.Test object at 0x0000000002A570B8> <class '__main__.Test'>
  • 大部分童鞋开始有疑问了,为什么在Desc类中定义的self不是应该是调用它的实例t吗?怎么变成了Desc类的实例了呢?
  • 注意:此处需要睁大眼睛看清楚了,这里调用的是t.x,也就是说是Test类的实例t的属性x,由于实例t中并没有定义属性x,所以找到了类属性x,而该属性是描述符属性,为Desc类的实例而已,所以此处并没有顶用Test的任何方法。
  • 那么我们如果直接通过类来调用属性x也可以得到相同的结果。
  • 下面是把t.x改为Test.x运行的结果。
self in Test: <__main__.Test object at 0x00000000022570B8>
self in Desc: <__main__.Desc object at 0x000000000223E208>
<__main__.Desc object at 0x000000000223E208> None <class '__main__.Test'>
  • 题外话:由于在很多时候描述符类中仍然需要知道调用该描述符的实例是谁,所以在描述符类中存在第二个参数ins,用来表示调用它的类实例,所以t.x时可以看到第三行中的运行结果中第二项为

总结

  • self在定义时需要定义,但是在调用时会自动传入。
  • self的名字并不是规定死的,但是最好还是按照约定是用self
  • self总是指调用时的类的实例。

set get 等解释 & descriptor

  • 如果你和我一样,曾经对method和function以及对它们的各种访问方式包括self参数的隐含传递迷惑不解,建议你耐心的看下去。这里还提到了Python属性查找策略,使你清楚的知道Python处理obj.attr和obj.attr=val时,到底做了哪些工作
  • Python中,对象的方法也是也可以认为是属性,所以下面所说的属性包含方法在内
  • 先定义下面这个类,还定义了它的一个实例,留着后面用
class T(object):
    name = 'name'
    def hello(self):
        print ("hello")
t = T()
print (dir(t))
['__class__', '__delattr__', '__dict__', '__doc__', '__getattribute__',
 '__hash__', '__init__', '__module__', '__new__', '__reduce__', '__reduce_ex__',
 '__repr__', '__setattr__', '__str__', '__weakref__', 'hello', 'name']
  • 属性可以分为两类,一类是Python自动产生的,如classhash等,另一类是我们自定义的,如上面的hello,name。我们只关心自定义属性。
  • 类和实例对象(实际上,Python中一切都是对象,类是type的实例)都有dict属性,里面存放它们的自定义属性(对与类,里面还存放了别的东西)。
>>> t.__dict__
{}
>>> T.__dict__
<dictproxy object at 0x00CD0FF0>
>>> dict(T.__dict__)            #由于T.__dict__并没有直接返回dict对象,这里进行转换,以方便观察其中的内容
{'__module__': '__main__', 'name': 'name',
 'hello': <function hello at 0x00CC2470>,
 '__dict__': <attribute '__dict__' of 'T' objects>,
 '__weakref__': <attribute '__weakref__' of 'T' objects>, '__doc__': None}
  • 有些内建类型,如list和string,它们没有dict属性,随意没办法在它们上面附加自定义属性
  • 到现在为止t.dict是一个空的字典,因为我们并没有在t上自定义任何属性,它的有效属性hello和name都是从T得到的。T的dict中包含hello和name。当遇到t.name语句时,Python怎么找到t的name属性呢?
  • 首先,Python判断name属性是否是个自动产生的属性,如果是自动产生的属性,就按特别的方法找到这个属性,当然,这里的name不是自动产生的属性,而是我们自己定义的,Python于是到t的dict中寻找。还是没找到。
  • 接着,Python找到了t所属的类T,搜索T.dict,期望找到name,很幸运,直接找到了,于是返回name的值:字符串‘name’。如果在T.dict中还没有找到,Python会接着到T的父类(如果T有父类的话)的dict中继续查找。
  • 这不足以解决我们的困惑,因为事情远没有这么简单,上面说的其实是个简化的步骤
  • 继续上面的例子,对于name属性T.name和T.dict[‘name’]是完全一样的
>>> T.name
'name'
>>> T.__dict__['name']
'name'
  • 但是对于hello,情形就有些不同了
>>> T.hello
<unbound method T.hello>
>>> T.__dict__['hello']
<function hello at 0x00CC2470>
  • 可以发现,T.hello是个unbound method。而T.dict[‘hello’]是个函数(不是方法)。
  • 推断:方法在类的dict中是以函数的形式存在的(方法的定义和函数的定义简直一样,除了要把第一个参数设为self)。那么T.hello得到的应该也是个函数啊,怎么成了unbound method了。
  • 再看看从实例t中访问hello
>>> t.hello
<bound method T.hello of <__main__.T object at 0x00CD0E50>>
  • 是一个bound method。
  • 有意思,按照上面的查找策略,既然在T的dict中hello是个函数,那么T.hello和t.hello应该都是同一个函数才对。到底是怎么变成方法的,而且还分为unbound method和bound method。
  • 关于unbound和bound到还好理解,我们不妨先作如下设想:方法是要从实例调用的嘛(指实例方法,classmethod和staticmethod后面讲),如果从类中访问,如T.hello,hello没有和任何实例发生联系,也就是没绑定(unbound)到任何实例上,所以是个unbound,对t.hello的访问方式,hello和t发生了联系,因此是bound。
  • 但从函数到方法
class Descriptor(object):
    def __get__(self, obj, type=None):
            return 'get', self, obj, type
    def __set__(self, obj, val):
        print ('set', self, obj, val)
    def __delete__(self, obj):
        print ('delete', self, obj)
  • 这里setdelete其实可以不出现,不过为了后面的说明,暂时把它们全写上。
  • 下面解释一下三个方法的参数:
  • self当然不用说,指的是当前Descriptor的实例。obj值拥有属性的对象。这应该不难理解,前面已经说了,descriptor是对象的稍微有点特殊的属性,这里的obj就是拥有它的对象,要注意的是,如果是直接用类访问descriptor(别嫌啰嗦,descriptor是个属性,直接用类访问descriptor就是直接用类访问类的属性),obj的值是None。type是obj的类型,刚才说过,如果直接通过类访问descriptor,obj是None,此时type就是类本身。
  • 三个方法的意义,假设T是一个类,t是它的一个实例,d是T的一个descriptor属性(牛什么啊,不就是有个get方法吗!),value是一个有效值:
  • 读取属性时,如T.d,返回的是d._get_(None, T)的结果,t.d返回的是d._get_(t, T)的结果
  • 设置属性时,t.d = value,实际上调用d._set_(t, value),T.d = value,这是真正的赋值,T.d的值从此变成value。删除属性和设置属性类似。
  • 下面用例子说明,看看Python中执行是怎么样的:
  • 重新定义我们的类T和实例t
// 这个代码有个小bug
class T(object):
    d = Descriptor()
t = T()
  • d是T的类属性,作为Descriptor的实例,它有get等方法,显然,d满足了所有的条件,现在它就是一个descriptor!
>>> t.d         #t.d,返回的实际是d.__get__(t, T)
('get', <__main__.Descriptor object at 0x00CD9450>, <__main__.T object at 0x00CD0E50>, <class '__main__.T'>)
>>> T.d        #T.d,返回的实际是d.__get__(None, T),所以obj的位置为None
('get', <__main__.Descriptor object at 0x00CD9450>, None, <class '__main__.T'>)
>>> t.d = 'hello'   #在实例上对descriptor设置值。要注意的是,现在显示不是返回值,而是__set__方法中print语句输出的。
set <__main__.Descriptor object at 0x00CD9450> <__main__.T object at 0x00CD0E50> hello
>>> t.d         #可见,调用了Python调用了__set__方法,并没有改变t.d的值
('get', <__main__.Descriptor object at 0x00CD9450>, <__main__.T object at 0x00CD0E50>, <class '__main__.T'>)
>>> T.d = 'hello'   #没有调用__set__方法
>>> T.d                #确实改变了T.d的值
'hello'
>>> t.d               #t.d的值也变了,这可以理解,按我们上面说的属性查找策略,t.d是从T.__dict__中得到的T.__dict__['d']的值是'hello',t.d当然也是'hello'
'hello'
  • data descriptor和non-data descriptor
  • 象上面的d,同时具有getset方法,这样的descriptor叫做data descriptor,如果只有get方法,则叫做non-data descriptor。容易想到,由于non-data descriptor没有set方法,所以在通过实例对属性赋值时,例如上面的t.d = ‘hello’,不会再调用set方法,会直接把t.d的值变成’hello’吗?口说无凭,实例为证:
class Descriptor(object):
    def __get__(self, obj, type=None):
            return 'get', self, obj, type
class T(object):
       d = Descriptor()
t = T()
>>> t.d
('get', <__main__.Descriptor object at 0x00CD9550>, <__main__.T object at 0x00CD9510>, <class '__main__.T'>)
>>> t.d = 'hello'
>>> t.d
'hello'
  • 在实例上对non-data descriptor赋值隐藏了实例上的non-data descriptor!

总结

  • 是时候坦白真正详细的属性查找策略 了,对于obj.attr(注意:obj可以是一个类):
  • 1.如果attr是一个Python自动产生的属性,找到!(优先级非常高!)
  • 2.查找obj.class.dict,如果attr存在并且是data descriptor,返回data descriptor的get方法的结果,如果没有继续在obj.class的父类以及祖先类中寻找data descriptor
  • 3.在obj.dict中查找,这一步分两种情况,第一种情况是obj是一个普通实例,找到就直接返回,找不到进行下一步。第二种情况是obj是一个类,依次在obj和它的父类、祖先类的dict中查找,如果找到一个descriptor就返回descriptor的get方法的结果,否则直接返回attr。如果没有找到,进行下一步。
  • 4.在obj.class.dict中查找,如果找到了一个descriptor(插一句:这里的descriptor一定是non-data descriptor,如果它是data descriptor,第二步就找到它了)descriptor的get方法的结果。如果找到一个普通属性,直接返回属性值。如果没找到,进行下一步。
  • 5.很不幸,Python终于受不了。在这一步,它raise AttributeError
  • 利用这个,我们简单分析一下上面为什么要强调descriptor要在类中才行。我们感兴趣的查找步骤是2,3,4。第2步和第4步都是在类中查找。对于第3步,如果在普通实例中找到了,直接返回,没有判断它有没有get()方法。
  • 对属性赋值时的查找策略 ,对于obj.attr = value
  • 1.查找obj.class.dict,如果attr存在并且是一个data descriptor,调用attr的set方法,结束。如果不存在,会继续到obj.class的父类和祖先类中查找,找到 data descriptor则调用其set方法。没找到则进入下一步。
  • 2.直接在obj.dict中加入obj.dict[‘attr’] = value
  • 顺便分析下为什么在实例上对non-data descriptor赋值隐藏了实例上的non-data descriptor。
  • 接上面的non-data descriptor例子
>>> t.__dict__
{'d': 'hello'}
  • 在t的dict里出现了d这个属性。根据对属性赋值的查找策略,第1步,确实在t.class.dict也就是T.dict中找到了属性d,但它是一个non-data descriptor,不满足data descriptor的要求,进入第2步,直接在t的dict属性中加入了属性和属性值。当获取t.d时,执行查找策略,第2步在T.dict中找到了d,但它是non-data descriptor,步满足要求,进行第3步,在t的dict中找到了d,直接返回了它的值’hello’。
  • 说了这么半天,还没到函数和方法!
  • 算了,明天在说吧
  • 简单提一下,所有的函数(方法)都有_get_方法,当它们在类的_dict_中是,它们就是non-data descriptor。

static、class、abstract方法

  • 方法就是一个函数,它作为一个类属性而存在,你可以用如下方式来声明、访问一个函数:
>>> class Pizza(object):
...     def __init__(self, size):
...         self.size = size
...     def get_size(self):
...         return self.size
...
>>> Pizza.get_size
<unbound method Pizza.get_size>
  • Python在告诉你,属性_get_size是类Pizza的一个未绑定方法。这是什么意思呢?很快我们就会知道答案:
>>> Pizza.get_size()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unbound method get_size() must be called with Pizza instance as first argument (got nothing instead)
  • 我们不能这么调用,因为它还没有绑定到Pizza类的任何实例上,它需要一个实例作为第一个参数传递进去(Python2必须是该类的实例,Python3中可以是任何东西),尝试一下:
>>> Pizza.get_size(Pizza(42))
42
  • 太棒了,现在用一个实例作为它的的第一个参数来调用,整个世界都清静了,如果我说这种调用方式还不是最方便的,你也会这么认为的;没错,现在每次调用这个方法的时候我们都不得不引用这个类,如果不知道哪个类是我们的对象,长期看来这种方式是行不通的。
  • 那么Python为我们做了什么呢,它绑定了所有来自类_Pizza的方法以及该类的任何一个实例的方法。也就意味着现在属性get_size是Pizza的一个实例对象的绑定方法,这个方法的第一个参数就是该实例本身。
>>> Pizza(42).get_size
<bound method Pizza.get_size of <__main__.Pizza object at 0x7f3138827910>>
>>> Pizza(42).get_size()
42
  • 和我们预期的一样,现在不再需要提供任何参数给_get_size,因为它已经是绑定的,它的self参数会自动地设置给Pizza实例,下面代码是最好的证明:
>>> m = Pizza(42).get_size
>>> m()
42
  • 更有甚者,你都没必要使用持有Pizza对象的引用了,因为该方法已经绑定到了这个对象,所以这个方法对它自己来说是已经足够了。
  • 也许,如果你想知道这个绑定的方法是绑定在哪个对象上,下面这种手段就能得知:
>>> m = Pizza(42).get_size
>>> m.__self__
<__main__.Pizza object at 0x7f3138827910>
>>> # You could guess, look at this:
...
>>> m == m.__self__.get_size
True
  • 显然,该对象仍然有一个引用存在,只要你愿意你还是可以把它找回来。
  • 在Python3中,依附在类上的函数不再当作是未绑定的方法,而是把它当作一个简单地函数,如果有必要它会绑定到一个对象身上去,原则依然和Python2保持一致,但是模块更简洁:
>>> class Pizza(object):
...     def __init__(self, size):
...         self.size = size
...     def get_size(self):
...         return self.size
...
>>> Pizza.get_size
<function Pizza.get_size at 0x7f307f984dd0>

静态方法

  • 静态方法是一类特殊的方法,有时你可能需要写一个属于这个类的方法,但是这些代码完全不会使用到实例对象本身,例如:
class Pizza(object):
    @staticmethod
    def mix_ingredients(x, y):
        return x + y
    def cook(self):
        return self.mix_ingredients(self.cheese, self.vegetables)
  • 这个例子中,如果把_mix_ingredients作为非静态方法同样可以运行,但是它要提供self参数,而这个参数在方法中根本不会被使用到。这里的@staticmethod装饰器可以给我们带来一些好处:
  • Python不再需要为Pizza对象实例初始化一个绑定方法,绑定方法同样是对象,但是创建他们需要成本,而静态方法就可以避免这些。
>>> Pizza().cook is Pizza().cook
False
>>> Pizza().mix_ingredients is Pizza.mix_ingredients
True
>>> Pizza().mix_ingredients is Pizza().mix_ingredients
True
  • 可读性更好的代码,看到@staticmethod我们就知道这个方法并不需要依赖对象本身的状态。
  • 可以在子类中被覆盖,如果是把mix_ingredients作为模块的顶层函数,那么继承自Pizza的子类就没法改变pizza的mix_ingredients了如果不覆盖cook的话。

类方法

  • 话虽如此,什么是类方法呢?类方法不是绑定到对象上,而是绑定在类上的方法。
>>> class Pizza(object):
...     radius = 42
...     @classmethod
...     def get_radius(cls):
...         return cls.radius
...
>>>
>>> Pizza.get_radius
<bound method type.get_radius of <class '__main__.Pizza'>>
>>> Pizza().get_radius
<bound method type.get_radius of <class '__main__.Pizza'>>
>>> Pizza.get_radius is Pizza().get_radius
True
>>> Pizza.get_radius()
42
  • 无论你用哪种方式访问这个方法,它总是绑定到了这个类身上,它的第一个参数是这个类本身(记住:类也是对象)。
  • 什么时候使用这种方法呢?类方法通常在以下两种场景是非常有用的:
  • 工厂方法:它用于创建类的实例,例如一些预处理。如果使用@staticmethod代替,那我们不得不硬编码Pizza类名在函数中,这使得任何继承Pizza的类都不能使用我们这个工厂方法给它自己用。
class Pizza(object):
    def __init__(self, ingredients):
        self.ingredients = ingredients
    @classmethod
    def from_fridge(cls, fridge):
        return cls(fridge.get_cheese() + fridge.get_vegetables())
  • 调用静态类:如果你把一个静态方法拆分成多个静态方法,除非你使用类方法,否则你还是得硬编码类名。使用这种方式声明方法,Pizza类名明永远都不会在被直接引用,继承和方法覆盖都可以完美的工作。
class Pizza(object):
    def __init__(self, radius, height):
        self.radius = radius
        self.height = height
    @staticmethod
    def compute_area(radius):
         return math.pi * (radius ** 2)
    @classmethod
    def compute_volume(cls, height, radius):
         return height * cls.compute_area(radius)
    def get_volume(self):
        return self.compute_volume(self.height, self.radius)

抽象方法

  • 抽象方法是定义在基类中的一种方法,它没有提供任何实现,类似于Java中接口(Interface)里面的方法。
  • 在Python中实现抽象方法最简单地方式是:
class Pizza(object):
    def get_radius(self):
        raise NotImplementedError
  • 任何继承自_Pizza的类必须覆盖实现方法get_radius,否则会抛出异常。
  • 这种抽象方法的实现有它的弊端,如果你写一个类继承Pizza,但是忘记实现get_radius,异常只有在你真正使用的时候才会抛出来。
>>> Pizza()
<__main__.Pizza object at 0x7fb747353d90>
>>> Pizza().get_radius()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 3, in get_radius
NotImplementedError
  • 还有一种方式可以让错误更早的触发,使用Python提供的abc模块,对象被初始化之后就可以抛出异常:
import abc
class BasePizza(object):
    __metaclass__  = abc.ABCMeta
    @abc.abstractmethod
    def get_radius(self):
         """Method that should do something."""
  • 使用abc后,当你尝试初始化BasePizza或者任何子类的时候立马就会得到一个TypeError,而无需等到真正调用get_radius的时候才发现异常。
>>> BasePizza()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Can't instantiate abstract class BasePizza with abstract methods get_radius

混合静态方法、类方法、抽象方法

  • 当你开始构建类和继承结构时,混合使用这些装饰器的时候到了,所以这里列出了一些技巧。
  • 记住,声明一个抽象的方法,不会固定方法的原型,这就意味着虽然你必须实现它,但是我可以用任何参数列表来实现:
import abc
class BasePizza(object):
    __metaclass__  = abc.ABCMeta
    @abc.abstractmethod
    def get_ingredients(self):
         """Returns the ingredient list."""
class Calzone(BasePizza):
    def get_ingredients(self, with_egg=False):
        egg = Egg() if with_egg else None
        return self.ingredients + egg
  • 这样是允许的,因为Calzone满足BasePizza对象所定义的接口需求。同样我们也可以用一个类方法或静态方法来实现:
import abc
class BasePizza(object):
    __metaclass__  = abc.ABCMeta
    @abc.abstractmethod
    def get_ingredients(self):
         """Returns the ingredient list."""
class DietPizza(BasePizza):
    @staticmethod
    def get_ingredients():
        return None
  • 这同样是正确的,因为它遵循抽象类BasePizza设定的契约。事实上get_ingredients方法并不需要知道返回结果是什么,结果是实现细节,不是契约条件。
  • 因此,你不能强制抽象方法的实现是一个常规方法、或者是类方法还是静态方法,也没什么可争论的。从Python3开始(在Python2中不能如你期待的运行,见issue5867),在abstractmethod方法上面使用@staticmethod和@classmethod装饰器成为可能。
import abc
class BasePizza(object):
    __metaclass__  = abc.ABCMeta
    ingredient = ['cheese']
    @classmethod
    @abc.abstractmethod
    def get_ingredients(cls):
         """Returns the ingredient list."""
         return cls.ingredients
  • 别误会了,如果你认为它会强制子类作为一个类方法来实现get_ingredients那你就错了,它仅仅表示你实现的get_ingredients在BasePizza中是一个类方法。
  • 可以在抽象方法中做代码的实现?没错,Python与Java接口中的方法相反,你可以在抽象方法编写实现代码通过super()来调用它。(译注:在Java8中,接口也提供的默认方法,允许在接口中写方法的实现)
import abc
class BasePizza(object):
    __metaclass__  = abc.ABCMeta
    default_ingredients = ['cheese']
    @classmethod
    @abc.abstractmethod
    def get_ingredients(cls):
         """Returns the ingredient list."""
         return cls.default_ingredients
class DietPizza(BasePizza):
    def get_ingredients(self):
        return ['egg'] + super(DietPizza, self).get_ingredients()
  • 这个例子中,你构建的每个pizza都通过继承BasePizza的方式,你不得不覆盖get_ingredients方法,但是能够使用默认机制通过super()来获取ingredient列表。
以下代码有定义小车的运动学属性吗# coding=utf-8 from .collision import * from .graphics import load_texture from .utils import get_file_path class WorldObj: def __init__(self, obj, domain_rand, safety_radius_mult): &quot;&quot;&quot; Initializes the object and its properties &quot;&quot;&quot; # XXX this is relied on by things but it is not always set # (Static analysis complains) self.visible = True # same self.color = (0, 0, 0) # maybe have an abstract method is_visible, get_color() self.process_obj_dict(obj, safety_radius_mult) self.domain_rand = domain_rand self.angle = self.y_rot * (math.pi / 180) self.generate_geometry() def generate_geometry(self): # Find corners and normal vectors assoc w. object self.obj_corners = generate_corners(self.pos, self.min_coords, self.max_coords, self.angle, self.scale) self.obj_norm = generate_norm(self.obj_corners) def process_obj_dict(self, obj, safety_radius_mult): self.kind = obj[&#39;kind&#39;] self.mesh = obj[&#39;mesh&#39;] self.pos = obj[&#39;pos&#39;] self.scale = obj[&#39;scale&#39;] self.y_rot = obj[&#39;y_rot&#39;] self.optional = obj[&#39;optional&#39;] self.min_coords = obj[&#39;mesh&#39;].min_coords self.max_coords = obj[&#39;mesh&#39;].max_coords self.static = obj[&#39;static&#39;] self.safety_radius = safety_radius_mult *\ calculate_safety_radius(self.mesh, self.scale) def render(self, draw_bbox): &quot;&quot;&quot; Renders the object to screen &quot;&quot;&quot; if not self.visible: return from pyglet import gl # Draw the bounding box if draw_bbox: gl.glColor3f(1, 0, 0) gl.glBegin(gl.GL_LINE_LOOP) gl.glVertex3f(self.obj_corners.T[0, 0], 0.01, self.obj_corners.T[1, 0]) gl.glVertex3f(self.obj_corners.T[0, 1], 0.01, self.obj_corners.T[1, 1]) gl.glVertex3f(self.obj_corners.T[0, 2], 0.01, self.obj_corners.T[1, 2]) gl.glVertex3f(self.obj_corners.T[0, 3], 0.01, self.obj_corners.T[1, 3]) gl.glEnd() gl.glPushMatrix() gl.glTranslatef(*self.pos) gl.glScalef(self.scale, self.scale, self.scale) gl.glRotatef(self.y_rot, 0, 1, 0) gl.glColor3f(*self.color) self.mesh.render() gl.glPopMatrix() # Below are the functions that need to # be reimplemented for any dynamic object def check_collision(self, agent_corners, agent_norm): &quot;&quot;&quot; See if the agent collided with this object For static, return false (static collisions checked w numpy in a batch operation) &quot;&quot;&quot; if not self.static: raise NotImplementedError return False def proximity(self, agent_pos, agent_safety_rad): &quot;&quot;&quot; See if the agent is too close to this object For static, return 0 (static safedriving checked w numpy in a batch operation) &quot;&quot;&quot; if not self.static: raise NotImplementedError return 0.0 def step(self, delta_time): &quot;&quot;&quot; Use a motion model to move the object in the world &quot;&quot;&quot; if not self.static: raise NotImplementedError class DuckiebotObj(WorldObj): def __init__(self, obj, domain_rand, safety_radius_mult, wheel_dist, robot_width, robot_length, gain=2.0, trim=0.0, radius=0.0318, k=27.0, limit=1.0): WorldObj.__init__(self, obj, domain_rand, safety_radius_mult) if self.domain_rand: self.follow_dist = np.random.uniform(0.3, 0.4) self.velocity = np.random.uniform(0.05, 0.15) else: self.follow_dist = 0.3 self.velocity = 0.1 self.max_iterations = 1000 # TODO: Make these DR as well self.gain = gain self.trim = trim self.radius = radius self.k = k self.limit = limit self.wheel_dist = wheel_dist self.robot_width = robot_width self.robot_length = robot_length # FIXME: this does not follow the same signature as WorldOb def step(self, delta_time, closest_curve_point, objects): &quot;&quot;&quot; Take a step, implemented as a PID controller &quot;&quot;&quot; # Find the curve point closest to the agent, and the tangent at that point closest_point, closest_tangent = closest_curve_point(self.pos, self.angle) iterations = 0 lookup_distance = self.follow_dist curve_point = None while iterations &lt; self.max_iterations: # Project a point ahead along the curve tangent, # then find the closest point to to that follow_point = closest_point + closest_tangent * lookup_distance curve_point, _ = closest_curve_point(follow_point, self.angle) # If we have a valid point on the curve, stop if curve_point is not None: break iterations += 1 lookup_distance *= 0.5 # Compute a normalized vector to the curve point point_vec = curve_point - self.pos point_vec /= np.linalg.norm(point_vec) dot = np.dot(self.get_right_vec(self.angle), point_vec) steering = self.gain * -dot self._update_pos([self.velocity, steering], delta_time) def get_dir_vec(self, angle): x = math.cos(angle) z = -math.sin(angle) return np.array([x, 0, z]) def get_right_vec(self, angle): x = math.sin(angle) z = math.cos(angle) return np.array([x, 0, z]) def check_collision(self, agent_corners, agent_norm): &quot;&quot;&quot; See if the agent collided with this object &quot;&quot;&quot; return intersects_single_obj( agent_corners, self.obj_corners.T, agent_norm, self.obj_norm ) def proximity(self, agent_pos, agent_safety_rad): &quot;&quot;&quot; See if the agent is too close to this object based on a heuristic for the &quot;overlap&quot; between their safety circles &quot;&quot;&quot; d = np.linalg.norm(agent_pos - self.pos) score = d - agent_safety_rad - self.safety_radius return min(0, score) def _update_pos(self, action, deltaTime): vel, angle = action # assuming same motor constants k for both motors k_r = self.k k_l = self.k # adjusting k by gain and trim k_r_inv = (self.gain + self.trim) / k_r k_l_inv = (self.gain - self.trim) / k_l omega_r = (vel + 0.5 * angle * self.wheel_dist) / self.radius omega_l = (vel - 0.5 * angle * self.wheel_dist) / self.radius # conversion from motor rotation rate to duty cycle u_r = omega_r * k_r_inv u_l = omega_l * k_l_inv # limiting output to limit, which is 1.0 for the duckiebot u_r_limited = max(min(u_r, self.limit), -self.limit) u_l_limited = max(min(u_l, self.limit), -self.limit) # If the wheel velocities are the same, then there is no rotation if u_l_limited == u_r_limited: self.pos = self.pos + deltaTime * u_l_limited * self.get_dir_vec(self.angle) return # Compute the angular rotation velocity about the ICC (center of curvature) w = (u_r_limited - u_l_limited) / self.wheel_dist # Compute the distance to the center of curvature r = (self.wheel_dist * (u_l_limited + u_r_limited)) / (2 * (u_l_limited - u_r_limited)) # Compute the rotation angle for this time step rotAngle = w * deltaTime # Rotate the robot&#39;s position around the center of rotation r_vec = self.get_right_vec(self.angle) px, py, pz = self.pos cx = px + r * r_vec[0] cz = pz + r * r_vec[2] npx, npz = rotate_point(px, pz, cx, cz, rotAngle) # Update position self.pos = np.array([npx, py, npz]) # Update the robot&#39;s direction angle self.angle += rotAngle self.y_rot += rotAngle * 180 / np.pi # Recompute the bounding boxes (BB) for the duckiebot self.obj_corners = agent_boundbox( self.pos, self.robot_width, self.robot_length, self.get_dir_vec(self.angle), self.get_right_vec(self.angle) ) class DuckieObj(WorldObj): def __init__(self, obj, domain_rand, safety_radius_mult, walk_distance): WorldObj.__init__(self, obj, domain_rand, safety_radius_mult) self.walk_distance = walk_distance + 0.25 # Dynamic duckie stuff # Randomize velocity and wait time if self.domain_rand: self.pedestrian_wait_time = np.random.randint(3, 20) self.vel = np.abs(np.random.normal(0.02, 0.005)) else: self.pedestrian_wait_time = 8 self.vel = 0.02 # Movement parameters self.heading = heading_vec(self.angle) self.start = np.copy(self.pos) self.center = self.pos self.pedestrian_active = False # Walk wiggle parameter self.wiggle = np.random.choice([14, 15, 16], 1) self.wiggle = np.pi / self.wiggle self.time = 0 def check_collision(self, agent_corners, agent_norm): &quot;&quot;&quot; See if the agent collided with this object &quot;&quot;&quot; return intersects_single_obj( agent_corners, self.obj_corners.T, agent_norm, self.obj_norm ) def proximity(self, agent_pos, agent_safety_rad): &quot;&quot;&quot; See if the agent is too close to this object based on a heuristic for the &quot;overlap&quot; between their safety circles &quot;&quot;&quot; d = np.linalg.norm(agent_pos - self.center) score = d - agent_safety_rad - self.safety_radius return min(0, score) def step(self, delta_time): &quot;&quot;&quot; Use a motion model to move the object in the world &quot;&quot;&quot; self.time += delta_time # If not walking, no need to do anything if not self.pedestrian_active: self.pedestrian_wait_time -= delta_time if self.pedestrian_wait_time &lt;= 0: self.pedestrian_active = True return # Update centers and bounding box vel_adjust = self.heading * self.vel self.center += vel_adjust self.obj_corners += vel_adjust[[0, -1]] distance = np.linalg.norm(self.center - self.start) if distance &gt; self.walk_distance: self.finish_walk() self.pos = self.center angle_delta = self.wiggle * math.sin(48 * self.time) self.y_rot = (self.angle + angle_delta) * (180 / np.pi) self.obj_norm = generate_norm(self.obj_corners) def finish_walk(self): &quot;&quot;&quot; After duckie crosses, update relevant attributes (vel, rot, wait time until next walk) &quot;&quot;&quot; self.start = np.copy(self.center) self.angle += np.pi self.pedestrian_active = False if self.domain_rand: # Assign a random velocity (in opp. direction) and a wait time # TODO: Fix this: This will go to 0 over time self.vel = -1 * np.sign(self.vel) * np.abs(np.random.normal(0.02, 0.005)) self.pedestrian_wait_time = np.random.randint(3, 20) else: # Just give it the negative of its current velocity self.vel *= -1 self.pedestrian_wait_time = 8 class TrafficLightObj(WorldObj): def __init__(self, obj, domain_rand, safety_radius_mult): WorldObj.__init__(self, obj, domain_rand, safety_radius_mult) self.texs = [ load_texture(get_file_path(&quot;textures&quot;, &quot;trafficlight_card0&quot;, &quot;jpg&quot;)), load_texture(get_file_path(&quot;textures&quot;, &quot;trafficlight_card1&quot;, &quot;jpg&quot;)) ] self.time = 0 # Frequency and current pattern of the lights if self.domain_rand: self.freq = np.random.randint(4, 7) self.pattern = np.random.randint(0, 2) else: self.freq = 5 self.pattern = 0 # Use the selected pattern self.mesh.textures[0] = self.texs[self.pattern] def step(self, delta_time): &quot;&quot;&quot; Changes the light color periodically &quot;&quot;&quot; self.time += delta_time if round(self.time, 3) % self.freq == 0: # Swap patterns self.pattern ^= 1 self.mesh.textures[0] = self.texs[self.pattern] def is_green(self, direction=&#39;N&#39;): if direction == &#39;N&#39; or direction == &#39;S&#39;: if self.y_rot == 45 or self.y_rot == 135: return self.pattern == 0 elif self.y_rot == 225 or self.y_rot == 315: return self.pattern == 1 elif direction == &#39;E&#39; or direction == &#39;W&#39;: if self.y_rot == 45 or self.y_rot == 135: return self.pattern == 1 elif self.y_rot == 225 or self.y_rot == 315: return self.pattern == 0 return False
07-12
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值